Allow templating values into extract file names
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
IamTheFij 2024-11-06 16:21:19 -08:00
parent 7380fa99ec
commit 35b07836e8
3 changed files with 60 additions and 29 deletions

View File

@ -9,8 +9,6 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from shutil import copy from shutil import copy
from shutil import copytree from shutil import copytree
from shutil import move
from subprocess import check_call
import toml import toml
from wheel.wheelfile import WheelFile from wheel.wheelfile import WheelFile
@ -40,27 +38,19 @@ class Config:
def download(config: Config, wheel_scripts: Path) -> list[Path]: def download(config: Config, wheel_scripts: Path) -> list[Path]:
release = rg.fetch_release( """Download and extract files to the wheel_scripts directory"""
rg.GitRemoteInfo(config.hostname, config.owner, config.repo), config.version return rg.download_release(
) rg.GitRemoteInfo(config.hostname, config.owner, config.repo),
asset = rg.match_asset( wheel_scripts,
release,
config.format, config.format,
version=config.version, version=config.version,
system_mapping=config.map_system, system_mapping=config.map_system,
arch_mapping=config.map_arch, arch_mapping=config.map_arch,
extract_files=config.extract_files,
pre_release=config.pre_release,
exec=config.exec,
) )
files = rg.download_asset(
asset, extract_files=config.extract_files, destination=wheel_scripts
)
# Optionally execute post command
if config.exec:
check_call(config.exec, shell=True, cwd=wheel_scripts)
return files
def read_metadata() -> Config: def read_metadata() -> Config:
"""Read configuration from pyproject.toml""" """Read configuration from pyproject.toml"""

View File

@ -14,6 +14,7 @@ from subprocess import check_output
from tarfile import TarFile from tarfile import TarFile
from tarfile import TarInfo from tarfile import TarInfo
from typing import Any from typing import Any
from typing import NamedTuple
from urllib.parse import urlparse from urllib.parse import urlparse
from zipfile import ZipFile from zipfile import ZipFile
@ -73,6 +74,12 @@ def get_synonyms(value: str, thesaurus: list[list[str]]) -> list[str]:
return results return results
class MatchedValues(NamedTuple):
version: str
system: str
arch: str
@dataclass @dataclass
class GitRemoteInfo: class GitRemoteInfo:
"""Extracts information about a repository""" """Extracts information about a repository"""
@ -225,7 +232,7 @@ def match_asset(
version: str | None = None, version: str | None = None,
system_mapping: dict[str, str] | None = None, system_mapping: dict[str, str] | None = None,
arch_mapping: dict[str, str] | None = None, arch_mapping: dict[str, str] | None = None,
) -> dict[Any, Any]: ) -> tuple[dict[Any, Any], MatchedValues]:
"""Accepts a release and searches for an appropriate asset attached using """Accepts a release and searches for an appropriate asset attached using
a provided template and some alternative mappings for version, system, and machine info a provided template and some alternative mappings for version, system, and machine info
@ -286,7 +293,7 @@ def match_asset(
version=version_opt, version=version_opt,
system=system_opt, system=system_opt,
arch=arch_opt, arch=arch_opt,
) ): MatchedValues(version=version_opt, system=system_opt, arch=arch_opt)
for version_opt, system_opt, arch_opt in product( for version_opt, system_opt, arch_opt in product(
( (
version.lstrip("v"), version.lstrip("v"),
@ -299,7 +306,7 @@ def match_asset(
for asset in release["assets"]: for asset in release["assets"]:
if asset["name"] in expected_names: if asset["name"] in expected_names:
return asset return (asset, expected_names[asset["name"]])
raise ValueError( raise ValueError(
f"Could not find asset named {expected_names} on release {release['name']}" f"Could not find asset named {expected_names} on release {release['name']}"
@ -581,41 +588,61 @@ def download_release(
arch_mapping: dict[str, str] | None = None, arch_mapping: dict[str, str] | None = None,
extract_files: list[str] | None = None, extract_files: list[str] | None = None,
pre_release=False, pre_release=False,
exec: str | None = None,
) -> list[Path]: ) -> list[Path]:
"""Convenience method for fetching, downloading and extracting a release""" """Convenience method for fetching, downloading, and extracting a release
This is slightly different than running off the commandline, it will execute the shell script
from the destination directory, not the current working directory.
"""
release = fetch_release( release = fetch_release(
remote_info, remote_info,
version=version, version=version,
pre_release=pre_release, pre_release=pre_release,
) )
asset = match_asset( asset, matched_values = match_asset(
release, release,
format, format,
version=version, version=version,
system_mapping=system_mapping, system_mapping=system_mapping,
arch_mapping=arch_mapping, arch_mapping=arch_mapping,
) )
formatted_files = (
[file.format(**matched_values._asdict()) for file in extract_files]
if extract_files
else None
)
files = download_asset( files = download_asset(
asset, asset,
extract_files=extract_files, extract_files=formatted_files,
destination=destination, destination=destination,
) )
if exec:
check_call(exec.format(asset["name"]), shell=True, cwd=destination)
return files return files
def main(): def main():
args = _parse_args() args = _parse_args()
# Fetch the release
release = fetch_release( release = fetch_release(
GitRemoteInfo(args.hostname, args.owner, args.repo), GitRemoteInfo(args.hostname, args.owner, args.repo),
version=args.version, version=args.version,
pre_release=args.prerelease, pre_release=args.prerelease,
) )
asset = match_asset(
version = args.version or release["tag_name"]
# Find the asset to download using mapping rules
asset, matched_values = match_asset(
release, release,
args.format, args.format,
version=args.version, version=version,
system_mapping=args.map_system, system_mapping=args.map_system,
arch_mapping=args.map_arch, arch_mapping=args.map_arch,
) )
@ -627,9 +654,16 @@ def main():
print(asset["browser_download_url"]) print(asset["browser_download_url"])
return return
# Format files to extract with version info, as this is sometimes included
formatted_files = (
[file.format(**matched_values._asdict()) for file in args.extract_files]
if args.extract_files
else None
)
files = download_asset( files = download_asset(
asset, asset,
extract_files=args.extract_files, extract_files=formatted_files,
destination=args.destination, destination=args.destination,
) )

View File

@ -199,6 +199,13 @@ class TestContentTypeDetection(unittest.TestCase):
) )
def first_result(f):
def wrapper(*args, **kwargs):
return f(*args, **kwargs)[0]
return wrapper
class TestMatchAsset(unittest.TestCase): class TestMatchAsset(unittest.TestCase):
def test_match_asset_versions(self, *_): def test_match_asset_versions(self, *_):
# Input variations: # Input variations:
@ -233,7 +240,7 @@ class TestMatchAsset(unittest.TestCase):
) )
] ]
for test_case in happy_cases: for test_case in happy_cases:
test_case.run(release_gitter.match_asset) test_case.run(first_result(release_gitter.match_asset))
def test_match_asset_systems(self, *_): def test_match_asset_systems(self, *_):
# Input variations: # Input variations:
@ -347,7 +354,7 @@ class TestMatchAsset(unittest.TestCase):
), ),
) )
for test_case in test_cases: for test_case in test_cases:
test_case.run(run_with_context) test_case.run(first_result(run_with_context))
def test_match_asset_archs(self, *_): def test_match_asset_archs(self, *_):
# Input variations: # Input variations:
@ -468,7 +475,7 @@ class TestMatchAsset(unittest.TestCase):
), ),
) )
for test_case in test_cases: for test_case in test_cases:
test_case.run(run_with_context) test_case.run(first_result(run_with_context))
if __name__ == "__main__": if __name__ == "__main__":