# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

"""Provide plumbing-layer functionality for working with Linux namespaces in
Python.
"""

import asyncio
import bisect
import contextlib
import dataclasses
import errno
import fcntl
import os
import pathlib
import socket
import stat
import struct
import subprocess
import typing

from .filedescriptor import *
from .atlocation import *
from .syscalls import *


def subidranges(
    kind: typing.Literal["uid", "gid"], login: str | None = None
) -> typing.Iterator[tuple[int, int]]:
    """Parse a `/etc/sub?id` file for ranges allocated to the given or current
    user. Return all ranges as (start, count) pairs.
    """
    if login is None:
        login = os.getlogin()
    with open(f"/etc/sub{kind}") as filelike:
        for line in filelike:
            parts = line.strip().split(":")
            if parts[0] == login:
                yield (int(parts[1]), int(parts[2]))


@dataclasses.dataclass(frozen=True)
class IDMapping:
    """Represent one range in a user or group id mapping."""

    innerstart: int
    outerstart: int
    count: int

    def __post_init__(self) -> None:
        if self.outerstart < 0:
            raise ValueError("outerstart must not be negative")
        if self.innerstart < 0:
            raise ValueError("innerstart must not be negative")
        if self.count <= 0:
            raise ValueError("count must be positive")
        if self.outerstart + self.count >= 1 << 64:
            raise ValueError("outerstart + count exceed 64bits")
        if self.innerstart + self.count >= 1 << 64:
            raise ValueError("innerstart + count exceed 64bits")

    @classmethod
    def identity(cls, idn: int, count: int = 1) -> typing.Self:
        """Construct an identity mapping for the given identifier."""
        return cls(idn, idn, count)


class IDAllocation:
    """This represents a subset of IDs (user or group). It can be used to
    allocate a contiguous range for use with a user namespace.
    """

    def __init__(self) -> None:
        self.ranges: list[tuple[int, int]] = []

    def add_range(self, start: int, count: int) -> None:
        """Add count ids starting from start to this allocation."""
        if start < 0 or count <= 0:
            raise ValueError("invalid range")
        index = bisect.bisect_right(self.ranges, (start, 0))
        prevrange = None
        if index > 0:
            prevrange = self.ranges[index - 1]
            if prevrange[0] + prevrange[1] > start:
                raise ValueError("attempt to add overlapping range")
        nextrange = None
        if index < len(self.ranges):
            nextrange = self.ranges[index]
            if nextrange[0] < start + count:
                raise ValueError("attempt to add overlapping range")
        if prevrange and prevrange[0] + prevrange[1] == start:
            if nextrange and nextrange[0] == start + count:
                self.ranges[index - 1] = (
                    prevrange[0],
                    prevrange[1] + count + nextrange[1],
                )
                del self.ranges[index]
            else:
                self.ranges[index - 1] = (prevrange[0], prevrange[1] + count)
        elif nextrange and nextrange[0] == start + count:
            self.ranges[index] = (start, count + nextrange[1])
        else:
            self.ranges.insert(index, (start, count))

    @classmethod
    def loadsubid(
        cls, kind: typing.Literal["uid", "gid"], login: str | None = None,
    ) -> "IDAllocation":
        """Load a `/etc/sub?id` file and return ids allocated to the given
        login or current user.
        """
        self = cls()
        for start, count in subidranges(kind, login):
            self.add_range(start, count)
        return self

    def find(self, count: int) -> int:
        """Locate count contiguous ids from this allocation. The start of
        the allocation is returned. The allocation object is left unchanged.
        """
        for start, available in self.ranges:
            if available >= count:
                return start
        raise ValueError("could not satisfy allocation request")

    def allocate(self, count: int) -> int:
        """Allocate count contiguous ids from this allocation. The start of
        the allocation is returned and the ids are removed from this
        IDAllocation object.
        """
        for index, (start, available) in enumerate(self.ranges):
            if available > count:
                self.ranges[index] = (start + count, available - count)
                return start
            if available == count:
                del self.ranges[index]
                return start
        raise ValueError("could not satisfy allocation request")

    def allocatemap(self, count: int, target: int = 0) -> IDMapping:
        """Allocate count contiguous ids from this allocation. An IDMapping
        with its innerstart set to target is returned. The allocation is
        removed from this IDAllocation object.
        """
        return IDMapping(target, self.allocate(count), count)

    def reserve(self, start: int, count: int) -> None:
        """Reserve (and remove) the given range from this allocation. If the
        range is not fully contained in this allocation, a ValueError is
        raised.
        """
        if count < 0:
            raise ValueError("negative count")
        index = bisect.bisect_right(self.ranges, (start, float("inf"))) - 1
        if index < 0:
            raise ValueError("range to reserve not found")
        cur_start, cur_count = self.ranges[index]
        assert cur_start <= start
        if cur_start == start:
            # Requested range starts at range boundary
            if cur_count < count:
                raise ValueError("range to reserve not found")
            if cur_count == count:
                # Requested range matches a range exactly
                del self.ranges[index]
            else:
                # Requested range is a head of the matched range
                self.ranges[index] = (start + count, cur_count - count)
        elif cur_start + cur_count >= start + count:
            # Requested range fits into a matched range
            self.ranges[index] = (cur_start, start - cur_start)
            if cur_start + cur_count > start + count:
                # Requested range punches a hole into a matched range
                self.ranges.insert(
                    index + 1,
                    (start + count, cur_start + cur_count - (start + count)),
                )
            # else: Requested range is a tail of a matched range
        else:
            raise ValueError("range to reserve not found")


