summaryrefslogtreecommitdiff
path: root/examples/chroottar.py
blob: d2299569b0ccfceb2bae5ca16ae4766d344d64c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/python3
# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

"""Extract a given tarball into a temporary location and chroot into it inside
a user and mount namespace.
"""

import argparse
import os
import pathlib
import socket
import sys
import tarfile
import tempfile
import typing

if __file__.split("/")[-2:-1] == ["examples"]:
    sys.path.insert(0, "/".join(__file__.split("/")[:-2]))

import linuxnamespaces


class TarFile(tarfile.TarFile):
    """Subclass of tarfile.TarFile that can read zstd compressed archives."""

    OPEN_METH = {"zst": "zstopen"} | tarfile.TarFile.OPEN_METH

    @classmethod
    def zstopen(
        cls,
        name: str,
        mode: typing.Literal["r", "w", "x"] = "r",
        fileobj: None = None,
    ) -> tarfile.TarFile:
        if mode not in ("r", "w", "x"):
            raise NotImplementedError(f"mode `{mode}' not implemented for zst")
        if fileobj is not None:
            raise NotImplementedError("zst does not support a fileobj")
        try:
            import zstandard
        except ImportError as err:
            raise tarfile.CompressionError(
                "zstandard module not available"
            ) from err
        if mode == "r":
            zfobj = zstandard.open(name, "rb")
        else:
            zfobj = zstandard.open(
                name,
                mode + "b",
                cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1),
            )
        try:
            tarobj = cls.taropen(name, mode, zfobj)
        except (OSError, EOFError, zstandard.ZstdError) as exc:
            zfobj.close()
            if mode == "r":
                raise tarfile.ReadError("not a zst file") from exc
            raise
        except:
            zfobj.close()
            raise
        # Setting the _extfileobj attribute is important to signal a need to
        # close this object and thus flush the compressed stream.
        # Unfortunately, tarfile.pyi doesn't know about it.
        tarobj._extfileobj = False  # type: ignore
        return tarobj

    def get_comptype(self) -> str:
        """Return the compression type used to compress the opened TarFile."""
        # The tarfile module does not expose the compression method selected
        # for open mode "r:*" in any way. We can guess it from the module that
        # implements the fileobj.
        compmodule = self.fileobj.__class__.__module__
        try:
            return {
                "bz2": "bz2",
                "gzip": "gz",
                "lzma": "xz",
                "_io": "tar",
                "zstd": "zst",
            }[compmodule]
        except KeyError:
            # pylint: disable=raise-missing-from  # no value in chaining
            raise ValueError(f"cannot guess comptype for module {compmodule}")


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save",
        action="store_true",
        help="save and replace the tarball at the end of the session",
    )
    parser.add_argument(
        "basetar",
        type=pathlib.Path,
        action="store",
        help="location of the tarball containing the chroot",
    )
    parser.add_argument(
        "command",
        nargs=argparse.REMAINDER,
        help="command to run inside the chroot",
    )
    args = parser.parse_args()
    assert args.basetar.exists()
    uidmap = linuxnamespaces.IDAllocation.loadsubid("uid").allocatemap(65536)
    gidmap = linuxnamespaces.IDAllocation.loadsubid("gid").allocatemap(65536)
    with tempfile.TemporaryDirectory() as tdir:
        parentsock, childsock = socket.socketpair()
        pid = os.fork()
        if pid == 0:
            parentsock.close()
            # Once we drop privileges via setreuid and friends, we may become
            # unable to open basetar or to chdir to tdir, so do those early.
            with TarFile.open(args.basetar, "r:*") as tarf:
                os.chdir(tdir)
                linuxnamespaces.unshare(
                    linuxnamespaces.CloneFlags.NEWUSER
                    | linuxnamespaces.CloneFlags.NEWNS
                )
                childsock.send(tarf.get_comptype().encode("ascii") + b"\0")
                childsock.recv(1)
                childsock.close()
                # The other process will now have set up our id mapping and
                # will have changed ownership of our working directory.
                os.setreuid(0, 0)
                os.setregid(0, 0)
                os.setgroups([])
                for tmem in tarf:
                    if tmem.name.removeprefix("./").startswith("dev/"):
                        continue
                    # Our namespace has privileged uids allocated high. Hence
                    # clamp unpacking.
                    if tmem.uid >= 65536 or tmem.gid >= 65536:
                        tmem.mode &= ~0o7000
                    if tmem.uid >= 65536:
                        tmem.uid = 0
                    if tmem.gid >= 65536:
                        tmem.gid = 0
                    tarf.extract(tmem, numeric_owner=True)
            linuxnamespaces.bind_mount(".", "/mnt", recursive=True)
            os.chdir("/mnt")
            linuxnamespaces.bind_mount("/proc", "proc", recursive=True)
            linuxnamespaces.bind_mount("/sys", "sys", recursive=True)
            linuxnamespaces.populate_dev("/", ".", pidns=False, tun=False)
            linuxnamespaces.pivot_root(".", ".")
            linuxnamespaces.umount(".", linuxnamespaces.UmountFlags.DETACH)
            if args.command:
                os.execvp(args.command[0], args.command)
            else:
                os.execlp(os.environ["SHELL"], os.environ["SHELL"])
            os._exit(1)

        childsock.close()
        comptype = parentsock.recv(10).split(b"\0", 1)[0].decode("ascii")
        linuxnamespaces.newidmaps(pid, [uidmap], [gidmap])
        # We still had to be in the initial namespace to call newidmaps and
        # now we transition to a namespace that can access both the container
        # and the files of the invoking user.
        linuxnamespaces.unshare_user_idmap(
            [uidmap, linuxnamespaces.IDMapping(65536, os.getuid(), 1)],
            [gidmap, linuxnamespaces.IDMapping(65536, os.getgid(), 1)],
        )
        os.chown(tdir, 0, 0)
        os.chmod(tdir, 0o755)
        parentsock.send(b"\0")
        parentsock.close()
        _, ret = os.waitpid(pid, 0)
        if args.save and ret == 0:
            tmptar = f"{args.basetar}.new"
            try:
                with TarFile.open(tmptar, "x:" + comptype) as tout:
                    tout.add(tdir, ".")
                os.rename(tmptar, args.basetar)
            except:
                os.unlink(tmptar)
                raise
    sys.exit(ret)


if __name__ == "__main__":
    main()