From a41066b413489b407b9d99174af697563ad680b9 Mon Sep 17 00:00:00 2001 From: Helmut Grohne Date: Mon, 13 Apr 2020 21:30:34 +0200 Subject: add type hints to all of the code In order to use type hint syntax, we need to bump the minimum Python version to 3.7 and some of the features such as Literal and Protocol are opted in when a sufficiently recent Python is available. This does not make all of the code pass type checking with mypy. A number of typing issues remain, but the output of mypy becomes something one can read through. In adding type hints, a lot of epydoc @type annotations are removed as redundant. This update also adopts black-style line breaking. --- wsgitools/applications.py | 44 ++++--- wsgitools/authentication.py | 43 ++++--- wsgitools/digest.py | 220 ++++++++++++++++++----------------- wsgitools/filters.py | 252 +++++++++++++++++++---------------------- wsgitools/internal.py | 37 +++++- wsgitools/middlewares.py | 176 +++++++++++++++++----------- wsgitools/scgi/__init__.py | 25 ++-- wsgitools/scgi/asynchronous.py | 98 +++++++++++----- wsgitools/scgi/forkpool.py | 95 +++++++++------- 9 files changed, 565 insertions(+), 425 deletions(-) (limited to 'wsgitools') diff --git a/wsgitools/applications.py b/wsgitools/applications.py index 9894cf8..f51fccf 100644 --- a/wsgitools/applications.py +++ b/wsgitools/applications.py @@ -1,4 +1,7 @@ import os.path +import typing + +from wsgitools.internal import Environ, HeaderList, StartResponse __all__ = [] @@ -9,17 +12,21 @@ class StaticContent: receives with method GET or HEAD (content stripped). If not present, a content-length header is computed. """ - def __init__(self, status, headers, content, anymethod=False): + content: typing.Iterable[bytes] + + def __init__( + self, + status: str, + headers: HeaderList, + content: typing.Union[bytes, typing.Iterable[bytes]], + anymethod: bool = False, + ): """ - @type status: str @param status: is the HTTP status returned to the browser (ex: "200 OK") - @type headers: list @param headers: is a list of C{(header, value)} pairs being delivered as HTTP headers - @type content: bytes @param content: contains the data to be delivered to the client. It is either a string or some kind of iterable yielding strings. - @type anymethod: boolean @param anymethod: determines whether any request method should be answered with this response instead of a 501 """ @@ -40,7 +47,9 @@ class StaticContent: if length >= 0: if not [v for h, v in headers if h.lower() == "content-length"]: headers.append(("Content-length", str(length))) - def __call__(self, environ, start_response): + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: """wsgi interface""" assert isinstance(environ, dict) if environ["REQUEST_METHOD"].upper() not in ["GET", "HEAD"] and \ @@ -61,20 +70,21 @@ class StaticFile: request it receives with method GET or HEAD (content stripped). If not present, a content-length header is computed. """ - def __init__(self, filelike, status="200 OK", headers=list(), - blocksize=4096): + def __init__( + self, + filelike: typing.Union[str, typing.BinaryIO], + status: str = "200 OK", + headers: HeaderList = list(), + blocksize: int = 4096, + ): """ - @type status: str @param status: is the HTTP status returned to the browser - @type headers: [(str, str)] @param headers: is a list of C{(header, value)} pairs being delivered as HTTP headers - @type filelike: str or file-like @param filelike: may either be an path in the local file system or a file-like that must support C{read(size)} and C{seek(0)}. If C{tell()} is present, C{seek(0, 2)} and C{tell()} will be used to compute the content-length. - @type blocksize: int @param blocksize: the content is provided in chunks of this size """ self.filelike = filelike @@ -82,7 +92,9 @@ class StaticFile: self.headers = headers self.blocksize = blocksize - def _serve_in_chunks(self, stream): + def _serve_in_chunks( + self, stream: typing.BinaryIO + ) -> typing.Iterator[bytes]: """internal method yielding data from the given stream""" while True: data = stream.read(self.blocksize) @@ -92,7 +104,9 @@ class StaticFile: if isinstance(self.filelike, str): stream.close() - def __call__(self, environ, start_response): + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: """wsgi interface""" assert isinstance(environ, dict) @@ -102,7 +116,7 @@ class StaticFile: [("Content-length", str(len(resp)))]) return [resp] - stream = None + stream: typing.Optional[typing.BinaryIO] = None size = -1 try: if isinstance(self.filelike, str): diff --git a/wsgitools/authentication.py b/wsgitools/authentication.py index c076d7f..345a0ae 100644 --- a/wsgitools/authentication.py +++ b/wsgitools/authentication.py @@ -1,3 +1,9 @@ +import typing + +from wsgitools.internal import ( + Environ, HeaderList, OptExcInfo, StartResponse, WriteCallback, WsgiApp +) + __all__ = [] class AuthenticationRequired(Exception): @@ -15,28 +21,24 @@ class AuthenticationMiddleware: @cvar authorization_method: the implemented Authorization method. It will be verified against Authorization headers. Subclasses must define this attribute. - @type authorization_method: str """ - authorization_method = None - def __init__(self, app): + authorization_method: typing.ClassVar[str] + def __init__(self, app: WsgiApp): """ @param app: is a WSGI application. """ assert self.authorization_method is not None self.app = app - def authenticate(self, auth, environ): + def authenticate(self, auth: str, environ: Environ) -> Environ: """Try to authenticate a request. The Authorization header is examined and checked agains the L{authorization_method} before being passed to this method. This method must either raise an AuthenticationRequired instance or return a dictionary explaining what was successfully authenticated. - @type auth: str @param auth: is the part of the Authorization header after the method - @type environ: {str: object} @param environ: is the environment passed with a WSGI request - @rtype: {str: object} @returns: a dictionary that provides a key "user" listing the authenticated username as a string. It may also provide the key "outheaders" with a [(str, str)] value to extend the response @@ -45,11 +47,10 @@ class AuthenticationMiddleware: """ raise NotImplementedError - def __call__(self, environ, start_response): - """wsgi interface - - @type environ: {str: object} - """ + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: + """wsgi interface""" assert isinstance(environ, dict) try: try: @@ -70,7 +71,9 @@ class AuthenticationMiddleware: assert "user" in result environ["REMOTE_USER"] = result["user"] if "outheaders" in result: - def modified_start_response(status, headers, exc_info=None): + def modified_start_response( + status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: assert isinstance(headers, list) headers.extend(result["outheaders"]) return start_response(status, headers, exc_info) @@ -78,22 +81,26 @@ class AuthenticationMiddleware: modified_start_response = start_response return self.app(environ, modified_start_response) - def www_authenticate(self, exception): + def www_authenticate( + self, exception: AuthenticationRequired + ) -> typing.Tuple[str, str]: """Generates a WWW-Authenticate header. Subclasses must implement this method. - @type exception: L{AuthenticationRequired} @param exception: reason for generating the header - @rtype: (str, str) @returns: the header as (part_before_colon, part_after_colon) """ raise NotImplementedError - def authorization_required(self, environ, start_response, exception): + def authorization_required( + self, + environ: Environ, + start_response: StartResponse, + exception: AuthenticationRequired, + ) -> typing.Iterable[bytes]: """Generate an error page after failed authentication. Apart from the exception parameter, this method behaves like a WSGI application. - @type exception: L{AuthenticationRequired} @param exception: reason for the authentication failure """ status = "401 Authorization required" diff --git a/wsgitools/digest.py b/wsgitools/digest.py index 6eb4cb3..18925df 100644 --- a/wsgitools/digest.py +++ b/wsgitools/digest.py @@ -23,25 +23,22 @@ except ImportError: import random sysrand = random.SystemRandom() randbits = sysrand.getrandbits - def compare_digest(a, b): + def compare_digest(a: str, b: str) -> bool: return a == b +import sys +import typing -from wsgitools.internal import bytes2str, str2bytes, textopen +from wsgitools.internal import bytes2str, Environ, str2bytes, textopen, WsgiApp from wsgitools.authentication import AuthenticationRequired, \ ProtocolViolation, AuthenticationMiddleware -def md5hex(data): - """ - @type data: str - @rtype: str - """ +def md5hex(data: str) -> str: return hashlib.md5(str2bytes(data)).hexdigest() -def gen_rand_str(bytesentropy=33): +def gen_rand_str(bytesentropy: int = 33) -> str: """ Generates a string of random base64 characters. @param bytesentropy: is the number of random 8bit values to be used - @rtype: str >>> gen_rand_str() != gen_rand_str() True @@ -53,7 +50,7 @@ def gen_rand_str(bytesentropy=33): randstr = bytes2str(randbytes) return randstr -def parse_digest_response(data): +def parse_digest_response(data: str) -> typing.Dict[str, str]: """internal @raises ValueError: @@ -118,13 +115,11 @@ def parse_digest_response(data): value, data = data.split(',', 1) result[key] = value -def format_digest(mapping): +def format_digest(mapping: typing.Dict[str, typing.Tuple[str, bool]]) -> str: """internal - @type mapping: {str: (str, bool)} @param mapping: a mapping of keys to values and a boolean that determines whether the value needs quoting. - @rtype: str @note: the RFC specifies which values must be quoted and which must not be quoted. """ @@ -150,38 +145,36 @@ class AbstractTokenGenerator: L{AuthDigestMiddleware}. @ivar realm: is a string according to RFC2617. - @type realm: str """ - def __init__(self, realm): - """ - @type realm: str - """ + realm: str + def __init__(self, realm: str): assert isinstance(realm, str) self.realm = realm - def __call__(self, username, algo="md5"): + def __call__( + self, username: str, algo: str = "md5" + ) -> typing.Optional[str]: """Generates an authentication token from a username. - @type username: str - @type algo: str @param algo: currently the only value supported by L{AuthDigestMiddleware} is "md5" - @rtype: str or None @returns: a valid token or None to signal that authentication should fail """ raise NotImplementedError - def check_password(self, username, password, environ=None): + def check_password( + self, + username: str, + password: str, + environ: typing.Optional[Environ] = None, + ) -> bool: """ This function implements the interface for verifying passwords used by L{BasicAuthMiddleware}. It works by computing a token from the user and comparing it to the token returned by the __call__ method. - @type username: str - @type password: str @param environ: ignored - @rtype: bool """ assert isinstance(username, str) assert isinstance(password, str) @@ -191,16 +184,26 @@ class AbstractTokenGenerator: return False return compare_digest(md5hex(token), expected) +if sys.version_info >= (3, 11): + class TokenGenerator(typing.Protocol): + realm: str + def __call__( + self, username: str, algo: str = "md5" + ) -> typing.Optional[str]: + ... +else: + TokenGenerator = typing.Callable[[str, str], typing.Optional[str]] + __all__.append("AuthTokenGenerator") class AuthTokenGenerator(AbstractTokenGenerator): """Generates authentication tokens for L{AuthDigestMiddleware}. The interface consists of beeing callable with a username and having a realm attribute being a string.""" - def __init__(self, realm, getpass): + def __init__( + self, realm: str, getpass: typing.Callable[[str], typing.Optional[str]] + ): """ - @type realm: str @param realm: is a string according to RFC2617. - @type getpass: str -> (str or None) @param getpass: this function is called with a username and password is expected as result. C{None} may be used as an invalid password. An example for getpass would be C{{username: password}.get}. @@ -208,7 +211,9 @@ class AuthTokenGenerator(AbstractTokenGenerator): AbstractTokenGenerator.__init__(self, realm) self.getpass = getpass - def __call__(self, username, algo="md5"): + def __call__( + self, username: str, algo: str = "md5" + ) -> typing.Optional[str]: assert isinstance(username, str) assert algo.lower() in ["md5", "md5-sess"] password = self.getpass(username) @@ -222,12 +227,13 @@ class HtdigestTokenGenerator(AbstractTokenGenerator): """Reads authentication tokens for L{AuthDigestMiddleware} from an apache htdigest file. """ - def __init__(self, realm, htdigestfile, ignoreparseerrors=False): + users: typing.Dict[str, str] + + def __init__( + self, realm: str, htdigestfile: str, ignoreparseerrors: bool = False + ): """ - @type realm: str - @type htdigestfile: str @param htdigestfile: path to the .htdigest file - @type ignoreparseerrors: bool @param ignoreparseerrors: passed to readhtdigest @raises IOError: @raises ValueError: @@ -236,10 +242,10 @@ class HtdigestTokenGenerator(AbstractTokenGenerator): self.users = {} self.readhtdigest(htdigestfile, ignoreparseerrors) - def readhtdigest(self, htdigestfile, ignoreparseerrors=False): + def readhtdigest( + self, htdigestfile: str, ignoreparseerrors: bool = False + ) -> None: """ - @type htdigestfile: str - @type ignoreparseerrors: bool @param ignoreparseerrors: do not raise ValueErrors for bad files @raises IOError: @raises ValueError: @@ -260,7 +266,7 @@ class HtdigestTokenGenerator(AbstractTokenGenerator): raise ValueError("duplicate user in htdigest file") self.users[user] = token - def __call__(self, user, algo="md5"): + def __call__(self, user: str, algo: str = "md5") -> typing.Optional[str]: assert algo.lower() in ["md5", "md5-sess"] return self.users.get(user) @@ -269,7 +275,9 @@ class UpdatingHtdigestTokenGenerator(HtdigestTokenGenerator): """Behaves like L{HtdigestTokenGenerator}, checks the htdigest file for changes on each invocation. """ - def __init__(self, realm, htdigestfile, ignoreparseerrors=False): + def __init__( + self, realm: str, htdigestfile: str, ignoreparseerrors: bool = False + ): assert isinstance(htdigestfile, str) # Need to stat the file before calling parent ctor to detect # modifications. @@ -282,7 +290,7 @@ class UpdatingHtdigestTokenGenerator(HtdigestTokenGenerator): self.htdigestfile = htdigestfile self.ignoreparseerrors = ignoreparseerrors - def __call__(self, user, algo="md5"): + def __call__(self, user: str, algo: str = "md5") -> typing.Optional[str]: # The interface does not permit raising exceptions, so all we can do is # fail by returning None. try: @@ -301,36 +309,30 @@ class UpdatingHtdigestTokenGenerator(HtdigestTokenGenerator): __all__.append("NonceStoreBase") class NonceStoreBase: """Nonce storage interface.""" - def __init__(self): + def __init__(self) -> None: pass - def newnonce(self, ident=None): + def newnonce(self, ident: typing.Optional[str] = None) -> str: """ This method is to be overriden and should return new nonces. - @type ident: str @param ident: is an identifier to be associated with this nonce - @rtype: str """ raise NotImplementedError - def checknonce(self, nonce, count=1, ident=None): + def checknonce( + self, nonce: str, count: int = 1, ident: typing.Optional[str] = None + ) -> bool: """ This method is to be overridden and should do a check for whether the given nonce is valid as being used count times. - @type nonce: str - @type count: int @param count: indicates how often the nonce has been used (including this check) - @type ident: str @param ident: it is also checked that the nonce was associated to this identifier when given - @rtype: bool """ raise NotImplementedError -def format_time(seconds): +def format_time(seconds: float) -> str: """ internal method formatting a unix time to a fixed-length string - @type seconds: float - @rtype: str """ # the overflow will happen about 2112 return "%013X" % int(seconds * 1000000) @@ -356,13 +358,11 @@ class StatelessNonceStore(NonceStoreBase): >>> s.checknonce(n.rsplit(':', 1)[0] + "bad hash") False """ - def __init__(self, maxage=300, secret=None): + def __init__(self, maxage: int = 300, secret: typing.Optional[str] = None): """ - @type maxage: int @param maxage: is the number of seconds a nonce may be valid. Choosing a large value may result in more memory usage whereas a smaller value results in more requests. Defaults to 5 minutes. - @type secret: str @param secret: if not given, a secret is generated and is therefore shared after forks. Knowing this secret permits creating nonces. """ @@ -373,12 +373,8 @@ class StatelessNonceStore(NonceStoreBase): else: self.server_secret = gen_rand_str() - def newnonce(self, ident=None): - """ - Generates a new nonce string. - @type ident: None or str - @rtype: str - """ + def newnonce(self, ident: typing.Optional[str] = None) -> str: + """Generates a new nonce string.""" nonce_time = format_time(time.time()) nonce_value = gen_rand_str() token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret) @@ -387,14 +383,10 @@ class StatelessNonceStore(NonceStoreBase): token = md5hex(token) return "%s:%s:%s" % (nonce_time, nonce_value, token) - def checknonce(self, nonce, count=1, ident=None): - """ - Check whether the provided string is a nonce. - @type nonce: str - @type count: int - @type ident: None or str - @rtype: bool - """ + def checknonce( + self, nonce: str, count: int = 1, ident: typing.Optional[str] = None + ) -> bool: + """Check whether the provided string is a nonce.""" if count != 1: return False try: @@ -429,13 +421,13 @@ class MemoryNonceStore(NonceStoreBase): >>> s.checknonce(n.rsplit(':', 1)[0] + "bad hash") False """ - def __init__(self, maxage=300, maxuses=5): + nonces: typing.List[typing.Tuple[str, str, int]] + + def __init__(self, maxage: int = 300, maxuses: int = 5): """ - @type maxage: int @param maxage: is the number of seconds a nonce may be valid. Choosing a large value may result in more memory usage whereas a smaller value results in more requests. Defaults to 5 minutes. - @type maxuses: int @param maxuses: is the number of times a nonce may be used (with different nc values). A value of 1 makes nonces usable exactly once resulting in more requests. Defaults to 5. @@ -447,18 +439,14 @@ class MemoryNonceStore(NonceStoreBase): # as [(str (hex encoded), str, int)] self.server_secret = gen_rand_str() - def _cleanup(self): + def _cleanup(self) -> None: """internal methods cleaning list of valid nonces""" old = format_time(time.time() - self.maxage) while self.nonces and self.nonces[0][0] < old: self.nonces.pop(0) - def newnonce(self, ident=None): - """ - Generates a new nonce string. - @type ident: None or str - @rtype: str - """ + def newnonce(self, ident: typing.Optional[str] = None) -> str: + """Generates a new nonce string.""" self._cleanup() # avoid growing self.nonces nonce_time = format_time(time.time()) nonce_value = gen_rand_str() @@ -469,14 +457,12 @@ class MemoryNonceStore(NonceStoreBase): token = md5hex(token) return "%s:%s:%s" % (nonce_time, nonce_value, token) - def checknonce(self, nonce, count=1, ident=None): + def checknonce( + self, nonce: str, count: int = 1, ident: typing.Optional[str] = None + ) -> bool: """ Do a check for whether the provided string is a nonce and increase usage count on returning True. - @type nonce: str - @type count: int - @type ident: None or str - @rtype: bool """ try: nonce_time, nonce_value, nonce_hash = nonce.split(':') @@ -522,19 +508,28 @@ class LazyDBAPI2Opener: because this way each worker child opens a new database connection when the first request is to be answered. """ - def __init__(self, function, *args, **kwargs): + _function: typing.Optional[typing.Callable[..., typing.Any]] + def __init__( + self, + function: typing.Callable[..., typing.Any], + *args, + **kwargs, + ): """ The database will be connected on the first method call. This is done by calling the given function with the remaining parameters. @param function: is the function that connects to the database """ self._function = function - self._args = args - self._kwargs = kwargs + self._args: typing.Optional[typing.Tuple[typing.Any, ...]] = args + self._kwargs: typing.Optional[typing.Dict[str, typing.Any]] = kwargs self._dbhandle = None - def _getdbhandle(self): + def _getdbhandle(self) -> typing.Any: """Returns an open database connection. Open if necessary.""" if self._dbhandle is None: + assert self._function is not None + assert self._args is not None + assert self._kwargs is not None self._dbhandle = self._function(*self._args, **self._kwargs) self._function = self._args = self._kwargs = None return self._dbhandle @@ -573,14 +568,14 @@ class DBAPI2NonceStore(NonceStoreBase): >>> s.checknonce(n.rsplit(':', 1)[0] + "bad hash") False """ - def __init__(self, dbhandle, maxage=300, maxuses=5, table="nonces"): + def __init__( + self, dbhandle, maxage: int = 300, maxuses: int = 5, table="nonces" + ): """ @param dbhandle: is a dbapi2 connection - @type maxage: int @param maxage: is the number of seconds a nonce may be valid. Choosing a large value may result in more memory usage whereas a smaller value results in more requests. Defaults to 5 minutes. - @type maxuses: int @param maxuses: is the number of times a nonce may be used (with different nc values). A value of 1 makes nonces usable exactly once resulting in more requests. Defaults to 5. @@ -592,16 +587,13 @@ class DBAPI2NonceStore(NonceStoreBase): self.table = table self.server_secret = gen_rand_str() - def _cleanup(self, cur): + def _cleanup(self, cur) -> None: """internal methods cleaning list of valid nonces""" old = format_time(time.time() - self.maxage) cur.execute("DELETE FROM %s WHERE key < '%s:';" % (self.table, old)) - def newnonce(self, ident=None): - """ - Generates a new nonce string. - @rtype: str - """ + def newnonce(self, ident: typing.Optional[str] = None) -> str: + """Generates a new nonce string.""" nonce_time = format_time(time.time()) nonce_value = gen_rand_str() dbkey = "%s:%s" % (nonce_time, nonce_value) @@ -615,14 +607,12 @@ class DBAPI2NonceStore(NonceStoreBase): token = md5hex(token) return "%s:%s:%s" % (nonce_time, nonce_value, token) - def checknonce(self, nonce, count=1, ident=None): + def checknonce( + self, nonce: str, count: int = 1, ident: typing.Optional[str] = None + ) -> bool: """ Do a check for whether the provided string is a nonce and increase usage count on returning True. - @type nonce: str - @type count: int - @type ident: str or None - @rtype: bool """ try: nonce_time, nonce_value, nonce_hash = nonce.split(':') @@ -668,7 +658,7 @@ class DBAPI2NonceStore(NonceStoreBase): self.dbhandle.commit() return True -def check_uri(credentials, environ): +def check_uri(credentials: typing.Dict[str, str], environ: Environ) -> None: """internal method for verifying the uri credential @raises AuthenticationRequired: """ @@ -706,19 +696,23 @@ class AuthDigestMiddleware(AuthenticationMiddleware): application.""" authorization_method = "digest" algorithms = {"md5": md5hex} - def __init__(self, app, gentoken, maxage=300, maxuses=5, store=None): + noncestore: NonceStoreBase + def __init__( + self, + app: WsgiApp, + gentoken: TokenGenerator, + maxage: int = 300, + maxuses: int = 5, + store: typing.Optional[NonceStoreBase] = None, + ): """ @param app: is the wsgi application to be served with authentication. - @type gentoken: str -> (str or None) @param gentoken: has to have the same functionality and interface as the L{AuthTokenGenerator} class. - @type maxage: int @param maxage: deprecated, see L{MemoryNonceStore} or L{StatelessNonceStore} and pass an instance to store - @type maxuses: int @param maxuses: deprecated, see L{MemoryNonceStore} and pass an instance to store - @type store: L{NonceStoreBase} @param store: a nonce storage implementation object. Usage of this parameter will override maxage and maxuses. """ @@ -731,7 +725,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware): assert hasattr(store, "checknonce") self.noncestore = store - def authenticate(self, auth, environ): + def authenticate(self, auth: str, environ: Environ) -> Environ: assert isinstance(auth, str) try: credentials = parse_digest_response(auth) @@ -783,7 +777,9 @@ class AuthDigestMiddleware(AuthenticationMiddleware): return dict(user=credentials["username"], outheaders=[("Authentication-Info", format_digest(digest))]) - def auth_response(self, credentials, reqmethod): + def auth_response( + self, credentials: typing.Dict[str, str], reqmethod: str + ) -> typing.Optional[str]: """internal method generating authentication tokens @raises AuthenticationRequired: """ @@ -818,7 +814,9 @@ class AuthDigestMiddleware(AuthenticationMiddleware): dig.insert(0, a1h) return self.algorithms[algo](":".join(dig)) - def www_authenticate(self, exception): + def www_authenticate( + self, exception: AuthenticationRequired + ) -> typing.Tuple[str, str]: digest = dict(nonce=(self.noncestore.newnonce(), True), realm=(self.gentoken.realm, True), algorithm=("MD5", False), diff --git a/wsgitools/filters.py b/wsgitools/filters.py index 7f8543d..eada536 100644 --- a/wsgitools/filters.py +++ b/wsgitools/filters.py @@ -11,27 +11,37 @@ import sys import time import gzip import io +import typing -from wsgitools.internal import str2bytes +from wsgitools.internal import ( + Environ, + HeaderList, + OptExcInfo, + StartResponse, + str2bytes, + WriteCallback, + WsgiApp, +) __all__.append("CloseableIterator") class CloseableIterator: """Concatenating iterator with close attribute.""" - def __init__(self, close_function, *iterators): + def __init__( + self, + close_function: typing.Optional[typing.Callable[[], None]], + *iterators: typing.Iterable[typing.Any] + ): """If close_function is not C{None}, it will be the C{close} attribute of the created iterator object. Further parameters specify iterators that are to be concatenated. - @type close_function: a function or C{None} """ if close_function is not None: self.close = close_function self.iterators = list(map(iter, iterators)) - def __iter__(self): - """iterator interface - @rtype: gen() - """ + def __iter__(self) -> typing.Iterator[typing.Any]: + """iterator interface""" return self - def __next__(self): + def __next__(self) -> typing.Any: """iterator interface""" if not self.iterators: raise StopIteration @@ -45,16 +55,17 @@ class CloseableIterator: __all__.append("CloseableList") class CloseableList(list): """A list with a close attribute.""" - def __init__(self, close_function, *args): + def __init__( + self, close_function: typing.Optional[typing.Callable[[], None]], *args + ): """If close_function is not C{None}, it will be the C{close} attribute of the created list object. Other parameters are passed to the list constructor. - @type close_function: a function or C{None} """ if close_function is not None: self.close = close_function list.__init__(self, *args) - def __iter__(self): + def __iter__(self) -> CloseableIterator: """iterator interface""" return CloseableIterator(getattr(self, "close", None), list.__iter__(self)) @@ -76,104 +87,101 @@ class BaseWSGIFilter: L{BaseWSGIFilter} class to a L{WSGIFilterMiddleware} will result in not modifying requests at all. """ - def __init__(self): + def __init__(self) -> None: """This constructor does nothing and can safely be overwritten. It is only listed here to document that it must be callable without additional parameters.""" - def filter_environ(self, environ): + def filter_environ(self, environ: Environ) -> Environ: """Receives a dict with the environment passed to the wsgi application and a C{dict} must be returned. The default is to return the same dict. - @type environ: {str: str} - @rtype: {str: str} """ return environ - def filter_exc_info(self, exc_info): + def filter_exc_info(self, exc_info: OptExcInfo) -> OptExcInfo: """Receives either C{None} or a tuple passed as third argument to C{start_response} from the wrapped wsgi application. Either C{None} or such a tuple must be returned.""" return exc_info - def filter_status(self, status): + def filter_status(self, status: str) -> str: """Receives a status string passed as first argument to C{start_response} from the wrapped wsgi application. A valid HTTP status string must be returned. - @type status: str - @rtype: str """ return status - def filter_header(self, headername, headervalue): + def filter_header( + self, headername: str, headervalue: str + ) -> typing.Optional[typing.Tuple[str, str]]: """This function is invoked for each C{(headername, headervalue)} tuple in the second argument to the C{start_response} from the wrapped wsgi application. Such a value or C{None} for discarding the header must be returned. - @type headername: str - @type headervalue: str - @rtype: (str, str) """ return (headername, headervalue) - def filter_headers(self, headers): + def filter_headers(self, headers: HeaderList) -> HeaderList: """A list of headers passed as the second argument to the C{start_response} from the wrapped wsgi application is passed to this function and such a list must also be returned. - @type headers: [(str, str)] - @rtype: [(str, str)] """ return headers - def filter_data(self, data): + def filter_data(self, data: bytes) -> bytes: """For each string that is either written by the C{write} callable or returned from the wrapped wsgi application this method is invoked. It must return a string. - @type data: bytes - @rtype: bytes """ return data - def append_data(self): + def append_data(self) -> typing.Iterable[bytes]: """This function can be used to append data to the response. A list of strings or some kind of iterable yielding strings has to be returned. The default is to return an empty list. - @rtype: gen([bytes]) """ return [] - def handle_close(self): + def handle_close(self) -> None: """This method is invoked after the request has finished.""" __all__.append("WSGIFilterMiddleware") class WSGIFilterMiddleware: """This wsgi middleware can be used with specialized L{BaseWSGIFilter}s to modify wsgi requests and/or reponses.""" - def __init__(self, app, filterclass): + def __init__( + self, app: WsgiApp, filterclass: typing.Callable[[], BaseWSGIFilter] + ): """ @param app: is a wsgi application. - @type filterclass: L{BaseWSGIFilter}s subclass - @param filterclass: is a subclass of L{BaseWSGIFilter} or some class - that implements the interface.""" + @param filterclass: is factory creating L{BaseWSGIFilter} instances + """ self.app = app self.filterclass = filterclass - def __call__(self, environ, start_response): - """wsgi interface - @type environ: {str, str} - @rtype: gen([bytes]) - """ + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: + """wsgi interface""" assert isinstance(environ, dict) reqfilter = self.filterclass() environ = reqfilter.filter_environ(environ) - def modified_start_response(status, headers, exc_info=None): + def modified_start_response( + status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: assert isinstance(status, str) assert isinstance(headers, list) exc_info = reqfilter.filter_exc_info(exc_info) status = reqfilter.filter_status(status) - headers = (reqfilter.filter_header(h, v) for h, v in headers) - headers = [h for h in headers if h] - headers = reqfilter.filter_headers(headers) + headers = reqfilter.filter_headers( + list( + filter( + None, + (reqfilter.filter_header(h, v) for h, v in headers) + ) + ) + ) write = start_response(status, headers, exc_info) - def modified_write(data): + def modified_write(data: bytes) -> None: write(reqfilter.filter_data(data)) return modified_write ret = self.app(environ, modified_start_response) assert hasattr(ret, "__iter__") - def modified_close(): + def modified_close() -> None: reqfilter.handle_close() getattr(ret, "close", lambda:0)() @@ -182,7 +190,7 @@ class WSGIFilterMiddleware: list(map(reqfilter.filter_data, ret)) + list(reqfilter.append_data())) ret = iter(ret) - def late_append_data(): + def late_append_data() -> typing.Iterator[bytes]: """Invoke C{reqfilter.append_data()} after C{filter_data()} has seen all data.""" for data in reqfilter.append_data(): @@ -194,38 +202,42 @@ class WSGIFilterMiddleware: # Using map and lambda here since pylint cannot handle list comprehension in # default arguments. Also note that neither ' nor " are considered printable. # For escape_string to be reversible \ is also not considered printable. -def escape_string(string, replacer=list(map( - lambda i: chr(i) if str2bytes(chr(i)).isalnum() or - chr(i) in '!#$%&()*+,-./:;<=>?@[]^_`{|}~ ' else - r"\x%2.2x" % i, - range(256)))): - """Encodes non-printable characters in a string using \\xXX escapes. - - @type string: str - @rtype: str - """ +def escape_string( + string: str, + replacer: typing.List[str] = list( + map( + lambda i: ( + chr(i) if ( + str2bytes(chr(i)).isalnum() or + chr(i) in '!#$%&()*+,-./:;<=>?@[]^_`{|}~ ' + ) else r"\x%2.2x" % i + ), + range(256), + ) + ), +) -> str: + """Encodes non-printable characters in a string using \\xXX escapes.""" return "".join(replacer[ord(char)] for char in string) __all__.append("RequestLogWSGIFilter") class RequestLogWSGIFilter(BaseWSGIFilter): """This filter logs all requests in the apache log file format.""" + proto: typing.Optional[str] @classmethod - def creator(cls, log, flush=True): + def creator( + cls, log: typing.TextIO, flush: bool = True + ) -> typing.Callable[[], "RequestLogWSGIFilter"]: """Returns a function creating L{RequestLogWSGIFilter}s on given log file. log has to be a file-like object. - @type log: file-like @param log: elements of type str are written to the log. That means in Py3.X the contents are decoded and in Py2.X the log is assumed to be encoded in latin1. This follows the spirit of WSGI. - @type flush: bool @param flush: if True, invoke the flush method on log after each write invocation """ return lambda:cls(log, flush) - def __init__(self, log=sys.stdout, flush=True): + def __init__(self, log: typing.TextIO = sys.stdout, flush: bool = True): """ - @type log: file-like - @type flush: bool @param flush: if True, invoke the flush method on log after each write invocation """ @@ -244,11 +256,8 @@ class RequestLogWSGIFilter(BaseWSGIFilter): self.length = 0 self.referrer = None self.useragent = None - def filter_environ(self, environ): - """BaseWSGIFilter interface - @type environ: {str: str} - @rtype: {str: str} - """ + def filter_environ(self, environ: Environ) -> Environ: + """BaseWSGIFilter interface""" assert isinstance(environ, dict) self.remote = environ.get("REMOTE_ADDR", self.remote) self.user = environ.get("REMOTE_USER", self.user) @@ -260,19 +269,16 @@ class RequestLogWSGIFilter(BaseWSGIFilter): self.referrer = environ.get("HTTP_REFERER", self.referrer) self.useragent = environ.get("HTTP_USER_AGENT", self.useragent) return environ - def filter_status(self, status): - """BaseWSGIFilter interface - @type status: str - @rtype: str - """ + def filter_status(self, status: str) -> str: + """BaseWSGIFilter interface""" assert isinstance(status, str) self.status = status.split()[0] return status - def filter_data(self, data): + def filter_data(self, data: bytes) -> bytes: assert isinstance(data, bytes) self.length += len(data) return data - def handle_close(self): + def handle_close(self) -> None: """BaseWSGIFilter interface""" line = '%s %s - [%s]' % (self.remote, self.user, self.time) line = '%s "%s %s' % (line, escape_string(self.reqmethod), @@ -301,25 +307,18 @@ class TimerWSGIFilter(BaseWSGIFilter): something like C{["spam?GenTime", "?GenTime spam", "?GenTime"]} only the last occurance get's replaced.""" @classmethod - def creator(cls, pattern): + def creator(cls, pattern: bytes) -> typing.Callable[[], "TimerWSGIFilter"]: """Returns a function creating L{TimerWSGIFilter}s with a given pattern beeing a string of exactly eight bytes. - @type pattern: bytes """ return lambda:cls(pattern) - def __init__(self, pattern=b"?GenTime"): - """ - @type pattern: str - """ + def __init__(self, pattern: bytes = b"?GenTime"): BaseWSGIFilter.__init__(self) assert isinstance(pattern, bytes) self.pattern = pattern self.start = time.time() - def filter_data(self, data): - """BaseWSGIFilter interface - @type data: bytes - @rtype: bytes - """ + def filter_data(self, data: bytes) -> bytes: + """BaseWSGIFilter interface""" if data == self.pattern: return b"%8.3g" % (time.time() - self.start) return data @@ -331,30 +330,19 @@ class EncodeWSGIFilter(BaseWSGIFilter): whereas wsgi mandates the use of bytes. """ @classmethod - def creator(cls, charset): + def creator(cls, charset: str) -> typing.Callable[[], "EncodeWSGIFilter"]: """Returns a function creating L{EncodeWSGIFilter}s with a given charset. - @type charset: str """ return lambda:cls(charset) - def __init__(self, charset="utf-8"): - """ - @type charset: str - """ + def __init__(self, charset: str = "utf-8"): BaseWSGIFilter.__init__(self) self.charset = charset - def filter_data(self, data): - """BaseWSGIFilter interface - @type data: str - @rtype: bytes - """ + def filter_data(self, data: str) -> bytes: + """BaseWSGIFilter interface""" return data.encode(self.charset) - def filter_header(self, header, value): - """BaseWSGIFilter interface - @type header: str - @type value: str - @rtype: (str, str) - """ + def filter_header(self, header: str, value: str) -> typing.Tuple[str, str]: + """BaseWSGIFilter interface""" if header.lower() != "content-type": return (header, value) return (header, "%s; charset=%s" % (value, self.charset)) @@ -362,17 +350,20 @@ class EncodeWSGIFilter(BaseWSGIFilter): __all__.append("GzipWSGIFilter") class GzipWSGIFilter(BaseWSGIFilter): """Compresses content using gzip.""" + gzip: typing.Optional[gzip.GzipFile] + sio: typing.Optional[io.BytesIO] + @classmethod - def creator(cls, flush=True): + def creator( + cls, flush: bool = True + ) -> typing.Callable[[], "GzipWSGIFilter"]: """ Returns a function creating L{GzipWSGIFilter}s. - @type flush: bool @param flush: whether or not the filter should always flush the buffer """ return lambda:cls(flush) - def __init__(self, flush=True): + def __init__(self, flush: bool = True): """ - @type flush: bool @param flush: whether or not the filter should always flush the buffer """ BaseWSGIFilter.__init__(self) @@ -380,10 +371,8 @@ class GzipWSGIFilter(BaseWSGIFilter): self.compress = False self.sio = None self.gzip = None - def filter_environ(self, environ): - """BaseWSGIFilter interface - @type environ: {str: str} - """ + def filter_environ(self, environ: Environ) -> Environ: + """BaseWSGIFilter interface""" assert isinstance(environ, dict) if "HTTP_ACCEPT_ENCODING" in environ: acceptenc = environ["HTTP_ACCEPT_ENCODING"].split(',') @@ -392,28 +381,25 @@ class GzipWSGIFilter(BaseWSGIFilter): self.sio = io.BytesIO() self.gzip = gzip.GzipFile(fileobj=self.sio, mode="w") return environ - def filter_header(self, headername, headervalue): - """ BaseWSGIFilter interface - @type headername: str - @type headervalue: str - @rtype: (str, str) or None - """ + def filter_header( + self, headername: str, headervalue: str + ) -> typing.Optional[typing.Tuple[str, str]]: + """BaseWSGIFilter interface""" if self.compress: if headername.lower() == "content-length": return None return (headername, headervalue) - def filter_headers(self, headers): - """BaseWSGIFilter interface - @type headers: [(str, str)] - @rtype: [(str, str)] - """ + def filter_headers(self, headers: HeaderList) -> HeaderList: + """BaseWSGIFilter interface""" assert isinstance(headers, list) if self.compress: headers.append(("Content-encoding", "gzip")) return headers - def filter_data(self, data): + def filter_data(self, data: bytes) -> bytes: if not self.compress: return data + assert self.gzip is not None + assert self.sio is not None self.gzip.write(data) if self.flush: self.gzip.flush() @@ -421,9 +407,11 @@ class GzipWSGIFilter(BaseWSGIFilter): self.sio.truncate(0) self.sio.seek(0) return data - def append_data(self): + def append_data(self) -> typing.List[bytes]: if not self.compress: return [] + assert self.gzip is not None + assert self.sio is not None self.gzip.close() data = self.sio.getvalue() return [data] @@ -435,30 +423,28 @@ class ReusableWSGIInputFilter(BaseWSGIFilter): C{BytesIO} instance which provides a C{seek} method. """ @classmethod - def creator(cls, maxrequestsize): + def creator( + cls, maxrequestsize: int + ) -> typing.Callable[[], "ReusableWSGIInputFilter"]: """ Returns a function creating L{ReusableWSGIInputFilter}s with desired maxrequestsize being set. If there is more data than maxrequestsize is available in C{wsgi.input} the rest will be ignored. (It is up to the adapter to eat this data.) - @type maxrequestsize: int @param maxrequestsize: is the maximum number of bytes to store in the C{BytesIO} """ return lambda:cls(maxrequestsize) - def __init__(self, maxrequestsize=65536): + def __init__(self, maxrequestsize: int = 65536): """ReusableWSGIInputFilters constructor. - @type maxrequestsize: int @param maxrequestsize: is the maximum number of bytes to store in the C{BytesIO}, see L{creator} """ BaseWSGIFilter.__init__(self) self.maxrequestsize = maxrequestsize - def filter_environ(self, environ): - """BaseWSGIFilter interface - @type environ: {str: str} - """ + def filter_environ(self, environ: Environ) -> Environ: + """BaseWSGIFilter interface""" if isinstance(environ["wsgi.input"], io.BytesIO): return environ # nothing to be done diff --git a/wsgitools/internal.py b/wsgitools/internal.py index 9bf7ded..86a9d5a 100644 --- a/wsgitools/internal.py +++ b/wsgitools/internal.py @@ -1,11 +1,42 @@ -def bytes2str(bstr): +import sys +import typing + +def bytes2str(bstr: bytes) -> str: assert isinstance(bstr, bytes) return bstr.decode("iso-8859-1") # always successful -def str2bytes(sstr): +def str2bytes(sstr: str) -> bytes: assert isinstance(sstr, str) return sstr.encode("iso-8859-1") # might fail, but spec says it doesn't -def textopen(filename, mode): +def textopen(filename: str, mode: str) -> typing.TextIO: # We use the same encoding as for all wsgi strings here. return open(filename, mode, encoding="iso-8859-1") + +Environ = typing.Dict[str, typing.Any] + +HeaderList = typing.List[typing.Tuple[str, str]] + +if sys.version_info >= (3, 11): + OptExcInfo = typing.Optional[ + typing.Tuple[type[BaseException], BaseException, typing.Any] + ] +else: + OptExcInfo = typing.Optional[ + typing.Tuple[typing.Any, BaseException, typing.Any] + ] + +WriteCallback = typing.Callable[[bytes], None] + +if sys.version_info >= (3, 11): + class StartResponse(typing.Protocol): + def __call__( + self, status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: + ... +else: + StartResponse = typing.Callable[ + [str, HeaderList, OptExcInfo], WriteCallback + ] + +WsgiApp = typing.Callable[[Environ, StartResponse], typing.Iterable[bytes]] diff --git a/wsgitools/middlewares.py b/wsgitools/middlewares.py index ef9fe84..8577384 100644 --- a/wsgitools/middlewares.py +++ b/wsgitools/middlewares.py @@ -6,27 +6,28 @@ import sys import cgitb import collections import io +import typing from wsgitools.internal import bytes2str, str2bytes from wsgitools.filters import CloseableList, CloseableIterator from wsgitools.authentication import AuthenticationRequired, \ ProtocolViolation, AuthenticationMiddleware +from wsgitools.internal import ( + Environ, HeaderList, OptExcInfo, StartResponse, WriteCallback, WsgiApp +) __all__.append("SubdirMiddleware") class SubdirMiddleware: """Middleware choosing wsgi applications based on a dict.""" - def __init__(self, default, mapping={}): - """ - @type default: wsgi app - @type mapping: {str: wsgi app} - """ + def __init__( + self, default: WsgiApp, mapping: typing.Dict[str, WsgiApp] = {} + ): self.default = default self.mapping = mapping - def __call__(self, environ, start_response): - """wsgi interface - @type environ: {str: str} - @rtype: gen([bytes]) - """ + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: + """wsgi interface""" assert isinstance(environ, dict) app = None script = environ["PATH_INFO"] @@ -51,22 +52,24 @@ class NoWriteCallableMiddleware: C{start_response} function to a wsgi application that doesn't need one by writing the data to a C{BytesIO} and then making it be the first result element.""" - def __init__(self, app): + def __init__(self, app: WsgiApp): """Wraps wsgi application app.""" self.app = app - def __call__(self, environ, start_response): - """wsgi interface - @type environ: {str, str} - @rtype: gen([bytes]) - """ + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: + """wsgi interface""" assert isinstance(environ, dict) - todo = [None] + todo: typing.Optional[typing.Tuple[str, HeaderList]] = None sio = io.BytesIO() gotiterdata = False - def write_calleable(data): + def write_calleable(data: bytes) -> None: assert not gotiterdata sio.write(data) - def modified_start_response(status, headers, exc_info=None): + def modified_start_response( + status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: + nonlocal todo try: if sio.tell() > 0 or gotiterdata: assert exc_info is not None @@ -75,7 +78,7 @@ class NoWriteCallableMiddleware: exc_info = None assert isinstance(status, str) assert isinstance(headers, list) - todo[0] = (status, headers) + todo = (status, headers) return write_calleable ret = self.app(environ, modified_start_response) @@ -96,8 +99,8 @@ class NoWriteCallableMiddleware: else: gotiterdata = True - assert todo[0] is not None - status, headers = todo[0] + assert todo is not None + status, headers = todo data = sio.getvalue() if isinstance(ret, list): @@ -117,27 +120,36 @@ class ContentLengthMiddleware: """Guesses the content length header if possible. @note: The application used must not use the C{write} callable returned by C{start_response}.""" - def __init__(self, app, maxstore=0): + maxstore: typing.Union[float, int] + def __init__( + self, app: WsgiApp, maxstore: typing.Union[int, typing.Tuple[()]] = 0 + ): """Wraps wsgi application app. If the application returns a list, the total length of strings is available and the content length header is set unless there already is one. For an iterator data is accumulated up to a total of maxstore bytes (where maxstore=() means infinity). If the iterator is exhaused within maxstore bytes a content length header is added unless already present. - @type maxstore: int or () @note: that setting maxstore to a value other than 0 will violate the wsgi standard """ self.app = app if maxstore == (): - maxstore = float("inf") - self.maxstore = maxstore - def __call__(self, environ, start_response): + self.maxstore = float("inf") + else: + assert isinstance(maxstore, int) + self.maxstore = maxstore + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: """wsgi interface""" assert isinstance(environ, dict) - todo = [] + todo: typing.Optional[typing.Tuple[str, HeaderList]] = None gotdata = False - def modified_start_response(status, headers, exc_info=None): + def modified_start_response( + status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: + nonlocal todo try: if gotdata: assert exc_info is not None @@ -146,8 +158,8 @@ class ContentLengthMiddleware: exc_info = None assert isinstance(status, str) assert isinstance(headers, list) - todo[:] = ((status, headers),) - def raise_not_imp(*args): + todo = (status, headers) + def raise_not_imp(_: bytes) -> None: raise NotImplementedError return raise_not_imp @@ -156,8 +168,8 @@ class ContentLengthMiddleware: if isinstance(ret, list): gotdata = True - assert bool(todo) - status, headers = todo[0] + assert todo is not None + status, headers = todo if all(k.lower() != "content-length" for k, _ in headers): length = sum(map(len, ret)) headers.append(("Content-Length", str(length))) @@ -173,8 +185,8 @@ class ContentLengthMiddleware: except StopIteration: stopped = True gotdata = True - assert bool(todo) - status, headers = todo[0] + assert todo is not None + status, headers = todo data = CloseableList(getattr(ret, "close", None)) if first: data.append(first) @@ -197,12 +209,12 @@ class ContentLengthMiddleware: return CloseableIterator(getattr(ret, "close", None), data, ret) -def storable(environ): +def storable(environ: Environ) -> bool: if environ["REQUEST_METHOD"] != "GET": return False return True -def cacheable(environ): +def cacheable(environ: Environ) -> bool: if environ.get("HTTP_CACHE_CONTROL", "") == "max-age=0": return False return True @@ -213,16 +225,24 @@ class CachingMiddleware: C{QUERY_STRING}.""" class CachedRequest: - def __init__(self, timestamp): + def __init__(self, timestamp: float): self.timestamp = timestamp self.status = "" - self.headers = [] - self.body = [] - - def __init__(self, app, maxage=60, storable=storable, cacheable=cacheable): + self.headers: HeaderList = [] + self.body: typing.List[bytes] = [] + + cache: typing.Dict[str, CachedRequest] + lastcached: typing.Deque[typing.Tuple[str, float]] + + def __init__( + self, + app: WsgiApp, + maxage: int = 60, + storable: typing.Callable[[Environ], bool] = storable, + cacheable: typing.Callable[[Environ], bool] = cacheable, + ): """ @param app: is a wsgi application to be cached. - @type maxage: int @param maxage: is the number of seconds a reponse may be cached. @param storable: is a predicate that determines whether the response may be cached at all based on the C{environ} dict. @@ -235,13 +255,20 @@ class CachingMiddleware: self.cache = {} self.lastcached = collections.deque() - def insert_cache(self, key, obj, now=None): + def insert_cache( + self, + key: str, + obj: CachedRequest, + now: typing.Optional[float] = None, + ) -> None: if now is None: now = time.time() self.cache[key] = obj self.lastcached.append((key, now)) - def prune_cache(self, maxclean=16, now=None): + def prune_cache( + self, maxclean: int = 16, now: typing.Optional[float] = None + ) -> None: if now is None: now = time.time() old = now - self.maxage @@ -258,10 +285,10 @@ class CachingMiddleware: if obj.timestamp <= old: del self.cache[key] - def __call__(self, environ, start_response): - """wsgi interface - @type environ: {str: str} - """ + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: + """wsgi interface""" assert isinstance(environ, dict) now = time.time() self.prune_cache(now=now) @@ -279,7 +306,9 @@ class CachingMiddleware: else: del self.cache[path] cache_object = self.CachedRequest(now) - def modified_start_respesponse(status, headers, exc_info=None): + def modified_start_respesponse( + status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: try: if cache_object.body: assert exc_info is not None @@ -291,7 +320,7 @@ class CachingMiddleware: cache_object.status = status cache_object.headers = headers write = start_response(status, list(headers)) - def modified_write(data): + def modified_write(data: bytes) -> None: cache_object.body.append(data) write(data) return modified_write @@ -303,7 +332,7 @@ class CachingMiddleware: cache_object.body.extend(ret) self.insert_cache(path, cache_object, now) return ret - def pass_through(): + def pass_through() -> typing.Iterator[bytes]: for data in ret: cache_object.body.append(data) yield data @@ -313,18 +342,13 @@ class CachingMiddleware: __all__.append("DictAuthChecker") class DictAuthChecker: """Verifies usernames and passwords by looking them up in a dict.""" - def __init__(self, users): + def __init__(self, users: typing.Dict[str, str]): """ - @type users: {str: str} @param users: is a dict mapping usernames to password.""" self.users = users - def __call__(self, username, password, environ): + def __call__(self, username: str, password: str, environ: Environ) -> bool: """check_function interface taking username and password and resulting in a bool. - @type username: str - @type password: str - @type environ: {str: object} - @rtype: bool """ return username in self.users and self.users[username] == password @@ -334,7 +358,13 @@ class BasicAuthMiddleware(AuthenticationMiddleware): the warpped application the environ dictionary is augmented by a REMOTE_USER key.""" authorization_method = "basic" - def __init__(self, app, check_function, realm='www', app401=None): + def __init__( + self, + app: WsgiApp, + check_function: typing.Callable[[str, str, Environ], bool], + realm: str = 'www', + app401: typing.Optional[WsgiApp] = None, + ): """ @param app: is a WSGI application. @param check_function: is a function taking three arguments username, @@ -342,7 +372,6 @@ class BasicAuthMiddleware(AuthenticationMiddleware): request may is allowed. The older interface of taking only the first two arguments is still supported via catching a C{TypeError}. - @type realm: str @param app401: is an optional WSGI application to be used for error messages """ @@ -351,7 +380,9 @@ class BasicAuthMiddleware(AuthenticationMiddleware): self.realm = realm self.app401 = app401 - def authenticate(self, auth, environ): + def authenticate( + self, auth: str, environ: Environ + ) -> typing.Dict[str, str]: assert isinstance(auth, str) assert isinstance(environ, dict) authb = str2bytes(auth) @@ -372,10 +403,17 @@ class BasicAuthMiddleware(AuthenticationMiddleware): return dict(user=username) raise AuthenticationRequired("credentials not valid") - def www_authenticate(self, exception): + def www_authenticate( + self, exception: AuthenticationRequired + ) -> typing.Tuple[str, str]: return ("WWW-Authenticate", 'Basic realm="%s"' % self.realm) - def authorization_required(self, environ, start_response, exception): + def authorization_required( + self, + environ: Environ, + start_response: StartResponse, + exception: AuthenticationRequired, + ) -> typing.Iterable[bytes]: if self.app401 is not None: return self.app401(environ, start_response) return AuthenticationMiddleware.authorization_required( @@ -385,13 +423,13 @@ __all__.append("TracebackMiddleware") class TracebackMiddleware: """In case the application throws an exception this middleware will show an html-formatted traceback using C{cgitb}.""" - def __init__(self, app): + def __init__(self, app: WsgiApp): """app is the wsgi application to proxy.""" self.app = app - def __call__(self, environ, start_response): - """wsgi interface - @type environ: {str: str} - """ + def __call__( + self, environ: Environ, start_response: StartResponse + ) -> typing.Iterable[bytes]: + """wsgi interface""" try: assert isinstance(environ, dict) ret = self.app(environ, start_response) diff --git a/wsgitools/scgi/__init__.py b/wsgitools/scgi/__init__.py index e2a68c2..677e3b5 100644 --- a/wsgitools/scgi/__init__.py +++ b/wsgitools/scgi/__init__.py @@ -1,5 +1,8 @@ __all__ = [] +import socket +import typing + try: import sendfile except ImportError: @@ -7,6 +10,8 @@ except ImportError: else: have_sendfile = True +from wsgitools.internal import Environ + class FileWrapper: """ @ivar offset: Initially 0. Becomes -1 when reading using next and @@ -14,18 +19,20 @@ class FileWrapper: counts the number of bytes sent. It also ensures that next and transfer are never mixed. """ - def __init__(self, filelike, blksize=8192): + def __init__(self, filelike, blksize: int = 8192): self.filelike = filelike self.blksize = blksize self.offset = 0 if hasattr(filelike, "close"): self.close = filelike.close - def can_transfer(self): + def can_transfer(self) -> bool: return have_sendfile and hasattr(self.filelike, "fileno") and \ self.offset >= 0 - def transfer(self, sock, blksize=None): + def transfer( + self, sock: socket.socket, blksize: typing.Optional[int] = None + ) -> int: assert self.offset >= 0 if blksize is None: blksize = self.blksize @@ -42,10 +49,10 @@ class FileWrapper: self.offset += sent return sent - def __iter__(self): + def __iter__(self) -> typing.Iterator[bytes]: return self - def __next__(self): + def __next__(self) -> bytes: assert self.offset <= 0 self.offset = -1 data = self.filelike.read(self.blksize) @@ -53,8 +60,12 @@ class FileWrapper: return data raise StopIteration -def _convert_environ(environ, multithread=False, multiprocess=False, - run_once=False): +def _convert_environ( + environ: Environ, + multithread: bool = False, + multiprocess: bool = False, + run_once: bool = False, +) -> None: environ.update({ "wsgi.version": (1, 0), "wsgi.url_scheme": "http", diff --git a/wsgitools/scgi/asynchronous.py b/wsgitools/scgi/asynchronous.py index 61bbc6b..264e43e 100644 --- a/wsgitools/scgi/asynchronous.py +++ b/wsgitools/scgi/asynchronous.py @@ -5,10 +5,24 @@ import io import socket import sys import errno +import typing -from wsgitools.internal import bytes2str, str2bytes +from wsgitools.internal import ( + bytes2str, + Environ, + HeaderList, + OptExcInfo, + str2bytes, + WriteCallback, + WsgiApp, +) from wsgitools.scgi import _convert_environ, FileWrapper +if sys.version_info >= (3, 8): + LiteralTrue = typing.Literal[True] +else: + LiteralTrue = bool + class SCGIConnection(asyncore.dispatcher): """SCGI connection class used by L{SCGIServer}.""" # connection states @@ -19,11 +33,27 @@ class SCGIConnection(asyncore.dispatcher): RESP = 3*4 | 2 # sending response, end state RESPH = 4*4 | 2 # buffered response headers, sending headers only, to TRANS TRANS = 5*4 | 2 # transferring using FileWrapper, end state - def __init__(self, server, connection, maxrequestsize=65536, - maxpostsize=8<<20, blocksize=4096, config={}): + + outheaders: typing.Union[ + typing.Tuple[()], # unset + typing.Tuple[str, HeaderList], # headers + LiteralTrue # sent + ] + wsgihandler: typing.Optional[typing.Iterable[bytes]] + wsgiiterator: typing.Optional[typing.Iterator[bytes]] + + def __init__( + self, + server: "SCGIServer", + connection: socket.socket, + maxrequestsize: int = 65536, + maxpostsize: int = 8<<20, + blocksize: int = 4096, + config: Environ = {}, + ): asyncore.dispatcher.__init__(self, connection) - self.server = server # WSGISCGIServer instance + self.server = server self.maxrequestsize = maxrequestsize self.maxpostsize = maxpostsize self.blocksize = blocksize @@ -35,10 +65,9 @@ class SCGIConnection(asyncore.dispatcher): self.wsgihandler = None # wsgi application self.wsgiiterator = None # wsgi application iterator self.outheaders = () # headers to be sent - # () -> unset, (..,..) -> set, True -> sent self.body = io.BytesIO() # request body - def _try_send_headers(self): + def _try_send_headers(self) -> None: if self.outheaders != True: assert not self.outbuff status, headers = self.outheaders @@ -47,7 +76,7 @@ class SCGIConnection(asyncore.dispatcher): self.outbuff = str2bytes(headdata) self.outheaders = True - def _wsgi_write(self, data): + def _wsgi_write(self, data: bytes) -> None: assert self.state >= SCGIConnection.RESP assert self.state < SCGIConnection.TRANS assert isinstance(data, bytes) @@ -55,15 +84,15 @@ class SCGIConnection(asyncore.dispatcher): self._try_send_headers() self.outbuff += data - def readable(self): + def readable(self) -> bool: """C{asyncore} interface""" return self.state & 1 == 1 - def writable(self): + def writable(self) -> bool: """C{asyncore} interface""" return self.state & 2 == 2 - def handle_read(self): + def handle_read(self) -> None: """C{asyncore} interface""" data = self.recv(self.blocksize) self.inbuff += data @@ -122,6 +151,7 @@ class SCGIConnection(asyncore.dispatcher): self.environ["wsgi.errors"] = self.server.error self.wsgihandler = self.server.wsgiapp(self.environ, self.start_response) + assert self.wsgihandler is not None if isinstance(self.wsgihandler, FileWrapper) and \ self.wsgihandler.can_transfer(): self._try_send_headers() @@ -134,7 +164,12 @@ class SCGIConnection(asyncore.dispatcher): self.reqlen -= len(self.inbuff) self.inbuff = b"" - def start_response(self, status, headers, exc_info=None): + def start_response( + self, + status: str, + headers: HeaderList, + exc_info: OptExcInfo = None, + ) -> WriteCallback: assert isinstance(status, str) assert isinstance(headers, list) if exc_info: @@ -147,7 +182,7 @@ class SCGIConnection(asyncore.dispatcher): self.outheaders = (status, headers) return self._wsgi_write - def send_buff(self): + def send_buff(self) -> None: try: sentbytes = self.send(self.outbuff[:self.blocksize]) except socket.error: @@ -155,11 +190,12 @@ class SCGIConnection(asyncore.dispatcher): else: self.outbuff = self.outbuff[sentbytes:] - def handle_write(self): + def handle_write(self) -> None: """C{asyncore} interface""" if self.state == SCGIConnection.RESP: if len(self.outbuff) < self.blocksize: self._try_send_headers() + assert self.wsgiiterator is not None for data in self.wsgiiterator: assert isinstance(data, bytes) if data: @@ -171,23 +207,26 @@ class SCGIConnection(asyncore.dispatcher): self.send_buff() elif self.state == SCGIConnection.RESPH: assert len(self.outbuff) > 0 + assert isinstance(self.wsgihandler, FileWrapper) self.send_buff() if not self.outbuff: self.state = SCGIConnection.TRANS else: assert self.state == SCGIConnection.TRANS + assert isinstance(self.wsgihandler, FileWrapper) assert self.wsgihandler.can_transfer() sent = self.wsgihandler.transfer(self.socket, self.blocksize) if sent <= 0: self.close() - def close(self): + def close(self) -> None: # None doesn't have a close attribute if hasattr(self.wsgihandler, "close"): + assert self.wsgihandler is not None self.wsgihandler.close() asyncore.dispatcher.close(self) - def handle_close(self): + def handle_close(self) -> None: """C{asyncore} interface""" self.close() @@ -195,33 +234,36 @@ __all__.append("SCGIServer") class SCGIServer(asyncore.dispatcher): """SCGI Server for WSGI applications. It does not use multiple processes or multiple threads.""" - def __init__(self, wsgiapp, port, interface="localhost", error=sys.stderr, - maxrequestsize=None, maxpostsize=None, blocksize=None, - config={}, reusesocket=None): + + def __init__( + self, + wsgiapp: WsgiApp, + port: int, + interface: str = "localhost", + error: typing.TextIO = sys.stderr, + maxrequestsize: typing.Optional[int] = None, + maxpostsize: typing.Optional[int] = None, + blocksize: typing.Optional[int] = None, + config: Environ = {}, + reusesocket: typing.Optional[socket.socket] = None + ): """ @param wsgiapp: is the wsgi application to be run. - @type port: int @param port: is an int representing the TCP port number to be used. - @type interface: str @param interface: is a string specifying the network interface to bind which defaults to C{"localhost"} making the server inaccessible over network. @param error: is a file-like object being passed as C{wsgi.error} in the environ parameter defaulting to stderr. - @type maxrequestsize: int @param maxrequestsize: limit the size of request blocks in scgi connections. Connections are dropped when this limit is hit. - @type maxpostsize: int @param maxpostsize: limit the size of post bodies that may be processed by this instance. Connections are dropped when this limit is hit. - @type blocksize: int @param blocksize: is amount of data to read or write from or to the network at once - @type config: {} @param config: the environ dictionary is updated using these values for each request. - @type reusesocket: None or socket.socket @param reusesocket: If a socket is passed, do not create a socket. Instead use given socket as listen socket. The passed socket must be set up for accepting tcp connections (i.e. C{AF_INET}, @@ -234,7 +276,7 @@ class SCGIServer(asyncore.dispatcher): self.wsgiapp = wsgiapp self.error = error - self.conf = {} + self.conf: Environ = {} if maxrequestsize is not None: self.conf["maxrequestsize"] = maxrequestsize if maxpostsize is not None: @@ -251,7 +293,7 @@ class SCGIServer(asyncore.dispatcher): else: self.accepting = True - def handle_accept(self): + def handle_accept(self) -> None: """asyncore interface""" try: ret = self.accept() @@ -264,7 +306,7 @@ class SCGIServer(asyncore.dispatcher): conn, _ = ret SCGIConnection(self, conn, **self.conf) - def run(self): + def run(self) -> None: """Runs the server. It will not return and you can invoke C{asyncore.loop()} instead achieving the same effect.""" asyncore.loop() diff --git a/wsgitools/scgi/forkpool.py b/wsgitools/scgi/forkpool.py index df8a92f..f864a6a 100644 --- a/wsgitools/scgi/forkpool.py +++ b/wsgitools/scgi/forkpool.py @@ -12,21 +12,24 @@ import select import sys import errno import signal +import typing -from wsgitools.internal import bytes2str, str2bytes +from wsgitools.internal import ( + bytes2str, HeaderList, OptExcInfo, str2bytes, WriteCallback, WsgiApp +) from wsgitools.scgi import _convert_environ, FileWrapper __all__ = [] class SocketFileWrapper: """Wraps a socket to a wsgi-compliant file-like object.""" - def __init__(self, sock, toread): + def __init__(self, sock: socket.socket, toread: int): """@param sock: is a C{socket.socket()}""" self.sock = sock self.buff = b"" self.toread = toread - def _recv(self, size=4096): + def _recv(self, size: int = 4096) -> bytes: """ internal method for receiving and counting incoming data @raises socket.error: @@ -44,7 +47,7 @@ class SocketFileWrapper: self.toread -= len(data) return data - def close(self): + def close(self) -> None: """Does not close the socket, because it might still be needed. It reads all data that should have been read as given by C{CONTENT_LENGTH}. """ @@ -55,7 +58,7 @@ class SocketFileWrapper: except socket.error: pass - def read(self, size=None): + def read(self, size: typing.Optional[int] = None) -> bytes: """ see pep333 @raises socket.error: @@ -92,7 +95,7 @@ class SocketFileWrapper: ret, self.buff = self.buff, b"" return ret - def readline(self, size=None): + def readline(self, size: typing.Optional[int] = None) -> bytes: """ see pep333 @raises socket.error: @@ -118,7 +121,7 @@ class SocketFileWrapper: return ret self.buff += data - def readlines(self): + def readlines(self) -> typing.Iterator[bytes]: """ see pep333 @raises socket.error: @@ -127,10 +130,10 @@ class SocketFileWrapper: while data: yield data data = self.readline() - def __iter__(self): + def __iter__(self) -> typing.Iterator[bytes]: """see pep333""" return self - def __next__(self): + def __next__(self) -> bytes: """ see pep333 @raises socket.error: @@ -139,9 +142,9 @@ class SocketFileWrapper: if not data: raise StopIteration return data - def flush(self): + def flush(self) -> None: """see pep333""" - def write(self, data): + def write(self, data: bytes) -> None: """see pep333""" assert isinstance(data, bytes) try: @@ -149,7 +152,7 @@ class SocketFileWrapper: except socket.error: # ignore all socket errors: there is no way to report return - def writelines(self, lines): + def writelines(self, lines: typing.List[bytes]) -> None: """see pep333""" for line in lines: self.write(line) @@ -161,49 +164,51 @@ class SCGIServer: class WorkerState: """state: 0 means idle and 1 means working. These values are also sent as strings '0' and '1' over the socket.""" - def __init__(self, pid, sock, state): - """ - @type pid: int - @type state: int - """ + def __init__(self, pid: int, sock: socket.socket, state: int): self.pid = pid self.sock = sock self.state = state - def __init__(self, wsgiapp, port, interface="localhost", error=sys.stderr, - minworkers=2, maxworkers=32, maxrequests=1000, config={}, - reusesocket=None, cpulimit=None, timelimit=None): + server: typing.Optional[socket.socket] + workers: typing.Dict[int, WorkerState] + sigpipe: typing.Optional[typing.Tuple[socket.socket, socket.socket]] + + def __init__( + self, + wsgiapp: WsgiApp, + port: int, + interface: str = "localhost", + error: typing.TextIO = sys.stderr, + minworkers: int = 2, + maxworkers: int = 32, + maxrequests: int = 1000, + config: typing.Dict[typing.Any, typing.Any] = {}, + reusesocket: typing.Optional[socket.socket] = None, + cpulimit: typing.Optional[typing.Tuple[int, int]] = None, + timelimit: typing.Optional[int] = None, + ): """ @param wsgiapp: is the WSGI application to be run. - @type port: int @param port: is the tcp port to listen on - @type interface: str @param interface: is the interface to bind to (default: C{"localhost"}) @param error: is a file-like object beeing passed as C{wsgi.errors} in environ - @type minworkers: int @param minworkers: is the number of worker processes to spawn - @type maxworkers: int @param maxworkers: is the maximum number of workers that can be spawned on demand - @type maxrequests: int @param maxrequests: is the number of requests a worker processes before dying - @type config: {} @param config: the environ dictionary is updated using these values for each request. - @type reusesocket: None or socket.socket @param reusesocket: If a socket is passed, do not create a socket. Instead use given socket as listen socket. The passed socket must be set up for accepting tcp connections (i.e. C{AF_INET}, C{SOCK_STREAM} with bind and listen called). - @type cpulimit: (int, int) @param cpulimit: a pair of soft and hard cpu time limit in seconds. This limit is installed for each worker using RLIMIT_CPU if resource limits are available to this platform. After reaching the soft limit workers will continue to process the current request and then cleanly terminate. - @type timelimit: int @param timelimit: The maximum number of wall clock seconds processing a request should take. If this is specified, an alarm timer is installed and the default action is to kill the worker. @@ -229,7 +234,7 @@ class SCGIServer: self.running = False self.ischild = False - def enable_sighandler(self, sig=signal.SIGTERM): + def enable_sighandler(self, sig: int = signal.SIGTERM) -> "SCGIServer": """ Changes the signal handler for the given signal to terminate the run() loop. @@ -239,7 +244,7 @@ class SCGIServer: signal.signal(sig, self.shutdownhandler) return self - def run(self): + def run(self) -> None: """ Serve the wsgi application. """ @@ -296,7 +301,7 @@ class SCGIServer: self.sigpipe = None self.killworkers() - def killworkers(self, sig=signal.SIGTERM): + def killworkers(self, sig: int = signal.SIGTERM) -> None: """ Kills all worker children. @param sig: is the signal used to kill the children @@ -307,7 +312,9 @@ class SCGIServer: os.kill(state.pid, sig) # TODO: handle working children with a timeout - def shutdownhandler(self, sig=None, stackframe=None): + def shutdownhandler( + self, sig: typing.Optional[int] = None, stackframe=None + ) -> None: """ Signal handler function for stopping the run() loop. It works by setting a variable that run() evaluates in each loop. As a signal @@ -319,10 +326,13 @@ class SCGIServer: if self.ischild: sys.exit() elif self.running: + assert self.sigpipe is not None self.running = False self.sigpipe[1].send(b' ') - def sigxcpuhandler(self, sig=None, stackframe=None): + def sigxcpuhandler( + self, sig: typing.Optional[int] = None, stackframe=None + ) -> None: """ Signal hanlder function for the SIGXCUP signal. It is sent to a worker when the soft RLIMIT_CPU is crossed. @@ -331,7 +341,7 @@ class SCGIServer: """ self.cpulimit = True - def spawnworker(self): + def spawnworker(self) -> None: """ internal! spawns a single worker """ @@ -366,11 +376,12 @@ class SCGIServer: else: raise RuntimeError("fork failed") - def work(self, worksock): + def work(self, worksock: socket.socket) -> None: """ internal! serves maxrequests times @raises socket.error: """ + assert self.server is not None for _ in range(self.maxrequests): (con, addr) = self.server.accept() # we cannot handle socket.errors here. @@ -384,7 +395,7 @@ class SCGIServer: if self.cpulimit: break - def process(self, con): + def process(self, con: socket.socket) -> None: """ internal! processes a single request on the connection con. """ @@ -439,9 +450,9 @@ class SCGIServer: # 0 -> False: set but unsent # 0 -> True: sent # 1 -> bytes of the complete header - response_head = [None, None] + response_head: typing.List[typing.Any] = [None, None] - def sendheaders(): + def sendheaders() -> None: assert response_head[0] is not None # headers set if response_head[0] != True: response_head[0] = True @@ -450,14 +461,16 @@ class SCGIServer: except socket.error: pass - def dumbsend(data): + def dumbsend(data: bytes) -> None: sendheaders() try: con.sendall(data) except socket.error: pass - def start_response(status, headers, exc_info=None): + def start_response( + status: str, headers: HeaderList, exc_info: OptExcInfo = None + ) -> WriteCallback: if exc_info and response_head[0]: try: raise exc_info[1].with_traceback(exc_info[2]) -- cgit v1.2.3