__all__ = [] import base64 import time import sys import cgitb import collections import io import typing from wsgitools.internal import bytes2str, str2bytes from wsgitools.filters import CloseableList, CloseableIterator from wsgitools.authentication import AuthenticationRequired, \ ProtocolViolation, AuthenticationMiddleware from wsgitools.internal import ( Environ, HeaderList, OptExcInfo, StartResponse, WriteCallback, WsgiApp ) __all__.append("SubdirMiddleware") class SubdirMiddleware: """Middleware choosing wsgi applications based on a dict.""" def __init__( self, default: WsgiApp, mapping: typing.Dict[str, WsgiApp] = {} ): self.default = default self.mapping = mapping def __call__( self, environ: Environ, start_response: StartResponse ) -> typing.Iterable[bytes]: """wsgi interface""" assert isinstance(environ, dict) app = None script = environ["PATH_INFO"] path_info = "" while '/' in script: if script in self.mapping: app = self.mapping[script] break script, tail = script.rsplit('/', 1) path_info = "/%s%s" % (tail, path_info) if app is None: app = self.mapping.get(script, None) if app is None: app = self.default environ["SCRIPT_NAME"] += script environ["PATH_INFO"] = path_info return app(environ, start_response) __all__.append("NoWriteCallableMiddleware") class NoWriteCallableMiddleware: """This middleware wraps a wsgi application that needs the return value of C{start_response} function to a wsgi application that doesn't need one by writing the data to a C{BytesIO} and then making it be the first result element.""" def __init__(self, app: WsgiApp): """Wraps wsgi application app.""" self.app = app def __call__( self, environ: Environ, start_response: StartResponse ) -> typing.Iterable[bytes]: """wsgi interface""" assert isinstance(environ, dict) todo: typing.Optional[typing.Tuple[str, HeaderList]] = None sio = io.BytesIO() gotiterdata = False def write_calleable(data: bytes) -> None: assert not gotiterdata sio.write(data) def modified_start_response( status: str, headers: HeaderList, exc_info: OptExcInfo = None ) -> WriteCallback: nonlocal todo try: if sio.tell() > 0 or gotiterdata: assert exc_info is not None raise exc_info[1].with_traceback(exc_info[2]) finally: exc_info = None assert isinstance(status, str) assert isinstance(headers, list) todo = (status, headers) return write_calleable ret = self.app(environ, modified_start_response) assert hasattr(ret, "__iter__") first = b"" if not isinstance(ret, list): ret = iter(ret) stopped = False while not (stopped or first): try: first = next(ret) except StopIteration: stopped = True gotiterdata = True if stopped: ret = CloseableList(getattr(ret, "close", None), (first,)) else: gotiterdata = True assert todo is not None status, headers = todo data = sio.getvalue() if isinstance(ret, list): if data: ret.insert(0, data) start_response(status, headers) return ret data += first start_response(status, headers) return CloseableIterator(getattr(ret, "close", None), (data,), ret) __all__.append("ContentLengthMiddleware") class ContentLengthMiddleware: """Guesses the content length header if possible. @note: The application used must not use the C{write} callable returned by C{start_response}.""" maxstore: typing.Union[float, int] def __init__( self, app: WsgiApp, maxstore: typing.Union[int, typing.Tuple[()]] = 0 ): """Wraps wsgi application app. If the application returns a list, the total length of strings is available and the content length header is set unless there already is one. For an iterator data is accumulated up to a total of maxstore bytes (where maxstore=() means infinity). If the iterator is exhaused within maxstore bytes a content length header is added unless already present. @note: that setting maxstore to a value other than 0 will violate the wsgi standard """ self.app = app if maxstore == (): self.maxstore = float("inf") else: assert isinstance(maxstore, int) self.maxstore = maxstore def __call__( self, environ: Environ, start_response: StartResponse ) -> typing.Iterable[bytes]: """wsgi interface""" assert isinstance(environ, dict) todo: typing.Optional[typing.Tuple[str, HeaderList]] = None gotdata = False def modified_start_response( status: str, headers: HeaderList, exc_info: OptExcInfo = None ) -> WriteCallback: nonlocal todo try: if gotdata: assert exc_info is not None raise exc_info[1].with_traceback(exc_info[2]) finally: exc_info = None assert isinstance(status, str) assert isinstance(headers, list) todo = (status, headers) def raise_not_imp(_: bytes) -> None: raise NotImplementedError return raise_not_imp ret = self.app(environ, modified_start_response) assert hasattr(ret, "__iter__") if isinstance(ret, list): gotdata = True assert todo is not None status, headers = todo if all(k.lower() != "content-length" for k, _ in headers): length = sum(map(len, ret)) headers.append(("Content-Length", str(length))) start_response(status, headers) return ret ret = iter(ret) first = b"" stopped = False while not (first or stopped): try: first = next(ret) except StopIteration: stopped = True gotdata = True assert todo is not None status, headers = todo data = CloseableList(getattr(ret, "close", None)) if first: data.append(first) length = len(first) if all(k.lower() != "content-length" for k, _ in headers): while (not stopped) and length < self.maxstore: try: data.append(next(ret)) length += len(data[-1]) except StopIteration: stopped = True if stopped: headers.append(("Content-length", str(length))) start_response(status, headers) return data start_response(status, headers) return CloseableIterator(getattr(ret, "close", None), data, ret) def storable(environ: Environ) -> bool: if environ["REQUEST_METHOD"] != "GET": return False return True def cacheable(environ: Environ) -> bool: if environ.get("HTTP_CACHE_CONTROL", "") == "max-age=0": return False return True __all__.append("CachingMiddleware") class CachingMiddleware: """Caches reponses to requests based on C{SCRIPT_NAME}, C{PATH_INFO} and C{QUERY_STRING}.""" class CachedRequest: def __init__(self, timestamp: float): self.timestamp = timestamp self.status = "" self.headers: HeaderList = [] self.body: typing.List[bytes] = [] cache: typing.Dict[str, CachedRequest] lastcached: typing.Deque[typing.Tuple[str, float]] def __init__( self, app: WsgiApp, maxage: int = 60, storable: typing.Callable[[Environ], bool] = storable, cacheable: typing.Callable[[Environ], bool] = cacheable, ): """ @param app: is a wsgi application to be cached. @param maxage: is the number of seconds a reponse may be cached. @param storable: is a predicate that determines whether the response may be cached at all based on the C{environ} dict. @param cacheable: is a predicate that determines whether this request invalidates the cache.""" self.app = app self.maxage = maxage self.storable = storable self.cacheable = cacheable self.cache = {} self.lastcached = collections.deque() def insert_cache( self, key: str, obj: CachedRequest, now: typing.Optional[float] = None, ) -> None: if now is None: now = time.time() self.cache[key] = obj self.lastcached.append((key, now)) def prune_cache( self, maxclean: int = 16, now: typing.Optional[float] = None ) -> None: if now is None: now = time.time() old = now - self.maxage while self.lastcached and maxclean > 0: # don't do too much work at once maxclean -= 1 if self.lastcached[0][1] > old: break key, _ = self.lastcached.popleft() try: obj = self.cache[key] except KeyError: pass else: if obj.timestamp <= old: del self.cache[key] def __call__( self, environ: Environ, start_response: StartResponse ) -> typing.Iterable[bytes]: """wsgi interface""" assert isinstance(environ, dict) now = time.time() self.prune_cache(now=now) if not self.storable(environ): return self.app(environ, start_response) path = environ.get("REQUEST_METHOD", "GET") + " " path += environ.get("SCRIPT_NAME", "/") path += environ.get("PATH_INFO", '') path += "?" + environ.get("QUERY_STRING", "") if path in self.cache and self.cacheable(environ): cache_object = self.cache[path] if cache_object.timestamp + self.maxage >= now: start_response(cache_object.status, list(cache_object.headers)) return cache_object.body else: del self.cache[path] cache_object = self.CachedRequest(now) def modified_start_respesponse( status: str, headers: HeaderList, exc_info: OptExcInfo = None ) -> WriteCallback: try: if cache_object.body: assert exc_info is not None raise exc_info[1].with_traceback(exc_info[2]) finally: exc_info = None assert isinstance(status, str) assert isinstance(headers, list) cache_object.status = status cache_object.headers = headers write = start_response(status, list(headers)) def modified_write(data: bytes) -> None: cache_object.body.append(data) write(data) return modified_write ret = self.app(environ, modified_start_respesponse) assert hasattr(ret, "__iter__") if isinstance(ret, list): cache_object.body.extend(ret) self.insert_cache(path, cache_object, now) return ret def pass_through() -> typing.Iterator[bytes]: for data in ret: cache_object.body.append(data) yield data self.insert_cache(path, cache_object, now) return CloseableIterator(getattr(ret, "close", None), pass_through()) __all__.append("DictAuthChecker") class DictAuthChecker: """Verifies usernames and passwords by looking them up in a dict.""" def __init__(self, users: typing.Dict[str, str]): """ @param users: is a dict mapping usernames to password.""" self.users = users def __call__(self, username: str, password: str, environ: Environ) -> bool: """check_function interface taking username and password and resulting in a bool. """ return username in self.users and self.users[username] == password __all__.append("BasicAuthMiddleware") class BasicAuthMiddleware(AuthenticationMiddleware): """Middleware implementing HTTP Basic Auth. Upon forwarding the request to the warpped application the environ dictionary is augmented by a REMOTE_USER key.""" authorization_method = "basic" def __init__( self, app: WsgiApp, check_function: typing.Callable[[str, str, Environ], bool], realm: str = 'www', app401: typing.Optional[WsgiApp] = None, ): """ @param app: is a WSGI application. @param check_function: is a function taking three arguments username, password and environment returning a bool indicating whether the request may is allowed. The older interface of taking only the first two arguments is still supported via catching a C{TypeError}. @param app401: is an optional WSGI application to be used for error messages """ AuthenticationMiddleware.__init__(self, app) self.check_function = check_function self.realm = realm self.app401 = app401 def authenticate( self, auth: str, environ: Environ ) -> typing.Dict[str, str]: assert isinstance(auth, str) assert isinstance(environ, dict) authb = str2bytes(auth) try: auth_infob = base64.b64decode(authb) except TypeError: raise ProtocolViolation("failed to base64 decode auth_info") auth_info = bytes2str(auth_infob) try: username, password = auth_info.split(':', 1) except ValueError: raise ProtocolViolation("no colon found in auth_info") try: result = self.check_function(username, password, environ) except TypeError: # catch old interface result = self.check_function(username, password) if result: return dict(user=username) raise AuthenticationRequired("credentials not valid") def www_authenticate( self, exception: AuthenticationRequired ) -> typing.Tuple[str, str]: return ("WWW-Authenticate", 'Basic realm="%s"' % self.realm) def authorization_required( self, environ: Environ, start_response: StartResponse, exception: AuthenticationRequired, ) -> typing.Iterable[bytes]: if self.app401 is not None: return self.app401(environ, start_response) return AuthenticationMiddleware.authorization_required( self, environ, start_response, exception) __all__.append("TracebackMiddleware") class TracebackMiddleware: """In case the application throws an exception this middleware will show an html-formatted traceback using C{cgitb}.""" def __init__(self, app: WsgiApp): """app is the wsgi application to proxy.""" self.app = app def __call__( self, environ: Environ, start_response: StartResponse ) -> typing.Iterable[bytes]: """wsgi interface""" try: assert isinstance(environ, dict) ret = self.app(environ, start_response) assert hasattr(ret, "__iter__") if isinstance(ret, list): return ret # Take the first element of the iterator and possibly catch an # exception there. ret = iter(ret) try: first = next(ret) except StopIteration: return CloseableList(getattr(ret, "close", None), []) return CloseableIterator(getattr(ret, "close", None), [first], ret) except: exc_info = sys.exc_info() data = cgitb.html(exc_info) start_response("200 OK", [("Content-type", "text/html"), ("Content-length", str(len(data)))]) if environ["REQUEST_METHOD"].upper() == "HEAD": return [] return [str2bytes(data)]