#!/usr/bin/python3 # Copyright 2024 Helmut Grohne # SPDX-License-Identifier: GPL-3 """Extensions to the tarfile module. * ZstdTarFile extends TarFile to deal with zstd-compressed archives. * get_comptype guesses the compression used for an open TarFile. * XAttrTarFile extends TarFile to map extended attributes to PAX headers. """ import os import tarfile import typing TarPath = str | bytes | os.PathLike[str] | os.PathLike[bytes] class ZstdTarFile(tarfile.TarFile): """Subclass of tarfile.TarFile that can read zstd compressed archives.""" # mypy types OPEN_METH as Mapping rather than dict while it really is a # dict. Hence, it complains that there is no __or__ for dict and Mapping. OPEN_METH = { "zst": "zstopen", } | tarfile.TarFile.OPEN_METH # type: ignore[operator] @classmethod def zstopen( cls, name: str, mode: typing.Literal["r", "w", "x"] = "r", fileobj: typing.BinaryIO | None = None, *, compresslevel: int | None = None, threads: int | None = None, **kwargs: typing.Any, ) -> tarfile.TarFile: """Open a zstd compressed tar archive with the given name for readin or writing. Appending is not supported. The class allows customizing the compression level and the compression concurrency (default parallel) while decompression ignores those arguments. """ if mode not in ("r", "w", "x"): raise ValueError("mode must be 'r', 'w' or 'x'") openobj: str | typing.BinaryIO = name if fileobj is None else fileobj try: import zstandard except ImportError as err: raise tarfile.CompressionError( "zstandard module not available" ) from err if mode == "r": zfobj = zstandard.open(openobj, "rb") else: if threads is None: threads = -1 if compresslevel is not None: if compresslevel > 22: raise ValueError( "invalid compression level {compresslevel}" ) cctx = zstandard.ZstdCompressor( write_checksum=True, threads=threads, level=compresslevel ) else: cctx = zstandard.ZstdCompressor( write_checksum=True, threads=threads ) zfobj = zstandard.open(openobj, mode + "b", cctx=cctx) try: tarobj = cls.taropen(name, mode, zfobj, **kwargs) except (OSError, EOFError, zstandard.ZstdError) as exc: zfobj.close() if mode == "r": raise tarfile.ReadError("not a zst file") from exc raise except: zfobj.close() raise # Setting the _extfileobj attribute is important to signal a need to # close this object and thus flush the compressed stream. # Unfortunately, tarfile.pyi doesn't know about it. tarobj._extfileobj = False # type: ignore return tarobj def get_comptype(tarobj: tarfile.TarFile) -> str: """Return the compression type used to compress the given TarFile.""" # The tarfile module does not expose the compression method selected # for open mode "r:*" in any way. We can guess it from the module that # implements the fileobj. compmodule = tarobj.fileobj.__class__.__module__ try: return { "bz2": "bz2", "gzip": "gz", "lzma": "xz", "_io": "tar", "zstd": "zst", }[compmodule] except KeyError: # pylint: disable=raise-missing-from # no value in chaining raise ValueError(f"cannot guess comptype for module {compmodule}") class XAttrTarFile(tarfile.TarFile): """A subclass to tarfile.TarFile that adds support for extended attributes via SCHILY.xattr.* PAX headers to extraction and creation of archives. It can be used as a mixin class with others as it does not add any state. """ def extract( self, member: tarfile.TarInfo | str, path: TarPath = "", set_attrs: bool = True, **kwargs: typing.Any, ) -> None: """Refer to tarfile.TarFile.extract. In addition, SCHILY.xattr.* PAX headers are examined and applied as extended attributes if set_attrs is true-ish. """ if not set_attrs: super().extract(member, path, False, **kwargs) return # We also need the tarinfo, so mimic the start of the built-in extract. if isinstance(member, str): tarinfo = self.getmember(member) else: tarinfo = member super().extract(tarinfo, path, True, **kwargs) # mypy is unhappy about the next line, but we have the same code in # TarFile.extract and if it bails here, it also bails there. path = os.path.join(path, tarinfo.name) # type: ignore for attr, value in tarinfo.pax_headers.items(): if not attr.startswith("SCHILY.xattr."): continue attr = attr.removeprefix("SCHILY.xattr.") os.setxattr( path, attr, value.encode(self.encoding or "utf8", "surrogateescape"), follow_symlinks=False, ) def gettarinfo( self, name: TarPath | None = None, arcname: str | None = None, fileobj: typing.IO[bytes] | None = None, ) -> tarfile.TarInfo: tarinfo = super().gettarinfo(name, arcname, fileobj) path: int | TarPath if fileobj is not None: path = fileobj.fileno() elif name is not None: path = name else: raise ValueError("gettarinfo requires a name or fileobj") dereference = True if self.dereference is None else self.dereference for attr in os.listxattr(path, follow_symlinks=dereference): key = "SCHILY.xattr." + attr value = os.getxattr( path, attr, follow_symlinks=dereference ).decode(self.encoding or "utf8", "surrogateescape") # TarInfo.pax_headers is designated as (read-only) Mapping, but it # really is a writable dict. tarinfo.pax_headers[key] = value # type: ignore[index] return tarinfo