diff options
Diffstat (limited to 'wsgitools/middlewares.py')
-rw-r--r-- | wsgitools/middlewares.py | 59 |
1 files changed, 26 insertions, 33 deletions
diff --git a/wsgitools/middlewares.py b/wsgitools/middlewares.py index c3f2871..13ba41c 100644 --- a/wsgitools/middlewares.py +++ b/wsgitools/middlewares.py @@ -5,7 +5,6 @@ import sys import cgitb import binascii import collections -from wsgitools.filters import CloseableList, CloseableIterator # Cannot use io module as it is broken in 2.6. # Writing a str to a io.StringIO results in an exception. try: @@ -15,8 +14,12 @@ except ImportError: try: next except NameError: - def next(it): - return it.next() + def next(iterator): + return iterator.next() + +from wsgitools.filters import CloseableList, CloseableIterator +from wsgitools.authentication import AuthenticationRequired, \ + ProtocolViolation, AuthenticationMiddleware __all__.append("SubdirMiddleware") class SubdirMiddleware: @@ -306,16 +309,17 @@ class DictAuthChecker: in a bool. @type username: str @type password: str - @type environ: {str: str} + @type environ: {str: object} @rtype: bool """ return username in self.users and self.users[username] == password __all__.append("BasicAuthMiddleware") -class BasicAuthMiddleware: +class BasicAuthMiddleware(AuthenticationMiddleware): """Middleware implementing HTTP Basic Auth. Upon forwarding the request to the warpped application the environ dictionary is augmented by a REMOTE_USER key.""" + authorization_method = "basic" def __init__(self, app, check_function, realm='www', app401=None): """ @param app: is a WSGI application. @@ -328,27 +332,24 @@ class BasicAuthMiddleware: @param app401: is an optional WSGI application to be used for error messages """ - self.app = app + AuthenticationMiddleware.__init__(self, app) self.check_function = check_function self.realm = realm self.app401 = app401 - def __call__(self, environ, start_response): + def authenticate(self, auth, environ, start_response): """wsgi interface @type environ: {str: str} """ assert isinstance(environ, dict) - auth = environ.get("HTTP_AUTHORIZATION") - if not auth or ' ' not in auth: - return self.authorization_required(environ, start_response) - auth_type, enc_auth_info = auth.split(None, 1) try: - auth_info = enc_auth_info.decode("base64") + auth_info = auth.decode("base64") except binascii.Error: - return self.authorization_required(environ, start_response) - if auth_type.lower() != "basic" or ':' not in auth_info: - return self.authorization_required(environ, start_response) - username, password = auth_info.split(':', 1) + raise ProtocolViolation("failed to base64 decode auth_info") + try: + username, password = auth_info.split(':', 1) + except ValueError: + raise ProtocolViolation("no colon found in auth_info") try: result = self.check_function(username, password, environ) except TypeError: # catch old interface @@ -356,24 +357,16 @@ class BasicAuthMiddleware: if result: environ["REMOTE_USER"] = username return self.app(environ, start_response) - return self.authorization_required(environ, start_response) + raise AuthenticationRequired("credentials not valid") - def authorization_required(self, environ, start_response): - """wsgi application for indicating authorization is required. - @type environ: {str: str} - """ - if self.app401 is None: - status = "401 Authorization required" - html = "<html><head><title>Authorization required</title></head>" \ - "<body><h1>Authorization required</h1></body></html>\n" - headers = [('Content-type', 'text/html'), - ('WWW-Authenticate', 'Basic realm="%s"' % self.realm), - ("Content-length", str(len(html)))] - start_response(status, headers) - if environ["REQUEST_METHOD"].upper() == "HEAD": - return [] - return [html] - return self.app401(environ, start_response) + def www_authenticate(self, exception): + return ("WWW-Authenticate", 'Basic realm="%s"' % self.realm) + + def authorization_required(self, environ, start_response, exception): + if self.app401 is not None: + return self.app401(environ, start_response) + return AuthenticationMiddleware.authorization_required( + self, environ, start_response, exception) __all__.append("TracebackMiddleware") class TracebackMiddleware: |