summaryrefslogtreecommitdiff
path: root/wsgitools/scgi/asyncio.py
diff options
context:
space:
mode:
Diffstat (limited to 'wsgitools/scgi/asyncio.py')
-rw-r--r--wsgitools/scgi/asyncio.py242
1 files changed, 242 insertions, 0 deletions
diff --git a/wsgitools/scgi/asyncio.py b/wsgitools/scgi/asyncio.py
new file mode 100644
index 0000000..9f82622
--- /dev/null
+++ b/wsgitools/scgi/asyncio.py
@@ -0,0 +1,242 @@
+__all__ = []
+
+import asyncio
+import io
+import sys
+import typing
+
+from wsgitools.internal import (
+ bytes2str,
+ Environ,
+ HeaderList,
+ OptExcInfo,
+ str2bytes,
+ StartResponse,
+ WriteCallback,
+ WsgiApp,
+)
+from wsgitools.scgi import _convert_environ
+
+__all__.append("SCGIProtocol")
+class SCGIProtocol(asyncio.Protocol):
+ def __init__(
+ self,
+ wsgiapp: WsgiApp,
+ baseenviron: Environ,
+ *,
+ maxrequestsize: int = 65536,
+ maxpostsize: int = 8 << 20,
+ ):
+ # configuration
+ self.wsgiapp = wsgiapp
+ self.maxrequestsize = maxrequestsize
+ self.maxpostsize = maxpostsize
+ self.transport: typing.Optional[asyncio.Transport] = None
+
+ # request state
+ self.inbuff = b""
+ self.parse: typing.Optional[typing.Callable[[], bool]]
+ self.parse = self.parse_reqsize
+ self.reqlen = -1
+ self.environ: Environ = baseenviron.copy()
+ self.body = io.BytesIO()
+
+ # response state
+ self.outheaders: typing.Union[bool, bytes] = False
+ # outheaders is a three-state
+ # * False -> start_response not yet called
+ # * bytes -> headers set by start_response, but not yet sent
+ # * True -> headers sent
+ self.wsgihandler: typing.Optional[typing.Iterable[bytes]] = None
+ self.wsgiiterator: typing.Optional[typing.Iterator[bytes]] = None
+ self.writeable = True
+
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ assert isinstance(transport, asyncio.Transport)
+ self.transport = transport
+
+ def data_received(self, data: bytes) -> None:
+ self.inbuff += data
+ assert self.parse is not None
+ while self.parse():
+ pass
+
+ def parse_reqsize(self) -> bool:
+ """Parse the "%d:" part of the SCGI header."""
+ assert self.transport is not None
+ parts = self.inbuff.split(b':', 1)
+ if len(parts) < 2 and len(self.inbuff) > 21:
+ self.transport.abort()
+ return False # request size implausibly large
+ try:
+ reqlen = int(parts[0])
+ except ValueError:
+ self.transport.abort()
+ return False # invalid request format
+ self.inbuff = parts[1]
+ if reqlen > self.maxrequestsize:
+ self.transport.abort()
+ return False
+ self.reqlen = reqlen
+ self.parse = self.parse_environ
+ return True
+
+ def parse_environ(self) -> bool:
+ """Parse the sequence of strings representing environ."""
+ consumed = 0
+ strings = self.inbuff[:self.reqlen].split(b'\0')
+ while len(strings) > 2:
+ key = strings.pop(0)
+ value = strings.pop(0)
+ self.environ[bytes2str(key)] = bytes2str(value)
+ consumed += len(key) + len(value) + 2
+ self.inbuff = self.inbuff[consumed:]
+ self.reqlen -= consumed
+ if self.reqlen != 0:
+ return False
+ self.parse = self.parse_tail
+ return True
+
+ def parse_tail(self) -> bool:
+ """Parse the comma and validate the environ."""
+ assert self.transport is not None
+ if not self.inbuff:
+ return False
+ if not self.inbuff.startswith(b","):
+ self.transport.abort()
+ return False # invalid request format
+ self.inbuff = self.inbuff[1:]
+ try:
+ self.reqlen = int(self.environ["CONTENT_LENGTH"])
+ except (KeyError, ValueError):
+ self.transport.abort()
+ return False
+ if self.reqlen > self.maxpostsize:
+ self.transport.abort()
+ return False
+ self.parse = self.parse_body
+ return True
+
+ def parse_body(self) -> bool:
+ """Read the request body."""
+ assert self.transport is not None
+ if len(self.inbuff) < self.reqlen:
+ self.body.write(self.inbuff)
+ self.reqlen -= len(self.inbuff)
+ self.inbuff = b""
+ return True
+
+ self.transport.pause_reading()
+ self.parse = None
+ self.body.write(self.inbuff[:self.reqlen])
+ self.body.seek(0)
+ self.inbuff = b""
+ self.reqlen = 0
+ _convert_environ(self.environ, enable_sendfile=False)
+ self.environ["wsgi.input"] = self.body
+ self.wsgihandler = self.wsgiapp(self.environ, self.start_response)
+ assert self.wsgihandler is not None
+ self.wsgiiterator = iter(self.wsgihandler)
+ self.resume_writing()
+ return False
+
+ def start_response(
+ self, status: str, headers: HeaderList, exc_info: OptExcInfo = None,
+ ) -> WriteCallback:
+ assert isinstance(status, str)
+ assert isinstance(headers, list)
+ if exc_info:
+ if self.outheaders is True:
+ try:
+ raise exc_info[0](exc_info[1]).with_traceback(exc_info[2])
+ finally:
+ exc_info = None
+ assert self.outheaders is not True
+ self.outheaders = str2bytes(
+ "Status: %s\r\n%s\r\n" % (
+ status, "".join(map("%s: %s\r\n".__mod__, headers))
+ )
+ )
+ return self.wsgi_write
+
+ def send_headers(self) -> None:
+ if self.outheaders is True:
+ return
+ assert self.transport is not None
+ assert isinstance(self.outheaders, bytes)
+ self.transport.write(self.outheaders)
+ self.outheaders = True
+
+ def wsgi_write(self, data: bytes) -> None:
+ assert isinstance(data, bytes)
+ assert self.transport is not None
+ if not data:
+ return
+ assert self.parse is None
+ self.send_headers()
+ self.transport.write(data)
+
+ def pause_writing(self) -> None:
+ self.writeable = False
+
+ def resume_writing(self) -> None:
+ assert self.transport is not None
+ assert self.wsgiiterator is not None
+ self.writeable = True
+ while self.writeable:
+ try:
+ data = next(self.wsgiiterator)
+ except StopIteration:
+ self.send_headers()
+ self.transport.write_eof()
+ return
+ self.wsgi_write(data)
+
+ def eof_received(self) -> None:
+ if self.parse is not None:
+ assert self.transport is not None
+ self.transport.abort()
+
+ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
+ assert self.transport is not None
+ self.transport.abort()
+ if hasattr(self.wsgihandler, "close"):
+ assert self.wsgihandler is not None
+ self.wsgihandler.close()
+
+__all__.append("SCGIProtocolFactory")
+class SCGIProtocolFactory:
+ """An asyncio.Protocol factory for the SCGI protocol.
+
+ Typical use:
+
+ await loop.create_server(SCGIProtocolFactory(app, ...), port=...)
+ """
+ def __init__(
+ self,
+ wsgiapp: WsgiApp,
+ *,
+ error: typing.Optional[typing.TextIO] = None,
+ maxrequestsize: typing.Optional[int] = None,
+ maxpostsize: typing.Optional[int] = None,
+ config: typing.Optional[Environ] = None,
+ ):
+ self.wsgiapp = wsgiapp
+ self.environ = {} if config is None else config.copy()
+ self.environ["wsgi.errors"] = error if error is None else sys.stderr
+ self.kwargs: typing.Dict[str, typing.Any] = {}
+ if maxrequestsize is not None:
+ self.kwargs["maxrequestsize"] = maxrequestsize
+ if maxpostsize is not None:
+ self.kwargs["maxpostsize"] = maxpostsize
+
+ def __call__(self) -> SCGIProtocol:
+ return SCGIProtocol(self.wsgiapp, self.environ, **self.kwargs)
+
+ def create_server(
+ self, *args, **kwargs
+ ) -> typing.Awaitable[asyncio.base_events.Server]:
+ """Convenience wrapper around asyncio.create_server filling in the loop
+ object and the factory argument.
+ """
+ return asyncio.get_running_loop().create_server(self, *args, **kwargs)