1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
|
__all__ = []
import asyncio
import io
import sys
import typing
from wsgitools.internal import (
bytes2str,
Environ,
HeaderList,
OptExcInfo,
str2bytes,
StartResponse,
WriteCallback,
WsgiApp,
)
from wsgitools.scgi import _convert_environ
__all__.append("SCGIProtocol")
class SCGIProtocol(asyncio.Protocol):
def __init__(
self,
wsgiapp: WsgiApp,
baseenviron: Environ,
*,
maxrequestsize: int = 65536,
maxpostsize: int = 8 << 20,
):
# configuration
self.wsgiapp = wsgiapp
self.maxrequestsize = maxrequestsize
self.maxpostsize = maxpostsize
self.transport: typing.Optional[asyncio.Transport] = None
# request state
self.inbuff = b""
self.parse: typing.Optional[typing.Callable[[], bool]]
self.parse = self.parse_reqsize
self.reqlen = -1
self.environ: Environ = baseenviron.copy()
self.body = io.BytesIO()
# response state
self.outheaders: typing.Union[bool, bytes] = False
# outheaders is a three-state
# * False -> start_response not yet called
# * bytes -> headers set by start_response, but not yet sent
# * True -> headers sent
self.wsgihandler: typing.Optional[typing.Iterable[bytes]] = None
self.wsgiiterator: typing.Optional[typing.Iterator[bytes]] = None
self.writeable = True
def connection_made(self, transport: asyncio.BaseTransport) -> None:
assert isinstance(transport, asyncio.Transport)
self.transport = transport
def data_received(self, data: bytes) -> None:
self.inbuff += data
assert self.parse is not None
while self.parse():
pass
def parse_reqsize(self) -> bool:
"""Parse the "%d:" part of the SCGI header."""
assert self.transport is not None
parts = self.inbuff.split(b':', 1)
if len(parts) < 2 and len(self.inbuff) > 21:
self.transport.abort()
return False # request size implausibly large
try:
reqlen = int(parts[0])
except ValueError:
self.transport.abort()
return False # invalid request format
self.inbuff = parts[1]
if reqlen > self.maxrequestsize:
self.transport.abort()
return False
self.reqlen = reqlen
self.parse = self.parse_environ
return True
def parse_environ(self) -> bool:
"""Parse the sequence of strings representing environ."""
consumed = 0
strings = self.inbuff[:self.reqlen].split(b'\0')
while len(strings) > 2:
key = strings.pop(0)
value = strings.pop(0)
self.environ[bytes2str(key)] = bytes2str(value)
consumed += len(key) + len(value) + 2
self.inbuff = self.inbuff[consumed:]
self.reqlen -= consumed
if self.reqlen != 0:
return False
self.parse = self.parse_tail
return True
def parse_tail(self) -> bool:
"""Parse the comma and validate the environ."""
assert self.transport is not None
if not self.inbuff:
return False
if not self.inbuff.startswith(b","):
self.transport.abort()
return False # invalid request format
self.inbuff = self.inbuff[1:]
try:
self.reqlen = int(self.environ["CONTENT_LENGTH"])
except (KeyError, ValueError):
self.transport.abort()
return False
if self.reqlen > self.maxpostsize:
self.transport.abort()
return False
self.parse = self.parse_body
return True
def parse_body(self) -> bool:
"""Read the request body."""
assert self.transport is not None
if len(self.inbuff) < self.reqlen:
self.body.write(self.inbuff)
self.reqlen -= len(self.inbuff)
self.inbuff = b""
return True
self.transport.pause_reading()
self.parse = None
self.body.write(self.inbuff[:self.reqlen])
self.body.seek(0)
self.inbuff = b""
self.reqlen = 0
_convert_environ(self.environ, enable_sendfile=False)
self.environ["wsgi.input"] = self.body
self.wsgihandler = self.wsgiapp(self.environ, self.start_response)
assert self.wsgihandler is not None
self.wsgiiterator = iter(self.wsgihandler)
self.resume_writing()
return False
def start_response(
self, status: str, headers: HeaderList, exc_info: OptExcInfo = None,
) -> WriteCallback:
assert isinstance(status, str)
assert isinstance(headers, list)
if exc_info:
if self.outheaders is True:
try:
raise exc_info[0](exc_info[1]).with_traceback(exc_info[2])
finally:
exc_info = None
assert self.outheaders is not True
self.outheaders = str2bytes(
"Status: %s\r\n%s\r\n" % (
status, "".join(map("%s: %s\r\n".__mod__, headers))
)
)
return self.wsgi_write
def send_headers(self) -> None:
if self.outheaders is True:
return
assert self.transport is not None
assert isinstance(self.outheaders, bytes)
self.transport.write(self.outheaders)
self.outheaders = True
def wsgi_write(self, data: bytes) -> None:
assert isinstance(data, bytes)
assert self.transport is not None
if not data:
return
assert self.parse is None
self.send_headers()
self.transport.write(data)
def pause_writing(self) -> None:
self.writeable = False
def resume_writing(self) -> None:
assert self.transport is not None
assert self.wsgiiterator is not None
self.writeable = True
while self.writeable:
try:
data = next(self.wsgiiterator)
except StopIteration:
self.send_headers()
self.transport.write_eof()
return
self.wsgi_write(data)
def eof_received(self) -> None:
if self.parse is not None:
assert self.transport is not None
self.transport.abort()
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
assert self.transport is not None
self.transport.abort()
if hasattr(self.wsgihandler, "close"):
assert self.wsgihandler is not None
self.wsgihandler.close()
__all__.append("SCGIProtocolFactory")
class SCGIProtocolFactory:
"""An asyncio.Protocol factory for the SCGI protocol.
Typical use:
await loop.create_server(SCGIProtocolFactory(app, ...), port=...)
"""
def __init__(
self,
wsgiapp: WsgiApp,
*,
error: typing.Optional[typing.TextIO] = None,
maxrequestsize: typing.Optional[int] = None,
maxpostsize: typing.Optional[int] = None,
config: typing.Optional[Environ] = None,
):
self.wsgiapp = wsgiapp
self.environ = {} if config is None else config.copy()
self.environ["wsgi.errors"] = error if error is None else sys.stderr
self.kwargs: typing.Dict[str, typing.Any] = {}
if maxrequestsize is not None:
self.kwargs["maxrequestsize"] = maxrequestsize
if maxpostsize is not None:
self.kwargs["maxpostsize"] = maxpostsize
def __call__(self) -> SCGIProtocol:
return SCGIProtocol(self.wsgiapp, self.environ, **self.kwargs)
def create_server(
self, *args, **kwargs
) -> typing.Awaitable[asyncio.base_events.Server]:
"""Convenience wrapper around asyncio.create_server filling in the loop
object and the factory argument.
"""
return asyncio.get_running_loop().create_server(self, *args, **kwargs)
|