From 91eeb7df51850ebfb28db28642ed6af1c9ffe93a Mon Sep 17 00:00:00 2001
From: Helmut Grohne <helmut@subdivi.de>
Date: Sun, 27 Nov 2011 13:24:14 +0100
Subject: added new base class AuthenticationMiddleware

The BasicAuthMiddleware and AuthDigestMiddleware now derive from
AuthenticationMiddleware which provides common functionality.
---
 wsgitools/authentication.py |  80 ++++++++++++++++++++++++
 wsgitools/digest.py         | 145 ++++++++++++++++----------------------------
 wsgitools/middlewares.py    |  59 ++++++++----------
 3 files changed, 158 insertions(+), 126 deletions(-)
 create mode 100644 wsgitools/authentication.py

(limited to 'wsgitools')

diff --git a/wsgitools/authentication.py b/wsgitools/authentication.py
new file mode 100644
index 0000000..a36b794
--- /dev/null
+++ b/wsgitools/authentication.py
@@ -0,0 +1,80 @@
+__all__ = []
+
+class AuthenticationRequired(Exception):
+    """
+    Internal Exception class that is thrown inside L{AuthenticationMiddleware},
+    but not visible to other code.
+    """
+
+class ProtocolViolation(AuthenticationRequired):
+    pass
+
+class AuthenticationMiddleware:
+    """Base class for HTTP authorization schemes.
+
+    @cvar authorization_required: the implemented Authorization method. It will
+        be verified against Authorization headers. Subclasses must define this
+        attribute.
+    @type authorization_required: str
+    """
+    authorization_method = None
+    def __init__(self, app):
+        """
+        @param app: is a WSGI application.
+        """
+        assert self.authorization_method is not None
+        self.app = app
+
+    def authenticate(self, auth, environ, start_response):
+        """
+        @type auth: str
+        @param auth: is the part of the Authorization header after the method
+        """
+        raise NotImplementedError
+
+    def __call__(self, environ, start_response):
+        assert isinstance(environ, dict)
+        try:
+            try:
+                auth = environ["HTTP_AUTHORIZATION"]
+            except KeyError:
+                raise AuthenticationRequired("no Authorization header found")
+            try:
+                method, rest = auth.split(' ', 1)
+            except ValueError:
+                method, rest = auth, ""
+            if method.lower() != self.authorization_method:
+                raise AuthenticationRequired(
+                    "authorization method not implemented: %r" % method)
+            return self.authenticate(rest, environ, start_response)
+        except AuthenticationRequired, exc:
+            return self.authorization_required(environ, start_response, exc)
+
+    def www_authenticate(self, exception):
+        """Generates a WWW-Authenticate header. Subclasses must implement this
+        method.
+
+        @type exception: AuthenticationRequired
+        @param exception: reason for generating the header
+        @rtype: (str, str)
+        @returns: the header as (part_before_colon, part_after_colon)
+        """
+        raise NotImplementedError
+
+    def authorization_required(self, environ, start_response, exception):
+        """Generate an error page after failed authentication. Apart from the
+        exception parameter, this method behaves like a WSGI application.
+
+        @type exception: AuthenticationRequired
+        @param exception: reason for the authentication failure
+        """
+        status = "401 Authorization required"
+        html = "<html><head><title>401 Authorization required</title></head>" \
+               "<body><h1>401 Authorization required</h1></body></html>"
+        headers = [("Content-Type", "text/html"),
+                   self.www_authenticate(exception),
+                   ("Content-Length", str(len(html)))]
+        start_response(status, headers)
+        if environ["REQUEST_METHOD"].upper() == "HEAD":
+            return []
+        return [html]
diff --git a/wsgitools/digest.py b/wsgitools/digest.py
index 53b7dea..c98517f 100644
--- a/wsgitools/digest.py
+++ b/wsgitools/digest.py
@@ -24,6 +24,9 @@ import base64
 import time
 import os
 
+from wsgitools.authentication import AuthenticationRequired, \
+        ProtocolViolation, AuthenticationMiddleware
+
 sysrand = random.SystemRandom()
 
 def gen_rand_str(bytes=33):
