From f1e580b1a14f980bf001662c74b0382e644850be Mon Sep 17 00:00:00 2001 From: Helmut Grohne Date: Sat, 17 Jun 2023 15:10:02 +0200 Subject: add a wsgitools.scgi.asyncio module This adds an asyncio implementation of the server side of the SCGI protocol, because asyncore is being deprecated. Unlike the asyncore implementation, this does not yet support sendfile. --- test.py | 31 ++++++ wsgitools/scgi/__init__.py | 3 +- wsgitools/scgi/asyncio.py | 242 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 275 insertions(+), 1 deletion(-) create mode 100644 wsgitools/scgi/asyncio.py diff --git a/test.py b/test.py index 5271819..7d6a463 100755 --- a/test.py +++ b/test.py @@ -498,6 +498,36 @@ class ScgiForkTest(unittest.TestCase): self.assertTrue(data.startswith(b"Status: 200 OK\r\n")) self.assertTrue(data.endswith(b"\r\n\r\nnothing")) +from wsgitools.scgi.asyncio import SCGIProtocolFactory +import asyncio + +class ScgiAsyncioTest(unittest.TestCase): + def testSimple(self) -> None: + asyncio.get_event_loop().run_until_complete(self.asyncTestSimple()) + + async def asyncTestSimple(self) -> None: + app = applications.StaticContent( + "200 OK", [("Content-Type", "text/plain")], b"nothing" + ) + server = await asyncio.get_running_loop().create_server( + SCGIProtocolFactory(app), + family=socket.AF_INET, + host="localhost", + port=0, + ) + port = server.sockets[0].getsockname()[1] + reader, writer = await asyncio.open_connection("localhost", port) + req = {"CONTENT_LENGTH": "0", "REQUEST_METHOD": "GET"} + reqb = str2bytes("".join(map("%s\0%s\0".__mod__, req.items()))) + writer.write(b"%d:%s," % (len(reqb), reqb)) + await writer.drain() + data = await reader.read() + writer.close() + await writer.wait_closed() + self.assertTrue(data.startswith(b"Status: 200 OK\r\n")) + self.assertTrue(data.endswith(b"\r\n\r\nnothing")) + + def alltests(case): return unittest.TestLoader().loadTestsFromTestCase(case) @@ -514,6 +544,7 @@ fullsuite.addTest(alltests(RequestLogWSGIFilterTest)) fullsuite.addTest(alltests(GzipWSGIFilterTest)) fullsuite.addTest(alltests(ScgiAsynchronousTest)) fullsuite.addTest(alltests(ScgiForkTest)) +fullsuite.addTest(alltests(ScgiAsyncioTest)) if __name__ == "__main__": runner = unittest.TextTestRunner(verbosity=2) diff --git a/wsgitools/scgi/__init__.py b/wsgitools/scgi/__init__.py index 677e3b5..38a5bef 100644 --- a/wsgitools/scgi/__init__.py +++ b/wsgitools/scgi/__init__.py @@ -65,6 +65,7 @@ def _convert_environ( multithread: bool = False, multiprocess: bool = False, run_once: bool = False, + enable_sendfile: bool = True, ) -> None: environ.update({ "wsgi.version": (1, 0), @@ -79,5 +80,5 @@ def _convert_environ( except KeyError: pass environ.pop("HTTP_CONTENT_LENGTH", None) # TODO: better way? - if have_sendfile: + if have_sendfile and enable_sendfile: environ["wsgi.file_wrapper"] = FileWrapper diff --git a/wsgitools/scgi/asyncio.py b/wsgitools/scgi/asyncio.py new file mode 100644 index 0000000..9f82622 --- /dev/null +++ b/wsgitools/scgi/asyncio.py @@ -0,0 +1,242 @@ +__all__ = [] + +import asyncio +import io +import sys +import typing + +from wsgitools.internal import ( + bytes2str, + Environ, + HeaderList, + OptExcInfo, + str2bytes, + StartResponse, + WriteCallback, + WsgiApp, +) +from wsgitools.scgi import _convert_environ + +__all__.append("SCGIProtocol") +class SCGIProtocol(asyncio.Protocol): + def __init__( + self, + wsgiapp: WsgiApp, + baseenviron: Environ, + *, + maxrequestsize: int = 65536, + maxpostsize: int = 8 << 20, + ): + # configuration + self.wsgiapp = wsgiapp + self.maxrequestsize = maxrequestsize + self.maxpostsize = maxpostsize + self.transport: typing.Optional[asyncio.Transport] = None + + # request state + self.inbuff = b"" + self.parse: typing.Optional[typing.Callable[[], bool]] + self.parse = self.parse_reqsize + self.reqlen = -1 + self.environ: Environ = baseenviron.copy() + self.body = io.BytesIO() + + # response state + self.outheaders: typing.Union[bool, bytes] = False + # outheaders is a three-state + # * False -> start_response not yet called + # * bytes -> headers set by start_response, but not yet sent + # * True -> headers sent + self.wsgihandler: typing.Optional[typing.Iterable[bytes]] = None + self.wsgiiterator: typing.Optional[typing.Iterator[bytes]] = None + self.writeable = True + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + assert isinstance(transport, asyncio.Transport) + self.transport = transport + + def data_received(self, data: bytes) -> None: + self.inbuff += data + assert self.parse is not None + while self.parse(): + pass + + def parse_reqsize(self) -> bool: + """Parse the "%d:" part of the SCGI header.""" + assert self.transport is not None + parts = self.inbuff.split(b':', 1) + if len(parts) < 2 and len(self.inbuff) > 21: + self.transport.abort() + return False # request size implausibly large + try: + reqlen = int(parts[0]) + except ValueError: + self.transport.abort() + return False # invalid request format + self.inbuff = parts[1] + if reqlen > self.maxrequestsize: + self.transport.abort() + return False + self.reqlen = reqlen + self.parse = self.parse_environ + return True + + def parse_environ(self) -> bool: + """Parse the sequence of strings representing environ.""" + consumed = 0 + strings = self.inbuff[:self.reqlen].split(b'\0') + while len(strings) > 2: + key = strings.pop(0) + value = strings.pop(0) + self.environ[bytes2str(key)] = bytes2str(value) + consumed += len(key) + len(value) + 2 + self.inbuff = self.inbuff[consumed:] + self.reqlen -= consumed + if self.reqlen != 0: + return False + self.parse = self.parse_tail + return True + + def parse_tail(self) -> bool: + """Parse the comma and validate the environ.""" + assert self.transport is not None + if not self.inbuff: + return False + if not self.inbuff.startswith(b","): + self.transport.abort() + return False # invalid request format + self.inbuff = self.inbuff[1:] + try: + self.reqlen = int(self.environ["CONTENT_LENGTH"]) + except (KeyError, ValueError): + self.transport.abort() + return False + if self.reqlen > self.maxpostsize: + self.transport.abort() + return False + self.parse = self.parse_body + return True + + def parse_body(self) -> bool: + """Read the request body.""" + assert self.transport is not None + if len(self.inbuff) < self.reqlen: + self.body.write(self.inbuff) + self.reqlen -= len(self.inbuff) + self.inbuff = b"" + return True + + self.transport.pause_reading() + self.parse = None + self.body.write(self.inbuff[:self.reqlen]) + self.body.seek(0) + self.inbuff = b"" + self.reqlen = 0 + _convert_environ(self.environ, enable_sendfile=False) + self.environ["wsgi.input"] = self.body + self.wsgihandler = self.wsgiapp(self.environ, self.start_response) + assert self.wsgihandler is not None + self.wsgiiterator = iter(self.wsgihandler) + self.resume_writing() + return False + + def start_response( + self, status: str, headers: HeaderList, exc_info: OptExcInfo = None, + ) -> WriteCallback: + assert isinstance(status, str) + assert isinstance(headers, list) + if exc_info: + if self.outheaders is True: + try: + raise exc_info[0](exc_info[1]).with_traceback(exc_info[2]) + finally: + exc_info = None + assert self.outheaders is not True + self.outheaders = str2bytes( + "Status: %s\r\n%s\r\n" % ( + status, "".join(map("%s: %s\r\n".__mod__, headers)) + ) + ) + return self.wsgi_write + + def send_headers(self) -> None: + if self.outheaders is True: + return + assert self.transport is not None + assert isinstance(self.outheaders, bytes) + self.transport.write(self.outheaders) + self.outheaders = True + + def wsgi_write(self, data: bytes) -> None: + assert isinstance(data, bytes) + assert self.transport is not None + if not data: + return + assert self.parse is None + self.send_headers() + self.transport.write(data) + + def pause_writing(self) -> None: + self.writeable = False + + def resume_writing(self) -> None: + assert self.transport is not None + assert self.wsgiiterator is not None + self.writeable = True + while self.writeable: + try: + data = next(self.wsgiiterator) + except StopIteration: + self.send_headers() + self.transport.write_eof() + return + self.wsgi_write(data) + + def eof_received(self) -> None: + if self.parse is not None: + assert self.transport is not None + self.transport.abort() + + def connection_lost(self, exc: typing.Optional[Exception]) -> None: + assert self.transport is not None + self.transport.abort() + if hasattr(self.wsgihandler, "close"): + assert self.wsgihandler is not None + self.wsgihandler.close() + +__all__.append("SCGIProtocolFactory") +class SCGIProtocolFactory: + """An asyncio.Protocol factory for the SCGI protocol. + + Typical use: + + await loop.create_server(SCGIProtocolFactory(app, ...), port=...) + """ + def __init__( + self, + wsgiapp: WsgiApp, + *, + error: typing.Optional[typing.TextIO] = None, + maxrequestsize: typing.Optional[int] = None, + maxpostsize: typing.Optional[int] = None, + config: typing.Optional[Environ] = None, + ): + self.wsgiapp = wsgiapp + self.environ = {} if config is None else config.copy() + self.environ["wsgi.errors"] = error if error is None else sys.stderr + self.kwargs: typing.Dict[str, typing.Any] = {} + if maxrequestsize is not None: + self.kwargs["maxrequestsize"] = maxrequestsize + if maxpostsize is not None: + self.kwargs["maxpostsize"] = maxpostsize + + def __call__(self) -> SCGIProtocol: + return SCGIProtocol(self.wsgiapp, self.environ, **self.kwargs) + + def create_server( + self, *args, **kwargs + ) -> typing.Awaitable[asyncio.base_events.Server]: + """Convenience wrapper around asyncio.create_server filling in the loop + object and the factory argument. + """ + return asyncio.get_running_loop().create_server(self, *args, **kwargs) -- cgit v1.2.3