def newidmap(
    kind: typing.Literal["uid", "gid"],
    pid: int,
    mapping: list[IDMapping],
    helper: bool | None = None,
) -> None:
    """Apply the given uid or gid mapping to the given process. A positive pid
    identifies a process, other values identify the currently running process.
    Whether setuid binaries newuidmap and newgidmap are used is determined via
    the helper argument. A None value indicate automatic detection of whether
    a helper is required for setting up the given mapping.
    """

    assert kind in ("uid", "gid")
    if pid <= 0:
        pid = os.getpid()
    if helper is None:
        # We cannot reliably test whether we have the right EUID and we don't
        # implement checking whether setgroups has been denied either. Please
        # be explicit about the helper choice in such cases.
        helper = len(mapping) > 1 or mapping[0].count > 1
    if helper:
        argv = [f"new{kind}map", str(pid)]
        for idblock in mapping:
            argv.extend(map(str, dataclasses.astuple(idblock)))
        subprocess.check_call(argv)
    else:
        pathlib.Path(f"/proc/{pid}/{kind}_map").write_text(
            "".join(
                "%d %d %d\n" % dataclasses.astuple(idblock)
                for idblock in mapping
            ),
            encoding="ascii",
        )


def newuidmap(pid: int, mapping: list[IDMapping], helper: bool = True) -> None:
    """Apply a given uid mapping to the given process. Refer to newidmap for
    details.
    """
    newidmap("uid", pid, mapping, helper)


def newgidmap(pid: int, mapping: list[IDMapping], helper: bool = True) -> None:
    """Apply a given gid mapping to the given process. Refer to newidmap for
    details.
    """
    newidmap("gid", pid, mapping, helper)


def newidmaps(
    pid: int,
    uidmapping: list[IDMapping],
    gidmapping: list[IDMapping],
    helper: bool = True,
) -> None:
    """Apply a given uid and gid mapping to the given process. Refer to
    newidmap for details.
    """
    newgidmap(pid, gidmapping, helper)
    newuidmap(pid, uidmapping, helper)


class run_in_fork:
    """Decorator for running the decorated function once in a separate process.
    """

    def __init__(self, function: typing.Callable[[], None]):
        """Fork a new process that will eventually run the given function and
        then exit.
        """
        self.efd = EventFD()
        self.pid = os.fork()
        if self.pid == 0:
            self.efd.read()
            self.efd.close()
            function()
            os._exit(0)

    def start(self) -> None:
        """Start the decorated function. It can only be started once."""
        if not self.efd:
            raise ValueError("this function can only be called once")
        self.efd.write(1)
        self.efd.close()

    def wait(self) -> None:
        """Wait for the process running the decorated function to finish."""
        if self.efd:
            raise ValueError("start must be called before wait")
        ret = os.waitpid(self.pid, 0)
        if ret != (self.pid, 0):
            raise ValueError("something failed")

    def __call__(self) -> None:
        """Start the decorated function and wait for its process to finish."""
        self.start()
        self.wait()


