1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
|
#!/usr/bin/python3
# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3
"""Extensions to the tarfile module.
* ZstdTarFile extends TarFile to deal with zstd-compressed archives.
* get_comptype guesses the compression used for an open TarFile.
* XAttrTarFile extends TarFile to map extended attributes to PAX headers.
"""
import os
import tarfile
import typing
TarPath = str | bytes | os.PathLike[str] | os.PathLike[bytes]
class ZstdTarFile(tarfile.TarFile):
"""Subclass of tarfile.TarFile that can read zstd compressed archives."""
OPEN_METH = {"zst": "zstopen"} | tarfile.TarFile.OPEN_METH
@classmethod
def zstopen(
cls,
name: str,
mode: typing.Literal["r", "w", "x"] = "r",
fileobj: typing.BinaryIO | None = None,
**kwargs: typing.Any,
) -> tarfile.TarFile:
if mode not in ("r", "w", "x"):
raise ValueError("mode must be 'r', 'w' or 'x'")
openobj: str | typing.BinaryIO = name if fileobj is None else fileobj
try:
import zstandard
except ImportError as err:
raise tarfile.CompressionError(
"zstandard module not available"
) from err
if mode == "r":
zfobj = zstandard.open(openobj, "rb")
else:
zfobj = zstandard.open(
openobj,
mode + "b",
cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1),
)
try:
tarobj = cls.taropen(name, mode, zfobj, **kwargs)
except (OSError, EOFError, zstandard.ZstdError) as exc:
zfobj.close()
if mode == "r":
raise tarfile.ReadError("not a zst file") from exc
raise
except:
zfobj.close()
raise
# Setting the _extfileobj attribute is important to signal a need to
# close this object and thus flush the compressed stream.
# Unfortunately, tarfile.pyi doesn't know about it.
tarobj._extfileobj = False # type: ignore
return tarobj
def get_comptype(tarobj: tarfile.TarFile) -> str:
"""Return the compression type used to compress the given TarFile."""
# The tarfile module does not expose the compression method selected
# for open mode "r:*" in any way. We can guess it from the module that
# implements the fileobj.
compmodule = tarobj.fileobj.__class__.__module__
try:
return {
"bz2": "bz2",
"gzip": "gz",
"lzma": "xz",
"_io": "tar",
"zstd": "zst",
}[compmodule]
except KeyError:
# pylint: disable=raise-missing-from # no value in chaining
raise ValueError(f"cannot guess comptype for module {compmodule}")
class XAttrTarFile(tarfile.TarFile):
"""A subclass to tarfile.TarFile that adds support for extended attributes
via SCHILY.xattr.* PAX headers to extraction and creation of archives. It
can be used as a mixin class with others as it does not add any state.
"""
def extract(
self,
member: tarfile.TarInfo | str,
path: TarPath = "",
set_attrs: bool = True,
**kwargs: typing.Any,
) -> None:
"""Refer to tarfile.TarFile.extract. In addition, SCHILY.xattr.* PAX
headers are examined and applied as extended attributes if set_attrs is
true-ish.
"""
if not set_attrs:
super().extract(member, path, False, **kwargs)
return
# We also need the tarinfo, so mimic the start of the built-in extract.
if isinstance(member, str):
tarinfo = self.getmember(member)
else:
tarinfo = member
super().extract(tarinfo, path, True, **kwargs)
# mypy is unhappy about the next line, but we have the same code in
# TarFile.extract and if it bails here, it also bails there.
path = os.path.join(path, tarinfo.name) # type: ignore
for attr, value in tarinfo.pax_headers.items():
if not attr.startswith("SCHILY.xattr."):
continue
attr = attr.removeprefix("SCHILY.xattr.")
os.setxattr(
path,
attr,
value.encode(self.encoding or "utf8", "surrogateescape"),
follow_symlinks=False,
)
def gettarinfo(
self,
name: TarPath | None = None,
arcname: str | None = None,
fileobj: typing.IO[bytes] | None = None,
) -> tarfile.TarInfo:
tarinfo = super().gettarinfo(name, arcname, fileobj)
path: int | TarPath
if fileobj is not None:
path = fileobj.fileno()
elif name is not None:
path = name
else:
raise ValueError("gettarinfo requires a name or fileobj")
dereference = True if self.dereference is None else self.dereference
for attr in os.listxattr(path, follow_symlinks=dereference):
key = "SCHILY.xattr." + attr
value = os.getxattr(
path, attr, follow_symlinks=dereference
).decode(self.encoding or "utf8", "surrogateescape")
# TarInfo.pax_headers is designated as (read-only) Mapping, but it
# really is a writable dict.
tarinfo.pax_headers[key] = value # type: ignore[index]
return tarinfo
|