diff options
Diffstat (limited to 'linuxnamespaces/syscalls.py')
-rw-r--r-- | linuxnamespaces/syscalls.py | 358 |
1 files changed, 301 insertions, 57 deletions
diff --git a/linuxnamespaces/syscalls.py b/linuxnamespaces/syscalls.py index dd4a332..e9b0e44 100644 --- a/linuxnamespaces/syscalls.py +++ b/linuxnamespaces/syscalls.py @@ -12,17 +12,49 @@ import enum import errno import logging import os +import signal import typing +from .filedescriptor import FileDescriptor, FileDescriptorLike from .atlocation import AtFlags, AtLocation, AtLocationLike, PathConvertible -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) LIBC_SO = ctypes.CDLL(None, use_errno=True) +if typing.TYPE_CHECKING: + CDataType = ctypes._CDataType # pylint: disable=protected-access +else: + CDataType = typing.Any + + +def _pad_fields( + fields: list[tuple[str, type[CDataType]]], + totalsize: int, + name: str, + padtype: type[CDataType] = ctypes.c_uint8, +) -> list[tuple[str, type[CDataType]]]: + """Append a padding element to a ctypes.Structure _fields_ sequence such + that its total size matches a given value. + """ + fieldssize = sum(ctypes.sizeof(ft) for _, ft in fields) + padsize = totalsize - fieldssize + if padsize < 0: + raise TypeError( + f"requested padding to {totalsize}, but fields consume {fieldssize}" + ) + eltsize = ctypes.sizeof(padtype) + elements, remainder = divmod(padsize, eltsize) + if remainder: + raise TypeError( + f"padding {padsize} is not a multiple of the element size {eltsize}" + ) + return fields + [(name, padtype * elements)] + + class CloneFlags(enum.IntFlag): """This value may be supplied to * unshare(2) flags @@ -121,39 +153,38 @@ class MountFlags(enum.IntFlag): # Map each flag to: # * The flag value # * Whether the flag value is negated - # * Whether the flag must be negated # * Whether the flag can be negated __flagstrmap = { - "acl": (POSIXACL, False, False, False), - "async": (SYNCHRONOUS, True, False, False), - "atime": (NOATIME, True, False, True), - "bind": (BIND, False, False, False), - "dev": (NODEV, True, False, True), - "diratime": (NODIRATIME, True, False, True), - "dirsync": (DIRSYNC, False, False, False), - "exec": (NOEXEC, True, False, True), - "iversion": (I_VERSION, False, False, True), - "lazytime": (LAZYTIME, False, False, True), - "loud": (SILENT, True, False, False), - "mand": (MANDLOCK, False, False, True), - "private": (PRIVATE, False, False, False), - "rbind": (BIND | REC, False, False, False), - "relatime": (RELATIME, False, False, True), - "remount": (REMOUNT, False, False, True), - "ro": (RDONLY, False, False, False), - "rprivate": (PRIVATE | REC, False, False, False), - "rshared": (SHARED | REC, False, False, False), - "rslave": (SLAVE | REC, False, False, False), - "runbindable": (UNBINDABLE | REC, False, False, False), - "rw": (RDONLY, True, False, False), - "shared": (SHARED, False, False, False), - "silent": (SILENT, False, False, False), - "slave": (SLAVE, False, False, False), - "strictatime": (STRICTATIME, False, False, True), - "suid": (NOSUID, True, False, True), - "symfollow": (NOSYMFOLLOW, True, False, True), - "sync": (SYNCHRONOUS, False, False, False), - "unbindable": (UNBINDABLE, False, False, False), + "acl": (POSIXACL, False, False), + "async": (SYNCHRONOUS, True, False), + "atime": (NOATIME, True, True), + "bind": (BIND, False, False), + "dev": (NODEV, True, True), + "diratime": (NODIRATIME, True, True), + "dirsync": (DIRSYNC, False, False), + "exec": (NOEXEC, True, True), + "iversion": (I_VERSION, False, True), + "lazytime": (LAZYTIME, False, True), + "loud": (SILENT, True, False), + "mand": (MANDLOCK, False, True), + "private": (PRIVATE, False, False), + "rbind": (BIND | REC, False, False), + "relatime": (RELATIME, False, True), + "remount": (REMOUNT, False, True), + "ro": (RDONLY, False, False), + "rprivate": (PRIVATE | REC, False, False), + "rshared": (SHARED | REC, False, False), + "rslave": (SLAVE | REC, False, False), + "runbindable": (UNBINDABLE | REC, False, False), + "rw": (RDONLY, True, False), + "shared": (SHARED, False, False), + "silent": (SILENT, False, False), + "slave": (SLAVE, False, False), + "strictatime": (STRICTATIME, False, True), + "suid": (NOSUID, True, True), + "symfollow": (NOSYMFOLLOW, True, True), + "sync": (SYNCHRONOUS, False, False), + "unbindable": (UNBINDABLE, False, False), } def change(self, flagsstr: str) -> "MountFlags": @@ -165,19 +196,23 @@ class MountFlags(enum.IntFlag): for flagstr in flagsstr.split(","): if not flagstr: continue - flag, negated, mustnegate, cannegate = self.__flagstrmap.get( - flagstr.removeprefix("no"), - (MountFlags.NONE, False, True, False), - ) - if mustnegate <= flagstr.startswith("no") <= cannegate: + try: + flag, negated, cannegate = self.__flagstrmap[ + flagstr.removeprefix("no") + ] + except KeyError: + raise ValueError( + f"not a valid mount flag: {flagstr!r}" + ) from None + else: + if flagstr.startswith("no") > cannegate: + raise ValueError(f"not a valid mount flag: {flagstr!r}") if negated ^ flagstr.startswith("no"): ret &= ~flag else: if flag & MountFlags.PROPAGATION_FLAGS: ret &= ~MountFlags.PROPAGATION_FLAGS ret |= flag - else: - raise ValueError(f"not a valid mount flag: {flagstr!r}") return ret @staticmethod @@ -221,22 +256,25 @@ class MountFlags(enum.IntFlag): reverse=True, ) - def tostr(self) -> str: - """Attempt to represent the flags in a comma-separated, textual way.""" + def tonames(self) -> list[str]: + """Represent the flags as a sequence of list of flag names.""" if (self & MountFlags.PROPAGATION_FLAGS).bit_count() > 1: raise ValueError("cannot represent conflicting propagation flags") parts: list[str] = [] remain = self for val, text in MountFlags.__flagvals: - # Older mypy think MountFlags.__flagvals and thus text was of type - # MountFlags. + # Older mypy wrongly deduces the type of MountFlags.__flagvals. assert isinstance(text, str) if remain & val == val: parts.insert(0, text) remain &= ~val if remain: raise ValueError("cannot represent flags {remain}") - return ",".join(parts) + return parts + + def tostr(self) -> str: + """Represent the flags in a comma-separated, textual way.""" + return ",".join(self.tonames()) class MountSetattrFlags(enum.IntFlag): @@ -328,15 +366,15 @@ class OpenTreeFlags(enum.IntFlag): """This value may be supplied to open_tree(2) as flags.""" NONE = 0 - OPEN_TREE_CLONE = 0x1 - OPEN_TREE_CLOEXEC = os.O_CLOEXEC + CLONE = 0x1 + CLOEXEC = os.O_CLOEXEC AT_SYMLINK_NOFOLLOW = 0x100 AT_NO_AUTOMOUNT = 0x800 AT_EMPTY_PATH = 0x1000 AT_RECURSIVE = 0x8000 ALL_FLAGS = ( - OPEN_TREE_CLONE - | OPEN_TREE_CLOEXEC + CLONE + | CLOEXEC | AT_SYMLINK_NOFOLLOW | AT_NO_AUTOMOUNT | AT_EMPTY_PATH @@ -348,10 +386,47 @@ class PrctlOption(enum.IntEnum): """This value may be supplied to prctl(2) as option.""" PR_SET_PDEATHSIG = 1 + PR_SET_DUMPABLE = 4 PR_SET_CHILD_SUBREAPER = 36 PR_CAP_AMBIENT = 47 +class SignalFDSigInfo(ctypes.Structure): + """Information about a received signal by reading from a signalfd(2).""" + + _fields_ = _pad_fields( + [ + ("ssi_signo", ctypes.c_uint32), + ("ssi_errno", ctypes.c_int32), + ("ssi_code", ctypes.c_int32), + ("ssi_pid", ctypes.c_uint32), + ("ssi_uid", ctypes.c_uint32), + ("ssi_fd", ctypes.c_int32), + ("ssi_tid", ctypes.c_uint32), + ("ssi_band", ctypes.c_uint32), + ("ssi_overrun", ctypes.c_uint32), + ("ssi_trapno", ctypes.c_uint32), + ("ssi_status", ctypes.c_int32), + ("ssi_int", ctypes.c_int32), + ("ssi_ptr", ctypes.c_uint64), + ("ssi_utime", ctypes.c_uint64), + ("ssi_stime", ctypes.c_uint64), + ("ssi_addr", ctypes.c_uint64), + ("ssi_addr_lsb", ctypes.c_uint16), + ], + 128, + "padding", + ) + + +class SignalFDFlags(enum.IntFlag): + """This value may be supplied as flags to signalfd(2).""" + + NONE = 0 + CLOEXEC = os.O_CLOEXEC + NONBLOCK = os.O_NONBLOCK + + class UmountFlags(enum.IntFlag): """This value may be supplied to umount2(2) as flags.""" @@ -368,9 +443,9 @@ def call_libc(funcname: str, *args: typing.Any) -> int: the function returns an integer that is non-negative on success. On failure, an OSError with errno is raised. """ - logger.debug("calling libc function %s%r", funcname, args) + _logger.debug("calling libc function %s%r", funcname, args) ret: int = LIBC_SO[funcname](*args) - logger.debug("%s returned %d", funcname, ret) + _logger.debug("%s returned %d", funcname, ret) if ret < 0: err = ctypes.get_errno() raise OSError( @@ -428,7 +503,7 @@ class EventFD: ) -> None: if flags & ~EventFDFlags.ALL_FLAGS: raise ValueError("invalid flags for eventfd") - self.fd = os.eventfd(initval, int(flags)) + self.fd = FileDescriptor(os.eventfd(initval, int(flags))) def read(self) -> int: """Decrease the value of the eventfd using eventfd_read.""" @@ -471,7 +546,7 @@ class EventFD: raise ValueError("attempt to read from closed eventfd") os.eventfd_write(self.fd, value) - def fileno(self) -> int: + def fileno(self) -> FileDescriptor: """Return the underlying file descriptor.""" return self.fd @@ -481,7 +556,7 @@ class EventFD: try: os.close(self.fd) finally: - self.fd = -1 + self.fd = FileDescriptor(-1) __del__ = close @@ -489,7 +564,7 @@ class EventFD: """Return True unless the eventfd is closed.""" return self.fd >= 0 - def __enter__(self) -> "EventFD": + def __enter__(self) -> typing.Self: """When used as a context manager, the EventFD is closed on scope exit. """ return self @@ -508,7 +583,7 @@ def mount( target: PathConvertible, filesystemtype: str | None, flags: MountFlags = MountFlags.NONE, - data: str | list[str] | None = None, + data: str | list[str] | dict[str, str | int | None] | None = None, ) -> None: """Python wrapper for mount(2).""" if (flags & MountFlags.PROPAGATION_FLAGS).bit_count() > 1: @@ -520,6 +595,11 @@ def mount( ) ): raise ValueError("invalid flags for mount") + if isinstance(data, dict): + data = [ + key if value is None else f"{key}={value}" + for key, value in data.items() + ] if isinstance(data, list): if any("," in s for s in data): raise ValueError("data elements must not contain a comma") @@ -540,7 +620,7 @@ def mount_setattr( attr_set: MountAttrFlags = MountAttrFlags.NONE, attr_clr: MountAttrFlags = MountAttrFlags.NONE, propagation: int = 0, - userns_fd: int = -1, + userns_fd: FileDescriptorLike = -1, ) -> None: """Python wrapper for mount_setattr(2).""" filesystem = AtLocation(filesystem) @@ -549,6 +629,8 @@ def mount_setattr( flags |= MountSetattrFlags.AT_RECURSIVE if attr_clr & MountAttrFlags.IDMAP: raise ValueError("cannot clear the MOUNT_ATTR_IDMAP flag") + if not isinstance(userns_fd, int): + userns_fd = userns_fd.fileno() attr = MountAttr(attr_set, attr_clr, propagation, userns_fd) call_libc( "mount_setattr", @@ -613,7 +695,7 @@ def open_tree( raise ValueError("invalid flags for open_tree") if ( flags & OpenTreeFlags.AT_RECURSIVE - and not flags & OpenTreeFlags.OPEN_TREE_CLONE + and not flags & OpenTreeFlags.CLONE ): raise ValueError("invalid flags for open_tree") if source.flags & AtFlags.AT_SYMLINK_NOFOLLOW: @@ -670,6 +752,11 @@ def prctl_set_child_subreaper(enabled: bool = True) -> None: prctl(PrctlOption.PR_SET_CHILD_SUBREAPER, int(enabled)) +def prctl_set_dumpable(enabled: bool) -> None: + """Set or clear the dumpable flag.""" + prctl(PrctlOption.PR_SET_DUMPABLE, int(enabled)) + + def prctl_set_pdeathsig(signum: int) -> None: """Set the parent-death signal of the calling process.""" if signum < 0: @@ -686,6 +773,163 @@ def setns(fd: int, nstype: CloneFlags = CloneFlags.NONE) -> None: call_libc("setns", fd, int(nstype)) +class SignalFD: + """Represent a file descriptor returned from signalfd(2).""" + + _ReadIterFut = asyncio.Future[tuple[list[SignalFDSigInfo], "_ReadIterFut"]] + + def __init__( + self, + sigmask: typing.Iterable[signal.Signals], + flags: SignalFDFlags = SignalFDFlags.NONE, + ): + self.fd = SignalFD.__signalfd(FileDescriptor(-1), sigmask, flags) + + @staticmethod + def __signalfd( + fd: FileDescriptor, + sigmask: typing.Iterable[signal.Signals], + flags: SignalFDFlags, + ) -> FileDescriptor: + """Python wrapper for signalfd(2).""" + bitsperlong = 8 * ctypes.sizeof(ctypes.c_ulong) + nval = 64 // bitsperlong + mask = [0] * nval + for sig in sigmask: + sigval = int(sig) - 1 + mask[sigval // bitsperlong] |= 1 << (sigval % bitsperlong) + csigmask = (ctypes.c_ulong * nval)(*mask) + return FileDescriptor(call_libc("signalfd", fd, csigmask, int(flags))) + + def readv(self, count: int) -> list[SignalFDSigInfo]: + """Read up to count signals from the signalfd.""" + if count < 0: + raise ValueError("read count must be positive") + if self.fd < 0: + raise ValueError("attempt to read from closed signalfd") + res = [SignalFDSigInfo() for _ in range(count)] + cnt = os.readv(self.fd, res) + cnt //= ctypes.sizeof(SignalFDSigInfo) + return res[:cnt] + + def read(self) -> SignalFDSigInfo: + """Read one signal from the signalfd.""" + res = self.readv(1) + return res[0] + + def __handle_read( + self, fd: int, fut: asyncio.Future[SignalFDSigInfo] + ) -> None: + try: + if fd != self.fd: + raise RuntimeError("SignalFD file descriptor changed") + try: + result = self.read() + except OSError as err: + if err.errno == errno.EAGAIN: + return + raise + except Exception as exc: + fut.get_loop().remove_reader(fd) + fut.set_exception(exc) + else: + fut.get_loop().remove_reader(fd) + fut.set_result(result) + + def aread(self) -> typing.Awaitable[SignalFDSigInfo]: + """Asynchronously read one signal from the signalfd.""" + if self.fd < 0: + raise ValueError("attempt to read from closed signalfd") + loop = asyncio.get_running_loop() + fut: asyncio.Future[SignalFDSigInfo] = loop.create_future() + loop.add_reader(self.fd, self.__handle_read, self.fd, fut) + return fut + + def __handle_readiter(self, fd: int, fut: _ReadIterFut) -> None: + loop = fut.get_loop() + try: + if fd != self.fd: + raise RuntimeError("SignalFD file descriptor changed") + try: + # Attempt to read a full page worth of queued signals. + results = self.readv(32) + except OSError as err: + if err.errno == errno.EAGAIN: + return + raise + except Exception as exc: + loop.remove_reader(fd) + fut.set_exception(exc) + else: + nextfut: SignalFD._ReadIterFut = loop.create_future() + loop.add_reader(fd, self.__handle_readiter, self.fd, nextfut) + fut.set_result((results, nextfut)) + + async def areaditer(self) -> typing.AsyncIterator[SignalFDSigInfo]: + """Asynchronously read signals from the signalfd forever.""" + if self.fd < 0: + raise ValueError("attempt to read from closed signalfd") + loop = asyncio.get_running_loop() + fut: SignalFD._ReadIterFut = loop.create_future() + loop.add_reader(self.fd, self.__handle_readiter, self.fd, fut) + while True: + results, fut = await fut + for result in results: + yield result + + def fileno(self) -> FileDescriptor: + """Return the underlying file descriptor.""" + return self.fd + + def close(self) -> None: + """Close the underlying file descriptor.""" + if self.fd >= 0: + try: + os.close(self.fd) + finally: + self.fd = FileDescriptor(-1) + + __del__ = close + + def __bool__(self) -> bool: + """Return True unless the signalfd is closed.""" + return self.fd >= 0 + + def __enter__(self) -> typing.Self: + """When used as a context manager, the SignalFD is closed on scope + exit. + """ + return self + + def __exit__( + self, + exc_type: typing.Any, + exc_value: typing.Any, + traceback: typing.Any, + ) -> None: + self.close() + + +class _SigqueueSigval(ctypes.Union): + _fields_ = [ + ("sival_int", ctypes.c_int), + ("sival_ptr", ctypes.c_void_p), + ] + + +def sigqueue( + pid: int, sig: signal.Signals, value: int | ctypes.c_void_p | None = None +) -> None: + """Python wrapper for sigqueue(2).""" + if value is None: + sigval = _SigqueueSigval() + elif isinstance(value, int): + sigval = _SigqueueSigval(sival_int=value) + else: + sigval = _SigqueueSigval(sival_ptr=value) + call_libc("sigqueue", pid, int(sig), sigval) + + def umount( path: PathConvertible, flags: UmountFlags = UmountFlags.NONE ) -> None: |