class async_run_in_fork:
    """Decorator for running the decorated function once in a separate process.
    Note that the decorator can only be used inside asynchronous code as it
    uses the running event loop. The decorated function instead must be
    synchronous and it must not access the event loop of the main process.
    """

    def __init__(self, function: typing.Callable[[], None]):
        """Fork a new process that will eventually run the given function and
        then exit.
        """
        loop = asyncio.get_running_loop()
        with asyncio.get_child_watcher() as watcher:
            if not watcher.is_active():
                raise RuntimeError(
                    "active child watcher required for creating a process"
                )
            self.future = loop.create_future()
            self.efd = EventFD()
            self.pid = os.fork()
            if self.pid == 0:
                self.efd.read()
                self.efd.close()
                function()
                os._exit(0)
            watcher.add_child_handler(self.pid, self._child_callback)

    def _child_callback(self, pid: int, returncode: int) -> None:
        if self.pid != pid:
            return
        self.future.set_result(returncode)

    def start(self) -> None:
        """Start the decorated function. It can only be started once."""
        if not self.efd:
            raise ValueError("this function can only be called once")
        self.efd.write(1)
        self.efd.close()

    async def wait(self) -> None:
        """Wait for the process running the decorated function to finish."""
        if self.efd:
            raise ValueError("start must be called before wait")
        ret = await self.future
        if ret != 0:
            raise ValueError("something failed")

    async def __call__(self) -> None:
        """Start the decorated function and wait for its process to finish."""
        self.start()
        await self.wait()


def bind_mount(
    source: AtLocationLike,
    target: AtLocationLike,
    recursive: bool = False,
    readonly: bool = False,
) -> None:
    """Create a bind mount from source to target. Depending on whether one of
    the locations involves a file descriptor or not, the new or old mount API
    will be used.
    """
    source = AtLocation(source)
    target = AtLocation(target)
    try:
        if readonly:
            # We would have to remount to apply the readonly flag, see
            # https://git.kernel.org/pub/scm/utils/util-linux/util-linux.git/commit/?id=9ac77b8a78452eab0612523d27fee52159f5016a
            raise ValueError()
        srcloc = os.fspath(source)
        tgtloc = os.fspath(target)
    except ValueError:
        otflags = OpenTreeFlags.OPEN_TREE_CLONE
        if recursive:
            otflags |= OpenTreeFlags.AT_RECURSIVE
        with open_tree(source, otflags) as srcfd:
            if readonly:
                mount_setattr(srcfd, recursive, MountAttrFlags.RDONLY)
            move_mount(srcfd, target)
    else:
        mflags = MountFlags.BIND
        if recursive:
            mflags |= MountFlags.REC
        mount(srcloc, tgtloc, None, mflags)


def get_cgroup(pid: int = -1) -> pathlib.PurePath:
    """Look up the cgroup that the given pid or the current process belongs
    to.
    """
    return pathlib.PurePath(
        pathlib.Path(
            f"/proc/{pid}/cgroup" if pid > 0 else "/proc/self/cgroup"
        ).read_text().split(":", 2)[2].strip()
    )


_P = typing.ParamSpec("_P")

class _ExceptionExitCallback:
    """Helper class that invokes a callback when a context manager exists with
    a failure.
    """

    def __init__(
        self,
        callback: typing.Callable[_P, typing.Any],
        *args: _P.args,
        **kwargs: _P.kwargs,
    ) -> None:
        self.callback = callback
        self.args = args
        self.kwargs = kwargs

    def __enter__(self) -> None:
        pass

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: typing.Any,
    ) -> None:
        if exc_type is not None:
            self.callback(*self.args, **self.kwargs)


