summaryrefslogtreecommitdiff
path: root/linuxnamespaces/tarutils.py
blob: c7a065cb02bf2a02157929360df67a485b0ed288 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/python3
# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

"""Extensions to the tarfile module.
 * ZstdTarFile extends TarFile to deal with zstd-compressed archives.
 * get_comptype guesses the compression used for an open TarFile.
"""

import tarfile
import typing


class ZstdTarFile(tarfile.TarFile):
    """Subclass of tarfile.TarFile that can read zstd compressed archives."""

    OPEN_METH = {"zst": "zstopen"} | tarfile.TarFile.OPEN_METH

    @classmethod
    def zstopen(
        cls,
        name: str,
        mode: typing.Literal["r", "w", "x"] = "r",
        fileobj: typing.BinaryIO | None = None,
        **kwargs: typing.Any,
    ) -> tarfile.TarFile:
        if mode not in ("r", "w", "x"):
            raise ValueError("mode must be 'r', 'w' or 'x'")
        openobj: str | typing.BinaryIO = name if fileobj is None else fileobj
        try:
            import zstandard
        except ImportError as err:
            raise tarfile.CompressionError(
                "zstandard module not available"
            ) from err
        if mode == "r":
            zfobj = zstandard.open(openobj, "rb")
        else:
            zfobj = zstandard.open(
                openobj,
                mode + "b",
                cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1),
            )
        try:
            tarobj = cls.taropen(name, mode, zfobj, **kwargs)
        except (OSError, EOFError, zstandard.ZstdError) as exc:
            zfobj.close()
            if mode == "r":
                raise tarfile.ReadError("not a zst file") from exc
            raise
        except:
            zfobj.close()
            raise
        # Setting the _extfileobj attribute is important to signal a need to
        # close this object and thus flush the compressed stream.
        # Unfortunately, tarfile.pyi doesn't know about it.
        tarobj._extfileobj = False  # type: ignore
        return tarobj


def get_comptype(tarobj: tarfile.TarFile) -> str:
    """Return the compression type used to compress the given TarFile."""
    # The tarfile module does not expose the compression method selected
    # for open mode "r:*" in any way. We can guess it from the module that
    # implements the fileobj.
    compmodule = tarobj.fileobj.__class__.__module__
    try:
        return {
            "bz2": "bz2",
            "gzip": "gz",
            "lzma": "xz",
            "_io": "tar",
            "zstd": "zst",
        }[compmodule]
    except KeyError:
        # pylint: disable=raise-missing-from  # no value in chaining
        raise ValueError(f"cannot guess comptype for module {compmodule}")