From 472144ac68188056eb41c9cb198df04b454a1da2 Mon Sep 17 00:00:00 2001
From: Helmut Grohne <helmut@subdivi.de>
Date: Fri, 29 Jun 2012 08:47:51 +0200
Subject: fix hashlib, base64 and other bytes issues

 * hashlib.md5 wants bytes now.
 * string.decode("base64") is now base64.b64decode and works on bytes
 * binascii.unhexlify is now base64.b16decode and also works on bytes
 * str.isalnum accepts umlauts, use bytes.isalnum instead
---
 wsgitools/digest.py      | 59 +++++++++++++++++++++++++++++-------------------
 wsgitools/middlewares.py | 14 +++++++-----
 2 files changed, 44 insertions(+), 29 deletions(-)

(limited to 'wsgitools')

diff --git a/wsgitools/digest.py b/wsgitools/digest.py
index 4b5f8fb..532b371 100644
--- a/wsgitools/digest.py
+++ b/wsgitools/digest.py
@@ -14,32 +14,39 @@ database using C{DBAPI2NonceStore}.
 __all__ = []
 
 import random
-try:
-    from hashlib import md5
-except ImportError:
-    from md5 import md5
-import binascii
 import base64
+import hashlib
 import time
 import os
 
+from wsgitools.internal import bytes2str, str2bytes
 from wsgitools.authentication import AuthenticationRequired, \
         ProtocolViolation, AuthenticationMiddleware
 
 sysrand = random.SystemRandom()
 
-def gen_rand_str(bytes=33):
+def md5hex(data):
+    """
+    @type data: str
+    @rtype: str
+    """
+    return hashlib.md5(str2bytes(data)).hexdigest()
+
+def gen_rand_str(bytesentropy=33):
     """
     Generates a string of random base64 characters.
-    @param bytes: is the number of random 8bit values to be used
+    @param bytesentropy: is the number of random 8bit values to be used
+    @rtype: str
 
     >>> gen_rand_str() != gen_rand_str()
     True
     """
-    randnum = sysrand.getrandbits(bytes*8)
-    randstr = ("%%0%dX" % (2*bytes)) % randnum
-    randstr = binascii.unhexlify(randstr)
-    randstr = base64.encodestring(randstr).strip()
+    randnum = sysrand.getrandbits(bytesentropy*8)
+    randstr = ("%%0%dX" % (2*bytesentropy)) % randnum
+    randbytes = str2bytes(randstr)
+    randbytes = base64.b16decode(randbytes)
+    randbytes = base64.b64encode(randbytes)
+    randstr = bytes2str(randbytes)
     return randstr
 
 def parse_digest_response(data):
@@ -120,6 +127,8 @@ def format_digest(mapping):
     assert isinstance(mapping, dict)
     result = []
     for key, (value, needsquoting) in mapping.items():
+        assert isinstance(key, str)
+        assert isinstance(value, str)
         if needsquoting:
             value = '"%s"' % value.replace('\\', '\\\\').replace('"', '\\"')
         else:
