# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

"""Provide plumbing-layer functionality for working with Linux namespaces in
Python.
"""

import asyncio
import contextlib
import errno
import fcntl
import logging
import os
import pathlib
import socket
import stat
import struct
import typing

from .filedescriptor import *
from .idmap import *
from .atlocation import *
from .syscalls import *


_logger = logging.getLogger(__name__)


class run_in_fork:
    """Decorator for running the decorated function once in a separate process.
    """

    def __init__(
        self, function: typing.Callable[[], None], start: bool = False
    ):
        """Fork a new process that will run the given function and then exit.
        If start is true, run it immediately, otherwise the start or __call__
        method should be used.
        """
        self.efd = None if start else EventFD()
        self.pid = os.fork()
        if self.pid == 0:
            code = 0
            try:
                if self.efd is not None:
                    self.efd.read()
                    self.efd.close()
                    self.efd = None
                function()
            except SystemExit as err:
                code = err.code
            except:
                _logger.exception(
                    "uncaught exception in run_in_fork %r", function
                )
                code = 1
            os._exit(code)

    @classmethod
    def now(cls, function: typing.Callable[[], None]) -> typing.Self:
        """Fork a new process that will immediately run the given function and
        then exit."""
        return cls(function, start=True)

    def start(self) -> None:
        """Start the decorated function. It can only be started once."""
        if not self.efd:
            raise ValueError("this function can only be called once")
        self.efd.write(1)
        self.efd.close()
        self.efd = None

    def wait(self) -> None:
        """Wait for the process running the decorated function to finish."""
        if self.efd:
            raise ValueError("start must be called before wait")
        ret = os.waitpid(self.pid, 0)
        if ret != (self.pid, 0):
            raise ValueError("something failed")

    def __call__(self) -> None:
        """Start the decorated function if needed and wait for its process to
        finish.
        """
        if self.efd:
            self.start()
        self.wait()


class async_run_in_fork:
    """Decorator for running the decorated function once in a separate process.
    Note that the decorator can only be used inside asynchronous code as it
    uses the running event loop. The decorated function instead must be
    synchronous and it must not access the event loop of the main process.
    """

    def __init__(
        self, function: typing.Callable[[], None], start: bool = False
    ):
        """Fork a new process that will run the given function and then exit.
        If start is true, run it immediately, otherwise the start or __call__
        method should be used.
        """
        loop = asyncio.get_running_loop()
        with asyncio.get_child_watcher() as watcher:
            if not watcher.is_active():
                raise RuntimeError(
                    "active child watcher required for creating a process"
                )
            self.future = loop.create_future()
            self.efd = None if start else EventFD()
            self.pid = os.fork()
            if self.pid == 0:
                code = 0
                try:
                    if self.efd:
                        self.efd.read()
                        self.efd.close()
                        self.efd = None
                    asyncio.set_event_loop(None)
                    function()
                except SystemExit as err:
                    code = err.code
                except:
                    _logger.exception(
                        "uncaught exception in run_in_fork %r", function
                    )
                    code = 1
                os._exit(code)
            watcher.add_child_handler(self.pid, self._child_callback)

    @classmethod
    def now(cls, function: typing.Callable[[], None]) -> typing.Self:
        """Fork a new process that will immediately run the given function and
        then exit."""
        return cls(function, start=True)

    def _child_callback(self, pid: int, returncode: int) -> None:
        if self.pid != pid:
            return
        self.future.set_result(returncode)

    def start(self) -> None:
        """Start the decorated function. It can only be started once."""
        if not self.efd:
            raise ValueError("this function can only be called once")
        self.efd.write(1)
        self.efd.close()
        self.efd = None

    async def wait(self) -> None:
        """Wait for the process running the decorated function to finish."""
        if self.efd:
            raise ValueError("start must be called before wait")
        ret = await self.future
        if ret != 0:
            raise ValueError("something failed")

    async def __call__(self) -> None:
        """Start the decorated function if needed and wait for its process to
        finish.
        """
        if self.efd:
            self.start()
        await self.wait()


