#!/usr/bin/env python import unittest import doctest import re import wsgiref.validate # Cannot use io module as it is broken in 2.6. # Writing a str to a io.StringIO results in an exception. try: import cStringIO as io except ImportError: import StringIO as io try: from hashlib import md5 except ImportError: from md5 import md5 import sys try: next except NameError: def next(iterator): return iterator.next() class Request: def __init__(self, case): """ @type case: unittest.TestCase """ self.testcase = case self.environ = dict( REQUEST_METHOD="GET", SERVER_NAME="localhost", SERVER_PORT="80", SCRIPT_NAME="", PATH_INFO="", QUERY_STRING="") self.environ.update({ "wsgi.version": (1, 0), "wsgi.input": io.StringIO(), "wsgi.errors": sys.stderr, "wsgi.url_scheme": "http", "wsgi.multithread": False, "wsgi.multiprocess": False, "wsgi.run_once": False}) def setenv(self, key, value): """ @type key: str @type value: str """ self.environ[key] = value def setmethod(self, request_method): """ @type request_method: str """ self.setenv("REQUEST_METHOD", request_method) def setheader(self, name, value): """ @type name: str @type value: str """ self.setenv(name.upper().replace('-', '_'), value) def copy(self): req = Request(self.testcase) req.environ = dict(self.environ) return req def __call__(self, app): app = wsgiref.validate.validator(app) res = Result(self.testcase) def write(data): res.writtendata.append(data) def start_response(status, headers, exc_info=None): res.statusdata = status res.headersdata = headers return write iterator = app(self.environ, start_response) res.returneddata = list(iterator) if hasattr(iterator, "close"): iterator.close() return res class Result: def __init__(self, case): """ @type case: unittest.TestCase """ self.testcase = case self.statusdata = None self.headersdata = None self.writtendata = [] self.returneddata = None def status(self, check): """ @type check: int or str """ if isinstance(check, int): status = int(self.statusdata.split()[0]) self.testcase.assertEqual(check, status) else: self.testcase.assertEqual(check, self.statusdata) def getheader(self, name): """ @type name: str @raises KeyError: """ for key, value in self.headersdata: if key == name: return value raise KeyError def header(self, name, check): """ @type name: str @type check: str or (str -> bool) """ found = False for key, value in self.headersdata: if key == name: found = True if isinstance(check, str): self.testcase.assertEqual(check, value) else: self.testcase.assert_(check(value)) if not found: self.testcase.fail("header %s not found" % name) def get_data(self): return "".join(self.writtendata) + "".join(self.returneddata) from wsgitools import applications class StaticContentTest(unittest.TestCase): def setUp(self): self.app = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], "nothing") self.req = Request(self) def testGet(self): res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") def testHead(self): req = self.req.copy() req.setmethod("HEAD") res = req(self.app) res.status(200) res.header("Content-length", "7") class StaticFileTest(unittest.TestCase): def setUp(self): self.app = applications.StaticFile(io.StringIO("success"), "200 Found", [("Content-Type", "text/plain")]) self.req = Request(self) def testGet(self): res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") def testHead(self): req = self.req.copy() req.setmethod("HEAD") res = req(self.app) res.status(200) res.header("Content-length", "7") from wsgitools import digest class AuthDigestMiddlewareTest(unittest.TestCase): def setUp(self): self.staticapp = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], "success") token_gen = digest.AuthTokenGenerator("foo", lambda _: "baz") self.app = digest.AuthDigestMiddleware( wsgiref.validate.validator(self.staticapp), token_gen) self.req = Request(self) def test401(self): res = self.req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) def test401garbage(self): req = self.req.copy() req.setheader('http-authorization', 'Garbage') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) def test401digestgarbage(self): req = self.req.copy() req.setheader('http-authorization', 'Digest ","') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) def doauth(self, password="baz", status=200): res = self.req(self.app) nonce = next(iter(filter(lambda x: x.startswith("nonce="), res.getheader("WWW-Authenticate").split()))) nonce = nonce.split('"')[1] req = self.req.copy() token = md5("bar:foo:%s" % password).hexdigest() other = md5("GET:").hexdigest() resp = md5("%s:%s:%s" % (token, nonce, other)).hexdigest() req.setheader('http-authorization', 'Digest algorithm=md5,nonce="%s",' \ 'uri=,username=bar,response="%s"' % (nonce, resp)) res = req(self.app) res.status(status) def test200(self): self.doauth() def test401authfail(self): self.doauth(password="spam", status=401) def testqopauth(self): res = self.req(self.app) nonce = next(iter(filter(lambda x: x.startswith("nonce="), res.getheader("WWW-Authenticate").split()))) nonce = nonce.split('"')[1] req = self.req.copy() token = md5("bar:foo:baz").hexdigest() other = md5("GET:").hexdigest() resp = md5("%s:%s:1:qux:auth:%s" % (token, nonce, other)).hexdigest() req.setheader('http-authorization', 'Digest algorithm=md5,nonce="%s",' \ 'uri=,username=bar,response="%s",qop=auth,nc=1,' \ 'cnonce=qux' % (nonce, resp)) res = req(self.app) res.status(200) from wsgitools import middlewares def writing_application(environ, start_response): write = start_response("404 Not found", [("Content-Type", "text/plain")]) write = start_response("200 Ok", [("Content-Type", "text/plain")]) write("first") yield "" yield "second" def write_only_application(environ, start_response): write = start_response("200 Ok", [("Content-Type", "text/plain")]) write("first") write("second") yield "" class NoWriteCallableMiddlewareTest(unittest.TestCase): def testWrite(self): app = middlewares.NoWriteCallableMiddleware(writing_application) res = Request(self)(app) self.assertEqual(res.writtendata, []) self.assertEqual("".join(res.returneddata), "firstsecond") def testWriteOnly(self): app = middlewares.NoWriteCallableMiddleware(write_only_application) res = Request(self)(app) self.assertEqual(res.writtendata, []) self.assertEqual("".join(res.returneddata), "firstsecond") class StupidIO: """file-like without tell method, so StaticFile is not able to determine the content-length.""" def __init__(self, content): self.content = content self.position = 0 def seek(self, pos): assert pos == 0 self.position = 0 def read(self, length): oldpos = self.position self.position += length return self.content[oldpos:self.position] class ContentLengthMiddlewareTest(unittest.TestCase): def setUp(self): self.staticapp = applications.StaticFile(StupidIO("success"), "200 Found", [("Content-Type", "text/plain")]) self.app = middlewares.ContentLengthMiddleware(self.staticapp, maxstore=10) self.req = Request(self) def testWithout(self): res = self.req(self.staticapp) res.status("200 Found") try: res.getheader("Content-length") self.fail("Content-length header found, test is useless") except KeyError: pass def testGet(self): res = self.req(self.app) res.status("200 Found") res.header("Content-length", "7") class BasicAuthMiddlewareTest(unittest.TestCase): def setUp(self): self.staticapp = applications.StaticContent( "200 Found", [("Content-Type", "text/plain")], "success") checkpw = middlewares.DictAuthChecker({"bar": "baz"}) self.app = middlewares.BasicAuthMiddleware( wsgiref.validate.validator(self.staticapp), checkpw) self.req = Request(self) def test401(self): res = self.req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) def test401garbage(self): req = self.req.copy() req.setheader('http-authorization', 'Garbage') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) def test401basicgarbage(self): req = self.req.copy() req.setheader('http-authorization', 'Basic ()') res = req(self.app) res.status(401) res.header("WWW-Authenticate", lambda _: True) def doauth(self, password="baz", status=200): req = self.req.copy() token = ("bar:%s" % password).encode("base64").strip() req.setheader('http-authorization', 'Basic %s' % token) res = req(self.app) res.status(status) def test200(self): self.doauth() def test401authfail(self): self.doauth(password="spam", status=401) from wsgitools import filters import gzip class RequestLogWSGIFilterTest(unittest.TestCase): def testSimple(self): app = applications.StaticContent("200 Found", [("Content-Type", "text/plain")], "nothing") log = io.StringIO() logfilter = filters.RequestLogWSGIFilter.creator(log) app = filters.WSGIFilterMiddleware(app, logfilter) req = Request(self) req.environ["REMOTE_ADDR"] = "1.2.3.4" req.environ["PATH_INFO"] = "/" req.environ["HTTP_USER_AGENT"] = "wsgitools-test" res = req(app) logged = log.getvalue() self.assert_(re.match(r'^1\.2\.3\.4 - - \[[^]]+\] "GET /" ' r'200 7 - "wsgitools-test"', logged)) class GzipWSGIFilterTest(unittest.TestCase): def testSimple(self): app = applications.StaticContent("200 Found", [("Content-Type", "text/plain")], "nothing") app = filters.WSGIFilterMiddleware(app, filters.GzipWSGIFilter) req = Request(self) req.environ["HTTP_ACCEPT_ENCODING"] = "gzip" res = req(app) data = gzip.GzipFile(fileobj=io.StringIO(res.get_data())).read() self.assertEqual(data, "nothing") def alltests(case): return unittest.TestLoader().loadTestsFromTestCase(case) fullsuite = unittest.TestSuite() fullsuite.addTest(doctest.DocTestSuite("wsgitools.digest")) fullsuite.addTest(alltests(StaticContentTest)) fullsuite.addTest(alltests(StaticFileTest)) fullsuite.addTest(alltests(AuthDigestMiddlewareTest)) fullsuite.addTest(alltests(ContentLengthMiddlewareTest)) fullsuite.addTest(alltests(BasicAuthMiddlewareTest)) fullsuite.addTest(alltests(NoWriteCallableMiddlewareTest)) fullsuite.addTest(alltests(RequestLogWSGIFilterTest)) fullsuite.addTest(alltests(GzipWSGIFilterTest)) if __name__ == "__main__": runner = unittest.TextTestRunner(verbosity=2) if "profile" in sys.argv: try: import cProfile as profile except ImportError: import profile prof = profile.Profile() prof.runcall(runner.run, fullsuite) prof.dump_stats("wsgitools.pstat") else: sys.exit(len(runner.run(fullsuite).failures))