summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'test.py')
-rwxr-xr-xtest.py207
1 files changed, 116 insertions, 91 deletions
diff --git a/test.py b/test.py
index 9690baf..5271819 100755
--- a/test.py
+++ b/test.py
@@ -8,16 +8,23 @@ import wsgiref.validate
import io
from hashlib import md5
import sys
-
-from wsgitools.internal import bytes2str, str2bytes
+import typing
+
+from wsgitools.internal import (
+ bytes2str,
+ Environ,
+ HeaderList,
+ OptExcInfo,
+ StartResponse,
+ str2bytes,
+ WriteCallback,
+ WsgiApp,
+)
class Request:
- def __init__(self, case):
- """
- @type case: unittest.TestCase
- """
+ def __init__(self, case: unittest.TestCase):
self.testcase = case
- self.environ = dict(
+ self.environ: Environ = dict(
REQUEST_METHOD="GET",
SERVER_NAME="localhost",
SERVER_PORT="80",
@@ -33,86 +40,92 @@ class Request:
"wsgi.multiprocess": False,
"wsgi.run_once": False})
- def setenv(self, key, value):
+ def setenv(self, key: str, value: str) -> "Request":
"""
- @type key: str
- @type value: str
@returns: self
"""
self.environ[key] = value
return self
- def setmethod(self, request_method):
+ def setmethod(self, request_method: str) -> "Request":
"""
- @type request_method: str
@returns: self
"""
return self.setenv("REQUEST_METHOD", request_method)
- def setheader(self, name, value):
+ def setheader(self, name: str, value: str) -> "Request":
"""
- @type name: str
- @type value: str
@returns: self
"""
return self.setenv("HTTP_" + name.upper().replace('-', '_'), value)
- def copy(self):
+ def copy(self) -> "Request":
req = Request(self.testcase)
req.environ = dict(self.environ)
return req
- def __call__(self, app):
+ def __call__(self, app: WsgiApp) -> "Result":
app = wsgiref.validate.validator(app)
- res = Result(self.testcase)
- def write(data):
- res.writtendata.append(data)
- def start_response(status, headers, exc_info=None):
- res.statusdata = status
- res.headersdata = headers
+ writtendata: typing.List[bytes] = []
+ def write(data: bytes) -> None:
+ nonlocal writtendata
+ writtendata.append(data)
+ statusdata: typing.Optional[str] = None
+ headersdata: typing.Optional[HeaderList] = None
+ def start_response(
+ status: str, headers: HeaderList, exc_info: OptExcInfo = None
+ ) -> WriteCallback:
+ nonlocal statusdata, headersdata
+ statusdata = status
+ headersdata = headers
return write
iterator = app(self.environ, start_response)
- res.returneddata = list(iterator)
+ returneddata = list(iterator)
if hasattr(iterator, "close"):
iterator.close()
- return res
+ assert statusdata is not None
+ assert headersdata is not None
+ return Result(
+ self.testcase, statusdata, headersdata, writtendata, returneddata
+ )
class Result:
- def __init__(self, case):
- """
- @type case: unittest.TestCase
- """
+ def __init__(
+ self,
+ case: unittest.TestCase,
+ statusdata: str,
+ headersdata: HeaderList,
+ writtendata: typing.List[bytes],
+ returneddata: typing.List[bytes],
+ ):
self.testcase = case
- self.statusdata = None
- self.headersdata = None
- self.writtendata = []
- self.returneddata = None
+ self.statusdata = statusdata
+ self.headersdata = headersdata
+ self.writtendata = writtendata
+ self.returneddata = returneddata
- def status(self, check):
- """
- @type check: int or str
- """
+ def status(self, check: typing.Union[int, str]) -> None:
if isinstance(check, int):
+ assert self.statusdata is not None
status = int(self.statusdata.split()[0])
self.testcase.assertEqual(check, status)
else:
self.testcase.assertEqual(check, self.statusdata)
- def getheader(self, name):
+ def getheader(self, name: str) -> str:
"""
- @type name: str
@raises KeyError:
"""
+ assert self.headersdata is not None
for key, value in self.headersdata:
if key == name:
return value
raise KeyError
- def header(self, name, check):
- """
- @type name: str
- @type check: str or (str -> bool)
- """
+ def header(
+ self, name: str, check: typing.Union[str, typing.Callable[[str], bool]]
+ ) -> None:
+ assert self.headersdata is not None
found = False
for key, value in self.headersdata:
if key == name:
@@ -124,23 +137,23 @@ class Result:
if not found:
self.testcase.fail("header %s not found" % name)
- def get_data(self):
+ def get_data(self) -> bytes:
return b"".join(self.writtendata) + b"".join(self.returneddata)
from wsgitools import applications
class StaticContentTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.app = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"nothing")
self.req = Request(self)
- def testGet(self):
+ def testGet(self) -> None:
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testHead(self):
+ def testHead(self) -> None:
req = self.req.copy()
req.setmethod("HEAD")
res = req(self.app)
@@ -148,17 +161,17 @@ class StaticContentTest(unittest.TestCase):
res.header("Content-length", "7")
class StaticFileTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.app = applications.StaticFile(io.BytesIO(b"success"), "200 Found",
[("Content-Type", "text/plain")])
self.req = Request(self)
- def testGet(self):
+ def testGet(self) -> None:
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testHead(self):
+ def testHead(self) -> None:
req = self.req.copy()
req.setmethod("HEAD")
res = req(self.app)
@@ -168,7 +181,7 @@ class StaticFileTest(unittest.TestCase):
from wsgitools import digest
class AuthDigestMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.staticapp = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"success")
token_gen = digest.AuthTokenGenerator("foo", lambda _: "baz")
@@ -176,26 +189,26 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
wsgiref.validate.validator(self.staticapp), token_gen)
self.req = Request(self)
- def test401(self):
+ def test401(self) -> None:
res = self.req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401garbage(self):
+ def test401garbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Garbage')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401digestgarbage(self):
+ def test401digestgarbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Digest ","')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def doauth(self, password="baz", status=200):
+ def doauth(self, password: str = "baz", status: int = 200) -> None:
res = self.req(self.app)
nonce = next(iter(filter(lambda x: x.startswith("nonce="),
res.getheader("WWW-Authenticate").split())))
@@ -209,13 +222,13 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
res = req(self.app)
res.status(status)
- def test200(self):
+ def test200(self) -> None:
self.doauth()
- def test401authfail(self):
+ def test401authfail(self) -> None:
self.doauth(password="spam", status=401)
- def testqopauth(self):
+ def testqopauth(self) -> None:
res = self.req(self.app)
nonce = next(iter(filter(lambda x: x.startswith("nonce="),
res.getheader("WWW-Authenticate").split())))
@@ -233,27 +246,31 @@ class AuthDigestMiddlewareTest(unittest.TestCase):
from wsgitools import middlewares
-def writing_application(environ, start_response):
+def writing_application(
+ environ: Environ, start_response: StartResponse
+) -> typing.Iterable[bytes]:
write = start_response("404 Not found", [("Content-Type", "text/plain")])
write = start_response("200 Ok", [("Content-Type", "text/plain")])
write(b"first")
yield b""
yield b"second"
-def write_only_application(environ, start_response):
+def write_only_application(
+ environ: Environ, start_response: StartResponse
+) -> typing.Iterable[bytes]:
write = start_response("200 Ok", [("Content-Type", "text/plain")])
write(b"first")
write(b"second")
yield b""
class NoWriteCallableMiddlewareTest(unittest.TestCase):
- def testWrite(self):
+ def testWrite(self) -> None:
app = middlewares.NoWriteCallableMiddleware(writing_application)
res = Request(self)(app)
self.assertEqual(res.writtendata, [])
self.assertEqual(b"".join(res.returneddata), b"firstsecond")
- def testWriteOnly(self):
+ def testWriteOnly(self) -> None:
app = middlewares.NoWriteCallableMiddleware(write_only_application)
res = Request(self)(app)
self.assertEqual(res.writtendata, [])
@@ -262,28 +279,30 @@ class NoWriteCallableMiddlewareTest(unittest.TestCase):
class StupidIO:
"""file-like without tell method, so StaticFile is not able to
determine the content-length."""
- def __init__(self, content):
+ def __init__(self, content: bytes):
self.content = content
self.position = 0
- def seek(self, pos):
+ def seek(self, pos: int) -> None:
assert pos == 0
self.position = 0
- def read(self, length):
+ def read(self, length: int) -> bytes:
oldpos = self.position
self.position += length
return self.content[oldpos:self.position]
class ContentLengthMiddlewareTest(unittest.TestCase):
- def customSetUp(self, maxstore=10):
+ def customSetUp(
+ self, maxstore: typing.Union[int, typing.Tuple[()]] = 10
+ ) -> None:
self.staticapp = applications.StaticFile(StupidIO(b"success"),
"200 Found", [("Content-Type", "text/plain")])
self.app = middlewares.ContentLengthMiddleware(self.staticapp,
maxstore=maxstore)
self.req = Request(self)
- def testWithout(self):
+ def testWithout(self) -> None:
self.customSetUp()
res = self.req(self.staticapp)
res.status("200 Found")
@@ -293,24 +312,26 @@ class ContentLengthMiddlewareTest(unittest.TestCase):
except KeyError:
pass
- def testGet(self):
+ def testGet(self) -> None:
self.customSetUp()
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
- def testInfiniteMaxstore(self):
+ def testInfiniteMaxstore(self) -> None:
self.customSetUp(maxstore=())
res = self.req(self.app)
res.status("200 Found")
res.header("Content-length", "7")
class CachingMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.cached = middlewares.CachingMiddleware(self.app)
- self.accessed = dict()
+ self.accessed: typing.Dict[str, int] = {}
- def app(self, environ, start_response):
+ def app(
+ self, environ: Environ, start_response: StartResponse
+ ) -> typing.Iterable[bytes]:
count = self.accessed.get(environ["SCRIPT_NAME"], 0) + 1
self.accessed[environ["SCRIPT_NAME"]] = count
headers = [("Content-Type", "text/plain")]
@@ -319,7 +340,7 @@ class CachingMiddlewareTest(unittest.TestCase):
start_response("200 Found", headers)
return [b"%d" % count]
- def testCache(self):
+ def testCache(self) -> None:
res = Request(self)(self.cached)
res.status(200)
self.assertEqual(res.get_data(), b"1")
@@ -327,7 +348,7 @@ class CachingMiddlewareTest(unittest.TestCase):
res.status(200)
self.assertEqual(res.get_data(), b"1")
- def testNoCache(self):
+ def testNoCache(self) -> None:
res = Request(self)(self.cached)
res.status(200)
self.assertEqual(res.get_data(), b"1")
@@ -337,7 +358,7 @@ class CachingMiddlewareTest(unittest.TestCase):
self.assertEqual(res.get_data(), b"2")
class BasicAuthMiddlewareTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.staticapp = applications.StaticContent(
"200 Found", [("Content-Type", "text/plain")], b"success")
checkpw = middlewares.DictAuthChecker({"bar": "baz"})
@@ -345,26 +366,26 @@ class BasicAuthMiddlewareTest(unittest.TestCase):
wsgiref.validate.validator(self.staticapp), checkpw)
self.req = Request(self)
- def test401(self):
+ def test401(self) -> None:
res = self.req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401garbage(self):
+ def test401garbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Garbage')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def test401basicgarbage(self):
+ def test401basicgarbage(self) -> None:
req = self.req.copy()
req.setheader('Authorization', 'Basic ()')
res = req(self.app)
res.status(401)
res.header("WWW-Authenticate", lambda _: True)
- def doauth(self, password="baz", status=200):
+ def doauth(self, password: str = "baz", status: int = 200) -> None:
req = self.req.copy()
token = "bar:%s" % password
token = bytes2str(base64.b64encode(str2bytes(token)))
@@ -372,19 +393,20 @@ class BasicAuthMiddlewareTest(unittest.TestCase):
res = req(self.app)
res.status(status)
- def test200(self):
+ def test200(self) -> None:
self.doauth()
- def test401authfail(self):
+ def test401authfail(self) -> None:
self.doauth(password="spam", status=401)
from wsgitools import filters
import gzip
class RequestLogWSGIFilterTest(unittest.TestCase):
- def testSimple(self):
- app = applications.StaticContent("200 Found",
- [("Content-Type", "text/plain")], b"nothing")
+ def testSimple(self) -> None:
+ app: WsgiApp = applications.StaticContent(
+ "200 Found", [("Content-Type", "text/plain")], b"nothing"
+ )
log = io.StringIO()
logfilter = filters.RequestLogWSGIFilter.creator(log)
app = filters.WSGIFilterMiddleware(app, logfilter)
@@ -398,9 +420,10 @@ class RequestLogWSGIFilterTest(unittest.TestCase):
r'200 7 - "wsgitools-test"', logged))
class GzipWSGIFilterTest(unittest.TestCase):
- def testSimple(self):
- app = applications.StaticContent("200 Found",
- [("Content-Type", "text/plain")], b"nothing")
+ def testSimple(self) -> None:
+ app: WsgiApp = applications.StaticContent(
+ "200 Found", [("Content-Type", "text/plain")], b"nothing"
+ )
app = filters.WSGIFilterMiddleware(app, filters.GzipWSGIFilter)
req = Request(self)
req.environ["HTTP_ACCEPT_ENCODING"] = "gzip"
@@ -413,7 +436,9 @@ import asyncore
import socket
import threading
-def fetch_scgi(port, req, body=b""):
+def fetch_scgi(
+ port: int, req: typing.Dict[str, str], body: bytes = b""
+) -> bytes:
with socket.socket() as client:
client.connect(("localhost", port))
req = req.copy()
@@ -423,7 +448,7 @@ def fetch_scgi(port, req, body=b""):
return client.recv(65536)
class ScgiAsynchronousTest(unittest.TestCase):
- def testSimple(self):
+ def testSimple(self) -> None:
app = applications.StaticContent(
"200 OK", [("Content-Type", "text/plain")], b"nothing"
)
@@ -447,7 +472,7 @@ import os
import signal
class ScgiForkTest(unittest.TestCase):
- def testSimple(self):
+ def testSimple(self) -> None:
app = applications.StaticContent(
"200 OK", [("Content-Type", "text/plain")], b"nothing"
)