def bind_mount(
    source: AtLocationLike,
    target: AtLocationLike,
    recursive: bool = False,
    readonly: bool = False,
) -> None:
    """Create a bind mount from source to target. Depending on whether one of
    the locations involves a file descriptor or not, the new or old mount API
    will be used.
    """
    source = AtLocation(source)
    target = AtLocation(target)
    try:
        if readonly:
            # We would have to remount to apply the readonly flag, see
            # https://git.kernel.org/pub/scm/utils/util-linux/util-linux.git/commit/?id=9ac77b8a78452eab0612523d27fee52159f5016a
            raise ValueError()
        srcloc = os.fspath(source)
        tgtloc = os.fspath(target)
    except ValueError:
        otflags = OpenTreeFlags.CLONE
        if recursive:
            otflags |= OpenTreeFlags.AT_RECURSIVE
        with open_tree(source, otflags) as srcfd:
            if readonly:
                mount_setattr(srcfd, recursive, MountAttrFlags.RDONLY)
            move_mount(srcfd, target)
    else:
        mflags = MountFlags.BIND
        if recursive:
            mflags |= MountFlags.REC
        mount(srcloc, tgtloc, None, mflags)


def get_cgroup(pid: int = -1) -> pathlib.PurePath:
    """Look up the cgroup that the given pid or the current process belongs
    to.
    """
    return pathlib.PurePath(
        pathlib.Path(
            f"/proc/{pid}/cgroup" if pid > 0 else "/proc/self/cgroup"
        ).read_text().split(":", 2)[2].strip()
    )


_P = typing.ParamSpec("_P")

class _ExceptionExitCallback:
    """Helper class that invokes a callback when a context manager exists with
    a failure.
    """

    def __init__(
        self,
        callback: typing.Callable[_P, typing.Any],
        *args: _P.args,
        **kwargs: _P.kwargs,
    ) -> None:
        self.callback = callback
        self.args = args
        self.kwargs = kwargs

    def __enter__(self) -> None:
        pass

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: typing.Any,
    ) -> None:
        if exc_type is not None:
            self.callback(*self.args, **self.kwargs)


def populate_dev(
    origroot: AtLocationLike,
    newroot: PathConvertible,
    *,
    fuse: bool = True,
    pidns: bool = True,
    tun: bool = True,
) -> None:
    """Mount a tmpfs to the dev directory beneath newroot and populate it with
    basic devices by bind mounting them from the dev directory beneath
    origroot. Also mount a new pts instance.

    Even though a CAP_SYS_ADMIN-enabled process can umount components of the
    /dev hierarchy, they they cannot gain privileges in doing so as no
    hierarchies are restricted via tmpfs mounts or read-only bind mounts.
    """
    origdev = AtLocation(origroot) / "dev"
    newdev = AtLocation(newroot) / "dev"
    directories = {"pts", "shm"}
    files = set()
    symlinks = {
        "fd": "/proc/self/fd",
        "stdin": "/proc/self/fd/0",
        "stdout": "/proc/self/fd/1",
        "stderr": "/proc/self/fd/2",
    }
    bind_mounts: dict[str, AtLocation] = {}
    with contextlib.ExitStack() as exitstack:
        for fn in "null zero full random urandom tty".split():
            files.add(fn)
            bind_mounts[fn] = exitstack.enter_context(
                open_tree(origdev / fn, OpenTreeFlags.CLONE)
            )
        if fuse:
            files.add("fuse")
            bind_mounts["fuse"] = exitstack.enter_context(
                open_tree(origdev / "fuse", OpenTreeFlags.CLONE)
            )
        if pidns:
            symlinks["ptmx"] = "pts/ptmx"
        else:
            bind_mounts["pts"] = exitstack.enter_context(
                open_tree(
                    origdev / "pts",
                    OpenTreeFlags.AT_RECURSIVE | OpenTreeFlags.CLONE,
                )
            )
            files.add("ptmx")
            bind_mounts["ptmx"] = exitstack.enter_context(
                open_tree(origdev / "ptmx", OpenTreeFlags.CLONE)
            )
        if tun:
            directories.add("net")
            files.add("net/tun")
            bind_mounts["net/tun"] = exitstack.enter_context(
                open_tree(origdev / "net/tun", OpenTreeFlags.CLONE)
            )
        mount(
            "devtmpfs",
            newdev,
            "tmpfs",
            MountFlags.NOSUID | MountFlags.NOEXEC,
            "mode=0755",
        )
        exitstack.enter_context(
            _ExceptionExitCallback(umount, newdev, UmountFlags.DETACH)
        )
        for fn in directories:
            (newdev / fn).mkdir()
            (newdev / fn).chmod(0o755)
        mount(
            "tmpfs",
            newdev / "shm",
            "tmpfs",
            MountFlags.NOSUID | MountFlags.NODEV,
        )
        for fn in files:
            (newdev / fn).mknod(stat.S_IFREG)
        for fn, target in symlinks.items():
            (newdev / fn).symlink_to(target)
        if pidns:
            mount(
                "devpts",
                newdev / "pts",
                "devpts",
                MountFlags.NOSUID | MountFlags.NOEXEC,
                "gid=5,mode=620,ptmxmode=666",
            )
        for fn, fd in bind_mounts.items():
            move_mount(fd, newdev / fn)