def populate_dev(
    origroot: AtLocationLike,
    newroot: PathConvertible,
    *,
    fuse: bool = True,
    pidns: bool = True,
    tun: bool = True,
) -> None:
    """Mount a tmpfs to the dev directory beneath newroot and populate it with
    basic devices by bind mounting them from the dev directory beneath
    origroot. Also mount a new pts instance.

    Even though a CAP_SYS_ADMIN-enabled process can umount components of the
    /dev hierarchy, they they cannot gain privileges in doing so as no
    hierarchies are restricted via tmpfs mounts or read-only bind mounts.
    """
    origdev = AtLocation(origroot) / "dev"
    newdev = AtLocation(newroot) / "dev"
    directories = {"pts", "shm"}
    files = set()
    symlinks = {
        "fd": "/proc/self/fd",
        "stdin": "/proc/self/fd/0",
        "stdout": "/proc/self/fd/1",
        "stderr": "/proc/self/fd/2",
    }
    bind_mounts: dict[str, AtLocation] = {}
    with contextlib.ExitStack() as exitstack:
        for fn in "null zero full random urandom tty".split():
            files.add(fn)
            bind_mounts[fn] = exitstack.enter_context(
                open_tree(origdev / fn, OpenTreeFlags.OPEN_TREE_CLONE)
            )
        if fuse:
            files.add("fuse")
            bind_mounts["fuse"] = exitstack.enter_context(
                open_tree(origdev / "fuse", OpenTreeFlags.OPEN_TREE_CLONE)
            )
        if pidns:
            symlinks["ptmx"] = "pts/ptmx"
        else:
            bind_mounts["pts"] = exitstack.enter_context(
                open_tree(
                    origdev / "pts",
                    OpenTreeFlags.AT_RECURSIVE | OpenTreeFlags.OPEN_TREE_CLONE,
                )
            )
            files.add("ptmx")
            bind_mounts["ptmx"] = exitstack.enter_context(
                open_tree(origdev / "ptmx", OpenTreeFlags.OPEN_TREE_CLONE)
            )
        if tun:
            directories.add("net")
            files.add("net/tun")
            bind_mounts["net/tun"] = exitstack.enter_context(
                open_tree(origdev / "net/tun", OpenTreeFlags.OPEN_TREE_CLONE)
            )
        mount(
            "devtmpfs",
            newdev,
            "tmpfs",
            MountFlags.NOSUID | MountFlags.NOEXEC,
            "mode=0755",
        )
        exitstack.enter_context(
            _ExceptionExitCallback(umount, newdev, UmountFlags.DETACH)
        )
        for fn in directories:
            (newdev / fn).mkdir()
            (newdev / fn).chmod(0o755)
        mount(
            "tmpfs",
            newdev / "shm",
            "tmpfs",
            MountFlags.NOSUID | MountFlags.NODEV,
        )
        for fn in files:
            (newdev / fn).mknod(stat.S_IFREG)
        for fn, target in symlinks.items():
            (newdev / fn).symlink_to(target)
        if pidns:
            mount(
                "devpts",
                newdev / "pts",
                "devpts",
                MountFlags.NOSUID | MountFlags.NOEXEC,
                "gid=5,mode=620,ptmxmode=666",
            )
        for fn, fd in bind_mounts.items():
            move_mount(fd, newdev / fn)


