# Copyright 2024-2025 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

"""Provide functionalit related to mapping user and group ids in a user
namespace.
"""

import bisect
import dataclasses
import os
import subprocess
import typing

from .atlocation import AtLocation, AtLocationLike


def subidranges(
    kind: typing.Literal["uid", "gid"], login: str | None = None
) -> typing.Iterator[tuple[int, int]]:
    """Parse a `/etc/sub?id` file for ranges allocated to the given or current
    user. Return all ranges as (start, count) pairs.
    """
    if login is None:
        login = os.getlogin()
    with open(f"/etc/sub{kind}") as filelike:
        for line in filelike:
            parts = line.strip().split(":")
            if parts[0] == login:
                yield (int(parts[1]), int(parts[2]))


@dataclasses.dataclass(frozen=True)
class IDMapping:
    """Represent one range in a user or group id mapping."""

    innerstart: int
    outerstart: int
    count: int

    def __post_init__(self) -> None:
        if self.outerstart < 0:
            raise ValueError("outerstart must not be negative")
        if self.innerstart < 0:
            raise ValueError("innerstart must not be negative")
        if self.count <= 0:
            raise ValueError("count must be positive")
        if self.outerstart + self.count >= 1 << 64:
            raise ValueError("outerstart + count exceed 64bits")
        if self.innerstart + self.count >= 1 << 64:
            raise ValueError("innerstart + count exceed 64bits")

    @classmethod
    def identity(cls, idn: int, count: int = 1) -> typing.Self:
        """Construct an identity mapping for the given identifier."""
        return cls(idn, idn, count)


class IDAllocation:
    """This represents a subset of IDs (user or group). It can be used to
    allocate a contiguous range for use with a user namespace.
    """

    def __init__(self) -> None:
        self.ranges: list[tuple[int, int]] = []

    def add_range(self, start: int, count: int) -> None:
        """Add count ids starting from start to this allocation."""
        if start < 0 or count <= 0:
            raise ValueError("invalid range")
        index = bisect.bisect_right(self.ranges, (start, 0))
        prevrange = None
        if index > 0:
            prevrange = self.ranges[index - 1]
            if prevrange[0] + prevrange[1] > start:
                raise ValueError("attempt to add overlapping range")
        nextrange = None
        if index < len(self.ranges):
            nextrange = self.ranges[index]
            if nextrange[0] < start + count:
                raise ValueError("attempt to add overlapping range")
        if prevrange and prevrange[0] + prevrange[1] == start:
            if nextrange and nextrange[0] == start + count:
                self.ranges[index - 1] = (
                    prevrange[0],
                    prevrange[1] + count + nextrange[1],
                )
                del self.ranges[index]
            else:
                self.ranges[index - 1] = (prevrange[0], prevrange[1] + count)
        elif nextrange and nextrange[0] == start + count:
            self.ranges[index] = (start, count + nextrange[1])
        else:
            self.ranges.insert(index, (start, count))

    @classmethod
    def loadsubid(
        cls, kind: typing.Literal["uid", "gid"], login: str | None = None,
    ) -> "IDAllocation":
        """Load a `/etc/sub?id` file and return ids allocated to the given
        login or current user.
        """
        self = cls()
        for start, count in subidranges(kind, login):
            self.add_range(start, count)
        return self

    def find(self, count: int) -> int:
        """Locate count contiguous ids from this allocation. The start of
        the allocation is returned. The allocation object is left unchanged.
        """
        for start, available in self.ranges:
            if available >= count:
                return start
        raise ValueError("could not satisfy allocation request")

    def allocate(self, count: int) -> int:
        """Allocate count contiguous ids from this allocation. The start of
        the allocation is returned and the ids are removed from this
        IDAllocation object.
        """
        for index, (start, available) in enumerate(self.ranges):
            if available > count:
                self.ranges[index] = (start + count, available - count)
                return start
            if available == count:
                del self.ranges[index]
                return start
        raise ValueError("could not satisfy allocation request")

    def allocatemap(self, count: int, target: int = 0) -> IDMapping:
        """Allocate count contiguous ids from this allocation. An IDMapping
        with its innerstart set to target is returned. The allocation is
        removed from this IDAllocation object.
        """
        return IDMapping(target, self.allocate(count), count)

    def reserve(self, start: int, count: int) -> None:
        """Reserve (and remove) the given range from this allocation. If the
        range is not fully contained in this allocation, a ValueError is
        raised.
        """
        if count < 0:
            raise ValueError("negative count")
        index = bisect.bisect_right(self.ranges, (start, float("inf"))) - 1
        if index < 0:
            raise ValueError("range to reserve not found")
        cur_start, cur_count = self.ranges[index]
        assert cur_start <= start
        if cur_start == start:
            # Requested range starts at range boundary
            if cur_count < count:
                raise ValueError("range to reserve not found")
            if cur_count == count:
                # Requested range matches a range exactly
                del self.ranges[index]
            else:
                # Requested range is a head of the matched range
                self.ranges[index] = (start + count, cur_count - count)
        elif cur_start + cur_count >= start + count:
            # Requested range fits into a matched range
            self.ranges[index] = (cur_start, start - cur_start)
            if cur_start + cur_count > start + count:
                # Requested range punches a hole into a matched range
                self.ranges.insert(
                    index + 1,
                    (start + count, cur_start + cur_count - (start + count)),
                )
            # else: Requested range is a tail of a matched range
        else:
            raise ValueError("range to reserve not found")


