#!/usr/bin/python3
# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

"""Emulate schroot using namespaces sufficiently well that sbuild can deal with
it but not any better. It assumes that ~/.cache/sbuild contains tars suitable
for sbuild --chroot-mode=unshare. Additionally, those tars are expected to
contain the non-essential passwd package. The actual sessions are stored in
~/.cache/unschroot. For using it with sbuild, your sbuildrc should contain:

    $chroot_mode = "schroot";
    $schroot = "/path/to/unschroot";
"""


import argparse
import grp
import os
import pathlib
import pwd
import shutil
import signal
import socket
import stat
import sys
import tempfile
import typing

if __file__.split("/")[-2:-1] == ["examples"]:
    sys.path.insert(0, "/".join(__file__.split("/")[:-2]))

import linuxnamespaces
import linuxnamespaces.tarutils


class TarFile(
    linuxnamespaces.tarutils.ZstdTarFile, linuxnamespaces.tarutils.XAttrTarFile
):
    pass


def write_etc_hosts(root: os.PathLike[str] | str) -> None:
    etc_hosts = pathlib.Path(root) / "etc/hosts"
    if not etc_hosts.exists():
        etc_hosts.write_text(
            """127.0.0.1 localhost
127.0.1.1 %s
::1 localhost ip6-localhost ip6-loopback
"""
            % socket.gethostname(),
            encoding="ascii",
        )


def load_subids() -> (
    tuple[linuxnamespaces.IDMapping, linuxnamespaces.IDMapping]
):
    return (
        linuxnamespaces.IDAllocation.loadsubid("uid").allocatemap(65536),
        linuxnamespaces.IDAllocation.loadsubid("gid").allocatemap(65536),
    )


# Ignore $HOME as sbuild sets to something invalid
HOME = pathlib.Path(pwd.getpwuid(os.getuid()).pw_dir)
CACHE_SBUILD = HOME / ".cache/sbuild"
CACHE_UNSCHROOT = HOME / ".cache/unschroot"
CACHE_DIRECTORY_CHROOTS = HOME / ".cache/directory_chroots"


class ChrootBase:
    namespace: str
    name: str

    def __init__(self) -> None:
        self.aliases: set[str] = set()

    def infodata(self) -> dict[str, str]:
        return {
            "Name": self.name,
            "Aliases": " ".join(sorted(self.aliases)),
        }

    def infostr(self) -> str:
        return f"--- {self.namespace} ---\n" + "".join(
            map("%s %s\n".__mod__, self.infodata().items())
        )


class SourceChroot(ChrootBase):
    namespace = "Chroot"

    def newsession(self) -> "SessionChroot":
        raise NotImplementedError


class SessionChroot(ChrootBase):
    namespace = "Session"

    def infodata(self) -> dict[str, str]:
        data = super().infodata()
        data["Session Purged"] = "true"
        data["Type"] = "unshare"
        return data

    def mount(self) -> pathlib.Path:
        raise NotImplementedError


class TarSourceChroot(SourceChroot):
    def __init__(self, path: pathlib.Path):
        super().__init__()
        self.path = path
        self.name = path.name.split(".", 1)[0] + "-sbuild"

    def infodata(self) -> dict[str, str]:
        data = super().infodata()
        data["Type"] = "file"
        data["File"] = str(self.path)
        return data

    def newsession(self) -> "TarSessionChroot":
        CACHE_UNSCHROOT.mkdir(parents=True, exist_ok=True)
        session = TarSessionChroot(
            pathlib.Path(tempfile.mkdtemp(prefix="tar-", dir=CACHE_UNSCHROOT)),
        )

        uidmap, gidmap = load_subids()
        mainsock, childsock = socket.socketpair()
        with TarFile.open(self.path, "r:*") as tarf:
            pid = os.fork()
            if pid == 0:
                mainsock.close()
                os.chdir(session.path)
                linuxnamespaces.unshare(
                    linuxnamespaces.CloneFlags.NEWUSER
                    | linuxnamespaces.CloneFlags.NEWNS,
                )
                childsock.send(b"\0")
                childsock.recv(1)
                childsock.close()
                os.setgid(0)
                os.setuid(0)
                for tmem in tarf:
                    if not tmem.name.startswith(("dev/", "./dev/")):
                        tarf.extract(tmem, numeric_owner=True)
                write_etc_hosts(".")
                sys.exit(0)
        childsock.close()
        mainsock.recv(1)
        pid2 = os.fork()
        if pid2 == 0:
            linuxnamespaces.unshare_user_idmap(
                [uidmap, linuxnamespaces.IDMapping(65536, os.getuid(), 1)],
                [gidmap, linuxnamespaces.IDMapping(65536, os.getgid(), 1)],
            )
            os.chown(session.path, 0, 0)
            session.path.chmod(0o755)
            sys.exit(0)
        linuxnamespaces.newidmaps(pid, [uidmap], [gidmap])
        _, ret = os.waitpid(pid2, 0)
        assert ret == 0
        mainsock.send(b"\0")
        mainsock.close()
        _, ret = os.waitpid(pid, 0)
        assert ret == 0
        return session


