diff options
author | Helmut Grohne <helmut@subdivi.de> | 2007-04-14 22:37:26 +0200 |
---|---|---|
committer | Helmut Grohne <helmut@subdivi.de> | 2007-04-14 22:37:26 +0200 |
commit | 2435f82361f6bc4dcd51e1305905ecbbb5757f50 (patch) | |
tree | df738a4ffdbd212383b6d8df22cb3b74e1fc83f9 /wsgitools/middlewares.py | |
download | wsgitools-2435f82361f6bc4dcd51e1305905ecbbb5757f50.tar.gz |
initial tree
Diffstat (limited to 'wsgitools/middlewares.py')
-rw-r--r-- | wsgitools/middlewares.py | 267 |
1 files changed, 267 insertions, 0 deletions
diff --git a/wsgitools/middlewares.py b/wsgitools/middlewares.py new file mode 100644 index 0000000..9eab2a1 --- /dev/null +++ b/wsgitools/middlewares.py @@ -0,0 +1,267 @@ +__all__ = [] + +import time +from filters import CloseableList, CloseableIterator +try: + import cStringIO as StringIO +except ImportError: + import StringIO + +__all__.append("SubdirMiddleware") +class SubdirMiddleware: + def __init__(self, default, mapping={}): + self.default = default + self.mapping = mapping + def __call__(self, environ, start_response): + app = None + script = environ["PATH_INFO"] + path_info = "" + while '/' in script: + if script in self.mapping: + app = self.mapping[script] + break + script, tail = script.rsplit('/', 1) + path_info = "/%s%s" % (tail, path_info) + if app is None: + app = self.mapping.get(script, None) + if app is None: + app = self.default + environ["SCRIPT_NAME"] += script + environ["PATH_INFO"] = path_info + return app(environ, start_response) + +__all__.append("NoWriteCallableMiddleware") +class NoWriteCallableMiddleware: + """This middleware wraps a wsgi application that needs the return value of + start_response function to a wsgi application that doesn't need one by + writing the data to a StringIO and then making it be the first result + element.""" + def __init__(self, app): + """Wraps wsgi application app.""" + self.app = app + def __call__(self, environ, start_response): + """wsgi interface""" + todo = [] + def modified_start_response(status, headers, exc_info=None): + if exc_info is not None: + todo.append(None) + return start_response(status, headers) + else: + sio = StringIO.StringIO() + todo.append((status, headers, sio)) + return sio.write + + ret = self.app(environ, modified_start_response) + + if todo and todo[0] is None: + return ret + + if isinstance(ret, list): + status, headers, data = todo[0] + data = data.getvalue() + if data: + ret.insert(0, data) + start_response(status, headers) + return ret + + ret = iter(ret) + stopped = False + try: + first = ret.next() + except StopIteration: + stopped = True + + status, headers, data = todo[0] + data = data.getvalue() + start_response(status, headers) + + if stopped: + return CloseableList(getattr(ret, "close", None), (data,)) + + return CloseableIterator(getattr(ret, "close", None), + (data, first), ret) + +__all__.append("ContentLengthMiddleware") +class ContentLengthMiddleware: + """Guesses the content length header if possible. + Note: The application used must not use the write callable returned by + start_response.""" + def __init__(self, app, maxstore=0): + """Wraps wsgi application app. It can also store the first result bytes + to possibly return a list of strings which will make guessing the size + of iterators possible. At most maxstore bytes will be accumulated. + Please note that a value larger than 0 will violate the wsgi standard. + The magical value () will make it always gather all data. + """ + self.app = app + self.maxstore = maxstore + def __call__(self, environ, start_response): + """wsgi interface""" + todo = [] + def modified_start_response(status, headers, exc_info=None): + if (exc_info is not None or + [v for h, v in headers if h.lower() == "content-length"]): + todo[:] = (None,) + return start_response(status, headers, exc_info) + else: + todo[:] = ((status, headers),) + def raise_not_imp(*args): + raise NotImplementedError + return raise_not_imp + + ret = self.app(environ, modified_start_response) + + if todo and todo[0] is None: # nothing to do + #print "content-length: nothing" + return ret + + if isinstance(ret, list): + #print "content-length: simple" + status, headers = todo[0] + length = sum(map(len, ret)) + headers.append(("Content-length", str(length))) + start_response(status, headers) + return ret + + ret = iter(ret) + stopped = False + data = CloseableList(getattr(ret, "close", None)) + length = 0 + try: + data.append(ret.next()) # fills todo + length += len(data[-1]) + except StopIteration: + stopped = True + + status, headers = todo[0] + + while (not stopped) and length < self.maxstore: + try: + data.append(ret.next()) + length += len(data[-1]) + except StopIteration: + stopped = True + + if stopped: + #print "content-length: gathered" + headers.append(("Content-length", str(length))) + start_response(status, headers) + return data + + #print "content-length: passthrough" + start_response(status, headers) + + return CloseableIterator(getattr(ret, "close", None), data, ret) + +def storable(environ): + if environ["REQUEST_METHOD"] != "GET": + return False + return True + +def cacheable(environ): + if environ.get("HTTP_CACHE_CONTROL", "") == "max-age=0": + return False + return True + +__all__.append("CachingMiddleware") +class CachingMiddleware: + """Caches reponses to requests based on SCRIPT_NAME, PATH_INFO and + QUERY_STRING.""" + def __init__(self, app, maxage=60, storable=storable, cacheable=cacheable): + """app is a wsgi application to be cached. + maxage is the number of seconds a reponse may be cached. + storable is a predicated that determines whether the response may be + cached at all based on the environ dict. + cacheable is a predicate that determines whether this request + invalidates the cache.""" + self.app = app + self.maxage = maxage + self.storable = storable + self.cacheable = cacheable + self.cache = {} + def __call__(self, environ, start_response): + """wsgi interface""" + if not self.storable(environ): + return self.app(environ, start_response) + path = environ.get("SCRIPT_NAME", "/") + path += environ.get("PATH_INFO", '') + path += "?" + environ.get("QUERY_STRING", "") + if self.cacheable(environ) and path in self.cache: + if self.cache[path][0] + self.maxage >= time.time(): + start_response(self.cache[path][1], self.cache[path][2]) + return self.cache[path][3] + else: + del self.cache[path] + cache_object = [time.time(), "", [], []] + def modified_start_respesponse(status, headers, exc_info): + if exc_info is not None: + return self.app(status, headers, exc_info) + cache_object[1] = status + cache_object[2] = headers + write = start_response(status, headers) + def modified_write(data): + cache_object[3].append(data) + write(data) + return modified_write + ret = self.app(environ, modified_start_respesponse) + if isinstance(ret, list): + cache_object[3].extend(ret) + self.cache[path] = cache_object + return ret + def pass_through(): + for data in ret: + cache_object[3].append(data) + yield data + self.cache[path] = cache_object + return CloseableIterator(getattr(ret, "close", None), pass_through()) + +__all__.append("DictAuthChecker") +class DictAuthChecker: + def __init__(self, users): + self.users = users + def __call__(self, username, password): + return username in self.users and self.users[username] == password + +__all__.append("BasicAuthMiddleware") +class BasicAuthMiddleware: + """Middleware implementing HTTP Basic Auth.""" + def __init__(self, app, check_function, realm='www'): + """app is a WSGI application. + check_function is a function taking two arguments username and password + returning a bool indicating whether the request may is + allowed.""" + self.app = app + self.check_function = check_function + self.realm = realm + + def __call__(self, environ, start_response): + """wsgi interface""" + auth = environ.get("HTTP_AUTHORIZATION") + if not auth or ' ' not in auth: + return self.authorization_required(environ, start_response) + auth_type, enc_auth_info = auth.split(None, 1) + try: + auth_info = enc_auth_info.decode("base64") + except: # It throws some non-standard exception. + return self.authorization_required(environ, start_response) + if auth_type.lower() != "basic" or ':' not in auth_info: + return self.authorization_required(environ, start_response) + username, password = auth_info.split(':', 1) + if self.check_function(username, password): + environ["REMOTE_USER"] = username + return self.app(environ, start_response) + return self.authorization_required(environ, start_response) + + def authorization_required(self, environ, start_response): + """wsgi application for indicating authorization is required.""" + status = "401 Authorization required" + headers = [('Content-type', 'text/html'), + ('WWW-Authenticate', 'Basic realm="%s"' % self.realm)] + if environ["REQUEST_METHOD"] == "HEAD": + start_response(status, headers) + return [] + html = "<html><head><title>Authorization required</title></head>" + \ + "<body><h1>Authorization required</h1></body></html>\n" + headers.append(('Content-length', len(html))) + start_response(status, headers) + return [html] |