From c5c9fe325782a790d563a0a8b1cf62a855a50d81 Mon Sep 17 00:00:00 2001 From: Helmut Grohne Date: Sat, 25 May 2024 10:22:21 +0200 Subject: add a FileDescriptor type It serves two main purposes. For one thing, it allows telling bare integers and file descriptors apart on a typing level similar to a NewType. For another it adds common methods to a file descriptor and enables closing it via a context manager. --- linuxnamespaces/__init__.py | 1 + linuxnamespaces/atlocation.py | 32 ++++++++--------- linuxnamespaces/filedescriptor.py | 75 +++++++++++++++++++++++++++++++++++++++ tests/test_simple.py | 62 ++++++++++++++++---------------- 4 files changed, 123 insertions(+), 47 deletions(-) create mode 100644 linuxnamespaces/filedescriptor.py diff --git a/linuxnamespaces/__init__.py b/linuxnamespaces/__init__.py index 3302867..b50f113 100644 --- a/linuxnamespaces/__init__.py +++ b/linuxnamespaces/__init__.py @@ -16,6 +16,7 @@ import stat import subprocess import typing +from .filedescriptor import FileDescriptor from .atlocation import * from .syscalls import * diff --git a/linuxnamespaces/atlocation.py b/linuxnamespaces/atlocation.py index 8da5982..8a38650 100644 --- a/linuxnamespaces/atlocation.py +++ b/linuxnamespaces/atlocation.py @@ -9,15 +9,16 @@ code for doing so. import enum import errno -import fcntl import os import os.path import pathlib import stat import typing +from .filedescriptor import FileDescriptor -AT_FDCWD = -100 + +AT_FDCWD = FileDescriptor(-100) PathConvertible = typing.Union[str, os.PathLike] @@ -51,7 +52,7 @@ class AtLocation: management. """ - fd: int + fd: FileDescriptor location: PathConvertible flags: AtFlags @@ -63,10 +64,10 @@ class AtLocation: ) -> "AtLocation": """The argument thing can be many different thing. If it is an AtLocation, it is copied and all other arguments must be unset. If it - is an integer, it is considered to be a file descriptor and the - location must be unset if flags contains AT_EMPTY_PATH. flags are used - as is except that AT_EMPTY_PATH is automatically added when given a - file descriptor and no location. + is an integer (e.g. FileDescriptor), it is considered to be a file + descriptor and the location must be unset if flags contains + AT_EMPTY_PATH. flags are used as is except that AT_EMPTY_PATH is + automatically added when given a file descriptor and no location. """ if isinstance(thing, AtLocation): if location is not None or flags != AtFlags.NONE: @@ -78,7 +79,10 @@ class AtLocation: if isinstance(thing, int): if thing < 0 and thing != AT_FDCWD: raise ValueError("fd cannot be negative") - obj.fd = thing + if isinstance(thing, FileDescriptor): + obj.fd = thing + else: + obj.fd = FileDescriptor(thing) if location is None: obj.location = "" obj.flags = flags | AtFlags.AT_EMPTY_PATH @@ -100,7 +104,7 @@ class AtLocation: def close(self) -> None: """Close the underlying file descriptor.""" if self.fd >= 0: - os.close(self.fd) + self.fd.close() self.fd = AT_FDCWD def as_emptypath(self, inheritable: bool = True) -> "AtLocation": @@ -109,11 +113,7 @@ class AtLocation: all cases, the caller is responsible for closing the result object. """ if self.flags & AtFlags.AT_EMPTY_PATH: - newfd = fcntl.fcntl( - self.fd, - fcntl.F_DUPFD if inheritable else fcntl.F_DUPFD_CLOEXEC, - 0, - ) + newfd = self.fd.dup(inheritable=inheritable) return AtLocation(newfd, flags=self.flags) return AtLocation( self.open(flags=os.O_PATH | (0 if inheritable else os.O_CLOEXEC)) @@ -175,7 +175,7 @@ class AtLocation: def __truediv__(self, name: "AtLocationLike") -> "AtLocation": return self.joinpath(name) - def fileno(self) -> int: + def fileno(self) -> FileDescriptor: """Return the underlying file descriptor if this is an AT_EMPTY_PATH location and raise a ValueError otherwise. """ @@ -186,7 +186,7 @@ class AtLocation: return self.fd @property - def fd_or_none(self) -> int | None: + def fd_or_none(self) -> FileDescriptor | None: """A variant of the fd attribute that replaces AT_FDCWD with None.""" return None if self.fd == AT_FDCWD else self.fd diff --git a/linuxnamespaces/filedescriptor.py b/linuxnamespaces/filedescriptor.py new file mode 100644 index 0000000..4395a54 --- /dev/null +++ b/linuxnamespaces/filedescriptor.py @@ -0,0 +1,75 @@ +# Copyright 2024 Helmut Grohne +# SPDX-License-Identifier: GPL-3 + +"""A type tag for integers that represent file descriptors.""" + +import fcntl +import os +import typing + + +class FileDescriptor(int): + """Type tag for integers that represent file descriptors. It also provides + a few very generic file descriptor methods. + """ + + def __enter__(self) -> "FileDescriptor": + """When used as a context manager, close the file descriptor on scope + exit. + """ + return self + + def __exit__(self, *args: typing.Any) -> None: + """When used as a context manager, close the file descriptor on scope + exit. + """ + self.close() + + def close(self) -> None: + """Close the file descriptor. Since int is immutable, the caller is + responsibe for not closing twice. + """ + os.close(self) + + def dup(self, inheritable: bool = True) -> "FileDescriptor": + """Return a duplicate of the file descriptor.""" + if inheritable: + return FileDescriptor(os.dup(self)) + return FileDescriptor(fcntl.fcntl(self, fcntl.F_DUPFD_CLOEXEC, 0)) + + def dup2(self, fd2: int, inheritable: bool = True) -> "FileDescriptor": + """Duplicate the file to the given file descriptor number.""" + return FileDescriptor(os.dup2(self, fd2, inheritable)) + + def fileno(self) -> int: + """Return self such that it satisfies the HasFileno protocol.""" + return self + + def get_blocking(self) -> bool: + """Get the blocking mode of the file descriptor.""" + return os.get_blocking(self) + + def get_inheritable(self) -> bool: + """Get the close-on-exec flag of the file descriptor.""" + return os.get_inheritable(self) + + @classmethod + def pipe( + cls, blocking: bool = True, inheritable: bool = True + ) -> tuple["FileDescriptor", "FileDescriptor"]: + """Create a pipe with flags set atomically. This actually corresponds + to the pipe2 syscall, but skipping flags is equivalent to calling pipe. + """ + rfd, wfd = os.pipe2( + (0 if blocking else os.O_NONBLOCK) + | (0 if inheritable else os.O_CLOEXEC), + ) + return (cls(rfd), cls(wfd)) + + def set_blocking(self, blocking: bool) -> None: + """Set the blocking mode of the file descriptor.""" + os.set_blocking(self, blocking) + + def set_inheritable(self, inheritable: bool) -> None: + """Set the close-on-exec flag of the file descriptor.""" + os.set_inheritable(self, inheritable) diff --git a/tests/test_simple.py b/tests/test_simple.py index b3331a3..212f414 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -76,40 +76,40 @@ class AsnycioTest(unittest.IsolatedAsyncioTestCase): await asyncio.wait_for(fut, 10) async def test_copyfd(self) -> None: - rfd1, wfd1 = os.pipe2(os.O_NONBLOCK) - rfd2, wfd2 = os.pipe2(os.O_NONBLOCK) - fut = asyncio.ensure_future(linuxnamespaces.async_copyfd(rfd1, wfd2)) - os.write(wfd1, b"hello") - await asyncio.sleep(0.000001) # Let the loop run - os.write(wfd1, b"world") - loop = asyncio.get_running_loop() - fut2 = loop.create_future() - def callback() -> None: - loop.remove_reader(rfd2) - fut2.set_result(None) - loop.add_reader(rfd2, callback) - await fut2 - self.assertEqual(os.read(rfd2, 11), b"helloworld") - self.assertFalse(fut.done()) - os.close(wfd1) - await asyncio.sleep(0.000001) # Let the loop run - self.assertTrue(fut.done()) - os.close(rfd1) - os.close(rfd2) - os.close(wfd2) + rfd1, wfd1 = linuxnamespaces.FileDescriptor.pipe(blocking=False) + rfd2, wfd2 = linuxnamespaces.FileDescriptor.pipe(blocking=False) + with wfd2, rfd2, rfd1: + with wfd1: + fut = asyncio.ensure_future( + linuxnamespaces.async_copyfd(rfd1, wfd2) + ) + os.write(wfd1, b"hello") + await asyncio.sleep(0.000001) # Let the loop run + os.write(wfd1, b"world") + loop = asyncio.get_running_loop() + fut2 = loop.create_future() + def callback() -> None: + loop.remove_reader(rfd2) + fut2.set_result(None) + loop.add_reader(rfd2, callback) + await fut2 + self.assertEqual(os.read(rfd2, 11), b"helloworld") + self.assertFalse(fut.done()) + await asyncio.sleep(0.000001) # Let the loop run + self.assertTrue(fut.done()) self.assertEqual(await fut, 10) async def test_copyfd_epipe(self) -> None: - rfd1, wfd1 = os.pipe2(os.O_NONBLOCK) - rfd2, wfd2 = os.pipe2(os.O_NONBLOCK) - fut = asyncio.ensure_future(linuxnamespaces.async_copyfd(rfd1, wfd2)) - os.close(rfd2) - os.write(wfd1, b"hello") - await asyncio.sleep(0.000001) # Let the loop run - self.assertTrue(fut.done()) - os.close(rfd1) - os.close(wfd1) - os.close(wfd2) + rfd1, wfd1 = linuxnamespaces.FileDescriptor.pipe(blocking=False) + rfd2, wfd2 = linuxnamespaces.FileDescriptor.pipe(blocking=False) + with wfd2, wfd1, rfd1: + with rfd2: + fut = asyncio.ensure_future( + linuxnamespaces.async_copyfd(rfd1, wfd2) + ) + os.write(wfd1, b"hello") + await asyncio.sleep(0.000001) # Let the loop run + self.assertTrue(fut.done()) exc = fut.exception() self.assertIsInstance(exc, OSError) self.assertEqual(exc.errno, errno.EPIPE) -- cgit v1.2.3