class TarSessionChroot(SessionChroot):
    def __init__(self, path: pathlib.Path):
        super().__init__()
        self.path = path
        self.name = path.name

    def mount(self) -> pathlib.Path:
        linuxnamespaces.bind_mount(self.path, "/mnt", recursive=True)
        return pathlib.Path("/mnt")


class DirectorySourceChroot(SourceChroot):
    def __init__(self, path: pathlib.Path):
        super().__init__()
        self.path = path
        self.name = path.name + "-sbuild"

    def infodata(self) -> dict[str, str]:
        data = super().infodata()
        data["Type"] = "directory"
        data["Directory"] = str(self.path)
        return data

    def newsession(self) -> "DirectorySessionChroot":
        CACHE_UNSCHROOT.mkdir(parents=True, exist_ok=True)
        path = pathlib.Path(
            tempfile.mkdtemp(
                prefix=f"overlay-{self.name}-", dir=CACHE_UNSCHROOT
            ),
        )
        session = DirectorySessionChroot(self, path)
        uidmap, gidmap = load_subids()
        pid = os.fork()
        if pid == 0:
            linuxnamespaces.unshare_user_idmap(
                [uidmap, linuxnamespaces.IDMapping(65536, os.getuid(), 1)],
                [gidmap, linuxnamespaces.IDMapping(65536, os.getgid(), 1)],
            )
            os.setgid(0)
            os.setuid(0)
            os.chown(path, 0, 0)
            path.chmod(0o755)
            (path / "upper").mkdir()
            (path / "work").mkdir()
            if not (self.path / "etc/hosts").exists():
                (path / "upper/etc").mkdir()
                write_etc_hosts(path / "upper")
            sys.exit(0)
        _, ret = os.waitpid(pid, 0)
        assert ret == 0
        return session


class DirectorySessionChroot(SessionChroot):
    def __init__(self, source: DirectorySourceChroot, path: pathlib.Path):
        super().__init__()
        self.source = source
        self.path = path
        self.name = path.name

    def infodata(self) -> dict[str, str]:
        data = super().infodata()
        data["Type"] = "directory"
        data["Directory"] = str(self.source.path)
        # It's a gross lie, but sbuild does not work without. It has to
        # actually exist and should not occur inside build logs.
        data["Location"] = str(self.source.path)
        return data

    def mount(self) -> pathlib.Path:
        mnt = "/mnt"
        linuxnamespaces.mount(
            "overlay",
            mnt,
            "overlay",
            data=[
                "lowerdir=" + str(self.source.path),
                "upperdir=" + str(self.path / "upper"),
                "workdir=" + str(self.path / "work"),
                "userxattr",
            ],
        )
        return pathlib.Path(mnt)


