From d541fba28338c66b6c9bea2f20a441d02754dd82 Mon Sep 17 00:00:00 2001 From: Ian Fijolek Date: Mon, 10 Jun 2024 13:50:50 -0700 Subject: [PATCH] Refactor interface and handle installed packages --- unhacs/main.py | 180 +++++++++++++++++---------------------------- unhacs/packages.py | 155 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 215 insertions(+), 120 deletions(-) diff --git a/unhacs/main.py b/unhacs/main.py index ab8c799..f473daa 100644 --- a/unhacs/main.py +++ b/unhacs/main.py @@ -1,32 +1,12 @@ -import json -import shutil -import tempfile from argparse import ArgumentParser -from io import BytesIO from pathlib import Path -from zipfile import ZipFile - -import requests +from unhacs.packages import DEFAULT_HASS_CONFIG_PATH from unhacs.packages import DEFAULT_PACKAGE_FILE from unhacs.packages import Package -from unhacs.packages import read_packages -from unhacs.packages import write_packages - -DEFAULT_HASS_CONFIG_PATH = Path(".") - - -def extract_zip(zip_file: ZipFile, dest_dir: Path): - for info in zip_file.infolist(): - if info.is_dir(): - continue - file = Path(info.filename) - # Strip top directory from path - file = Path(*file.parts[1:]) - path = dest_dir / file - path.parent.mkdir(parents=True, exist_ok=True) - with zip_file.open(info) as source, open(path, "wb") as dest: - dest.write(source.read()) +from unhacs.packages import get_installed_packages +from unhacs.packages import read_lock_packages +from unhacs.packages import write_lock_packages def create_parser(): @@ -52,7 +32,10 @@ def create_parser(): list_parser.add_argument("--verbose", "-v", action="store_true") add_parser = subparsers.add_parser("add") - add_parser.add_argument("url", type=str, help="The URL of the package.") + add_parser.add_argument( + "--file", "-f", type=Path, help="The path to a package file." + ) + add_parser.add_argument("url", nargs="?", type=str, help="The URL of the package.") add_parser.add_argument( "name", type=str, nargs="?", help="The name of the package." ) @@ -67,9 +50,9 @@ def create_parser(): ) remove_parser = subparsers.add_parser("remove") - remove_parser.add_argument("packages", nargs="*") + remove_parser.add_argument("packages", nargs="+") - update_parser = subparsers.add_parser("update") + update_parser = subparsers.add_parser("upgrade") update_parser.add_argument("packages", nargs="*") return parser @@ -83,104 +66,65 @@ class Unhacs: version: str | None = None, update: bool = False, ): - # Parse the package URL to get the owner and repo name - parts = package_url.split("/") - owner = parts[-2] - repo = parts[-1] - - # Fetch the releases from the GitHub API - response = requests.get(f"https://api.github.com/repos/{owner}/{repo}/releases") - response.raise_for_status() - releases = response.json() - - # If a version is provided, check if it exists in the releases - if version: - for release in releases: - if release["tag_name"] == version: - break - else: - raise ValueError(f"Version {version} does not exist for this package") - else: - # If no version is provided, use the latest release - version = releases[0]["tag_name"] - - if not version: - raise ValueError("No releases found for this package") - - package = Package(name=package_name or repo, url=package_url, version=version) - packages = read_packages() + package = Package(name=package_name, url=package_url, version=version) + packages = read_lock_packages() # Raise an error if the package is already in the list if package in packages: if update: # Remove old version of the package - packages = [p for p in packages if p.url != package_url] + packages = [p for p in packages if p != package] else: raise ValueError("Package already exists in the list") + package.install() + packages.append(package) - write_packages(packages) + write_lock_packages(packages) - self.download_package(package) - - def download_package(self, package: Package, replace: bool = True): - # Parse the package URL to get the owner and repo name - parts = package.url.split("/") - owner = parts[-2] - repo = parts[-1] - - # Fetch the releases from the GitHub API - response = requests.get(f"https://api.github.com/repos/{owner}/{repo}/releases") - response.raise_for_status() - releases = response.json() - - # Find the release with the specified version - for release in releases: - if release["tag_name"] == package.version: - break - else: - raise ValueError(f"Version {package.version} not found for this package") - - # Download the release zip with the specified name - response = requests.get(release["zipball_url"]) - response.raise_for_status() - - release_zip = ZipFile(BytesIO(response.content)) - - with tempfile.TemporaryDirectory(prefix="unhacs-") as tempdir: - tmpdir = Path(tempdir) - extract_zip(release_zip, tmpdir) - - for file in tmpdir.glob("*"): - print(file) - hacs = json.loads((tmpdir / "hacs.json").read_text()) - print(hacs) - - for custom_component in tmpdir.glob("custom_components/*"): - dest = ( - DEFAULT_HASS_CONFIG_PATH - / "custom_components" - / custom_component.name - ) - if replace: - shutil.rmtree(dest, ignore_errors=True) - - shutil.move(custom_component, dest) - - def update_packages(self, package_names: list[str]): + def upgrade_packages(self, package_names: list[str]): if not package_names: - package_urls = [p.url for p in read_packages()] + packages = read_lock_packages() else: - package_urls = [p.url for p in read_packages() if p.name in package_names] + packages = [p for p in read_lock_packages() if p.name in package_names] - for package in package_urls: - print("Updating", package) - self.add_package(package, update=True) + latest_packages = [Package(name=p.name, url=p.url) for p in packages] + for package, latest_package in zip(packages, latest_packages): + if latest_package.outdated(): + print( + f"upgrade {package.name} from {package.version} to {latest_package.version}" + ) + + # Prompt the user to press Y to continue and upgrade all packages, otherwise exit + if input("Upgrade all packages? (y/N) ").lower() != "y": + return + + for package in latest_packages: + package.install() + + write_lock_packages(set(latest_packages) | set(packages)) def list_packages(self, verbose: bool = False): - for package in read_packages(): + for package in get_installed_packages(): print(package.verbose_str() if verbose else str(package)) + def remove_packages(self, package_names: list[str]): + packages_to_remove = [ + package + for package in get_installed_packages() + if package.name in package_names + ] + remaining_packages = [ + package + for package in read_lock_packages() + if package not in packages_to_remove + ] + + for package in packages_to_remove: + package.uninstall() + + write_lock_packages(remaining_packages) + def main(): # If the sub command is add package, it should pass the parsed arguments to the add_package function and return @@ -190,15 +134,25 @@ def main(): unhacs = Unhacs() if args.subcommand == "add": - unhacs.add_package(args.url, args.name, args.version, args.update) + # If a file was provided, update all packages based on the lock file + if args.file: + packages = read_lock_packages(args.file) + for package in packages: + unhacs.add_package( + package.url, package.name, package.version, update=True + ) + elif args.url: + unhacs.add_package(args.url, args.name, args.version, args.update) + else: + raise ValueError("Either a file or a URL must be provided") elif args.subcommand == "list": unhacs.list_packages(args.verbose) elif args.subcommand == "remove": - print("Not implemented") - elif args.subcommand == "update": - unhacs.update_packages(args.packages) + unhacs.remove_packages(args.packages) + elif args.subcommand == "upgrade": + unhacs.upgrade_packages(args.packages) else: - print("Not implemented") + print(f"Command {args.subcommand} is not implemented") if __name__ == "__main__": diff --git a/unhacs/packages.py b/unhacs/packages.py index 25ac26b..27bc1e3 100644 --- a/unhacs/packages.py +++ b/unhacs/packages.py @@ -1,33 +1,174 @@ -from dataclasses import dataclass +import json +import shutil +import tempfile +from collections.abc import Iterable +from io import BytesIO from pathlib import Path +from zipfile import ZipFile +import requests + +DEFAULT_HASS_CONFIG_PATH: Path = Path(".") DEFAULT_PACKAGE_FILE = "unhacs.txt" -@dataclass +def extract_zip(zip_file: ZipFile, dest_dir: Path): + for info in zip_file.infolist(): + if info.is_dir(): + continue + file = Path(info.filename) + # Strip top directory from path + file = Path(*file.parts[1:]) + path = dest_dir / file + path.parent.mkdir(parents=True, exist_ok=True) + with zip_file.open(info) as source, open(path, "wb") as dest: + dest.write(source.read()) + + class Package: url: str version: str + zip_url: str name: str + path: Path | None = None + + def __init__(self, url: str, version: str | None = None, name: str | None = None): + self.url = url + + self.version, self.zip_url = self.fetch_version_release(version) + + parts = url.split("/") + repo = parts[-1] + self.name = name or repo def __str__(self): return f"{self.name} {self.version}" + def __eq__(self, other): + return ( + self.url == other.url + and self.version == other.version + and self.name == other.name + ) + def verbose_str(self): return f"{self.name} {self.version} ({self.url})" + def serialize(self) -> str: + return f"{self.url} {self.version} {self.name}" + + @staticmethod + def deserialize(serialized: str) -> "Package": + url, version, name = serialized.split() + return Package(url, version, name) + + def fetch_version_release(self, version: str | None = None) -> tuple[str, str]: + # Fetch the releases from the GitHub API + parts = self.url.split("/") + owner = parts[-2] + repo = parts[-1] + + response = requests.get(f"https://api.github.com/repos/{owner}/{repo}/releases") + response.raise_for_status() + releases = response.json() + + if not releases: + raise ValueError(f"No releases found for package {self.name}") + + # If a version is provided, check if it exists in the releases + if version: + for release in releases: + if release["tag_name"] == version: + return version, release["zipball_url"] + else: + raise ValueError(f"Version {version} does not exist for this package") + # If no version is provided, use the latest release + return releases[0]["tag_name"], releases[0]["zipball_url"] + + def install( + self, hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH, replace: bool = True + ): + # Fetch the release zip with the specified version + if not self.zip_url: + _, self.zip_url = self.fetch_version_release(self.version) + + response = requests.get(self.zip_url) + response.raise_for_status() + + # Extract the zip to a temporary directory + with tempfile.TemporaryDirectory(prefix="unhacs-") as tempdir: + tmpdir = Path(tempdir) + extract_zip(ZipFile(BytesIO(response.content)), tmpdir) + + hacs = json.loads((tmpdir / "hacs.json").read_text()) + print("Hacs?", hacs) + + for custom_component in tmpdir.glob("custom_components/*"): + dest = hass_config_path / "custom_components" / custom_component.name + if replace: + shutil.rmtree(dest, ignore_errors=True) + + shutil.move(custom_component, dest) + dest.joinpath("unhacs.txt").write_text(self.serialize()) + + def uninstall(self, hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH) -> bool: + if self.path: + shutil.rmtree(self.path) + return True + + installed_package = self.installed_package(hass_config_path) + if installed_package and installed_package.path: + shutil.rmtree(installed_package.path) + return True + + return False + + def installed_package( + self, hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH + ) -> "Package|None": + for custom_component in (hass_config_path / "custom_components").glob("*"): + unhacs = custom_component / "unhacs.txt" + if unhacs.exists(): + installed_package = Package.deserialize(unhacs.read_text()) + installed_package.path = custom_component + if ( + installed_package.name == self.name + and installed_package.url == self.url + ): + return installed_package + return None + + def outdated(self) -> bool: + installed_package = self.installed_package() + return installed_package is None or installed_package.version != self.version + + +def get_installed_packages( + hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH, +) -> list[Package]: + packages = [] + for custom_component in (hass_config_path / "custom_components").glob("*"): + unhacs = custom_component / "unhacs.txt" + if unhacs.exists(): + package = Package.deserialize(unhacs.read_text()) + package.path = custom_component + packages.append(package) + + return packages + # Read a list of Packages from a text file in the plain text format "URL version name" -def read_packages(package_file: str = DEFAULT_PACKAGE_FILE) -> list[Package]: +def read_lock_packages(package_file: str = DEFAULT_PACKAGE_FILE) -> list[Package]: path = Path(package_file) if path.exists(): with path.open() as f: - return [Package(*line.strip().split()) for line in f] + return [Package.deserialize(line.strip()) for line in f] return [] # Write a list of Packages to a text file in the format URL version name -def write_packages(packages: list[Package], package_file: str = DEFAULT_PACKAGE_FILE): +def write_lock_packages( + packages: Iterable[Package], package_file: str = DEFAULT_PACKAGE_FILE +): with open(package_file, "w") as f: - for package in packages: - f.write(f"{package.url} {package.version} {package.name}\n") + f.writelines(f"{package.serialize()}\n" for package in packages)