Improve content type detection
continuous-integration/drone/push Build is passing Details

Cycle through detected content types and use the first supported one.

Adds tests to cover cases of priority and exceptions.
This commit is contained in:
IamTheFij 2022-10-11 12:20:57 -07:00
parent e6a269af3d
commit ab0603d1b9
2 changed files with 86 additions and 12 deletions

View File

@ -24,6 +24,10 @@ import requests
# Extract metadata from repo
class UnsupportedContentTypeError(ValueError):
pass
class InvalidRemoteError(ValueError):
pass
@ -285,7 +289,9 @@ class PackageAdapter:
):
self._package = TarFile.open(fileobj=BytesIO(response.content), mode="r:*")
else:
raise ValueError(f"Unknown or unsupported content type {content_type}")
raise UnsupportedContentTypeError(
f"Unknown or unsupported content type {content_type}"
)
def get_names(self) -> list[str]:
"""Get list of all file names in package"""
@ -322,6 +328,27 @@ class PackageAdapter:
return members
def get_asset_package(
asset: dict[str, Any], result: requests.Response
) -> PackageAdapter:
possible_content_types = (
asset.get("content_type"),
"+".join(t for t in guess_type(asset["name"]) if t is not None),
)
for content_type in possible_content_types:
if not content_type:
continue
try:
return PackageAdapter(content_type, result)
except UnsupportedContentTypeError:
continue
else:
raise UnsupportedContentTypeError(
"Cannot extract files from archive because we don't recognize the content type"
)
def download_asset(
asset: dict[Any, Any],
extract_files: Optional[list[str]] = None,
@ -344,18 +371,8 @@ def download_asset(
result = requests.get(asset["browser_download_url"])
content_type = asset.get(
"content_type",
guess_type(asset["name"]),
)
if extract_files is not None:
if isinstance(content_type, tuple):
content_type = "+".join(t for t in content_type if t is not None)
if not content_type:
raise TypeError(
"Cannot extract files from archive because we don't recognize the content type"
)
package = PackageAdapter(content_type, result)
package = get_asset_package(asset, result)
extract_files = package.extractall(path=destination, members=extract_files)
return [destination / name for name in extract_files]

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import unittest
from pathlib import Path
from tarfile import TarFile
from typing import Any
from typing import Callable
from typing import NamedTuple
@ -9,6 +10,7 @@ from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import mock_open
from unittest.mock import patch
from zipfile import ZipFile
import requests
@ -141,5 +143,60 @@ class TestVersionInfo(unittest.TestCase):
release_gitter.read_version()
@patch("release_gitter.ZipFile", autospec=True)
@patch("release_gitter.BytesIO", autospec=True)
class TestContentTypeDetection(unittest.TestCase):
def test_asset_encoding_priority(self, *_):
package = release_gitter.get_asset_package(
{
"content_type": "application/x-tar",
"name": "test.zip",
},
MagicMock(spec=["raw", "content"]),
)
# Tar should take priority over the file name zip extension
self.assertIsInstance(package._package, TarFile)
def test_fallback_to_supported_encoding(self, *_):
package = release_gitter.get_asset_package(
{
"content_type": "application/octetstream",
"name": "test.zip",
},
MagicMock(spec=["raw", "content"]),
)
# Should fall back to zip extension
self.assertIsInstance(package._package, ZipFile)
def test_missing_only_name_content_type(self, *_):
package = release_gitter.get_asset_package(
{
"name": "test.zip",
},
MagicMock(spec=["raw", "content"]),
)
# Should fall back to zip extension
self.assertIsInstance(package._package, ZipFile)
def test_no_content_types(self, *_):
with self.assertRaises(release_gitter.UnsupportedContentTypeError):
release_gitter.get_asset_package(
{
"name": "test",
},
MagicMock(spec=["raw", "content"]),
)
def test_no_supported_content_types(self, *_):
with self.assertRaises(release_gitter.UnsupportedContentTypeError):
release_gitter.get_asset_package(
{
"content_type": "application/octetstream",
"name": "test",
},
MagicMock(spec=["raw", "content"]),
)
if __name__ == "__main__":
unittest.main()