__all__ = [] import asyncio import io import sys import typing from wsgitools.internal import ( bytes2str, Environ, HeaderList, OptExcInfo, str2bytes, StartResponse, WriteCallback, WsgiApp, ) from wsgitools.scgi import _convert_environ __all__.append("SCGIProtocol") class SCGIProtocol(asyncio.Protocol): def __init__( self, wsgiapp: WsgiApp, baseenviron: Environ, *, maxrequestsize: int = 65536, maxpostsize: int = 8 << 20, ): # configuration self.wsgiapp = wsgiapp self.maxrequestsize = maxrequestsize self.maxpostsize = maxpostsize self.transport: typing.Optional[asyncio.Transport] = None # request state self.inbuff = b"" self.parse: typing.Optional[typing.Callable[[], bool]] self.parse = self.parse_reqsize self.reqlen = -1 self.environ: Environ = baseenviron.copy() self.body = io.BytesIO() # response state self.outheaders: typing.Union[bool, bytes] = False # outheaders is a three-state # * False -> start_response not yet called # * bytes -> headers set by start_response, but not yet sent # * True -> headers sent self.wsgihandler: typing.Optional[typing.Iterable[bytes]] = None self.wsgiiterator: typing.Optional[typing.Iterator[bytes]] = None self.writeable = True def connection_made(self, transport: asyncio.BaseTransport) -> None: assert isinstance(transport, asyncio.Transport) self.transport = transport def data_received(self, data: bytes) -> None: self.inbuff += data assert self.parse is not None while self.parse(): pass def parse_reqsize(self) -> bool: """Parse the "%d:" part of the SCGI header.""" assert self.transport is not None parts = self.inbuff.split(b':', 1) if len(parts) < 2 and len(self.inbuff) > 21: self.transport.abort() return False # request size implausibly large try: reqlen = int(parts[0]) except ValueError: self.transport.abort() return False # invalid request format self.inbuff = parts[1] if reqlen > self.maxrequestsize: self.transport.abort() return False self.reqlen = reqlen self.parse = self.parse_environ return True def parse_environ(self) -> bool: """Parse the sequence of strings representing environ.""" consumed = 0 strings = self.inbuff[:self.reqlen].split(b'\0') while len(strings) > 2: key = strings.pop(0) value = strings.pop(0) self.environ[bytes2str(key)] = bytes2str(value) consumed += len(key) + len(value) + 2 self.inbuff = self.inbuff[consumed:] self.reqlen -= consumed if self.reqlen != 0: return False self.parse = self.parse_tail return True def parse_tail(self) -> bool: """Parse the comma and validate the environ.""" assert self.transport is not None if not self.inbuff: return False if not self.inbuff.startswith(b","): self.transport.abort() return False # invalid request format self.inbuff = self.inbuff[1:] try: self.reqlen = int(self.environ["CONTENT_LENGTH"]) except (KeyError, ValueError): self.transport.abort() return False if self.reqlen > self.maxpostsize: self.transport.abort() return False self.parse = self.parse_body return True def parse_body(self) -> bool: """Read the request body.""" assert self.transport is not None if len(self.inbuff) < self.reqlen: self.body.write(self.inbuff) self.reqlen -= len(self.inbuff) self.inbuff = b"" return True self.transport.pause_reading() self.parse = None self.body.write(self.inbuff[:self.reqlen]) self.body.seek(0) self.inbuff = b"" self.reqlen = 0 _convert_environ(self.environ, enable_sendfile=False) self.environ["wsgi.input"] = self.body self.wsgihandler = self.wsgiapp(self.environ, self.start_response) assert self.wsgihandler is not None self.wsgiiterator = iter(self.wsgihandler) self.resume_writing() return False def start_response( self, status: str, headers: HeaderList, exc_info: OptExcInfo = None, ) -> WriteCallback: assert isinstance(status, str) assert isinstance(headers, list) if exc_info: if self.outheaders is True: try: raise exc_info[0](exc_info[1]).with_traceback(exc_info[2]) finally: exc_info = None assert self.outheaders is not True self.outheaders = str2bytes( "Status: %s\r\n%s\r\n" % ( status, "".join(map("%s: %s\r\n".__mod__, headers)) ) ) return self.wsgi_write def send_headers(self) -> None: if self.outheaders is True: return assert self.transport is not None assert isinstance(self.outheaders, bytes) self.transport.write(self.outheaders) self.outheaders = True def wsgi_write(self, data: bytes) -> None: assert isinstance(data, bytes) assert self.transport is not None if not data: return assert self.parse is None self.send_headers() self.transport.write(data) def pause_writing(self) -> None: self.writeable = False def resume_writing(self) -> None: assert self.transport is not None assert self.wsgiiterator is not None self.writeable = True while self.writeable: try: data = next(self.wsgiiterator) except StopIteration: self.send_headers() self.transport.write_eof() return self.wsgi_write(data) def eof_received(self) -> None: if self.parse is not None: assert self.transport is not None self.transport.abort() def connection_lost(self, exc: typing.Optional[Exception]) -> None: assert self.transport is not None self.transport.abort() if hasattr(self.wsgihandler, "close"): assert self.wsgihandler is not None self.wsgihandler.close() __all__.append("SCGIProtocolFactory") class SCGIProtocolFactory: """An asyncio.Protocol factory for the SCGI protocol. Typical use: await loop.create_server(SCGIProtocolFactory(app, ...), port=...) """ def __init__( self, wsgiapp: WsgiApp, *, error: typing.Optional[typing.TextIO] = None, maxrequestsize: typing.Optional[int] = None, maxpostsize: typing.Optional[int] = None, config: typing.Optional[Environ] = None, ): self.wsgiapp = wsgiapp self.environ = {} if config is None else config.copy() self.environ["wsgi.errors"] = error if error is None else sys.stderr self.kwargs: typing.Dict[str, typing.Any] = {} if maxrequestsize is not None: self.kwargs["maxrequestsize"] = maxrequestsize if maxpostsize is not None: self.kwargs["maxpostsize"] = maxpostsize def __call__(self) -> SCGIProtocol: return SCGIProtocol(self.wsgiapp, self.environ, **self.kwargs) def create_server( self, *args, **kwargs ) -> typing.Awaitable[asyncio.base_events.Server]: """Convenience wrapper around asyncio.create_server filling in the loop object and the factory argument. """ return asyncio.get_running_loop().create_server(self, *args, **kwargs)