summaryrefslogtreecommitdiff
path: root/linuxnamespaces
diff options
context:
space:
mode:
Diffstat (limited to 'linuxnamespaces')
-rw-r--r--linuxnamespaces/__init__.py400
-rw-r--r--linuxnamespaces/atlocation.py49
-rw-r--r--linuxnamespaces/filedescriptor.py35
-rw-r--r--linuxnamespaces/idmap.py250
-rw-r--r--linuxnamespaces/syscalls.py358
-rw-r--r--linuxnamespaces/systemd/__init__.py53
-rw-r--r--linuxnamespaces/systemd/dbussy.py12
-rw-r--r--linuxnamespaces/tarutils.py28
8 files changed, 838 insertions, 347 deletions
diff --git a/linuxnamespaces/__init__.py b/linuxnamespaces/__init__.py
index a2c7985..83358b6 100644
--- a/linuxnamespaces/__init__.py
+++ b/linuxnamespaces/__init__.py
@@ -6,258 +6,61 @@ Python.
"""
import asyncio
-import bisect
import contextlib
-import dataclasses
import errno
import fcntl
+import logging
import os
import pathlib
import socket
import stat
import struct
-import subprocess
import typing
from .filedescriptor import *
+from .idmap import *
from .atlocation import *
from .syscalls import *
-def subidranges(
- kind: typing.Literal["uid", "gid"], login: str | None = None
-) -> typing.Iterator[tuple[int, int]]:
- """Parse a `/etc/sub?id` file for ranges allocated to the given or current
- user. Return all ranges as (start, count) pairs.
- """
- if login is None:
- login = os.getlogin()
- with open(f"/etc/sub{kind}") as filelike:
- for line in filelike:
- parts = line.strip().split(":")
- if parts[0] == login:
- yield (int(parts[1]), int(parts[2]))
-
-
-@dataclasses.dataclass(frozen=True)
-class IDMapping:
- """Represent one range in a user or group id mapping."""
-
- innerstart: int
- outerstart: int
- count: int
-
- def __post_init__(self) -> None:
- if self.outerstart < 0:
- raise ValueError("outerstart must not be negative")
- if self.innerstart < 0:
- raise ValueError("innerstart must not be negative")
- if self.count <= 0:
- raise ValueError("count must be positive")
- if self.outerstart + self.count >= 1 << 64:
- raise ValueError("outerstart + count exceed 64bits")
- if self.innerstart + self.count >= 1 << 64:
- raise ValueError("innerstart + count exceed 64bits")
-
- @classmethod
- def identity(cls, idn: int, count: int = 1) -> typing.Self:
- """Construct an identity mapping for the given identifier."""
- return cls(idn, idn, count)
-
-
-class IDAllocation:
- """This represents a subset of IDs (user or group). It can be used to
- allocate a contiguous range for use with a user namespace.
- """
-
- def __init__(self) -> None:
- self.ranges: list[tuple[int, int]] = []
-
- def add_range(self, start: int, count: int) -> None:
- """Add count ids starting from start to this allocation."""
- if start < 0 or count <= 0:
- raise ValueError("invalid range")
- index = bisect.bisect_right(self.ranges, (start, 0))
- prevrange = None
- if index > 0:
- prevrange = self.ranges[index - 1]
- if prevrange[0] + prevrange[1] > start:
- raise ValueError("attempt to add overlapping range")
- nextrange = None
- if index < len(self.ranges):
- nextrange = self.ranges[index]
- if nextrange[0] < start + count:
- raise ValueError("attempt to add overlapping range")
- if prevrange and prevrange[0] + prevrange[1] == start:
- if nextrange and nextrange[0] == start + count:
- self.ranges[index - 1] = (
- prevrange[0],
- prevrange[1] + count + nextrange[1],
- )
- del self.ranges[index]
- else:
- self.ranges[index - 1] = (prevrange[0], prevrange[1] + count)
- elif nextrange and nextrange[0] == start + count:
- self.ranges[index] = (start, count + nextrange[1])
- else:
- self.ranges.insert(index, (start, count))
-
- @classmethod
- def loadsubid(
- cls, kind: typing.Literal["uid", "gid"], login: str | None = None,
- ) -> "IDAllocation":
- """Load a `/etc/sub?id` file and return ids allocated to the given
- login or current user.
- """
- self = cls()
- for start, count in subidranges(kind, login):
- self.add_range(start, count)
- return self
-
- def find(self, count: int) -> int:
- """Locate count contiguous ids from this allocation. The start of
- the allocation is returned. The allocation object is left unchanged.
- """
- for start, available in self.ranges:
- if available >= count:
- return start
- raise ValueError("could not satisfy allocation request")
-
- def allocate(self, count: int) -> int:
- """Allocate count contiguous ids from this allocation. The start of
- the allocation is returned and the ids are removed from this
- IDAllocation object.
- """
- for index, (start, available) in enumerate(self.ranges):
- if available > count:
- self.ranges[index] = (start + count, available - count)
- return start
- if available == count:
- del self.ranges[index]
- return start
- raise ValueError("could not satisfy allocation request")
-
- def allocatemap(self, count: int, target: int = 0) -> IDMapping:
- """Allocate count contiguous ids from this allocation. An IDMapping
- with its innerstart set to target is returned. The allocation is
- removed from this IDAllocation object.
- """
- return IDMapping(target, self.allocate(count), count)
-
- def reserve(self, start: int, count: int) -> None:
- """Reserve (and remove) the given range from this allocation. If the
- range is not fully contained in this allocation, a ValueError is
- raised.
- """
- if count < 0:
- raise ValueError("negative count")
- index = bisect.bisect_right(self.ranges, (start, float("inf"))) - 1
- if index < 0:
- raise ValueError("range to reserve not found")
- cur_start, cur_count = self.ranges[index]
- assert cur_start <= start
- if cur_start == start:
- # Requested range starts at range boundary
- if cur_count < count:
- raise ValueError("range to reserve not found")
- if cur_count == count:
- # Requested range matches a range exactly
- del self.ranges[index]
- else:
- # Requested range is a head of the matched range
- self.ranges[index] = (start + count, cur_count - count)
- elif cur_start + cur_count >= start + count:
- # Requested range fits into a matched range
- self.ranges[index] = (cur_start, start - cur_start)
- if cur_start + cur_count > start + count:
- # Requested range punches a hole into a matched range
- self.ranges.insert(
- index + 1,
- (start + count, cur_start + cur_count - (start + count)),
- )
- # else: Requested range is a tail of a matched range
- else:
- raise ValueError("range to reserve not found")
-
-
-def newidmap(
- kind: typing.Literal["uid", "gid"],
- pid: int,
- mapping: list[IDMapping],
- helper: bool | None = None,
-) -> None:
- """Apply the given uid or gid mapping to the given process. A positive pid
- identifies a process, other values identify the currently running process.
- Whether setuid binaries newuidmap and newgidmap are used is determined via
- the helper argument. A None value indicate automatic detection of whether
- a helper is required for setting up the given mapping.
- """
-
- assert kind in ("uid", "gid")
- if pid <= 0:
- pid = os.getpid()
- if helper is None:
- # We cannot reliably test whether we have the right EUID and we don't
- # implement checking whether setgroups has been denied either. Please
- # be explicit about the helper choice in such cases.
- helper = len(mapping) > 1 or mapping[0].count > 1
- if helper:
- argv = [f"new{kind}map", str(pid)]
- for idblock in mapping:
- argv.extend(map(str, dataclasses.astuple(idblock)))
- subprocess.check_call(argv)
- else:
- pathlib.Path(f"/proc/{pid}/{kind}_map").write_text(
- "".join(
- "%d %d %d\n" % dataclasses.astuple(idblock)
- for idblock in mapping
- ),
- encoding="ascii",
- )
-
-
-def newuidmap(pid: int, mapping: list[IDMapping], helper: bool = True) -> None:
- """Apply a given uid mapping to the given process. Refer to newidmap for
- details.
- """
- newidmap("uid", pid, mapping, helper)
-
-
-def newgidmap(pid: int, mapping: list[IDMapping], helper: bool = True) -> None:
- """Apply a given gid mapping to the given process. Refer to newidmap for
- details.
- """
- newidmap("gid", pid, mapping, helper)
-
-
-def newidmaps(
- pid: int,
- uidmapping: list[IDMapping],
- gidmapping: list[IDMapping],
- helper: bool = True,
-) -> None:
- """Apply a given uid and gid mapping to the given process. Refer to
- newidmap for details.
- """
- newgidmap(pid, gidmapping, helper)
- newuidmap(pid, uidmapping, helper)
+_logger = logging.getLogger(__name__)
class run_in_fork:
"""Decorator for running the decorated function once in a separate process.
"""
- def __init__(self, function: typing.Callable[[], None]):
- """Fork a new process that will eventually run the given function and
- then exit.
+ def __init__(
+ self, function: typing.Callable[[], None], start: bool = False
+ ):
+ """Fork a new process that will run the given function and then exit.
+ If start is true, run it immediately, otherwise the start or __call__
+ method should be used.
"""
- self.efd = EventFD()
+ self.efd = None if start else EventFD()
self.pid = os.fork()
if self.pid == 0:
- self.efd.read()
- self.efd.close()
- function()
- os._exit(0)
+ code = 0
+ try:
+ if self.efd is not None:
+ self.efd.read()
+ self.efd.close()
+ self.efd = None
+ function()
+ except SystemExit as err:
+ code = err.code
+ except:
+ _logger.exception(
+ "uncaught exception in run_in_fork %r", function
+ )
+ code = 1
+ os._exit(code)
+
+ @classmethod
+ def now(cls, function: typing.Callable[[], None]) -> typing.Self:
+ """Fork a new process that will immediately run the given function and
+ then exit."""
+ return cls(function, start=True)
def start(self) -> None:
"""Start the decorated function. It can only be started once."""
@@ -265,6 +68,7 @@ class run_in_fork:
raise ValueError("this function can only be called once")
self.efd.write(1)
self.efd.close()
+ self.efd = None
def wait(self) -> None:
"""Wait for the process running the decorated function to finish."""
@@ -275,8 +79,11 @@ class run_in_fork:
raise ValueError("something failed")
def __call__(self) -> None:
- """Start the decorated function and wait for its process to finish."""
- self.start()
+ """Start the decorated function if needed and wait for its process to
+ finish.
+ """
+ if self.efd:
+ self.start()
self.wait()
@@ -287,9 +94,12 @@ class async_run_in_fork:
synchronous and it must not access the event loop of the main process.
"""
- def __init__(self, function: typing.Callable[[], None]):
- """Fork a new process that will eventually run the given function and
- then exit.
+ def __init__(
+ self, function: typing.Callable[[], None], start: bool = False
+ ):
+ """Fork a new process that will run the given function and then exit.
+ If start is true, run it immediately, otherwise the start or __call__
+ method should be used.
"""
loop = asyncio.get_running_loop()
with asyncio.get_child_watcher() as watcher:
@@ -298,15 +108,33 @@ class async_run_in_fork:
"active child watcher required for creating a process"
)
self.future = loop.create_future()
- self.efd = EventFD()
+ self.efd = None if start else EventFD()
self.pid = os.fork()
if self.pid == 0:
- self.efd.read()
- self.efd.close()
- function()
- os._exit(0)
+ code = 0
+ try:
+ if self.efd:
+ self.efd.read()
+ self.efd.close()
+ self.efd = None
+ asyncio.set_event_loop(None)
+ function()
+ except SystemExit as err:
+ code = err.code
+ except:
+ _logger.exception(
+ "uncaught exception in run_in_fork %r", function
+ )
+ code = 1
+ os._exit(code)
watcher.add_child_handler(self.pid, self._child_callback)
+ @classmethod
+ def now(cls, function: typing.Callable[[], None]) -> typing.Self:
+ """Fork a new process that will immediately run the given function and
+ then exit."""
+ return cls(function, start=True)
+
def _child_callback(self, pid: int, returncode: int) -> None:
if self.pid != pid:
return
@@ -318,6 +146,7 @@ class async_run_in_fork:
raise ValueError("this function can only be called once")
self.efd.write(1)
self.efd.close()
+ self.efd = None
async def wait(self) -> None:
"""Wait for the process running the decorated function to finish."""
@@ -328,8 +157,11 @@ class async_run_in_fork:
raise ValueError("something failed")
async def __call__(self) -> None:
- """Start the decorated function and wait for its process to finish."""
- self.start()
+ """Start the decorated function if needed and wait for its process to
+ finish.
+ """
+ if self.efd:
+ self.start()
await self.wait()
@@ -353,7 +185,7 @@ def bind_mount(
srcloc = os.fspath(source)
tgtloc = os.fspath(target)
except ValueError:
- otflags = OpenTreeFlags.OPEN_TREE_CLONE
+ otflags = OpenTreeFlags.CLONE
if recursive:
otflags |= OpenTreeFlags.AT_RECURSIVE
with open_tree(source, otflags) as srcfd:
@@ -367,6 +199,17 @@ def bind_mount(
mount(srcloc, tgtloc, None, mflags)
+def get_cgroup(pid: int = -1) -> pathlib.PurePath:
+ """Look up the cgroup that the given pid or the current process belongs
+ to.
+ """
+ return pathlib.PurePath(
+ pathlib.Path(
+ f"/proc/{pid}/cgroup" if pid > 0 else "/proc/self/cgroup"
+ ).read_text().split(":", 2)[2].strip()
+ )
+
+
_P = typing.ParamSpec("_P")
class _ExceptionExitCallback:
@@ -402,7 +245,7 @@ def populate_dev(
newroot: PathConvertible,
*,
fuse: bool = True,
- pidns: bool = True,
+ pts: typing.Literal["defer", "host", "new", "absent"] = "new",
tun: bool = True,
) -> None:
"""Mount a tmpfs to the dev directory beneath newroot and populate it with
@@ -412,6 +255,12 @@ def populate_dev(
Even though a CAP_SYS_ADMIN-enabled process can umount components of the
/dev hierarchy, they they cannot gain privileges in doing so as no
hierarchies are restricted via tmpfs mounts or read-only bind mounts.
+
+ The /dev/fuse and /dev/net/tun devices are optional and can be enabled or
+ disabled as desired. /dev/pts (and /dev/ptmx) can be shared with the host
+ or mounted as a new instance. Since a PID namespace is usually required for
+ mounting a new instance, it can also be deferred to a later manual mount.
+ If not desired, it can be left absent.
"""
origdev = AtLocation(origroot) / "dev"
newdev = AtLocation(newroot) / "dev"
@@ -428,31 +277,31 @@ def populate_dev(
for fn in "null zero full random urandom tty".split():
files.add(fn)
bind_mounts[fn] = exitstack.enter_context(
- open_tree(origdev / fn, OpenTreeFlags.OPEN_TREE_CLONE)
+ open_tree(origdev / fn, OpenTreeFlags.CLONE)
)
if fuse:
files.add("fuse")
bind_mounts["fuse"] = exitstack.enter_context(
- open_tree(origdev / "fuse", OpenTreeFlags.OPEN_TREE_CLONE)
+ open_tree(origdev / "fuse", OpenTreeFlags.CLONE)
)
- if pidns:
- symlinks["ptmx"] = "pts/ptmx"
- else:
+ if pts == "host":
bind_mounts["pts"] = exitstack.enter_context(
open_tree(
origdev / "pts",
- OpenTreeFlags.AT_RECURSIVE | OpenTreeFlags.OPEN_TREE_CLONE,
+ OpenTreeFlags.AT_RECURSIVE | OpenTreeFlags.CLONE,
)
)
files.add("ptmx")
bind_mounts["ptmx"] = exitstack.enter_context(
- open_tree(origdev / "ptmx", OpenTreeFlags.OPEN_TREE_CLONE)
+ open_tree(origdev / "ptmx", OpenTreeFlags.CLONE)
)
+ elif pts != "absent":
+ symlinks["ptmx"] = "pts/ptmx"
if tun:
directories.add("net")
files.add("net/tun")
bind_mounts["net/tun"] = exitstack.enter_context(
- open_tree(origdev / "net/tun", OpenTreeFlags.OPEN_TREE_CLONE)
+ open_tree(origdev / "net/tun", OpenTreeFlags.CLONE)
)
mount(
"devtmpfs",
@@ -477,7 +326,7 @@ def populate_dev(
(newdev / fn).mknod(stat.S_IFREG)
for fn, target in symlinks.items():
(newdev / fn).symlink_to(target)
- if pidns:
+ if pts == "new":
mount(
"devpts",
newdev / "pts",
@@ -519,7 +368,7 @@ def populate_proc(
if namespaces & CloneFlags.NEWNET == CloneFlags.NEWNET:
psn = open_tree(
newproc / "sys/net",
- OpenTreeFlags.OPEN_TREE_CLONE | OpenTreeFlags.AT_RECURSIVE,
+ OpenTreeFlags.CLONE | OpenTreeFlags.AT_RECURSIVE,
)
bind_mount(newproc / "sys", newproc / "sys", True, True)
if psn is not None:
@@ -575,7 +424,7 @@ def populate_sys(
bindfd = exitstack.enter_context(
open_tree(
AtLocation(origroot) / "sys" / source,
- OpenTreeFlags.OPEN_TREE_CLONE | OpenTreeFlags.AT_RECURSIVE,
+ OpenTreeFlags.CLONE | OpenTreeFlags.AT_RECURSIVE,
),
)
if rdonly:
@@ -617,8 +466,13 @@ def unshare_user_idmap(
unshare(flags)
setup_idmaps()
+
def unshare_user_idmap_nohelper(
- uid: int, gid: int, flags: CloneFlags = CloneFlags.NEWUSER
+ uid: int,
+ gid: int,
+ flags: CloneFlags = CloneFlags.NEWUSER,
+ *,
+ proc: AtLocationLike | None = None,
) -> None:
"""Unshare the given namespaces (must include user) and
map the current user and group to the given uid and gid
@@ -627,14 +481,20 @@ def unshare_user_idmap_nohelper(
uidmap = IDMapping(uid, os.getuid(), 1)
gidmap = IDMapping(gid, os.getgid(), 1)
unshare(flags)
- pathlib.Path("/proc/self/setgroups").write_bytes(b"deny")
- newidmaps(-1, [uidmap], [gidmap], False)
+ proc = AtLocation("/proc" if proc is None else proc)
+ (proc / "self/setgroups").write_bytes(b"deny")
+ newidmaps(-1, [uidmap], [gidmap], False, proc=proc)
class _AsyncFilesender:
bs = 65536
- def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
+ def __init__(
+ self,
+ from_fd: FileDescriptor,
+ to_fd: FileDescriptor,
+ count: int | None = None,
+ ):
self.from_fd = from_fd
self.to_fd = to_fd
self.copied = 0
@@ -667,7 +527,12 @@ class _AsyncFilesender:
class _AsyncSplicer:
bs = 65536
- def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
+ def __init__(
+ self,
+ from_fd: FileDescriptor,
+ to_fd: FileDescriptor,
+ count: int | None = None,
+ ):
self.from_fd = from_fd
self.to_fd = to_fd
self.copied = 0
@@ -711,7 +576,12 @@ class _AsyncSplicer:
class _AsyncCopier:
bs = 65536
- def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
+ def __init__(
+ self,
+ from_fd: FileDescriptor,
+ to_fd: FileDescriptor,
+ count: int | None = None,
+ ):
self.from_fd = from_fd
self.to_fd = to_fd
self.buffer = b""
@@ -775,13 +645,17 @@ class _AsyncCopier:
def async_copyfd(
- from_fd: int, to_fd: int, count: int | None = None
+ from_fd: FileDescriptorLike,
+ to_fd: FileDescriptorLike,
+ count: int | None = None,
) -> asyncio.Future[int]:
"""Copy the given number of bytes from the first file descriptor to the
second file descriptor in an asyncio context. Both copies are performed
binary. An efficient implementation is chosen depending on the file type
of file descriptors.
"""
+ from_fd = FileDescriptor(from_fd)
+ to_fd = FileDescriptor(to_fd)
from_mode = os.fstat(from_fd).st_mode
if stat.S_ISREG(from_mode):
return _AsyncFilesender(from_fd, to_fd, count).fut
@@ -791,7 +665,7 @@ def async_copyfd(
class _AsyncPidfdWaiter:
- def __init__(self, pidfd: int, flags: int):
+ def __init__(self, pidfd: FileDescriptor, flags: int):
self.pidfd = pidfd
self.flags = flags
self.loop = asyncio.get_running_loop()
@@ -816,12 +690,12 @@ class _AsyncPidfdWaiter:
def async_waitpidfd(
- pidfd: int, flags: int
+ pidfd: FileDescriptorLike, flags: int
) -> asyncio.Future[os.waitid_result | None]:
"""Asynchronously wait for a process represented as a pidfd. This is an
async variant of waitid(P_PIDFD, pidfd, flags).
"""
- return _AsyncPidfdWaiter(pidfd, flags).fut
+ return _AsyncPidfdWaiter(FileDescriptor(pidfd), flags).fut
def enable_loopback_if() -> None:
diff --git a/linuxnamespaces/atlocation.py b/linuxnamespaces/atlocation.py
index 20d402a..46ac541 100644
--- a/linuxnamespaces/atlocation.py
+++ b/linuxnamespaces/atlocation.py
@@ -9,13 +9,14 @@ code for doing so.
import enum
import errno
+import locale
import os
import os.path
import pathlib
import stat
import typing
-from .filedescriptor import FileDescriptor
+from .filedescriptor import FileDescriptor, FileDescriptorLike, HasFileno
AT_FDCWD = FileDescriptor(-100)
@@ -58,7 +59,7 @@ class AtLocation:
def __new__(
cls,
- thing: typing.Union["AtLocation", int, PathConvertible],
+ thing: typing.Union["AtLocation", FileDescriptorLike, PathConvertible],
location: PathConvertible | None = None,
flags: AtFlags = AtFlags.NONE,
) -> "AtLocation":
@@ -76,13 +77,14 @@ class AtLocation:
)
return thing # Don't copy.
obj = super(AtLocation, cls).__new__(cls)
- if isinstance(thing, int):
+ if not isinstance(thing, FileDescriptor):
+ if isinstance(thing, (int, HasFileno)):
+ thing = FileDescriptor(thing)
+ if isinstance(thing, FileDescriptor):
if thing < 0 and thing != AT_FDCWD:
raise ValueError("fd cannot be negative")
if isinstance(thing, FileDescriptor):
obj.fd = thing
- else:
- obj.fd = FileDescriptor(thing)
if location is None:
obj.location = ""
obj.flags = flags | AtFlags.AT_EMPTY_PATH
@@ -148,7 +150,7 @@ class AtLocation:
them with a slash as separator. The returned AtLocation borrows its fd
if any.
"""
- if isinstance(other, int):
+ if isinstance(other, (int, HasFileno)):
# A an fd is considered an absolute AT_EMPTY_PATH path.
return AtLocation(other)
non_empty_flags = self.flags & ~AtFlags.AT_EMPTY_PATH
@@ -218,7 +220,12 @@ class AtLocation:
"chdir on AtLocation only supports flag AT_EMPTY_PATH"
)
assert self.location
- return os.chdir(self.location)
+ if self.fd == AT_FDCWD:
+ return os.chdir(self.location)
+ with FileDescriptor(
+ self.open(flags=os.O_PATH | os.O_CLOEXEC)
+ ) as dirfd:
+ return os.fchdir(dirfd)
def chmod(self, mode: int) -> None:
"""Wrapper for os.chmod or os.fchmod."""
@@ -417,7 +424,7 @@ class AtLocation:
assert self.location
os.mknod(self.location, mode, device, dir_fd=self.fd_or_none)
- def open(self, flags: int, mode: int = 0o777) -> int:
+ def open(self, flags: int, mode: int = 0o777) -> FileDescriptor:
"""Wrapper for os.open supplying path and dir_fd."""
if self.flags == AtFlags.AT_SYMLINK_NOFOLLOW:
flags |= os.O_NOFOLLOW
@@ -426,7 +433,9 @@ class AtLocation:
"opening an AtLocation only supports flag AT_SYMLINK_NOFOLLOW"
)
assert self.location
- return os.open(self.location, flags, mode, dir_fd=self.fd_or_none)
+ return FileDescriptor(
+ os.open(self.location, flags, mode, dir_fd=self.fd_or_none)
+ )
def readlink(self) -> str:
"""Wrapper for os.readlink supplying path and dir_fd."""
@@ -543,6 +552,26 @@ class AtLocation:
AtLocation(dirfd),
)
+ def write_bytes(self, data: bytes) -> None:
+ """Overwrite the file with the given data bytes."""
+ dataview = memoryview(data)
+ with self.open(os.O_CREAT | os.O_WRONLY) as fd:
+ while dataview:
+ written = os.write(fd, dataview)
+ dataview = dataview[written:]
+
+ def write_text(
+ self, data: str, encoding: str | None = None, errors: str | None = None
+ ) -> None:
+ """Overwrite the file with the given data string."""
+ if encoding is None:
+ encoding = locale.getencoding()
+ if errors is None:
+ databytes = data.encode(encoding=encoding)
+ else:
+ databytes = data.encode(encoding=encoding, errors=errors)
+ self.write_bytes(databytes)
+
def __enter__(self) -> "AtLocation":
"""When used as a context manager, the associated fd will be closed on
scope exit.
@@ -590,4 +619,4 @@ class AtLocation:
return f"{cn}({self.fd}, flags={self.flags!r})"
-AtLocationLike = typing.Union[AtLocation, int, PathConvertible]
+AtLocationLike = typing.Union[AtLocation, FileDescriptorLike, PathConvertible]
diff --git a/linuxnamespaces/filedescriptor.py b/linuxnamespaces/filedescriptor.py
index e4eff9b..ee96a94 100644
--- a/linuxnamespaces/filedescriptor.py
+++ b/linuxnamespaces/filedescriptor.py
@@ -8,11 +8,33 @@ import os
import typing
+# pylint: disable=too-few-public-methods # It's that one method we describe.
+@typing.runtime_checkable
+class HasFileno(typing.Protocol):
+ """A typing protocol representing a file-like object and looking up the
+ underlying file descriptor.
+ """
+
+ def fileno(self) -> int:
+ """Return the underlying file descriptor."""
+
+
+FileDescriptorLike = int | HasFileno
+
+
class FileDescriptor(int):
"""Type tag for integers that represent file descriptors. It also provides
a few very generic file descriptor methods.
"""
+ def __new__(cls, value: FileDescriptorLike) -> typing.Self:
+ """Construct a FileDescriptor from an int or HasFileno."""
+ if isinstance(value, cls):
+ return value # No need to copy, it's immutable.
+ if not isinstance(value, int):
+ value = value.fileno()
+ return super(FileDescriptor, cls).__new__(cls, value)
+
def __enter__(self) -> "FileDescriptor":
"""When used as a context manager, close the file descriptor on scope
exit.
@@ -37,11 +59,18 @@ class FileDescriptor(int):
return FileDescriptor(os.dup(self))
return FileDescriptor(fcntl.fcntl(self, fcntl.F_DUPFD_CLOEXEC, 0))
- def dup2(self, fd2: int, inheritable: bool = True) -> "FileDescriptor":
+ def dup2(
+ self, fd2: FileDescriptorLike, inheritable: bool = True
+ ) -> "FileDescriptor":
"""Duplicate the file to the given file descriptor number."""
- return FileDescriptor(os.dup2(self, fd2, inheritable))
+ return FileDescriptor(os.dup2(self, FileDescriptor(fd2), inheritable))
- def fileno(self) -> int:
+ @classmethod
+ def pidfd_open(cls, pid: int, flags: int = 0) -> typing.Self:
+ """Convenience wrapper for os.pidfd_open."""
+ return cls(os.pidfd_open(pid, flags))
+
+ def fileno(self) -> "FileDescriptor":
"""Return self such that it satisfies the HasFileno protocol."""
return self
diff --git a/linuxnamespaces/idmap.py b/linuxnamespaces/idmap.py
new file mode 100644
index 0000000..a10ec12
--- /dev/null
+++ b/linuxnamespaces/idmap.py
@@ -0,0 +1,250 @@
+# Copyright 2024-2025 Helmut Grohne <helmut@subdivi.de>
+# SPDX-License-Identifier: GPL-3
+
+"""Provide functionalit related to mapping user and group ids in a user
+namespace.
+"""
+
+import bisect
+import dataclasses
+import os
+import subprocess
+import typing
+
+from .atlocation import AtLocation, AtLocationLike
+
+
+def subidranges(
+ kind: typing.Literal["uid", "gid"], login: str | None = None
+) -> typing.Iterator[tuple[int, int]]:
+ """Parse a `/etc/sub?id` file for ranges allocated to the given or current
+ user. Return all ranges as (start, count) pairs.
+ """
+ if login is None:
+ login = os.getlogin()
+ with open(f"/etc/sub{kind}") as filelike:
+ for line in filelike:
+ parts = line.strip().split(":")
+ if parts[0] == login:
+ yield (int(parts[1]), int(parts[2]))
+
+
+@dataclasses.dataclass(frozen=True)
+class IDMapping:
+ """Represent one range in a user or group id mapping."""
+
+ innerstart: int
+ outerstart: int
+ count: int
+
+ def __post_init__(self) -> None:
+ if self.outerstart < 0:
+ raise ValueError("outerstart must not be negative")
+ if self.innerstart < 0:
+ raise ValueError("innerstart must not be negative")
+ if self.count <= 0:
+ raise ValueError("count must be positive")
+ if self.outerstart + self.count >= 1 << 64:
+ raise ValueError("outerstart + count exceed 64bits")
+ if self.innerstart + self.count >= 1 << 64:
+ raise ValueError("innerstart + count exceed 64bits")
+
+ @classmethod
+ def identity(cls, idn: int, count: int = 1) -> typing.Self:
+ """Construct an identity mapping for the given identifier."""
+ return cls(idn, idn, count)
+
+
+class IDAllocation:
+ """This represents a subset of IDs (user or group). It can be used to
+ allocate a contiguous range for use with a user namespace.
+ """
+
+ def __init__(self) -> None:
+ self.ranges: list[tuple[int, int]] = []
+
+ def add_range(self, start: int, count: int) -> None:
+ """Add count ids starting from start to this allocation."""
+ if start < 0 or count <= 0:
+ raise ValueError("invalid range")
+ index = bisect.bisect_right(self.ranges, (start, 0))
+ prevrange = None
+ if index > 0:
+ prevrange = self.ranges[index - 1]
+ if prevrange[0] + prevrange[1] > start:
+ raise ValueError("attempt to add overlapping range")
+ nextrange = None
+ if index < len(self.ranges):
+ nextrange = self.ranges[index]
+ if nextrange[0] < start + count:
+ raise ValueError("attempt to add overlapping range")
+ if prevrange and prevrange[0] + prevrange[1] == start:
+ if nextrange and nextrange[0] == start + count:
+ self.ranges[index - 1] = (
+ prevrange[0],
+ prevrange[1] + count + nextrange[1],
+ )
+ del self.ranges[index]
+ else:
+ self.ranges[index - 1] = (prevrange[0], prevrange[1] + count)
+ elif nextrange and nextrange[0] == start + count:
+ self.ranges[index] = (start, count + nextrange[1])
+ else:
+ self.ranges.insert(index, (start, count))
+
+ @classmethod
+ def loadsubid(
+ cls, kind: typing.Literal["uid", "gid"], login: str | None = None,
+ ) -> "IDAllocation":
+ """Load a `/etc/sub?id` file and return ids allocated to the given
+ login or current user.
+ """
+ self = cls()
+ for start, count in subidranges(kind, login):
+ self.add_range(start, count)
+ return self
+
+ def find(self, count: int) -> int:
+ """Locate count contiguous ids from this allocation. The start of
+ the allocation is returned. The allocation object is left unchanged.
+ """
+ for start, available in self.ranges:
+ if available >= count:
+ return start
+ raise ValueError("could not satisfy allocation request")
+
+ def allocate(self, count: int) -> int:
+ """Allocate count contiguous ids from this allocation. The start of
+ the allocation is returned and the ids are removed from this
+ IDAllocation object.
+ """
+ for index, (start, available) in enumerate(self.ranges):
+ if available > count:
+ self.ranges[index] = (start + count, available - count)
+ return start
+ if available == count:
+ del self.ranges[index]
+ return start
+ raise ValueError("could not satisfy allocation request")
+
+ def allocatemap(self, count: int, target: int = 0) -> IDMapping:
+ """Allocate count contiguous ids from this allocation. An IDMapping
+ with its innerstart set to target is returned. The allocation is
+ removed from this IDAllocation object.
+ """
+ return IDMapping(target, self.allocate(count), count)
+
+ def reserve(self, start: int, count: int) -> None:
+ """Reserve (and remove) the given range from this allocation. If the
+ range is not fully contained in this allocation, a ValueError is
+ raised.
+ """
+ if count < 0:
+ raise ValueError("negative count")
+ index = bisect.bisect_right(self.ranges, (start, float("inf"))) - 1
+ if index < 0:
+ raise ValueError("range to reserve not found")
+ cur_start, cur_count = self.ranges[index]
+ assert cur_start <= start
+ if cur_start == start:
+ # Requested range starts at range boundary
+ if cur_count < count:
+ raise ValueError("range to reserve not found")
+ if cur_count == count:
+ # Requested range matches a range exactly
+ del self.ranges[index]
+ else:
+ # Requested range is a head of the matched range
+ self.ranges[index] = (start + count, cur_count - count)
+ elif cur_start + cur_count >= start + count:
+ # Requested range fits into a matched range
+ self.ranges[index] = (cur_start, start - cur_start)
+ if cur_start + cur_count > start + count:
+ # Requested range punches a hole into a matched range
+ self.ranges.insert(
+ index + 1,
+ (start + count, cur_start + cur_count - (start + count)),
+ )
+ # else: Requested range is a tail of a matched range
+ else:
+ raise ValueError("range to reserve not found")
+
+
+def newidmap(
+ kind: typing.Literal["uid", "gid"],
+ pid: int,
+ mapping: list[IDMapping],
+ helper: bool | None = None,
+ *,
+ proc: AtLocationLike | None = None,
+) -> None:
+ """Apply the given uid or gid mapping to the given process. A positive pid
+ identifies a process, other values identify the currently running process.
+ Whether setuid binaries newuidmap and newgidmap are used is determined via
+ the helper argument. A None value indicate automatic detection of whether
+ a helper is required for setting up the given mapping.
+ """
+
+ assert kind in ("uid", "gid")
+ if pid <= 0:
+ pid = os.getpid()
+ if helper is None:
+ # We cannot reliably test whether we have the right EUID and we don't
+ # implement checking whether setgroups has been denied either. Please
+ # be explicit about the helper choice in such cases.
+ helper = len(mapping) > 1 or mapping[0].count > 1
+ if helper:
+ argv = [f"new{kind}map", str(pid)]
+ for idblock in mapping:
+ argv.extend(map(str, dataclasses.astuple(idblock)))
+ subprocess.check_call(argv)
+ else:
+ proc = AtLocation("/proc" if proc is None else proc)
+ (proc / f"{pid}/{kind}_map").write_text(
+ "".join(
+ "%d %d %d\n" % dataclasses.astuple(idblock)
+ for idblock in mapping
+ ),
+ encoding="ascii",
+ )
+
+
+def newuidmap(
+ pid: int,
+ mapping: list[IDMapping],
+ helper: bool = True,
+ *,
+ proc: AtLocationLike | None = None,
+) -> None:
+ """Apply a given uid mapping to the given process. Refer to newidmap for
+ details.
+ """
+ newidmap("uid", pid, mapping, helper, proc=proc)
+
+
+def newgidmap(
+ pid: int,
+ mapping: list[IDMapping],
+ helper: bool = True,
+ *,
+ proc: AtLocationLike | None = None,
+) -> None:
+ """Apply a given gid mapping to the given process. Refer to newidmap for
+ details.
+ """
+ newidmap("gid", pid, mapping, helper, proc=proc)
+
+
+def newidmaps(
+ pid: int,
+ uidmapping: list[IDMapping],
+ gidmapping: list[IDMapping],
+ helper: bool = True,
+ *,
+ proc: AtLocationLike | None = None,
+) -> None:
+ """Apply a given uid and gid mapping to the given process. Refer to
+ newidmap for details.
+ """
+ newgidmap(pid, gidmapping, helper, proc=proc)
+ newuidmap(pid, uidmapping, helper, proc=proc)
diff --git a/linuxnamespaces/syscalls.py b/linuxnamespaces/syscalls.py
index dd4a332..e9b0e44 100644
--- a/linuxnamespaces/syscalls.py
+++ b/linuxnamespaces/syscalls.py
@@ -12,17 +12,49 @@ import enum
import errno
import logging
import os
+import signal
import typing
+from .filedescriptor import FileDescriptor, FileDescriptorLike
from .atlocation import AtFlags, AtLocation, AtLocationLike, PathConvertible
-logger = logging.getLogger(__name__)
+_logger = logging.getLogger(__name__)
LIBC_SO = ctypes.CDLL(None, use_errno=True)
+if typing.TYPE_CHECKING:
+ CDataType = ctypes._CDataType # pylint: disable=protected-access
+else:
+ CDataType = typing.Any
+
+
+def _pad_fields(
+ fields: list[tuple[str, type[CDataType]]],
+ totalsize: int,
+ name: str,
+ padtype: type[CDataType] = ctypes.c_uint8,
+) -> list[tuple[str, type[CDataType]]]:
+ """Append a padding element to a ctypes.Structure _fields_ sequence such
+ that its total size matches a given value.
+ """
+ fieldssize = sum(ctypes.sizeof(ft) for _, ft in fields)
+ padsize = totalsize - fieldssize
+ if padsize < 0:
+ raise TypeError(
+ f"requested padding to {totalsize}, but fields consume {fieldssize}"
+ )
+ eltsize = ctypes.sizeof(padtype)
+ elements, remainder = divmod(padsize, eltsize)
+ if remainder:
+ raise TypeError(
+ f"padding {padsize} is not a multiple of the element size {eltsize}"
+ )
+ return fields + [(name, padtype * elements)]
+
+
class CloneFlags(enum.IntFlag):
"""This value may be supplied to
* unshare(2) flags
@@ -121,39 +153,38 @@ class MountFlags(enum.IntFlag):
# Map each flag to:
# * The flag value
# * Whether the flag value is negated
- # * Whether the flag must be negated
# * Whether the flag can be negated
__flagstrmap = {
- "acl": (POSIXACL, False, False, False),
- "async": (SYNCHRONOUS, True, False, False),
- "atime": (NOATIME, True, False, True),
- "bind": (BIND, False, False, False),
- "dev": (NODEV, True, False, True),
- "diratime": (NODIRATIME, True, False, True),
- "dirsync": (DIRSYNC, False, False, False),
- "exec": (NOEXEC, True, False, True),
- "iversion": (I_VERSION, False, False, True),
- "lazytime": (LAZYTIME, False, False, True),
- "loud": (SILENT, True, False, False),
- "mand": (MANDLOCK, False, False, True),
- "private": (PRIVATE, False, False, False),
- "rbind": (BIND | REC, False, False, False),
- "relatime": (RELATIME, False, False, True),
- "remount": (REMOUNT, False, False, True),
- "ro": (RDONLY, False, False, False),
- "rprivate": (PRIVATE | REC, False, False, False),
- "rshared": (SHARED | REC, False, False, False),
- "rslave": (SLAVE | REC, False, False, False),
- "runbindable": (UNBINDABLE | REC, False, False, False),
- "rw": (RDONLY, True, False, False),
- "shared": (SHARED, False, False, False),
- "silent": (SILENT, False, False, False),
- "slave": (SLAVE, False, False, False),
- "strictatime": (STRICTATIME, False, False, True),
- "suid": (NOSUID, True, False, True),
- "symfollow": (NOSYMFOLLOW, True, False, True),
- "sync": (SYNCHRONOUS, False, False, False),
- "unbindable": (UNBINDABLE, False, False, False),
+ "acl": (POSIXACL, False, False),
+ "async": (SYNCHRONOUS, True, False),
+ "atime": (NOATIME, True, True),
+ "bind": (BIND, False, False),
+ "dev": (NODEV, True, True),
+ "diratime": (NODIRATIME, True, True),
+ "dirsync": (DIRSYNC, False, False),
+ "exec": (NOEXEC, True, True),
+ "iversion": (I_VERSION, False, True),
+ "lazytime": (LAZYTIME, False, True),
+ "loud": (SILENT, True, False),
+ "mand": (MANDLOCK, False, True),
+ "private": (PRIVATE, False, False),
+ "rbind": (BIND | REC, False, False),
+ "relatime": (RELATIME, False, True),
+ "remount": (REMOUNT, False, True),
+ "ro": (RDONLY, False, False),
+ "rprivate": (PRIVATE | REC, False, False),
+ "rshared": (SHARED | REC, False, False),
+ "rslave": (SLAVE | REC, False, False),
+ "runbindable": (UNBINDABLE | REC, False, False),
+ "rw": (RDONLY, True, False),
+ "shared": (SHARED, False, False),
+ "silent": (SILENT, False, False),
+ "slave": (SLAVE, False, False),
+ "strictatime": (STRICTATIME, False, True),
+ "suid": (NOSUID, True, True),
+ "symfollow": (NOSYMFOLLOW, True, True),
+ "sync": (SYNCHRONOUS, False, False),
+ "unbindable": (UNBINDABLE, False, False),
}
def change(self, flagsstr: str) -> "MountFlags":
@@ -165,19 +196,23 @@ class MountFlags(enum.IntFlag):
for flagstr in flagsstr.split(","):
if not flagstr:
continue
- flag, negated, mustnegate, cannegate = self.__flagstrmap.get(
- flagstr.removeprefix("no"),
- (MountFlags.NONE, False, True, False),
- )
- if mustnegate <= flagstr.startswith("no") <= cannegate:
+ try:
+ flag, negated, cannegate = self.__flagstrmap[
+ flagstr.removeprefix("no")
+ ]
+ except KeyError:
+ raise ValueError(
+ f"not a valid mount flag: {flagstr!r}"
+ ) from None
+ else:
+ if flagstr.startswith("no") > cannegate:
+ raise ValueError(f"not a valid mount flag: {flagstr!r}")
if negated ^ flagstr.startswith("no"):
ret &= ~flag
else:
if flag & MountFlags.PROPAGATION_FLAGS:
ret &= ~MountFlags.PROPAGATION_FLAGS
ret |= flag
- else:
- raise ValueError(f"not a valid mount flag: {flagstr!r}")
return ret
@staticmethod
@@ -221,22 +256,25 @@ class MountFlags(enum.IntFlag):
reverse=True,
)
- def tostr(self) -> str:
- """Attempt to represent the flags in a comma-separated, textual way."""
+ def tonames(self) -> list[str]:
+ """Represent the flags as a sequence of list of flag names."""
if (self & MountFlags.PROPAGATION_FLAGS).bit_count() > 1:
raise ValueError("cannot represent conflicting propagation flags")
parts: list[str] = []
remain = self
for val, text in MountFlags.__flagvals:
- # Older mypy think MountFlags.__flagvals and thus text was of type
- # MountFlags.
+ # Older mypy wrongly deduces the type of MountFlags.__flagvals.
assert isinstance(text, str)
if remain & val == val:
parts.insert(0, text)
remain &= ~val
if remain:
raise ValueError("cannot represent flags {remain}")
- return ",".join(parts)
+ return parts
+
+ def tostr(self) -> str:
+ """Represent the flags in a comma-separated, textual way."""
+ return ",".join(self.tonames())
class MountSetattrFlags(enum.IntFlag):
@@ -328,15 +366,15 @@ class OpenTreeFlags(enum.IntFlag):
"""This value may be supplied to open_tree(2) as flags."""
NONE = 0
- OPEN_TREE_CLONE = 0x1
- OPEN_TREE_CLOEXEC = os.O_CLOEXEC
+ CLONE = 0x1
+ CLOEXEC = os.O_CLOEXEC
AT_SYMLINK_NOFOLLOW = 0x100
AT_NO_AUTOMOUNT = 0x800
AT_EMPTY_PATH = 0x1000
AT_RECURSIVE = 0x8000
ALL_FLAGS = (
- OPEN_TREE_CLONE
- | OPEN_TREE_CLOEXEC
+ CLONE
+ | CLOEXEC
| AT_SYMLINK_NOFOLLOW
| AT_NO_AUTOMOUNT
| AT_EMPTY_PATH
@@ -348,10 +386,47 @@ class PrctlOption(enum.IntEnum):
"""This value may be supplied to prctl(2) as option."""
PR_SET_PDEATHSIG = 1
+ PR_SET_DUMPABLE = 4
PR_SET_CHILD_SUBREAPER = 36
PR_CAP_AMBIENT = 47
+class SignalFDSigInfo(ctypes.Structure):
+ """Information about a received signal by reading from a signalfd(2)."""
+
+ _fields_ = _pad_fields(
+ [
+ ("ssi_signo", ctypes.c_uint32),
+ ("ssi_errno", ctypes.c_int32),
+ ("ssi_code", ctypes.c_int32),
+ ("ssi_pid", ctypes.c_uint32),
+ ("ssi_uid", ctypes.c_uint32),
+ ("ssi_fd", ctypes.c_int32),
+ ("ssi_tid", ctypes.c_uint32),
+ ("ssi_band", ctypes.c_uint32),
+ ("ssi_overrun", ctypes.c_uint32),
+ ("ssi_trapno", ctypes.c_uint32),
+ ("ssi_status", ctypes.c_int32),
+ ("ssi_int", ctypes.c_int32),
+ ("ssi_ptr", ctypes.c_uint64),
+ ("ssi_utime", ctypes.c_uint64),
+ ("ssi_stime", ctypes.c_uint64),
+ ("ssi_addr", ctypes.c_uint64),
+ ("ssi_addr_lsb", ctypes.c_uint16),
+ ],
+ 128,
+ "padding",
+ )
+
+
+class SignalFDFlags(enum.IntFlag):
+ """This value may be supplied as flags to signalfd(2)."""
+
+ NONE = 0
+ CLOEXEC = os.O_CLOEXEC
+ NONBLOCK = os.O_NONBLOCK
+
+
class UmountFlags(enum.IntFlag):
"""This value may be supplied to umount2(2) as flags."""
@@ -368,9 +443,9 @@ def call_libc(funcname: str, *args: typing.Any) -> int:
the function returns an integer that is non-negative on success. On
failure, an OSError with errno is raised.
"""
- logger.debug("calling libc function %s%r", funcname, args)
+ _logger.debug("calling libc function %s%r", funcname, args)
ret: int = LIBC_SO[funcname](*args)
- logger.debug("%s returned %d", funcname, ret)
+ _logger.debug("%s returned %d", funcname, ret)
if ret < 0:
err = ctypes.get_errno()
raise OSError(
@@ -428,7 +503,7 @@ class EventFD:
) -> None:
if flags & ~EventFDFlags.ALL_FLAGS:
raise ValueError("invalid flags for eventfd")
- self.fd = os.eventfd(initval, int(flags))
+ self.fd = FileDescriptor(os.eventfd(initval, int(flags)))
def read(self) -> int:
"""Decrease the value of the eventfd using eventfd_read."""
@@ -471,7 +546,7 @@ class EventFD:
raise ValueError("attempt to read from closed eventfd")
os.eventfd_write(self.fd, value)
- def fileno(self) -> int:
+ def fileno(self) -> FileDescriptor:
"""Return the underlying file descriptor."""
return self.fd
@@ -481,7 +556,7 @@ class EventFD:
try:
os.close(self.fd)
finally:
- self.fd = -1
+ self.fd = FileDescriptor(-1)
__del__ = close
@@ -489,7 +564,7 @@ class EventFD:
"""Return True unless the eventfd is closed."""
return self.fd >= 0
- def __enter__(self) -> "EventFD":
+ def __enter__(self) -> typing.Self:
"""When used as a context manager, the EventFD is closed on scope exit.
"""
return self
@@ -508,7 +583,7 @@ def mount(
target: PathConvertible,
filesystemtype: str | None,
flags: MountFlags = MountFlags.NONE,
- data: str | list[str] | None = None,
+ data: str | list[str] | dict[str, str | int | None] | None = None,
) -> None:
"""Python wrapper for mount(2)."""
if (flags & MountFlags.PROPAGATION_FLAGS).bit_count() > 1:
@@ -520,6 +595,11 @@ def mount(
)
):
raise ValueError("invalid flags for mount")
+ if isinstance(data, dict):
+ data = [
+ key if value is None else f"{key}={value}"
+ for key, value in data.items()
+ ]
if isinstance(data, list):
if any("," in s for s in data):
raise ValueError("data elements must not contain a comma")
@@ -540,7 +620,7 @@ def mount_setattr(
attr_set: MountAttrFlags = MountAttrFlags.NONE,
attr_clr: MountAttrFlags = MountAttrFlags.NONE,
propagation: int = 0,
- userns_fd: int = -1,
+ userns_fd: FileDescriptorLike = -1,
) -> None:
"""Python wrapper for mount_setattr(2)."""
filesystem = AtLocation(filesystem)
@@ -549,6 +629,8 @@ def mount_setattr(
flags |= MountSetattrFlags.AT_RECURSIVE
if attr_clr & MountAttrFlags.IDMAP:
raise ValueError("cannot clear the MOUNT_ATTR_IDMAP flag")
+ if not isinstance(userns_fd, int):
+ userns_fd = userns_fd.fileno()
attr = MountAttr(attr_set, attr_clr, propagation, userns_fd)
call_libc(
"mount_setattr",
@@ -613,7 +695,7 @@ def open_tree(
raise ValueError("invalid flags for open_tree")
if (
flags & OpenTreeFlags.AT_RECURSIVE
- and not flags & OpenTreeFlags.OPEN_TREE_CLONE
+ and not flags & OpenTreeFlags.CLONE
):
raise ValueError("invalid flags for open_tree")
if source.flags & AtFlags.AT_SYMLINK_NOFOLLOW:
@@ -670,6 +752,11 @@ def prctl_set_child_subreaper(enabled: bool = True) -> None:
prctl(PrctlOption.PR_SET_CHILD_SUBREAPER, int(enabled))
+def prctl_set_dumpable(enabled: bool) -> None:
+ """Set or clear the dumpable flag."""
+ prctl(PrctlOption.PR_SET_DUMPABLE, int(enabled))
+
+
def prctl_set_pdeathsig(signum: int) -> None:
"""Set the parent-death signal of the calling process."""
if signum < 0:
@@ -686,6 +773,163 @@ def setns(fd: int, nstype: CloneFlags = CloneFlags.NONE) -> None:
call_libc("setns", fd, int(nstype))
+class SignalFD:
+ """Represent a file descriptor returned from signalfd(2)."""
+
+ _ReadIterFut = asyncio.Future[tuple[list[SignalFDSigInfo], "_ReadIterFut"]]
+
+ def __init__(
+ self,
+ sigmask: typing.Iterable[signal.Signals],
+ flags: SignalFDFlags = SignalFDFlags.NONE,
+ ):
+ self.fd = SignalFD.__signalfd(FileDescriptor(-1), sigmask, flags)
+
+ @staticmethod
+ def __signalfd(
+ fd: FileDescriptor,
+ sigmask: typing.Iterable[signal.Signals],
+ flags: SignalFDFlags,
+ ) -> FileDescriptor:
+ """Python wrapper for signalfd(2)."""
+ bitsperlong = 8 * ctypes.sizeof(ctypes.c_ulong)
+ nval = 64 // bitsperlong
+ mask = [0] * nval
+ for sig in sigmask:
+ sigval = int(sig) - 1
+ mask[sigval // bitsperlong] |= 1 << (sigval % bitsperlong)
+ csigmask = (ctypes.c_ulong * nval)(*mask)
+ return FileDescriptor(call_libc("signalfd", fd, csigmask, int(flags)))
+
+ def readv(self, count: int) -> list[SignalFDSigInfo]:
+ """Read up to count signals from the signalfd."""
+ if count < 0:
+ raise ValueError("read count must be positive")
+ if self.fd < 0:
+ raise ValueError("attempt to read from closed signalfd")
+ res = [SignalFDSigInfo() for _ in range(count)]
+ cnt = os.readv(self.fd, res)
+ cnt //= ctypes.sizeof(SignalFDSigInfo)
+ return res[:cnt]
+
+ def read(self) -> SignalFDSigInfo:
+ """Read one signal from the signalfd."""
+ res = self.readv(1)
+ return res[0]
+
+ def __handle_read(
+ self, fd: int, fut: asyncio.Future[SignalFDSigInfo]
+ ) -> None:
+ try:
+ if fd != self.fd:
+ raise RuntimeError("SignalFD file descriptor changed")
+ try:
+ result = self.read()
+ except OSError as err:
+ if err.errno == errno.EAGAIN:
+ return
+ raise
+ except Exception as exc:
+ fut.get_loop().remove_reader(fd)
+ fut.set_exception(exc)
+ else:
+ fut.get_loop().remove_reader(fd)
+ fut.set_result(result)
+
+ def aread(self) -> typing.Awaitable[SignalFDSigInfo]:
+ """Asynchronously read one signal from the signalfd."""
+ if self.fd < 0:
+ raise ValueError("attempt to read from closed signalfd")
+ loop = asyncio.get_running_loop()
+ fut: asyncio.Future[SignalFDSigInfo] = loop.create_future()
+ loop.add_reader(self.fd, self.__handle_read, self.fd, fut)
+ return fut
+
+ def __handle_readiter(self, fd: int, fut: _ReadIterFut) -> None:
+ loop = fut.get_loop()
+ try:
+ if fd != self.fd:
+ raise RuntimeError("SignalFD file descriptor changed")
+ try:
+ # Attempt to read a full page worth of queued signals.
+ results = self.readv(32)
+ except OSError as err:
+ if err.errno == errno.EAGAIN:
+ return
+ raise
+ except Exception as exc:
+ loop.remove_reader(fd)
+ fut.set_exception(exc)
+ else:
+ nextfut: SignalFD._ReadIterFut = loop.create_future()
+ loop.add_reader(fd, self.__handle_readiter, self.fd, nextfut)
+ fut.set_result((results, nextfut))
+
+ async def areaditer(self) -> typing.AsyncIterator[SignalFDSigInfo]:
+ """Asynchronously read signals from the signalfd forever."""
+ if self.fd < 0:
+ raise ValueError("attempt to read from closed signalfd")
+ loop = asyncio.get_running_loop()
+ fut: SignalFD._ReadIterFut = loop.create_future()
+ loop.add_reader(self.fd, self.__handle_readiter, self.fd, fut)
+ while True:
+ results, fut = await fut
+ for result in results:
+ yield result
+
+ def fileno(self) -> FileDescriptor:
+ """Return the underlying file descriptor."""
+ return self.fd
+
+ def close(self) -> None:
+ """Close the underlying file descriptor."""
+ if self.fd >= 0:
+ try:
+ os.close(self.fd)
+ finally:
+ self.fd = FileDescriptor(-1)
+
+ __del__ = close
+
+ def __bool__(self) -> bool:
+ """Return True unless the signalfd is closed."""
+ return self.fd >= 0
+
+ def __enter__(self) -> typing.Self:
+ """When used as a context manager, the SignalFD is closed on scope
+ exit.
+ """
+ return self
+
+ def __exit__(
+ self,
+ exc_type: typing.Any,
+ exc_value: typing.Any,
+ traceback: typing.Any,
+ ) -> None:
+ self.close()
+
+
+class _SigqueueSigval(ctypes.Union):
+ _fields_ = [
+ ("sival_int", ctypes.c_int),
+ ("sival_ptr", ctypes.c_void_p),
+ ]
+
+
+def sigqueue(
+ pid: int, sig: signal.Signals, value: int | ctypes.c_void_p | None = None
+) -> None:
+ """Python wrapper for sigqueue(2)."""
+ if value is None:
+ sigval = _SigqueueSigval()
+ elif isinstance(value, int):
+ sigval = _SigqueueSigval(sival_int=value)
+ else:
+ sigval = _SigqueueSigval(sival_ptr=value)
+ call_libc("sigqueue", pid, int(sig), sigval)
+
+
def umount(
path: PathConvertible, flags: UmountFlags = UmountFlags.NONE
) -> None:
diff --git a/linuxnamespaces/systemd/__init__.py b/linuxnamespaces/systemd/__init__.py
index d8e7f86..84cb135 100644
--- a/linuxnamespaces/systemd/__init__.py
+++ b/linuxnamespaces/systemd/__init__.py
@@ -8,6 +8,48 @@ import sys
import typing
+_DBUS_INTEGER_BOUNDS = (
+ ("q", 0, 1 << 16),
+ ("n", -(1 << 15), 1 << 15),
+ ("u", 0, 1 << 32),
+ ("i", -(1 << 31), 1 << 31),
+ ("t", 0, 1 << 64),
+ ("x", -(1 << 63), 1 << 63),
+)
+
+
+def _guess_dbus_type(value: typing.Any) -> typing.Iterator[str]:
+ """Guess the type of a Python value in dbus. May yield multiple candidates.
+ """
+ if isinstance(value, bool):
+ yield "b"
+ elif isinstance(value, str):
+ yield "s"
+ elif isinstance(value, int):
+ found = False
+ for guess, low, high in _DBUS_INTEGER_BOUNDS:
+ if low <= value < high:
+ found = True
+ yield guess
+ if not found:
+ raise ValueError("integer out of bounds for dbus")
+ elif isinstance(value, float):
+ yield "d"
+ elif isinstance(value, list):
+ if not value:
+ raise ValueError("cannot guess dbus type for empty list")
+ types = [list(_guess_dbus_type(v)) for v in value]
+ found = False
+ for guess in types[0]:
+ if all(guess in guesses for guesses in types):
+ found = True
+ yield "a" + guess
+ if not found:
+ raise ValueError("could not determine homogeneous type of list")
+ else:
+ raise ValueError("failed to guess dbus type")
+
+
async def start_transient_unit(
unitname: str,
pids: list[int] | None = None,
@@ -20,14 +62,13 @@ async def start_transient_unit(
pids = [os.getpid()]
dbus_properties.append(("PIDs", ("au", pids)))
for key, value in ({} if properties is None else properties).items():
- if isinstance(value, bool):
- dbus_properties.append((key, ("b", value)))
- elif isinstance(value, str):
- dbus_properties.append((key, ("s", value)))
- else:
+ try:
+ guess = next(_guess_dbus_type(value))
+ except ValueError as err:
raise ValueError(
f"cannot infer dbus type for property {key} value"
- )
+ ) from err
+ dbus_properties.append((key, (guess, value)))
if dbusdriver in ("auto", "jeepney"):
try:
from .jeepney import start_transient_unit as jeepney_impl
diff --git a/linuxnamespaces/systemd/dbussy.py b/linuxnamespaces/systemd/dbussy.py
index 77410df..60b74fc 100644
--- a/linuxnamespaces/systemd/dbussy.py
+++ b/linuxnamespaces/systemd/dbussy.py
@@ -52,6 +52,7 @@ class SystemdJobWaiter:
try:
return self.jobs_removed[job]
except KeyError:
+ self.jobs_removed.clear()
return await asyncio.wait_for(self.job_done, timeout)
def __exit__(self, *exc_info: typing.Any) -> None:
@@ -72,10 +73,15 @@ async def start_transient_unit(
"""
bus = await ravel.session_bus_async()
with SystemdJobWaiter(bus) as wait:
+ systemd1 = bus["org.freedesktop.systemd1"]["/org/freedesktop/systemd1"]
result = await wait(
- bus["org.freedesktop.systemd1"]["/org/freedesktop/systemd1"]
- .get_interface("org.freedesktop.systemd1.Manager")
- .StartTransientUnit(unitname, "fail", properties, [])[0],
+ (
+ await (
+ await systemd1.get_async_interface(
+ "org.freedesktop.systemd1.Manager"
+ )
+ ).StartTransientUnit(unitname, "fail", properties, [])
+ )[0],
)
if result != "done":
raise OSError("StartTransientUnit failed: " + result)
diff --git a/linuxnamespaces/tarutils.py b/linuxnamespaces/tarutils.py
index 6285d5a..5ad60cd 100644
--- a/linuxnamespaces/tarutils.py
+++ b/linuxnamespaces/tarutils.py
@@ -31,8 +31,16 @@ class ZstdTarFile(tarfile.TarFile):
name: str,
mode: typing.Literal["r", "w", "x"] = "r",
fileobj: typing.BinaryIO | None = None,
+ *,
+ compresslevel: int | None = None,
+ threads: int | None = None,
**kwargs: typing.Any,
) -> tarfile.TarFile:
+ """Open a zstd compressed tar archive with the given name for readin or
+ writing. Appending is not supported. The class allows customizing the
+ compression level and the compression concurrency (default parallel)
+ while decompression ignores those arguments.
+ """
if mode not in ("r", "w", "x"):
raise ValueError("mode must be 'r', 'w' or 'x'")
openobj: str | typing.BinaryIO = name if fileobj is None else fileobj
@@ -45,11 +53,21 @@ class ZstdTarFile(tarfile.TarFile):
if mode == "r":
zfobj = zstandard.open(openobj, "rb")
else:
- zfobj = zstandard.open(
- openobj,
- mode + "b",
- cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1),
- )
+ if threads is None:
+ threads = -1
+ if compresslevel is not None:
+ if compresslevel > 22:
+ raise ValueError(
+ "invalid compression level {compresslevel}"
+ )
+ cctx = zstandard.ZstdCompressor(
+ write_checksum=True, threads=threads, level=compresslevel
+ )
+ else:
+ cctx = zstandard.ZstdCompressor(
+ write_checksum=True, threads=threads
+ )
+ zfobj = zstandard.open(openobj, mode + "b", cctx=cctx)
try:
tarobj = cls.taropen(name, mode, zfobj, **kwargs)
except (OSError, EOFError, zstandard.ZstdError) as exc: