summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHelmut Grohne <helmut@subdivi.de>2024-05-25 10:22:21 +0200
committerHelmut Grohne <helmut@subdivi.de>2024-05-25 10:22:21 +0200
commitc5c9fe325782a790d563a0a8b1cf62a855a50d81 (patch)
tree16fb8834cd3e260615497faf03f0f78e2929c305
parent992e877614476dd40abd11a82ffedc6e261dabdf (diff)
downloadpython-linuxnamespaces-c5c9fe325782a790d563a0a8b1cf62a855a50d81.tar.gz
add a FileDescriptor type
It serves two main purposes. For one thing, it allows telling bare integers and file descriptors apart on a typing level similar to a NewType. For another it adds common methods to a file descriptor and enables closing it via a context manager.
-rw-r--r--linuxnamespaces/__init__.py1
-rw-r--r--linuxnamespaces/atlocation.py32
-rw-r--r--linuxnamespaces/filedescriptor.py75
-rw-r--r--tests/test_simple.py62
4 files changed, 123 insertions, 47 deletions
diff --git a/linuxnamespaces/__init__.py b/linuxnamespaces/__init__.py
index 3302867..b50f113 100644
--- a/linuxnamespaces/__init__.py
+++ b/linuxnamespaces/__init__.py
@@ -16,6 +16,7 @@ import stat
import subprocess
import typing
+from .filedescriptor import FileDescriptor
from .atlocation import *
from .syscalls import *
diff --git a/linuxnamespaces/atlocation.py b/linuxnamespaces/atlocation.py
index 8da5982..8a38650 100644
--- a/linuxnamespaces/atlocation.py
+++ b/linuxnamespaces/atlocation.py
@@ -9,15 +9,16 @@ code for doing so.
import enum
import errno
-import fcntl
import os
import os.path
import pathlib
import stat
import typing
+from .filedescriptor import FileDescriptor
-AT_FDCWD = -100
+
+AT_FDCWD = FileDescriptor(-100)
PathConvertible = typing.Union[str, os.PathLike]
@@ -51,7 +52,7 @@ class AtLocation:
management.
"""
- fd: int
+ fd: FileDescriptor
location: PathConvertible
flags: AtFlags
@@ -63,10 +64,10 @@ class AtLocation:
) -> "AtLocation":
"""The argument thing can be many different thing. If it is an
AtLocation, it is copied and all other arguments must be unset. If it
- is an integer, it is considered to be a file descriptor and the
- location must be unset if flags contains AT_EMPTY_PATH. flags are used
- as is except that AT_EMPTY_PATH is automatically added when given a
- file descriptor and no location.
+ is an integer (e.g. FileDescriptor), it is considered to be a file
+ descriptor and the location must be unset if flags contains
+ AT_EMPTY_PATH. flags are used as is except that AT_EMPTY_PATH is
+ automatically added when given a file descriptor and no location.
"""
if isinstance(thing, AtLocation):
if location is not None or flags != AtFlags.NONE:
@@ -78,7 +79,10 @@ class AtLocation:
if isinstance(thing, int):
if thing < 0 and thing != AT_FDCWD:
raise ValueError("fd cannot be negative")
- obj.fd = thing
+ if isinstance(thing, FileDescriptor):
+ obj.fd = thing
+ else:
+ obj.fd = FileDescriptor(thing)
if location is None:
obj.location = ""
obj.flags = flags | AtFlags.AT_EMPTY_PATH
@@ -100,7 +104,7 @@ class AtLocation:
def close(self) -> None:
"""Close the underlying file descriptor."""
if self.fd >= 0:
- os.close(self.fd)
+ self.fd.close()
self.fd = AT_FDCWD
def as_emptypath(self, inheritable: bool = True) -> "AtLocation":
@@ -109,11 +113,7 @@ class AtLocation:
all cases, the caller is responsible for closing the result object.
"""
if self.flags & AtFlags.AT_EMPTY_PATH:
- newfd = fcntl.fcntl(
- self.fd,
- fcntl.F_DUPFD if inheritable else fcntl.F_DUPFD_CLOEXEC,
- 0,
- )
+ newfd = self.fd.dup(inheritable=inheritable)
return AtLocation(newfd, flags=self.flags)
return AtLocation(
self.open(flags=os.O_PATH | (0 if inheritable else os.O_CLOEXEC))
@@ -175,7 +175,7 @@ class AtLocation:
def __truediv__(self, name: "AtLocationLike") -> "AtLocation":
return self.joinpath(name)
- def fileno(self) -> int:
+ def fileno(self) -> FileDescriptor:
"""Return the underlying file descriptor if this is an AT_EMPTY_PATH
location and raise a ValueError otherwise.
"""
@@ -186,7 +186,7 @@ class AtLocation:
return self.fd
@property
- def fd_or_none(self) -> int | None:
+ def fd_or_none(self) -> FileDescriptor | None:
"""A variant of the fd attribute that replaces AT_FDCWD with None."""
return None if self.fd == AT_FDCWD else self.fd
diff --git a/linuxnamespaces/filedescriptor.py b/linuxnamespaces/filedescriptor.py
new file mode 100644
index 0000000..4395a54
--- /dev/null
+++ b/linuxnamespaces/filedescriptor.py
@@ -0,0 +1,75 @@
+# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
+# SPDX-License-Identifier: GPL-3
+
+"""A type tag for integers that represent file descriptors."""
+
+import fcntl
+import os
+import typing
+
+
+class FileDescriptor(int):
+ """Type tag for integers that represent file descriptors. It also provides
+ a few very generic file descriptor methods.
+ """
+
+ def __enter__(self) -> "FileDescriptor":
+ """When used as a context manager, close the file descriptor on scope
+ exit.
+ """
+ return self
+
+ def __exit__(self, *args: typing.Any) -> None:
+ """When used as a context manager, close the file descriptor on scope
+ exit.
+ """
+ self.close()
+
+ def close(self) -> None:
+ """Close the file descriptor. Since int is immutable, the caller is
+ responsibe for not closing twice.
+ """
+ os.close(self)
+
+ def dup(self, inheritable: bool = True) -> "FileDescriptor":
+ """Return a duplicate of the file descriptor."""
+ if inheritable:
+ return FileDescriptor(os.dup(self))
+ return FileDescriptor(fcntl.fcntl(self, fcntl.F_DUPFD_CLOEXEC, 0))
+
+ def dup2(self, fd2: int, inheritable: bool = True) -> "FileDescriptor":
+ """Duplicate the file to the given file descriptor number."""
+ return FileDescriptor(os.dup2(self, fd2, inheritable))
+
+ def fileno(self) -> int:
+ """Return self such that it satisfies the HasFileno protocol."""
+ return self
+
+ def get_blocking(self) -> bool:
+ """Get the blocking mode of the file descriptor."""
+ return os.get_blocking(self)
+
+ def get_inheritable(self) -> bool:
+ """Get the close-on-exec flag of the file descriptor."""
+ return os.get_inheritable(self)
+
+ @classmethod
+ def pipe(
+ cls, blocking: bool = True, inheritable: bool = True
+ ) -> tuple["FileDescriptor", "FileDescriptor"]:
+ """Create a pipe with flags set atomically. This actually corresponds
+ to the pipe2 syscall, but skipping flags is equivalent to calling pipe.
+ """
+ rfd, wfd = os.pipe2(
+ (0 if blocking else os.O_NONBLOCK)
+ | (0 if inheritable else os.O_CLOEXEC),
+ )
+ return (cls(rfd), cls(wfd))
+
+ def set_blocking(self, blocking: bool) -> None:
+ """Set the blocking mode of the file descriptor."""
+ os.set_blocking(self, blocking)
+
+ def set_inheritable(self, inheritable: bool) -> None:
+ """Set the close-on-exec flag of the file descriptor."""
+ os.set_inheritable(self, inheritable)
diff --git a/tests/test_simple.py b/tests/test_simple.py
index b3331a3..212f414 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -76,40 +76,40 @@ class AsnycioTest(unittest.IsolatedAsyncioTestCase):
await asyncio.wait_for(fut, 10)
async def test_copyfd(self) -> None:
- rfd1, wfd1 = os.pipe2(os.O_NONBLOCK)
- rfd2, wfd2 = os.pipe2(os.O_NONBLOCK)
- fut = asyncio.ensure_future(linuxnamespaces.async_copyfd(rfd1, wfd2))
- os.write(wfd1, b"hello")
- await asyncio.sleep(0.000001) # Let the loop run
- os.write(wfd1, b"world")
- loop = asyncio.get_running_loop()
- fut2 = loop.create_future()
- def callback() -> None:
- loop.remove_reader(rfd2)
- fut2.set_result(None)
- loop.add_reader(rfd2, callback)
- await fut2
- self.assertEqual(os.read(rfd2, 11), b"helloworld")
- self.assertFalse(fut.done())
- os.close(wfd1)
- await asyncio.sleep(0.000001) # Let the loop run
- self.assertTrue(fut.done())
- os.close(rfd1)
- os.close(rfd2)
- os.close(wfd2)
+ rfd1, wfd1 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
+ rfd2, wfd2 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
+ with wfd2, rfd2, rfd1:
+ with wfd1:
+ fut = asyncio.ensure_future(
+ linuxnamespaces.async_copyfd(rfd1, wfd2)
+ )
+ os.write(wfd1, b"hello")
+ await asyncio.sleep(0.000001) # Let the loop run
+ os.write(wfd1, b"world")
+ loop = asyncio.get_running_loop()
+ fut2 = loop.create_future()
+ def callback() -> None:
+ loop.remove_reader(rfd2)
+ fut2.set_result(None)
+ loop.add_reader(rfd2, callback)
+ await fut2
+ self.assertEqual(os.read(rfd2, 11), b"helloworld")
+ self.assertFalse(fut.done())
+ await asyncio.sleep(0.000001) # Let the loop run
+ self.assertTrue(fut.done())
self.assertEqual(await fut, 10)
async def test_copyfd_epipe(self) -> None:
- rfd1, wfd1 = os.pipe2(os.O_NONBLOCK)
- rfd2, wfd2 = os.pipe2(os.O_NONBLOCK)
- fut = asyncio.ensure_future(linuxnamespaces.async_copyfd(rfd1, wfd2))
- os.close(rfd2)
- os.write(wfd1, b"hello")
- await asyncio.sleep(0.000001) # Let the loop run
- self.assertTrue(fut.done())
- os.close(rfd1)
- os.close(wfd1)
- os.close(wfd2)
+ rfd1, wfd1 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
+ rfd2, wfd2 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
+ with wfd2, wfd1, rfd1:
+ with rfd2:
+ fut = asyncio.ensure_future(
+ linuxnamespaces.async_copyfd(rfd1, wfd2)
+ )
+ os.write(wfd1, b"hello")
+ await asyncio.sleep(0.000001) # Let the loop run
+ self.assertTrue(fut.done())
exc = fut.exception()
self.assertIsInstance(exc, OSError)
self.assertEqual(exc.errno, errno.EPIPE)