def populate_proc(
    origroot: AtLocationLike,
    newroot: PathConvertible,
    namespaces: CloneFlags,
) -> None:
    """Mount a /proc hierarchy.

    Note that a user with CAP_SYS_ADMIN can change read-only bind mounts.
    Still those bind mounts provide guidance to the container as to which
    aspects it should manage.
    """
    assert namespaces & CloneFlags.NEWNS == CloneFlags.NEWNS
    newproc = AtLocation(newroot) / "proc"
    rwns = CloneFlags.NEWPID
    if namespaces & rwns == rwns:
        mount(
            "proc",
            newproc,
            "proc",
            MountFlags.NOSUID | MountFlags.NODEV | MountFlags.NOEXEC,
        )
    else:
        bind_mount(AtLocation(origroot) / "proc", newproc, True)
    with _ExceptionExitCallback(umount, newproc, UmountFlags.DETACH):
        rwns |= CloneFlags.NEWUSER | CloneFlags.NEWIPC | CloneFlags.NEWUTS
        if namespaces & rwns != rwns:
            psn: AtLocation | None = None
            if namespaces & CloneFlags.NEWNET == CloneFlags.NEWNET:
                psn = open_tree(
                    newproc / "sys/net",
                    OpenTreeFlags.OPEN_TREE_CLONE | OpenTreeFlags.AT_RECURSIVE,
                )
            bind_mount(newproc / "sys", newproc / "sys", True, True)
            if psn is not None:
                move_mount(psn, newproc / "sys/net")
        elif namespaces & CloneFlags.NEWNET != CloneFlags.NEWNET:
            bind_mount(newproc / "sys/net", newproc / "sys/net", True, True)


def populate_sys(
    origroot: AtLocationLike,
    newroot: PathConvertible,
    namespaces: CloneFlags,
    rootcgroup: PathConvertible | None = None,
    module: bool = True,
    devices: bool = False,
) -> None:
    """Create a /sys hierarchy below newroot. Bind the cgroup hierarchy. The
    cgroup hierarchy will be mounted read-only if mounting the root group.

    The module parameter indicates whether the /sys/module should be made
    available read-only (True) or not at all (False). The devices parameter
    indicates whether the devices hierarchy should be made available. If the
    given namespaces happen to include a network namespace, virtual network
    devices will be modifiable.

    A process with CAP_SYS_ADMIN can remount the created bind mounts read-write
    or umount hiding mounts and thus elevate their privileges.
    """
    newsys = AtLocation(newroot) / "sys"
    mflags = MountFlags.NOSUID | MountFlags.NOEXEC | MountFlags.NODEV
    bind_mounts: dict[
        str, tuple[typing.Union[str, pathlib.PurePath], bool]
    ] = {}
    if rootcgroup is None:
        bind_mounts["fs/cgroup"] = ("fs/cgroup", True)
    else:
        bind_mounts["fs/cgroup"] = (
            "fs/cgroup" / pathlib.PurePath(rootcgroup).relative_to("/"), False
        )
    if module:
        bind_mounts["module"] = ("module", True)
    if devices:
        for subdir in ("bus", "class", "dev", "devices"):
            bind_mounts[subdir] = (subdir, True)
    if namespaces & CloneFlags.NEWNET:
        if not devices:
            bind_mounts["class/net"] = ("class/net", True)
        bind_mounts["devices/virtual/net"] = ("devices/virtual/net", False)

    bind_fds: dict[str, AtLocation] = {}
    with contextlib.ExitStack() as exitstack:
        for target, (source, rdonly) in bind_mounts.items():
            bindfd = exitstack.enter_context(
                open_tree(
                    AtLocation(origroot) / "sys" / source,
                    OpenTreeFlags.OPEN_TREE_CLONE | OpenTreeFlags.AT_RECURSIVE,
                ),
            )
            if rdonly:
                mount_setattr(bindfd, True, MountAttrFlags.RDONLY)
            bind_fds[target] = bindfd

        mount("sysfs", newsys, "tmpfs", mflags, "mode=0755")

        exitstack.enter_context(
            _ExceptionExitCallback(umount, newsys, UmountFlags.DETACH)
        )
        dirs = set()
        for subdir in bind_fds:
            dirs.add(subdir)
            while "/" in subdir:
                subdir = subdir.rsplit("/", 1)[0]
                dirs.add(subdir)
        for subdir in sorted(dirs):
            (newsys / subdir).mkdir()
            (newsys / subdir).chmod(0o755)
        mflags |= MountFlags.REMOUNT | MountFlags.RDONLY
        mount("sysfs", newsys, "tmpfs", mflags, "mode=0755")
        for subdir, bindfd in sorted(bind_fds.items()):
            move_mount(bindfd, newsys / subdir)


