From a41066b413489b407b9d99174af697563ad680b9 Mon Sep 17 00:00:00 2001 From: Helmut Grohne Date: Mon, 13 Apr 2020 21:30:34 +0200 Subject: add type hints to all of the code In order to use type hint syntax, we need to bump the minimum Python version to 3.7 and some of the features such as Literal and Protocol are opted in when a sufficiently recent Python is available. This does not make all of the code pass type checking with mypy. A number of typing issues remain, but the output of mypy becomes something one can read through. In adding type hints, a lot of epydoc @type annotations are removed as redundant. This update also adopts black-style line breaking. --- test.py | 207 ++++++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 116 insertions(+), 91 deletions(-) (limited to 'test.py') diff --git a/test.py b/test.py index 9690baf..5271819 100755 --- a/test.py +++ b/test.py @@ -8,16 +8,23 @@ import wsgiref.validate import io from hashlib import md5 import sys - -from wsgitools.internal import bytes2str, str2bytes +import typing + +from wsgitools.internal import ( + bytes2str, + Environ, + HeaderList, + OptExcInfo, + StartResponse, + str2bytes, + WriteCallback, + WsgiApp, +) class Request: - def __init__(self, case): - """ - @type case: unittest.TestCase - """ + def __init__(self, case: unittest.TestCase): self.testcase = case - self.environ = dict( + self.environ: Environ = dict( REQUEST_METHOD="GET", SERVER_NAME="localhost", SERVER_PORT="80", @@ -33,86 +40,92 @@ class Request: "wsgi.multiprocess": False, "wsgi.run_once": False}) - def setenv(self, key, value): + def setenv(self, key: str, value: str) -> "Request": """ - @type key: str - @type value: str @returns: self """ self.environ[key] = value return self - def setmethod(self, request_method): + def setmethod(self, request_method: str) -> "Request": """ - @type request_method: str @returns: self """ return self.setenv("REQUEST_METHOD", request_method) - def setheader(self, name, value): + def setheader(self, name: str, value: str) -> "Request": """ - @type name: str - @type value: str @returns: self """ return self.setenv("HTTP_" + name.upper().replace('-', '_'), value) - def copy(self): + def copy(self) -> "Request": req = Request(self.testcase) req.environ = dict(self.environ) return req - def __call__(self, app): + def __call__(self, app: WsgiApp) -> "Result": app = wsgiref.validate.validator(app) - res = Result(self.testcase) - def write(data): - res.writtendata.append(data) - def start_response(status, headers, exc_info=None): - res.statusdata = status - res.headersdata = headers + writtendata: typing.List[bytes] = [] + def write(data: bytes) -> None: + nonlocal writtendata + writtendata.append(data) + statusdata: typing.Optional[str] = None + headersdata: typing.Optional[HeaderList] = None + def start_response( + status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: + nonlocal statusdata, headersdata + statusdata = status + headersdata = headers return write iterator = app(self.environ, start_response) - res.returneddata = list(iterator) + returneddata = list(iterator) if hasattr(iterator, "close"): iterator.close() - return res + assert statusdata is not None + assert headersdata is not None + return Result( + self.testcase, statusdata, headersdata, writtendata, returneddata + ) class Result: - def __init__(self, case): - """ - @type case: unittest.TestCase - """ + def __init__( + self, + case: unittest.TestCase, + statusdata: str, + headersdata: HeaderList, + writtendata: typing.List[bytes], + returneddata: typing.List[bytes], + ): self.testcase = case - self.statusdata = None - self.headersdata = None - self.writtendata = [] - self.returneddata = None + self.statusdata = statusdata + self.headersdata = headersdata + self.writtendata = writtendata + self.returneddata = returneddata - def status(self, check): - """ - @type check: int or str - """ + def status(self, check: typing.Union[int, str]) -> None: if isinstance(check, int): + assert self.statusdata is not None status = int(self.statusdata.split()[0]) self.testcase.assertEqual(check, status) else: self.testcase.assertEqual(check, self.statusdata) - def getheader(self, name): + def getheader(self, name: str) -> str: """ - @type name: str @raises KeyError: """ + assert self.headersdata is not None for key, value in self.headersdata: if key == name: return value raise KeyError - def header(self, name, check): - """ - @type name: str - @type check: str or (str -> bool) - """ + def header( + self, name: str, check: typing.Union[str, typing.Callable[[str], bool]] + ) -> None: + assert self.headersdata is not None found = False for key, value in self.headersdata: if key == name: @@ -124,23 +137,23 @@ class Result: if not found: self.testcase.fail("header %s not found" % name) - def get_data(self): + def get_data(self) -> bytes: return b"".join(self.writtendata) + b"".join(self.returneddata) from wsgitools import applications class StaticContentTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], b"nothing") self.req = Request(self) - def testGet(self): + def testGet(self) -> None: res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") - def testHead(self): + def testHead(self) -> None: req = self.req.copy() req.setmethod("HEAD") res = req(self.app) @@ -148,17 +161,17 @@ class StaticContentTest(unittest.TestCase): res.header("Content-length", "7") class StaticFileTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = applications.StaticFile(io.BytesIO(b"success"), "200 Found", [("Content-Type", "text/plain")]) self.req = Request(self) - def testGet(self): + def testGet(self) -> None: res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") - def testHead(self): + def testHead(self) -> None: req = self.req.copy() req.setmethod("HEAD") res = req(self.app) @@ -168,7 +181,7 @@ class StaticFileTest(unittest.TestCase): from wsgitools import digest class AuthDigestMiddlewareTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.staticapp = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], b"success") token_gen = digest.AuthTokenGenerator("foo", lambda _: "baz") @@ -176,26 +189,26 @@ class AuthDigestMiddlewareTest(unittest.TestCase): wsgiref.validate.validator(self.staticapp), token_gen) self.req = Request(self) - def test401(self): + def test401(self) -> None: res = self.req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def test401garbage(self): + def test401garbage(self) -> None: req = self.req.copy() req.setheader('Authorization', 'Garbage') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def test401digestgarbage(self): + def test401digestgarbage(self) -> None: req = self.req.copy() req.setheader('Authorization', 'Digest ","') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def doauth(self, password="baz", status=200): + def doauth(self, password: str = "baz", status: int = 200) -> None: res = self.req(self.app) nonce = next(iter(filter(lambda x: x.startswith("nonce="), res.getheader("WWW-Authenticate").split()))) @@ -209,13 +222,13 @@ class AuthDigestMiddlewareTest(unittest.TestCase): res = req(self.app) res.status(status) - def test200(self): + def test200(self) -> None: self.doauth() - def test401authfail(self): + def test401authfail(self) -> None: self.doauth(password="spam", status=401) - def testqopauth(self): + def testqopauth(self) -> None: res = self.req(self.app) nonce = next(iter(filter(lambda x: x.startswith("nonce="), res.getheader("WWW-Authenticate").split()))) @@ -233,27 +246,31 @@ class AuthDigestMiddlewareTest(unittest.TestCase): from wsgitools import middlewares -def writing_application(environ, start_response): +def writing_application( + environ: Environ, start_response: StartResponse +) -> typing.Iterable[bytes]: write = start_response("404 Not found", [("Content-Type", "text/plain")]) write = start_response("200 Ok", [("Content-Type", "text/plain")]) write(b"first") yield b"" yield b"second" -def write_only_application(environ, start_response): +def write_only_application( + environ: Environ, start_response: StartResponse +) -> typing.Iterable[bytes]: write = start_response("200 Ok", [("Content-Type", "text/plain")]) write(b"first") write(b"second") yield b"" class NoWriteCallableMiddlewareTest(unittest.TestCase): - def testWrite(self): + def testWrite(self) -> None: app = middlewares.NoWriteCallableMiddleware(writing_application) res = Request(self)(app) self.assertEqual(res.writtendata, []) self.assertEqual(b"".join(res.returneddata), b"firstsecond") - def testWriteOnly(self): + def testWriteOnly(self) -> None: app = middlewares.NoWriteCallableMiddleware(write_only_application) res = Request(self)(app) self.assertEqual(res.writtendata, []) @@ -262,28 +279,30 @@ class NoWriteCallableMiddlewareTest(unittest.TestCase): class StupidIO: """file-like without tell method, so StaticFile is not able to determine the content-length.""" - def __init__(self, content): + def __init__(self, content: bytes): self.content = content self.position = 0 - def seek(self, pos): + def seek(self, pos: int) -> None: assert pos == 0 self.position = 0 - def read(self, length): + def read(self, length: int) -> bytes: oldpos = self.position self.position += length return self.content[oldpos:self.position] class ContentLengthMiddlewareTest(unittest.TestCase): - def customSetUp(self, maxstore=10): + def customSetUp( + self, maxstore: typing.Union[int, typing.Tuple[()]] = 10 + ) -> None: self.staticapp = applications.StaticFile(StupidIO(b"success"), "200 Found", [("Content-Type", "text/plain")]) self.app = middlewares.ContentLengthMiddleware(self.staticapp, maxstore=maxstore) self.req = Request(self) - def testWithout(self): + def testWithout(self) -> None: self.customSetUp() res = self.req(self.staticapp) res.status("200 Found") @@ -293,24 +312,26 @@ class ContentLengthMiddlewareTest(unittest.TestCase): except KeyError: pass - def testGet(self): + def testGet(self) -> None: self.customSetUp() res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") - def testInfiniteMaxstore(self): + def testInfiniteMaxstore(self) -> None: self.customSetUp(maxstore=()) res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") class CachingMiddlewareTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.cached = middlewares.CachingMiddleware(self.app) - self.accessed = dict() + self.accessed: typing.Dict[str, int] = {} - def app(self, environ, start_response): + def app( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: count = self.accessed.get(environ["SCRIPT_NAME"], 0) + 1 self.accessed[environ["SCRIPT_NAME"]] = count headers = [("Content-Type", "text/plain")] @@ -319,7 +340,7 @@ class CachingMiddlewareTest(unittest.TestCase): start_response("200 Found", headers) return [b"%d" % count] - def testCache(self): + def testCache(self) -> None: res = Request(self)(self.cached) res.status(200) self.assertEqual(res.get_data(), b"1") @@ -327,7 +348,7 @@ class CachingMiddlewareTest(unittest.TestCase): res.status(200) self.assertEqual(res.get_data(), b"1") - def testNoCache(self): + def testNoCache(self) -> None: res = Request(self)(self.cached) res.status(200) self.assertEqual(res.get_data(), b"1") @@ -337,7 +358,7 @@ class CachingMiddlewareTest(unittest.TestCase): self.assertEqual(res.get_data(), b"2") class BasicAuthMiddlewareTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.staticapp = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], b"success") checkpw = middlewares.DictAuthChecker({"bar": "baz"}) @@ -345,26 +366,26 @@ class BasicAuthMiddlewareTest(unittest.TestCase): wsgiref.validate.validator(self.staticapp), checkpw) self.req = Request(self) - def test401(self): + def test401(self) -> None: res = self.req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def test401garbage(self): + def test401garbage(self) -> None: req = self.req.copy() req.setheader('Authorization', 'Garbage') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def test401basicgarbage(self): + def test401basicgarbage(self) -> None: req = self.req.copy() req.setheader('Authorization', 'Basic ()') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def doauth(self, password="baz", status=200): + def doauth(self, password: str = "baz", status: int = 200) -> None: req = self.req.copy() token = "bar:%s" % password token = bytes2str(base64.b64encode(str2bytes(token))) @@ -372,19 +393,20 @@ class BasicAuthMiddlewareTest(unittest.TestCase): res = req(self.app) res.status(status) - def test200(self): + def test200(self) -> None: self.doauth() - def test401authfail(self): + def test401authfail(self) -> None: self.doauth(password="spam", status=401) from wsgitools import filters import gzip class RequestLogWSGIFilterTest(unittest.TestCase): - def testSimple(self): - app = applications.StaticContent("200 Found", - [("Content-Type", "text/plain")], b"nothing") + def testSimple(self) -> None: + app: WsgiApp = applications.StaticContent( + "200 Found", [("Content-Type", "text/plain")], b"nothing" + ) log = io.StringIO() logfilter = filters.RequestLogWSGIFilter.creator(log) app = filters.WSGIFilterMiddleware(app, logfilter) @@ -398,9 +420,10 @@ class RequestLogWSGIFilterTest(unittest.TestCase): r'200 7 - "wsgitools-test"', logged)) class GzipWSGIFilterTest(unittest.TestCase): - def testSimple(self): - app = applications.StaticContent("200 Found", - [("Content-Type", "text/plain")], b"nothing") + def testSimple(self) -> None: + app: WsgiApp = applications.StaticContent( + "200 Found", [("Content-Type", "text/plain")], b"nothing" + ) app = filters.WSGIFilterMiddleware(app, filters.GzipWSGIFilter) req = Request(self) req.environ["HTTP_ACCEPT_ENCODING"] = "gzip" @@ -413,7 +436,9 @@ import asyncore import socket import threading -def fetch_scgi(port, req, body=b""): +def fetch_scgi( + port: int, req: typing.Dict[str, str], body: bytes = b"" +) -> bytes: with socket.socket() as client: client.connect(("localhost", port)) req = req.copy() @@ -423,7 +448,7 @@ def fetch_scgi(port, req, body=b""): return client.recv(65536) class ScgiAsynchronousTest(unittest.TestCase): - def testSimple(self): + def testSimple(self) -> None: app = applications.StaticContent( "200 OK", [("Content-Type", "text/plain")], b"nothing" ) @@ -447,7 +472,7 @@ import os import signal class ScgiForkTest(unittest.TestCase): - def testSimple(self): + def testSimple(self) -> None: app = applications.StaticContent( "200 OK", [("Content-Type", "text/plain")], b"nothing" ) -- cgit v1.2.3