# SPDX-License-Identifier: MIT """Common functions used by multiple backends""" from __future__ import annotations import argparse import contextlib import fnmatch import hashlib import importlib.resources import json import multiprocessing import pathlib import tarfile import tempfile import typing import urllib.parse import debian.deb822 import requests try: import jsonschema except ImportError: jsonschema = None def json_load(filecontextmanager: typing.ContextManager[typing.IO[typing.AnyStr]]) -> typing.Any: """Load the json context from a file context manager.""" with filecontextmanager as fileobj: return json.load(fileobj) JsonObject = typing.Dict[str, typing.Any] def buildjson_validate(buildobj: JsonObject) -> None: """Validate the given build json object against the schema.""" if jsonschema: jsonschema.validate( buildobj, json_load( importlib.resources.open_text("mdbp", "build_schema.json"))) def buildjson_patch_relative(buildobj: JsonObject, basedir: pathlib.PurePath) -> None: """Resolve relative paths used in the buildobj using the given basedir: * .input.dscpath * .output.directory The operation is performed in-place and modifes the given buildobj. """ for attrs in (("input", "dscpath"), ("output", "directory")): obj = buildobj for attr in attrs[:-1]: try: obj = obj[attr] except KeyError: break else: with contextlib.suppress(KeyError): obj[attrs[-1]] = str(basedir / pathlib.Path(obj[attrs[-1]])) def buildjson(filename: str) -> JsonObject: """Type constructor for argparse validating a build json file path and returning the parsed json object.""" buildobj = json_load(argparse.FileType("r")(filename)) buildjson_validate(buildobj) buildjson_patch_relative(buildobj, pathlib.Path(filename).parent) assert isinstance(buildobj, dict) return buildobj def compute_env(build: JsonObject) -> typing.Dict[str, str]: """Compute the process environment from the build object.""" env = dict(PATH="/usr/bin:/bin") env.update(build.get("environment", {})) parallel = build.get("parallel") if parallel == "auto": parallel = "%d" % multiprocessing.cpu_count() options = build.get("options", []) if parallel: options.append("parallel=" + str(parallel)) if options: env["DEB_BUILD_OPTIONS"] = " ".join(options) return env class HashSumMismatch(Exception): """Raised from `hash_check` when validation fails.""" # pylint does not grok from __future__ import annotations yet # pylint: disable=E1101,W0212 def hash_check(iterable: typing.Iterable[bytes], hashobj: hashlib._Hash, expected_digest: str) -> \ typing.Iterator[bytes]: """Wraps an iterable that yields bytes. It doesn't modify the sequence, but on the final element it verifies that the concatenation of bytes yields an expected digest value. Upon failure, the final next() results in a HashSumMismatch rather than StopIteration. """ for data in iterable: hashobj.update(data) yield data if hashobj.hexdigest() != expected_digest: raise HashSumMismatch() def download(uri: str, checksums: typing.Dict[str, str], dest: pathlib.Path) -> None: """Download the given uri and save it as the given dest path provided that the given checksums match. When checksums do not match, raise a HashSumMismatch. """ with requests.get(uri, stream=True) as resp: resp.raise_for_status() iterable = resp.iter_content(None) for algo, csum in checksums.items(): iterable = hash_check(iterable, hashlib.new(algo), csum) try: with dest.open("wb") as out: for chunk in iterable: out.write(chunk) except HashSumMismatch: dest.unlink() raise def download_dsc(buildinput: JsonObject, destdir: pathlib.Path) -> pathlib.Path: """Download the .input.dscuri including referenced components to the given destination directory and return the path to the contained .dsc file. """ dscuri = buildinput["dscuri"] dscpath = destdir / dscuri.split("/")[-1] # mypy doesn't grok this: assert isinstance(dscpath, pathlib.Path) download(dscuri, buildinput.get("checksums", {}), dscpath) files: typing.Dict[str, typing.Dict[str, str]] = {} with dscpath.open("r") as dscf: for key, value in debian.deb822.Dsc(dscf).items(): if key.lower().startswith("checksums-"): for entry in value: algo = key[10:].lower() files.setdefault(entry["name"], dict())[algo] = entry[algo] for name, checksums in files.items(): download(urllib.parse.urljoin(dscuri, name), checksums, destdir / name) return dscpath @contextlib.contextmanager def get_dsc(build: JsonObject) -> typing.Iterator[pathlib.Path]: """A context manager that provides a path pointing at the .dsc file for the duration of the context. If the .dsc is supplied as a path, it simply is returned. If it is supplied as a uri, it and the referred components are downloaded to a temporary location. """ try: dscpath = build["input"]["dscpath"] except KeyError: with tempfile.TemporaryDirectory() as tdir: yield download_dsc(build["input"], pathlib.Path(tdir)) else: yield pathlib.Path(dscpath) def get_dsc_files(dscpath: pathlib.Path) -> typing.List[pathlib.Path]: """Get the component names referenced by the .dsc file.""" with dscpath.open("r") as dscf: dsc = debian.deb822.Dsc(dscf) return [dscpath.parent / item["name"] for item in dsc["Files"]] def make_option(optname: str, value: typing.Optional[str]) -> typing.List[str]: """Construct a valued option if a value is given.""" if not value: return [] if optname.endswith("="): return [optname + value] return [optname, value] def profile_option(build: JsonObject, optname: str) -> typing.List[str]: """Construct the option for specifying build profiles if required.""" return make_option(optname, ",".join(build.get("profiles", ()))) def tar_add(tarobj: tarfile.TarFile, path: pathlib.Path) -> None: """Add the given file as its basename to the tarobj retaining its modification time, but no mode or ownership information. """ info = tarfile.TarInfo(path.name) statres = path.stat() info.size = statres.st_size info.mtime = int(statres.st_mtime) with path.open("rb") as fobj: tarobj.addfile(info, fobj) def clean_dir(directory: pathlib.Path, patterns: typing.List[str]) -> None: """Delete all entries of `directory` that match none of the given `patterns`.""" for entry in directory.iterdir(): if not any(fnmatch.fnmatchcase(entry.name, pattern) for pattern in patterns): entry.unlink()