diff options
-rw-r--r-- | linuxnamespaces/syscalls.py | 188 | ||||
-rw-r--r-- | tests/test_simple.py | 21 |
2 files changed, 209 insertions, 0 deletions
diff --git a/linuxnamespaces/syscalls.py b/linuxnamespaces/syscalls.py index dd4a332..ad8eb78 100644 --- a/linuxnamespaces/syscalls.py +++ b/linuxnamespaces/syscalls.py @@ -12,6 +12,7 @@ import enum import errno import logging import os +import signal import typing from .atlocation import AtFlags, AtLocation, AtLocationLike, PathConvertible @@ -23,6 +24,36 @@ logger = logging.getLogger(__name__) LIBC_SO = ctypes.CDLL(None, use_errno=True) +if typing.TYPE_CHECKING: + CData = ctypes._CData # pylint: disable=protected-access +else: + CData = typing.Any + + +def _pad_fields( + fields: list[tuple[str, type[CData]]], + totalsize: int, + name: str, + padtype: type[CData] = ctypes.c_uint8, +) -> list[tuple[str, type[CData]]]: + """Append a padding element to a ctypes.Structure _fields_ sequence such + that its total size matches a given value. + """ + fieldssize = sum(ctypes.sizeof(ft) for _, ft in fields) + padsize = totalsize - fieldssize + if padsize < 0: + raise TypeError( + f"requested padding to {totalsize}, but fields consume {fieldssize}" + ) + eltsize = ctypes.sizeof(padtype) + elements, remainder = divmod(padsize, eltsize) + if remainder: + raise TypeError( + f"padding {padsize} is not a multiple of the element size {eltsize}" + ) + return fields + [(name, padtype * elements)] + + class CloneFlags(enum.IntFlag): """This value may be supplied to * unshare(2) flags @@ -352,6 +383,42 @@ class PrctlOption(enum.IntEnum): PR_CAP_AMBIENT = 47 +class SignalFDSigInfo(ctypes.Structure): + """Information about a received signal by reading from a signalfd(2).""" + + _fields_ = _pad_fields( + [ + ("ssi_signo", ctypes.c_uint32), + ("ssi_errno", ctypes.c_int32), + ("ssi_code", ctypes.c_int32), + ("ssi_pid", ctypes.c_uint32), + ("ssi_uid", ctypes.c_uint32), + ("ssi_fd", ctypes.c_int32), + ("ssi_tid", ctypes.c_uint32), + ("ssi_band", ctypes.c_uint32), + ("ssi_overrun", ctypes.c_uint32), + ("ssi_trapno", ctypes.c_uint32), + ("ssi_status", ctypes.c_int32), + ("ssi_int", ctypes.c_int32), + ("ssi_ptr", ctypes.c_uint64), + ("ssi_utime", ctypes.c_uint64), + ("ssi_stime", ctypes.c_uint64), + ("ssi_addr", ctypes.c_uint64), + ("ssi_addr_lsb", ctypes.c_uint16), + ], + 128, + "padding", + ) + + +class SignalFDFlags(enum.IntFlag): + """This value may be supplied as flags to signalfd(2).""" + + NONE = 0 + CLOEXEC = os.O_CLOEXEC + NONBLOCK = os.O_NONBLOCK + + class UmountFlags(enum.IntFlag): """This value may be supplied to umount2(2) as flags.""" @@ -686,6 +753,127 @@ def setns(fd: int, nstype: CloneFlags = CloneFlags.NONE) -> None: call_libc("setns", fd, int(nstype)) +class SignalFD: + """Represent a file descriptor returned from signalfd(2).""" + + def __init__( + self, + sigmask: typing.Iterable[signal.Signals], + flags: SignalFDFlags = SignalFDFlags.NONE, + ): + self.fd = SignalFD.__signalfd(-1, sigmask, flags) + + @staticmethod + def __signalfd( + fd: int, sigmask: typing.Iterable[signal.Signals], flags: SignalFDFlags + ) -> int: + """Python wrapper for signalfd(2).""" + bitsperlong = 8 * ctypes.sizeof(ctypes.c_ulong) + nval = 64 // bitsperlong + mask = [0] * nval + for sig in sigmask: + sigval = int(sig) - 1 + mask[sigval // bitsperlong] |= 1 << (sigval % bitsperlong) + csigmask = (ctypes.c_ulong * nval)(*mask) + return call_libc("signalfd", fd, csigmask, int(flags)) + + def readv(self, count: int) -> list[SignalFDSigInfo]: + """Read up to count signals from the signalfd.""" + if count < 0: + raise ValueError("read count must be positive") + if self.fd < 0: + raise ValueError("attempt to read from closed signalfd") + res = [SignalFDSigInfo() for _ in range(count)] + cnt = os.readv(self.fd, res) + cnt //= ctypes.sizeof(SignalFDSigInfo) + return res[:cnt] + + def read(self) -> SignalFDSigInfo: + """Read one signal from the signalfd.""" + res = self.readv(1) + return res[0] + + def __handle_readable( + self, fd: int, fut: asyncio.Future[SignalFDSigInfo] + ) -> None: + try: + if fd != self.fd: + raise RuntimeError("SignalFD file descriptor changed") + try: + result = self.read() + except OSError as err: + if err.errno == errno.EAGAIN: + return + raise + except Exception as exc: + fut.get_loop().remove_reader(fd) + fut.set_exception(exc) + else: + fut.get_loop().remove_reader(fd) + fut.set_result(result) + + def aread(self) -> typing.Awaitable[SignalFDSigInfo]: + """Asynchronously read one signal from the signalfd.""" + if self.fd < 0: + raise ValueError("attempt to read from closed signalfd") + loop = asyncio.get_running_loop() + fut: asyncio.Future[SignalFDSigInfo] = loop.create_future() + loop.add_reader(self.fd, self.__handle_readable, self.fd, fut) + return fut + + def fileno(self) -> int: + """Return the underlying file descriptor.""" + return self.fd + + def close(self) -> None: + """Close the underlying file descriptor.""" + if self.fd >= 0: + try: + os.close(self.fd) + finally: + self.fd = -1 + + __del__ = close + + def __bool__(self) -> bool: + """Return True unless the signalfd is closed.""" + return self.fd >= 0 + + def __enter__(self) -> "EventFD": + """When used as a context manager, the SignalFD is closed on scope + exit. + """ + return self + + def __exit__( + self, + exc_type: typing.Any, + exc_value: typing.Any, + traceback: typing.Any, + ) -> None: + self.close() + + +class _SigqueueSigval(ctypes.Union): + _fields_ = [ + ("sival_int", ctypes.c_int), + ("sival_ptr", ctypes.c_void_p), + ] + + +def sigqueue( + pid: int, sig: signal.Signals, value: int | ctypes.c_void_p | None = None +) -> None: + """Python wrapper for sigqueue(2).""" + sigval = _SigqueueSigval() + if value is not None: + if isinstance(value, int): + sigval.sival_int = value + else: + sigval.sival_ptr = value + call_libc("sigqueue", pid, int(sig), sigval) + + def umount( path: PathConvertible, flags: UmountFlags = UmountFlags.NONE ) -> None: diff --git a/tests/test_simple.py b/tests/test_simple.py index eb03384..07a0740 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -5,6 +5,7 @@ import asyncio import errno import os import pathlib +import signal import socket import unittest @@ -92,6 +93,26 @@ class AsnycioTest(unittest.IsolatedAsyncioTestCase): efd.write() self.assertEqual(await fut, 1) + async def test_signalfd(self) -> None: + testsig = signal.SIGUSR1 + sfd = linuxnamespaces.SignalFD( + [testsig], linuxnamespaces.SignalFDFlags.NONBLOCK + ) + self.addCleanup(sfd.close) + oldmask = signal.pthread_sigmask(signal.SIG_SETMASK, [testsig]) + self.addCleanup(signal.pthread_sigmask, signal.SIG_SETMASK, oldmask) + fut = asyncio.ensure_future(sfd.aread()) + await asyncio.sleep(0.000001) # Let the loop run + self.assertFalse(fut.done()) + sigval = 123456789 + mypid = os.getpid() + linuxnamespaces.sigqueue(mypid, testsig, sigval) + siginfo = await fut + self.assertEqual(siginfo.ssi_signo, testsig) + self.assertEqual(siginfo.ssi_pid, mypid) + self.assertEqual(siginfo.ssi_uid, os.getuid()) + self.assertEqual(siginfo.ssi_int, sigval) + async def test_run_in_fork(self) -> None: with linuxnamespaces.EventFD( 0, linuxnamespaces.EventFDFlags.NONBLOCK |