def populate_proc(
    origroot: AtLocationLike,
    newroot: PathConvertible,
    namespaces: CloneFlags,
) -> None:
    """Mount a /proc hierarchy.

    Note that a user with CAP_SYS_ADMIN can change read-only bind mounts.
    Still those bind mounts provide guidance to the container as to which
    aspects it should manage.
    """
    assert namespaces & CloneFlags.NEWNS == CloneFlags.NEWNS
    newproc = AtLocation(newroot) / "proc"
    rwns = CloneFlags.NEWPID
    if namespaces & rwns == rwns:
        mount(
            "proc",
            newproc,
            "proc",
            MountFlags.NOSUID | MountFlags.NODEV | MountFlags.NOEXEC,
        )
    else:
        bind_mount(AtLocation(origroot) / "proc", newproc, True)
    with _ExceptionExitCallback(umount, newproc, UmountFlags.DETACH):
        rwns |= CloneFlags.NEWUSER | CloneFlags.NEWIPC | CloneFlags.NEWUTS
        if namespaces & rwns != rwns:
            psn: AtLocation | None = None
            if namespaces & CloneFlags.NEWNET == CloneFlags.NEWNET:
                psn = open_tree(
                    newproc / "sys/net",
                    OpenTreeFlags.CLONE | OpenTreeFlags.AT_RECURSIVE,
                )
            bind_mount(newproc / "sys", newproc / "sys", True, True)
            if psn is not None:
                move_mount(psn, newproc / "sys/net")
        elif namespaces & CloneFlags.NEWNET != CloneFlags.NEWNET:
            bind_mount(newproc / "sys/net", newproc / "sys/net", True, True)


def populate_sys(
    origroot: AtLocationLike,
    newroot: PathConvertible,
    namespaces: CloneFlags,
    rootcgroup: PathConvertible | None = None,
    module: bool = True,
    devices: bool = False,
) -> None:
    """Create a /sys hierarchy below newroot. Bind the cgroup hierarchy. The
    cgroup hierarchy will be mounted read-only if mounting the root group.

    The module parameter indicates whether the /sys/module should be made
    available read-only (True) or not at all (False). The devices parameter
    indicates whether the devices hierarchy should be made available. If the
    given namespaces happen to include a network namespace, virtual network
    devices will be modifiable.

    A process with CAP_SYS_ADMIN can remount the created bind mounts read-write
    or umount hiding mounts and thus elevate their privileges.
    """
    newsys = AtLocation(newroot) / "sys"
    mflags = MountFlags.NOSUID | MountFlags.NOEXEC | MountFlags.NODEV
    bind_mounts: dict[
        str, tuple[typing.Union[str, pathlib.PurePath], bool]
    ] = {}
    if rootcgroup is None:
        bind_mounts["fs/cgroup"] = ("fs/cgroup", True)
    else:
        bind_mounts["fs/cgroup"] = (
            "fs/cgroup" / pathlib.PurePath(rootcgroup).relative_to("/"), False
        )
    if module:
        bind_mounts["module"] = ("module", True)
    if devices:
        for subdir in ("bus", "class", "dev", "devices"):
            bind_mounts[subdir] = (subdir, True)
    if namespaces & CloneFlags.NEWNET:
        if not devices:
            bind_mounts["class/net"] = ("class/net", True)
        bind_mounts["devices/virtual/net"] = ("devices/virtual/net", False)

    bind_fds: dict[str, AtLocation] = {}
    with contextlib.ExitStack() as exitstack:
        for target, (source, rdonly) in bind_mounts.items():
            bindfd = exitstack.enter_context(
                open_tree(
                    AtLocation(origroot) / "sys" / source,
                    OpenTreeFlags.CLONE | OpenTreeFlags.AT_RECURSIVE,
                ),
            )
            if rdonly:
                mount_setattr(bindfd, True, MountAttrFlags.RDONLY)
            bind_fds[target] = bindfd

        mount("sysfs", newsys, "tmpfs", mflags, "mode=0755")

        exitstack.enter_context(
            _ExceptionExitCallback(umount, newsys, UmountFlags.DETACH)
        )
        dirs = set()
        for subdir in bind_fds:
            dirs.add(subdir)
            while "/" in subdir:
                subdir = subdir.rsplit("/", 1)[0]
                dirs.add(subdir)
        for subdir in sorted(dirs):
            (newsys / subdir).mkdir()
            (newsys / subdir).chmod(0o755)
        mflags |= MountFlags.REMOUNT | MountFlags.RDONLY
        mount("sysfs", newsys, "tmpfs", mflags, "mode=0755")
        for subdir, bindfd in sorted(bind_fds.items()):
            move_mount(bindfd, newsys / subdir)