def unshare_user_idmap(
    uidmap: list[IDMapping],
    gidmap: list[IDMapping],
    flags: CloneFlags = CloneFlags.NEWUSER,
) -> None:
    """Unshare the given namespaces (must include user) and set up the given
    id mappings.
    """
    pid = os.getpid()
    @run_in_fork
    def setup_idmaps() -> None:
        newidmaps(pid, uidmap, gidmap)
    unshare(flags)
    setup_idmaps()

def unshare_user_idmap_nohelper(
    uid: int, gid: int, flags: CloneFlags = CloneFlags.NEWUSER
) -> None:
    """Unshare the given namespaces (must include user) and
    map the current user and group to the given uid and gid
    without using the setuid helpers.
    """
    uidmap = IDMapping(uid, os.getuid(), 1)
    gidmap = IDMapping(gid, os.getgid(), 1)
    unshare(flags)
    pathlib.Path("/proc/self/setgroups").write_bytes(b"deny")
    newidmaps(-1, [uidmap], [gidmap], False)


class _AsyncFilesender:
    bs = 65536

    def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
        self.from_fd = from_fd
        self.to_fd = to_fd
        self.copied = 0
        self.remain = count
        self.loop = asyncio.get_running_loop()
        self.fut: asyncio.Future[int] = self.loop.create_future()
        self.loop.add_writer(self.to_fd, self.handle_write)

    def handle_write(self) -> None:
        try:
            ret = os.sendfile(
                self.to_fd,
                self.from_fd,
                None,
                self.bs if self.remain is None else min(self.bs, self.remain),
            )
        except OSError as err:
            if err.errno != errno.EAGAIN:
                self.loop.remove_writer(self.to_fd)
                self.fut.set_exception(err)
        else:
            self.copied += ret
            if self.remain is not None:
                self.remain -= ret
            if ret == 0 or self.remain == 0:
                self.loop.remove_writer(self.to_fd)
                self.fut.set_result(self.copied)


class _AsyncSplicer:
    bs = 65536

    def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
        self.from_fd = from_fd
        self.to_fd = to_fd
        self.copied = 0
        self.remain = count
        self.wait_read = True
        self.loop = asyncio.get_running_loop()
        self.fut: asyncio.Future[int] = self.loop.create_future()
        self.loop.add_reader(self.from_fd, self.handle_io)

    def handle_io(self) -> None:
        try:
            ret = os.splice(
                self.from_fd,
                self.to_fd,
                self.bs if self.remain is None else min(self.bs, self.remain),
                flags=os.SPLICE_F_NONBLOCK,
            )
        except OSError as err:
            if err.errno == errno.EAGAIN:
                self.wait_read = not self.wait_read
                if self.wait_read:
                    self.loop.remove_writer(self.to_fd)
                    self.loop.add_reader(self.from_fd, self.handle_io)
                else:
                    self.loop.remove_reader(self.from_fd)
                    self.loop.add_writer(self.to_fd, self.handle_io)
            else:
                self.loop.remove_reader(self.from_fd)
                self.loop.remove_writer(self.to_fd)
                self.fut.set_exception(err)
        else:
            self.copied += ret
            if self.remain is not None:
                self.remain -= ret
            if ret == 0 or self.remain == 0:
                self.loop.remove_reader(self.from_fd)
                self.loop.remove_writer(self.to_fd)
                self.fut.set_result(self.copied)


