import bz2
import struct
import zlib

import lzma

class GzipDecompressor:
    """An interface to gzip which is similar to bz2.BZ2Decompressor and
    lzma.LZMADecompressor."""
    def __init__(self):
        self.sawheader = False
        self.inbuffer = b""
        self.decompressor = None
        self.crc = 0
        self.size = 0

    def decompress(self, data):
        """
        @raises ValueError: if no gzip magic is found
        @raises zlib.error: from zlib invocations
        """
        while True:
            if self.decompressor:
                data = self.decompressor.decompress(data)
                self.crc = zlib.crc32(data, self.crc)
                self.size += len(data)
                unused_data = self.decompressor.unused_data
                if not unused_data:
                    return data
                self.decompressor = None
                return data + self.decompress(unused_data)
            self.inbuffer += data
            skip = 10
            if len(self.inbuffer) < skip:
                return b""
            if not self.inbuffer.startswith(b"\037\213\010"):
                raise ValueError("gzip magic not found")
            flag = ord(self.inbuffer[3:4])
            if flag & 4:
                if len(self.inbuffer) < skip + 2:
                    return b""
                length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
                skip += 2 + length
            for field in (8, 16):
                if flag & field:
                    length = self.inbuffer.find(b"\0", skip)
                    if length < 0:
                        return b""
                    skip = length + 1
            if flag & 2:
                skip += 2
            if len(self.inbuffer) < skip:
                return b""
            data = self.inbuffer[skip:]
            self.inbuffer = b""
            self.sawheader = True
            self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)

    @property
    def unused_data(self):
        if self.decompressor:
            return self.decompressor.unused_data
        elif not self.sawheader:
            return self.inbuffer
        else:
            expect = struct.pack("<LL", self.crc, self.size)
            if self.inbuffer.startswith(expect) and \
                    self.inbuffer[len(expect):].replace(b"\0", b"") == b"":
                return b""
            return self.inbuffer

    def flush(self):
        """
        @raises zlib.error: from zlib invocations
        """
        if not self.decompressor:
            return b""
        return self.decompressor.flush()

    def copy(self):
        new = GzipDecompressor()
        new.inbuffer = self.inbuffer
        if self.decompressor:
            new.decompressor = self.decompressor.copy()
        new.sawheader = self.sawheader
        new.crc = self.crc
        new.size = self.size
        return new

class DecompressedStream:
    """Turn a readable file-like into a decompressed file-like. It supports
    read(optional length), tell, seek(forward only) and close."""
    blocksize = 65536

    def __init__(self, fileobj, decompressor):
        """
        @param fileobj: a file-like object providing read(size)
        @param decompressor: a bz2.BZ2Decompressor or lzma.LZMADecompressor
            like object providing methods decompress and flush and an
            attribute unused_data
        """
        self.fileobj = fileobj
        self.decompressor = decompressor
        self.buff = bytearray()
        self.pos = 0

    def _fill_buff_until(self, predicate):
        assert self.fileobj is not None
        while not predicate(self.buff):
            data = self.fileobj.read(self.blocksize)
            if data:
                self.buff += self.decompressor.decompress(data)
            else:
                if hasattr(self.decompressor, "flush"):
                    self.buff += self.decompressor.flush()
                break

    def _read_from_buff(self, length):
        ret = bytes(self.buff[:length])
        self.buff[:length] = b""
        self.pos += length
        return ret

    def read(self, length=None):
        if length is None:
            self._fill_buff_until(lambda _: False)
            length = len(self.buff)
        else:
            self._fill_buff_until(lambda b, l=length: len(b) >= l)
        return self._read_from_buff(length)

    def readline(self):
        self._fill_buff_until(lambda b: b'\n' in b)
        try:
            length = self.buff.index(b'\n') + 1
        except ValueError:
            length = len(self.buff)
        return self._read_from_buff(length)

    def __iter__(self):
        return iter(self.readline, b'')

    def tell(self):
        assert self.fileobj is not None
        return self.pos

    def seek(self, pos):
        """Forward seeks by absolute position only."""
        assert self.fileobj is not None
        if pos < self.pos:
            raise ValueError("negative seek not allowed on decompressed stream")
        while True:
            left = pos - self.pos
            # Reading self.buff entirely avoids string concatenation.
            size = len(self.buff) or self.blocksize
            if left > size:
                self.read(size)
            else:
                self.read(left)
                return

    def close(self):
        if self.fileobj is not None:
            self.fileobj.close()
            self.fileobj = None
            self.decompressor = None
            self.buff = bytearray()

decompressors = {
    '.gz':   GzipDecompressor,
    '.bz2':  bz2.BZ2Decompressor,
    '.lzma': lzma.LZMADecompressor,
    '.xz':   lzma.LZMADecompressor,
}

def decompress(filelike, extension):
    """Decompress a stream according to its extension.
    @param filelike: is a read-only byte-stream. It must support read(size) and
                     close().
    @param extension: permitted values are "", ".gz", ".bz2", ".lzma", and
                      ".xz"
    @type extension: unicode
    @returns: a read-only byte-stream with the decompressed contents of the
              original filelike. It supports read(size) and close(). If the
              original supports seek(pos) and tell(), then it also supports
              those.
    @raises ValueError: on unkown extensions
    """
    if not extension:
        return filelike
    try:
        decompressor = decompressors[extension]
    except KeyError:
        raise ValueError("unknown compression format with extension %r" %
                         extension)
    return DecompressedStream(filelike, decompressor())