summaryrefslogtreecommitdiff
path: root/linuxnamespaces/tarutils.py
blob: facb537210d0a6013ce2001452fc1de5d833e52b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/python3
# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

"""Extensions to the tarfile module.
 * ZstdTarFile extends TarFile to deal with zstd-compressed archives.
 * get_comptype guesses the compression used for an open TarFile.
 * XAttrTarFile extends TarFile to map extended attributes to PAX headers.
"""

import os
import tarfile
import typing


TarPath = str | bytes | os.PathLike[str] | os.PathLike[bytes]


class ZstdTarFile(tarfile.TarFile):
    """Subclass of tarfile.TarFile that can read zstd compressed archives."""

    OPEN_METH = {"zst": "zstopen"} | tarfile.TarFile.OPEN_METH

    @classmethod
    def zstopen(
        cls,
        name: str,
        mode: typing.Literal["r", "w", "x"] = "r",
        fileobj: typing.BinaryIO | None = None,
        **kwargs: typing.Any,
    ) -> tarfile.TarFile:
        if mode not in ("r", "w", "x"):
            raise ValueError("mode must be 'r', 'w' or 'x'")
        openobj: str | typing.BinaryIO = name if fileobj is None else fileobj
        try:
            import zstandard
        except ImportError as err:
            raise tarfile.CompressionError(
                "zstandard module not available"
            ) from err
        if mode == "r":
            zfobj = zstandard.open(openobj, "rb")
        else:
            zfobj = zstandard.open(
                openobj,
                mode + "b",
                cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1),
            )
        try:
            tarobj = cls.taropen(name, mode, zfobj, **kwargs)
        except (OSError, EOFError, zstandard.ZstdError) as exc:
            zfobj.close()
            if mode == "r":
                raise tarfile.ReadError("not a zst file") from exc
            raise
        except:
            zfobj.close()
            raise
        # Setting the _extfileobj attribute is important to signal a need to
        # close this object and thus flush the compressed stream.
        # Unfortunately, tarfile.pyi doesn't know about it.
        tarobj._extfileobj = False  # type: ignore
        return tarobj


def get_comptype(tarobj: tarfile.TarFile) -> str:
    """Return the compression type used to compress the given TarFile."""
    # The tarfile module does not expose the compression method selected
    # for open mode "r:*" in any way. We can guess it from the module that
    # implements the fileobj.
    compmodule = tarobj.fileobj.__class__.__module__
    try:
        return {
            "bz2": "bz2",
            "gzip": "gz",
            "lzma": "xz",
            "_io": "tar",
            "zstd": "zst",
        }[compmodule]
    except KeyError:
        # pylint: disable=raise-missing-from  # no value in chaining
        raise ValueError(f"cannot guess comptype for module {compmodule}")


class XAttrTarFile(tarfile.TarFile):
    """A subclass to tarfile.TarFile that adds support for extended attributes
    via SCHILY.xattr.* PAX headers to extraction and creation of archives. It
    can be used as a mixin class with others as it does not add any state.
    """

    def extract(
        self,
        member: tarfile.TarInfo | str,
        path: TarPath = "",
        set_attrs: bool = True,
        **kwargs: typing.Any,
    ) -> None:
        """Refer to tarfile.TarFile.extract. In addition, SCHILY.xattr.* PAX
        headers are examined and applied as extended attributes if set_attrs is
        true-ish.
        """
        if not set_attrs:
            super().extract(member, path, False, **kwargs)
            return

        # We also need the tarinfo, so mimic the start of the built-in extract.
        if isinstance(member, str):
            tarinfo = self.getmember(member)
        else:
            tarinfo = member

        super().extract(tarinfo, path, True, **kwargs)

        # mypy is unhappy about the next line, but we have the same code in
        # TarFile.extract and if it bails here, it also bails there.
        path = os.path.join(path, tarinfo.name)  # type: ignore

        for attr, value in tarinfo.pax_headers.items():
            if not attr.startswith("SCHILY.xattr."):
                continue
            attr = attr.removeprefix("SCHILY.xattr.")
            os.setxattr(
                path,
                attr,
                value.encode(self.encoding or "utf8", "surrogateescape"),
                follow_symlinks=False,
            )

    def gettarinfo(
        self,
        name: TarPath | None = None,
        arcname: str | None = None,
        fileobj: typing.IO[bytes] | None = None,
    ) -> tarfile.TarInfo:
        tarinfo = super().gettarinfo(name, arcname, fileobj)
        path: int | TarPath
        if fileobj is not None:
            path = fileobj.fileno()
        elif name is not None:
            path = name
        else:
            raise ValueError("gettarinfo requires a name or fileobj")
        dereference = True if self.dereference is None else self.dereference
        for attr in os.listxattr(path, follow_symlinks=dereference):
            key = "SCHILY.xattr." + attr
            value = os.getxattr(
                path, attr, follow_symlinks=dereference
            ).decode(self.encoding or "utf8", "surrogateescape")
            # TarInfo.pax_headers is designated as (read-only) Mapping, but it
            # really is a writable dict.
            tarinfo.pax_headers[key] = value  # type: ignore[index]
        return tarinfo