summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--linuxnamespaces/__init__.py31
-rw-r--r--linuxnamespaces/atlocation.py21
-rw-r--r--linuxnamespaces/filedescriptor.py25
-rw-r--r--linuxnamespaces/syscalls.py25
4 files changed, 73 insertions, 29 deletions
diff --git a/linuxnamespaces/__init__.py b/linuxnamespaces/__init__.py
index cd30498..ce37150 100644
--- a/linuxnamespaces/__init__.py
+++ b/linuxnamespaces/__init__.py
@@ -645,7 +645,12 @@ def unshare_user_idmap_nohelper(
class _AsyncFilesender:
bs = 65536
- def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
+ def __init__(
+ self,
+ from_fd: FileDescriptor,
+ to_fd: FileDescriptor,
+ count: int | None = None,
+ ):
self.from_fd = from_fd
self.to_fd = to_fd
self.copied = 0
@@ -678,7 +683,12 @@ class _AsyncFilesender:
class _AsyncSplicer:
bs = 65536
- def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
+ def __init__(
+ self,
+ from_fd: FileDescriptor,
+ to_fd: FileDescriptor,
+ count: int | None = None,
+ ):
self.from_fd = from_fd
self.to_fd = to_fd
self.copied = 0
@@ -722,7 +732,12 @@ class _AsyncSplicer:
class _AsyncCopier:
bs = 65536
- def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
+ def __init__(
+ self,
+ from_fd: FileDescriptor,
+ to_fd: FileDescriptor,
+ count: int | None = None,
+ ):
self.from_fd = from_fd
self.to_fd = to_fd
self.buffer = b""
@@ -786,13 +801,15 @@ class _AsyncCopier:
def async_copyfd(
- from_fd: int, to_fd: int, count: int | None = None
+ from_fd: FileDescriptor, to_fd: FileDescriptor, count: int | None = None
) -> asyncio.Future[int]:
"""Copy the given number of bytes from the first file descriptor to the
second file descriptor in an asyncio context. Both copies are performed
binary. An efficient implementation is chosen depending on the file type
of file descriptors.
"""
+ from_fd = FileDescriptor(from_fd)
+ to_fd = FileDescriptor(to_fd)
from_mode = os.fstat(from_fd).st_mode
if stat.S_ISREG(from_mode):
return _AsyncFilesender(from_fd, to_fd, count).fut
@@ -802,7 +819,7 @@ def async_copyfd(
class _AsyncPidfdWaiter:
- def __init__(self, pidfd: int, flags: int):
+ def __init__(self, pidfd: FileDescriptor, flags: int):
self.pidfd = pidfd
self.flags = flags
self.loop = asyncio.get_running_loop()
@@ -827,12 +844,12 @@ class _AsyncPidfdWaiter:
def async_waitpidfd(
- pidfd: int, flags: int
+ pidfd: FileDescriptorLike, flags: int
) -> asyncio.Future[os.waitid_result | None]:
"""Asynchronously wait for a process represented as a pidfd. This is an
async variant of waitid(P_PIDFD, pidfd, flags).
"""
- return _AsyncPidfdWaiter(pidfd, flags).fut
+ return _AsyncPidfdWaiter(FileDescriptor(pidfd), flags).fut
def enable_loopback_if() -> None:
diff --git a/linuxnamespaces/atlocation.py b/linuxnamespaces/atlocation.py
index 0003c35..d30e88f 100644
--- a/linuxnamespaces/atlocation.py
+++ b/linuxnamespaces/atlocation.py
@@ -15,7 +15,7 @@ import pathlib
import stat
import typing
-from .filedescriptor import FileDescriptor
+from .filedescriptor import FileDescriptor, FileDescriptorLike, HasFileno
AT_FDCWD = FileDescriptor(-100)
@@ -58,7 +58,7 @@ class AtLocation:
def __new__(
cls,
- thing: typing.Union["AtLocation", int, PathConvertible],
+ thing: typing.Union["AtLocation", FileDescriptorLike, PathConvertible],
location: PathConvertible | None = None,
flags: AtFlags = AtFlags.NONE,
) -> "AtLocation":
@@ -76,13 +76,14 @@ class AtLocation:
)
return thing # Don't copy.
obj = super(AtLocation, cls).__new__(cls)
- if isinstance(thing, int):
+ if not isinstance(thing, FileDescriptor):
+ if isinstance(thing, int) or isinstance(thing, HasFileno):
+ thing = FileDescriptor(thing)
+ if isinstance(thing, FileDescriptor):
if thing < 0 and thing != AT_FDCWD:
raise ValueError("fd cannot be negative")
if isinstance(thing, FileDescriptor):
obj.fd = thing
- else:
- obj.fd = FileDescriptor(thing)
if location is None:
obj.location = ""
obj.flags = flags | AtFlags.AT_EMPTY_PATH
@@ -148,7 +149,7 @@ class AtLocation:
them with a slash as separator. The returned AtLocation borrows its fd
if any.
"""
- if isinstance(other, int):
+ if isinstance(other, int) or isinstance(other, HasFileno):
# A an fd is considered an absolute AT_EMPTY_PATH path.
return AtLocation(other)
non_empty_flags = self.flags & ~AtFlags.AT_EMPTY_PATH
@@ -422,7 +423,7 @@ class AtLocation:
assert self.location
os.mknod(self.location, mode, device, dir_fd=self.fd_or_none)
- def open(self, flags: int, mode: int = 0o777) -> int:
+ def open(self, flags: int, mode: int = 0o777) -> FileDescriptor:
"""Wrapper for os.open supplying path and dir_fd."""
if self.flags == AtFlags.AT_SYMLINK_NOFOLLOW:
flags |= os.O_NOFOLLOW
@@ -431,7 +432,9 @@ class AtLocation:
"opening an AtLocation only supports flag AT_SYMLINK_NOFOLLOW"
)
assert self.location
- return os.open(self.location, flags, mode, dir_fd=self.fd_or_none)
+ return FileDescriptor(
+ os.open(self.location, flags, mode, dir_fd=self.fd_or_none)
+ )
def readlink(self) -> str:
"""Wrapper for os.readlink supplying path and dir_fd."""
@@ -595,4 +598,4 @@ class AtLocation:
return f"{cn}({self.fd}, flags={self.flags!r})"
-AtLocationLike = typing.Union[AtLocation, int, PathConvertible]
+AtLocationLike = typing.Union[AtLocation, FileDescriptorLike, PathConvertible]
diff --git a/linuxnamespaces/filedescriptor.py b/linuxnamespaces/filedescriptor.py
index 2cf8442..d11d4d6 100644
--- a/linuxnamespaces/filedescriptor.py
+++ b/linuxnamespaces/filedescriptor.py
@@ -8,11 +8,28 @@ import os
import typing
+@typing.runtime_checkable
+class HasFileno(typing.Protocol):
+ def fileno(self) -> int:
+ ...
+
+
+FileDescriptorLike = int | HasFileno
+
+
class FileDescriptor(int):
"""Type tag for integers that represent file descriptors. It also provides
a few very generic file descriptor methods.
"""
+ def __new__(cls, value: FileDescriptorLike) -> typing.Self:
+ """Construct a FileDescriptor from an int or HasFileno."""
+ if isinstance(value, cls):
+ return value # No need to copy, it's immutable.
+ if not isinstance(value, int):
+ value = value.fileno()
+ return super(FileDescriptor, cls).__new__(cls, value)
+
def __enter__(self) -> "FileDescriptor":
"""When used as a context manager, close the file descriptor on scope
exit.
@@ -37,16 +54,18 @@ class FileDescriptor(int):
return FileDescriptor(os.dup(self))
return FileDescriptor(fcntl.fcntl(self, fcntl.F_DUPFD_CLOEXEC, 0))
- def dup2(self, fd2: int, inheritable: bool = True) -> "FileDescriptor":
+ def dup2(
+ self, fd2: FileDescriptorLike, inheritable: bool = True
+ ) -> "FileDescriptor":
"""Duplicate the file to the given file descriptor number."""
- return FileDescriptor(os.dup2(self, fd2, inheritable))
+ return FileDescriptor(os.dup2(self, FileDescriptor(fd2), inheritable))
@classmethod
def pidfd_open(cls, pid: int, flags: int = 0) -> typing.Self:
"""Convenience wrapper for os.pidfd_open."""
return cls(os.pidfd_open(pid, flags))
- def fileno(self) -> int:
+ def fileno(self) -> FileDescriptor:
"""Return self such that it satisfies the HasFileno protocol."""
return self
diff --git a/linuxnamespaces/syscalls.py b/linuxnamespaces/syscalls.py
index f6af348..be0a5f7 100644
--- a/linuxnamespaces/syscalls.py
+++ b/linuxnamespaces/syscalls.py
@@ -15,6 +15,7 @@ import os
import signal
import typing
+from .filedescriptor import FileDescriptor, FileDescriptorLike
from .atlocation import AtFlags, AtLocation, AtLocationLike, PathConvertible
@@ -498,7 +499,7 @@ class EventFD:
) -> None:
if flags & ~EventFDFlags.ALL_FLAGS:
raise ValueError("invalid flags for eventfd")
- self.fd = os.eventfd(initval, int(flags))
+ self.fd = FileDescriptor(os.eventfd(initval, int(flags)))
def read(self) -> int:
"""Decrease the value of the eventfd using eventfd_read."""
@@ -541,7 +542,7 @@ class EventFD:
raise ValueError("attempt to read from closed eventfd")
os.eventfd_write(self.fd, value)
- def fileno(self) -> int:
+ def fileno(self) -> FileDescriptor:
"""Return the underlying file descriptor."""
return self.fd
@@ -551,7 +552,7 @@ class EventFD:
try:
os.close(self.fd)
finally:
- self.fd = -1
+ self.fd = FileDescriptor(-1)
__del__ = close
@@ -610,7 +611,7 @@ def mount_setattr(
attr_set: MountAttrFlags = MountAttrFlags.NONE,
attr_clr: MountAttrFlags = MountAttrFlags.NONE,
propagation: int = 0,
- userns_fd: int = -1,
+ userns_fd: FileDescriptorLike = -1,
) -> None:
"""Python wrapper for mount_setattr(2)."""
filesystem = AtLocation(filesystem)
@@ -619,6 +620,8 @@ def mount_setattr(
flags |= MountSetattrFlags.AT_RECURSIVE
if attr_clr & MountAttrFlags.IDMAP:
raise ValueError("cannot clear the MOUNT_ATTR_IDMAP flag")
+ if not isinstance(userns_fd, int):
+ userns_fd = userns_fd.fileno()
attr = MountAttr(attr_set, attr_clr, propagation, userns_fd)
call_libc(
"mount_setattr",
@@ -764,12 +767,14 @@ class SignalFD:
sigmask: typing.Iterable[signal.Signals],
flags: SignalFDFlags = SignalFDFlags.NONE,
):
- self.fd = SignalFD.__signalfd(-1, sigmask, flags)
+ self.fd = SignalFD.__signalfd(FileDescriptor(-1), sigmask, flags)
@staticmethod
def __signalfd(
- fd: int, sigmask: typing.Iterable[signal.Signals], flags: SignalFDFlags
- ) -> int:
+ fd: FileDescriptor,
+ sigmask: typing.Iterable[signal.Signals],
+ flags: SignalFDFlags,
+ ) -> FileDescriptor:
"""Python wrapper for signalfd(2)."""
bitsperlong = 8 * ctypes.sizeof(ctypes.c_ulong)
nval = 64 // bitsperlong
@@ -778,7 +783,7 @@ class SignalFD:
sigval = int(sig) - 1
mask[sigval // bitsperlong] |= 1 << (sigval % bitsperlong)
csigmask = (ctypes.c_ulong * nval)(*mask)
- return call_libc("signalfd", fd, csigmask, int(flags))
+ return FileDescriptor(call_libc("signalfd", fd, csigmask, int(flags)))
def readv(self, count: int) -> list[SignalFDSigInfo]:
"""Read up to count signals from the signalfd."""
@@ -824,7 +829,7 @@ class SignalFD:
loop.add_reader(self.fd, self.__handle_readable, self.fd, fut)
return fut
- def fileno(self) -> int:
+ def fileno(self) -> FileDescriptor:
"""Return the underlying file descriptor."""
return self.fd
@@ -834,7 +839,7 @@ class SignalFD:
try:
os.close(self.fd)
finally:
- self.fd = -1
+ self.fd = FileDescriptor(-1)
__del__ = close