From 1d5308dff9a69c54f97d611655816267b3a7c203 Mon Sep 17 00:00:00 2001
From: Helmut Grohne <helmut@subdivi.de>
Date: Sat, 27 Jan 2024 22:39:03 +0100
Subject: examples/chroottar.py: support saving a tar after working inside

---
 examples/chroottar.py | 95 ++++++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 76 insertions(+), 19 deletions(-)

(limited to 'examples')

diff --git a/examples/chroottar.py b/examples/chroottar.py
index ed494b2..523892c 100755
--- a/examples/chroottar.py
+++ b/examples/chroottar.py
@@ -6,8 +6,10 @@
 a user and mount namespace.
 """
 
+import argparse
 import os
 import pathlib
+import socket
 import sys
 import tarfile
 import tempfile
@@ -27,46 +29,89 @@ class TarFile(tarfile.TarFile):
     def zstopen(
         cls, name: str, mode: str = "r", fileobj: None = None
     ) -> tarfile.TarFile:
-        if mode != "r":
-            raise NotImplementedError("zst only implmented for reading")
+        if mode not in ("r", "w", "x"):
+            raise NotImplementedError(f"mode `{mode}' not implemented for zst")
         if fileobj is not None:
             raise NotImplementedError("zst does not support a fileobj")
         try:
             import zstandard
         except ImportError:
             raise tarfile.CompressionError("zstandard module not available")
-        zfobj = zstandard.open(name, "rb")
+        if mode == "r":
+            zfobj = zstandard.open(name, "rb")
+        else:
+            zfobj = zstandard.open(
+                name,
+                mode + "b",
+                cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1),
+            )
         try:
-            tarobj = cls.taropen(name, "r", zfobj)
+            tarobj = cls.taropen(name, mode, zfobj)
         except (OSError, EOFError, zstandard.ZstdError) as exc:
             zfobj.close()
-            raise tarfile.ReadError("not a zst file") from exc
+            if mode == "r":
+                raise tarfile.ReadError("not a zst file") from exc
+            raise
         except:
             zfobj.close()
             raise
+        tarobj._extfileobj = False
         return tarobj
 
+    def get_comptype(self) -> str:
+        """Return the compression type used to compress the opened TarFile."""
+        # The tarfile module does not expose the compression method selected
+        # for open mode "r:*" in any way. We can guess it from the module that
+        # implements the fileobj.
+        compmodule = self.fileobj.__class__.__module__
+        try:
+            return {
+                "bz2": "bz2",
+                "gzip": "gz",
+                "lzma": "xz",
+                "_io": "tar",
+                "zstd": "zst",
+            }[compmodule]
+        except KeyError:
+            raise ValueError(f"cannot guess comptype for module {compmodule}")
+
 
 def main() -> None:
-    basetar = pathlib.Path(sys.argv[1])
-    assert basetar.exists()
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--save",
+        action="store_true",
+        help="save and replace the tarball at the end of the session",
+    )
+    parser.add_argument(
+        "basetar",
+        type=pathlib.Path,
+        action="store",
+        help="location of the tarball containing the chroot",
+    )
+    parser.add_argument(
+        "command",
+        nargs=argparse.REMAINDER,
+        help="command to run inside the chroot",
+    )
+    args = parser.parse_args()
+    assert args.basetar.exists()
     uidmap = linuxnamespaces.IDAllocation.loadsubid("uid").allocatemap(65536)
     gidmap = linuxnamespaces.IDAllocation.loadsubid("gid").allocatemap(65536)
     with tempfile.TemporaryDirectory() as tdir:
-        unshareevent = linuxnamespaces.EventFD()
-        setupevent = linuxnamespaces.EventFD()
+        parentsock, childsock = socket.socketpair()
         pid = os.fork()
         if pid == 0:
-            with TarFile.open(basetar, "r:*") as tarf:
+            parentsock.close()
+            with TarFile.open(args.basetar, "r:*") as tarf:
                 os.chdir(tdir)
                 linuxnamespaces.unshare(
                     linuxnamespaces.CloneFlags.NEWUSER
                     | linuxnamespaces.CloneFlags.NEWNS
                 )
-                unshareevent.write(1)
-                setupevent.read()
-                unshareevent.close()
-                setupevent.close()
+                childsock.send(tarf.get_comptype().encode("ascii") + b"\0")
+                childsock.recv(1)
+                childsock.close()
                 os.setreuid(0, 0)
                 os.setregid(0, 0)
                 os.setgroups([])
@@ -81,11 +126,14 @@ def main() -> None:
             linuxnamespaces.populate_dev("/", ".", pidns=False, tun=False)
             linuxnamespaces.pivot_root(".", ".")
             linuxnamespaces.umount(".", linuxnamespaces.UmountFlags.DETACH)
-            os.execlp(os.environ["SHELL"], os.environ["SHELL"])
+            if args.command:
+                os.execvp(args.command[0], args.command)
+            else:
+                os.execlp(os.environ["SHELL"], os.environ["SHELL"])
             os._exit(1)
 
-        unshareevent.read()
-        unshareevent.close()
+        childsock.close()
+        comptype = parentsock.recv(10).split(b"\0", 1)[0].decode("ascii")
         linuxnamespaces.newidmaps(pid, [uidmap], [gidmap])
         linuxnamespaces.unshare_user_idmap(
             [uidmap, linuxnamespaces.IDMapping(65536, os.getuid(), 1)],
@@ -93,9 +141,18 @@ def main() -> None:
         )
         os.chown(tdir, 0, 0)
         os.chmod(tdir, 0o755)
-        setupevent.write()
-        setupevent.close()
+        parentsock.send(b"\0")
+        parentsock.close()
         _, ret = os.waitpid(pid, 0)
+        if args.save and ret == 0:
+            tmptar = f"{args.basetar}.new"
+            try:
+                with TarFile.open(tmptar, "x:" + comptype) as tout:
+                    tout.add(tdir, ".")
+                os.rename(tmptar, args.basetar)
+            except:
+                os.unlink(tmptar)
+                raise
     sys.exit(ret)
 
 
-- 
cgit v1.2.3