summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--linuxnamespaces/syscalls.py188
-rw-r--r--tests/test_simple.py21
2 files changed, 209 insertions, 0 deletions
diff --git a/linuxnamespaces/syscalls.py b/linuxnamespaces/syscalls.py
index dd4a332..ad8eb78 100644
--- a/linuxnamespaces/syscalls.py
+++ b/linuxnamespaces/syscalls.py
@@ -12,6 +12,7 @@ import enum
import errno
import logging
import os
+import signal
import typing
from .atlocation import AtFlags, AtLocation, AtLocationLike, PathConvertible
@@ -23,6 +24,36 @@ logger = logging.getLogger(__name__)
LIBC_SO = ctypes.CDLL(None, use_errno=True)
+if typing.TYPE_CHECKING:
+ CData = ctypes._CData # pylint: disable=protected-access
+else:
+ CData = typing.Any
+
+
+def _pad_fields(
+ fields: list[tuple[str, type[CData]]],
+ totalsize: int,
+ name: str,
+ padtype: type[CData] = ctypes.c_uint8,
+) -> list[tuple[str, type[CData]]]:
+ """Append a padding element to a ctypes.Structure _fields_ sequence such
+ that its total size matches a given value.
+ """
+ fieldssize = sum(ctypes.sizeof(ft) for _, ft in fields)
+ padsize = totalsize - fieldssize
+ if padsize < 0:
+ raise TypeError(
+ f"requested padding to {totalsize}, but fields consume {fieldssize}"
+ )
+ eltsize = ctypes.sizeof(padtype)
+ elements, remainder = divmod(padsize, eltsize)
+ if remainder:
+ raise TypeError(
+ f"padding {padsize} is not a multiple of the element size {eltsize}"
+ )
+ return fields + [(name, padtype * elements)]
+
+
class CloneFlags(enum.IntFlag):
"""This value may be supplied to
* unshare(2) flags
@@ -352,6 +383,42 @@ class PrctlOption(enum.IntEnum):
PR_CAP_AMBIENT = 47
+class SignalFDSigInfo(ctypes.Structure):
+ """Information about a received signal by reading from a signalfd(2)."""
+
+ _fields_ = _pad_fields(
+ [
+ ("ssi_signo", ctypes.c_uint32),
+ ("ssi_errno", ctypes.c_int32),
+ ("ssi_code", ctypes.c_int32),
+ ("ssi_pid", ctypes.c_uint32),
+ ("ssi_uid", ctypes.c_uint32),
+ ("ssi_fd", ctypes.c_int32),
+ ("ssi_tid", ctypes.c_uint32),
+ ("ssi_band", ctypes.c_uint32),
+ ("ssi_overrun", ctypes.c_uint32),
+ ("ssi_trapno", ctypes.c_uint32),
+ ("ssi_status", ctypes.c_int32),
+ ("ssi_int", ctypes.c_int32),
+ ("ssi_ptr", ctypes.c_uint64),
+ ("ssi_utime", ctypes.c_uint64),
+ ("ssi_stime", ctypes.c_uint64),
+ ("ssi_addr", ctypes.c_uint64),
+ ("ssi_addr_lsb", ctypes.c_uint16),
+ ],
+ 128,
+ "padding",
+ )
+
+
+class SignalFDFlags(enum.IntFlag):
+ """This value may be supplied as flags to signalfd(2)."""
+
+ NONE = 0
+ CLOEXEC = os.O_CLOEXEC
+ NONBLOCK = os.O_NONBLOCK
+
+
class UmountFlags(enum.IntFlag):
"""This value may be supplied to umount2(2) as flags."""
@@ -686,6 +753,127 @@ def setns(fd: int, nstype: CloneFlags = CloneFlags.NONE) -> None:
call_libc("setns", fd, int(nstype))
+class SignalFD:
+ """Represent a file descriptor returned from signalfd(2)."""
+
+ def __init__(
+ self,
+ sigmask: typing.Iterable[signal.Signals],
+ flags: SignalFDFlags = SignalFDFlags.NONE,
+ ):
+ self.fd = SignalFD.__signalfd(-1, sigmask, flags)
+
+ @staticmethod
+ def __signalfd(
+ fd: int, sigmask: typing.Iterable[signal.Signals], flags: SignalFDFlags
+ ) -> int:
+ """Python wrapper for signalfd(2)."""
+ bitsperlong = 8 * ctypes.sizeof(ctypes.c_ulong)
+ nval = 64 // bitsperlong
+ mask = [0] * nval
+ for sig in sigmask:
+ sigval = int(sig) - 1
+ mask[sigval // bitsperlong] |= 1 << (sigval % bitsperlong)
+ csigmask = (ctypes.c_ulong * nval)(*mask)
+ return call_libc("signalfd", fd, csigmask, int(flags))
+
+ def readv(self, count: int) -> list[SignalFDSigInfo]:
+ """Read up to count signals from the signalfd."""
+ if count < 0:
+ raise ValueError("read count must be positive")
+ if self.fd < 0:
+ raise ValueError("attempt to read from closed signalfd")
+ res = [SignalFDSigInfo() for _ in range(count)]
+ cnt = os.readv(self.fd, res)
+ cnt //= ctypes.sizeof(SignalFDSigInfo)
+ return res[:cnt]
+
+ def read(self) -> SignalFDSigInfo:
+ """Read one signal from the signalfd."""
+ res = self.readv(1)
+ return res[0]
+
+ def __handle_readable(
+ self, fd: int, fut: asyncio.Future[SignalFDSigInfo]
+ ) -> None:
+ try:
+ if fd != self.fd:
+ raise RuntimeError("SignalFD file descriptor changed")
+ try:
+ result = self.read()
+ except OSError as err:
+ if err.errno == errno.EAGAIN:
+ return
+ raise
+ except Exception as exc:
+ fut.get_loop().remove_reader(fd)
+ fut.set_exception(exc)
+ else:
+ fut.get_loop().remove_reader(fd)
+ fut.set_result(result)
+
+ def aread(self) -> typing.Awaitable[SignalFDSigInfo]:
+ """Asynchronously read one signal from the signalfd."""
+ if self.fd < 0:
+ raise ValueError("attempt to read from closed signalfd")
+ loop = asyncio.get_running_loop()
+ fut: asyncio.Future[SignalFDSigInfo] = loop.create_future()
+ loop.add_reader(self.fd, self.__handle_readable, self.fd, fut)
+ return fut
+
+ def fileno(self) -> int:
+ """Return the underlying file descriptor."""
+ return self.fd
+
+ def close(self) -> None:
+ """Close the underlying file descriptor."""
+ if self.fd >= 0:
+ try:
+ os.close(self.fd)
+ finally:
+ self.fd = -1
+
+ __del__ = close
+
+ def __bool__(self) -> bool:
+ """Return True unless the signalfd is closed."""
+ return self.fd >= 0
+
+ def __enter__(self) -> "EventFD":
+ """When used as a context manager, the SignalFD is closed on scope
+ exit.
+ """
+ return self
+
+ def __exit__(
+ self,
+ exc_type: typing.Any,
+ exc_value: typing.Any,
+ traceback: typing.Any,
+ ) -> None:
+ self.close()
+
+
+class _SigqueueSigval(ctypes.Union):
+ _fields_ = [
+ ("sival_int", ctypes.c_int),
+ ("sival_ptr", ctypes.c_void_p),
+ ]
+
+
+def sigqueue(
+ pid: int, sig: signal.Signals, value: int | ctypes.c_void_p | None = None
+) -> None:
+ """Python wrapper for sigqueue(2)."""
+ sigval = _SigqueueSigval()
+ if value is not None:
+ if isinstance(value, int):
+ sigval.sival_int = value
+ else:
+ sigval.sival_ptr = value
+ call_libc("sigqueue", pid, int(sig), sigval)
+
+
def umount(
path: PathConvertible, flags: UmountFlags = UmountFlags.NONE
) -> None:
diff --git a/tests/test_simple.py b/tests/test_simple.py
index eb03384..07a0740 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -5,6 +5,7 @@ import asyncio
import errno
import os
import pathlib
+import signal
import socket
import unittest
@@ -92,6 +93,26 @@ class AsnycioTest(unittest.IsolatedAsyncioTestCase):
efd.write()
self.assertEqual(await fut, 1)
+ async def test_signalfd(self) -> None:
+ testsig = signal.SIGUSR1
+ sfd = linuxnamespaces.SignalFD(
+ [testsig], linuxnamespaces.SignalFDFlags.NONBLOCK
+ )
+ self.addCleanup(sfd.close)
+ oldmask = signal.pthread_sigmask(signal.SIG_SETMASK, [testsig])
+ self.addCleanup(signal.pthread_sigmask, signal.SIG_SETMASK, oldmask)
+ fut = asyncio.ensure_future(sfd.aread())
+ await asyncio.sleep(0.000001) # Let the loop run
+ self.assertFalse(fut.done())
+ sigval = 123456789
+ mypid = os.getpid()
+ linuxnamespaces.sigqueue(mypid, testsig, sigval)
+ siginfo = await fut
+ self.assertEqual(siginfo.ssi_signo, testsig)
+ self.assertEqual(siginfo.ssi_pid, mypid)
+ self.assertEqual(siginfo.ssi_uid, os.getuid())
+ self.assertEqual(siginfo.ssi_int, sigval)
+
async def test_run_in_fork(self) -> None:
with linuxnamespaces.EventFD(
0, linuxnamespaces.EventFDFlags.NONBLOCK