def scan_chroots() -> dict[str, ChrootBase]:
    chrootmap: dict[str, ChrootBase] = {}
    chroot: ChrootBase
    for loc, cls in (
        (CACHE_SBUILD, TarSourceChroot),
        (CACHE_DIRECTORY_CHROOTS, DirectorySourceChroot),
    ):
        if loc.is_dir():
            chroots = []
            aliases: dict[str, set[str]] = {}
            for path in loc.iterdir():
                if path.is_symlink():
                    alias = path.name.split(".", 1)[0] + "-sbuild"
                    aliases.setdefault(str(path.readlink()), set()).add(alias)
                else:
                    chroots.append(path)
            for path in chroots:
                chroot = cls(path)
                chrootaliases = aliases.get(path.name, set())
                chroot.aliases.update(chrootaliases)
                if chroot.name not in chrootmap:
                    chrootmap[chroot.name] = chroot
                for alias in chrootaliases:
                    if alias not in chrootmap:
                        chrootmap[alias] = chroot

    if CACHE_UNSCHROOT.is_dir():
        for path in CACHE_UNSCHROOT.iterdir():
            if path.name.startswith("tar-"):
                chroot = TarSessionChroot(path)
                if chroot.name not in chrootmap:
                    chrootmap[chroot.name] = chroot
            elif path.name.startswith("overlay-"):
                base = "-".join(path.name.split("-")[1:-1])
                if base not in chrootmap:
                    continue
                source = chrootmap[base]
                assert isinstance(source, DirectorySourceChroot)
                chroot = DirectorySessionChroot(source, path)
                if chroot.name not in chrootmap:
                    chrootmap[chroot.name] = chroot

    return chrootmap


def do_info(args: argparse.Namespace) -> None:
    """Show information about selected chroots"""
    chrootmap = scan_chroots()
    chroots: typing.Iterable[ChrootBase]
    if args.chroot:
        chroots = [
            chrootmap[
                args.chroot.removeprefix("chroot:").removeprefix("session:")
            ],
        ]
    else:
        chroots = chrootmap.values()
    sys.stdout.write("\n".join(chroot.infostr() for chroot in chroots))


def do_begin_session(args: argparse.Namespace) -> None:
    """Begin a session; returns the session ID"""
    chrootmap = scan_chroots()
    source = chrootmap[args.chroot.removeprefix("chroot:")]
    assert isinstance(source, SourceChroot)
    session = source.newsession()
    print(session.name)


def exec_perl_dumb_init(pid: int) -> typing.NoReturn:
    """Roughly implement dumb-init in perl: Wait for all children until we
    receive an exit from the given pid and forward its status.
    """
    os.execlp(
        "perl",
        "perl",
        "-e",
        "$r=255<<8;"  # exit 255 when we run out of children
        "do{"
            "$p=wait;"
            f"$r=$?,$p=0 if $p=={pid};"
        "}while($p>0);"
        "exit(0<$r<256?128|$r:$r>>8);",  # sig -> 128+sig; exit -> exit
    )


