summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHelmut Grohne <helmut@subdivi.de>2020-04-13 21:30:34 +0200
committerHelmut Grohne <helmut@subdivi.de>2023-06-18 23:16:57 +0200
commita41066b413489b407b9d99174af697563ad680b9 (patch)
tree2f08f9e886e13a7500d1eb527e30737d961deab6
parent4d52eaa4801df3f3169df8e58758bcccf22dc4de (diff)
downloadwsgitools-a41066b413489b407b9d99174af697563ad680b9.tar.gz
add type hints to all of the code
In order to use type hint syntax, we need to bump the minimum Python version to 3.7 and some of the features such as Literal and Protocol are opted in when a sufficiently recent Python is available. This does not make all of the code pass type checking with mypy. A number of typing issues remain, but the output of mypy becomes something one can read through. In adding type hints, a lot of epydoc @type annotations are removed as redundant. This update also adopts black-style line breaking.
-rw-r--r--README2
-rwxr-xr-xtest.py207
-rw-r--r--wsgitools/applications.py44
-rw-r--r--wsgitools/authentication.py43
-rw-r--r--wsgitools/digest.py220
-rw-r--r--wsgitools/filters.py252
-rw-r--r--wsgitools/internal.py37
-rw-r--r--wsgitools/middlewares.py176
-rw-r--r--wsgitools/scgi/__init__.py25
-rw-r--r--wsgitools/scgi/asynchronous.py98
-rw-r--r--wsgitools/scgi/forkpool.py95
11 files changed, 682 insertions, 517 deletions
diff --git a/README b/README
index 1ad2a89..ec217e3 100644
--- a/README
+++ b/README
@@ -2,7 +2,7 @@ The software should be usable by reading the docstrings. If you think that
certain features are missing or you found a bug, don't hesitate to ask me
via mail!
-Supported Python versions currently are >= 3.5 and <= 3.11. 3.12 will be
+Supported Python versions currently are >= 3.7 and <= 3.11. 3.12 will be
degraded, due to use of deprecated modules.
Installation should be easy using setup.py. I recommend running the test suite
diff --git a/test.py b/test.py
index 9690baf..5271819 100755
--- a/test.py
+++ b/test.py
@@ -8,16 +8,23 @@ import wsgiref.validate
import io
from hashlib import md5
import sys
-
-from wsgitools.internal import bytes2str, str2bytes
+import typing
+
+from wsgitools.internal import (
+ bytes2str,
+ Environ,
+ HeaderList,
+ OptExcInfo,
+ StartResponse,
+ str2bytes,
+ WriteCallback,
+ WsgiApp,
+)
class Request:
- def __init__(self, case):
- """
- @type case: unittest.TestCase
- """
+ def __init__(self, case: unittest.TestCase):
self.testcase = case
- self.environ = dict(
+ self.environ: Environ = dict(
REQUEST_METHOD="GET",
SERVER_NAME="localhost",
SERVER_PORT="80",
@@ -33,86 +40,92 @@ class Request:
"wsgi.multiprocess": False,
"wsgi.run_once": False})
- def setenv(self, key, value):
+ def setenv(self, key: str, value: str) -> "Request":
"""
- @type key: str
- @type value: str
@returns: self
"""
self.environ[key] = value
return self
- def setmethod(self, request_method):
+ def setmethod(self, request_method: str) -> "Request":
"""
- @type request_method: str
@returns: self
"""
return self.setenv("REQUEST_METHOD", request_method)
- def setheader(self, name, value):
+ def setheader(self, name: str, value: str) -> "Request":
"""
- @type name: str
- @type value: str
@returns: self
"""
return self.setenv("HTTP_" + name.upper().replace('-', '_'), value)
- def copy(self):
+ def copy(self) -> "Request":
req = Request(self.testcase)
req.environ = dict(self.environ)
return req
- def __call__(self, app):
+ def __call__(self, app: WsgiApp) -> "Result":
app = wsgiref.validate.validator(app)
- res = Result(self.testcase)
- def write(data):
- res.writtendata.append(data)
- def start_response(status, headers, exc_info=None):
- res.statusdata = status
- res.headersdata = headers
+ writtendata: typing.List[bytes] = []
+ def write(data: bytes) -> None:
+ nonlocal writtendata
+ writtendata.append(data)
+ statusdata: typing.Optional[str] = None
+ headersdata: typing.Optional[HeaderList] = None
+ def start_response(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
+ nonlocal statusdata, headersdata
+ statusdata = status
+ headersdata = headers
return write
iterator = app(self.environ, start_response)
- res.returneddata = list(iterator)
+ returneddata = list(iterator)
if hasattr(iterator, "close"):
iterator.close()
- return res
+ assert statusdata is not None
+ assert headersdata is not None
+ return Result(
+ self.testcase, statusdata, headersdata, writtendata, returneddata
+ )
class Result:
- def __init__(self, case):
- """
- @type case: unittest.TestCase
- """
+ def __init__(
+ self,
+ case: unittest.TestCase,
+ statusdata: str,
+ headersdata: HeaderList,
+ writtendata: typing.List[bytes],
+ returneddata: typing.List[bytes],
+ ):
self.testcase = case
- self.statusdata = None
- self.headersdata = None
- self.writtendata = []
- self.returneddata = None
+ self.statusdata = statusdata
+ self.headersdata = headersdata
+ self.writtendata = writtendata
+ self.returneddata = returneddata
- def status(self, check):
- """
- @type check: int or str
- """
+ def status(self, check: typing.Union[int, str]) -> None:
if isinstance(check, int):
+ assert self.statusdata is not None
status = int(self.statusdata.split()[0])
self.testcase.assertEqual(check, status)
else:
self.testcase.assertEqual(check, self.statusdata)
- def getheader(self, name):
+ def getheader(self, name: str) -> str:
"""
- @type name: str
@raises KeyError:
"""
+ assert self.headersdata is not None
for key, value in self.headersdata:
if key == name:
return value
raise KeyError
- def header(self, name, check):
- """
- @type name: str
- @type check: str or (str -> bool)
- """
+ def header(
+ self, name: str, check: typing.Union[str, typing.Callable[[str], bool]]
+ ) -> None:
+ assert self.headersdata is not None
found = False
for key, value in self.headersdata:
if key == name:
@@ -124,23 +137,23 @@ class Result:
if not found:
self.testcase.fail("header %s not found" % name)
- def get_data(self):
+ def get_data(self) -> bytes:
return b"".join(self.writtendata) + b"".join(self.returneddata)
from wsgitools import applications
class StaticContentTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.app = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"nothing")
self.req = Request(self)
- def testGet(self):
+ def testGet(self) -> None:
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testHead(self):
+ def testHead(self) -> None:
req = self.req.copy()
req.setmethod("HEAD")
res = req(self.app)
@@ -148,17 +161,17 @@ class StaticContentTest(unittest.TestCase):
res.header("Content-length", "7")
class StaticFileTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.app = applications.StaticFile(io.BytesIO(b"success"), "200 Found",
[("Content-Type", "text/plain")])
self.req = Request(self)
- def testGet(self):
+ def testGet(self) -> None:
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testHead(self):
+ def testHead(self) -> None:
req = self.req.copy()
req.setmethod("HEAD")
res = req(self.app)
@@ -168,7 +181,7 @@ class StaticFileTest(unittest.TestCase):
from wsgitools import digest
class AuthDigestMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.staticapp = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"success")
token_gen = digest.AuthTokenGenerator("foo", lambda _: "baz")
@@ -176,26 +189,26 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
wsgiref.validate.validator(self.staticapp), token_gen)
self.req = Request(self)
- def test401(self):
+ def test401(self) -> None:
res = self.req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401garbage(self):
+ def test401garbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Garbage')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401digestgarbage(self):
+ def test401digestgarbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Digest ","')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def doauth(self, password="baz", status=200):
+ def doauth(self, password: str = "baz", status: int = 200) -> None:
res = self.req(self.app)
nonce = next(iter(filter(lambda x: x.startswith("nonce="),
res.getheader("WWW-Authenticate").split())))
@@ -209,13 +222,13 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
res = req(self.app)
res.status(status)
- def test200(self):
+ def test200(self) -> None:
self.doauth()
- def test401authfail(self):
+ def test401authfail(self) -> None:
self.doauth(password="spam", status=401)
- def testqopauth(self):
+ def testqopauth(self) -> None:
res = self.req(self.app)
nonce = next(iter(filter(lambda x: x.startswith("nonce="),
res.getheader("WWW-Authenticate").split())))
@@ -233,27 +246,31 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
from wsgitools import middlewares
-def writing_application(environ, start_response):
+def writing_application(
+ environ: Environ, start_response: StartResponse
+) -> typing.Iterable[bytes]:
write = start_response("404 Not found", [("Content-Type", "text/plain")])
write = start_response("200 Ok", [("Content-Type", "text/plain")])
write(b"first")
yield b""
yield b"second"
-def write_only_application(environ, start_response):
+def write_only_application(
+ environ: Environ, start_response: StartResponse
+) -> typing.Iterable[bytes]:
write = start_response("200 Ok", [("Content-Type", "text/plain")])
write(b"first")
write(b"second")
yield b""
class NoWriteCallableMiddlewareTest(unittest.TestCase):
- def testWrite(self):
+ def testWrite(self) -> None:
app = middlewares.NoWriteCallableMiddleware(writing_application)
res = Request(self)(app)
self.assertEqual(res.writtendata, [])
self.assertEqual(b"".join(res.returneddata), b"firstsecond")
- def testWriteOnly(self):
+ def testWriteOnly(self) -> None:
app = middlewares.NoWriteCallableMiddleware(write_only_application)
res = Request(self)(app)
self.assertEqual(res.writtendata, [])
@@ -262,28 +279,30 @@ class NoWriteCallableMiddlewareTest(unittest.TestCase):
class StupidIO:
"""file-like without tell method, so StaticFile is not able to
determine the content-length."""
- def __init__(self, content):
+ def __init__(self, content: bytes):
self.content = content
self.position = 0
- def seek(self, pos):
+ def seek(self, pos: int) -> None:
assert pos == 0
self.position = 0
- def read(self, length):
+ def read(self, length: int) -> bytes:
oldpos = self.position
self.position += length
return self.content[oldpos:self.position]
class ContentLengthMiddlewareTest(unittest.TestCase):
- def customSetUp(self, maxstore=10):
+ def customSetUp(
+ self, maxstore: typing.Union[int, typing.Tuple[()]] = 10
+ ) -> None:
self.staticapp = applications.StaticFile(StupidIO(b"success"),
"200 Found", [("Content-Type", "text/plain")])
self.app = middlewares.ContentLengthMiddleware(self.staticapp,
maxstore=maxstore)
self.req = Request(self)
- def testWithout(self):
+ def testWithout(self) -> None:
self.customSetUp()
res = self.req(self.staticapp)
res.status("200 Found")
@@ -293,24 +312,26 @@ class ContentLengthMiddlewareTest(unittest.TestCase):
except KeyError:
pass
- def testGet(self):
+ def testGet(self) -> None:
self.customSetUp()
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testInfiniteMaxstore(self):
+ def testInfiniteMaxstore(self) -> None:
self.customSetUp(maxstore=())
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
class CachingMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.cached = middlewares.CachingMiddleware(self.app)
- self.accessed = dict()
+ self.accessed: typing.Dict[str, int] = {}
- def app(self, environ, start_response):
+ def app(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
count = self.accessed.get(environ["SCRIPT_NAME"], 0) + 1
self.accessed[environ["SCRIPT_NAME"]] = count
headers = [("Content-Type", "text/plain")]
@@ -319,7 +340,7 @@ class CachingMiddlewareTest(unittest.TestCase):
start_response("200 Found", headers)
return [b"%d" % count]
- def testCache(self):
+ def testCache(self) -> None:
res = Request(self)(self.cached)
res.status(200)
self.assertEqual(res.get_data(), b"1")
@@ -327,7 +348,7 @@ class CachingMiddlewareTest(unittest.TestCase):
res.status(200)
self.assertEqual(res.get_data(), b"1")
- def testNoCache(self):
+ def testNoCache(self) -> None:
res = Request(self)(self.cached)
res.status(200)
self.assertEqual(res.get_data(), b"1")
@@ -337,7 +358,7 @@ class CachingMiddlewareTest(unittest.TestCase):
self.assertEqual(res.get_data(), b"2")
class BasicAuthMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.staticapp = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"success")
checkpw = middlewares.DictAuthChecker({"bar": "baz"})
@@ -345,26 +366,26 @@ class BasicAuthMiddlewareTest(unittest.TestCase):
wsgiref.validate.validator(self.staticapp), checkpw)
self.req = Request(self)
- def test401(self):
+ def test401(self) -> None:
res = self.req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401garbage(self):
+ def test401garbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Garbage')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401basicgarbage(self):
+ def test401basicgarbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Basic ()')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def doauth(self, password="baz", status=200):
+ def doauth(self, password: str = "baz", status: int = 200) -> None:
req = self.req.copy()
token = "bar:%s" % password
token = bytes2str(base64.b64encode(str2bytes(token)))
@@ -372,19 +393,20 @@ class BasicAuthMiddlewareTest(unittest.TestCase):
res = req(self.app)
res.status(status)
- def test200(self):
+ def test200(self) -> None:
self.doauth()
- def test401authfail(self):
+ def test401authfail(self) -> None:
self.doauth(password="spam", status=401)
from wsgitools import filters
import gzip
class RequestLogWSGIFilterTest(unittest.TestCase):
- def testSimple(self):
- app = applications.StaticContent("200 Found",
- [("Content-Type", "text/plain")], b"nothing")
+ def testSimple(self) -> None:
+ app: WsgiApp = applications.StaticContent(
+ "200 Found", [("Content-Type", "text/plain")], b"nothing"
+ )
log = io.StringIO()
logfilter = filters.RequestLogWSGIFilter.creator(log)
app = filters.WSGIFilterMiddleware(app, logfilter)
@@ -398,9 +420,10 @@ class RequestLogWSGIFilterTest(unittest.TestCase):
r'200 7 - "wsgitools-test"', logged))
class GzipWSGIFilterTest(unittest.TestCase):
- def testSimple(self):
- app = applications.StaticContent("200 Found",
- [("Content-Type", "text/plain")], b"nothing")
+ def testSimple(self) -> None:
+ app: WsgiApp = applications.StaticContent(
+ "200 Found", [("Content-Type", "text/plain")], b"nothing"
+ )
app = filters.WSGIFilterMiddleware(app, filters.GzipWSGIFilter)
req = Request(self)
req.environ["HTTP_ACCEPT_ENCODING"] = "gzip"
@@ -413,7 +436,9 @@ import asyncore
import socket
import threading
-def fetch_scgi(port, req, body=b""):
+def fetch_scgi(
+ port: int, req: typing.Dict[str, str], body: bytes = b""
+) -> bytes:
with socket.socket() as client:
client.connect(("localhost", port))
req = req.copy()
@@ -423,7 +448,7 @@ def fetch_scgi(port, req, body=b""):
return client.recv(65536)
class ScgiAsynchronousTest(unittest.TestCase):
- def testSimple(self):
+ def testSimple(self) -> None:
app = applications.StaticContent(
"200 OK", [("Content-Type", "text/plain")], b"nothing"
)
@@ -447,7 +472,7 @@ import os
import signal
class ScgiForkTest(unittest.TestCase):
- def testSimple(self):
+ def testSimple(self) -> None:
app = applications.StaticContent(
"200 OK", [("Content-Type", "text/plain")], b"nothing"
)
diff --git a/wsgitools/applications.py b/wsgitools/applications.py
index 9894cf8..f51fccf 100644
--- a/wsgitools/applications.py
+++ b/wsgitools/applications.py
@@ -1,4 +1,7 @@
import os.path
+import typing
+
+from wsgitools.internal import Environ, HeaderList, StartResponse
__all__ = []
@@ -9,17 +12,21 @@ class StaticContent:
receives with method GET or HEAD (content stripped). If not present, a
content-length header is computed.
"""
- def __init__(self, status, headers, content, anymethod=False):
+ content: typing.Iterable[bytes]
+
+ def __init__(
+ self,
+ status: str,
+ headers: HeaderList,
+ content: typing.Union[bytes, typing.Iterable[bytes]],
+ anymethod: bool = False,
+ ):
"""
- @type status: str
@param status: is the HTTP status returned to the browser (ex: "200 OK")
- @type headers: list
@param headers: is a list of C{(header, value)} pairs being delivered as
HTTP headers
- @type content: bytes
@param content: contains the data to be delivered to the client. It is
either a string or some kind of iterable yielding strings.
- @type anymethod: boolean
@param anymethod: determines whether any request method should be
answered with this response instead of a 501
"""
@@ -40,7 +47,9 @@ class StaticContent:
if length >= 0:
if not [v for h, v in headers if h.lower() == "content-length"]:
headers.append(("Content-length", str(length)))
- def __call__(self, environ, start_response):
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
"""wsgi interface"""
assert isinstance(environ, dict)
if environ["REQUEST_METHOD"].upper() not in ["GET", "HEAD"] and \
@@ -61,20 +70,21 @@ class StaticFile:
request it receives with method GET or HEAD (content stripped). If not
present, a content-length header is computed.
"""
- def __init__(self, filelike, status="200 OK", headers=list(),
- blocksize=4096):
+ def __init__(
+ self,
+ filelike: typing.Union[str, typing.BinaryIO],
+ status: str = "200 OK",
+ headers: HeaderList = list(),
+ blocksize: int = 4096,
+ ):
"""
- @type status: str
@param status: is the HTTP status returned to the browser
- @type headers: [(str, str)]
@param headers: is a list of C{(header, value)} pairs being delivered as
HTTP headers
- @type filelike: str or file-like
@param filelike: may either be an path in the local file system or a
file-like that must support C{read(size)} and C{seek(0)}. If
C{tell()} is present, C{seek(0, 2)} and C{tell()} will be used
to compute the content-length.
- @type blocksize: int
@param blocksize: the content is provided in chunks of this size
"""
self.filelike = filelike
@@ -82,7 +92,9 @@ class StaticFile:
self.headers = headers
self.blocksize = blocksize
- def _serve_in_chunks(self, stream):
+ def _serve_in_chunks(
+ self, stream: typing.BinaryIO
+ ) -> typing.Iterator[bytes]:
"""internal method yielding data from the given stream"""
while True:
data = stream.read(self.blocksize)
@@ -92,7 +104,9 @@ class StaticFile:
if isinstance(self.filelike, str):
stream.close()
- def __call__(self, environ, start_response):
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
"""wsgi interface"""
assert isinstance(environ, dict)
@@ -102,7 +116,7 @@ class StaticFile:
[("Content-length", str(len(resp)))])
return [resp]
- stream = None
+ stream: typing.Optional[typing.BinaryIO] = None
size = -1
try:
if isinstance(self.filelike, str):
diff --git a/wsgitools/authentication.py b/wsgitools/authentication.py
index c076d7f..345a0ae 100644
--- a/wsgitools/authentication.py
+++ b/wsgitools/authentication.py
@@ -1,3 +1,9 @@
+import typing
+
+from wsgitools.internal import (
+ Environ, HeaderList, OptExcInfo, StartResponse, WriteCallback, WsgiApp
+)
+
__all__ = []
class AuthenticationRequired(Exception):
@@ -15,28 +21,24 @@ class AuthenticationMiddleware:
@cvar authorization_method: the implemented Authorization method. It will
be verified against Authorization headers. Subclasses must define this
attribute.
- @type authorization_method: str
"""
- authorization_method = None
- def __init__(self, app):
+ authorization_method: typing.ClassVar[str]
+ def __init__(self, app: WsgiApp):
"""
@param app: is a WSGI application.
"""
assert self.authorization_method is not None
self.app = app
- def authenticate(self, auth, environ):
+ def authenticate(self, auth: str, environ: Environ) -> Environ:
"""Try to authenticate a request. The Authorization header is examined
and checked agains the L{authorization_method} before being passed to
this method. This method must either raise an AuthenticationRequired
instance or return a dictionary explaining what was successfully
authenticated.
- @type auth: str
@param auth: is the part of the Authorization header after the method
- @type environ: {str: object}
@param environ: is the environment passed with a WSGI request
- @rtype: {str: object}
@returns: a dictionary that provides a key "user" listing the
authenticated username as a string. It may also provide the key
"outheaders" with a [(str, str)] value to extend the response
@@ -45,11 +47,10 @@ class AuthenticationMiddleware:
"""
raise NotImplementedError
- def __call__(self, environ, start_response):
- """wsgi interface
-
- @type environ: {str: object}
- """
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
+ """wsgi interface"""
assert isinstance(environ, dict)
try:
try:
@@ -70,7 +71,9 @@ class AuthenticationMiddleware:
assert "user" in result
environ["REMOTE_USER"] = result["user"]
if "outheaders" in result:
- def modified_start_response(status, headers, exc_info=None):
+ def modified_start_response(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
assert isinstance(headers, list)
headers.extend(result["outheaders"])
return start_response(status, headers, exc_info)
@@ -78,22 +81,26 @@ class AuthenticationMiddleware:
modified_start_response = start_response
return self.app(environ, modified_start_response)
- def www_authenticate(self, exception):
+ def www_authenticate(
+ self, exception: AuthenticationRequired
+ ) -> typing.Tuple[str, str]:
"""Generates a WWW-Authenticate header. Subclasses must implement this
method.
- @type exception: L{AuthenticationRequired}
@param exception: reason for generating the header
- @rtype: (str, str)
@returns: the header as (part_before_colon, part_after_colon)
"""
raise NotImplementedError
- def authorization_required(self, environ, start_response, exception):
+ def authorization_required(
+ self,
+ environ: Environ,
+ start_response: StartResponse,
+ exception: AuthenticationRequired,
+ ) -> typing.Iterable[bytes]:
"""Generate an error page after failed authentication. Apart from the
exception parameter, this method behaves like a WSGI application.
- @type exception: L{AuthenticationRequired}
@param exception: reason for the authentication failure
"""
status = "401 Authorization required"
diff --git a/wsgitools/digest.py b/wsgitools/digest.py
index 6eb4cb3..18925df 100644
--- a/wsgitools/digest.py
+++ b/wsgitools/digest.py
@@ -23,25 +23,22 @@ except ImportError:
import random
sysrand = random.SystemRandom()
randbits = sysrand.getrandbits
- def compare_digest(a, b):
+ def compare_digest(a: str, b: str) -> bool:
return a == b
+import sys
+import typing
-from wsgitools.internal import bytes2str, str2bytes, textopen
+from wsgitools.internal import bytes2str, Environ, str2bytes, textopen, WsgiApp
from wsgitools.authentication import AuthenticationRequired, \
ProtocolViolation, AuthenticationMiddleware
-def md5hex(data):
- """
- @type data: str
- @rtype: str
- """
+def md5hex(data: str) -> str:
return hashlib.md5(str2bytes(data)).hexdigest()
-def gen_rand_str(bytesentropy=33):
+def gen_rand_str(bytesentropy: int = 33) -> str:
"""
Generates a string of random base64 characters.
@param bytesentropy: is the number of random 8bit values to be used
- @rtype: str
>>> gen_rand_str() != gen_rand_str()
True
@@ -53,7 +50,7 @@ def gen_rand_str(bytesentropy=33):
randstr = bytes2str(randbytes)
return randstr
-def parse_digest_response(data):
+def parse_digest_response(data: str) -> typing.Dict[str, str]:
"""internal
@raises ValueError:
@@ -118,13 +115,11 @@ def parse_digest_response(data):
value, data = data.split(',', 1)
result[key] = value
-def format_digest(mapping):
+def format_digest(mapping: typing.Dict[str, typing.Tuple[str, bool]]) -> str:
"""internal
- @type mapping: {str: (str, bool)}
@param mapping: a mapping of keys to values and a boolean that
determines whether the value needs quoting.
- @rtype: str
@note: the RFC specifies which values must be quoted and which must not be
quoted.
"""
@@ -150,38 +145,36 @@ class AbstractTokenGenerator:
L{AuthDigestMiddleware}.
@ivar realm: is a string according to RFC2617.
- @type realm: str
"""
- def __init__(self, realm):
- """
- @type realm: str
- """
+ realm: str
+ def __init__(self, realm: str):
assert isinstance(realm, str)
self.realm = realm
- def __call__(self, username, algo="md5"):
+ def __call__(
+ self, username: str, algo: str = "md5"
+ ) -> typing.Optional[str]:
"""Generates an authentication token from a username.
- @type username: str
- @type algo: str
@param algo: currently the only value supported by
L{AuthDigestMiddleware} is "md5"
- @rtype: str or None
@returns: a valid token or None to signal that authentication should
fail
"""
raise NotImplementedError
- def check_password(self, username, password, environ=None):
+ def check_password(
+ self,
+ username: str,
+ password: str,
+ environ: typing.Optional[Environ] = None,
+ ) -> bool:
"""
This function implements the interface for verifying passwords
used by L{BasicAuthMiddleware}. It works by computing a token
from the user and comparing it to the token returned by the
__call__ method.
- @type username: str
- @type password: str
@param environ: ignored
- @rtype: bool
"""
assert isinstance(username, str)
assert isinstance(password, str)
@@ -191,16 +184,26 @@ class AbstractTokenGenerator:
return False
return compare_digest(md5hex(token), expected)
+if sys.version_info >= (3, 11):
+ class TokenGenerator(typing.Protocol):
+ realm: str
+ def __call__(
+ self, username: str, algo: str = "md5"
+ ) -> typing.Optional[str]:
+ ...
+else:
+ TokenGenerator = typing.Callable[[str, str], typing.Optional[str]]
+
__all__.append("AuthTokenGenerator")
class AuthTokenGenerator(AbstractTokenGenerator):
"""Generates authentication tokens for L{AuthDigestMiddleware}. The
interface consists of beeing callable with a username and having a
realm attribute being a string."""
- def __init__(self, realm, getpass):
+ def __init__(
+ self, realm: str, getpass: typing.Callable[[str], typing.Optional[str]]
+ ):
"""
- @type realm: str
@param realm: is a string according to RFC2617.
- @type getpass: str -> (str or None)
@param getpass: this function is called with a username and password is
expected as result. C{None} may be used as an invalid password.
An example for getpass would be C{{username: password}.get}.
@@ -208,7 +211,9 @@ class AuthTokenGenerator(AbstractTokenGenerator):
AbstractTokenGenerator.__init__(self, realm)
self.getpass = getpass
- def __call__(self, username, algo="md5"):
+ def __call__(
+ self, username: str, algo: str = "md5"
+ ) -> typing.Optional[str]:
assert isinstance(username, str)
assert algo.lower() in ["md5", "md5-sess"]
password = self.getpass(username)
@@ -222,12 +227,13 @@ class HtdigestTokenGenerator(AbstractTokenGenerator):
"""Reads authentication tokens for L{AuthDigestMiddleware} from an
apache htdigest file.
"""
- def __init__(self, realm, htdigestfile, ignoreparseerrors=False):
+ users: typing.Dict[str, str]
+
+ def __init__(
+ self, realm: str, htdigestfile: str, ignoreparseerrors: bool = False
+ ):
"""
- @type realm: str
- @type htdigestfile: str
@param htdigestfile: path to the .htdigest file
- @type ignoreparseerrors: bool
@param ignoreparseerrors: passed to readhtdigest
@raises IOError:
@raises ValueError:
@@ -236,10 +242,10 @@ class HtdigestTokenGenerator(AbstractTokenGenerator):
self.users = {}
self.readhtdigest(htdigestfile, ignoreparseerrors)
- def readhtdigest(self, htdigestfile, ignoreparseerrors=False):
+ def readhtdigest(
+ self, htdigestfile: str, ignoreparseerrors: bool = False
+ ) -> None:
"""
- @type htdigestfile: str
- @type ignoreparseerrors: bool
@param ignoreparseerrors: do not raise ValueErrors for bad files
@raises IOError:
@raises ValueError:
@@ -260,7 +266,7 @@ class HtdigestTokenGenerator(AbstractTokenGenerator):
raise ValueError("duplicate user in htdigest file")
self.users[user] = token
- def __call__(self, user, algo="md5"):
+ def __call__(self, user: str, algo: str = "md5") -> typing.Optional[str]:
assert algo.lower() in ["md5", "md5-sess"]
return self.users.get(user)
@@ -269,7 +275,9 @@ class UpdatingHtdigestTokenGenerator(HtdigestTokenGenerator):
"""Behaves like L{HtdigestTokenGenerator}, checks the htdigest file
for changes on each invocation.
"""
- def __init__(self, realm, htdigestfile, ignoreparseerrors=False):
+ def __init__(
+ self, realm: str, htdigestfile: str, ignoreparseerrors: bool = False
+ ):
assert isinstance(htdigestfile, str)
# Need to stat the file before calling parent ctor to detect
# modifications.
@@ -282,7 +290,7 @@ class UpdatingHtdigestTokenGenerator(HtdigestTokenGenerator):
self.htdigestfile = htdigestfile
self.ignoreparseerrors = ignoreparseerrors
- def __call__(self, user, algo="md5"):
+ def __call__(self, user: str, algo: str = "md5") -> typing.Optional[str]:
# The interface does not permit raising exceptions, so all we can do is
# fail by returning None.
try:
@@ -301,36 +309,30 @@ class UpdatingHtdigestTokenGenerator(HtdigestTokenGenerator):
__all__.append("NonceStoreBase")
class NonceStoreBase:
"""Nonce storage interface."""
- def __init__(self):
+ def __init__(self) -> None:
pass
- def newnonce(self, ident=None):
+ def newnonce(self, ident: typing.Optional[str] = None) -> str:
"""
This method is to be overriden and should return new nonces.
- @type ident: str
@param ident: is an identifier to be associated with this nonce
- @rtype: str
"""
raise NotImplementedError
- def checknonce(self, nonce, count=1, ident=None):
+ def checknonce(
+ self, nonce: str, count: int = 1, ident: typing.Optional[str] = None
+ ) -> bool:
"""
This method is to be overridden and should do a check for whether the
given nonce is valid as being used count times.
- @type nonce: str
- @type count: int
@param count: indicates how often the nonce has been used (including
this check)
- @type ident: str
@param ident: it is also checked that the nonce was associated to this
identifier when given
- @rtype: bool
"""
raise NotImplementedError
-def format_time(seconds):
+def format_time(seconds: float) -> str:
"""
internal method formatting a unix time to a fixed-length string
- @type seconds: float
- @rtype: str
"""
# the overflow will happen about 2112
return "%013X" % int(seconds * 1000000)
@@ -356,13 +358,11 @@ class StatelessNonceStore(NonceStoreBase):
>>> s.checknonce(n.rsplit(':', 1)[0] + "bad hash")
False
"""
- def __init__(self, maxage=300, secret=None):
+ def __init__(self, maxage: int = 300, secret: typing.Optional[str] = None):
"""
- @type maxage: int
@param maxage: is the number of seconds a nonce may be valid. Choosing a
large value may result in more memory usage whereas a smaller
value results in more requests. Defaults to 5 minutes.
- @type secret: str
@param secret: if not given, a secret is generated and is therefore
shared after forks. Knowing this secret permits creating nonces.
"""
@@ -373,12 +373,8 @@ class StatelessNonceStore(NonceStoreBase):
else:
self.server_secret = gen_rand_str()
- def newnonce(self, ident=None):
- """
- Generates a new nonce string.
- @type ident: None or str
- @rtype: str
- """
+ def newnonce(self, ident: typing.Optional[str] = None) -> str:
+ """Generates a new nonce string."""
nonce_time = format_time(time.time())
nonce_value = gen_rand_str()
token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret)
@@ -387,14 +383,10 @@ class StatelessNonceStore(NonceStoreBase):
token = md5hex(token)
return "%s:%s:%s" % (nonce_time, nonce_value, token)
- def checknonce(self, nonce, count=1, ident=None):
- """
- Check whether the provided string is a nonce.
- @type nonce: str
- @type count: int
- @type ident: None or str
- @rtype: bool
- """
+ def checknonce(
+ self, nonce: str, count: int = 1, ident: typing.Optional[str] = None
+ ) -> bool:
+ """Check whether the provided string is a nonce."""
if count != 1:
return False
try:
@@ -429,13 +421,13 @@ class MemoryNonceStore(NonceStoreBase):
>>> s.checknonce(n.rsplit(':', 1)[0] + "bad hash")
False
"""
- def __init__(self, maxage=300, maxuses=5):
+ nonces: typing.List[typing.Tuple[str, str, int]]
+
+ def __init__(self, maxage: int = 300, maxuses: int = 5):
"""
- @type maxage: int
@param maxage: is the number of seconds a nonce may be valid. Choosing a
large value may result in more memory usage whereas a smaller
value results in more requests. Defaults to 5 minutes.
- @type maxuses: int
@param maxuses: is the number of times a nonce may be used (with
different nc values). A value of 1 makes nonces usable exactly
once resulting in more requests. Defaults to 5.
@@ -447,18 +439,14 @@ class MemoryNonceStore(NonceStoreBase):
# as [(str (hex encoded), str, int)]
self.server_secret = gen_rand_str()
- def _cleanup(self):
+ def _cleanup(self) -> None:
"""internal methods cleaning list of valid nonces"""
old = format_time(time.time() - self.maxage)
while self.nonces and self.nonces[0][0] < old:
self.nonces.pop(0)
- def newnonce(self, ident=None):
- """
- Generates a new nonce string.
- @type ident: None or str
- @rtype: str
- """
+ def newnonce(self, ident: typing.Optional[str] = None) -> str:
+ """Generates a new nonce string."""
self._cleanup() # avoid growing self.nonces
nonce_time = format_time(time.time())
nonce_value = gen_rand_str()
@@ -469,14 +457,12 @@ class MemoryNonceStore(NonceStoreBase):
token = md5hex(token)
return "%s:%s:%s" % (nonce_time, nonce_value, token)
- def checknonce(self, nonce, count=1, ident=None):
+ def checknonce(
+ self, nonce: str, count: int = 1, ident: typing.Optional[str] = None
+ ) -> bool:
"""
Do a check for whether the provided string is a nonce and increase usage
count on returning True.
- @type nonce: str
- @type count: int
- @type ident: None or str
- @rtype: bool
"""
try:
nonce_time, nonce_value, nonce_hash = nonce.split(':')
@@ -522,19 +508,28 @@ class LazyDBAPI2Opener:
because this way each worker child opens a new database connection when
the first request is to be answered.
"""
- def __init__(self, function, *args, **kwargs):
+ _function: typing.Optional[typing.Callable[..., typing.Any]]
+ def __init__(
+ self,
+ function: typing.Callable[..., typing.Any],
+ *args,
+ **kwargs,
+ ):
"""
The database will be connected on the first method call. This is done
by calling the given function with the remaining parameters.
@param function: is the function that connects to the database
"""
self._function = function
- self._args = args
- self._kwargs = kwargs
+ self._args: typing.Optional[typing.Tuple[typing.Any, ...]] = args
+ self._kwargs: typing.Optional[typing.Dict[str, typing.Any]] = kwargs
self._dbhandle = None
- def _getdbhandle(self):
+ def _getdbhandle(self) -> typing.Any:
"""Returns an open database connection. Open if necessary."""
if self._dbhandle is None:
+ assert self._function is not None
+ assert self._args is not None
+ assert self._kwargs is not None
self._dbhandle = self._function(*self._args, **self._kwargs)
self._function = self._args = self._kwargs = None
return self._dbhandle
@@ -573,14 +568,14 @@ class DBAPI2NonceStore(NonceStoreBase):
>>> s.checknonce(n.rsplit(':', 1)[0] + "bad hash")
False
"""
- def __init__(self, dbhandle, maxage=300, maxuses=5, table="nonces"):
+ def __init__(
+ self, dbhandle, maxage: int = 300, maxuses: int = 5, table="nonces"
+ ):
"""
@param dbhandle: is a dbapi2 connection
- @type maxage: int
@param maxage: is the number of seconds a nonce may be valid. Choosing a
large value may result in more memory usage whereas a smaller
value results in more requests. Defaults to 5 minutes.
- @type maxuses: int
@param maxuses: is the number of times a nonce may be used (with
different nc values). A value of 1 makes nonces usable exactly
once resulting in more requests. Defaults to 5.
@@ -592,16 +587,13 @@ class DBAPI2NonceStore(NonceStoreBase):
self.table = table
self.server_secret = gen_rand_str()
- def _cleanup(self, cur):
+ def _cleanup(self, cur) -> None:
"""internal methods cleaning list of valid nonces"""
old = format_time(time.time() - self.maxage)
cur.execute("DELETE FROM %s WHERE key < '%s:';" % (self.table, old))
- def newnonce(self, ident=None):
- """
- Generates a new nonce string.
- @rtype: str
- """
+ def newnonce(self, ident: typing.Optional[str] = None) -> str:
+ """Generates a new nonce string."""
nonce_time = format_time(time.time())
nonce_value = gen_rand_str()
dbkey = "%s:%s" % (nonce_time, nonce_value)
@@ -615,14 +607,12 @@ class DBAPI2NonceStore(NonceStoreBase):
token = md5hex(token)
return "%s:%s:%s" % (nonce_time, nonce_value, token)
- def checknonce(self, nonce, count=1, ident=None):
+ def checknonce(
+ self, nonce: str, count: int = 1, ident: typing.Optional[str] = None
+ ) -> bool:
"""
Do a check for whether the provided string is a nonce and increase usage
count on returning True.
- @type nonce: str
- @type count: int
- @type ident: str or None
- @rtype: bool
"""
try:
nonce_time, nonce_value, nonce_hash = nonce.split(':')
@@ -668,7 +658,7 @@ class DBAPI2NonceStore(NonceStoreBase):
self.dbhandle.commit()
return True
-def check_uri(credentials, environ):
+def check_uri(credentials: typing.Dict[str, str], environ: Environ) -> None:
"""internal method for verifying the uri credential
@raises AuthenticationRequired:
"""
@@ -706,19 +696,23 @@ class AuthDigestMiddleware(AuthenticationMiddleware):
application."""
authorization_method = "digest"
algorithms = {"md5": md5hex}
- def __init__(self, app, gentoken, maxage=300, maxuses=5, store=None):
+ noncestore: NonceStoreBase
+ def __init__(
+ self,
+ app: WsgiApp,
+ gentoken: TokenGenerator,
+ maxage: int = 300,
+ maxuses: int = 5,
+ store: typing.Optional[NonceStoreBase] = None,
+ ):
"""
@param app: is the wsgi application to be served with authentication.
- @type gentoken: str -> (str or None)
@param gentoken: has to have the same functionality and interface as the
L{AuthTokenGenerator} class.
- @type maxage: int
@param maxage: deprecated, see L{MemoryNonceStore} or
L{StatelessNonceStore} and pass an instance to store
- @type maxuses: int
@param maxuses: deprecated, see L{MemoryNonceStore} and pass an
instance to store
- @type store: L{NonceStoreBase}
@param store: a nonce storage implementation object. Usage of this
parameter will override maxage and maxuses.
"""
@@ -731,7 +725,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware):
assert hasattr(store, "checknonce")
self.noncestore = store
- def authenticate(self, auth, environ):
+ def authenticate(self, auth: str, environ: Environ) -> Environ:
assert isinstance(auth, str)
try:
credentials = parse_digest_response(auth)
@@ -783,7 +777,9 @@ class AuthDigestMiddleware(AuthenticationMiddleware):
return dict(user=credentials["username"],
outheaders=[("Authentication-Info", format_digest(digest))])
- def auth_response(self, credentials, reqmethod):
+ def auth_response(
+ self, credentials: typing.Dict[str, str], reqmethod: str
+ ) -> typing.Optional[str]:
"""internal method generating authentication tokens
@raises AuthenticationRequired:
"""
@@ -818,7 +814,9 @@ class AuthDigestMiddleware(AuthenticationMiddleware):
dig.insert(0, a1h)
return self.algorithms[algo](":".join(dig))
- def www_authenticate(self, exception):
+ def www_authenticate(
+ self, exception: AuthenticationRequired
+ ) -> typing.Tuple[str, str]:
digest = dict(nonce=(self.noncestore.newnonce(), True),
realm=(self.gentoken.realm, True),
algorithm=("MD5", False),
diff --git a/wsgitools/filters.py b/wsgitools/filters.py
index 7f8543d..eada536 100644
--- a/wsgitools/filters.py
+++ b/wsgitools/filters.py
@@ -11,27 +11,37 @@ import sys
import time
import gzip
import io
+import typing
-from wsgitools.internal import str2bytes
+from wsgitools.internal import (
+ Environ,
+ HeaderList,
+ OptExcInfo,
+ StartResponse,
+ str2bytes,
+ WriteCallback,
+ WsgiApp,
+)
__all__.append("CloseableIterator")
class CloseableIterator:
"""Concatenating iterator with close attribute."""
- def __init__(self, close_function, *iterators):
+ def __init__(
+ self,
+ close_function: typing.Optional[typing.Callable[[], None]],
+ *iterators: typing.Iterable[typing.Any]
+ ):
"""If close_function is not C{None}, it will be the C{close} attribute
of the created iterator object. Further parameters specify iterators
that are to be concatenated.
- @type close_function: a function or C{None}
"""
if close_function is not None:
self.close = close_function
self.iterators = list(map(iter, iterators))
- def __iter__(self):
- """iterator interface
- @rtype: gen()
- """
+ def __iter__(self) -> typing.Iterator[typing.Any]:
+ """iterator interface"""
return self
- def __next__(self):
+ def __next__(self) -> typing.Any:
"""iterator interface"""
if not self.iterators:
raise StopIteration
@@ -45,16 +55,17 @@ class CloseableIterator:
__all__.append("CloseableList")
class CloseableList(list):
"""A list with a close attribute."""
- def __init__(self, close_function, *args):
+ def __init__(
+ self, close_function: typing.Optional[typing.Callable[[], None]], *args
+ ):
"""If close_function is not C{None}, it will be the C{close} attribute
of the created list object. Other parameters are passed to the list
constructor.
- @type close_function: a function or C{None}
"""
if close_function is not None:
self.close = close_function
list.__init__(self, *args)
- def __iter__(self):
+ def __iter__(self) -> CloseableIterator:
"""iterator interface"""
return CloseableIterator(getattr(self, "close", None),
list.__iter__(self))
@@ -76,104 +87,101 @@ class BaseWSGIFilter:
L{BaseWSGIFilter} class to a L{WSGIFilterMiddleware} will result in not
modifying requests at all.
"""
- def __init__(self):
+ def __init__(self) -> None:
"""This constructor does nothing and can safely be overwritten. It is
only listed here to document that it must be callable without additional
parameters."""
- def filter_environ(self, environ):
+ def filter_environ(self, environ: Environ) -> Environ:
"""Receives a dict with the environment passed to the wsgi application
and a C{dict} must be returned. The default is to return the same dict.
- @type environ: {str: str}
- @rtype: {str: str}
"""
return environ
- def filter_exc_info(self, exc_info):
+ def filter_exc_info(self, exc_info: OptExcInfo) -> OptExcInfo:
"""Receives either C{None} or a tuple passed as third argument to
C{start_response} from the wrapped wsgi application. Either C{None} or
such a tuple must be returned."""
return exc_info
- def filter_status(self, status):
+ def filter_status(self, status: str) -> str:
"""Receives a status string passed as first argument to
C{start_response} from the wrapped wsgi application. A valid HTTP status
string must be returned.
- @type status: str
- @rtype: str
"""
return status
- def filter_header(self, headername, headervalue):
+ def filter_header(
+ self, headername: str, headervalue: str
+ ) -> typing.Optional[typing.Tuple[str, str]]:
"""This function is invoked for each C{(headername, headervalue)} tuple
in the second argument to the C{start_response} from the wrapped wsgi
application. Such a value or C{None} for discarding the header must be
returned.
- @type headername: str
- @type headervalue: str
- @rtype: (str, str)
"""
return (headername, headervalue)
- def filter_headers(self, headers):
+ def filter_headers(self, headers: HeaderList) -> HeaderList:
"""A list of headers passed as the second argument to the
C{start_response} from the wrapped wsgi application is passed to this
function and such a list must also be returned.
- @type headers: [(str, str)]
- @rtype: [(str, str)]
"""
return headers
- def filter_data(self, data):
+ def filter_data(self, data: bytes) -> bytes:
"""For each string that is either written by the C{write} callable or
returned from the wrapped wsgi application this method is invoked. It
must return a string.
- @type data: bytes
- @rtype: bytes
"""
return data
- def append_data(self):
+ def append_data(self) -> typing.Iterable[bytes]:
"""This function can be used to append data to the response. A list of
strings or some kind of iterable yielding strings has to be returned.
The default is to return an empty list.
- @rtype: gen([bytes])
"""
return []
- def handle_close(self):
+ def handle_close(self) -> None:
"""This method is invoked after the request has finished."""
__all__.append("WSGIFilterMiddleware")
class WSGIFilterMiddleware:
"""This wsgi middleware can be used with specialized L{BaseWSGIFilter}s to
modify wsgi requests and/or reponses."""
- def __init__(self, app, filterclass):
+ def __init__(
+ self, app: WsgiApp, filterclass: typing.Callable[[], BaseWSGIFilter]
+ ):
"""
@param app: is a wsgi application.
- @type filterclass: L{BaseWSGIFilter}s subclass
- @param filterclass: is a subclass of L{BaseWSGIFilter} or some class
- that implements the interface."""
+ @param filterclass: is factory creating L{BaseWSGIFilter} instances
+ """
self.app = app
self.filterclass = filterclass
- def __call__(self, environ, start_response):
- """wsgi interface
- @type environ: {str, str}
- @rtype: gen([bytes])
- """
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
+ """wsgi interface"""
assert isinstance(environ, dict)
reqfilter = self.filterclass()
environ = reqfilter.filter_environ(environ)
- def modified_start_response(status, headers, exc_info=None):
+ def modified_start_response(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
assert isinstance(status, str)
assert isinstance(headers, list)
exc_info = reqfilter.filter_exc_info(exc_info)
status = reqfilter.filter_status(status)
- headers = (reqfilter.filter_header(h, v) for h, v in headers)
- headers = [h for h in headers if h]
- headers = reqfilter.filter_headers(headers)
+ headers = reqfilter.filter_headers(
+ list(
+ filter(
+ None,
+ (reqfilter.filter_header(h, v) for h, v in headers)
+ )
+ )
+ )
write = start_response(status, headers, exc_info)
- def modified_write(data):
+ def modified_write(data: bytes) -> None:
write(reqfilter.filter_data(data))
return modified_write
ret = self.app(environ, modified_start_response)
assert hasattr(ret, "__iter__")
- def modified_close():
+ def modified_close() -> None:
reqfilter.handle_close()
getattr(ret, "close", lambda:0)()
@@ -182,7 +190,7 @@ class WSGIFilterMiddleware:
list(map(reqfilter.filter_data, ret))
+ list(reqfilter.append_data()))
ret = iter(ret)
- def late_append_data():
+ def late_append_data() -> typing.Iterator[bytes]:
"""Invoke C{reqfilter.append_data()} after C{filter_data()} has seen
all data."""
for data in reqfilter.append_data():
@@ -194,38 +202,42 @@ class WSGIFilterMiddleware:
# Using map and lambda here since pylint cannot handle list comprehension in
# default arguments. Also note that neither ' nor " are considered printable.
# For escape_string to be reversible \ is also not considered printable.
-def escape_string(string, replacer=list(map(
- lambda i: chr(i) if str2bytes(chr(i)).isalnum() or
- chr(i) in '!#$%&()*+,-./:;<=>?@[]^_`{|}~ ' else
- r"\x%2.2x" % i,
- range(256)))):
- """Encodes non-printable characters in a string using \\xXX escapes.
-
- @type string: str
- @rtype: str
- """
+def escape_string(
+ string: str,
+ replacer: typing.List[str] = list(
+ map(
+ lambda i: (
+ chr(i) if (
+ str2bytes(chr(i)).isalnum() or
+ chr(i) in '!#$%&()*+,-./:;<=>?@[]^_`{|}~ '
+ ) else r"\x%2.2x" % i
+ ),
+ range(256),
+ )
+ ),
+) -> str:
+ """Encodes non-printable characters in a string using \\xXX escapes."""
return "".join(replacer[ord(char)] for char in string)
__all__.append("RequestLogWSGIFilter")
class RequestLogWSGIFilter(BaseWSGIFilter):
"""This filter logs all requests in the apache log file format."""
+ proto: typing.Optional[str]
@classmethod
- def creator(cls, log, flush=True):
+ def creator(
+ cls, log: typing.TextIO, flush: bool = True
+ ) -> typing.Callable[[], "RequestLogWSGIFilter"]:
"""Returns a function creating L{RequestLogWSGIFilter}s on given log
file. log has to be a file-like object.
- @type log: file-like
@param log: elements of type str are written to the log. That means in
Py3.X the contents are decoded and in Py2.X the log is assumed
to be encoded in latin1. This follows the spirit of WSGI.
- @type flush: bool
@param flush: if True, invoke the flush method on log after each
write invocation
"""
return lambda:cls(log, flush)
- def __init__(self, log=sys.stdout, flush=True):
+ def __init__(self, log: typing.TextIO = sys.stdout, flush: bool = True):
"""
- @type log: file-like
- @type flush: bool
@param flush: if True, invoke the flush method on log after each
write invocation
"""
@@ -244,11 +256,8 @@ class RequestLogWSGIFilter(BaseWSGIFilter):
self.length = 0
self.referrer = None
self.useragent = None
- def filter_environ(self, environ):
- """BaseWSGIFilter interface
- @type environ: {str: str}
- @rtype: {str: str}
- """
+ def filter_environ(self, environ: Environ) -> Environ:
+ """BaseWSGIFilter interface"""
assert isinstance(environ, dict)
self.remote = environ.get("REMOTE_ADDR", self.remote)
self.user = environ.get("REMOTE_USER", self.user)
@@ -260,19 +269,16 @@ class RequestLogWSGIFilter(BaseWSGIFilter):
self.referrer = environ.get("HTTP_REFERER", self.referrer)
self.useragent = environ.get("HTTP_USER_AGENT", self.useragent)
return environ
- def filter_status(self, status):
- """BaseWSGIFilter interface
- @type status: str
- @rtype: str
- """
+ def filter_status(self, status: str) -> str:
+ """BaseWSGIFilter interface"""
assert isinstance(status, str)
self.status = status.split()[0]
return status
- def filter_data(self, data):
+ def filter_data(self, data: bytes) -> bytes:
assert isinstance(data, bytes)
self.length += len(data)
return data
- def handle_close(self):
+ def handle_close(self) -> None:
"""BaseWSGIFilter interface"""
line = '%s %s - [%s]' % (self.remote, self.user, self.time)
line = '%s "%s %s' % (line, escape_string(self.reqmethod),
@@ -301,25 +307,18 @@ class TimerWSGIFilter(BaseWSGIFilter):
something like C{["spam?GenTime", "?GenTime spam", "?GenTime"]} only the
last occurance get's replaced."""
@classmethod
- def creator(cls, pattern):
+ def creator(cls, pattern: bytes) -> typing.Callable[[], "TimerWSGIFilter"]:
"""Returns a function creating L{TimerWSGIFilter}s with a given pattern
beeing a string of exactly eight bytes.
- @type pattern: bytes
"""
return lambda:cls(pattern)
- def __init__(self, pattern=b"?GenTime"):
- """
- @type pattern: str
- """
+ def __init__(self, pattern: bytes = b"?GenTime"):
BaseWSGIFilter.__init__(self)
assert isinstance(pattern, bytes)
self.pattern = pattern
self.start = time.time()
- def filter_data(self, data):
- """BaseWSGIFilter interface
- @type data: bytes
- @rtype: bytes
- """
+ def filter_data(self, data: bytes) -> bytes:
+ """BaseWSGIFilter interface"""
if data == self.pattern:
return b"%8.3g" % (time.time() - self.start)
return data
@@ -331,30 +330,19 @@ class EncodeWSGIFilter(BaseWSGIFilter):
whereas wsgi mandates the use of bytes.
"""
@classmethod
- def creator(cls, charset):
+ def creator(cls, charset: str) -> typing.Callable[[], "EncodeWSGIFilter"]:
"""Returns a function creating L{EncodeWSGIFilter}s with a given
charset.
- @type charset: str
"""
return lambda:cls(charset)
- def __init__(self, charset="utf-8"):
- """
- @type charset: str
- """
+ def __init__(self, charset: str = "utf-8"):
BaseWSGIFilter.__init__(self)
self.charset = charset
- def filter_data(self, data):
- """BaseWSGIFilter interface
- @type data: str
- @rtype: bytes
- """
+ def filter_data(self, data: str) -> bytes:
+ """BaseWSGIFilter interface"""
return data.encode(self.charset)
- def filter_header(self, header, value):
- """BaseWSGIFilter interface
- @type header: str
- @type value: str
- @rtype: (str, str)
- """
+ def filter_header(self, header: str, value: str) -> typing.Tuple[str, str]:
+ """BaseWSGIFilter interface"""
if header.lower() != "content-type":
return (header, value)
return (header, "%s; charset=%s" % (value, self.charset))
@@ -362,17 +350,20 @@ class EncodeWSGIFilter(BaseWSGIFilter):
__all__.append("GzipWSGIFilter")
class GzipWSGIFilter(BaseWSGIFilter):
"""Compresses content using gzip."""
+ gzip: typing.Optional[gzip.GzipFile]
+ sio: typing.Optional[io.BytesIO]
+
@classmethod
- def creator(cls, flush=True):
+ def creator(
+ cls, flush: bool = True
+ ) -> typing.Callable[[], "GzipWSGIFilter"]:
"""
Returns a function creating L{GzipWSGIFilter}s.
- @type flush: bool
@param flush: whether or not the filter should always flush the buffer
"""
return lambda:cls(flush)
- def __init__(self, flush=True):
+ def __init__(self, flush: bool = True):
"""
- @type flush: bool
@param flush: whether or not the filter should always flush the buffer
"""
BaseWSGIFilter.__init__(self)
@@ -380,10 +371,8 @@ class GzipWSGIFilter(BaseWSGIFilter):
self.compress = False
self.sio = None
self.gzip = None
- def filter_environ(self, environ):
- """BaseWSGIFilter interface
- @type environ: {str: str}
- """
+ def filter_environ(self, environ: Environ) -> Environ:
+ """BaseWSGIFilter interface"""
assert isinstance(environ, dict)
if "HTTP_ACCEPT_ENCODING" in environ:
acceptenc = environ["HTTP_ACCEPT_ENCODING"].split(',')
@@ -392,28 +381,25 @@ class GzipWSGIFilter(BaseWSGIFilter):
self.sio = io.BytesIO()
self.gzip = gzip.GzipFile(fileobj=self.sio, mode="w")
return environ
- def filter_header(self, headername, headervalue):
- """ BaseWSGIFilter interface
- @type headername: str
- @type headervalue: str
- @rtype: (str, str) or None
- """
+ def filter_header(
+ self, headername: str, headervalue: str
+ ) -> typing.Optional[typing.Tuple[str, str]]:
+ """BaseWSGIFilter interface"""
if self.compress:
if headername.lower() == "content-length":
return None
return (headername, headervalue)
- def filter_headers(self, headers):
- """BaseWSGIFilter interface
- @type headers: [(str, str)]
- @rtype: [(str, str)]
- """
+ def filter_headers(self, headers: HeaderList) -> HeaderList:
+ """BaseWSGIFilter interface"""
assert isinstance(headers, list)
if self.compress:
headers.append(("Content-encoding", "gzip"))
return headers
- def filter_data(self, data):
+ def filter_data(self, data: bytes) -> bytes:
if not self.compress:
return data
+ assert self.gzip is not None
+ assert self.sio is not None
self.gzip.write(data)
if self.flush:
self.gzip.flush()
@@ -421,9 +407,11 @@ class GzipWSGIFilter(BaseWSGIFilter):
self.sio.truncate(0)
self.sio.seek(0)
return data
- def append_data(self):
+ def append_data(self) -> typing.List[bytes]:
if not self.compress:
return []
+ assert self.gzip is not None
+ assert self.sio is not None
self.gzip.close()
data = self.sio.getvalue()
return [data]
@@ -435,30 +423,28 @@ class ReusableWSGIInputFilter(BaseWSGIFilter):
C{BytesIO} instance which provides a C{seek} method.
"""
@classmethod
- def creator(cls, maxrequestsize):
+ def creator(
+ cls, maxrequestsize: int
+ ) -> typing.Callable[[], "ReusableWSGIInputFilter"]:
"""
Returns a function creating L{ReusableWSGIInputFilter}s with desired
maxrequestsize being set. If there is more data than maxrequestsize is
available in C{wsgi.input} the rest will be ignored. (It is up to the
adapter to eat this data.)
- @type maxrequestsize: int
@param maxrequestsize: is the maximum number of bytes to store in the
C{BytesIO}
"""
return lambda:cls(maxrequestsize)
- def __init__(self, maxrequestsize=65536):
+ def __init__(self, maxrequestsize: int = 65536):
"""ReusableWSGIInputFilters constructor.
- @type maxrequestsize: int
@param maxrequestsize: is the maximum number of bytes to store in the
C{BytesIO}, see L{creator}
"""
BaseWSGIFilter.__init__(self)
self.maxrequestsize = maxrequestsize
- def filter_environ(self, environ):
- """BaseWSGIFilter interface
- @type environ: {str: str}
- """
+ def filter_environ(self, environ: Environ) -> Environ:
+ """BaseWSGIFilter interface"""
if isinstance(environ["wsgi.input"], io.BytesIO):
return environ # nothing to be done
diff --git a/wsgitools/internal.py b/wsgitools/internal.py
index 9bf7ded..86a9d5a 100644
--- a/wsgitools/internal.py
+++ b/wsgitools/internal.py
@@ -1,11 +1,42 @@
-def bytes2str(bstr):
+import sys
+import typing
+
+def bytes2str(bstr: bytes) -> str:
assert isinstance(bstr, bytes)
return bstr.decode("iso-8859-1") # always successful
-def str2bytes(sstr):
+def str2bytes(sstr: str) -> bytes:
assert isinstance(sstr, str)
return sstr.encode("iso-8859-1") # might fail, but spec says it doesn't
-def textopen(filename, mode):
+def textopen(filename: str, mode: str) -> typing.TextIO:
# We use the same encoding as for all wsgi strings here.
return open(filename, mode, encoding="iso-8859-1")
+
+Environ = typing.Dict[str, typing.Any]
+
+HeaderList = typing.List[typing.Tuple[str, str]]
+
+if sys.version_info >= (3, 11):
+ OptExcInfo = typing.Optional[
+ typing.Tuple[type[BaseException], BaseException, typing.Any]
+ ]
+else:
+ OptExcInfo = typing.Optional[
+ typing.Tuple[typing.Any, BaseException, typing.Any]
+ ]
+
+WriteCallback = typing.Callable[[bytes], None]
+
+if sys.version_info >= (3, 11):
+ class StartResponse(typing.Protocol):
+ def __call__(
+ self, status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
+ ...
+else:
+ StartResponse = typing.Callable[
+ [str, HeaderList, OptExcInfo], WriteCallback
+ ]
+
+WsgiApp = typing.Callable[[Environ, StartResponse], typing.Iterable[bytes]]
diff --git a/wsgitools/middlewares.py b/wsgitools/middlewares.py
index ef9fe84..8577384 100644
--- a/wsgitools/middlewares.py
+++ b/wsgitools/middlewares.py
@@ -6,27 +6,28 @@ import sys
import cgitb
import collections
import io
+import typing
from wsgitools.internal import bytes2str, str2bytes
from wsgitools.filters import CloseableList, CloseableIterator
from wsgitools.authentication import AuthenticationRequired, \
ProtocolViolation, AuthenticationMiddleware
+from wsgitools.internal import (
+ Environ, HeaderList, OptExcInfo, StartResponse, WriteCallback, WsgiApp
+)
__all__.append("SubdirMiddleware")
class SubdirMiddleware:
"""Middleware choosing wsgi applications based on a dict."""
- def __init__(self, default, mapping={}):
- """
- @type default: wsgi app
- @type mapping: {str: wsgi app}
- """
+ def __init__(
+ self, default: WsgiApp, mapping: typing.Dict[str, WsgiApp] = {}
+ ):
self.default = default
self.mapping = mapping
- def __call__(self, environ, start_response):
- """wsgi interface
- @type environ: {str: str}
- @rtype: gen([bytes])
- """
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
+ """wsgi interface"""
assert isinstance(environ, dict)
app = None
script = environ["PATH_INFO"]
@@ -51,22 +52,24 @@ class NoWriteCallableMiddleware:
C{start_response} function to a wsgi application that doesn't need one by
writing the data to a C{BytesIO} and then making it be the first result
element."""
- def __init__(self, app):
+ def __init__(self, app: WsgiApp):
"""Wraps wsgi application app."""
self.app = app
- def __call__(self, environ, start_response):
- """wsgi interface
- @type environ: {str, str}
- @rtype: gen([bytes])
- """
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
+ """wsgi interface"""
assert isinstance(environ, dict)
- todo = [None]
+ todo: typing.Optional[typing.Tuple[str, HeaderList]] = None
sio = io.BytesIO()
gotiterdata = False
- def write_calleable(data):
+ def write_calleable(data: bytes) -> None:
assert not gotiterdata
sio.write(data)
- def modified_start_response(status, headers, exc_info=None):
+ def modified_start_response(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
+ nonlocal todo
try:
if sio.tell() > 0 or gotiterdata:
assert exc_info is not None
@@ -75,7 +78,7 @@ class NoWriteCallableMiddleware:
exc_info = None
assert isinstance(status, str)
assert isinstance(headers, list)
- todo[0] = (status, headers)
+ todo = (status, headers)
return write_calleable
ret = self.app(environ, modified_start_response)
@@ -96,8 +99,8 @@ class NoWriteCallableMiddleware:
else:
gotiterdata = True
- assert todo[0] is not None
- status, headers = todo[0]
+ assert todo is not None
+ status, headers = todo
data = sio.getvalue()
if isinstance(ret, list):
@@ -117,27 +120,36 @@ class ContentLengthMiddleware:
"""Guesses the content length header if possible.
@note: The application used must not use the C{write} callable returned by
C{start_response}."""
- def __init__(self, app, maxstore=0):
+ maxstore: typing.Union[float, int]
+ def __init__(
+ self, app: WsgiApp, maxstore: typing.Union[int, typing.Tuple[()]] = 0
+ ):
"""Wraps wsgi application app. If the application returns a list, the
total length of strings is available and the content length header is
set unless there already is one. For an iterator data is accumulated up
to a total of maxstore bytes (where maxstore=() means infinity). If the
iterator is exhaused within maxstore bytes a content length header is
added unless already present.
- @type maxstore: int or ()
@note: that setting maxstore to a value other than 0 will violate the
wsgi standard
"""
self.app = app
if maxstore == ():
- maxstore = float("inf")
- self.maxstore = maxstore
- def __call__(self, environ, start_response):
+ self.maxstore = float("inf")
+ else:
+ assert isinstance(maxstore, int)
+ self.maxstore = maxstore
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
"""wsgi interface"""
assert isinstance(environ, dict)
- todo = []
+ todo: typing.Optional[typing.Tuple[str, HeaderList]] = None
gotdata = False
- def modified_start_response(status, headers, exc_info=None):
+ def modified_start_response(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
+ nonlocal todo
try:
if gotdata:
assert exc_info is not None
@@ -146,8 +158,8 @@ class ContentLengthMiddleware:
exc_info = None
assert isinstance(status, str)
assert isinstance(headers, list)
- todo[:] = ((status, headers),)
- def raise_not_imp(*args):
+ todo = (status, headers)
+ def raise_not_imp(_: bytes) -> None:
raise NotImplementedError
return raise_not_imp
@@ -156,8 +168,8 @@ class ContentLengthMiddleware:
if isinstance(ret, list):
gotdata = True
- assert bool(todo)
- status, headers = todo[0]
+ assert todo is not None
+ status, headers = todo
if all(k.lower() != "content-length" for k, _ in headers):
length = sum(map(len, ret))
headers.append(("Content-Length", str(length)))
@@ -173,8 +185,8 @@ class ContentLengthMiddleware:
except StopIteration:
stopped = True
gotdata = True
- assert bool(todo)
- status, headers = todo[0]
+ assert todo is not None
+ status, headers = todo
data = CloseableList(getattr(ret, "close", None))
if first:
data.append(first)
@@ -197,12 +209,12 @@ class ContentLengthMiddleware:
return CloseableIterator(getattr(ret, "close", None), data, ret)
-def storable(environ):
+def storable(environ: Environ) -> bool:
if environ["REQUEST_METHOD"] != "GET":
return False
return True
-def cacheable(environ):
+def cacheable(environ: Environ) -> bool:
if environ.get("HTTP_CACHE_CONTROL", "") == "max-age=0":
return False
return True
@@ -213,16 +225,24 @@ class CachingMiddleware:
C{QUERY_STRING}."""
class CachedRequest:
- def __init__(self, timestamp):
+ def __init__(self, timestamp: float):
self.timestamp = timestamp
self.status = ""
- self.headers = []
- self.body = []
-
- def __init__(self, app, maxage=60, storable=storable, cacheable=cacheable):
+ self.headers: HeaderList = []
+ self.body: typing.List[bytes] = []
+
+ cache: typing.Dict[str, CachedRequest]
+ lastcached: typing.Deque[typing.Tuple[str, float]]
+
+ def __init__(
+ self,
+ app: WsgiApp,
+ maxage: int = 60,
+ storable: typing.Callable[[Environ], bool] = storable,
+ cacheable: typing.Callable[[Environ], bool] = cacheable,
+ ):
"""
@param app: is a wsgi application to be cached.
- @type maxage: int
@param maxage: is the number of seconds a reponse may be cached.
@param storable: is a predicate that determines whether the response
may be cached at all based on the C{environ} dict.
@@ -235,13 +255,20 @@ class CachingMiddleware:
self.cache = {}
self.lastcached = collections.deque()
- def insert_cache(self, key, obj, now=None):
+ def insert_cache(
+ self,
+ key: str,
+ obj: CachedRequest,
+ now: typing.Optional[float] = None,
+ ) -> None:
if now is None:
now = time.time()
self.cache[key] = obj
self.lastcached.append((key, now))
- def prune_cache(self, maxclean=16, now=None):
+ def prune_cache(
+ self, maxclean: int = 16, now: typing.Optional[float] = None
+ ) -> None:
if now is None:
now = time.time()
old = now - self.maxage
@@ -258,10 +285,10 @@ class CachingMiddleware:
if obj.timestamp <= old:
del self.cache[key]
- def __call__(self, environ, start_response):
- """wsgi interface
- @type environ: {str: str}
- """
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
+ """wsgi interface"""
assert isinstance(environ, dict)
now = time.time()
self.prune_cache(now=now)
@@ -279,7 +306,9 @@ class CachingMiddleware:
else:
del self.cache[path]
cache_object = self.CachedRequest(now)
- def modified_start_respesponse(status, headers, exc_info=None):
+ def modified_start_respesponse(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
try:
if cache_object.body:
assert exc_info is not None
@@ -291,7 +320,7 @@ class CachingMiddleware:
cache_object.status = status
cache_object.headers = headers
write = start_response(status, list(headers))
- def modified_write(data):
+ def modified_write(data: bytes) -> None:
cache_object.body.append(data)
write(data)
return modified_write
@@ -303,7 +332,7 @@ class CachingMiddleware:
cache_object.body.extend(ret)
self.insert_cache(path, cache_object, now)
return ret
- def pass_through():
+ def pass_through() -> typing.Iterator[bytes]:
for data in ret:
cache_object.body.append(data)
yield data
@@ -313,18 +342,13 @@ class CachingMiddleware:
__all__.append("DictAuthChecker")
class DictAuthChecker:
"""Verifies usernames and passwords by looking them up in a dict."""
- def __init__(self, users):
+ def __init__(self, users: typing.Dict[str, str]):
"""
- @type users: {str: str}
@param users: is a dict mapping usernames to password."""
self.users = users
- def __call__(self, username, password, environ):
+ def __call__(self, username: str, password: str, environ: Environ) -> bool:
"""check_function interface taking username and password and resulting
in a bool.
- @type username: str
- @type password: str
- @type environ: {str: object}
- @rtype: bool
"""
return username in self.users and self.users[username] == password
@@ -334,7 +358,13 @@ class BasicAuthMiddleware(AuthenticationMiddleware):
the warpped application the environ dictionary is augmented by a REMOTE_USER
key."""
authorization_method = "basic"
- def __init__(self, app, check_function, realm='www', app401=None):
+ def __init__(
+ self,
+ app: WsgiApp,
+ check_function: typing.Callable[[str, str, Environ], bool],
+ realm: str = 'www',
+ app401: typing.Optional[WsgiApp] = None,
+ ):
"""
@param app: is a WSGI application.
@param check_function: is a function taking three arguments username,
@@ -342,7 +372,6 @@ class BasicAuthMiddleware(AuthenticationMiddleware):
request may is allowed. The older interface of taking only the
first two arguments is still supported via catching a
C{TypeError}.
- @type realm: str
@param app401: is an optional WSGI application to be used for error
messages
"""
@@ -351,7 +380,9 @@ class BasicAuthMiddleware(AuthenticationMiddleware):
self.realm = realm
self.app401 = app401
- def authenticate(self, auth, environ):
+ def authenticate(
+ self, auth: str, environ: Environ
+ ) -> typing.Dict[str, str]:
assert isinstance(auth, str)
assert isinstance(environ, dict)
authb = str2bytes(auth)
@@ -372,10 +403,17 @@ class BasicAuthMiddleware(AuthenticationMiddleware):
return dict(user=username)
raise AuthenticationRequired("credentials not valid")
- def www_authenticate(self, exception):
+ def www_authenticate(
+ self, exception: AuthenticationRequired
+ ) -> typing.Tuple[str, str]:
return ("WWW-Authenticate", 'Basic realm="%s"' % self.realm)
- def authorization_required(self, environ, start_response, exception):
+ def authorization_required(
+ self,
+ environ: Environ,
+ start_response: StartResponse,
+ exception: AuthenticationRequired,
+ ) -> typing.Iterable[bytes]:
if self.app401 is not None:
return self.app401(environ, start_response)
return AuthenticationMiddleware.authorization_required(
@@ -385,13 +423,13 @@ __all__.append("TracebackMiddleware")
class TracebackMiddleware:
"""In case the application throws an exception this middleware will show an
html-formatted traceback using C{cgitb}."""
- def __init__(self, app):
+ def __init__(self, app: WsgiApp):
"""app is the wsgi application to proxy."""
self.app = app
- def __call__(self, environ, start_response):
- """wsgi interface
- @type environ: {str: str}
- """
+ def __call__(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
+ """wsgi interface"""
try:
assert isinstance(environ, dict)
ret = self.app(environ, start_response)
diff --git a/wsgitools/scgi/__init__.py b/wsgitools/scgi/__init__.py
index e2a68c2..677e3b5 100644
--- a/wsgitools/scgi/__init__.py
+++ b/wsgitools/scgi/__init__.py
@@ -1,5 +1,8 @@
__all__ = []
+import socket
+import typing
+
try:
import sendfile
except ImportError:
@@ -7,6 +10,8 @@ except ImportError:
else:
have_sendfile = True
+from wsgitools.internal import Environ
+
class FileWrapper:
"""
@ivar offset: Initially 0. Becomes -1 when reading using next and
@@ -14,18 +19,20 @@ class FileWrapper:
counts the number of bytes sent. It also ensures that next and
transfer are never mixed.
"""
- def __init__(self, filelike, blksize=8192):
+ def __init__(self, filelike, blksize: int = 8192):
self.filelike = filelike
self.blksize = blksize
self.offset = 0
if hasattr(filelike, "close"):
self.close = filelike.close
- def can_transfer(self):
+ def can_transfer(self) -> bool:
return have_sendfile and hasattr(self.filelike, "fileno") and \
self.offset >= 0
- def transfer(self, sock, blksize=None):
+ def transfer(
+ self, sock: socket.socket, blksize: typing.Optional[int] = None
+ ) -> int:
assert self.offset >= 0
if blksize is None:
blksize = self.blksize
@@ -42,10 +49,10 @@ class FileWrapper:
self.offset += sent
return sent
- def __iter__(self):
+ def __iter__(self) -> typing.Iterator[bytes]:
return self
- def __next__(self):
+ def __next__(self) -> bytes:
assert self.offset <= 0
self.offset = -1
data = self.filelike.read(self.blksize)
@@ -53,8 +60,12 @@ class FileWrapper:
return data
raise StopIteration
-def _convert_environ(environ, multithread=False, multiprocess=False,
- run_once=False):
+def _convert_environ(
+ environ: Environ,
+ multithread: bool = False,
+ multiprocess: bool = False,
+ run_once: bool = False,
+) -> None:
environ.update({
"wsgi.version": (1, 0),
"wsgi.url_scheme": "http",
diff --git a/wsgitools/scgi/asynchronous.py b/wsgitools/scgi/asynchronous.py
index 61bbc6b..264e43e 100644
--- a/wsgitools/scgi/asynchronous.py
+++ b/wsgitools/scgi/asynchronous.py
@@ -5,10 +5,24 @@ import io
import socket
import sys
import errno
+import typing
-from wsgitools.internal import bytes2str, str2bytes
+from wsgitools.internal import (
+ bytes2str,
+ Environ,
+ HeaderList,
+ OptExcInfo,
+ str2bytes,
+ WriteCallback,
+ WsgiApp,
+)
from wsgitools.scgi import _convert_environ, FileWrapper
+if sys.version_info >= (3, 8):
+ LiteralTrue = typing.Literal[True]
+else:
+ LiteralTrue = bool
+
class SCGIConnection(asyncore.dispatcher):
"""SCGI connection class used by L{SCGIServer}."""
# connection states
@@ -19,11 +33,27 @@ class SCGIConnection(asyncore.dispatcher):
RESP = 3*4 | 2 # sending response, end state
RESPH = 4*4 | 2 # buffered response headers, sending headers only, to TRANS
TRANS = 5*4 | 2 # transferring using FileWrapper, end state
- def __init__(self, server, connection, maxrequestsize=65536,
- maxpostsize=8<<20, blocksize=4096, config={}):
+
+ outheaders: typing.Union[
+ typing.Tuple[()], # unset
+ typing.Tuple[str, HeaderList], # headers
+ LiteralTrue # sent
+ ]
+ wsgihandler: typing.Optional[typing.Iterable[bytes]]
+ wsgiiterator: typing.Optional[typing.Iterator[bytes]]
+
+ def __init__(
+ self,
+ server: "SCGIServer",
+ connection: socket.socket,
+ maxrequestsize: int = 65536,
+ maxpostsize: int = 8<<20,
+ blocksize: int = 4096,
+ config: Environ = {},
+ ):
asyncore.dispatcher.__init__(self, connection)
- self.server = server # WSGISCGIServer instance
+ self.server = server
self.maxrequestsize = maxrequestsize
self.maxpostsize = maxpostsize
self.blocksize = blocksize
@@ -35,10 +65,9 @@ class SCGIConnection(asyncore.dispatcher):
self.wsgihandler = None # wsgi application
self.wsgiiterator = None # wsgi application iterator
self.outheaders = () # headers to be sent
- # () -> unset, (..,..) -> set, True -> sent
self.body = io.BytesIO() # request body
- def _try_send_headers(self):
+ def _try_send_headers(self) -> None:
if self.outheaders != True:
assert not self.outbuff
status, headers = self.outheaders
@@ -47,7 +76,7 @@ class SCGIConnection(asyncore.dispatcher):
self.outbuff = str2bytes(headdata)
self.outheaders = True
- def _wsgi_write(self, data):
+ def _wsgi_write(self, data: bytes) -> None:
assert self.state >= SCGIConnection.RESP
assert self.state < SCGIConnection.TRANS
assert isinstance(data, bytes)
@@ -55,15 +84,15 @@ class SCGIConnection(asyncore.dispatcher):
self._try_send_headers()
self.outbuff += data
- def readable(self):
+ def readable(self) -> bool:
"""C{asyncore} interface"""
return self.state & 1 == 1
- def writable(self):
+ def writable(self) -> bool:
"""C{asyncore} interface"""
return self.state & 2 == 2
- def handle_read(self):
+ def handle_read(self) -> None:
"""C{asyncore} interface"""
data = self.recv(self.blocksize)
self.inbuff += data
@@ -122,6 +151,7 @@ class SCGIConnection(asyncore.dispatcher):
self.environ["wsgi.errors"] = self.server.error
self.wsgihandler = self.server.wsgiapp(self.environ,
self.start_response)
+ assert self.wsgihandler is not None
if isinstance(self.wsgihandler, FileWrapper) and \
self.wsgihandler.can_transfer():
self._try_send_headers()
@@ -134,7 +164,12 @@ class SCGIConnection(asyncore.dispatcher):
self.reqlen -= len(self.inbuff)
self.inbuff = b""
- def start_response(self, status, headers, exc_info=None):
+ def start_response(
+ self,
+ status: str,
+ headers: HeaderList,
+ exc_info: OptExcInfo = None,
+ ) -> WriteCallback:
assert isinstance(status, str)
assert isinstance(headers, list)
if exc_info:
@@ -147,7 +182,7 @@ class SCGIConnection(asyncore.dispatcher):
self.outheaders = (status, headers)
return self._wsgi_write
- def send_buff(self):
+ def send_buff(self) -> None:
try:
sentbytes = self.send(self.outbuff[:self.blocksize])
except socket.error:
@@ -155,11 +190,12 @@ class SCGIConnection(asyncore.dispatcher):
else:
self.outbuff = self.outbuff[sentbytes:]
- def handle_write(self):
+ def handle_write(self) -> None:
"""C{asyncore} interface"""
if self.state == SCGIConnection.RESP:
if len(self.outbuff) < self.blocksize:
self._try_send_headers()
+ assert self.wsgiiterator is not None
for data in self.wsgiiterator:
assert isinstance(data, bytes)
if data:
@@ -171,23 +207,26 @@ class SCGIConnection(asyncore.dispatcher):
self.send_buff()
elif self.state == SCGIConnection.RESPH:
assert len(self.outbuff) > 0
+ assert isinstance(self.wsgihandler, FileWrapper)
self.send_buff()
if not self.outbuff:
self.state = SCGIConnection.TRANS
else:
assert self.state == SCGIConnection.TRANS
+ assert isinstance(self.wsgihandler, FileWrapper)
assert self.wsgihandler.can_transfer()
sent = self.wsgihandler.transfer(self.socket, self.blocksize)
if sent <= 0:
self.close()
- def close(self):
+ def close(self) -> None:
# None doesn't have a close attribute
if hasattr(self.wsgihandler, "close"):
+ assert self.wsgihandler is not None
self.wsgihandler.close()
asyncore.dispatcher.close(self)
- def handle_close(self):
+ def handle_close(self) -> None:
"""C{asyncore} interface"""
self.close()
@@ -195,33 +234,36 @@ __all__.append("SCGIServer")
class SCGIServer(asyncore.dispatcher):
"""SCGI Server for WSGI applications. It does not use multiple processes or
multiple threads."""
- def __init__(self, wsgiapp, port, interface="localhost", error=sys.stderr,
- maxrequestsize=None, maxpostsize=None, blocksize=None,
- config={}, reusesocket=None):
+
+ def __init__(
+ self,
+ wsgiapp: WsgiApp,
+ port: int,
+ interface: str = "localhost",
+ error: typing.TextIO = sys.stderr,
+ maxrequestsize: typing.Optional[int] = None,
+ maxpostsize: typing.Optional[int] = None,
+ blocksize: typing.Optional[int] = None,
+ config: Environ = {},
+ reusesocket: typing.Optional[socket.socket] = None
+ ):
"""
@param wsgiapp: is the wsgi application to be run.
- @type port: int
@param port: is an int representing the TCP port number to be used.
- @type interface: str
@param interface: is a string specifying the network interface to bind
which defaults to C{"localhost"} making the server inaccessible
over network.
@param error: is a file-like object being passed as C{wsgi.error} in the
environ parameter defaulting to stderr.
- @type maxrequestsize: int
@param maxrequestsize: limit the size of request blocks in scgi
connections. Connections are dropped when this limit is hit.
- @type maxpostsize: int
@param maxpostsize: limit the size of post bodies that may be processed
by this instance. Connections are dropped when this limit is
hit.
- @type blocksize: int
@param blocksize: is amount of data to read or write from or to the
network at once
- @type config: {}
@param config: the environ dictionary is updated using these values for
each request.
- @type reusesocket: None or socket.socket
@param reusesocket: If a socket is passed, do not create a socket.
Instead use given socket as listen socket. The passed socket
must be set up for accepting tcp connections (i.e. C{AF_INET},
@@ -234,7 +276,7 @@ class SCGIServer(asyncore.dispatcher):
self.wsgiapp = wsgiapp
self.error = error
- self.conf = {}
+ self.conf: Environ = {}
if maxrequestsize is not None:
self.conf["maxrequestsize"] = maxrequestsize
if maxpostsize is not None:
@@ -251,7 +293,7 @@ class SCGIServer(asyncore.dispatcher):
else:
self.accepting = True
- def handle_accept(self):
+ def handle_accept(self) -> None:
"""asyncore interface"""
try:
ret = self.accept()
@@ -264,7 +306,7 @@ class SCGIServer(asyncore.dispatcher):
conn, _ = ret
SCGIConnection(self, conn, **self.conf)
- def run(self):
+ def run(self) -> None:
"""Runs the server. It will not return and you can invoke
C{asyncore.loop()} instead achieving the same effect."""
asyncore.loop()
diff --git a/wsgitools/scgi/forkpool.py b/wsgitools/scgi/forkpool.py
index df8a92f..f864a6a 100644
--- a/wsgitools/scgi/forkpool.py
+++ b/wsgitools/scgi/forkpool.py
@@ -12,21 +12,24 @@ import select
import sys
import errno
import signal
+import typing
-from wsgitools.internal import bytes2str, str2bytes
+from wsgitools.internal import (
+ bytes2str, HeaderList, OptExcInfo, str2bytes, WriteCallback, WsgiApp
+)
from wsgitools.scgi import _convert_environ, FileWrapper
__all__ = []
class SocketFileWrapper:
"""Wraps a socket to a wsgi-compliant file-like object."""
- def __init__(self, sock, toread):
+ def __init__(self, sock: socket.socket, toread: int):
"""@param sock: is a C{socket.socket()}"""
self.sock = sock
self.buff = b""
self.toread = toread
- def _recv(self, size=4096):
+ def _recv(self, size: int = 4096) -> bytes:
"""
internal method for receiving and counting incoming data
@raises socket.error:
@@ -44,7 +47,7 @@ class SocketFileWrapper:
self.toread -= len(data)
return data
- def close(self):
+ def close(self) -> None:
"""Does not close the socket, because it might still be needed. It
reads all data that should have been read as given by C{CONTENT_LENGTH}.
"""
@@ -55,7 +58,7 @@ class SocketFileWrapper:
except socket.error:
pass
- def read(self, size=None):
+ def read(self, size: typing.Optional[int] = None) -> bytes:
"""
see pep333
@raises socket.error:
@@ -92,7 +95,7 @@ class SocketFileWrapper:
ret, self.buff = self.buff, b""
return ret
- def readline(self, size=None):
+ def readline(self, size: typing.Optional[int] = None) -> bytes:
"""
see pep333
@raises socket.error:
@@ -118,7 +121,7 @@ class SocketFileWrapper:
return ret
self.buff += data
- def readlines(self):
+ def readlines(self) -> typing.Iterator[bytes]:
"""
see pep333
@raises socket.error:
@@ -127,10 +130,10 @@ class SocketFileWrapper:
while data:
yield data
data = self.readline()
- def __iter__(self):
+ def __iter__(self) -> typing.Iterator[bytes]:
"""see pep333"""
return self
- def __next__(self):
+ def __next__(self) -> bytes:
"""
see pep333
@raises socket.error:
@@ -139,9 +142,9 @@ class SocketFileWrapper:
if not data:
raise StopIteration
return data
- def flush(self):
+ def flush(self) -> None:
"""see pep333"""
- def write(self, data):
+ def write(self, data: bytes) -> None:
"""see pep333"""
assert isinstance(data, bytes)
try:
@@ -149,7 +152,7 @@ class SocketFileWrapper:
except socket.error:
# ignore all socket errors: there is no way to report
return
- def writelines(self, lines):
+ def writelines(self, lines: typing.List[bytes]) -> None:
"""see pep333"""
for line in lines:
self.write(line)
@@ -161,49 +164,51 @@ class SCGIServer:
class WorkerState:
"""state: 0 means idle and 1 means working.
These values are also sent as strings '0' and '1' over the socket."""
- def __init__(self, pid, sock, state):
- """
- @type pid: int
- @type state: int
- """
+ def __init__(self, pid: int, sock: socket.socket, state: int):
self.pid = pid
self.sock = sock
self.state = state
- def __init__(self, wsgiapp, port, interface="localhost", error=sys.stderr,
- minworkers=2, maxworkers=32, maxrequests=1000, config={},
- reusesocket=None, cpulimit=None, timelimit=None):
+ server: typing.Optional[socket.socket]
+ workers: typing.Dict[int, WorkerState]
+ sigpipe: typing.Optional[typing.Tuple[socket.socket, socket.socket]]
+
+ def __init__(
+ self,
+ wsgiapp: WsgiApp,
+ port: int,
+ interface: str = "localhost",
+ error: typing.TextIO = sys.stderr,
+ minworkers: int = 2,
+ maxworkers: int = 32,
+ maxrequests: int = 1000,
+ config: typing.Dict[typing.Any, typing.Any] = {},
+ reusesocket: typing.Optional[socket.socket] = None,
+ cpulimit: typing.Optional[typing.Tuple[int, int]] = None,
+ timelimit: typing.Optional[int] = None,
+ ):
"""
@param wsgiapp: is the WSGI application to be run.
- @type port: int
@param port: is the tcp port to listen on
- @type interface: str
@param interface: is the interface to bind to (default: C{"localhost"})
@param error: is a file-like object beeing passed as C{wsgi.errors} in
environ
- @type minworkers: int
@param minworkers: is the number of worker processes to spawn
- @type maxworkers: int
@param maxworkers: is the maximum number of workers that can be spawned
on demand
- @type maxrequests: int
@param maxrequests: is the number of requests a worker processes before
dying
- @type config: {}
@param config: the environ dictionary is updated using these values for
each request.
- @type reusesocket: None or socket.socket
@param reusesocket: If a socket is passed, do not create a socket.
Instead use given socket as listen socket. The passed socket
must be set up for accepting tcp connections (i.e. C{AF_INET},
C{SOCK_STREAM} with bind and listen called).
- @type cpulimit: (int, int)
@param cpulimit: a pair of soft and hard cpu time limit in seconds.
This limit is installed for each worker using RLIMIT_CPU if
resource limits are available to this platform. After reaching
the soft limit workers will continue to process the current
request and then cleanly terminate.
- @type timelimit: int
@param timelimit: The maximum number of wall clock seconds processing
a request should take. If this is specified, an alarm timer is
installed and the default action is to kill the worker.
@@ -229,7 +234,7 @@ class SCGIServer:
self.running = False
self.ischild = False
- def enable_sighandler(self, sig=signal.SIGTERM):
+ def enable_sighandler(self, sig: int = signal.SIGTERM) -> "SCGIServer":
"""
Changes the signal handler for the given signal to terminate the run()
loop.
@@ -239,7 +244,7 @@ class SCGIServer:
signal.signal(sig, self.shutdownhandler)
return self
- def run(self):
+ def run(self) -> None:
"""
Serve the wsgi application.
"""
@@ -296,7 +301,7 @@ class SCGIServer:
self.sigpipe = None
self.killworkers()
- def killworkers(self, sig=signal.SIGTERM):
+ def killworkers(self, sig: int = signal.SIGTERM) -> None:
"""
Kills all worker children.
@param sig: is the signal used to kill the children
@@ -307,7 +312,9 @@ class SCGIServer:
os.kill(state.pid, sig)
# TODO: handle working children with a timeout
- def shutdownhandler(self, sig=None, stackframe=None):
+ def shutdownhandler(
+ self, sig: typing.Optional[int] = None, stackframe=None
+ ) -> None:
"""
Signal handler function for stopping the run() loop. It works by
setting a variable that run() evaluates in each loop. As a signal
@@ -319,10 +326,13 @@ class SCGIServer:
if self.ischild:
sys.exit()
elif self.running:
+ assert self.sigpipe is not None
self.running = False
self.sigpipe[1].send(b' ')
- def sigxcpuhandler(self, sig=None, stackframe=None):
+ def sigxcpuhandler(
+ self, sig: typing.Optional[int] = None, stackframe=None
+ ) -> None:
"""
Signal hanlder function for the SIGXCUP signal. It is sent to a
worker when the soft RLIMIT_CPU is crossed.
@@ -331,7 +341,7 @@ class SCGIServer:
"""
self.cpulimit = True
- def spawnworker(self):
+ def spawnworker(self) -> None:
"""
internal! spawns a single worker
"""
@@ -366,11 +376,12 @@ class SCGIServer:
else:
raise RuntimeError("fork failed")
- def work(self, worksock):
+ def work(self, worksock: socket.socket) -> None:
"""
internal! serves maxrequests times
@raises socket.error:
"""
+ assert self.server is not None
for _ in range(self.maxrequests):
(con, addr) = self.server.accept()
# we cannot handle socket.errors here.
@@ -384,7 +395,7 @@ class SCGIServer:
if self.cpulimit:
break
- def process(self, con):
+ def process(self, con: socket.socket) -> None:
"""
internal! processes a single request on the connection con.
"""
@@ -439,9 +450,9 @@ class SCGIServer:
# 0 -> False: set but unsent
# 0 -> True: sent
# 1 -> bytes of the complete header
- response_head = [None, None]
+ response_head: typing.List[typing.Any] = [None, None]
- def sendheaders():
+ def sendheaders() -> None:
assert response_head[0] is not None # headers set
if response_head[0] != True:
response_head[0] = True
@@ -450,14 +461,16 @@ class SCGIServer:
except socket.error:
pass
- def dumbsend(data):
+ def dumbsend(data: bytes) -> None:
sendheaders()
try:
con.sendall(data)
except socket.error:
pass
- def start_response(status, headers, exc_info=None):
+ def start_response(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
if exc_info and response_head[0]:
try:
raise exc_info[1].with_traceback(exc_info[2])