#!/usr/bin/python3

import argparse
import json
import os
import sys
import tempfile
import shutil

from samba import param
from samba.credentials import MUST_USE_KERBEROS, Credentials
from samba.dcerpc import preg

from vendor_samba.gp.gpclass import GPOStorage
from vendor_samba.gp import gp_cert_auto_enroll_ext as cae
from vendor_samba.gp.util.logging import logger_init, log

class adsys_cert_auto_enroll(cae.gp_cert_auto_enroll_ext):
    def enroll(self, guid, entries, trust_dir, private_dir):
        self._gp_cert_auto_enroll_ext__enroll(guid, entries, trust_dir, private_dir)

    def unenroll(self, guid):
        ca_attrs = self.cache_get_all_attribute_values(guid)
        self.clean(guid, remove=list(ca_attrs.keys()))

def smb_config(realm, enable_debug):
    config = "[global]\nrealm = %s\n" % realm
    if enable_debug:
        config += "log level = 10\n"
    return config

def main():
    parser = argparse.ArgumentParser(description='Certificate autoenrollment via Samba')
    parser.add_argument('action', type=str,
                        help='Action to perform (one of: enroll, unenroll)',
                        choices=['enroll', 'unenroll'])
    parser.add_argument('object_name', type=str,
                        help='The computer name to enroll/unenroll, e.g. keypress')
    parser.add_argument('realm', type=str,
                        help='The realm of the domain, e.g. example.com')

    parser.add_argument('--policy_servers_json', type=str,
                        help='GPO entries for advanced configuration of the policy servers. \
                        Must be in JSON format.')
    parser.add_argument('--state_dir', type=str,
                        default='/var/lib/adsys',
                        help='Directory to store all certificate-related files in.')
    parser.add_argument('--global_trust_dir', type=str,
                        default='/usr/local/share/ca-certificates',
                        help='Directory to symlink root CA certificates to.')
    parser.add_argument('--debug', action='store_true',
                        help='Enable samba debug output.')

    args = parser.parse_args()

    samba_cache_dir = os.path.join(args.state_dir, 'samba')
    trust_dir = os.path.join(args.state_dir, 'certs')
    private_dir = os.path.join(args.state_dir, 'private', 'certs')
    global_trust_dir = args.global_trust_dir

    with tempfile.NamedTemporaryFile(prefix='smb_conf') as smb_conf:
        smb_conf.write(smb_config(args.realm, args.debug).encode('utf-8'))
        smb_conf.flush()

        lp = param.LoadParm(smb_conf.name)
        # Set up logging
        logger_init('cert-autoenroll', lp.log_level())

        if not cepces_submit() or not certmonger():
            log.warning('certmonger and/or cepces not found, skipping certificate enrollment')
            return

        # Create needed directories if they don't exist
        for directory in [samba_cache_dir, trust_dir, private_dir, global_trust_dir]:
            if not os.path.exists(directory):
                perms = 0o700 if directory == private_dir else 0o755
                os.makedirs(directory, mode=perms)

        c = Credentials()
        c.set_kerberos_state(MUST_USE_KERBEROS)
        c.guess(lp)
        username = c.get_username()
        store = GPOStorage(os.path.join(samba_cache_dir, f'cert_gpo_state_{args.object_name}.tdb'))

        ext = adsys_cert_auto_enroll(lp, c, username, store)
        guid = f'adsys-cert-autoenroll-{args.object_name}'
        if args.action == 'enroll':
            entries = gpo_entries(args.policy_servers_json)
            ext.enroll(guid, entries, trust_dir, private_dir)
        else:
            ext.unenroll(guid)
            if os.path.exists(samba_cache_dir):
                shutil.rmtree(samba_cache_dir)

def gpo_entries(entries_json):
    """
    Convert JSON string to list of GPO entries

    JSON must be an array of objects with the following keys:
        keyname (str): Registry key name
        valuename (str): Registry value name
        type (int): Registry value type
        data (any): Registry value data

    Parameters:
        entries_json (str): JSON string of GPO entries
    Returns:
        list: List of GPO entries, or empty list if entries_json is empty
    """

    if not entries_json:
        return []

    entries_dict = json.loads(entries_json)
    if not entries_dict:
        return []

    entries = []
    for entry in entries_dict:
        try:
            e = preg.entry()
            e.keyname = entry['keyname']
            e.valuename = entry['valuename']
            e.type = entry['type']
            e.data = entry['data']
            entries.append(e)
        except KeyError as exc:
            raise ValueError(f'Could not find key {exc} in GPO entry') from exc
        except TypeError as exc:
            raise ValueError(f'GPO data must be a JSON array of objects') from exc
    return entries

def cepces_submit():
    certmonger_dirs = [os.environ.get('PATH'), '/usr/lib/certmonger',
                       '/usr/libexec/certmonger']
    return shutil.which('cepces-submit', path=':'.join(certmonger_dirs))

def certmonger():
    return shutil.which('getcert')

if __name__ == "__main__":
    sys.exit(main())
