summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHelmut Grohne <helmut@subdivi.de>2023-06-17 15:10:02 +0200
committerHelmut Grohne <helmut@subdivi.de>2023-06-18 23:18:50 +0200
commitf1e580b1a14f980bf001662c74b0382e644850be (patch)
tree44706914b7b1e864307be6d7e36f503c6e8b5946
parenta41066b413489b407b9d99174af697563ad680b9 (diff)
downloadwsgitools-f1e580b1a14f980bf001662c74b0382e644850be.tar.gz
add a wsgitools.scgi.asyncio module
This adds an asyncio implementation of the server side of the SCGI protocol, because asyncore is being deprecated. Unlike the asyncore implementation, this does not yet support sendfile.
-rwxr-xr-xtest.py31
-rw-r--r--wsgitools/scgi/__init__.py3
-rw-r--r--wsgitools/scgi/asyncio.py242
3 files changed, 275 insertions, 1 deletions
diff --git a/test.py b/test.py
index 5271819..7d6a463 100755
--- a/test.py
+++ b/test.py
@@ -498,6 +498,36 @@ class ScgiForkTest(unittest.TestCase):
self.assertTrue(data.startswith(b"Status: 200 OK\r\n"))
self.assertTrue(data.endswith(b"\r\n\r\nnothing"))
+from wsgitools.scgi.asyncio import SCGIProtocolFactory
+import asyncio
+
+class ScgiAsyncioTest(unittest.TestCase):
+ def testSimple(self) -> None:
+ asyncio.get_event_loop().run_until_complete(self.asyncTestSimple())
+
+ async def asyncTestSimple(self) -> None:
+ app = applications.StaticContent(
+ "200 OK", [("Content-Type", "text/plain")], b"nothing"
+ )
+ server = await asyncio.get_running_loop().create_server(
+ SCGIProtocolFactory(app),
+ family=socket.AF_INET,
+ host="localhost",
+ port=0,
+ )
+ port = server.sockets[0].getsockname()[1]
+ reader, writer = await asyncio.open_connection("localhost", port)
+ req = {"CONTENT_LENGTH": "0", "REQUEST_METHOD": "GET"}
+ reqb = str2bytes("".join(map("%s\0%s\0".__mod__, req.items())))
+ writer.write(b"%d:%s," % (len(reqb), reqb))
+ await writer.drain()
+ data = await reader.read()
+ writer.close()
+ await writer.wait_closed()
+ self.assertTrue(data.startswith(b"Status: 200 OK\r\n"))
+ self.assertTrue(data.endswith(b"\r\n\r\nnothing"))
+
+
def alltests(case):
return unittest.TestLoader().loadTestsFromTestCase(case)
@@ -514,6 +544,7 @@ fullsuite.addTest(alltests(RequestLogWSGIFilterTest))
fullsuite.addTest(alltests(GzipWSGIFilterTest))
fullsuite.addTest(alltests(ScgiAsynchronousTest))
fullsuite.addTest(alltests(ScgiForkTest))
+fullsuite.addTest(alltests(ScgiAsyncioTest))
if __name__ == "__main__":
runner = unittest.TextTestRunner(verbosity=2)
diff --git a/wsgitools/scgi/__init__.py b/wsgitools/scgi/__init__.py
index 677e3b5..38a5bef 100644
--- a/wsgitools/scgi/__init__.py
+++ b/wsgitools/scgi/__init__.py
@@ -65,6 +65,7 @@ def _convert_environ(
multithread: bool = False,
multiprocess: bool = False,
run_once: bool = False,
+ enable_sendfile: bool = True,
) -> None:
environ.update({
"wsgi.version": (1, 0),
@@ -79,5 +80,5 @@ def _convert_environ(
except KeyError:
pass
environ.pop("HTTP_CONTENT_LENGTH", None) # TODO: better way?
- if have_sendfile:
+ if have_sendfile and enable_sendfile:
environ["wsgi.file_wrapper"] = FileWrapper
diff --git a/wsgitools/scgi/asyncio.py b/wsgitools/scgi/asyncio.py
new file mode 100644
index 0000000..9f82622
--- /dev/null
+++ b/wsgitools/scgi/asyncio.py
@@ -0,0 +1,242 @@
+__all__ = []
+
+import asyncio
+import io
+import sys
+import typing
+
+from wsgitools.internal import (
+ bytes2str,
+ Environ,
+ HeaderList,
+ OptExcInfo,
+ str2bytes,
+ StartResponse,
+ WriteCallback,
+ WsgiApp,
+)
+from wsgitools.scgi import _convert_environ
+
+__all__.append("SCGIProtocol")
+class SCGIProtocol(asyncio.Protocol):
+ def __init__(
+ self,
+ wsgiapp: WsgiApp,
+ baseenviron: Environ,
+ *,
+ maxrequestsize: int = 65536,
+ maxpostsize: int = 8 << 20,
+ ):
+ # configuration
+ self.wsgiapp = wsgiapp
+ self.maxrequestsize = maxrequestsize
+ self.maxpostsize = maxpostsize
+ self.transport: typing.Optional[asyncio.Transport] = None
+
+ # request state
+ self.inbuff = b""
+ self.parse: typing.Optional[typing.Callable[[], bool]]
+ self.parse = self.parse_reqsize
+ self.reqlen = -1
+ self.environ: Environ = baseenviron.copy()
+ self.body = io.BytesIO()
+
+ # response state
+ self.outheaders: typing.Union[bool, bytes] = False
+ # outheaders is a three-state
+ # * False -> start_response not yet called
+ # * bytes -> headers set by start_response, but not yet sent
+ # * True -> headers sent
+ self.wsgihandler: typing.Optional[typing.Iterable[bytes]] = None
+ self.wsgiiterator: typing.Optional[typing.Iterator[bytes]] = None
+ self.writeable = True
+
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ assert isinstance(transport, asyncio.Transport)
+ self.transport = transport
+
+ def data_received(self, data: bytes) -> None:
+ self.inbuff += data
+ assert self.parse is not None
+ while self.parse():
+ pass
+
+ def parse_reqsize(self) -> bool:
+ """Parse the "%d:" part of the SCGI header."""
+ assert self.transport is not None
+ parts = self.inbuff.split(b':', 1)
+ if len(parts) < 2 and len(self.inbuff) > 21:
+ self.transport.abort()
+ return False # request size implausibly large
+ try:
+ reqlen = int(parts[0])
+ except ValueError:
+ self.transport.abort()
+ return False # invalid request format
+ self.inbuff = parts[1]
+ if reqlen > self.maxrequestsize:
+ self.transport.abort()
+ return False
+ self.reqlen = reqlen
+ self.parse = self.parse_environ
+ return True
+
+ def parse_environ(self) -> bool:
+ """Parse the sequence of strings representing environ."""
+ consumed = 0
+ strings = self.inbuff[:self.reqlen].split(b'\0')
+ while len(strings) > 2:
+ key = strings.pop(0)
+ value = strings.pop(0)
+ self.environ[bytes2str(key)] = bytes2str(value)
+ consumed += len(key) + len(value) + 2
+ self.inbuff = self.inbuff[consumed:]
+ self.reqlen -= consumed
+ if self.reqlen != 0:
+ return False
+ self.parse = self.parse_tail
+ return True
+
+ def parse_tail(self) -> bool:
+ """Parse the comma and validate the environ."""
+ assert self.transport is not None
+ if not self.inbuff:
+ return False
+ if not self.inbuff.startswith(b","):
+ self.transport.abort()
+ return False # invalid request format
+ self.inbuff = self.inbuff[1:]
+ try:
+ self.reqlen = int(self.environ["CONTENT_LENGTH"])
+ except (KeyError, ValueError):
+ self.transport.abort()
+ return False
+ if self.reqlen > self.maxpostsize:
+ self.transport.abort()
+ return False
+ self.parse = self.parse_body
+ return True
+
+ def parse_body(self) -> bool:
+ """Read the request body."""
+ assert self.transport is not None
+ if len(self.inbuff) < self.reqlen:
+ self.body.write(self.inbuff)
+ self.reqlen -= len(self.inbuff)
+ self.inbuff = b""
+ return True
+
+ self.transport.pause_reading()
+ self.parse = None
+ self.body.write(self.inbuff[:self.reqlen])
+ self.body.seek(0)
+ self.inbuff = b""
+ self.reqlen = 0
+ _convert_environ(self.environ, enable_sendfile=False)
+ self.environ["wsgi.input"] = self.body
+ self.wsgihandler = self.wsgiapp(self.environ, self.start_response)
+ assert self.wsgihandler is not None
+ self.wsgiiterator = iter(self.wsgihandler)
+ self.resume_writing()
+ return False
+
+ def start_response(
+ self, status: str, headers: HeaderList, exc_info: OptExcInfo = None,
+ ) -> WriteCallback:
+ assert isinstance(status, str)
+ assert isinstance(headers, list)
+ if exc_info:
+ if self.outheaders is True:
+ try:
+ raise exc_info[0](exc_info[1]).with_traceback(exc_info[2])
+ finally:
+ exc_info = None
+ assert self.outheaders is not True
+ self.outheaders = str2bytes(
+ "Status: %s\r\n%s\r\n" % (
+ status, "".join(map("%s: %s\r\n".__mod__, headers))
+ )
+ )
+ return self.wsgi_write
+
+ def send_headers(self) -> None:
+ if self.outheaders is True:
+ return
+ assert self.transport is not None
+ assert isinstance(self.outheaders, bytes)
+ self.transport.write(self.outheaders)
+ self.outheaders = True
+
+ def wsgi_write(self, data: bytes) -> None:
+ assert isinstance(data, bytes)
+ assert self.transport is not None
+ if not data:
+ return
+ assert self.parse is None
+ self.send_headers()
+ self.transport.write(data)
+
+ def pause_writing(self) -> None:
+ self.writeable = False
+
+ def resume_writing(self) -> None:
+ assert self.transport is not None
+ assert self.wsgiiterator is not None
+ self.writeable = True
+ while self.writeable:
+ try:
+ data = next(self.wsgiiterator)
+ except StopIteration:
+ self.send_headers()
+ self.transport.write_eof()
+ return
+ self.wsgi_write(data)
+
+ def eof_received(self) -> None:
+ if self.parse is not None:
+ assert self.transport is not None
+ self.transport.abort()
+
+ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
+ assert self.transport is not None
+ self.transport.abort()
+ if hasattr(self.wsgihandler, "close"):
+ assert self.wsgihandler is not None
+ self.wsgihandler.close()
+
+__all__.append("SCGIProtocolFactory")
+class SCGIProtocolFactory:
+ """An asyncio.Protocol factory for the SCGI protocol.
+
+ Typical use:
+
+ await loop.create_server(SCGIProtocolFactory(app, ...), port=...)
+ """
+ def __init__(
+ self,
+ wsgiapp: WsgiApp,
+ *,
+ error: typing.Optional[typing.TextIO] = None,
+ maxrequestsize: typing.Optional[int] = None,
+ maxpostsize: typing.Optional[int] = None,
+ config: typing.Optional[Environ] = None,
+ ):
+ self.wsgiapp = wsgiapp
+ self.environ = {} if config is None else config.copy()
+ self.environ["wsgi.errors"] = error if error is None else sys.stderr
+ self.kwargs: typing.Dict[str, typing.Any] = {}
+ if maxrequestsize is not None:
+ self.kwargs["maxrequestsize"] = maxrequestsize
+ if maxpostsize is not None:
+ self.kwargs["maxpostsize"] = maxpostsize
+
+ def __call__(self) -> SCGIProtocol:
+ return SCGIProtocol(self.wsgiapp, self.environ, **self.kwargs)
+
+ def create_server(
+ self, *args, **kwargs
+ ) -> typing.Awaitable[asyncio.base_events.Server]:
+ """Convenience wrapper around asyncio.create_server filling in the loop
+ object and the factory argument.
+ """
+ return asyncio.get_running_loop().create_server(self, *args, **kwargs)