@@ -83,15 +86,6 @@ def parse_digest_response(data, ret=None):
     ret[key] = value
     return parse_digest_response(rest, ret)
 
-class AuthenticationRequired(Exception):
-    """
-    Internal Exception class that is thrown inside L{AuthDigestMiddleware}, but
-    not visible to other code.
-    """
-
-class ProtocolViolation(AuthenticationRequired):
-    pass
-
 class StaleNonce(AuthenticationRequired):
     pass
 
@@ -632,15 +626,16 @@ def check_uri(credentials, environ):
         raise AuthenticationRequired("url mismatch")
 
 __all__.append("AuthDigestMiddleware")
-class AuthDigestMiddleware:
+class AuthDigestMiddleware(AuthenticationMiddleware):
     """Middleware partly implementing RFC2617. (md5-sess was omited)
     Upon successful authentication the environ dict will be extended
     by a REMOTE_USER key before being passed to the wrapped
     application."""
+    authorization_method = "digest"
     algorithms = {"md5": lambda data: md5(data).hexdigest()}
     def __init__(self, app, gentoken, maxage=300, maxuses=5, store=None):
         """
-        @param app: is the wsgi application to be served with authentification.
+        @param app: is the wsgi application to be served with authentication.
         @type gentoken: str -> (str or None)
         @param gentoken: has to have the same functionality and interface as the
                 L{AuthTokenGenerator} class.
@@ -653,7 +648,7 @@ class AuthDigestMiddleware:
         @param store: a nonce storage implementation object. Usage of this
                 parameter will override maxage and maxuses.
         """
-        self.app = app
+        AuthenticationMiddleware.__init__(self, app)
         self.gentoken = gentoken
         if store is None:
             self.noncestore = MemoryNonceStore(maxage, maxuses)
@@ -662,78 +657,62 @@ class AuthDigestMiddleware:
             assert hasattr(store, "checknonce")
             self.noncestore = store
 
-    def __call__(self, environ, start_response):
+    def authenticate(self, auth, environ, start_response):
         """wsgi interface"""
 
         try:
+            credentials = parse_digest_response(auth)
+        except ValueError:
+            raise ProtocolViolation("failed to parse digest response")
+
+        ### Check algorithm field
+        credentials["algorithm"] = credentials.get("algorithm",
+                                                   "md5").lower()
+        if not credentials["algorithm"] in self.algorithms:
+            raise ProtocolViolation("algorithm not implemented: %r" %
+                                    credentials["algorithm"])
+
+        check_uri(credentials, environ)
+
+        try:
+            nonce = credentials["nonce"]
+            credresponse = credentials["response"]
+        except KeyError, err:
+            raise ProtocolViolation("%s missing in credentials" %
+                                    err.args[0])
+        noncecount = 1
+        if "qop" in credentials:
+            if credentials["qop"] != "auth":
+                raise ProtocolViolation("unimplemented qop: %r" %
+                                        credentials["qop"])
             try:
-                auth = environ["HTTP_AUTHORIZATION"]
+                noncecount = int(credentials["nc"], 16)
             except KeyError:
-                raise AuthenticationRequired("no Authorization header found")
-            try:
-                method, rest = auth.split(' ', 1)
+                raise ProtocolViolation("nc missing in qop=auth")
             except ValueError:
-                method, rest = auth, ""
+                raise ProtocolViolation("non hexdigit found in nonce count")
 
-            if method.lower() != "digest":
-                raise AuthenticationRequired(
-                    "authorization method not implemented: %r" % method)
-            try:
-                credentials = parse_digest_response(rest)
-            except ValueError:
-                raise ProtocolViolation("failed to parse digest response")
+        # raises AuthenticationRequired
+        response = self.auth_response(credentials,
+                                      environ["REQUEST_METHOD"])
 
-            ### Check algorithm field
-            credentials["algorithm"] = credentials.get("algorithm",
-                                                       "md5").lower()
-            if not credentials["algorithm"] in self.algorithms:
-                raise ProtocolViolation("algorithm not implemented: %r" %
-                                        credentials["algorithm"])
+        if not self.noncestore.checknonce(nonce, noncecount):
+            raise StaleNonce()
 
