#!/usr/bin/python
"""This scrip takes a directory or a http base url to a mirror and imports all
packages contained. It has rather strong assumptions on the working directory.
"""

import gzip
import errno
import io
import multiprocessing
import optparse
import os
import sqlite3
import subprocess
import tempfile
import urllib

import concurrent.futures
from debian import deb822
from debian.debian_support import version_compare

from readyaml import readyaml

def process_http(pkgs, url):
    pkglist = urllib.urlopen(url + "/dists/sid/main/binary-amd64/Packages.gz").read()
    pkglist = gzip.GzipFile(fileobj=io.BytesIO(pkglist)).read()
    pkglist = io.BytesIO(pkglist)
    pkglist = deb822.Packages.iter_paragraphs(pkglist)
    for pkg in pkglist:
        key = (pkg["Package"], pkg["Architecture"])
        if key in pkgs and \
                version_compare(pkgs[key]["version"], pkg["Version"]) > 0:
            continue
        pkgs[key] = dict(version=pkg["Version"],
                         filename="%s/%s" % (url, pkg["Filename"]),
                         sha256hash=pkg["SHA256"])

def process_file(pkgs, filename):
    base = os.path.basename(filename)
    if not base.endswith(".deb"):
        raise ValueError("filename does not end in .deb")
    parts = base[:-4].split("_")
    if len(parts) != 3:
        raise ValueError("filename not in form name_version_arch.deb")
    name, version, architecture = parts
    key = (name, architecture)
    version = urllib.unquote(version)
    if key in pkgs and version_compare(pkgs[key]["version"], version) > 0:
        return
    pkgs[key] = dict(version=version, filename=filename)

def process_dir(pkgs, d):
    for entry in os.listdir(d):
        try:
            process_file(pkgs, os.path.join(d, entry))
        except ValueError:
            pass

def process_pkg(key, pkgdict, outpath):
    filename = pkgdict["filename"]
    print("importing %s" % filename)
    importcmd = ["python", "importpkg.py"]
    if "sha256hash" in pkgdict:
        importcmd.extend(["-H", pkgdict["sha256hash"]])
    if filename.startswith(("http://", "https://", "ftp://", "file://")):
        with open(outpath, "w") as outp:
            dl = subprocess.Popen(["curl", "-s", filename],
                                  stdout=subprocess.PIPE, close_fds=True)
            imp = subprocess.Popen(importcmd, stdin=dl.stdout, stdout=outp,
                                   close_fds=True)
            if imp.wait():
                raise ValueError("importpkg failed")
            if dl.wait():
                raise ValueError("curl failed")
    else:
        with open(filename) as inp:
            with open(outpath, "w") as outp:
                subprocess.check_call(importcmd, stdin=inp, stdout=outp,
                                      close_fds=True)
    print("preprocessed %s:%s" % key)

def main():
    parser = optparse.OptionParser()
    parser.add_option("-n", "--new", action="store_true",
                      help="avoid reimporting same versions")
    parser.add_option("-p", "--prune", action="store_true",
                      help="prune packages old packages")
    parser.add_option("-d", "--database", action="store",
                      default="test.sqlite3",
                      help="path to the sqlite3 database file")
    options, args = parser.parse_args()
    tmpdir = tempfile.mkdtemp(prefix=b"debian-dedup")
    db = sqlite3.connect(options.database)
    cur = db.cursor()
    cur.execute("PRAGMA foreign_keys = ON;")
    e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
    pkgs = {}
    for d in args:
        print("processing %s" % d)
        if d.startswith(("http://", "https://", "ftp://", "file://")):
            process_http(pkgs, d)
        elif os.path.isdir(d):
            process_dir(pkgs, d)
        else:
            process_file(pkgs, d)

    print("reading database")
    cur.execute("SELECT name, architecture, version FROM package;")
    knownpkgs = dict(((row[0], row[1]), row[2]) for row in cur.fetchall())
    distpkgs = set(pkgs.keys())
    if options.new:
        for key in distpkgs:
            if key in knownpkgs and version_compare(pkgs[key]["version"],
                    knownpkgs[key]) <= 0:
                del pkgs[key]
    knownpkgs = set(knownpkgs)

    with e:
        fs = {}
        for key, pkg in pkgs.items():
            outpath = os.path.join(tmpdir, "%s_%s" % key)
            fs[e.submit(process_pkg, key, pkg, outpath)] = key

        for f in concurrent.futures.as_completed(fs.keys()):
            key = fs[f]
            if f.exception():
                print("%s:%s failed to import: %r" %
                      (key[0], key[1], f.exception()))
                continue
            inf = os.path.join(tmpdir, "%s_%s" % key)
            print("sqlimporting %s:%s" % key)
            with open(inf) as inp:
                try:
                    readyaml(db, inp)
                except Exception as exc:
                    print("%s:%s failed sql with exception %r" %
                          (key[0], key[1], exc))
                else:
                    os.unlink(inf)

    if options.prune:
        delpkgs = knownpkgs - distpkgs
        print("clearing packages %s" % " ".join(map("%s:%s".__mod__, delpkgs)))
        cur.executemany("DELETE FROM package WHERE name = ? AND architecture = ?;",
                        delpkgs)
        # Tables content, dependency and sharing will also be pruned
        # due to ON DELETE CASCADE clauses.
        db.commit()
    try:
        os.rmdir(tmpdir)
    except OSError as err:
        if err.errno != errno.ENOTEMPTY:
            raise
        print("keeping temporary directory %s due to failed packages %s" %
              (tmpdir, " ".join(os.listdir(tmpdir))))

if __name__ == "__main__":
    main()