diff options
Diffstat (limited to 'test.py')
-rwxr-xr-x | test.py | 207 |
1 files changed, 116 insertions, 91 deletions
@@ -8,16 +8,23 @@ import wsgiref.validate import io from hashlib import md5 import sys - -from wsgitools.internal import bytes2str, str2bytes +import typing + +from wsgitools.internal import ( + bytes2str, + Environ, + HeaderList, + OptExcInfo, + StartResponse, + str2bytes, + WriteCallback, + WsgiApp, +) class Request: - def __init__(self, case): - """ - @type case: unittest.TestCase - """ + def __init__(self, case: unittest.TestCase): self.testcase = case - self.environ = dict( + self.environ: Environ = dict( REQUEST_METHOD="GET", SERVER_NAME="localhost", SERVER_PORT="80", @@ -33,86 +40,92 @@ class Request: "wsgi.multiprocess": False, "wsgi.run_once": False}) - def setenv(self, key, value): + def setenv(self, key: str, value: str) -> "Request": """ - @type key: str - @type value: str @returns: self """ self.environ[key] = value return self - def setmethod(self, request_method): + def setmethod(self, request_method: str) -> "Request": """ - @type request_method: str @returns: self """ return self.setenv("REQUEST_METHOD", request_method) - def setheader(self, name, value): + def setheader(self, name: str, value: str) -> "Request": """ - @type name: str - @type value: str @returns: self """ return self.setenv("HTTP_" + name.upper().replace('-', '_'), value) - def copy(self): + def copy(self) -> "Request": req = Request(self.testcase) req.environ = dict(self.environ) return req - def __call__(self, app): + def __call__(self, app: WsgiApp) -> "Result": app = wsgiref.validate.validator(app) - res = Result(self.testcase) - def write(data): - res.writtendata.append(data) - def start_response(status, headers, exc_info=None): - res.statusdata = status - res.headersdata = headers + writtendata: typing.List[bytes] = [] + def write(data: bytes) -> None: + nonlocal writtendata + writtendata.append(data) + statusdata: typing.Optional[str] = None + headersdata: typing.Optional[HeaderList] = None + def start_response( + status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: + nonlocal statusdata, headersdata + statusdata = status + headersdata = headers return write iterator = app(self.environ, start_response) - res.returneddata = list(iterator) + returneddata = list(iterator) if hasattr(iterator, "close"): iterator.close() - return res + assert statusdata is not None + assert headersdata is not None + return Result( + self.testcase, statusdata, headersdata, writtendata, returneddata + ) class Result: - def __init__(self, case): - """ - @type case: unittest.TestCase - """ + def __init__( + self, + case: unittest.TestCase, + statusdata: str, + headersdata: HeaderList, + writtendata: typing.List[bytes], + returneddata: typing.List[bytes], + ): self.testcase = case - self.statusdata = None - self.headersdata = None - self.writtendata = [] - self.returneddata = None + self.statusdata = statusdata + self.headersdata = headersdata + self.writtendata = writtendata + self.returneddata = returneddata - def status(self, check): - """ - @type check: int or str - """ + def status(self, check: typing.Union[int, str]) -> None: if isinstance(check, int): + assert self.statusdata is not None status = int(self.statusdata.split()[0]) self.testcase.assertEqual(check, status) else: self.testcase.assertEqual(check, self.statusdata) - def getheader(self, name): + def getheader(self, name: str) -> str: """ - @type name: str @raises KeyError: """ + assert self.headersdata is not None for key, value in self.headersdata: if key == name: return value raise KeyError - def header(self, name, check): - """ - @type name: str - @type check: str or (str -> bool) - """ + def header( + self, name: str, check: typing.Union[str, typing.Callable[[str], bool]] + ) -> None: + assert self.headersdata is not None found = False for key, value in self.headersdata: if key == name: @@ -124,23 +137,23 @@ class Result: if not found: self.testcase.fail("header %s not found" % name) - def get_data(self): + def get_data(self) -> bytes: return b"".join(self.writtendata) + b"".join(self.returneddata) from wsgitools import applications class StaticContentTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], b"nothing") self.req = Request(self) - def testGet(self): + def testGet(self) -> None: res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") - def testHead(self): + def testHead(self) -> None: req = self.req.copy() req.setmethod("HEAD") res = req(self.app) @@ -148,17 +161,17 @@ class StaticContentTest(unittest.TestCase): res.header("Content-length", "7") class StaticFileTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = applications.StaticFile(io.BytesIO(b"success"), "200 Found", [("Content-Type", "text/plain")]) self.req = Request(self) - def testGet(self): + def testGet(self) -> None: res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") - def testHead(self): + def testHead(self) -> None: req = self.req.copy() req.setmethod("HEAD") res = req(self.app) @@ -168,7 +181,7 @@ class StaticFileTest(unittest.TestCase): from wsgitools import digest class AuthDigestMiddlewareTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.staticapp = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], b"success") token_gen = digest.AuthTokenGenerator("foo", lambda _: "baz") @@ -176,26 +189,26 @@ class AuthDigestMiddlewareTest(unittest.TestCase): wsgiref.validate.validator(self.staticapp), token_gen) self.req = Request(self) - def test401(self): + def test401(self) -> None: res = self.req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def test401garbage(self): + def test401garbage(self) -> None: req = self.req.copy() req.setheader('Authorization', 'Garbage') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def test401digestgarbage(self): + def test401digestgarbage(self) -> None: req = self.req.copy() req.setheader('Authorization', 'Digest ","') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def doauth(self, password="baz", status=200): + def doauth(self, password: str = "baz", status: int = 200) -> None: res = self.req(self.app) nonce = next(iter(filter(lambda x: x.startswith("nonce="), res.getheader("WWW-Authenticate").split()))) @@ -209,13 +222,13 @@ class AuthDigestMiddlewareTest(unittest.TestCase): res = req(self.app) res.status(status) - def test200(self): + def test200(self) -> None: self.doauth() - def test401authfail(self): + def test401authfail(self) -> None: self.doauth(password="spam", status=401) - def testqopauth(self): + def testqopauth(self) -> None: res = self.req(self.app) nonce = next(iter(filter(lambda x: x.startswith("nonce="), res.getheader("WWW-Authenticate").split()))) @@ -233,27 +246,31 @@ class AuthDigestMiddlewareTest(unittest.TestCase): from wsgitools import middlewares -def writing_application(environ, start_response): +def writing_application( + environ: Environ, start_response: StartResponse +) -> typing.Iterable[bytes]: write = start_response("404 Not found", [("Content-Type", "text/plain")]) write = start_response("200 Ok", [("Content-Type", "text/plain")]) write(b"first") yield b"" yield b"second" -def write_only_application(environ, start_response): +def write_only_application( + environ: Environ, start_response: StartResponse +) -> typing.Iterable[bytes]: write = start_response("200 Ok", [("Content-Type", "text/plain")]) write(b"first") write(b"second") yield b"" class NoWriteCallableMiddlewareTest(unittest.TestCase): - def testWrite(self): + def testWrite(self) -> None: app = middlewares.NoWriteCallableMiddleware(writing_application) res = Request(self)(app) self.assertEqual(res.writtendata, []) self.assertEqual(b"".join(res.returneddata), b"firstsecond") - def testWriteOnly(self): + def testWriteOnly(self) -> None: app = middlewares.NoWriteCallableMiddleware(write_only_application) res = Request(self)(app) self.assertEqual(res.writtendata, []) @@ -262,28 +279,30 @@ class NoWriteCallableMiddlewareTest(unittest.TestCase): class StupidIO: """file-like without tell method, so StaticFile is not able to determine the content-length.""" - def __init__(self, content): + def __init__(self, content: bytes): self.content = content self.position = 0 - def seek(self, pos): + def seek(self, pos: int) -> None: assert pos == 0 self.position = 0 - def read(self, length): + def read(self, length: int) -> bytes: oldpos = self.position self.position += length return self.content[oldpos:self.position] class ContentLengthMiddlewareTest(unittest.TestCase): - def customSetUp(self, maxstore=10): + def customSetUp( + self, maxstore: typing.Union[int, typing.Tuple[()]] = 10 + ) -> None: self.staticapp = applications.StaticFile(StupidIO(b"success"), "200 Found", [("Content-Type", "text/plain")]) self.app = middlewares.ContentLengthMiddleware(self.staticapp, maxstore=maxstore) self.req = Request(self) - def testWithout(self): + def testWithout(self) -> None: self.customSetUp() res = self.req(self.staticapp) res.status("200 Found") @@ -293,24 +312,26 @@ class ContentLengthMiddlewareTest(unittest.TestCase): except KeyError: pass - def testGet(self): + def testGet(self) -> None: self.customSetUp() res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") - def testInfiniteMaxstore(self): + def testInfiniteMaxstore(self) -> None: self.customSetUp(maxstore=()) res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") class CachingMiddlewareTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.cached = middlewares.CachingMiddleware(self.app) - self.accessed = dict() + self.accessed: typing.Dict[str, int] = {} - def app(self, environ, start_response): + def app( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: count = self.accessed.get(environ["SCRIPT_NAME"], 0) + 1 self.accessed[environ["SCRIPT_NAME"]] = count headers = [("Content-Type", "text/plain")] @@ -319,7 +340,7 @@ class CachingMiddlewareTest(unittest.TestCase): start_response("200 Found", headers) return [b"%d" % count] - def testCache(self): + def testCache(self) -> None: res = Request(self)(self.cached) res.status(200) self.assertEqual(res.get_data(), b"1") @@ -327,7 +348,7 @@ class CachingMiddlewareTest(unittest.TestCase): res.status(200) self.assertEqual(res.get_data(), b"1") - def testNoCache(self): + def testNoCache(self) -> None: res = Request(self)(self.cached) res.status(200) self.assertEqual(res.get_data(), b"1") @@ -337,7 +358,7 @@ class CachingMiddlewareTest(unittest.TestCase): self.assertEqual(res.get_data(), b"2") class BasicAuthMiddlewareTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.staticapp = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], b"success") checkpw = middlewares.DictAuthChecker({"bar": "baz"}) @@ -345,26 +366,26 @@ class BasicAuthMiddlewareTest(unittest.TestCase): wsgiref.validate.validator(self.staticapp), checkpw) self.req = Request(self) - def test401(self): + def test401(self) -> None: res = self.req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def test401garbage(self): + def test401garbage(self) -> None: req = self.req.copy() req.setheader('Authorization', 'Garbage') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def test401basicgarbage(self): + def test401basicgarbage(self) -> None: req = self.req.copy() req.setheader('Authorization', 'Basic ()') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) - def doauth(self, password="baz", status=200): + def doauth(self, password: str = "baz", status: int = 200) -> None: req = self.req.copy() token = "bar:%s" % password token = bytes2str(base64.b64encode(str2bytes(token))) @@ -372,19 +393,20 @@ class BasicAuthMiddlewareTest(unittest.TestCase): res = req(self.app) res.status(status) - def test200(self): + def test200(self) -> None: self.doauth() - def test401authfail(self): + def test401authfail(self) -> None: self.doauth(password="spam", status=401) from wsgitools import filters import gzip class RequestLogWSGIFilterTest(unittest.TestCase): - def testSimple(self): - app = applications.StaticContent("200 Found", - [("Content-Type", "text/plain")], b"nothing") + def testSimple(self) -> None: + app: WsgiApp = applications.StaticContent( + "200 Found", [("Content-Type", "text/plain")], b"nothing" + ) log = io.StringIO() logfilter = filters.RequestLogWSGIFilter.creator(log) app = filters.WSGIFilterMiddleware(app, logfilter) @@ -398,9 +420,10 @@ class RequestLogWSGIFilterTest(unittest.TestCase): r'200 7 - "wsgitools-test"', logged)) class GzipWSGIFilterTest(unittest.TestCase): - def testSimple(self): - app = applications.StaticContent("200 Found", - [("Content-Type", "text/plain")], b"nothing") + def testSimple(self) -> None: + app: WsgiApp = applications.StaticContent( + "200 Found", [("Content-Type", "text/plain")], b"nothing" + ) app = filters.WSGIFilterMiddleware(app, filters.GzipWSGIFilter) req = Request(self) req.environ["HTTP_ACCEPT_ENCODING"] = "gzip" @@ -413,7 +436,9 @@ import asyncore import socket import threading -def fetch_scgi(port, req, body=b""): +def fetch_scgi( + port: int, req: typing.Dict[str, str], body: bytes = b"" +) -> bytes: with socket.socket() as client: client.connect(("localhost", port)) req = req.copy() @@ -423,7 +448,7 @@ def fetch_scgi(port, req, body=b""): return client.recv(65536) class ScgiAsynchronousTest(unittest.TestCase): - def testSimple(self): + def testSimple(self) -> None: app = applications.StaticContent( "200 OK", [("Content-Type", "text/plain")], b"nothing" ) @@ -447,7 +472,7 @@ import os import signal class ScgiForkTest(unittest.TestCase): - def testSimple(self): + def testSimple(self) -> None: app = applications.StaticContent( "200 OK", [("Content-Type", "text/plain")], b"nothing" ) |