-            check_uri(credentials, environ)
+        if response is None or response != credresponse:
+            raise AuthenticationRequired("wrong response")
 
-            try:
-                nonce = credentials["nonce"]
-                credresponse = credentials["response"]
-            except KeyError, err:
-                raise ProtocolViolation("%s missing in credentials" %
-                                        err.args[0])
-            noncecount = 1
+        environ["REMOTE_USER"] = credentials["username"]
+        def modified_start_response(status, headers, exc_info=None):
+            digest = dict(nextnonce=self.noncestore.newnonce())
             if "qop" in credentials:
-                if credentials["qop"] != "auth":
-                    raise ProtocolViolation("unimplemented qop: %r" %
-                                            credentials["qop"])
-                try:
-                    noncecount = int(credentials["nc"], 16)
-                except KeyError:
-                    raise ProtocolViolation("nc missing in qop=auth")
-                except ValueError:
-                    raise ProtocolViolation("non hexdigit found in nonce count")
-
-            # raises AuthenticationRequired
-            response = self.auth_response(credentials,
-                                          environ["REQUEST_METHOD"])
-
-            if not self.noncestore.checknonce(nonce, noncecount):
-                raise StaleNonce()
-
-            if response is None or response != credresponse:
-                raise AuthenticationRequired("wrong response")
-
-        except AuthenticationRequired, exc:
-            return self.authorization_required(environ, start_response, exc)
-        else:
-            environ["REMOTE_USER"] = credentials["username"]
-            def modified_start_response(status, headers, exc_info=None):
-                digest = dict(nextnonce=self.noncestore.newnonce())
-                if "qop" in credentials:
-                    digest["qop"] = "auth"
-                    digest["cnonce"] = credentials["cnonce"] # no KeyError
-                    digest["rspauth"] = self.auth_response(credentials, "")
-                challenge = ", ".join(map('%s="%s"'.__mod__, digest.items()))
-                headers.append(("Authentication-Info", challenge))
-                return start_response(status, headers, exc_info)
-            return self.app(environ, modified_start_response)
+                digest["qop"] = "auth"
+                digest["cnonce"] = credentials["cnonce"] # no KeyError
+                digest["rspauth"] = self.auth_response(credentials, "")
+            challenge = ", ".join(map('%s="%s"'.__mod__, digest.items()))
+            headers.append(("Authentication-Info", challenge))
+            return start_response(status, headers, exc_info)
+        return self.app(environ, modified_start_response)
 
     def auth_response(self, credentials, reqmethod):
         """internal method generating authentication tokens
@@ -771,13 +750,6 @@ class AuthDigestMiddleware:
         return self.algorithms[algo](":".join(dig))
 
     def www_authenticate(self, exception):
-        """Generates a WWW-Authenticate header.
-
-        @type exception: AuthenticationRequired
-        @param exception: reason for generating the header
-        @rtype: (str, str)
-        @returns: the header as (part_before_colon, part_after_colon)
-        """
         digest = dict(nonce=self.noncestore.newnonce(),
                       realm=self.gentoken.realm,
                       algorithm="md5",
@@ -786,16 +758,3 @@ class AuthDigestMiddleware:
             digest["stale"] = "TRUE"
         challenge = ", ".join(map('%s="%s"'.__mod__, digest.items()))
         return ("WWW-Authenticate", "Digest %s" % challenge)
-
-    def authorization_required(self, environ, start_response, exception):
-        """internal method implementing wsgi interface, serving 401 page"""
-        status = "401 Not authorized"
-        headers = [("Content-type", "text/html"),
-                   self.www_authenticate(exception)]
-        data = "<html><head><title>401 Not authorized</title></head><body><h1>"
-        data += "401 Not authorized</h1></body></html>"
-        headers.append(("Content-length", str(len(data))))
-        start_response(status, headers)
-        if environ["REQUEST_METHOD"] == "HEAD":
-            return []
-        return [data]
diff --git a/wsgitools/middlewares.py b/wsgitools/middlewares.py
index c3f2871..13ba41c 100644
--- a/wsgitools/middlewares.py
+++ b/wsgitools/middlewares.py
@@ -5,7 +5,6 @@ import sys
 import cgitb
 import binascii
 import collections
-from wsgitools.filters import CloseableList, CloseableIterator
 # Cannot use io module as it is broken in 2.6.
 # Writing a str to a io.StringIO results in an exception.
 try:
@@ -15,8 +14,12 @@ except ImportError:
 try:
     next
 except NameError:
-    def next(it):
-        return it.next()
+    def next(iterator):
+        return iterator.next()
+
+from wsgitools.filters import CloseableList, CloseableIterator
+from wsgitools.authentication import AuthenticationRequired, \
+        ProtocolViolation, AuthenticationMiddleware
 
 __all__.append("SubdirMiddleware")
 class SubdirMiddleware:
@@ -306,16 +309,17 @@ class DictAuthChecker:
         in a bool.
         @type username: str
         @type password: str
-        @type environ: {str: str}
+        @type environ: {str: object}
         @rtype: bool
         """
         return username in self.users and self.users[username] == password
 
 __all__.append("BasicAuthMiddleware")
