From bb46ba775280198325cc8b86eb43be7cce75a399 Mon Sep 17 00:00:00 2001 From: Helmut Grohne Date: Mon, 20 May 2024 07:50:32 +0200 Subject: add function async_copyfd It is a bit like an async version of shutil.copyfileobj but for bare file descriptors and has an optimized version for pipes. --- linuxnamespaces/__init__.py | 127 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/linuxnamespaces/__init__.py b/linuxnamespaces/__init__.py index 54a6883..1a8b0ee 100644 --- a/linuxnamespaces/__init__.py +++ b/linuxnamespaces/__init__.py @@ -9,6 +9,7 @@ import asyncio import bisect import contextlib import dataclasses +import errno import os import pathlib import stat @@ -546,3 +547,129 @@ def unshare_user_idmap_nohelper( unshare(flags) pathlib.Path("/proc/self/setgroups").write_bytes(b"deny") newidmaps(-1, [uidmap], [gidmap], False) + + +class _AsyncSplicer: + bs = 65536 + + def __init__(self, from_fd: int, to_fd: int, count: int | None = None): + self.from_fd = from_fd + self.to_fd = to_fd + self.copied = 0 + self.remain = count + self.wait_read = True + self.loop = asyncio.get_running_loop() + self.fut: asyncio.Future[int] = self.loop.create_future() + self.loop.add_reader(self.from_fd, self.handle_io) + + def handle_io(self) -> None: + try: + ret = os.splice( + self.from_fd, + self.to_fd, + self.bs if self.remain is None else min(self.bs, self.remain), + flags=os.SPLICE_F_NONBLOCK, + ) + except OSError as err: + if err.errno == errno.EAGAIN: + self.wait_read = not self.wait_read + if self.wait_read: + self.loop.remove_writer(self.to_fd) + self.loop.add_reader(self.from_fd, self.handle_io) + else: + self.loop.remove_reader(self.from_fd) + self.loop.add_writer(self.to_fd, self.handle_io) + else: + self.loop.remove_reader(self.from_fd) + self.loop.remove_writer(self.to_fd) + self.fut.set_exception(err) + else: + self.copied += ret + if self.remain is not None: + self.remain -= ret + if ret == 0 or self.remain == 0: + self.loop.remove_reader(self.from_fd) + self.loop.remove_writer(self.to_fd) + self.fut.set_result(self.copied) + + +class _AsyncCopier: + bs = 65536 + + def __init__(self, from_fd: int, to_fd: int, count: int | None = None): + self.from_fd = from_fd + self.to_fd = to_fd + self.buffer = b"" + self.copied = 0 # bytes read and written + self.remain = count # remaining bytes not yet read + # eof can be an exception when a read failed and otherwise indicates + # whether a read returned 0. + self.eof: bool | OSError = False + self.loop = asyncio.get_running_loop() + self.fut: asyncio.Future[int] = self.loop.create_future() + self.loop.add_reader(self.from_fd, self.handle_readable) + + def handle_readable(self) -> None: + try: + data = os.read( + self.from_fd, + self.bs if self.remain is None else min(self.bs, self.remain), + ) + except OSError as err: + if err.errno != errno.EAGAIN: + self.loop.remove_reader(self.from_fd) + if self.buffer: + self.eof = err + else: + self.fut.set_exception(err) + else: + if data: + if self.remain is not None: + self.remain -= len(data) + self.buffer += data + if len(self.buffer) == len(data): + self.loop.add_writer(self.to_fd, self.handle_writeable) + if self.remain == 0 or len(self.buffer) >= self.bs: + self.loop.remove_reader(self.from_fd) + else: + self.eof = True + self.loop.remove_reader(self.from_fd) + + def handle_writeable(self) -> None: + try: + written = os.write(self.to_fd, self.buffer) + except OSError as err: + if err.errno != errno.EAGAIN: + self.loop.remove_writer(self.to_fd) + if isinstance(self.eof, OSError): + self.fut.set_exception(self.eof) + else: + self.loop.remove_reader(self.from_fd) + self.fut.set_exception(err) + else: + self.buffer = self.buffer[written:] + self.copied += written + if not self.buffer: + self.loop.remove_writer(self.to_fd) + if self.eof is True or self.remain == 0: + self.fut.set_result(self.copied) + elif isinstance(self.eof, OSError): + self.fut.set_exception(self.eof) + elif not self.eof and self.remain and len(self.buffer) < self.bs: + self.loop.add_reader(self.from_fd, self.handle_readable) + + +def async_copyfd( + from_fd: int, to_fd: int, count: int | None = None +) -> asyncio.Future[int]: + """Copy the given number of bytes from the first file descriptor to the + second file descriptor in an asyncio context. Both copies are performed + binary. An efficient implementation is chosen depending on the file type + of file descriptors. + """ + if ( + stat.S_ISFIFO(os.fstat(from_fd).st_mode) + or stat.S_ISFIFO(os.fstat(to_fd).st_mode) + ): + return _AsyncSplicer(from_fd, to_fd, count).fut + return _AsyncCopier(from_fd, to_fd, count).fut -- cgit v1.2.3