def do_run_session(args: argparse.Namespace) -> None:
    """Run an existing session"""
    chrootmap = scan_chroots()
    session = chrootmap[args.chroot]
    assert isinstance(session, SessionChroot)
    uidmap, gidmap = load_subids()
    mainsock, childsock = socket.socketpair()
    pid = os.fork()
    pidfd: int
    if pid == 0:
        mainsock.close()
        for fd in (1, 2):
            if stat.S_ISFIFO(os.fstat(fd).st_mode):
                os.fchmod(fd, 0o666)
        ns = (
            linuxnamespaces.CloneFlags.NEWUSER
            | linuxnamespaces.CloneFlags.NEWNS
            | linuxnamespaces.CloneFlags.NEWPID
        )
        if args.isolate_network:
            ns |= linuxnamespaces.CloneFlags.NEWNET
        linuxnamespaces.unshare(ns)
        childsock.send(b"\0")
        childsock.recv(1)
        if os.fork() != 0:
            sys.exit(0)
        assert os.getpid() == 1
        with linuxnamespaces.FileDescriptor(os.pidfd_open(1, 0)) as pidfd:
            socket.send_fds(childsock, [b"\0"], [pidfd])
        os.setgid(0)
        os.setuid(0)
        root = session.mount()
        os.chdir(root)
        linuxnamespaces.populate_sys("/", ".", ns, devices=True)
        linuxnamespaces.populate_proc("/", ".", ns)
        linuxnamespaces.populate_dev(
            "/", ".", tun=bool(ns & linuxnamespaces.CloneFlags.NEWNET)
        )
        linuxnamespaces.pivot_root(".", ".")
        linuxnamespaces.umount(".", linuxnamespaces.UmountFlags.DETACH)
        os.chdir("/")
        if ns & linuxnamespaces.CloneFlags.NEWNET:
            linuxnamespaces.enable_loopback_if()
        if args.user.isdigit():
            spw = pwd.getpwuid(int(args.user))
        else:
            spw = pwd.getpwnam(args.user)
        supplementary = [
            sgr.gr_gid for sgr in grp.getgrall() if spw.pw_name in sgr.gr_mem
        ]

        childsock.recv(1)
        childsock.close()
        rfd, wfd = linuxnamespaces.FileDescriptor.pipe(inheritable=False)
        pid = os.fork()
        if pid == 0:
            wfd.close()
            if args.directory:
                os.chdir(args.directory)
            os.setgroups(supplementary)
            os.setgid(spw.pw_gid)
            os.setuid(spw.pw_uid)
            if "PATH" not in os.environ:
                if spw.pw_uid == 0:
                    os.environ["PATH"] = "/usr/sbin:/sbin:/usr/bin:/bin"
                else:
                    os.environ["PATH"] = "/usr/bin:/bin"
            if not args.command:
                args.command.append("bash")
            # Wait until Python has handed off to Perl.
            os.read(rfd, 1)
            os.execvp(args.command[0], args.command)
        else:
            rfd.close()
            linuxnamespaces.prctl_set_pdeathsig(signal.SIGKILL)
            os.close(0)
            # It is important that we now exec to get rid of our previous
            # execution context that carries pieces such as memory maps from
            # different namespaces that could allow escalating privileges. The
            # exec will close wfd and allow the target process to exec.
            exec_perl_dumb_init(pid)
    childsock.close()
    mainsock.recv(1)
    linuxnamespaces.newidmaps(pid, [uidmap], [gidmap])
    linuxnamespaces.prctl_set_child_subreaper(True)
    mainsock.send(b"\0")
    _data, fds, _flags, _address = socket.recv_fds(mainsock, 1, 1)
    pidfd = fds[0]
    os.waitpid(pid, 0)
    linuxnamespaces.prctl_set_child_subreaper(False)
    mainsock.send(b"\0")
    wres = os.waitid(os.P_PIDFD, pidfd, os.WEXITED)
    assert wres is not None
    sys.exit(wres.si_status)


def do_end_session(args: argparse.Namespace) -> None:
    """End an existing session"""
    chrootmap = scan_chroots()
    session = chrootmap[args.chroot]
    assert isinstance(session, (TarSessionChroot, DirectorySessionChroot))
    uidmap = linuxnamespaces.IDAllocation.loadsubid("uid").allocatemap(65536)
    gidmap = linuxnamespaces.IDAllocation.loadsubid("gid").allocatemap(65536)
    linuxnamespaces.unshare_user_idmap(
        [uidmap, linuxnamespaces.IDMapping(65536, os.getuid(), 1)],
        [gidmap, linuxnamespaces.IDMapping(65536, os.getgid(), 1)],
    )
    shutil.rmtree(session.path)


def main() -> None:
    parser = argparse.ArgumentParser()
    group = parser.add_mutually_exclusive_group(required=True)
    for comm in ("info", "begin-session", "run-session", "end-session"):
        func = globals()["do_" + comm.replace("-", "_")]
        group.add_argument(
            f"-{comm[0]}",
            f"--{comm}",
            dest="subcommand",
            action="store_const",
            const=func,
            help=func.__doc__,
        )
    parser.add_argument(
        "-c",
        "--chroot",
        dest="chroot",
        action="store",
        help="Use specified chroot",
    )
    parser.add_argument("-d", "--directory", action="store")
    parser.add_argument("-p", "--preserve-environment", action="store_true")
    parser.add_argument("-q", "--quiet", action="store_true")
    parser.add_argument("-u", "--user", action="store", default=os.getlogin())
    parser.add_argument("--isolate-network", action="store_true")
    parser.add_argument("command", nargs="*")
    args = parser.parse_args()
    assert args.subcommand is not None
    args.subcommand(args)


if __name__ == "__main__":
    main()