def unshare_user_idmap(
    uidmap: list[IDMapping],
    gidmap: list[IDMapping],
    flags: CloneFlags = CloneFlags.NEWUSER,
) -> None:
    """Unshare the given namespaces (must include user) and set up the given
    id mappings.
    """
    pid = os.getpid()
    @run_in_fork
    def setup_idmaps() -> None:
        newidmaps(pid, uidmap, gidmap)
    unshare(flags)
    setup_idmaps()


def unshare_user_idmap_nohelper(
    uid: int,
    gid: int,
    flags: CloneFlags = CloneFlags.NEWUSER,
    *,
    proc: AtLocationLike | None = None,
) -> None:
    """Unshare the given namespaces (must include user) and
    map the current user and group to the given uid and gid
    without using the setuid helpers.
    """
    uidmap = IDMapping(uid, os.getuid(), 1)
    gidmap = IDMapping(gid, os.getgid(), 1)
    unshare(flags)
    proc = AtLocation("/proc" if proc is None else proc)
    (proc / "self/setgroups").write_bytes(b"deny")
    newidmaps(-1, [uidmap], [gidmap], False, proc=proc)


class _AsyncFilesender:
    bs = 65536

    def __init__(
        self,
        from_fd: FileDescriptor,
        to_fd: FileDescriptor,
        count: int | None = None,
    ):
        self.from_fd = from_fd
        self.to_fd = to_fd
        self.copied = 0
        self.remain = count
        self.loop = asyncio.get_running_loop()
        self.fut: asyncio.Future[int] = self.loop.create_future()
        self.loop.add_writer(self.to_fd, self.handle_write)

    def handle_write(self) -> None:
        try:
            ret = os.sendfile(
                self.to_fd,
                self.from_fd,
                None,
                self.bs if self.remain is None else min(self.bs, self.remain),
            )
        except OSError as err:
            if err.errno != errno.EAGAIN:
                self.loop.remove_writer(self.to_fd)
                self.fut.set_exception(err)
        else:
            self.copied += ret
            if self.remain is not None:
                self.remain -= ret
            if ret == 0 or self.remain == 0:
                self.loop.remove_writer(self.to_fd)
                self.fut.set_result(self.copied)


class _AsyncSplicer:
    bs = 65536

    def __init__(
        self,
        from_fd: FileDescriptor,
        to_fd: FileDescriptor,
        count: int | None = None,
    ):
        self.from_fd = from_fd
        self.to_fd = to_fd
        self.copied = 0
        self.remain = count
        self.wait_read = True
        self.loop = asyncio.get_running_loop()
        self.fut: asyncio.Future[int] = self.loop.create_future()
        self.loop.add_reader(self.from_fd, self.handle_io)

    def handle_io(self) -> None:
        try:
            ret = os.splice(
                self.from_fd,
                self.to_fd,
                self.bs if self.remain is None else min(self.bs, self.remain),
                flags=os.SPLICE_F_NONBLOCK,
            )
        except OSError as err:
            if err.errno == errno.EAGAIN:
                self.wait_read = not self.wait_read
                if self.wait_read:
                    self.loop.remove_writer(self.to_fd)
                    self.loop.add_reader(self.from_fd, self.handle_io)
                else:
                    self.loop.remove_reader(self.from_fd)
                    self.loop.add_writer(self.to_fd, self.handle_io)
            else:
                self.loop.remove_reader(self.from_fd)
                self.loop.remove_writer(self.to_fd)
                self.fut.set_exception(err)
        else:
            self.copied += ret
            if self.remain is not None:
                self.remain -= ret
            if ret == 0 or self.remain == 0:
                self.loop.remove_reader(self.from_fd)
                self.loop.remove_writer(self.to_fd)
                self.fut.set_result(self.copied)


