diff --git a/unhacs/main.py b/unhacs/main.py index 67acc95..d06a2b0 100644 --- a/unhacs/main.py +++ b/unhacs/main.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from collections.abc import Iterable from pathlib import Path from unhacs.packages import DEFAULT_HASS_CONFIG_PATH @@ -59,6 +60,20 @@ def create_parser(): class Unhacs: + def __init__( + self, + hass_config: Path = DEFAULT_HASS_CONFIG_PATH, + package_file: Path = DEFAULT_PACKAGE_FILE, + ): + self.hass_config = hass_config + self.package_file = package_file + + def read_lock_packages(self) -> list[Package]: + return read_lock_packages(self.package_file) + + def write_lock_packages(self, packages: Iterable[Package]): + return write_lock_packages(packages, self.package_file) + def add_package( self, package_url: str, @@ -66,8 +81,9 @@ class Unhacs: version: str | None = None, update: bool = False, ): + """Install and add a package to the lock or install a specific version.""" package = Package(name=package_name, url=package_url, version=version) - packages = read_lock_packages() + packages = self.read_lock_packages() # Raise an error if the package is already in the list if package in packages: @@ -77,18 +93,20 @@ class Unhacs: else: raise ValueError("Package already exists in the list") - package.install() + package.install(self.hass_config) packages.append(package) - write_lock_packages(packages) + self.write_lock_packages(packages) def upgrade_packages(self, package_names: list[str]): """Uograde to latest version of packages and update lock.""" if not package_names: - installed_packages = get_installed_packages() + installed_packages = get_installed_packages(self.hass_config) else: installed_packages = [ - p for p in get_installed_packages() if p.name in package_names + p + for p in get_installed_packages(self.hass_config) + if p.name in package_names ] upgrade_packages: list[Package] = [] @@ -106,19 +124,21 @@ class Unhacs: return for installed_package in upgrade_packages: - installed_package.install() + installed_package.install(self.hass_config) - # Update lock file to latest now that we know they are upgraded + # Update lock file to latest now that we know they are uograded latest_lookup = {p.url: p for p in latest_packages} - packages = [latest_lookup.get(p.url, p) for p in read_lock_packages()] + packages = [latest_lookup.get(p.url, p) for p in self.read_lock_packages()] - write_lock_packages(packages) + self.write_lock_packages(packages) def list_packages(self, verbose: bool = False): + """List installed packages and their versions.""" for package in get_installed_packages(): print(package.verbose_str() if verbose else str(package)) def remove_packages(self, package_names: list[str]): + """Remove installed packages and uodate lock.""" packages_to_remove = [ package for package in get_installed_packages() @@ -126,14 +146,14 @@ class Unhacs: ] remaining_packages = [ package - for package in read_lock_packages() + for package in self.read_lock_packages() if package not in packages_to_remove ] for package in packages_to_remove: - package.uninstall() + package.uninstall(self.hass_config) - write_lock_packages(remaining_packages) + self.write_lock_packages(remaining_packages) def main(): @@ -141,7 +161,7 @@ def main(): parser = create_parser() args = parser.parse_args() - unhacs = Unhacs() + unhacs = Unhacs(args.config, args.package_file) if args.subcommand == "add": # If a file was provided, update all packages based on the lock file @@ -163,6 +183,7 @@ def main(): unhacs.upgrade_packages(args.packages) else: print(f"Command {args.subcommand} is not implemented") + exit(1) if __name__ == "__main__": diff --git a/unhacs/packages.py b/unhacs/packages.py index d546786..3c5ba99 100644 --- a/unhacs/packages.py +++ b/unhacs/packages.py @@ -9,7 +9,7 @@ from zipfile import ZipFile import requests DEFAULT_HASS_CONFIG_PATH: Path = Path(".") -DEFAULT_PACKAGE_FILE = "unhacs.txt" +DEFAULT_PACKAGE_FILE = Path("unhacs.txt") def extract_zip(zip_file: ZipFile, dest_dir: Path): @@ -88,9 +88,7 @@ class Package: # If no version is provided, use the latest release return releases[0]["tag_name"], releases[0]["zipball_url"] - def install( - self, hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH, replace: bool = True - ): + def install(self, hass_config_path: Path, replace: bool = True): # Fetch the release zip with the specified version if not self.zip_url: _, self.zip_url = self.fetch_version_release(self.version) @@ -114,7 +112,7 @@ class Package: shutil.move(custom_component, dest) dest.joinpath("unhacs.txt").write_text(self.serialize()) - def uninstall(self, hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH) -> bool: + def uninstall(self, hass_config_path: Path) -> bool: if self.path: shutil.rmtree(self.path) return True @@ -126,9 +124,7 @@ class Package: return False - def installed_package( - self, hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH - ) -> "Package|None": + def installed_package(self, hass_config_path: Path) -> "Package|None": for custom_component in (hass_config_path / "custom_components").glob("*"): unhacs = custom_component / "unhacs.txt" if unhacs.exists(): @@ -141,8 +137,8 @@ class Package: return installed_package return None - def outdated(self) -> bool: - installed_package = self.installed_package() + def is_update(self, hass_config_path: Path) -> bool: + installed_package = self.installed_package(hass_config_path) return installed_package is None or installed_package.version != self.version @@ -161,17 +157,16 @@ def get_installed_packages( # Read a list of Packages from a text file in the plain text format "URL version name" -def read_lock_packages(package_file: str = DEFAULT_PACKAGE_FILE) -> list[Package]: - path = Path(package_file) - if path.exists(): - with path.open() as f: +def read_lock_packages(package_file: Path = DEFAULT_PACKAGE_FILE) -> list[Package]: + if package_file.exists(): + with package_file.open() as f: return [Package.deserialize(line.strip()) for line in f] return [] # Write a list of Packages to a text file in the format URL version name def write_lock_packages( - packages: Iterable[Package], package_file: str = DEFAULT_PACKAGE_FILE + packages: Iterable[Package], package_file: Path = DEFAULT_PACKAGE_FILE ): - with open(package_file, "w") as f: + with package_file.open("w") as f: f.writelines(sorted(f"{package.serialize()}\n" for package in packages))