summaryrefslogtreecommitdiff
path: root/wsgitools/scgi/asyncio.py
blob: 9f8262225e65f427c97bd5a94baaf840882acce3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
__all__ = []

import asyncio
import io
import sys
import typing

from wsgitools.internal import (
    bytes2str,
    Environ,
    HeaderList,
    OptExcInfo,
    str2bytes,
    StartResponse,
    WriteCallback,
    WsgiApp,
)
from wsgitools.scgi import _convert_environ

__all__.append("SCGIProtocol")
class SCGIProtocol(asyncio.Protocol):
    def __init__(
        self,
        wsgiapp: WsgiApp,
        baseenviron: Environ,
        *,
        maxrequestsize: int = 65536,
        maxpostsize: int = 8 << 20,
    ):
        # configuration
        self.wsgiapp = wsgiapp
        self.maxrequestsize = maxrequestsize
        self.maxpostsize = maxpostsize
        self.transport: typing.Optional[asyncio.Transport] = None

        # request state
        self.inbuff = b""
        self.parse: typing.Optional[typing.Callable[[], bool]]
        self.parse = self.parse_reqsize
        self.reqlen = -1
        self.environ: Environ = baseenviron.copy()
        self.body = io.BytesIO()

        # response state
        self.outheaders: typing.Union[bool, bytes] = False
        # outheaders is a three-state
        #  * False -> start_response not yet called
        #  * bytes -> headers set by start_response, but not yet sent
        #  * True -> headers sent
        self.wsgihandler: typing.Optional[typing.Iterable[bytes]] = None
        self.wsgiiterator: typing.Optional[typing.Iterator[bytes]] = None
        self.writeable = True

    def connection_made(self, transport: asyncio.BaseTransport) -> None:
        assert isinstance(transport, asyncio.Transport)
        self.transport = transport

    def data_received(self, data: bytes) -> None:
        self.inbuff += data
        assert self.parse is not None
        while self.parse():
            pass

    def parse_reqsize(self) -> bool:
        """Parse the "%d:" part of the SCGI header."""
        assert self.transport is not None
        parts = self.inbuff.split(b':', 1)
        if len(parts) < 2 and len(self.inbuff) > 21:
            self.transport.abort()
            return False  # request size implausibly large
        try:
            reqlen = int(parts[0])
        except ValueError:
            self.transport.abort()
            return False  # invalid request format
        self.inbuff = parts[1]
        if reqlen > self.maxrequestsize:
            self.transport.abort()
            return False
        self.reqlen = reqlen
        self.parse = self.parse_environ
        return True

    def parse_environ(self) -> bool:
        """Parse the sequence of strings representing environ."""
        consumed = 0
        strings = self.inbuff[:self.reqlen].split(b'\0')
        while len(strings) > 2:
            key = strings.pop(0)
            value = strings.pop(0)
            self.environ[bytes2str(key)] = bytes2str(value)
            consumed += len(key) + len(value) + 2
        self.inbuff = self.inbuff[consumed:]
        self.reqlen -= consumed
        if self.reqlen != 0:
            return False
        self.parse = self.parse_tail
        return True

    def parse_tail(self) -> bool:
        """Parse the comma and validate the environ."""
        assert self.transport is not None
        if not self.inbuff:
            return False
        if not self.inbuff.startswith(b","):
            self.transport.abort()
            return False  # invalid request format
        self.inbuff = self.inbuff[1:]
        try:
            self.reqlen = int(self.environ["CONTENT_LENGTH"])
        except (KeyError, ValueError):
            self.transport.abort()
            return False
        if self.reqlen > self.maxpostsize:
            self.transport.abort()
            return False
        self.parse = self.parse_body
        return True

    def parse_body(self) -> bool:
        """Read the request body."""
        assert self.transport is not None
        if len(self.inbuff) < self.reqlen:
            self.body.write(self.inbuff)
            self.reqlen -= len(self.inbuff)
            self.inbuff = b""
            return True

        self.transport.pause_reading()
        self.parse = None
        self.body.write(self.inbuff[:self.reqlen])
        self.body.seek(0)
        self.inbuff = b""
        self.reqlen = 0
        _convert_environ(self.environ, enable_sendfile=False)
        self.environ["wsgi.input"] = self.body
        self.wsgihandler = self.wsgiapp(self.environ, self.start_response)
        assert self.wsgihandler is not None
        self.wsgiiterator = iter(self.wsgihandler)
        self.resume_writing()
        return False

    def start_response(
        self, status: str, headers: HeaderList, exc_info: OptExcInfo = None,
    ) -> WriteCallback:
        assert isinstance(status, str)
        assert isinstance(headers, list)
        if exc_info:
            if self.outheaders is True:
                try:
                    raise exc_info[0](exc_info[1]).with_traceback(exc_info[2])
                finally:
                    exc_info = None
        assert self.outheaders is not True
        self.outheaders = str2bytes(
            "Status: %s\r\n%s\r\n" % (
                status, "".join(map("%s: %s\r\n".__mod__, headers))
            )
        )
        return self.wsgi_write

    def send_headers(self) -> None:
        if self.outheaders is True:
            return
        assert self.transport is not None
        assert isinstance(self.outheaders, bytes)
        self.transport.write(self.outheaders)
        self.outheaders = True

    def wsgi_write(self, data: bytes) -> None:
        assert isinstance(data, bytes)
        assert self.transport is not None
        if not data:
            return
        assert self.parse is None
        self.send_headers()
        self.transport.write(data)

    def pause_writing(self) -> None:
        self.writeable = False

    def resume_writing(self) -> None:
        assert self.transport is not None
        assert self.wsgiiterator is not None
        self.writeable = True
        while self.writeable:
            try:
                data = next(self.wsgiiterator)
            except StopIteration:
                self.send_headers()
                self.transport.write_eof()
                return
            self.wsgi_write(data)

    def eof_received(self) -> None:
        if self.parse is not None:
            assert self.transport is not None
            self.transport.abort()

    def connection_lost(self, exc: typing.Optional[Exception]) -> None:
        assert self.transport is not None
        self.transport.abort()
        if hasattr(self.wsgihandler, "close"):
            assert self.wsgihandler is not None
            self.wsgihandler.close()

__all__.append("SCGIProtocolFactory")
class SCGIProtocolFactory:
    """An asyncio.Protocol factory for the SCGI protocol.

    Typical use:

        await loop.create_server(SCGIProtocolFactory(app, ...), port=...)
    """
    def __init__(
        self,
        wsgiapp: WsgiApp,
        *,
        error: typing.Optional[typing.TextIO] = None,
        maxrequestsize: typing.Optional[int] = None,
        maxpostsize: typing.Optional[int] = None,
        config: typing.Optional[Environ] = None,
    ):
        self.wsgiapp = wsgiapp
        self.environ = {} if config is None else config.copy()
        self.environ["wsgi.errors"] = error if error is None else sys.stderr
        self.kwargs: typing.Dict[str, typing.Any] = {}
        if maxrequestsize is not None:
            self.kwargs["maxrequestsize"] = maxrequestsize
        if maxpostsize is not None:
            self.kwargs["maxpostsize"] = maxpostsize

    def __call__(self) -> SCGIProtocol:
        return SCGIProtocol(self.wsgiapp, self.environ, **self.kwargs)

    def create_server(
        self, *args, **kwargs
    ) -> typing.Awaitable[asyncio.base_events.Server]:
        """Convenience wrapper around asyncio.create_server filling in the loop
        object and the factory argument.
        """
        return asyncio.get_running_loop().create_server(self, *args, **kwargs)