class _AsyncCopier:
    bs = 65536

    def __init__(
        self,
        from_fd: FileDescriptor,
        to_fd: FileDescriptor,
        count: int | None = None,
    ):
        self.from_fd = from_fd
        self.to_fd = to_fd
        self.buffer = b""
        self.copied = 0  # bytes read and written
        self.remain = count  # remaining bytes not yet read
        # eof can be an exception when a read failed and otherwise indicates
        # whether a read returned 0.
        self.eof: bool | OSError = False
        self.loop = asyncio.get_running_loop()
        self.fut: asyncio.Future[int] = self.loop.create_future()
        self.loop.add_reader(self.from_fd, self.handle_readable)

    def handle_readable(self) -> None:
        try:
            data = os.read(
                self.from_fd,
                self.bs if self.remain is None else min(self.bs, self.remain),
            )
        except OSError as err:
            if err.errno != errno.EAGAIN:
                self.loop.remove_reader(self.from_fd)
                if self.buffer:
                    self.eof = err
                else:
                    self.fut.set_exception(err)
        else:
            if data:
                if self.remain is not None:
                    self.remain -= len(data)
                self.buffer += data
                if len(self.buffer) == len(data):
                    self.loop.add_writer(self.to_fd, self.handle_writeable)
                if self.remain == 0 or len(self.buffer) >= self.bs:
                    self.loop.remove_reader(self.from_fd)
            else:
                self.eof = True
                self.loop.remove_reader(self.from_fd)

    def handle_writeable(self) -> None:
        try:
            written = os.write(self.to_fd, self.buffer)
        except OSError as err:
            if err.errno != errno.EAGAIN:
                self.loop.remove_writer(self.to_fd)
                if isinstance(self.eof, OSError):
                    self.fut.set_exception(self.eof)
                else:
                    self.loop.remove_reader(self.from_fd)
                    self.fut.set_exception(err)
        else:
            self.buffer = self.buffer[written:]
            self.copied += written
            if not self.buffer:
                self.loop.remove_writer(self.to_fd)
                if self.eof is True or self.remain == 0:
                    self.fut.set_result(self.copied)
                elif isinstance(self.eof, OSError):
                    self.fut.set_exception(self.eof)
            elif not self.eof and self.remain and len(self.buffer) < self.bs:
                self.loop.add_reader(self.from_fd, self.handle_readable)


def async_copyfd(
    from_fd: FileDescriptorLike,
    to_fd: FileDescriptorLike,
    count: int | None = None,
) -> asyncio.Future[int]:
    """Copy the given number of bytes from the first file descriptor to the
    second file descriptor in an asyncio context. Both copies are performed
    binary. An efficient implementation is chosen depending on the file type
    of file descriptors.
    """
    from_fd = FileDescriptor(from_fd)
    to_fd = FileDescriptor(to_fd)
    from_mode = os.fstat(from_fd).st_mode
    if stat.S_ISREG(from_mode):
        return _AsyncFilesender(from_fd, to_fd, count).fut
    if stat.S_ISFIFO(from_mode) or stat.S_ISFIFO(os.fstat(to_fd).st_mode):
        return _AsyncSplicer(from_fd, to_fd, count).fut
    return _AsyncCopier(from_fd, to_fd, count).fut


class _AsyncPidfdWaiter:
    def __init__(self, pidfd: FileDescriptor, flags: int):
        self.pidfd = pidfd
        self.flags = flags
        self.loop = asyncio.get_running_loop()
        self.fut: asyncio.Future[
            os.waitid_result | None
        ] = self.loop.create_future()
        self.loop.add_reader(pidfd, self.handle_readable)

    def handle_readable(self) -> None:
        try:
            result = os.waitid(os.P_PIDFD, self.pidfd, self.flags)
        except OSError as err:
            if err.errno != errno.EAGAIN:
                self.loop.remove_reader(self.pidfd)
                self.fut.set_exception(err)
        except Exception as err:
            self.loop.remove_reader(self.pidfd)
            self.fut.set_exception(err)
        else:
            self.loop.remove_reader(self.pidfd)
            self.fut.set_result(result)


def async_waitpidfd(
    pidfd: FileDescriptorLike, flags: int
) -> asyncio.Future[os.waitid_result | None]:
    """Asynchronously wait for a process represented as a pidfd. This is an
    async variant of waitid(P_PIDFD, pidfd, flags).
    """
    return _AsyncPidfdWaiter(FileDescriptor(pidfd), flags).fut


def enable_loopback_if() -> None:
    """Enable the loopback network interface that is initially down in a new
    network namespace.
    """
    # We us the old and deprecated ioctl API rather than netlink, because it
    # is way simpler and good enough for our purpose. The interface is always
    # created as "lo" by the kernel and it'll have loopback addresses
    # configured automatically. All that we have to do is "up" it.
    SIOCSIFFLAGS = 0x8914
    IFF_UP = 1
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) as sock:
        fcntl.ioctl(sock, SIOCSIFFLAGS, struct.pack("@16sH", b"lo", IFF_UP))