diff options
Diffstat (limited to 'linuxnamespaces')
-rw-r--r-- | linuxnamespaces/__init__.py | 31 | ||||
-rw-r--r-- | linuxnamespaces/atlocation.py | 21 | ||||
-rw-r--r-- | linuxnamespaces/filedescriptor.py | 25 | ||||
-rw-r--r-- | linuxnamespaces/syscalls.py | 25 |
4 files changed, 73 insertions, 29 deletions
diff --git a/linuxnamespaces/__init__.py b/linuxnamespaces/__init__.py index cd30498..ce37150 100644 --- a/linuxnamespaces/__init__.py +++ b/linuxnamespaces/__init__.py @@ -645,7 +645,12 @@ def unshare_user_idmap_nohelper( class _AsyncFilesender: bs = 65536 - def __init__(self, from_fd: int, to_fd: int, count: int | None = None): + def __init__( + self, + from_fd: FileDescriptor, + to_fd: FileDescriptor, + count: int | None = None, + ): self.from_fd = from_fd self.to_fd = to_fd self.copied = 0 @@ -678,7 +683,12 @@ class _AsyncFilesender: class _AsyncSplicer: bs = 65536 - def __init__(self, from_fd: int, to_fd: int, count: int | None = None): + def __init__( + self, + from_fd: FileDescriptor, + to_fd: FileDescriptor, + count: int | None = None, + ): self.from_fd = from_fd self.to_fd = to_fd self.copied = 0 @@ -722,7 +732,12 @@ class _AsyncSplicer: class _AsyncCopier: bs = 65536 - def __init__(self, from_fd: int, to_fd: int, count: int | None = None): + def __init__( + self, + from_fd: FileDescriptor, + to_fd: FileDescriptor, + count: int | None = None, + ): self.from_fd = from_fd self.to_fd = to_fd self.buffer = b"" @@ -786,13 +801,15 @@ class _AsyncCopier: def async_copyfd( - from_fd: int, to_fd: int, count: int | None = None + from_fd: FileDescriptor, to_fd: FileDescriptor, count: int | None = None ) -> asyncio.Future[int]: """Copy the given number of bytes from the first file descriptor to the second file descriptor in an asyncio context. Both copies are performed binary. An efficient implementation is chosen depending on the file type of file descriptors. """ + from_fd = FileDescriptor(from_fd) + to_fd = FileDescriptor(to_fd) from_mode = os.fstat(from_fd).st_mode if stat.S_ISREG(from_mode): return _AsyncFilesender(from_fd, to_fd, count).fut @@ -802,7 +819,7 @@ def async_copyfd( class _AsyncPidfdWaiter: - def __init__(self, pidfd: int, flags: int): + def __init__(self, pidfd: FileDescriptor, flags: int): self.pidfd = pidfd self.flags = flags self.loop = asyncio.get_running_loop() @@ -827,12 +844,12 @@ class _AsyncPidfdWaiter: def async_waitpidfd( - pidfd: int, flags: int + pidfd: FileDescriptorLike, flags: int ) -> asyncio.Future[os.waitid_result | None]: """Asynchronously wait for a process represented as a pidfd. This is an async variant of waitid(P_PIDFD, pidfd, flags). """ - return _AsyncPidfdWaiter(pidfd, flags).fut + return _AsyncPidfdWaiter(FileDescriptor(pidfd), flags).fut def enable_loopback_if() -> None: diff --git a/linuxnamespaces/atlocation.py b/linuxnamespaces/atlocation.py index 0003c35..d30e88f 100644 --- a/linuxnamespaces/atlocation.py +++ b/linuxnamespaces/atlocation.py @@ -15,7 +15,7 @@ import pathlib import stat import typing -from .filedescriptor import FileDescriptor +from .filedescriptor import FileDescriptor, FileDescriptorLike, HasFileno AT_FDCWD = FileDescriptor(-100) @@ -58,7 +58,7 @@ class AtLocation: def __new__( cls, - thing: typing.Union["AtLocation", int, PathConvertible], + thing: typing.Union["AtLocation", FileDescriptorLike, PathConvertible], location: PathConvertible | None = None, flags: AtFlags = AtFlags.NONE, ) -> "AtLocation": @@ -76,13 +76,14 @@ class AtLocation: ) return thing # Don't copy. obj = super(AtLocation, cls).__new__(cls) - if isinstance(thing, int): + if not isinstance(thing, FileDescriptor): + if isinstance(thing, int) or isinstance(thing, HasFileno): + thing = FileDescriptor(thing) + if isinstance(thing, FileDescriptor): if thing < 0 and thing != AT_FDCWD: raise ValueError("fd cannot be negative") if isinstance(thing, FileDescriptor): obj.fd = thing - else: - obj.fd = FileDescriptor(thing) if location is None: obj.location = "" obj.flags = flags | AtFlags.AT_EMPTY_PATH @@ -148,7 +149,7 @@ class AtLocation: them with a slash as separator. The returned AtLocation borrows its fd if any. """ - if isinstance(other, int): + if isinstance(other, int) or isinstance(other, HasFileno): # A an fd is considered an absolute AT_EMPTY_PATH path. return AtLocation(other) non_empty_flags = self.flags & ~AtFlags.AT_EMPTY_PATH @@ -422,7 +423,7 @@ class AtLocation: assert self.location os.mknod(self.location, mode, device, dir_fd=self.fd_or_none) - def open(self, flags: int, mode: int = 0o777) -> int: + def open(self, flags: int, mode: int = 0o777) -> FileDescriptor: """Wrapper for os.open supplying path and dir_fd.""" if self.flags == AtFlags.AT_SYMLINK_NOFOLLOW: flags |= os.O_NOFOLLOW @@ -431,7 +432,9 @@ class AtLocation: "opening an AtLocation only supports flag AT_SYMLINK_NOFOLLOW" ) assert self.location - return os.open(self.location, flags, mode, dir_fd=self.fd_or_none) + return FileDescriptor( + os.open(self.location, flags, mode, dir_fd=self.fd_or_none) + ) def readlink(self) -> str: """Wrapper for os.readlink supplying path and dir_fd.""" @@ -595,4 +598,4 @@ class AtLocation: return f"{cn}({self.fd}, flags={self.flags!r})" -AtLocationLike = typing.Union[AtLocation, int, PathConvertible] +AtLocationLike = typing.Union[AtLocation, FileDescriptorLike, PathConvertible] diff --git a/linuxnamespaces/filedescriptor.py b/linuxnamespaces/filedescriptor.py index 2cf8442..d11d4d6 100644 --- a/linuxnamespaces/filedescriptor.py +++ b/linuxnamespaces/filedescriptor.py @@ -8,11 +8,28 @@ import os import typing +@typing.runtime_checkable +class HasFileno(typing.Protocol): + def fileno(self) -> int: + ... + + +FileDescriptorLike = int | HasFileno + + class FileDescriptor(int): """Type tag for integers that represent file descriptors. It also provides a few very generic file descriptor methods. """ + def __new__(cls, value: FileDescriptorLike) -> typing.Self: + """Construct a FileDescriptor from an int or HasFileno.""" + if isinstance(value, cls): + return value # No need to copy, it's immutable. + if not isinstance(value, int): + value = value.fileno() + return super(FileDescriptor, cls).__new__(cls, value) + def __enter__(self) -> "FileDescriptor": """When used as a context manager, close the file descriptor on scope exit. @@ -37,16 +54,18 @@ class FileDescriptor(int): return FileDescriptor(os.dup(self)) return FileDescriptor(fcntl.fcntl(self, fcntl.F_DUPFD_CLOEXEC, 0)) - def dup2(self, fd2: int, inheritable: bool = True) -> "FileDescriptor": + def dup2( + self, fd2: FileDescriptorLike, inheritable: bool = True + ) -> "FileDescriptor": """Duplicate the file to the given file descriptor number.""" - return FileDescriptor(os.dup2(self, fd2, inheritable)) + return FileDescriptor(os.dup2(self, FileDescriptor(fd2), inheritable)) @classmethod def pidfd_open(cls, pid: int, flags: int = 0) -> typing.Self: """Convenience wrapper for os.pidfd_open.""" return cls(os.pidfd_open(pid, flags)) - def fileno(self) -> int: + def fileno(self) -> FileDescriptor: """Return self such that it satisfies the HasFileno protocol.""" return self diff --git a/linuxnamespaces/syscalls.py b/linuxnamespaces/syscalls.py index f6af348..be0a5f7 100644 --- a/linuxnamespaces/syscalls.py +++ b/linuxnamespaces/syscalls.py @@ -15,6 +15,7 @@ import os import signal import typing +from .filedescriptor import FileDescriptor, FileDescriptorLike from .atlocation import AtFlags, AtLocation, AtLocationLike, PathConvertible @@ -498,7 +499,7 @@ class EventFD: ) -> None: if flags & ~EventFDFlags.ALL_FLAGS: raise ValueError("invalid flags for eventfd") - self.fd = os.eventfd(initval, int(flags)) + self.fd = FileDescriptor(os.eventfd(initval, int(flags))) def read(self) -> int: """Decrease the value of the eventfd using eventfd_read.""" @@ -541,7 +542,7 @@ class EventFD: raise ValueError("attempt to read from closed eventfd") os.eventfd_write(self.fd, value) - def fileno(self) -> int: + def fileno(self) -> FileDescriptor: """Return the underlying file descriptor.""" return self.fd @@ -551,7 +552,7 @@ class EventFD: try: os.close(self.fd) finally: - self.fd = -1 + self.fd = FileDescriptor(-1) __del__ = close @@ -610,7 +611,7 @@ def mount_setattr( attr_set: MountAttrFlags = MountAttrFlags.NONE, attr_clr: MountAttrFlags = MountAttrFlags.NONE, propagation: int = 0, - userns_fd: int = -1, + userns_fd: FileDescriptorLike = -1, ) -> None: """Python wrapper for mount_setattr(2).""" filesystem = AtLocation(filesystem) @@ -619,6 +620,8 @@ def mount_setattr( flags |= MountSetattrFlags.AT_RECURSIVE if attr_clr & MountAttrFlags.IDMAP: raise ValueError("cannot clear the MOUNT_ATTR_IDMAP flag") + if not isinstance(userns_fd, int): + userns_fd = userns_fd.fileno() attr = MountAttr(attr_set, attr_clr, propagation, userns_fd) call_libc( "mount_setattr", @@ -764,12 +767,14 @@ class SignalFD: sigmask: typing.Iterable[signal.Signals], flags: SignalFDFlags = SignalFDFlags.NONE, ): - self.fd = SignalFD.__signalfd(-1, sigmask, flags) + self.fd = SignalFD.__signalfd(FileDescriptor(-1), sigmask, flags) @staticmethod def __signalfd( - fd: int, sigmask: typing.Iterable[signal.Signals], flags: SignalFDFlags - ) -> int: + fd: FileDescriptor, + sigmask: typing.Iterable[signal.Signals], + flags: SignalFDFlags, + ) -> FileDescriptor: """Python wrapper for signalfd(2).""" bitsperlong = 8 * ctypes.sizeof(ctypes.c_ulong) nval = 64 // bitsperlong @@ -778,7 +783,7 @@ class SignalFD: sigval = int(sig) - 1 mask[sigval // bitsperlong] |= 1 << (sigval % bitsperlong) csigmask = (ctypes.c_ulong * nval)(*mask) - return call_libc("signalfd", fd, csigmask, int(flags)) + return FileDescriptor(call_libc("signalfd", fd, csigmask, int(flags))) def readv(self, count: int) -> list[SignalFDSigInfo]: """Read up to count signals from the signalfd.""" @@ -824,7 +829,7 @@ class SignalFD: loop.add_reader(self.fd, self.__handle_readable, self.fd, fut) return fut - def fileno(self) -> int: + def fileno(self) -> FileDescriptor: """Return the underlying file descriptor.""" return self.fd @@ -834,7 +839,7 @@ class SignalFD: try: os.close(self.fd) finally: - self.fd = -1 + self.fd = FileDescriptor(-1) __del__ = close |