#!/usr/bin/python3
"""This scrip takes a directory or a http base url to a mirror and imports all
packages contained. It has rather strong assumptions on the working directory.
"""

import argparse
import errno
import multiprocessing
import pathlib
import sqlite3
import subprocess
import sys
import tempfile
import typing
import urllib.parse
import concurrent.futures
from debian.debian_support import version_compare

from dedup.utils import iterate_packages

from readyaml import readyaml


PkgDict = typing.Dict[str, str]


def process_http(
    pkgs: typing.Dict[str, PkgDict], url: str, addhash: bool = True
) -> None:
    for pkg in iterate_packages(url, "amd64"):
        name = pkg["Package"]
        if name in pkgs and \
                version_compare(pkgs[name]["version"], pkg["Version"]) > 0:
            continue
        inst = dict(version=pkg["Version"],
                    filename="%s/%s" % (url, pkg["Filename"]))
        if addhash:
            inst["sha256hash"] = pkg["SHA256"]
        pkgs[name] = inst


def process_file(
    pkgs: typing.Dict[str, PkgDict], filename: pathlib.Path
) -> None:
    if filename.suffix != ".deb":
        raise ValueError("filename does not end in .deb")
    parts = filename.name.split("_")
    if len(parts) != 3:
        raise ValueError("filename not in form name_version_arch.deb")
    name, version, _ = parts
    version = urllib.parse.unquote(version)
    if name in pkgs and version_compare(pkgs[name]["version"], version) > 0:
        return
    pkgs[name] = dict(version=version, filename=str(filename))


def process_dir(pkgs: typing.Dict[str, PkgDict], d: pathlib.Path) -> None:
    for entry in d.iterdir():
        try:
            process_file(pkgs, entry)
        except ValueError:
            pass

def process_pkg(name: str, pkgdict: PkgDict, outpath: pathlib.Path) -> None:
    filename = pkgdict["filename"]
    print("importing %s" % filename)
    importcmd = [sys.executable, "importpkg.py"]
    if "sha256hash" in pkgdict:
        importcmd.extend(["-H", pkgdict["sha256hash"]])
    if filename.startswith(("http://", "https://", "ftp://", "file://")):
        importcmd.append(filename)
        with outpath.open("w") as outp:
            subprocess.check_call(importcmd, stdout=outp, close_fds=True)
    else:
        with open(filename) as inp:
            with outpath.open("w") as outp:
                subprocess.check_call(importcmd, stdin=inp, stdout=outp,
                                      close_fds=True)
    print("preprocessed %s" % name)

def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", "--new", action="store_true",
                        help="avoid reimporting same versions")
    parser.add_argument("-p", "--prune", action="store_true",
                        help="prune packages old packages")
    parser.add_argument("-d", "--database", action="store",
                        default="test.sqlite3",
                        help="path to the sqlite3 database file")
    parser.add_argument("--noverify", action="store_true",
                        help="do not verify binary package hashes")
    parser.add_argument("files", nargs='+',
                        help="files or directories or repository urls")
    args = parser.parse_args()
    tmpdir = pathlib.Path(tempfile.mkdtemp(prefix="debian-dedup"))
    db = sqlite3.connect(args.database)
    cur = db.cursor()
    cur.execute("PRAGMA foreign_keys = ON;")
    e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
    pkgs: typing.Dict[str, PkgDict] = {}
    for d in args.files:
        print("processing %s" % d)
        if d.startswith(("http://", "https://", "ftp://", "file://")):
            process_http(pkgs, d, not args.noverify)
        else:
            dp = pathlib.Path(d)
            if dp.is_dir():
                process_dir(pkgs, dp)
            else:
                process_file(pkgs, dp)

    print("reading database")
    cur.execute("SELECT name, version FROM package;")
    knownpkgvers = dict((row[0], row[1]) for row in cur.fetchall())
    distpkgs = set(pkgs.keys())
    if args.new:
        for name in distpkgs:
            if name in knownpkgvers and \
               version_compare(pkgs[name]["version"], knownpkgvers[name]) <= 0:
                del pkgs[name]
    knownpkgs = set(knownpkgvers)
    del knownpkgvers

    with e:
        fs = {}
        for name, pkg in pkgs.items():
            fs[e.submit(process_pkg, name, pkg, tmpdir / name)] = name

        for f in concurrent.futures.as_completed(fs.keys()):
            name = fs[f]
            if f.exception():
                print("%s failed to import: %r" % (name, f.exception()))
                continue
            inf = tmpdir / name
            print("sqlimporting %s" % name)
            with inf.open() as inp:
                try:
                    readyaml(db, inp)
                except Exception as exc:
                    print("%s failed sql with exception %r" % (name, exc))
                else:
                    inf.unlink()

    if args.prune:
        delpkgs = knownpkgs - distpkgs
        print("clearing packages %s" % " ".join(delpkgs))
        cur.executemany("DELETE FROM package WHERE name = ?;",
                        ((pkg,) for pkg in delpkgs))
        # Tables content, dependency and sharing will also be pruned
        # due to ON DELETE CASCADE clauses.
        db.commit()
    try:
        tmpdir.rmdir()
    except OSError as err:
        if err.errno != errno.ENOTEMPTY:
            raise
        print("keeping temporary directory %s due to failed packages %s" %
              (tmpdir, " ".join(map(str, tmpdir.iterdir()))))

if __name__ == "__main__":
    main()