summaryrefslogtreecommitdiff
path: root/wsgitools/middlewares.py
diff options
context:
space:
mode:
authorHelmut Grohne <helmut@subdivi.de>2007-04-14 22:37:26 +0200
committerHelmut Grohne <helmut@subdivi.de>2007-04-14 22:37:26 +0200
commit2435f82361f6bc4dcd51e1305905ecbbb5757f50 (patch)
treedf738a4ffdbd212383b6d8df22cb3b74e1fc83f9 /wsgitools/middlewares.py
downloadwsgitools-2435f82361f6bc4dcd51e1305905ecbbb5757f50.tar.gz
initial tree
Diffstat (limited to 'wsgitools/middlewares.py')
-rw-r--r--wsgitools/middlewares.py267
1 files changed, 267 insertions, 0 deletions
diff --git a/wsgitools/middlewares.py b/wsgitools/middlewares.py
new file mode 100644
index 0000000..9eab2a1
--- /dev/null
+++ b/wsgitools/middlewares.py
@@ -0,0 +1,267 @@
+__all__ = []
+
+import time
+from filters import CloseableList, CloseableIterator
+try:
+ import cStringIO as StringIO
+except ImportError:
+ import StringIO
+
+__all__.append("SubdirMiddleware")
+class SubdirMiddleware:
+ def __init__(self, default, mapping={}):
+ self.default = default
+ self.mapping = mapping
+ def __call__(self, environ, start_response):
+ app = None
+ script = environ["PATH_INFO"]
+ path_info = ""
+ while '/' in script:
+ if script in self.mapping:
+ app = self.mapping[script]
+ break
+ script, tail = script.rsplit('/', 1)
+ path_info = "/%s%s" % (tail, path_info)
+ if app is None:
+ app = self.mapping.get(script, None)
+ if app is None:
+ app = self.default
+ environ["SCRIPT_NAME"] += script
+ environ["PATH_INFO"] = path_info
+ return app(environ, start_response)
+
+__all__.append("NoWriteCallableMiddleware")
+class NoWriteCallableMiddleware:
+ """This middleware wraps a wsgi application that needs the return value of
+ start_response function to a wsgi application that doesn't need one by
+ writing the data to a StringIO and then making it be the first result
+ element."""
+ def __init__(self, app):
+ """Wraps wsgi application app."""
+ self.app = app
+ def __call__(self, environ, start_response):
+ """wsgi interface"""
+ todo = []
+ def modified_start_response(status, headers, exc_info=None):
+ if exc_info is not None:
+ todo.append(None)
+ return start_response(status, headers)
+ else:
+ sio = StringIO.StringIO()
+ todo.append((status, headers, sio))
+ return sio.write
+
+ ret = self.app(environ, modified_start_response)
+
+ if todo and todo[0] is None:
+ return ret
+
+ if isinstance(ret, list):
+ status, headers, data = todo[0]
+ data = data.getvalue()
+ if data:
+ ret.insert(0, data)
+ start_response(status, headers)
+ return ret
+
+ ret = iter(ret)
+ stopped = False
+ try:
+ first = ret.next()
+ except StopIteration:
+ stopped = True
+
+ status, headers, data = todo[0]
+ data = data.getvalue()
+ start_response(status, headers)
+
+ if stopped:
+ return CloseableList(getattr(ret, "close", None), (data,))
+
+ return CloseableIterator(getattr(ret, "close", None),
+ (data, first), ret)
+
+__all__.append("ContentLengthMiddleware")
+class ContentLengthMiddleware:
+ """Guesses the content length header if possible.
+ Note: The application used must not use the write callable returned by
+ start_response."""
+ def __init__(self, app, maxstore=0):
+ """Wraps wsgi application app. It can also store the first result bytes
+ to possibly return a list of strings which will make guessing the size
+ of iterators possible. At most maxstore bytes will be accumulated.
+ Please note that a value larger than 0 will violate the wsgi standard.
+ The magical value () will make it always gather all data.
+ """
+ self.app = app
+ self.maxstore = maxstore
+ def __call__(self, environ, start_response):
+ """wsgi interface"""
+ todo = []
+ def modified_start_response(status, headers, exc_info=None):
+ if (exc_info is not None or
+ [v for h, v in headers if h.lower() == "content-length"]):
+ todo[:] = (None,)
+ return start_response(status, headers, exc_info)
+ else:
+ todo[:] = ((status, headers),)
+ def raise_not_imp(*args):
+ raise NotImplementedError
+ return raise_not_imp
+
+ ret = self.app(environ, modified_start_response)
+
+ if todo and todo[0] is None: # nothing to do
+ #print "content-length: nothing"
+ return ret
+
+ if isinstance(ret, list):
+ #print "content-length: simple"
+ status, headers = todo[0]
+ length = sum(map(len, ret))
+ headers.append(("Content-length", str(length)))
+ start_response(status, headers)
+ return ret
+
+ ret = iter(ret)
+ stopped = False
+ data = CloseableList(getattr(ret, "close", None))
+ length = 0
+ try:
+ data.append(ret.next()) # fills todo
+ length += len(data[-1])
+ except StopIteration:
+ stopped = True
+
+ status, headers = todo[0]
+
+ while (not stopped) and length < self.maxstore:
+ try:
+ data.append(ret.next())
+ length += len(data[-1])
+ except StopIteration:
+ stopped = True
+
+ if stopped:
+ #print "content-length: gathered"
+ headers.append(("Content-length", str(length)))
+ start_response(status, headers)
+ return data
+
+ #print "content-length: passthrough"
+ start_response(status, headers)
+
+ return CloseableIterator(getattr(ret, "close", None), data, ret)
+
+def storable(environ):
+ if environ["REQUEST_METHOD"] != "GET":
+ return False
+ return True
+
+def cacheable(environ):
+ if environ.get("HTTP_CACHE_CONTROL", "") == "max-age=0":
+ return False
+ return True
+
+__all__.append("CachingMiddleware")
+class CachingMiddleware:
+ """Caches reponses to requests based on SCRIPT_NAME, PATH_INFO and
+ QUERY_STRING."""
+ def __init__(self, app, maxage=60, storable=storable, cacheable=cacheable):
+ """app is a wsgi application to be cached.
+ maxage is the number of seconds a reponse may be cached.
+ storable is a predicated that determines whether the response may be
+ cached at all based on the environ dict.
+ cacheable is a predicate that determines whether this request
+ invalidates the cache."""
+ self.app = app
+ self.maxage = maxage
+ self.storable = storable
+ self.cacheable = cacheable
+ self.cache = {}
+ def __call__(self, environ, start_response):
+ """wsgi interface"""
+ if not self.storable(environ):
+ return self.app(environ, start_response)
+ path = environ.get("SCRIPT_NAME", "/")
+ path += environ.get("PATH_INFO", '')
+ path += "?" + environ.get("QUERY_STRING", "")
+ if self.cacheable(environ) and path in self.cache:
+ if self.cache[path][0] + self.maxage >= time.time():
+ start_response(self.cache[path][1], self.cache[path][2])
+ return self.cache[path][3]
+ else:
+ del self.cache[path]
+ cache_object = [time.time(), "", [], []]
+ def modified_start_respesponse(status, headers, exc_info):
+ if exc_info is not None:
+ return self.app(status, headers, exc_info)
+ cache_object[1] = status
+ cache_object[2] = headers
+ write = start_response(status, headers)
+ def modified_write(data):
+ cache_object[3].append(data)
+ write(data)
+ return modified_write
+ ret = self.app(environ, modified_start_respesponse)
+ if isinstance(ret, list):
+ cache_object[3].extend(ret)
+ self.cache[path] = cache_object
+ return ret
+ def pass_through():
+ for data in ret:
+ cache_object[3].append(data)
+ yield data
+ self.cache[path] = cache_object
+ return CloseableIterator(getattr(ret, "close", None), pass_through())
+
+__all__.append("DictAuthChecker")
+class DictAuthChecker:
+ def __init__(self, users):
+ self.users = users
+ def __call__(self, username, password):
+ return username in self.users and self.users[username] == password
+
+__all__.append("BasicAuthMiddleware")
+class BasicAuthMiddleware:
+ """Middleware implementing HTTP Basic Auth."""
+ def __init__(self, app, check_function, realm='www'):
+ """app is a WSGI application.
+ check_function is a function taking two arguments username and password
+ returning a bool indicating whether the request may is
+ allowed."""
+ self.app = app
+ self.check_function = check_function
+ self.realm = realm
+
+ def __call__(self, environ, start_response):
+ """wsgi interface"""
+ auth = environ.get("HTTP_AUTHORIZATION")
+ if not auth or ' ' not in auth:
+ return self.authorization_required(environ, start_response)
+ auth_type, enc_auth_info = auth.split(None, 1)
+ try:
+ auth_info = enc_auth_info.decode("base64")
+ except: # It throws some non-standard exception.
+ return self.authorization_required(environ, start_response)
+ if auth_type.lower() != "basic" or ':' not in auth_info:
+ return self.authorization_required(environ, start_response)
+ username, password = auth_info.split(':', 1)
+ if self.check_function(username, password):
+ environ["REMOTE_USER"] = username
+ return self.app(environ, start_response)
+ return self.authorization_required(environ, start_response)
+
+ def authorization_required(self, environ, start_response):
+ """wsgi application for indicating authorization is required."""
+ status = "401 Authorization required"
+ headers = [('Content-type', 'text/html'),
+ ('WWW-Authenticate', 'Basic realm="%s"' % self.realm)]
+ if environ["REQUEST_METHOD"] == "HEAD":
+ start_response(status, headers)
+ return []
+ html = "<html><head><title>Authorization required</title></head>" + \
+ "<body><h1>Authorization required</h1></body></html>\n"
+ headers.append(('Content-length', len(html)))
+ start_response(status, headers)
+ return [html]