-class BasicAuthMiddleware:
+class BasicAuthMiddleware(AuthenticationMiddleware):
     """Middleware implementing HTTP Basic Auth. Upon forwarding the request to
     the warpped application the environ dictionary is augmented by a REMOTE_USER
     key."""
+    authorization_method = "basic"
     def __init__(self, app, check_function, realm='www', app401=None):
         """
         @param app: is a WSGI application.
@@ -328,27 +332,24 @@ class BasicAuthMiddleware:
         @param app401: is an optional WSGI application to be used for error
                 messages
         """
-        self.app = app
+        AuthenticationMiddleware.__init__(self, app)
         self.check_function = check_function
         self.realm = realm
         self.app401 = app401
 
-    def __call__(self, environ, start_response):
+    def authenticate(self, auth, environ, start_response):
         """wsgi interface
         @type environ: {str: str}
         """
         assert isinstance(environ, dict)
-        auth = environ.get("HTTP_AUTHORIZATION")
-        if not auth or ' ' not in auth:
-            return self.authorization_required(environ, start_response)
-        auth_type, enc_auth_info = auth.split(None, 1)
         try:
-            auth_info = enc_auth_info.decode("base64")
+            auth_info = auth.decode("base64")
         except binascii.Error:
-            return self.authorization_required(environ, start_response)
-        if auth_type.lower() != "basic" or ':' not in auth_info:
-            return self.authorization_required(environ, start_response)
-        username, password = auth_info.split(':', 1)
+            raise ProtocolViolation("failed to base64 decode auth_info")
+        try:
+            username, password = auth_info.split(':', 1)
+        except ValueError:
+            raise ProtocolViolation("no colon found in auth_info")
         try:
             result = self.check_function(username, password, environ)
         except TypeError: # catch old interface
@@ -356,24 +357,16 @@ class BasicAuthMiddleware:
         if result:
             environ["REMOTE_USER"] = username
             return self.app(environ, start_response)
-        return self.authorization_required(environ, start_response)
+        raise AuthenticationRequired("credentials not valid")
 
-    def authorization_required(self, environ, start_response):
-        """wsgi application for indicating authorization is required.
-        @type environ: {str: str}
-        """
-        if self.app401 is None:
-            status = "401 Authorization required"
-            html = "<html><head><title>Authorization required</title></head>" \
-                    "<body><h1>Authorization required</h1></body></html>\n"
-            headers = [('Content-type', 'text/html'),
-                       ('WWW-Authenticate', 'Basic realm="%s"' % self.realm),
-                       ("Content-length", str(len(html)))]
-            start_response(status, headers)
-            if environ["REQUEST_METHOD"].upper() == "HEAD":
-                return []
-            return [html]
-        return self.app401(environ, start_response)
+    def www_authenticate(self, exception):
+        return ("WWW-Authenticate", 'Basic realm="%s"' % self.realm)
+
+    def authorization_required(self, environ, start_response, exception):
+        if self.app401 is not None:
+            return self.app401(environ, start_response)
+        return AuthenticationMiddleware.authorization_required(
+                self, environ, start_response, exception)
 
 __all__.append("TracebackMiddleware")
 class TracebackMiddleware:
-- 
cgit v1.2.3