Switch to yaml for lock file

This commit is contained in:
IamTheFij 2024-07-06 15:01:27 -07:00
parent 31286add39
commit b61e94005e
3 changed files with 64 additions and 51 deletions

View File

@ -26,3 +26,4 @@ repos:
exclude: docs/ exclude: docs/
additional_dependencies: additional_dependencies:
- types-requests - types-requests
- types-PyYAML

View File

@ -13,6 +13,7 @@ readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.11" python = "^3.11"
requests = "^2.32.3" requests = "^2.32.3"
pyyaml = "^6.0.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^24.4.2" black = "^24.4.2"

View File

@ -10,9 +10,10 @@ from typing import cast
from zipfile import ZipFile from zipfile import ZipFile
import requests import requests
import yaml
DEFAULT_HASS_CONFIG_PATH: Path = Path(".") DEFAULT_HASS_CONFIG_PATH: Path = Path(".")
DEFAULT_PACKAGE_FILE = Path("unhacs.txt") DEFAULT_PACKAGE_FILE = Path("unhacs.yaml")
def extract_zip(zip_file: ZipFile, dest_dir: Path): def extract_zip(zip_file: ZipFile, dest_dir: Path):
@ -64,28 +65,27 @@ class Package:
return f"{self.name} {self.version}" return f"{self.name} {self.version}"
def __eq__(self, other): def __eq__(self, other):
return ( return self.url == other.url and self.version == other.version
self.url == other.url
and self.version == other.version
and self.name == other.name
)
def verbose_str(self): def verbose_str(self):
return f"{self.name} {self.version} ({self.url})" return f"{self.name} {self.version} ({self.url})"
def serialize(self) -> str:
return f"{self.url} {self.version} {self.package_type}"
@staticmethod @staticmethod
def deserialize(serialized: str) -> "Package": def from_yaml(yaml: dict) -> "Package":
url, version, package_type = serialized.split() # Convert package_type to enum
package_type = yaml.pop("package_type")
# TODO: Use a less ambiguous serialization format that's still easy to read. Maybe TOML? if package_type and isinstance(package_type, str):
try:
package_type = PackageType(package_type) package_type = PackageType(package_type)
except ValueError: yaml["package_type"] = package_type
package_type = PackageType.INTEGRATION
return Package(url, version, package_type=package_type) return Package(**yaml)
def to_yaml(self: "Package") -> dict:
return {
"url": self.url,
"version": self.version,
"package_type": str(self.package_type),
}
def fetch_version_release(self, version: str | None = None) -> tuple[str, str]: def fetch_version_release(self, version: str | None = None) -> tuple[str, str]:
# Fetch the releases from the GitHub API # Fetch the releases from the GitHub API
@ -155,17 +155,18 @@ class Package:
# If a file is found, write it to www/js/<filename>.js and write a file www/js/<filename>-unhacs.txt with the # If a file is found, write it to www/js/<filename>.js and write a file www/js/<filename>-unhacs.txt with the
# serialized package # serialized package
filename = f"{self.name.removeprefix('lovelace-')}.js" valid_filenames: Iterable[str]
print(filename) if filename := self.get_hacs_json().get("filename"):
valid_filenames = (cast(str, filename),)
hacs_json = self.get_hacs_json()
if hacs_json.get("filename"):
filename = hacs_json["filename"]
plugin = requests.get(
f"https://github.com/{self.owner}/{self.name}/releases/download/{self.version}/{filename}"
)
else: else:
# Get dist file path URL valid_filenames = (
f"{self.name.removeprefix('lovelace-')}.js",
f"{self.name}.js",
f"{self.name}-umd.js",
f"{self.name}-bundle.js",
)
def real_get(filename) -> requests.Response:
plugin = requests.get( plugin = requests.get(
f"https://raw.githubusercontent.com/{self.owner}/{self.version}/dist/{filename}" f"https://raw.githubusercontent.com/{self.owner}/{self.version}/dist/{filename}"
) )
@ -173,19 +174,28 @@ class Package:
plugin = requests.get( plugin = requests.get(
f"https://github.com/{self.owner}/{self.name}/releases/download/{self.version}/{filename}" f"https://github.com/{self.owner}/{self.name}/releases/download/{self.version}/{filename}"
) )
plugin.raise_for_status()
if plugin.status_code == 404: if plugin.status_code == 404:
plugin = requests.get( plugin = requests.get(
f"https://raw.githubusercontent.com/{self.owner}/{self.version}/{filename}" f"https://raw.githubusercontent.com/{self.owner}/{self.version}/{filename}"
) )
plugin.raise_for_status() plugin.raise_for_status()
return plugin
for filename in valid_filenames:
try:
plugin = real_get(filename)
break
except requests.HTTPError:
pass
else:
raise ValueError(f"No valid filename found for package {self.name}")
js_path = hass_config_path / "www" / "js" js_path = hass_config_path / "www" / "js"
js_path.mkdir(parents=True, exist_ok=True) js_path.mkdir(parents=True, exist_ok=True)
js_path.joinpath(filename).write_text(plugin.text) js_path.joinpath(filename).write_text(plugin.text)
js_path.joinpath(f"{filename}-unhacs.txt").write_text(self.serialize()) yaml.dump(self.to_yaml(), js_path.joinpath(f"{filename}-unhacs.yaml").open("w"))
def install_integration(self, hass_config_path: Path): def install_integration(self, hass_config_path: Path):
zipball_url = f"https://codeload.github.com/{self.owner}/{self.name}/zip/refs/tags/{self.version}" zipball_url = f"https://codeload.github.com/{self.owner}/{self.name}/zip/refs/tags/{self.version}"
@ -196,17 +206,13 @@ class Package:
tmpdir = Path(tempdir) tmpdir = Path(tempdir)
extract_zip(ZipFile(BytesIO(response.content)), tmpdir) extract_zip(ZipFile(BytesIO(response.content)), tmpdir)
# If an integration, check for a custom_component directory and install contents
# If not present, check the hacs.json file for content_in_root to true, if so install
# the root to custom_components/<package_name>
hacs_json = json.loads((tmpdir / "hacs.json").read_text())
source, dest = None, None source, dest = None, None
for custom_component in tmpdir.glob("custom_components/*"): for custom_component in tmpdir.glob("custom_components/*"):
source = custom_component source = custom_component
dest = hass_config_path / "custom_components" / custom_component.name dest = hass_config_path / "custom_components" / custom_component.name
break break
else: else:
hacs_json = json.loads((tmpdir / "hacs.json").read_text())
if hacs_json.get("content_in_root"): if hacs_json.get("content_in_root"):
source = tmpdir source = tmpdir
dest = hass_config_path / "custom_components" / self.name dest = hass_config_path / "custom_components" / self.name
@ -218,10 +224,9 @@ class Package:
shutil.rmtree(dest, ignore_errors=True) shutil.rmtree(dest, ignore_errors=True)
shutil.move(source, dest) shutil.move(source, dest)
dest.joinpath("unhacs.txt").write_text(self.serialize()) yaml.dump(self.to_yaml(), dest.joinpath("unhacs.yaml").open("w"))
def install(self, hass_config_path: Path): def install(self, hass_config_path: Path):
print(self.package_type)
if self.package_type == PackageType.PLUGIN: if self.package_type == PackageType.PLUGIN:
self.install_plugin(hass_config_path) self.install_plugin(hass_config_path)
elif self.package_type == PackageType.INTEGRATION: elif self.package_type == PackageType.INTEGRATION:
@ -235,6 +240,7 @@ class Package:
shutil.rmtree(self.path) shutil.rmtree(self.path)
else: else:
self.path.unlink() self.path.unlink()
self.path.with_name(f"{self.path.name}-unhacs.yaml").unlink()
return True return True
installed_package = self.installed_package(hass_config_path) installed_package = self.installed_package(hass_config_path)
@ -246,9 +252,9 @@ class Package:
def installed_package(self, hass_config_path: Path) -> "Package|None": def installed_package(self, hass_config_path: Path) -> "Package|None":
for custom_component in (hass_config_path / "custom_components").glob("*"): for custom_component in (hass_config_path / "custom_components").glob("*"):
unhacs = custom_component / "unhacs.txt" unhacs = custom_component / "unhacs.yaml"
if unhacs.exists(): if unhacs.exists():
installed_package = Package.deserialize(unhacs.read_text()) installed_package = Package.from_yaml(yaml.safe_load(unhacs.open()))
installed_package.path = custom_component installed_package.path = custom_component
if ( if (
installed_package.name == self.name installed_package.name == self.name
@ -256,10 +262,10 @@ class Package:
): ):
return installed_package return installed_package
for js_unhacs in (hass_config_path / "www" / "js").glob("*-unhacs.txt"): for js_unhacs in (hass_config_path / "www" / "js").glob("*-unhacs.yaml"):
installed_package = Package.deserialize(js_unhacs.read_text()) installed_package = Package.from_yaml(yaml.safe_load(js_unhacs.open()))
installed_package.path = js_unhacs.with_name( installed_package.path = js_unhacs.with_name(
js_unhacs.name.removesuffix("-unhacs.txt") js_unhacs.name.removesuffix("-unhacs.yaml")
) )
if ( if (
installed_package.name == self.name installed_package.name == self.name
@ -278,15 +284,19 @@ def get_installed_packages(
hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH, hass_config_path: Path = DEFAULT_HASS_CONFIG_PATH,
) -> list[Package]: ) -> list[Package]:
packages = [] packages = []
# Integration packages
for custom_component in (hass_config_path / "custom_components").glob("*"): for custom_component in (hass_config_path / "custom_components").glob("*"):
unhacs = custom_component / "unhacs.txt" unhacs = custom_component / "unhacs.yaml"
if unhacs.exists(): if unhacs.exists():
package = Package.deserialize(unhacs.read_text()) package = Package.from_yaml(yaml.safe_load(unhacs.open()))
package.path = custom_component package.path = custom_component
packages.append(package) packages.append(package)
for js_unhacs in (hass_config_path / "www" / "js").glob("*-unhacs.txt"):
package = Package.deserialize(js_unhacs.read_text()) # Plugin packages
package.path = js_unhacs.with_name(js_unhacs.name.removesuffix("-unhacs.txt")) for js_unhacs in (hass_config_path / "www" / "js").glob("*-unhacs.yaml"):
package = Package.from_yaml(yaml.safe_load(js_unhacs.open()))
package.path = js_unhacs.with_name(js_unhacs.name.removesuffix("-unhacs.yaml"))
packages.append(package) packages.append(package)
return packages return packages
@ -295,8 +305,10 @@ def get_installed_packages(
# Read a list of Packages from a text file in the plain text format "URL version name" # Read a list of Packages from a text file in the plain text format "URL version name"
def read_lock_packages(package_file: Path = DEFAULT_PACKAGE_FILE) -> list[Package]: def read_lock_packages(package_file: Path = DEFAULT_PACKAGE_FILE) -> list[Package]:
if package_file.exists(): if package_file.exists():
with package_file.open() as f: return [
return [Package.deserialize(line.strip()) for line in f] Package.from_yaml(p)
for p in yaml.safe_load(package_file.open())["packages"]
]
return [] return []
@ -304,5 +316,4 @@ def read_lock_packages(package_file: Path = DEFAULT_PACKAGE_FILE) -> list[Packag
def write_lock_packages( def write_lock_packages(
packages: Iterable[Package], package_file: Path = DEFAULT_PACKAGE_FILE packages: Iterable[Package], package_file: Path = DEFAULT_PACKAGE_FILE
): ):
with package_file.open("w") as f: yaml.dump({"packages": [p.to_yaml() for p in packages]}, package_file.open("w"))
f.writelines(sorted(f"{package.serialize()}\n" for package in packages))