diff options
Diffstat (limited to 'wsgitools')
-rw-r--r-- | wsgitools/adapters.py | 6 | ||||
-rw-r--r-- | wsgitools/applications.py | 14 | ||||
-rw-r--r-- | wsgitools/authentication.py | 6 | ||||
-rw-r--r-- | wsgitools/digest.py | 95 | ||||
-rw-r--r-- | wsgitools/filters.py | 71 | ||||
-rw-r--r-- | wsgitools/internal.py | 19 | ||||
-rw-r--r-- | wsgitools/middlewares.py | 38 | ||||
-rw-r--r-- | wsgitools/scgi/__init__.py | 4 | ||||
-rw-r--r-- | wsgitools/scgi/asynchronous.py | 50 | ||||
-rw-r--r-- | wsgitools/scgi/forkpool.py | 90 |
10 files changed, 202 insertions, 191 deletions
diff --git a/wsgitools/adapters.py b/wsgitools/adapters.py index 2c7615a..050c00a 100644 --- a/wsgitools/adapters.py +++ b/wsgitools/adapters.py @@ -9,12 +9,6 @@ __all__ = [] from wsgitools.filters import CloseableIterator, CloseableList -try: - next -except NameError: - def next(it): - return it.next() - __all__.append("WSGI2to1Adapter") class WSGI2to1Adapter(object): """Adapts an application with an interface that might somewhen be known as diff --git a/wsgitools/applications.py b/wsgitools/applications.py index 6b6601b..df304db 100644 --- a/wsgitools/applications.py +++ b/wsgitools/applications.py @@ -21,7 +21,7 @@ class StaticContent(object): @type headers: list @param headers: is a list of C{(header, value)} pairs being delivered as HTTP headers - @type content: basestring + @type content: bytes @param content: contains the data to be delivered to the client. It is either a string or some kind of iterable yielding strings. @type anymethod: boolean @@ -30,12 +30,12 @@ class StaticContent(object): """ assert isinstance(status, str) assert isinstance(headers, list) - assert isinstance(content, basestring) or hasattr(content, "__iter__") + assert isinstance(content, bytes) or hasattr(content, "__iter__") self.status = status self.headers = headers self.anymethod = anymethod length = -1 - if isinstance(content, basestring): + if isinstance(content, bytes): self.content = [content] length = len(content) else: @@ -50,7 +50,7 @@ class StaticContent(object): assert isinstance(environ, dict) if environ["REQUEST_METHOD"].upper() not in ["GET", "HEAD"] and \ not self.anymethod: - resp = "Request method not implemented" + resp = b"Request method not implemented" start_response("501 Not Implemented", [("Content-length", str(len(resp)))]) return [resp] @@ -102,7 +102,7 @@ class StaticFile(object): assert isinstance(environ, dict) if environ["REQUEST_METHOD"].upper() not in ["GET", "HEAD"]: - resp = "Request method not implemented" + resp = b"Request method not implemented" start_response("501 Not Implemented", [("Content-length", str(len(resp)))]) return [resp] @@ -112,7 +112,7 @@ class StaticFile(object): try: if isinstance(self.filelike, basestring): # raises IOError - stream = file(self.filelike) + stream = open(self.filelike, "rb") size = os.path.getsize(self.filelike) else: stream = self.filelike @@ -121,7 +121,7 @@ class StaticFile(object): size = stream.tell() stream.seek(0) except IOError: - resp = "File not found" + resp = b"File not found" start_response("404 File not found", [("Content-length", str(len(resp)))]) return [resp] diff --git a/wsgitools/authentication.py b/wsgitools/authentication.py index c39c018..59747e0 100644 --- a/wsgitools/authentication.py +++ b/wsgitools/authentication.py @@ -64,7 +64,7 @@ class AuthenticationMiddleware(object): raise AuthenticationRequired( "authorization method not implemented: %r" % method) result = self.authenticate(rest, environ) - except AuthenticationRequired, exc: + except AuthenticationRequired as exc: return self.authorization_required(environ, start_response, exc) assert isinstance(result, dict) assert "user" in result @@ -97,8 +97,8 @@ class AuthenticationMiddleware(object): @param exception: reason for the authentication failure """ status = "401 Authorization required" - html = "<html><head><title>401 Authorization required</title></head>" \ - "<body><h1>401 Authorization required</h1></body></html>" + html = b"<html><head><title>401 Authorization required</title></head>" \ + b"<body><h1>401 Authorization required</h1></body></html>" headers = [("Content-Type", "text/html"), self.www_authenticate(exception), ("Content-Length", str(len(html)))] diff --git a/wsgitools/digest.py b/wsgitools/digest.py index 5ed05f8..2f49ff7 100644 --- a/wsgitools/digest.py +++ b/wsgitools/digest.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python2.5 """ This module contains an C{AuthDigestMiddleware} for authenticating HTTP requests using the method described in RFC2617. The credentials are to be @@ -15,32 +14,39 @@ database using C{DBAPI2NonceStore}. __all__ = [] import random -try: - from hashlib import md5 -except ImportError: - from md5 import md5 -import binascii import base64 +import hashlib import time import os +from wsgitools.internal import bytes2str, str2bytes, textopen from wsgitools.authentication import AuthenticationRequired, \ ProtocolViolation, AuthenticationMiddleware sysrand = random.SystemRandom() -def gen_rand_str(bytes=33): +def md5hex(data): + """ + @type data: str + @rtype: str + """ + return hashlib.md5(str2bytes(data)).hexdigest() + +def gen_rand_str(bytesentropy=33): """ Generates a string of random base64 characters. - @param bytes: is the number of random 8bit values to be used + @param bytesentropy: is the number of random 8bit values to be used + @rtype: str >>> gen_rand_str() != gen_rand_str() True """ - randnum = sysrand.getrandbits(bytes*8) - randstr = ("%%0%dX" % (2*bytes)) % randnum - randstr = binascii.unhexlify(randstr) - randstr = base64.encodestring(randstr).strip() + randnum = sysrand.getrandbits(bytesentropy*8) + randstr = ("%%0%dX" % (2*bytesentropy)) % randnum + randbytes = str2bytes(randstr) + randbytes = base64.b16decode(randbytes) + randbytes = base64.b64encode(randbytes) + randstr = bytes2str(randbytes) return randstr def parse_digest_response(data): @@ -121,6 +127,8 @@ def format_digest(mapping): assert isinstance(mapping, dict) result = [] for key, (value, needsquoting) in mapping.items(): + assert isinstance(key, str) + assert isinstance(value, str) if needsquoting: value = '"%s"' % value.replace('\\', '\\\\').replace('"', '\\"') else: @@ -173,8 +181,8 @@ class AbstractTokenGenerator(object): """ assert isinstance(username, str) assert isinstance(password, str) - token = md5("%s:%s:%s" % (username, self.realm, password)).hexdigest() - return token == self(username) + token = "%s:%s:%s" % (username, self.realm, password) + return md5hex(token) == self(username) __all__.append("AuthTokenGenerator") class AuthTokenGenerator(AbstractTokenGenerator): @@ -200,7 +208,7 @@ class AuthTokenGenerator(AbstractTokenGenerator): if password is None: return None a1 = "%s:%s:%s" % (username, self.realm, password) - return md5(a1).hexdigest() + return md5hex(a1) __all__.append("HtdigestTokenGenerator") class HtdigestTokenGenerator(AbstractTokenGenerator): @@ -231,18 +239,19 @@ class HtdigestTokenGenerator(AbstractTokenGenerator): """ assert isinstance(htdigestfile, str) self.users = {} - for line in file(htdigestfile): - parts = line.rstrip("\n").split(":") - if len(parts) != 3: - if ignoreparseerrors: + with textopen(htdigestfile, "r") as htdigest: + for line in htdigest: + parts = line.rstrip("\n").split(":") + if len(parts) != 3: + if ignoreparseerrors: + continue + raise ValueError("invalid number of colons in htdigest file") + user, realm, token = parts + if realm != self.realm: continue - raise ValueError("invalid number of colons in htdigest file") - user, realm, token = parts - if realm != self.realm: - continue - if user in self.users and not ignoreparseerrors: - raise ValueError("duplicate user in htdigest file") - self.users[user] = token + if user in self.users and not ignoreparseerrors: + raise ValueError("duplicate user in htdigest file") + self.users[user] = token def __call__(self, user, algo="md5"): assert algo.lower() in ["md5", "md5-sess"] @@ -259,7 +268,7 @@ class UpdatingHtdigestTokenGenerator(HtdigestTokenGenerator): # modifications. try: self.statcache = os.stat(htdigestfile) - except OSError, err: + except OSError as err: raise IOError(str(err)) HtdigestTokenGenerator.__init__(self, realm, htdigestfile, ignoreparseerrors) @@ -276,7 +285,9 @@ class UpdatingHtdigestTokenGenerator(HtdigestTokenGenerator): if self.statcache != statcache: try: self.readhtdigest(self.htdigestfile, self.ignoreparseerrors) - except (IOError, ValueError): + except IOError: + return None + except ValueError: return None return HtdigestTokenGenerator.__call__(self, user, algo) @@ -366,7 +377,7 @@ class StatelessNonceStore(NonceStoreBase): token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret) if ident is not None: token = "%s:%s" % (token, ident) - token = md5(token).hexdigest() + token = md5hex(token) return "%s:%s:%s" % (nonce_time, nonce_value, token) def checknonce(self, nonce, count=1, ident=None): @@ -386,7 +397,7 @@ class StatelessNonceStore(NonceStoreBase): token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret) if ident is not None: token = "%s:%s" % (token, ident) - token = md5(token).hexdigest() + token = md5hex(token) if token != nonce_hash: return False @@ -448,7 +459,7 @@ class MemoryNonceStore(NonceStoreBase): token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret) if ident is not None: token = "%s:%s" % (token, ident) - token = md5(token).hexdigest() + token = md5hex(token) return "%s:%s:%s" % (nonce_time, nonce_value, token) def checknonce(self, nonce, count=1, ident=None): @@ -467,7 +478,7 @@ class MemoryNonceStore(NonceStoreBase): token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret) if ident is not None: token = "%s:%s" % (token, ident) - token = md5(token).hexdigest() + token = md5hex(token) if token != nonce_hash: return False @@ -594,7 +605,7 @@ class DBAPI2NonceStore(NonceStoreBase): token = "%s:%s" % (dbkey, self.server_secret) if ident is not None: token = "%s:%s" % (token, ident) - token = md5(token).hexdigest() + token = md5hex(token) return "%s:%s:%s" % (nonce_time, nonce_value, token) def checknonce(self, nonce, count=1, ident=None): @@ -603,19 +614,22 @@ class DBAPI2NonceStore(NonceStoreBase): count on returning True. @type nonce: str @type count: int + @type ident: str or None @rtype: bool """ try: nonce_time, nonce_value, nonce_hash = nonce.split(':') except ValueError: return False - if not nonce_time.isalnum() or not nonce_value.replace("+", ""). \ - replace("/", "").replace("=", "").isalnum(): + # use bytes.isalnum to avoid locale specific interpretation + if not str2bytes(nonce_time).isalnum() or \ + not str2bytes(nonce_value.replace("+", "").replace("/", "") \ + .replace("=", "")).isalnum(): return False token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret) if ident is not None: token = "%s:%s" % (token, ident) - token = md5(token).hexdigest() + token = md5hex(token) if token != nonce_hash: return False @@ -680,7 +694,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware): by a REMOTE_USER key before being passed to the wrapped application.""" authorization_method = "digest" - algorithms = {"md5": lambda data: md5(data).hexdigest()} + algorithms = {"md5": md5hex} def __init__(self, app, gentoken, maxage=300, maxuses=5, store=None): """ @param app: is the wsgi application to be served with authentication. @@ -707,6 +721,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware): self.noncestore = store def authenticate(self, auth, environ): + assert isinstance(auth, str) try: credentials = parse_digest_response(auth) except ValueError: @@ -724,7 +739,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware): try: nonce = credentials["nonce"] credresponse = credentials["response"] - except KeyError, err: + except KeyError as err: raise ProtocolViolation("%s missing in credentials" % err.args[0]) noncecount = 1 @@ -765,7 +780,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware): username = credentials["username"] algo = credentials["algorithm"] uri = credentials["uri"] - except KeyError, err: + except KeyError as err: raise ProtocolViolation("%s missing in credentials" % err.args[0]) try: dig = [credentials["nonce"]] @@ -778,7 +793,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware): try: dig.append(credentials["nc"]) dig.append(credentials["cnonce"]) - except KeyError, err: + except KeyError as err: raise ProtocolViolation( "missing %s in credentials with qop=auth" % err.args[0]) dig.append(qop) diff --git a/wsgitools/filters.py b/wsgitools/filters.py index d691f74..ed976a2 100644 --- a/wsgitools/filters.py +++ b/wsgitools/filters.py @@ -10,18 +10,9 @@ __all__ = [] import sys import time import gzip -# Cannot use io module as it is broken in 2.6. -# Writing a str to a io.StringIO results in an exception. -try: - import cStringIO as io -except ImportError: - import StringIO as io +import io -try: - next -except NameError: - def next(it): - return it.next() +from wsgitools.internal import str2bytes __all__.append("CloseableIterator") class CloseableIterator(object): @@ -40,7 +31,7 @@ class CloseableIterator(object): @rtype: gen() """ return self - def next(self): + def __next__(self): """iterator interface""" if not self.iterators: raise StopIteration @@ -49,6 +40,8 @@ class CloseableIterator(object): except StopIteration: self.iterators.pop(0) return next(self) + def next(self): + return self.__next__() __all__.append("CloseableList") class CloseableList(list): @@ -131,15 +124,15 @@ class BaseWSGIFilter(object): """For each string that is either written by the C{write} callable or returned from the wrapped wsgi application this method is invoked. It must return a string. - @type data: str - @rtype: str + @type data: bytes + @rtype: bytes """ return data def append_data(self): """This function can be used to append data to the response. A list of strings or some kind of iterable yielding strings has to be returned. The default is to return an empty list. - @rtype: gen([str]) + @rtype: gen([bytes]) """ return [] def handle_close(self): @@ -161,7 +154,7 @@ class WSGIFilterMiddleware(object): def __call__(self, environ, start_response): """wsgi interface @type environ: {str, str} - @rtype: gen([str]) + @rtype: gen([bytes]) """ assert isinstance(environ, dict) reqfilter = self.filterclass() @@ -205,7 +198,7 @@ class WSGIFilterMiddleware(object): # default arguments. Also note that neither ' nor " are considered printable. # For escape_string to be reversible \ is also not considered printable. def escape_string(string, replacer=list(map( - lambda i: chr(i) if chr(i).isalnum() or + lambda i: chr(i) if str2bytes(chr(i)).isalnum() or chr(i) in '!#$%&()*+,-./:;<=>?@[]^_`{|}~ ' else r"\x%2.2x" % i, range(256)))): @@ -224,6 +217,9 @@ class RequestLogWSGIFilter(BaseWSGIFilter): """Returns a function creating L{RequestLogWSGIFilter}s on given log file. log has to be a file-like object. @type log: file-like + @param log: elements of type str are written to the log. That means in + Py3.X the contents are decoded and in Py2.X the log is assumed + to be encoded in latin1. This follows the spirit of WSGI. @type flush: bool @param flush: if True, invoke the flush method on log after each write invocation @@ -276,11 +272,7 @@ class RequestLogWSGIFilter(BaseWSGIFilter): self.status = status.split()[0] return status def filter_data(self, data): - """BaseWSGIFilter interface - @type data: str - @rtype: str - """ - assert isinstance(data, str) + assert isinstance(data, bytes) self.length += len(data) return data def handle_close(self): @@ -315,30 +307,31 @@ class TimerWSGIFilter(BaseWSGIFilter): def creator(cls, pattern): """Returns a function creating L{TimerWSGIFilter}s with a given pattern beeing a string of exactly eight bytes. - @type pattern: str + @type pattern: bytes """ return lambda:cls(pattern) - def __init__(self, pattern="?GenTime"): + def __init__(self, pattern=b"?GenTime"): """ @type pattern: str """ BaseWSGIFilter.__init__(self) + assert isinstance(pattern, bytes) self.pattern = pattern self.start = time.time() def filter_data(self, data): """BaseWSGIFilter interface - @type data: str - @rtype: str + @type data: bytes + @rtype: bytes """ if data == self.pattern: - return "%8.3g" % (time.time() - self.start) + return str2bytes("%8.3g" % (time.time() - self.start)) return data __all__.append("EncodeWSGIFilter") class EncodeWSGIFilter(BaseWSGIFilter): """Encodes all body data (no headers) with given charset. @note: This violates the wsgi standard as it requires unicode objects - whereas wsgi mandates the use of str. + whereas wsgi mandates the use of bytes. """ @classmethod def creator(cls, charset): @@ -356,7 +349,7 @@ class EncodeWSGIFilter(BaseWSGIFilter): def filter_data(self, data): """BaseWSGIFilter interface @type data: str - @rtype: str + @rtype: bytes """ return data.encode(self.charset) def filter_header(self, header, value): @@ -400,7 +393,7 @@ class GzipWSGIFilter(BaseWSGIFilter): acceptenc = map(str.strip, acceptenc) if "gzip" in acceptenc: self.compress = True - self.sio = io.StringIO() + self.sio = io.BytesIO() self.gzip = gzip.GzipFile(fileobj=self.sio, mode="w") return environ def filter_header(self, headername, headervalue): @@ -423,10 +416,6 @@ class GzipWSGIFilter(BaseWSGIFilter): headers.append(("Content-encoding", "gzip")) return headers def filter_data(self, data): - """BaseWSGIFilter interface - @type data: str - @rtype: str - """ if not self.compress: return data self.gzip.write(data) @@ -434,11 +423,9 @@ class GzipWSGIFilter(BaseWSGIFilter): self.gzip.flush() data = self.sio.getvalue() self.sio.truncate(0) + self.sio.seek(0) return data def append_data(self): - """BaseWSGIFilter interface - @rtype: [str] - """ if not self.compress: return [] self.gzip.close() @@ -449,7 +436,7 @@ class ReusableWSGIInputFilter(BaseWSGIFilter): """Make C{environ["wsgi.input"]} readable multiple times. Although this is not required by the standard it is sometimes desirable to read C{wsgi.input} multiple times. This filter will therefore replace that variable with a - C{StringIO} instance which provides a C{seek} method. + C{BytesIO} instance which provides a C{seek} method. """ @classmethod def creator(cls, maxrequestsize): @@ -460,14 +447,14 @@ class ReusableWSGIInputFilter(BaseWSGIFilter): adapter to eat this data.) @type maxrequestsize: int @param maxrequestsize: is the maximum number of bytes to store in the - C{StringIO} + C{BytesIO} """ return lambda:cls(maxrequestsize) def __init__(self, maxrequestsize=65536): """ReusableWSGIInputFilters constructor. @type maxrequestsize: int @param maxrequestsize: is the maximum number of bytes to store in the - C{StringIO}, see L{creator} + C{BytesIO}, see L{creator} """ BaseWSGIFilter.__init__(self) self.maxrequestsize = maxrequestsize @@ -477,12 +464,12 @@ class ReusableWSGIInputFilter(BaseWSGIFilter): @type environ: {str: str} """ - if isinstance(environ["wsgi.input"], io.StringIO): + if isinstance(environ["wsgi.input"], io.BytesIO): return environ # nothing to be done # XXX: is this really a good idea? use with care environ["wsgitools.oldinput"] = environ["wsgi.input"] - data = io.StringIO(environ["wsgi.input"].read(self.maxrequestsize)) + data = io.BytesIO(environ["wsgi.input"].read(self.maxrequestsize)) environ["wsgi.input"] = data return environ diff --git a/wsgitools/internal.py b/wsgitools/internal.py new file mode 100644 index 0000000..c4f1da1 --- /dev/null +++ b/wsgitools/internal.py @@ -0,0 +1,19 @@ +if bytes is str: + def bytes2str(bstr): + assert isinstance(bstr, bytes) + return bstr + def str2bytes(sstr): + assert isinstance(sstr, str) + return sstr + def textopen(filename, mode): + return open(filename, mode) +else: + def bytes2str(bstr): + assert isinstance(bstr, bytes) + return bstr.decode("iso-8859-1") # always successful + def str2bytes(sstr): + assert isinstance(sstr, str) + return sstr.encode("iso-8859-1") # might fail, but spec says it doesn't + def textopen(filename, mode): + # We use the same encoding as for all wsgi strings here. + return open(filename, mode, encoding="iso-8859-1") diff --git a/wsgitools/middlewares.py b/wsgitools/middlewares.py index 0f5e416..b37b130 100644 --- a/wsgitools/middlewares.py +++ b/wsgitools/middlewares.py @@ -1,21 +1,13 @@ __all__ = [] +import base64 import time import sys import cgitb -import binascii import collections -# Cannot use io module as it is broken in 2.6. -# Writing a str to a io.StringIO results in an exception. -try: - import cStringIO as io -except ImportError: - import StringIO as io -try: - next -except NameError: - def next(iterator): - return iterator.next() +import io + +from wsgitools.internal import bytes2str, str2bytes if sys.version_info[0] >= 3: def exc_info_for_raise(exc_info): @@ -41,7 +33,7 @@ class SubdirMiddleware(object): def __call__(self, environ, start_response): """wsgi interface @type environ: {str: str} - @rtype: gen([str]) + @rtype: gen([bytes]) """ assert isinstance(environ, dict) app = None @@ -65,7 +57,7 @@ __all__.append("NoWriteCallableMiddleware") class NoWriteCallableMiddleware(object): """This middleware wraps a wsgi application that needs the return value of C{start_response} function to a wsgi application that doesn't need one by - writing the data to a C{StringIO} and then making it be the first result + writing the data to a C{BytesIO} and then making it be the first result element.""" def __init__(self, app): """Wraps wsgi application app.""" @@ -73,11 +65,11 @@ class NoWriteCallableMiddleware(object): def __call__(self, environ, start_response): """wsgi interface @type environ: {str, str} - @rtype: gen([str]) + @rtype: gen([bytes]) """ assert isinstance(environ, dict) todo = [None] - sio = io.StringIO() + sio = io.BytesIO() gotiterdata = False def write_calleable(data): assert not gotiterdata @@ -97,7 +89,7 @@ class NoWriteCallableMiddleware(object): ret = self.app(environ, modified_start_response) assert hasattr(ret, "__iter__") - first = "" + first = b"" if not isinstance(ret, list): ret = iter(ret) stopped = False @@ -179,7 +171,7 @@ class ContentLengthMiddleware(object): return ret ret = iter(ret) - first = "" + first = b"" stopped = False while not (first or stopped): try: @@ -357,14 +349,14 @@ class BasicAuthMiddleware(AuthenticationMiddleware): self.app401 = app401 def authenticate(self, auth, environ): - """ - @type environ: {str: object} - """ + assert isinstance(auth, str) assert isinstance(environ, dict) + auth = str2bytes(auth) try: - auth_info = auth.decode("base64") - except binascii.Error: + auth_info = base64.b64decode(auth) + except TypeError: raise ProtocolViolation("failed to base64 decode auth_info") + auth_info = bytes2str(auth_info) try: username, password = auth_info.split(':', 1) except ValueError: diff --git a/wsgitools/scgi/__init__.py b/wsgitools/scgi/__init__.py index 898fd61..f651264 100644 --- a/wsgitools/scgi/__init__.py +++ b/wsgitools/scgi/__init__.py @@ -45,13 +45,15 @@ class FileWrapper(object): def __iter__(self): return self - def next(self): + def __next__(self): assert self.offset <= 0 self.offset = -1 data = self.filelike.read(self.blksize) if data: return data raise StopIteration + def next(self): + return self.__next__() def _convert_environ(environ, multithread=False, multiprocess=False, run_once=False): diff --git a/wsgitools/scgi/asynchronous.py b/wsgitools/scgi/asynchronous.py index 386e1d0..51c1d55 100644 --- a/wsgitools/scgi/asynchronous.py +++ b/wsgitools/scgi/asynchronous.py @@ -1,16 +1,12 @@ __all__ = [] import asyncore +import io import socket import sys -# Cannot use io module as it is broken in 2.6. -# Writing a str to a io.StringIO results in an exception. -try: - import cStringIO as io -except ImportError: - import StringIO as io import errno +from wsgitools.internal import bytes2str, str2bytes from wsgitools.scgi import _convert_environ, FileWrapper if sys.version_info[0] >= 3: @@ -42,20 +38,21 @@ class SCGIConnection(asyncore.dispatcher): self.state = SCGIConnection.NEW # internal state self.environ = config.copy() # environment passed to wsgi app self.reqlen = -1 # request length used in two different meanings - self.inbuff = "" # input buffer - self.outbuff = "" # output buffer + self.inbuff = b"" # input buffer + self.outbuff = b"" # output buffer self.wsgihandler = None # wsgi application self.wsgiiterator = None # wsgi application iterator self.outheaders = () # headers to be sent # () -> unset, (..,..) -> set, True -> sent - self.body = io.StringIO() # request body + self.body = io.BytesIO() # request body def _try_send_headers(self): if self.outheaders != True: assert not self.outbuff status, headers = self.outheaders headdata = "".join(map("%s: %s\r\n".__mod__, headers)) - self.outbuff = "Status: %s\r\n%s\r\n" % (status, headdata) + headdata = "Status: %s\r\n%s\r\n" % (status, headdata) + self.outbuff = str2bytes(headdata) self.outheaders = True def _wsgi_write(self, data): @@ -79,12 +76,13 @@ class SCGIConnection(asyncore.dispatcher): data = self.recv(self.blocksize) self.inbuff += data if self.state == SCGIConnection.NEW: - if ':' in self.inbuff: - reqlen, self.inbuff = self.inbuff.split(':', 1) - if not reqlen.isdigit(): + if b':' in self.inbuff: + reqlen, self.inbuff = self.inbuff.split(b':', 1) + try: + reqlen = int(reqlen) + except ValueError: # invalid request format self.close() - return # invalid request format - reqlen = int(reqlen) + return if reqlen > self.maxrequestsize: self.close() return # request too long @@ -98,20 +96,21 @@ class SCGIConnection(asyncore.dispatcher): buff = self.inbuff[:self.reqlen] remainder = self.inbuff[self.reqlen:] - while buff.count('\0') >= 2: - key, value, buff = buff.split('\0', 2) - self.environ[key] = value + while buff.count(b'\0') >= 2: + key, value, buff = buff.split(b'\0', 2) + self.environ[bytes2str(key)] = bytes2str(value) self.reqlen -= len(key) + len(value) + 2 self.inbuff = buff + remainder if self.reqlen == 0: - if self.inbuff.startswith(','): + if self.inbuff.startswith(b','): self.inbuff = self.inbuff[1:] - if not self.environ.get("CONTENT_LENGTH", "bad").isdigit(): + try: + self.reqlen = int(self.environ["CONTENT_LENGTH"]) + except ValueError: self.close() return - self.reqlen = int(self.environ["CONTENT_LENGTH"]) if self.reqlen > self.maxpostsize: self.close() return @@ -124,7 +123,7 @@ class SCGIConnection(asyncore.dispatcher): if len(self.inbuff) >= self.reqlen: self.body.write(self.inbuff[:self.reqlen]) self.body.seek(0) - self.inbuff = "" + self.inbuff = b"" self.reqlen = 0 _convert_environ(self.environ) self.environ["wsgi.input"] = self.body @@ -141,7 +140,7 @@ class SCGIConnection(asyncore.dispatcher): else: self.body.write(self.inbuff) self.reqlen -= len(self.inbuff) - self.inbuff = "" + self.inbuff = b"" def start_response(self, status, headers, exc_info=None): assert isinstance(status, str) @@ -170,7 +169,7 @@ class SCGIConnection(asyncore.dispatcher): if len(self.outbuff) < self.blocksize: self._try_send_headers() for data in self.wsgiiterator: - assert isinstance(data, str) + assert isinstance(data, bytes) if data: self.outbuff += data break @@ -262,7 +261,7 @@ class SCGIServer(asyncore.dispatcher): """asyncore interface""" try: ret = self.accept() - except socket.error, err: + except socket.error as err: # See http://bugs.python.org/issue6706 if err.args[0] not in (errno.ECONNABORTED, errno.EAGAIN): raise @@ -275,4 +274,3 @@ class SCGIServer(asyncore.dispatcher): """Runs the server. It will not return and you can invoke C{asyncore.loop()} instead achieving the same effect.""" asyncore.loop() - diff --git a/wsgitools/scgi/forkpool.py b/wsgitools/scgi/forkpool.py index a49d1ec..7df1575 100644 --- a/wsgitools/scgi/forkpool.py +++ b/wsgitools/scgi/forkpool.py @@ -16,6 +16,7 @@ import sys import errno import signal +from wsgitools.internal import bytes2str, str2bytes from wsgitools.scgi import _convert_environ, FileWrapper if sys.version_info[0] >= 3: @@ -32,22 +33,22 @@ class SocketFileWrapper(object): def __init__(self, sock, toread): """@param sock: is a C{socket.socket()}""" self.sock = sock - self.buff = "" + self.buff = b"" self.toread = toread def _recv(self, size=4096): """ internal method for receiving and counting incoming data - @raise socket.error: + @raises socket.error: """ toread = min(size, self.toread) if not toread: - return "" + return b"" try: data = self.sock.recv(toread) - except socket.error, why: + except socket.error as why: if why[0] in (errno.ECONNRESET, errno.ENOTCONN, errno.ESHUTDOWN): - data = "" + data = b"" else: raise self.toread -= len(data) @@ -67,12 +68,12 @@ class SocketFileWrapper(object): def read(self, size=None): """ see pep333 - @raise socket.error: + @raises socket.error: """ if size is None: retl = [] data = self.buff - self.buff = "" + self.buff = b"" while True: retl.append(data) try: @@ -81,7 +82,7 @@ class SocketFileWrapper(object): break if not data: break - return "".join(retl) + return b"".join(retl) datalist = [self.buff] datalen = len(self.buff) while datalen < size: @@ -93,22 +94,22 @@ class SocketFileWrapper(object): break datalist.append(data) datalen += len(data) - self.buff = "".join(datalist) + self.buff = b"".join(datalist) if size <= len(self.buff): ret, self.buff = self.buff[:size], self.buff[size:] return ret - ret, self.buff = self.buff, "" + ret, self.buff = self.buff, b"" return ret def readline(self, size=None): """ see pep333 - @raise socket.error: + @raises socket.error: """ while True: try: - split = self.buff.index('\n') + 1 + split = self.buff.index(b'\n') + 1 if size is not None and split > size: split = size ret, self.buff = self.buff[:split], self.buff[split:] @@ -123,14 +124,14 @@ class SocketFileWrapper(object): else: data = self._recv(4096) if not data: - ret, self.buff = self.buff, "" + ret, self.buff = self.buff, b"" return ret self.buff += data def readlines(self): """ see pep333 - @raise socket.error: + @raises socket.error: """ data = self.readline() while data: @@ -139,21 +140,23 @@ class SocketFileWrapper(object): def __iter__(self): """see pep333""" return self - def next(self): + def __next__(self): """ see pep333 - @raise socket.error: + @raises socket.error: """ data = self.read(4096) if not data: raise StopIteration return data + def next(self): + return self.__next__() def flush(self): """see pep333""" pass def write(self, data): """see pep333""" - assert isinstance(data, str) + assert isinstance(data, bytes) try: self.sock.sendall(data) except socket.error: @@ -268,7 +271,7 @@ class SCGIServer(object): self.spawnworker() try: rs, _, _ = select.select(self.workers.keys(), [], []) - except select.error, e: + except select.error as e: if e[0] != errno.EINTR: raise rs = [] @@ -277,11 +280,11 @@ class SCGIServer(object): data = self.workers[s].sock.recv(1) except socket.error: # we cannot handle errors here, so drop the connection. - data = '' - if data == '': + data = b'' + if data == b'': self.workers[s].sock.close() del self.workers[s] - elif data in ('0', '1'): + elif data in (b'0', b'1'): self.workers[s].state = int(data) else: raise RuntimeError("unexpected data from worker") @@ -368,18 +371,18 @@ class SCGIServer(object): def work(self, worksock): """ internal! serves maxrequests times - @raise socket.error: + @raises socket.error: """ for _ in range(self.maxrequests): (con, addr) = self.server.accept() # we cannot handle socket.errors here. - worksock.sendall('1') # tell server we're working + worksock.sendall(b'1') # tell server we're working if self.timelimit: signal.alarm(self.timelimit) self.process(con) if self.timelimit: signal.alarm(0) - worksock.sendall('0') # tell server we've finished + worksock.sendall(b'0') # tell server we've finished if self.cpulimit: break @@ -398,14 +401,15 @@ class SCGIServer(object): except socket.error: con.close() return - if not ':' in data: + if not b':' in data: con.close() return - length, data = data.split(':', 1) - if not length.isdigit(): # clear protocol violation + length, data = data.split(b':', 1) + try: + length = int(length) + except ValueError: # clear protocol violation con.close() return - length = int(length) while len(data) != length + 1: # read one byte beyond try: @@ -419,35 +423,32 @@ class SCGIServer(object): data += t # netstrings! - data = data.split('\0') + data = data.split(b'\0') # the byte beyond has to be a ','. # and the number of netstrings excluding the final ',' has to be even - if data.pop() != ',' or len(data) % 2 != 0: + if data.pop() != b',' or len(data) % 2 != 0: con.close() return environ = self.config.copy() while data: - key = data.pop(0) - value = data.pop(0) + key = bytes2str(data.pop(0)) + value = bytes2str(data.pop(0)) environ[key] = value # elements: # 0 -> None: no headers set # 0 -> False: set but unsent # 0 -> True: sent - # 1 -> status string - # 2 -> header list - response_head = [None, None, None] + # 1 -> bytes of the complete header + response_head = [None, None] def sendheaders(): assert response_head[0] is not None # headers set if response_head[0] != True: response_head[0] = True try: - con.sendall('Status: %s\r\n%s\r\n\r\n' % (response_head[1], - '\r\n'.join(map("%s: %s".__mod__, - response_head[2])))) + con.sendall(response_head[1]) except socket.error: pass @@ -465,17 +466,20 @@ class SCGIServer(object): finally: exc_info = None assert not response_head[0] # unset or not sent + headers = "".join(map("%s: %s\r\n".__mod__, headers)) + full_header = "Status: %s\r\n%s\r\n" % (status, headers) + response_head[1] = str2bytes(full_header) response_head[0] = False # set but nothing sent - response_head[1] = status - response_head[2] = headers return dumbsend - if not environ.get("CONTENT_LENGTH", "bad").isdigit(): + try: + content_length = int(environ["CONTENT_LENGTH"]) + except ValueError: con.close() return _convert_environ(environ, multiprocess=True) - sfw = SocketFileWrapper(con, int(environ["CONTENT_LENGTH"])) + sfw = SocketFileWrapper(con, content_length) environ["wsgi.input"] = sfw result = self.wsgiapp(environ, start_response) @@ -490,7 +494,7 @@ class SCGIServer(object): result_iter = iter(result) for data in result_iter: assert response_head[0] is not None - assert isinstance(data, str) + assert isinstance(data, bytes) dumbsend(data) if response_head[0] != True: sendheaders() |