def newidmap(
    kind: typing.Literal["uid", "gid"],
    pid: int,
    mapping: list[IDMapping],
    helper: bool | None = None,
    *,
    proc: AtLocationLike | None = None,
) -> None:
    """Apply the given uid or gid mapping to the given process. A positive pid
    identifies a process, other values identify the currently running process.
    Whether setuid binaries newuidmap and newgidmap are used is determined via
    the helper argument. A None value indicate automatic detection of whether
    a helper is required for setting up the given mapping.
    """

    assert kind in ("uid", "gid")
    if pid <= 0:
        pid = os.getpid()
    if helper is None:
        # We cannot reliably test whether we have the right EUID and we don't
        # implement checking whether setgroups has been denied either. Please
        # be explicit about the helper choice in such cases.
        helper = len(mapping) > 1 or mapping[0].count > 1
    if helper:
        argv = [f"new{kind}map", str(pid)]
        for idblock in mapping:
            argv.extend(map(str, dataclasses.astuple(idblock)))
        subprocess.check_call(argv)
    else:
        proc = AtLocation("/proc" if proc is None else proc)
        (proc / f"{pid}/{kind}_map").write_text(
            "".join(
                "%d %d %d\n" % dataclasses.astuple(idblock)
                for idblock in mapping
            ),
            encoding="ascii",
        )


def newuidmap(
    pid: int,
    mapping: list[IDMapping],
    helper: bool = True,
    *,
    proc: AtLocationLike | None = None,
) -> None:
    """Apply a given uid mapping to the given process. Refer to newidmap for
    details.
    """
    newidmap("uid", pid, mapping, helper, proc=proc)


def newgidmap(
    pid: int,
    mapping: list[IDMapping],
    helper: bool = True,
    *,
    proc: AtLocationLike | None = None,
) -> None:
    """Apply a given gid mapping to the given process. Refer to newidmap for
    details.
    """
    newidmap("gid", pid, mapping, helper, proc=proc)


def newidmaps(
    pid: int,
    uidmapping: list[IDMapping],
    gidmapping: list[IDMapping],
    helper: bool = True,
    *,
    proc: AtLocationLike | None = None,
) -> None:
    """Apply a given uid and gid mapping to the given process. Refer to
    newidmap for details.
    """
    newgidmap(pid, gidmapping, helper, proc=proc)
    newuidmap(pid, uidmapping, helper, proc=proc)