@@ -172,8 +181,8 @@ class AbstractTokenGenerator:
         """
         assert isinstance(username, str)
         assert isinstance(password, str)
-        token = md5("%s:%s:%s" % (username, self.realm, password)).hexdigest()
-        return token == self(username)
+        token = "%s:%s:%s" % (username, self.realm, password)
+        return md5hex(token) == self(username)
 
 __all__.append("AuthTokenGenerator")
 class AuthTokenGenerator(AbstractTokenGenerator):
@@ -199,7 +208,7 @@ class AuthTokenGenerator(AbstractTokenGenerator):
         if password is None:
             return None
         a1 = "%s:%s:%s" % (username, self.realm, password)
-        return md5(a1).hexdigest()
+        return md5hex(a1)
 
 __all__.append("HtdigestTokenGenerator")
 class HtdigestTokenGenerator(AbstractTokenGenerator):
@@ -367,7 +376,7 @@ class StatelessNonceStore(NonceStoreBase):
         token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret)
         if ident is not None:
             token = "%s:%s" % (token, ident)
-        token = md5(token).hexdigest()
+        token = md5hex(token)
         return "%s:%s:%s" % (nonce_time, nonce_value, token)
 
     def checknonce(self, nonce, count=1, ident=None):
@@ -387,7 +396,7 @@ class StatelessNonceStore(NonceStoreBase):
         token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret)
         if ident is not None:
             token = "%s:%s" % (token, ident)
-        token = md5(token).hexdigest()
+        token = md5hex(token)
         if token != nonce_hash:
             return False
 
@@ -449,7 +458,7 @@ class MemoryNonceStore(NonceStoreBase):
         token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret)
         if ident is not None:
             token = "%s:%s" % (token, ident)
-        token = md5(token).hexdigest()
+        token = md5hex(token)
         return "%s:%s:%s" % (nonce_time, nonce_value, token)
 
     def checknonce(self, nonce, count=1, ident=None):
@@ -468,7 +477,7 @@ class MemoryNonceStore(NonceStoreBase):
         token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret)
         if ident is not None:
             token = "%s:%s" % (token, ident)
-        token = md5(token).hexdigest()
+        token = md5hex(token)
         if token != nonce_hash:
             return False
 
@@ -595,7 +604,7 @@ class DBAPI2NonceStore(NonceStoreBase):
         token = "%s:%s" % (dbkey, self.server_secret)
         if ident is not None:
             token = "%s:%s" % (token, ident)
-        token = md5(token).hexdigest()
+        token = md5hex(token)
         return "%s:%s:%s" % (nonce_time, nonce_value, token)
 
     def checknonce(self, nonce, count=1, ident=None):
@@ -604,19 +613,22 @@ class DBAPI2NonceStore(NonceStoreBase):
         count on returning True.
         @type nonce: str
         @type count: int
+        @type ident: str or None
         @rtype: bool
         """
         try:
             nonce_time, nonce_value, nonce_hash = nonce.split(':')
         except ValueError:
             return False
-        if not nonce_time.isalnum() or not nonce_value.replace("+", ""). \
-           replace("/", "").replace("=", "").isalnum():
+        # use bytes.isalnum to avoid locale specific interpretation
+        if not str2bytes(nonce_time).isalnum() or \
+                not str2bytes(nonce_value.replace("+", "").replace("/", "") \
+                              .replace("=", "")).isalnum():
             return False
         token = "%s:%s:%s" % (nonce_time, nonce_value, self.server_secret)
         if ident is not None:
             token = "%s:%s" % (token, ident)
-        token = md5(token).hexdigest()
+        token = md5hex(token)
         if token != nonce_hash:
             return False
 
@@ -681,7 +693,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware):
     by a REMOTE_USER key before being passed to the wrapped
     application."""
     authorization_method = "digest"
-    algorithms = {"md5": lambda data: md5(data).hexdigest()}
+    algorithms = {"md5": md5hex}
     def __init__(self, app, gentoken, maxage=300, maxuses=5, store=None):
         """
         @param app: is the wsgi application to be served with authentication.
@@ -708,6 +720,7 @@ class AuthDigestMiddleware(AuthenticationMiddleware):
             self.noncestore = store
 
     def authenticate(self, auth, environ):
+        assert isinstance(auth, str)
         try:
             credentials = parse_digest_response(auth)
         except ValueError:
diff --git a/wsgitools/middlewares.py b/wsgitools/middlewares.py
index e6ede9d..725deb1 100644
--- a/wsgitools/middlewares.py
+++ b/wsgitools/middlewares.py
@@ -1,12 +1,14 @@
 __all__ = []
 
+import base64
 import time
 import sys
 import cgitb
-import binascii
 import collections
 import io
 
+from wsgitools.internal import bytes2str, str2bytes
+
 if sys.version_info[0] >= 3:
     def exc_info_for_raise(exc_info):
         return exc_info[0](exc_info[1]).with_traceback(exc_info[2])
@@ -347,14 +349,14 @@ class BasicAuthMiddleware(AuthenticationMiddleware):
         self.app401 = app401
 
     def authenticate(self, auth, environ):
-        """
-        @type environ: {str: object}
-        """
+        assert isinstance(auth, str)
         assert isinstance(environ, dict)
+        auth = str2bytes(auth)
         try:
-            auth_info = auth.decode("base64")
-        except binascii.Error:
+            auth_info = base64.b64decode(auth)
+        except TypeError:
             raise ProtocolViolation("failed to base64 decode auth_info")
+        auth_info = bytes2str(auth_info)
         try:
             username, password = auth_info.split(':', 1)
         except ValueError:
-- 
cgit v1.2.3