Allow templating values into extract file names
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
IamTheFij 2024-11-06 16:21:19 -08:00
parent 7380fa99ec
commit 35b07836e8
3 changed files with 60 additions and 29 deletions

View File

@ -9,8 +9,6 @@ from dataclasses import dataclass
from pathlib import Path
from shutil import copy
from shutil import copytree
from shutil import move
from subprocess import check_call
import toml
from wheel.wheelfile import WheelFile
@ -40,27 +38,19 @@ class Config:
def download(config: Config, wheel_scripts: Path) -> list[Path]:
release = rg.fetch_release(
rg.GitRemoteInfo(config.hostname, config.owner, config.repo), config.version
)
asset = rg.match_asset(
release,
"""Download and extract files to the wheel_scripts directory"""
return rg.download_release(
rg.GitRemoteInfo(config.hostname, config.owner, config.repo),
wheel_scripts,
config.format,
version=config.version,
system_mapping=config.map_system,
arch_mapping=config.map_arch,
extract_files=config.extract_files,
pre_release=config.pre_release,
exec=config.exec,
)
files = rg.download_asset(
asset, extract_files=config.extract_files, destination=wheel_scripts
)
# Optionally execute post command
if config.exec:
check_call(config.exec, shell=True, cwd=wheel_scripts)
return files
def read_metadata() -> Config:
"""Read configuration from pyproject.toml"""

View File

@ -14,6 +14,7 @@ from subprocess import check_output
from tarfile import TarFile
from tarfile import TarInfo
from typing import Any
from typing import NamedTuple
from urllib.parse import urlparse
from zipfile import ZipFile
@ -73,6 +74,12 @@ def get_synonyms(value: str, thesaurus: list[list[str]]) -> list[str]:
return results
class MatchedValues(NamedTuple):
version: str
system: str
arch: str
@dataclass
class GitRemoteInfo:
"""Extracts information about a repository"""
@ -225,7 +232,7 @@ def match_asset(
version: str | None = None,
system_mapping: dict[str, str] | None = None,
arch_mapping: dict[str, str] | None = None,
) -> dict[Any, Any]:
) -> tuple[dict[Any, Any], MatchedValues]:
"""Accepts a release and searches for an appropriate asset attached using
a provided template and some alternative mappings for version, system, and machine info
@ -286,7 +293,7 @@ def match_asset(
version=version_opt,
system=system_opt,
arch=arch_opt,
)
): MatchedValues(version=version_opt, system=system_opt, arch=arch_opt)
for version_opt, system_opt, arch_opt in product(
(
version.lstrip("v"),
@ -299,7 +306,7 @@ def match_asset(
for asset in release["assets"]:
if asset["name"] in expected_names:
return asset
return (asset, expected_names[asset["name"]])
raise ValueError(
f"Could not find asset named {expected_names} on release {release['name']}"
@ -581,41 +588,61 @@ def download_release(
arch_mapping: dict[str, str] | None = None,
extract_files: list[str] | None = None,
pre_release=False,
exec: str | None = None,
) -> list[Path]:
"""Convenience method for fetching, downloading and extracting a release"""
"""Convenience method for fetching, downloading, and extracting a release
This is slightly different than running off the commandline, it will execute the shell script
from the destination directory, not the current working directory.
"""
release = fetch_release(
remote_info,
version=version,
pre_release=pre_release,
)
asset = match_asset(
asset, matched_values = match_asset(
release,
format,
version=version,
system_mapping=system_mapping,
arch_mapping=arch_mapping,
)
formatted_files = (
[file.format(**matched_values._asdict()) for file in extract_files]
if extract_files
else None
)
files = download_asset(
asset,
extract_files=extract_files,
extract_files=formatted_files,
destination=destination,
)
if exec:
check_call(exec.format(asset["name"]), shell=True, cwd=destination)
return files
def main():
args = _parse_args()
# Fetch the release
release = fetch_release(
GitRemoteInfo(args.hostname, args.owner, args.repo),
version=args.version,
pre_release=args.prerelease,
)
asset = match_asset(
version = args.version or release["tag_name"]
# Find the asset to download using mapping rules
asset, matched_values = match_asset(
release,
args.format,
version=args.version,
version=version,
system_mapping=args.map_system,
arch_mapping=args.map_arch,
)
@ -627,9 +654,16 @@ def main():
print(asset["browser_download_url"])
return
# Format files to extract with version info, as this is sometimes included
formatted_files = (
[file.format(**matched_values._asdict()) for file in args.extract_files]
if args.extract_files
else None
)
files = download_asset(
asset,
extract_files=args.extract_files,
extract_files=formatted_files,
destination=args.destination,
)

View File

@ -199,6 +199,13 @@ class TestContentTypeDetection(unittest.TestCase):
)
def first_result(f):
def wrapper(*args, **kwargs):
return f(*args, **kwargs)[0]
return wrapper
class TestMatchAsset(unittest.TestCase):
def test_match_asset_versions(self, *_):
# Input variations:
@ -233,7 +240,7 @@ class TestMatchAsset(unittest.TestCase):
)
]
for test_case in happy_cases:
test_case.run(release_gitter.match_asset)
test_case.run(first_result(release_gitter.match_asset))
def test_match_asset_systems(self, *_):
# Input variations:
@ -347,7 +354,7 @@ class TestMatchAsset(unittest.TestCase):
),
)
for test_case in test_cases:
test_case.run(run_with_context)
test_case.run(first_result(run_with_context))
def test_match_asset_archs(self, *_):
# Input variations:
@ -468,7 +475,7 @@ class TestMatchAsset(unittest.TestCase):
),
)
for test_case in test_cases:
test_case.run(run_with_context)
test_case.run(first_result(run_with_context))
if __name__ == "__main__":