summaryrefslogtreecommitdiff
path: root/linuxnamespaces
diff options
context:
space:
mode:
Diffstat (limited to 'linuxnamespaces')
-rw-r--r--linuxnamespaces/tarutils.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/linuxnamespaces/tarutils.py b/linuxnamespaces/tarutils.py
new file mode 100644
index 0000000..c7a065c
--- /dev/null
+++ b/linuxnamespaces/tarutils.py
@@ -0,0 +1,77 @@
+#!/usr/bin/python3
+# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
+# SPDX-License-Identifier: GPL-3
+
+"""Extensions to the tarfile module.
+ * ZstdTarFile extends TarFile to deal with zstd-compressed archives.
+ * get_comptype guesses the compression used for an open TarFile.
+"""
+
+import tarfile
+import typing
+
+
+class ZstdTarFile(tarfile.TarFile):
+ """Subclass of tarfile.TarFile that can read zstd compressed archives."""
+
+ OPEN_METH = {"zst": "zstopen"} | tarfile.TarFile.OPEN_METH
+
+ @classmethod
+ def zstopen(
+ cls,
+ name: str,
+ mode: typing.Literal["r", "w", "x"] = "r",
+ fileobj: typing.BinaryIO | None = None,
+ **kwargs: typing.Any,
+ ) -> tarfile.TarFile:
+ if mode not in ("r", "w", "x"):
+ raise ValueError("mode must be 'r', 'w' or 'x'")
+ openobj: str | typing.BinaryIO = name if fileobj is None else fileobj
+ try:
+ import zstandard
+ except ImportError as err:
+ raise tarfile.CompressionError(
+ "zstandard module not available"
+ ) from err
+ if mode == "r":
+ zfobj = zstandard.open(openobj, "rb")
+ else:
+ zfobj = zstandard.open(
+ openobj,
+ mode + "b",
+ cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1),
+ )
+ try:
+ tarobj = cls.taropen(name, mode, zfobj, **kwargs)
+ except (OSError, EOFError, zstandard.ZstdError) as exc:
+ zfobj.close()
+ if mode == "r":
+ raise tarfile.ReadError("not a zst file") from exc
+ raise
+ except:
+ zfobj.close()
+ raise
+ # Setting the _extfileobj attribute is important to signal a need to
+ # close this object and thus flush the compressed stream.
+ # Unfortunately, tarfile.pyi doesn't know about it.
+ tarobj._extfileobj = False # type: ignore
+ return tarobj
+
+
+def get_comptype(tarobj: tarfile.TarFile) -> str:
+ """Return the compression type used to compress the given TarFile."""
+ # The tarfile module does not expose the compression method selected
+ # for open mode "r:*" in any way. We can guess it from the module that
+ # implements the fileobj.
+ compmodule = tarobj.fileobj.__class__.__module__
+ try:
+ return {
+ "bz2": "bz2",
+ "gzip": "gz",
+ "lzma": "xz",
+ "_io": "tar",
+ "zstd": "zst",
+ }[compmodule]
+ except KeyError:
+ # pylint: disable=raise-missing-from # no value in chaining
+ raise ValueError(f"cannot guess comptype for module {compmodule}")