summaryrefslogtreecommitdiff
path: root/wsgitools/scgi
diff options
context:
space:
mode:
authorHelmut Grohne <helmut@subdivi.de>2023-06-17 15:10:02 +0200
committerHelmut Grohne <helmut@subdivi.de>2023-06-18 23:18:50 +0200
commitf1e580b1a14f980bf001662c74b0382e644850be (patch)
tree44706914b7b1e864307be6d7e36f503c6e8b5946 /wsgitools/scgi
parenta41066b413489b407b9d99174af697563ad680b9 (diff)
downloadwsgitools-f1e580b1a14f980bf001662c74b0382e644850be.tar.gz
add a wsgitools.scgi.asyncio module
This adds an asyncio implementation of the server side of the SCGI protocol, because asyncore is being deprecated. Unlike the asyncore implementation, this does not yet support sendfile.
Diffstat (limited to 'wsgitools/scgi')
-rw-r--r--wsgitools/scgi/__init__.py3
-rw-r--r--wsgitools/scgi/asyncio.py242
2 files changed, 244 insertions, 1 deletions
diff --git a/wsgitools/scgi/__init__.py b/wsgitools/scgi/__init__.py
index 677e3b5..38a5bef 100644
--- a/wsgitools/scgi/__init__.py
+++ b/wsgitools/scgi/__init__.py
@@ -65,6 +65,7 @@ def _convert_environ(
multithread: bool = False,
multiprocess: bool = False,
run_once: bool = False,
+ enable_sendfile: bool = True,
) -> None:
environ.update({
"wsgi.version": (1, 0),
@@ -79,5 +80,5 @@ def _convert_environ(
except KeyError:
pass
environ.pop("HTTP_CONTENT_LENGTH", None) # TODO: better way?
- if have_sendfile:
+ if have_sendfile and enable_sendfile:
environ["wsgi.file_wrapper"] = FileWrapper
diff --git a/wsgitools/scgi/asyncio.py b/wsgitools/scgi/asyncio.py
new file mode 100644
index 0000000..9f82622
--- /dev/null
+++ b/wsgitools/scgi/asyncio.py
@@ -0,0 +1,242 @@
+__all__ = []
+
+import asyncio
+import io
+import sys
+import typing
+
+from wsgitools.internal import (
+ bytes2str,
+ Environ,
+ HeaderList,
+ OptExcInfo,
+ str2bytes,
+ StartResponse,
+ WriteCallback,
+ WsgiApp,
+)
+from wsgitools.scgi import _convert_environ
+
+__all__.append("SCGIProtocol")
+class SCGIProtocol(asyncio.Protocol):
+ def __init__(
+ self,
+ wsgiapp: WsgiApp,
+ baseenviron: Environ,
+ *,
+ maxrequestsize: int = 65536,
+ maxpostsize: int = 8 << 20,
+ ):
+ # configuration
+ self.wsgiapp = wsgiapp
+ self.maxrequestsize = maxrequestsize
+ self.maxpostsize = maxpostsize
+ self.transport: typing.Optional[asyncio.Transport] = None
+
+ # request state
+ self.inbuff = b""
+ self.parse: typing.Optional[typing.Callable[[], bool]]
+ self.parse = self.parse_reqsize
+ self.reqlen = -1
+ self.environ: Environ = baseenviron.copy()
+ self.body = io.BytesIO()
+
+ # response state
+ self.outheaders: typing.Union[bool, bytes] = False
+ # outheaders is a three-state
+ # * False -> start_response not yet called
+ # * bytes -> headers set by start_response, but not yet sent
+ # * True -> headers sent
+ self.wsgihandler: typing.Optional[typing.Iterable[bytes]] = None
+ self.wsgiiterator: typing.Optional[typing.Iterator[bytes]] = None
+ self.writeable = True
+
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ assert isinstance(transport, asyncio.Transport)
+ self.transport = transport
+
+ def data_received(self, data: bytes) -> None:
+ self.inbuff += data
+ assert self.parse is not None
+ while self.parse():
+ pass
+
+ def parse_reqsize(self) -> bool:
+ """Parse the "%d:" part of the SCGI header."""
+ assert self.transport is not None
+ parts = self.inbuff.split(b':', 1)
+ if len(parts) < 2 and len(self.inbuff) > 21:
+ self.transport.abort()
+ return False # request size implausibly large
+ try:
+ reqlen = int(parts[0])
+ except ValueError:
+ self.transport.abort()
+ return False # invalid request format
+ self.inbuff = parts[1]
+ if reqlen > self.maxrequestsize:
+ self.transport.abort()
+ return False
+ self.reqlen = reqlen
+ self.parse = self.parse_environ
+ return True
+
+ def parse_environ(self) -> bool:
+ """Parse the sequence of strings representing environ."""
+ consumed = 0
+ strings = self.inbuff[:self.reqlen].split(b'\0')
+ while len(strings) > 2:
+ key = strings.pop(0)
+ value = strings.pop(0)
+ self.environ[bytes2str(key)] = bytes2str(value)
+ consumed += len(key) + len(value) + 2
+ self.inbuff = self.inbuff[consumed:]
+ self.reqlen -= consumed
+ if self.reqlen != 0:
+ return False
+ self.parse = self.parse_tail
+ return True
+
+ def parse_tail(self) -> bool:
+ """Parse the comma and validate the environ."""
+ assert self.transport is not None
+ if not self.inbuff:
+ return False
+ if not self.inbuff.startswith(b","):
+ self.transport.abort()
+ return False # invalid request format
+ self.inbuff = self.inbuff[1:]
+ try:
+ self.reqlen = int(self.environ["CONTENT_LENGTH"])
+ except (KeyError, ValueError):
+ self.transport.abort()
+ return False
+ if self.reqlen > self.maxpostsize:
+ self.transport.abort()
+ return False
+ self.parse = self.parse_body
+ return True
+
+ def parse_body(self) -> bool:
+ """Read the request body."""
+ assert self.transport is not None
+ if len(self.inbuff) < self.reqlen:
+ self.body.write(self.inbuff)
+ self.reqlen -= len(self.inbuff)
+ self.inbuff = b""
+ return True
+
+ self.transport.pause_reading()
+ self.parse = None
+ self.body.write(self.inbuff[:self.reqlen])
+ self.body.seek(0)
+ self.inbuff = b""
+ self.reqlen = 0
+ _convert_environ(self.environ, enable_sendfile=False)
+ self.environ["wsgi.input"] = self.body
+ self.wsgihandler = self.wsgiapp(self.environ, self.start_response)
+ assert self.wsgihandler is not None
+ self.wsgiiterator = iter(self.wsgihandler)
+ self.resume_writing()
+ return False
+
+ def start_response(
+ self, status: str, headers: HeaderList, exc_info: OptExcInfo = None,
+ ) -> WriteCallback:
+ assert isinstance(status, str)
+ assert isinstance(headers, list)
+ if exc_info:
+ if self.outheaders is True:
+ try:
+ raise exc_info[0](exc_info[1]).with_traceback(exc_info[2])
+ finally:
+ exc_info = None
+ assert self.outheaders is not True
+ self.outheaders = str2bytes(
+ "Status: %s\r\n%s\r\n" % (
+ status, "".join(map("%s: %s\r\n".__mod__, headers))
+ )
+ )
+ return self.wsgi_write
+
+ def send_headers(self) -> None:
+ if self.outheaders is True:
+ return
+ assert self.transport is not None
+ assert isinstance(self.outheaders, bytes)
+ self.transport.write(self.outheaders)
+ self.outheaders = True
+
+ def wsgi_write(self, data: bytes) -> None:
+ assert isinstance(data, bytes)
+ assert self.transport is not None
+ if not data:
+ return
+ assert self.parse is None
+ self.send_headers()
+ self.transport.write(data)
+
+ def pause_writing(self) -> None:
+ self.writeable = False
+
+ def resume_writing(self) -> None:
+ assert self.transport is not None
+ assert self.wsgiiterator is not None
+ self.writeable = True
+ while self.writeable:
+ try:
+ data = next(self.wsgiiterator)
+ except StopIteration:
+ self.send_headers()
+ self.transport.write_eof()
+ return
+ self.wsgi_write(data)
+
+ def eof_received(self) -> None:
+ if self.parse is not None:
+ assert self.transport is not None
+ self.transport.abort()
+
+ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
+ assert self.transport is not None
+ self.transport.abort()
+ if hasattr(self.wsgihandler, "close"):
+ assert self.wsgihandler is not None
+ self.wsgihandler.close()
+
+__all__.append("SCGIProtocolFactory")
+class SCGIProtocolFactory:
+ """An asyncio.Protocol factory for the SCGI protocol.
+
+ Typical use:
+
+ await loop.create_server(SCGIProtocolFactory(app, ...), port=...)
+ """
+ def __init__(
+ self,
+ wsgiapp: WsgiApp,
+ *,
+ error: typing.Optional[typing.TextIO] = None,
+ maxrequestsize: typing.Optional[int] = None,
+ maxpostsize: typing.Optional[int] = None,
+ config: typing.Optional[Environ] = None,
+ ):
+ self.wsgiapp = wsgiapp
+ self.environ = {} if config is None else config.copy()
+ self.environ["wsgi.errors"] = error if error is None else sys.stderr
+ self.kwargs: typing.Dict[str, typing.Any] = {}
+ if maxrequestsize is not None:
+ self.kwargs["maxrequestsize"] = maxrequestsize
+ if maxpostsize is not None:
+ self.kwargs["maxpostsize"] = maxpostsize
+
+ def __call__(self) -> SCGIProtocol:
+ return SCGIProtocol(self.wsgiapp, self.environ, **self.kwargs)
+
+ def create_server(
+ self, *args, **kwargs
+ ) -> typing.Awaitable[asyncio.base_events.Server]:
+ """Convenience wrapper around asyncio.create_server filling in the loop
+ object and the factory argument.
+ """
+ return asyncio.get_running_loop().create_server(self, *args, **kwargs)