summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
authorHelmut Grohne <helmut@subdivi.de>2020-04-13 21:30:34 +0200
committerHelmut Grohne <helmut@subdivi.de>2023-06-18 23:16:57 +0200
commita41066b413489b407b9d99174af697563ad680b9 (patch)
tree2f08f9e886e13a7500d1eb527e30737d961deab6 /test.py
parent4d52eaa4801df3f3169df8e58758bcccf22dc4de (diff)
downloadwsgitools-a41066b413489b407b9d99174af697563ad680b9.tar.gz
add type hints to all of the code
In order to use type hint syntax, we need to bump the minimum Python version to 3.7 and some of the features such as Literal and Protocol are opted in when a sufficiently recent Python is available. This does not make all of the code pass type checking with mypy. A number of typing issues remain, but the output of mypy becomes something one can read through. In adding type hints, a lot of epydoc @type annotations are removed as redundant. This update also adopts black-style line breaking.
Diffstat (limited to 'test.py')
-rwxr-xr-xtest.py207
1 files changed, 116 insertions, 91 deletions
diff --git a/test.py b/test.py
index 9690baf..5271819 100755
--- a/test.py
+++ b/test.py
@@ -8,16 +8,23 @@ import wsgiref.validate
import io
from hashlib import md5
import sys
-
-from wsgitools.internal import bytes2str, str2bytes
+import typing
+
+from wsgitools.internal import (
+ bytes2str,
+ Environ,
+ HeaderList,
+ OptExcInfo,
+ StartResponse,
+ str2bytes,
+ WriteCallback,
+ WsgiApp,
+)
class Request:
- def __init__(self, case):
- """
- @type case: unittest.TestCase
- """
+ def __init__(self, case: unittest.TestCase):
self.testcase = case
- self.environ = dict(
+ self.environ: Environ = dict(
REQUEST_METHOD="GET",
SERVER_NAME="localhost",
SERVER_PORT="80",
@@ -33,86 +40,92 @@ class Request:
"wsgi.multiprocess": False,
"wsgi.run_once": False})
- def setenv(self, key, value):
+ def setenv(self, key: str, value: str) -> "Request":
"""
- @type key: str
- @type value: str
@returns: self
"""
self.environ[key] = value
return self
- def setmethod(self, request_method):
+ def setmethod(self, request_method: str) -> "Request":
"""
- @type request_method: str
@returns: self
"""
return self.setenv("REQUEST_METHOD", request_method)
- def setheader(self, name, value):
+ def setheader(self, name: str, value: str) -> "Request":
"""
- @type name: str
- @type value: str
@returns: self
"""
return self.setenv("HTTP_" + name.upper().replace('-', '_'), value)
- def copy(self):
+ def copy(self) -> "Request":
req = Request(self.testcase)
req.environ = dict(self.environ)
return req
- def __call__(self, app):
+ def __call__(self, app: WsgiApp) -> "Result":
app = wsgiref.validate.validator(app)
- res = Result(self.testcase)
- def write(data):
- res.writtendata.append(data)
- def start_response(status, headers, exc_info=None):
- res.statusdata = status
- res.headersdata = headers
+ writtendata: typing.List[bytes] = []
+ def write(data: bytes) -> None:
+ nonlocal writtendata
+ writtendata.append(data)
+ statusdata: typing.Optional[str] = None
+ headersdata: typing.Optional[HeaderList] = None
+ def start_response(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
+ nonlocal statusdata, headersdata
+ statusdata = status
+ headersdata = headers
return write
iterator = app(self.environ, start_response)
- res.returneddata = list(iterator)
+ returneddata = list(iterator)
if hasattr(iterator, "close"):
iterator.close()
- return res
+ assert statusdata is not None
+ assert headersdata is not None
+ return Result(
+ self.testcase, statusdata, headersdata, writtendata, returneddata
+ )
class Result:
- def __init__(self, case):
- """
- @type case: unittest.TestCase
- """
+ def __init__(
+ self,
+ case: unittest.TestCase,
+ statusdata: str,
+ headersdata: HeaderList,
+ writtendata: typing.List[bytes],
+ returneddata: typing.List[bytes],
+ ):
self.testcase = case
- self.statusdata = None
- self.headersdata = None
- self.writtendata = []
- self.returneddata = None
+ self.statusdata = statusdata
+ self.headersdata = headersdata
+ self.writtendata = writtendata
+ self.returneddata = returneddata
- def status(self, check):
- """
- @type check: int or str
- """
+ def status(self, check: typing.Union[int, str]) -> None:
if isinstance(check, int):
+ assert self.statusdata is not None
status = int(self.statusdata.split()[0])
self.testcase.assertEqual(check, status)
else:
self.testcase.assertEqual(check, self.statusdata)
- def getheader(self, name):
+ def getheader(self, name: str) -> str:
"""
- @type name: str
@raises KeyError:
"""
+ assert self.headersdata is not None
for key, value in self.headersdata:
if key == name:
return value
raise KeyError
- def header(self, name, check):
- """
- @type name: str
- @type check: str or (str -> bool)
- """
+ def header(
+ self, name: str, check: typing.Union[str, typing.Callable[[str], bool]]
+ ) -> None:
+ assert self.headersdata is not None
found = False
for key, value in self.headersdata:
if key == name:
@@ -124,23 +137,23 @@ class Result:
if not found:
self.testcase.fail("header %s not found" % name)
- def get_data(self):
+ def get_data(self) -> bytes:
return b"".join(self.writtendata) + b"".join(self.returneddata)
from wsgitools import applications
class StaticContentTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.app = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"nothing")
self.req = Request(self)
- def testGet(self):
+ def testGet(self) -> None:
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testHead(self):
+ def testHead(self) -> None:
req = self.req.copy()
req.setmethod("HEAD")
res = req(self.app)
@@ -148,17 +161,17 @@ class StaticContentTest(unittest.TestCase):
res.header("Content-length", "7")
class StaticFileTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.app = applications.StaticFile(io.BytesIO(b"success"), "200 Found",
[("Content-Type", "text/plain")])
self.req = Request(self)
- def testGet(self):
+ def testGet(self) -> None:
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testHead(self):
+ def testHead(self) -> None:
req = self.req.copy()
req.setmethod("HEAD")
res = req(self.app)
@@ -168,7 +181,7 @@ class StaticFileTest(unittest.TestCase):
from wsgitools import digest
class AuthDigestMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.staticapp = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"success")
token_gen = digest.AuthTokenGenerator("foo", lambda _: "baz")
@@ -176,26 +189,26 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
wsgiref.validate.validator(self.staticapp), token_gen)
self.req = Request(self)
- def test401(self):
+ def test401(self) -> None:
res = self.req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401garbage(self):
+ def test401garbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Garbage')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401digestgarbage(self):
+ def test401digestgarbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Digest ","')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def doauth(self, password="baz", status=200):
+ def doauth(self, password: str = "baz", status: int = 200) -> None:
res = self.req(self.app)
nonce = next(iter(filter(lambda x: x.startswith("nonce="),
res.getheader("WWW-Authenticate").split())))
@@ -209,13 +222,13 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
res = req(self.app)
res.status(status)
- def test200(self):
+ def test200(self) -> None:
self.doauth()
- def test401authfail(self):
+ def test401authfail(self) -> None:
self.doauth(password="spam", status=401)
- def testqopauth(self):
+ def testqopauth(self) -> None:
res = self.req(self.app)
nonce = next(iter(filter(lambda x: x.startswith("nonce="),
res.getheader("WWW-Authenticate").split())))
@@ -233,27 +246,31 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
from wsgitools import middlewares
-def writing_application(environ, start_response):
+def writing_application(
+ environ: Environ, start_response: StartResponse
+) -> typing.Iterable[bytes]:
write = start_response("404 Not found", [("Content-Type", "text/plain")])
write = start_response("200 Ok", [("Content-Type", "text/plain")])
write(b"first")
yield b""
yield b"second"
-def write_only_application(environ, start_response):
+def write_only_application(
+ environ: Environ, start_response: StartResponse
+) -> typing.Iterable[bytes]:
write = start_response("200 Ok", [("Content-Type", "text/plain")])
write(b"first")
write(b"second")
yield b""
class NoWriteCallableMiddlewareTest(unittest.TestCase):
- def testWrite(self):
+ def testWrite(self) -> None:
app = middlewares.NoWriteCallableMiddleware(writing_application)
res = Request(self)(app)
self.assertEqual(res.writtendata, [])
self.assertEqual(b"".join(res.returneddata), b"firstsecond")
- def testWriteOnly(self):
+ def testWriteOnly(self) -> None:
app = middlewares.NoWriteCallableMiddleware(write_only_application)
res = Request(self)(app)
self.assertEqual(res.writtendata, [])
@@ -262,28 +279,30 @@ class NoWriteCallableMiddlewareTest(unittest.TestCase):
class StupidIO:
"""file-like without tell method, so StaticFile is not able to
determine the content-length."""
- def __init__(self, content):
+ def __init__(self, content: bytes):
self.content = content
self.position = 0
- def seek(self, pos):
+ def seek(self, pos: int) -> None:
assert pos == 0
self.position = 0
- def read(self, length):
+ def read(self, length: int) -> bytes:
oldpos = self.position
self.position += length
return self.content[oldpos:self.position]
class ContentLengthMiddlewareTest(unittest.TestCase):
- def customSetUp(self, maxstore=10):
+ def customSetUp(
+ self, maxstore: typing.Union[int, typing.Tuple[()]] = 10
+ ) -> None:
self.staticapp = applications.StaticFile(StupidIO(b"success"),
"200 Found", [("Content-Type", "text/plain")])
self.app = middlewares.ContentLengthMiddleware(self.staticapp,
maxstore=maxstore)
self.req = Request(self)
- def testWithout(self):
+ def testWithout(self) -> None:
self.customSetUp()
res = self.req(self.staticapp)
res.status("200 Found")
@@ -293,24 +312,26 @@ class ContentLengthMiddlewareTest(unittest.TestCase):
except KeyError:
pass
- def testGet(self):
+ def testGet(self) -> None:
self.customSetUp()
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testInfiniteMaxstore(self):
+ def testInfiniteMaxstore(self) -> None:
self.customSetUp(maxstore=())
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
class CachingMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.cached = middlewares.CachingMiddleware(self.app)
- self.accessed = dict()
+ self.accessed: typing.Dict[str, int] = {}
- def app(self, environ, start_response):
+ def app(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
count = self.accessed.get(environ["SCRIPT_NAME"], 0) + 1
self.accessed[environ["SCRIPT_NAME"]] = count
headers = [("Content-Type", "text/plain")]
@@ -319,7 +340,7 @@ class CachingMiddlewareTest(unittest.TestCase):
start_response("200 Found", headers)
return [b"%d" % count]
- def testCache(self):
+ def testCache(self) -> None:
res = Request(self)(self.cached)
res.status(200)
self.assertEqual(res.get_data(), b"1")
@@ -327,7 +348,7 @@ class CachingMiddlewareTest(unittest.TestCase):
res.status(200)
self.assertEqual(res.get_data(), b"1")
- def testNoCache(self):
+ def testNoCache(self) -> None:
res = Request(self)(self.cached)
res.status(200)
self.assertEqual(res.get_data(), b"1")
@@ -337,7 +358,7 @@ class CachingMiddlewareTest(unittest.TestCase):
self.assertEqual(res.get_data(), b"2")
class BasicAuthMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.staticapp = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"success")
checkpw = middlewares.DictAuthChecker({"bar": "baz"})
@@ -345,26 +366,26 @@ class BasicAuthMiddlewareTest(unittest.TestCase):
wsgiref.validate.validator(self.staticapp), checkpw)
self.req = Request(self)
- def test401(self):
+ def test401(self) -> None:
res = self.req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401garbage(self):
+ def test401garbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Garbage')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401basicgarbage(self):
+ def test401basicgarbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Basic ()')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def doauth(self, password="baz", status=200):
+ def doauth(self, password: str = "baz", status: int = 200) -> None:
req = self.req.copy()
token = "bar:%s" % password
token = bytes2str(base64.b64encode(str2bytes(token)))
@@ -372,19 +393,20 @@ class BasicAuthMiddlewareTest(unittest.TestCase):
res = req(self.app)
res.status(status)
- def test200(self):
+ def test200(self) -> None:
self.doauth()
- def test401authfail(self):
+ def test401authfail(self) -> None:
self.doauth(password="spam", status=401)
from wsgitools import filters
import gzip
class RequestLogWSGIFilterTest(unittest.TestCase):
- def testSimple(self):
- app = applications.StaticContent("200 Found",
- [("Content-Type", "text/plain")], b"nothing")
+ def testSimple(self) -> None:
+ app: WsgiApp = applications.StaticContent(
+ "200 Found", [("Content-Type", "text/plain")], b"nothing"
+ )
log = io.StringIO()
logfilter = filters.RequestLogWSGIFilter.creator(log)
app = filters.WSGIFilterMiddleware(app, logfilter)
@@ -398,9 +420,10 @@ class RequestLogWSGIFilterTest(unittest.TestCase):
r'200 7 - "wsgitools-test"', logged))
class GzipWSGIFilterTest(unittest.TestCase):
- def testSimple(self):
- app = applications.StaticContent("200 Found",
- [("Content-Type", "text/plain")], b"nothing")
+ def testSimple(self) -> None:
+ app: WsgiApp = applications.StaticContent(
+ "200 Found", [("Content-Type", "text/plain")], b"nothing"
+ )
app = filters.WSGIFilterMiddleware(app, filters.GzipWSGIFilter)
req = Request(self)
req.environ["HTTP_ACCEPT_ENCODING"] = "gzip"
@@ -413,7 +436,9 @@ import asyncore
import socket
import threading
-def fetch_scgi(port, req, body=b""):
+def fetch_scgi(
+ port: int, req: typing.Dict[str, str], body: bytes = b""
+) -> bytes:
with socket.socket() as client:
client.connect(("localhost", port))
req = req.copy()
@@ -423,7 +448,7 @@ def fetch_scgi(port, req, body=b""):
return client.recv(65536)
class ScgiAsynchronousTest(unittest.TestCase):
- def testSimple(self):
+ def testSimple(self) -> None:
app = applications.StaticContent(
"200 OK", [("Content-Type", "text/plain")], b"nothing"
)
@@ -447,7 +472,7 @@ import os
import signal
class ScgiForkTest(unittest.TestCase):
- def testSimple(self):
+ def testSimple(self) -> None:
app = applications.StaticContent(
"200 OK", [("Content-Type", "text/plain")], b"nothing"
)