class _AsyncCopier:
    bs = 65536

    def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
        self.from_fd = from_fd
        self.to_fd = to_fd
        self.buffer = b""
        self.copied = 0  # bytes read and written
        self.remain = count  # remaining bytes not yet read
        # eof can be an exception when a read failed and otherwise indicates
        # whether a read returned 0.
        self.eof: bool | OSError = False
        self.loop = asyncio.get_running_loop()
        self.fut: asyncio.Future[int] = self.loop.create_future()
        self.loop.add_reader(self.from_fd, self.handle_readable)

    def handle_readable(self) -> None:
        try:
            data = os.read(
                self.from_fd,
                self.bs if self.remain is None else min(self.bs, self.remain),
            )
        except OSError as err:
            if err.errno != errno.EAGAIN:
                self.loop.remove_reader(self.from_fd)
                if self.buffer:
                    self.eof = err
                else:
                    self.fut.set_exception(err)
        else:
            if data:
                if self.remain is not None:
                    self.remain -= len(data)
                self.buffer += data
                if len(self.buffer) == len(data):
                    self.loop.add_writer(self.to_fd, self.handle_writeable)
                if self.remain == 0 or len(self.buffer) >= self.bs:
                    self.loop.remove_reader(self.from_fd)
            else:
                self.eof = True
                self.loop.remove_reader(self.from_fd)

    def handle_writeable(self) -> None:
        try:
            written = os.write(self.to_fd, self.buffer)
        except OSError as err:
            if err.errno != errno.EAGAIN:
                self.loop.remove_writer(self.to_fd)
                if isinstance(self.eof, OSError):
                    self.fut.set_exception(self.eof)
                else:
                    self.loop.remove_reader(self.from_fd)
                    self.fut.set_exception(err)
        else:
            self.buffer = self.buffer[written:]
            self.copied += written
            if not self.buffer:
                self.loop.remove_writer(self.to_fd)
                if self.eof is True or self.remain == 0:
                    self.fut.set_result(self.copied)
                elif isinstance(self.eof, OSError):
                    self.fut.set_exception(self.eof)
            elif not self.eof and self.remain and len(self.buffer) < self.bs:
                self.loop.add_reader(self.from_fd, self.handle_readable)


def async_copyfd(
    from_fd: int, to_fd: int, count: int | None = None
) -> asyncio.Future[int]:
    """Copy the given number of bytes from the first file descriptor to the
    second file descriptor in an asyncio context. Both copies are performed
    binary. An efficient implementation is chosen depending on the file type
    of file descriptors.
    """
    from_mode = os.fstat(from_fd).st_mode
    if stat.S_ISREG(from_mode):
        return _AsyncFilesender(from_fd, to_fd, count).fut
    if stat.S_ISFIFO(from_mode) or stat.S_ISFIFO(os.fstat(to_fd).st_mode):
        return _AsyncSplicer(from_fd, to_fd, count).fut
    return _AsyncCopier(from_fd, to_fd, count).fut


class _AsyncPidfdWaiter:
    def __init__(self, pidfd: int, flags: int):
        self.pidfd = pidfd
        self.flags = flags
        self.loop = asyncio.get_running_loop()
        self.fut: asyncio.Future[
            os.waitid_result | None
        ] = self.loop.create_future()
        self.loop.add_reader(pidfd, self.handle_readable)

    def handle_readable(self) -> None:
        try:
            result = os.waitid(os.P_PIDFD, self.pidfd, self.flags)
        except OSError as err:
            if err.errno != errno.EAGAIN:
                self.loop.remove_reader(self.pidfd)
                self.fut.set_exception(err)
        except Exception as err:
            self.loop.remove_reader(self.pidfd)
            self.fut.set_exception(err)
        else:
            self.loop.remove_reader(self.pidfd)
            self.fut.set_result(result)


def async_waitpidfd(
    pidfd: int, flags: int
) -> asyncio.Future[os.waitid_result | None]:
    """Asynchronously wait for a process represented as a pidfd. This is an
    async variant of waitid(P_PIDFD, pidfd, flags).
    """
    return _AsyncPidfdWaiter(pidfd, flags).fut


def enable_loopback_if() -> None:
    """Enable the loopback network interface that is initially down in a new
    network namespace.
    """
    # We us the old and deprecated ioctl API rather than netlink, because it
    # is way simpler and good enough for our purpose. The interface is always
    # created as "lo" by the kernel and it'll have loopback addresses
    # configured automatically. All that we have to do is "up" it.
    SIOCSIFFLAGS = 0x8914
    IFF_UP = 1
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) as sock:
        fcntl.ioctl(sock, SIOCSIFFLAGS, struct.pack("@16sH", b"lo", IFF_UP))