summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--linuxnamespaces/__init__.py41
-rw-r--r--tests/test_simple.py47
2 files changed, 76 insertions, 12 deletions
diff --git a/linuxnamespaces/__init__.py b/linuxnamespaces/__init__.py
index a39ab1d..65dfc8d 100644
--- a/linuxnamespaces/__init__.py
+++ b/linuxnamespaces/__init__.py
@@ -585,6 +585,39 @@ def unshare_user_idmap_nohelper(
newidmaps(-1, [uidmap], [gidmap], False)
+class _AsyncFilesender:
+ bs = 65536
+
+ def __init__(self, from_fd: int, to_fd: int, count: int | None = None):
+ self.from_fd = from_fd
+ self.to_fd = to_fd
+ self.copied = 0
+ self.remain = count
+ self.loop = asyncio.get_running_loop()
+ self.fut: asyncio.Future[int] = self.loop.create_future()
+ self.loop.add_writer(self.to_fd, self.handle_write)
+
+ def handle_write(self) -> None:
+ try:
+ ret = os.sendfile(
+ self.to_fd,
+ self.from_fd,
+ None,
+ self.bs if self.remain is None else min(self.bs, self.remain),
+ )
+ except OSError as err:
+ if err.errno != errno.EAGAIN:
+ self.loop.remove_writer(self.to_fd)
+ self.fut.set_exception(err)
+ else:
+ self.copied += ret
+ if self.remain is not None:
+ self.remain -= ret
+ if ret == 0 or self.remain == 0:
+ self.loop.remove_writer(self.to_fd)
+ self.fut.set_result(self.copied)
+
+
class _AsyncSplicer:
bs = 65536
@@ -703,10 +736,10 @@ def async_copyfd(
binary. An efficient implementation is chosen depending on the file type
of file descriptors.
"""
- if (
- stat.S_ISFIFO(os.fstat(from_fd).st_mode)
- or stat.S_ISFIFO(os.fstat(to_fd).st_mode)
- ):
+ from_mode = os.fstat(from_fd).st_mode
+ if stat.S_ISREG(from_mode):
+ return _AsyncFilesender(from_fd, to_fd, count).fut
+ if stat.S_ISFIFO(from_mode) or stat.S_ISFIFO(os.fstat(to_fd).st_mode):
return _AsyncSplicer(from_fd, to_fd, count).fut
return _AsyncCopier(from_fd, to_fd, count).fut
diff --git a/tests/test_simple.py b/tests/test_simple.py
index 456e088..0c4e2b9 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -68,6 +68,17 @@ class IDAllocationTest(unittest.TestCase):
class AsnycioTest(unittest.IsolatedAsyncioTestCase):
+ async def asyncSetUp(self) -> None:
+ self.loop = asyncio.get_running_loop()
+
+ def wait_readable(self, rfd: int) -> asyncio.Future[None]:
+ fut = self.loop.create_future()
+ def callback() -> None:
+ self.loop.remove_reader(rfd)
+ fut.set_result(None)
+ self.loop.add_reader(rfd, callback)
+ return fut
+
async def test_eventfd(self) -> None:
with linuxnamespaces.EventFD(
1, linuxnamespaces.EventFDFlags.NONBLOCK
@@ -93,7 +104,33 @@ class AsnycioTest(unittest.IsolatedAsyncioTestCase):
await set_ready()
await asyncio.wait_for(fut, 10)
- async def test_copyfd(self) -> None:
+ async def test_copyfd_file_sock(self) -> None:
+ sock1, sock2 = socket.socketpair()
+ with sock1, sock2, linuxnamespaces.FileDescriptor(
+ os.open("/etc/passwd", os.O_RDONLY)
+ ) as rfd:
+ fut = asyncio.ensure_future(
+ linuxnamespaces.async_copyfd(rfd, sock1.fileno(), 999)
+ )
+ await self.wait_readable(sock2.fileno())
+ self.assertGreater(len(sock2.recv(999)), 0)
+ self.assertTrue(fut.done())
+ self.assertGreater(await fut, 0)
+
+ async def test_copyfd_file_pipe(self) -> None:
+ rfdp, wfdp = linuxnamespaces.FileDescriptor.pipe(blocking=False)
+ with rfdp, wfdp, linuxnamespaces.FileDescriptor(
+ os.open("/etc/passwd", os.O_RDONLY)
+ ) as rfd:
+ fut = asyncio.ensure_future(
+ linuxnamespaces.async_copyfd(rfd, wfdp, 999)
+ )
+ await self.wait_readable(rfdp)
+ self.assertGreater(len(os.read(rfdp, 999)), 0)
+ self.assertTrue(fut.done())
+ self.assertGreater(await fut, 0)
+
+ async def test_copyfd_pipe_pipe(self) -> None:
rfd1, wfd1 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
rfd2, wfd2 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
with wfd2, rfd2, rfd1:
@@ -104,13 +141,7 @@ class AsnycioTest(unittest.IsolatedAsyncioTestCase):
os.write(wfd1, b"hello")
await asyncio.sleep(0.000001) # Let the loop run
os.write(wfd1, b"world")
- loop = asyncio.get_running_loop()
- fut2 = loop.create_future()
- def callback() -> None:
- loop.remove_reader(rfd2)
- fut2.set_result(None)
- loop.add_reader(rfd2, callback)
- await fut2
+ await self.wait_readable(rfd2)
self.assertEqual(os.read(rfd2, 11), b"helloworld")
self.assertFalse(fut.done())
await asyncio.sleep(0.000001) # Let the loop run