diff options
-rwxr-xr-x | test.py | 178 |
1 files changed, 178 insertions, 0 deletions
@@ -0,0 +1,178 @@ +#!/usr/bin/env python2.5 + +import unittest +import doctest + +try: + from hashlib import md5 +except ImportError: + from md5 import md5 + +try: + next +except NameError: + def next(iterator): + return iterator.next() + +class Request: + def __init__(self, case): + """ + @type case: unittest.TestCase + """ + self.testcase = case + self.environ = dict(REQUEST_METHOD="GET") + + def setenv(self, key, value): + """ + @type key: str + @type value: str + """ + self.environ[key] = value + + def setmethod(self, request_method): + """ + @type request_method: str + """ + self.setenv("REQUEST_METHOD", request_method) + + def setheader(self, name, value): + """ + @type name: str + @type value: str + """ + self.setenv(name.upper().replace('-', '_'), value) + + def copy(self): + req = Request(self.testcase) + req.environ = dict(self.environ) + return req + + def __call__(self, app): + res = Result(self.testcase) + def write(data): + res.writtendata.append(data) + def start_response(status, headers, exc_info=None): + res.statusdata = status + res.headersdata = headers + return write + res.returneddata = app(self.environ, start_response) + return res + +class Result: + def __init__(self, case): + """ + @type case: unittest.TestCase + """ + self.testcase = case + self.statusdata = None + self.headersdata = None + self.writtendata = [] + self.returneddata = None + + def status(self, check): + """ + @type check: int or str + """ + if isinstance(check, int): + status = int(self.statusdata.split()[0]) + self.testcase.assertEqual(check, status) + else: + self.testcase.assertEqual(check, self.statusdata) + + def getheader(self, name): + """ + @type name: str + @raises KeyError: + """ + for key, value in self.headersdata: + if key == name: + return value + raise KeyError + + def header(self, name, check): + """ + @type name: str + @type check: str or (str -> bool) + """ + for key, value in self.headersdata: + if key == name: + if isinstance(check, str): + self.testcase.assertEqual(check, value) + else: + self.testcase.assert_(check(value)) + +from wsgitools import applications + +class StaticContentTest(unittest.TestCase): + def setUp(self): + self.app = applications.StaticContent("200 Found", [("Spam", "Egg")], + "nothing") + self.req = Request(self) + + def testGet(self): + res = self.req(self.app) + res.status("200 Found") + res.header("Content-length", "7") + + def testHead(self): + req = self.req.copy() + req.setmethod("HEAD") + res = req(self.app) + res.status(200) + res.header("Content-length", "7") + +from wsgitools import digest + +class AuthDigestMiddlewareTest(unittest.TestCase): + def setUp(self): + self.staticapp = applications.StaticContent("200 Found", [], "success") + token_gen = digest.AuthTokenGenerator("foo", lambda _: "baz") + self.app = digest.AuthDigestMiddleware(self.staticapp, token_gen) + self.req = Request(self) + + def test401(self): + res = self.req(self.app) + res.status(401) + res.header("WWW-Authenticate", lambda _: True) + + def doauth(self, password="baz", status=200): + res = self.req(self.app) + nonce = next(iter(filter(lambda x: x.startswith("nonce="), + res.getheader("WWW-Authenticate").split()))) + nonce = nonce.split('"')[1] + req = self.req.copy() + token = md5("bar:foo:%s" % password).hexdigest() + other = md5("GET:").hexdigest() + resp = md5("%s:%s:%s" % (token, nonce, other)).hexdigest() + req.setheader('http-authorization', 'Digest algorithm=md5,nonce="%s",' \ + 'uri=,username=bar,response="%s"' % (nonce, resp)) + res = req(self.app) + res.status(status) + + def test200(self): + self.doauth() + + def test401authfail(self): + self.doauth(password="spam", status=401) + +def alltests(case): + return unittest.TestLoader().loadTestsFromTestCase(case) + +fullsuite = unittest.TestSuite() +fullsuite.addTest(doctest.DocTestSuite("wsgitools.digest")) +fullsuite.addTest(alltests(StaticContentTest)) +fullsuite.addTest(alltests(AuthDigestMiddlewareTest)) + +if __name__ == "__main__": + runner = unittest.TextTestRunner(verbosity=2) + import sys + if "profile" in sys.argv: + try: + import cProfile as profile + except ImportError: + import profile + prof = profile.Profile() + prof.runcall(runner.run, fullsuite) + prof.dump_stats("wsgitools.pstat") + else: + runner.run(fullsuite) |