From 6455d0aa18a923bff621dd961990063fe9d9038b Mon Sep 17 00:00:00 2001
From: Helmut Grohne <helmut@subdivi.de>
Date: Tue, 8 Jul 2008 19:40:51 +0200
Subject: refactor scgi.forkpool socket handling

---
 wsgitools/scgi/forkpool.py | 90 +++++++++++++++++++++++++++++++++++++---------
 1 file changed, 73 insertions(+), 17 deletions(-)

diff --git a/wsgitools/scgi/forkpool.py b/wsgitools/scgi/forkpool.py
index 6d6f5c5..055dbd1 100644
--- a/wsgitools/scgi/forkpool.py
+++ b/wsgitools/scgi/forkpool.py
@@ -2,6 +2,7 @@ import socket
 import os
 import select
 import sys
+import errno
 
 class SocketFileWrapper:
     """Wraps a socket to a wsgi-compliant file-like object."""
@@ -12,12 +13,28 @@ class SocketFileWrapper:
         self.toread = toread
 
     def _recv(self, size=None):
-        """internal method for receiving and counting incoming data"""
+        """
+        internal method for receiving and counting incoming data
+        @raise: socket.error
+        """
         if size is None:
-            data = self.sock.recv()
+            try:
+                data = self.sock.recv()
+            except socket.error, why:
+                if why[0] in (errno.ECONNRESET, errno.ENOTCONN,
+                              errno.ESHUTDOWN):
+                    data = ""
+                else:
+                    raise
             self.toread -= len(data)
             return data
-        data = self.sock.recv(size)
+        try:
+            data = self.sock.recv(size)
+        except socket.error, why:
+            if why[0] in (errno.ECONNRESET, errno.ENOTCONN, errno.ESHUTDOWN):
+                data = ""
+            else:
+                raise
         self.toread -= len(data)
         return data
 
@@ -34,7 +51,10 @@ class SocketFileWrapper:
             pass
 
     def read(self, size=None):
-        """see pep333"""
+        """
+        see pep333
+        @raise: socket.error
+        """
         if size is None:
             try:
                 data = self._recv()
@@ -57,7 +77,10 @@ class SocketFileWrapper:
         return self.buff + data
 
     def readline(self, size=None):
-        """see pep333"""
+        """
+        see pep333
+        @raise: socket.error
+        """
         while True:
             try:
                 split = self.buff.index('\n') + 1
@@ -80,7 +103,10 @@ class SocketFileWrapper:
                 self.buff += data
 
     def readlines(self):
-        """see pep333"""
+        """
+        see pep333
+        @raise: socket.error
+        """
         data = self.readline()
         while data:
             yield data
@@ -89,7 +115,10 @@ class SocketFileWrapper:
         """see pep333"""
         return self
     def next(self):
-        """see pep333"""
+        """
+        see pep333
+        @raise: socket.error
+        """
         data = self.read(4096)
         if not data:
             raise StopIteration
@@ -100,7 +129,11 @@ class SocketFileWrapper:
     def write(self, data):
         """see pep333"""
         assert isinstance(data, str)
-        self.sock.send(data)
+        try:
+            self.sock.sendall(data)
+        except socket.error:
+            # ignore all socket errors: there is no way to report
+            return
     def writelines(self, lines):
         """see pep333"""
         map(self.write, lines)
@@ -163,7 +196,11 @@ class SCGIServer:
                 self.spawnworker()
             rs, _, _ = select.select(self.workers.keys(), [], [])
             for s in rs:
-                data = self.workers[s].sock.recv(1)
+                try:
+                    data = self.workers[s].sock.recv(1)
+                except socket.error:
+                    # we cannot handle errors here, so drop the connection.
+                    data = ''
                 if data == '':
                     self.workers[s].sock.close()
                     del self.workers[s]
@@ -192,7 +229,10 @@ class SCGIServer:
                 worker.sock.close()
             del self.workers
 
-            self.work(worksock)
+            try:
+                self.work(worksock)
+            except socket.error:
+                pass
 
             sys.exit()
         elif pid > 0:
@@ -207,12 +247,14 @@ class SCGIServer:
     def work(self, worksock):
         """
         internal! serves maxrequests times
+        @raise: socket.error
         """
         for _ in range(self.maxrequests):
             (con, addr) = self.server.accept()
-            worksock.send('1') # tell server we're working
+            # we cannot handle socket.errors here.
+            worksock.sendall('1') # tell server we're working
             self.process(con)
-            worksock.send('0') # tell server we've finished
+            worksock.sendall('0') # tell server we've finished
 
     def process(self, con):
         """
@@ -224,7 +266,11 @@ class SCGIServer:
         # 2. the packet isn't fragmented.
         # Furthermore 1 implies that the request isn't longer than 999999 bytes.
         # This method however works. :-)
-        data = con.recv(7)
+        try:
+            data = con.recv(7)
+        except socket.error:
+            con.close()
+            return
         if not ':' in data:
             con.close()
             return
@@ -235,7 +281,11 @@ class SCGIServer:
         length = long(length)
 
         while len(data) != length + 1: # read one byte beyond
-            t = con.recv(length + 1 - len(data))
+            try:
+                t = con.recv(min(4096, length + 1 - len(data)))
+            except socket.error:
+                con.close()
+                return
             if not t: # request too short
                 con.close()
                 return
@@ -255,10 +305,16 @@ class SCGIServer:
             value = data.pop(0)
             environ[key] = value
 
+        def dumbsend(data):
+            try:
+                con.sendall(data)
+            except socket.error:
+                pass
+
         def start_response(status, headers, exc_info=None):
-            con.send('Status: %s\r\n%s\r\n\r\n' % (status,
+            dumbsend('Status: %s\r\n%s\r\n\r\n' % (status,
                       '\r\n'.join(map("%s: %s".__mod__, headers))))
-            return con.send
+            return dumbsend
 
         sfw = SocketFileWrapper(con, long(environ["CONTENT_LENGTH"]))
         environ.update({
@@ -281,7 +337,7 @@ class SCGIServer:
 
         for data in result:
             assert isinstance(data, str)
-            con.send(data)
+            dumbsend(data)
         if hasattr(result, "close"):
             result.close()
         sfw.close()
-- 
cgit v1.2.3