#! /usr/bin/python3

import hashlib
import argparse
import os
import re
import sys
import tarfile
from urllib import request
from urllib.error import HTTPError
from urllib.parse import (
    urlparse,
    urlunparse,
    )

import apt

# package_name: package containing the objects we signed
# package_version: package version containing the objects we signed
# src_package: source package name in dists
# signed_type: 'signed' or 'uefi' schema in the url

parser = argparse.ArgumentParser()
parser.add_argument(
    "package_name",
    help="package containining the objects we signed")
parser.add_argument(
    "package_version",
    help="package version containing the objects we signed, or 'current'")
parser.add_argument(
    "src_package",
    help="source package name in dists")
parser.add_argument(
    "signed_type",
    nargs='?',
    default='signed',
    help="subdirectory type in the url, 'signed' or 'uefi'")
args = parser.parse_args()


class SignedDownloader:
    """Download a block of signed information from dists.

    Find a block of signed information as published in dists/*/signed
    and download the contents.  Use the contained checksum files to
    identify the members and to validate them once downloaded.
    """

    def __init__(self, package_name, package_version, src_package, signed_type='signed'):
        self.package_name = package_name
        self.package_version = package_version
        self.src_package = src_package

        # Find the package in the available archive repositories.  Use a _binary_
        # package name and version to locate the appropriate archive.  Then use the
        # URI there to look for and find the appropriate binary.
        cache = apt.Cache()

        self.package = None
        if self.package_version == "current":
            self.package = cache[package_name].candidate
        else:
            for version in cache[package_name].versions:
                if version.version == self.package_version:
                    self.package = version
                    break

        if not self.package:
            raise KeyError("{0}: package version not found".format(self.package_name))

        origin = self.package.origins[0]
        pool_parsed = urlparse(self.package.uri)
        self.package_dir = "%s/%s/%s/%s-%s/%s/" % (
            origin.archive, 'main', signed_type,
            self.src_package, self.package.architecture, self.package_version)

        # Prepare the master url stem and pull out any username/password.  If present
        # replace the default opener with one which offers that password.
        dists_parsed_master = list(pool_parsed)
        if '@' in dists_parsed_master[1]:
            (username_password, host) = pool_parsed[1].split('@', 1)
            (username, password) = username_password.split(':', 1)

            dists_parsed_master[1] = host

            # Work out the authentication domain.
            domain_parsed = [ dists_parsed_master[0], dists_parsed_master[1], '/', None, None, None ]
            auth_uri = urlunparse(domain_parsed)

            # create a password manager
            password_mgr = request.HTTPPasswordMgrWithDefaultRealm()

            # Add the username and password.
            # If we knew the realm, we could use it instead of None.
            password_mgr.add_password(None, auth_uri, username, password)

            handler = request.HTTPBasicAuthHandler(password_mgr)

            # create "opener" (OpenerDirector instance)
            opener = request.build_opener(handler)

            # Now all calls to urllib.request.urlopen use our opener.
            request.install_opener(opener)

        self.dists_parsed = dists_parsed_master

    def download_one(self, member, filename, hash_factory=None):
        directory = os.path.dirname(filename)
        os.makedirs(directory, exist_ok=True)

        dists_parsed = list(self.dists_parsed)
        dists_parsed[2] = re.sub(r"/pool/.*", "/dists/" + self.package_dir + \
            member, dists_parsed[2])
        dists_uri = urlunparse(dists_parsed)

        print("Downloading %s ... " % dists_uri, end='')
        sys.stdout.flush()
        try:
            with request.urlopen(dists_uri) as dists, open(filename, "wb") as out:
                hashobj = None
                if hash_factory:
                    hashobj = hash_factory()
                for chunk in iter(lambda: dists.read(256 * 1024), b''):
                    if hashobj:
                        hashobj.update(chunk)
                    out.write(chunk)
                checksum = True
                if hashobj:
                    checksum = hashobj.hexdigest()
        except HTTPError as e:
            if e.code == 404:
                print("not found")
            else:
                raise
        else:
            print("found")
            return checksum
        return None

    def download(self, base):
        """Download an entire signed result from dists."""

        # Download the checksums and use that to download the contents.
        sums = 'SHA256SUMS'
        sums_local = os.path.join(base, self.package_version, sums)
        if not self.download_one(sums, sums_local):
            print('download-signed: {0}: not found'.format(sums))
            sys.exit(1)

        # Read the checksum file and download the files it mentions.
        here = os.path.abspath(base)
        with open(sums_local) as sfd:
            for line in sfd:
                line = line.strip()
                (checksum_expected, member) = (line[0:64], line[66:])
                filename = os.path.abspath(os.path.join(base, self.package_version, member))
                if not filename.startswith(here):
                    print('download-signed: {0}: member outside output directory'.format(member))
                    sys.exit(1)

                # Download and checksum this member.
                checksum_actual = self.download_one(member, filename, hashlib.sha256)
                if checksum_expected != checksum_actual:
                    print('download-signed: {0}: member checksum invalid'.format(member))
                    sys.exit(1)

        # If this is a tarball result then extract it.
        here = os.path.abspath(os.path.join(base, self.package_version))
        tarball_filename = os.path.join(base, self.package_version, 'signed.tar.gz')
        if os.path.exists(tarball_filename):
            with tarfile.open(tarball_filename) as tarball:
                for tarinfo in tarball:
                    if not filename.startswith(here):
                        print('download-signed: {0}: tarball member outside output directory'.format(member))
                        sys.exit(1)
                for tarinfo in tarball:
                    print('Extracting {0} ...'.format(tarinfo.name))
                    tarball.extract(tarinfo, base)


downloader = SignedDownloader(**vars(args))
downloader.download('.')
