diff options
Diffstat (limited to 'wsgitools')
-rw-r--r-- | wsgitools/scgi/__init__.py | 3 | ||||
-rw-r--r-- | wsgitools/scgi/asyncio.py | 242 |
2 files changed, 244 insertions, 1 deletions
diff --git a/wsgitools/scgi/__init__.py b/wsgitools/scgi/__init__.py index 677e3b5..38a5bef 100644 --- a/wsgitools/scgi/__init__.py +++ b/wsgitools/scgi/__init__.py @@ -65,6 +65,7 @@ def _convert_environ( multithread: bool = False, multiprocess: bool = False, run_once: bool = False, + enable_sendfile: bool = True, ) -> None: environ.update({ "wsgi.version": (1, 0), @@ -79,5 +80,5 @@ def _convert_environ( except KeyError: pass environ.pop("HTTP_CONTENT_LENGTH", None) # TODO: better way? - if have_sendfile: + if have_sendfile and enable_sendfile: environ["wsgi.file_wrapper"] = FileWrapper diff --git a/wsgitools/scgi/asyncio.py b/wsgitools/scgi/asyncio.py new file mode 100644 index 0000000..9f82622 --- /dev/null +++ b/wsgitools/scgi/asyncio.py @@ -0,0 +1,242 @@ +__all__ = [] + +import asyncio +import io +import sys +import typing + +from wsgitools.internal import ( + bytes2str, + Environ, + HeaderList, + OptExcInfo, + str2bytes, + StartResponse, + WriteCallback, + WsgiApp, +) +from wsgitools.scgi import _convert_environ + +__all__.append("SCGIProtocol") +class SCGIProtocol(asyncio.Protocol): + def __init__( + self, + wsgiapp: WsgiApp, + baseenviron: Environ, + *, + maxrequestsize: int = 65536, + maxpostsize: int = 8 << 20, + ): + # configuration + self.wsgiapp = wsgiapp + self.maxrequestsize = maxrequestsize + self.maxpostsize = maxpostsize + self.transport: typing.Optional[asyncio.Transport] = None + + # request state + self.inbuff = b"" + self.parse: typing.Optional[typing.Callable[[], bool]] + self.parse = self.parse_reqsize + self.reqlen = -1 + self.environ: Environ = baseenviron.copy() + self.body = io.BytesIO() + + # response state + self.outheaders: typing.Union[bool, bytes] = False + # outheaders is a three-state + # * False -> start_response not yet called + # * bytes -> headers set by start_response, but not yet sent + # * True -> headers sent + self.wsgihandler: typing.Optional[typing.Iterable[bytes]] = None + self.wsgiiterator: typing.Optional[typing.Iterator[bytes]] = None + self.writeable = True + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + assert isinstance(transport, asyncio.Transport) + self.transport = transport + + def data_received(self, data: bytes) -> None: + self.inbuff += data + assert self.parse is not None + while self.parse(): + pass + + def parse_reqsize(self) -> bool: + """Parse the "%d:" part of the SCGI header.""" + assert self.transport is not None + parts = self.inbuff.split(b':', 1) + if len(parts) < 2 and len(self.inbuff) > 21: + self.transport.abort() + return False # request size implausibly large + try: + reqlen = int(parts[0]) + except ValueError: + self.transport.abort() + return False # invalid request format + self.inbuff = parts[1] + if reqlen > self.maxrequestsize: + self.transport.abort() + return False + self.reqlen = reqlen + self.parse = self.parse_environ + return True + + def parse_environ(self) -> bool: + """Parse the sequence of strings representing environ.""" + consumed = 0 + strings = self.inbuff[:self.reqlen].split(b'\0') + while len(strings) > 2: + key = strings.pop(0) + value = strings.pop(0) + self.environ[bytes2str(key)] = bytes2str(value) + consumed += len(key) + len(value) + 2 + self.inbuff = self.inbuff[consumed:] + self.reqlen -= consumed + if self.reqlen != 0: + return False + self.parse = self.parse_tail + return True + + def parse_tail(self) -> bool: + """Parse the comma and validate the environ.""" + assert self.transport is not None + if not self.inbuff: + return False + if not self.inbuff.startswith(b","): + self.transport.abort() + return False # invalid request format + self.inbuff = self.inbuff[1:] + try: + self.reqlen = int(self.environ["CONTENT_LENGTH"]) + except (KeyError, ValueError): + self.transport.abort() + return False + if self.reqlen > self.maxpostsize: + self.transport.abort() + return False + self.parse = self.parse_body + return True + + def parse_body(self) -> bool: + """Read the request body.""" + assert self.transport is not None + if len(self.inbuff) < self.reqlen: + self.body.write(self.inbuff) + self.reqlen -= len(self.inbuff) + self.inbuff = b"" + return True + + self.transport.pause_reading() + self.parse = None + self.body.write(self.inbuff[:self.reqlen]) + self.body.seek(0) + self.inbuff = b"" + self.reqlen = 0 + _convert_environ(self.environ, enable_sendfile=False) + self.environ["wsgi.input"] = self.body + self.wsgihandler = self.wsgiapp(self.environ, self.start_response) + assert self.wsgihandler is not None + self.wsgiiterator = iter(self.wsgihandler) + self.resume_writing() + return False + + def start_response( + self, status: str, headers: HeaderList, exc_info: OptExcInfo = None, + ) -> WriteCallback: + assert isinstance(status, str) + assert isinstance(headers, list) + if exc_info: + if self.outheaders is True: + try: + raise exc_info[0](exc_info[1]).with_traceback(exc_info[2]) + finally: + exc_info = None + assert self.outheaders is not True + self.outheaders = str2bytes( + "Status: %s\r\n%s\r\n" % ( + status, "".join(map("%s: %s\r\n".__mod__, headers)) + ) + ) + return self.wsgi_write + + def send_headers(self) -> None: + if self.outheaders is True: + return + assert self.transport is not None + assert isinstance(self.outheaders, bytes) + self.transport.write(self.outheaders) + self.outheaders = True + + def wsgi_write(self, data: bytes) -> None: + assert isinstance(data, bytes) + assert self.transport is not None + if not data: + return + assert self.parse is None + self.send_headers() + self.transport.write(data) + + def pause_writing(self) -> None: + self.writeable = False + + def resume_writing(self) -> None: + assert self.transport is not None + assert self.wsgiiterator is not None + self.writeable = True + while self.writeable: + try: + data = next(self.wsgiiterator) + except StopIteration: + self.send_headers() + self.transport.write_eof() + return + self.wsgi_write(data) + + def eof_received(self) -> None: + if self.parse is not None: + assert self.transport is not None + self.transport.abort() + + def connection_lost(self, exc: typing.Optional[Exception]) -> None: + assert self.transport is not None + self.transport.abort() + if hasattr(self.wsgihandler, "close"): + assert self.wsgihandler is not None + self.wsgihandler.close() + +__all__.append("SCGIProtocolFactory") +class SCGIProtocolFactory: + """An asyncio.Protocol factory for the SCGI protocol. + + Typical use: + + await loop.create_server(SCGIProtocolFactory(app, ...), port=...) + """ + def __init__( + self, + wsgiapp: WsgiApp, + *, + error: typing.Optional[typing.TextIO] = None, + maxrequestsize: typing.Optional[int] = None, + maxpostsize: typing.Optional[int] = None, + config: typing.Optional[Environ] = None, + ): + self.wsgiapp = wsgiapp + self.environ = {} if config is None else config.copy() + self.environ["wsgi.errors"] = error if error is None else sys.stderr + self.kwargs: typing.Dict[str, typing.Any] = {} + if maxrequestsize is not None: + self.kwargs["maxrequestsize"] = maxrequestsize + if maxpostsize is not None: + self.kwargs["maxpostsize"] = maxpostsize + + def __call__(self) -> SCGIProtocol: + return SCGIProtocol(self.wsgiapp, self.environ, **self.kwargs) + + def create_server( + self, *args, **kwargs + ) -> typing.Awaitable[asyncio.base_events.Server]: + """Convenience wrapper around asyncio.create_server filling in the loop + object and the factory argument. + """ + return asyncio.get_running_loop().create_server(self, *args, **kwargs) |