#!/usr/bin/python

#host fantasia {
#  dhcp-client-identifier
#  hardware ethernet 08:00:07:26:c0:a5;
#  fixed-address fantasia.fugue.com;
#}

#subnet 1.2.3.0 netmask 255.255.255.0 {
#  option routers 1.2.3.4;
#  range 1.2.3.100 1.2.3.200;
#  option domain-name "foo.bar.example.com";
#}

#shared-network "foo" {
#}

from ldaptor.protocols.ldap import ldapclient, ldapconnector, ldapsyntax
from ldaptor.protocols import pureber, pureldap
from ldaptor import usage, ldapfilter, config
from twisted.internet import reactor
from socket import inet_aton, inet_ntoa
import sets


def my_aton_octets(ip):
    s=inet_aton(ip)
    octets=map(None, s)
    n=0L
    for o in octets:
        n=n<<8
        n+=ord(o)
    return n

def my_aton_numbits(num):
    n=0L
    while num>0:
        n>>=1
        n |= 2**31
        num-=1
    return n

def my_aton(ip):
    try:
        i=int(ip)
    except ValueError:
        return my_aton_octets(ip)
    else:
        return my_aton_numbits(i)

def my_ntoa(n):
    s=(
        chr((n>>24)&0xFF)
        + chr((n>>16)&0xFF)
        + chr((n>>8)&0xFF)
        + chr(n&0xFF)
       )
    ip=inet_ntoa(s)
    return ip

class HostIPAddress:
    def __init__(self, host, ipAddress):
        self.host=host
        self.ipAddress=ipAddress

    def printDHCP(self, domain, prefix=''):
        def output():
            yield '# %s' % self.host.dn
            yield 'host %s.%s {' % (self.host.name, domain)
            for mac in self.host.macAddresses:
                yield '\thardware ethernet %s;' % mac
            yield '\tfixed-address %s;' % self.ipAddress
            if self.host.bootFile is not None:
                # TODO quote bootFile
                yield '\tfilename "%s";' % self.host.bootFile
            yield '}'

        print '\n'.join([prefix+line for line in output()])

    def __repr__(self):
        return (self.__class__.__name__
                +'('
                +'host=%s, ' % id(self.host)
                +'ipAddress=%s' % repr(self.ipAddress)
                +')')

class Group:
    def __init__(self, dn, bootFile=None):
        self.dn = dn
        self.bootFile = bootFile
        self.hosts = sets.Set()

    def addHost(self, host):
        if host.group is not None:
            print >>sys.stderr, (
                'Host %s is in two groups: %r and %r'
                % (host.dn, host.group, self))
        else:
            host.group = self
            self.hosts.add(host)

    def printDHCP(self, domain, addrs, prefix=''):
        addresses = sets.Set([addr
                              for host in self.hosts
                              for addr in host.ipAddresses])
        addresses.intersection_update(addrs)
        addrs.difference_update(addresses)

        if addresses:
            print prefix+'# '+str(self.dn)
            print prefix+'group {'

            if self.bootFile is not None:
                # TODO quote bootFile
                print prefix+'\tfilename "%s";' % self.bootFile

            for addr in addresses:
                addr.printDHCP(domain, prefix=prefix+'\t')

            print prefix+'}'

class Host:
    group = None

    def __init__(self, dn, name, ipAddresses, macAddresses=(),
                 bootFile=None):
        self.dn=dn
        self.name=name
        self.ipAddresses=[HostIPAddress(self, ip) for ip in ipAddresses]
        self.macAddresses=macAddresses
        self.bootFile = bootFile

    def __repr__(self):
        return (self.__class__.__name__
                +'('
                +'dn=%s, ' % repr(self.dn)
                +'name=%s, ' % repr(self.name)
                +'ipAddresses=%s, ' % repr(self.ipAddresses)
                +'macAddresses=%s, ' % repr(self.macAddresses)
                +'bootFile=%s' % repr(self.bootFile)
                +')')

class Net:
    def __init__(self, dn, name, address, mask,
                 routers=(),
                 dhcpRanges=(),
                 winsServers=(),
                 domainNameServers=(),
                 ):
        self.dn=dn
        self.name=name
        self.address=address
        self.mask=mask
        self.routers=routers
        self.dhcpRanges=dhcpRanges
        self.winsServers=winsServers
        self.domainNameServers=domainNameServers
        self.hosts=[]

    def isInNet(self, ipAddress):
        net = my_aton(self.address)
        mask = my_aton(self.mask)
        ip = my_aton(ipAddress)
        if ip&mask == net:
            return 1
        return 0

    def addHost(self, host):
        assert self.isInNet(host.ipAddress)
        self.hosts.append(host)

    def printDHCP(self, domain, prefix=''):
        nm = self.mask
        nm = my_aton(nm)
        nm = my_ntoa(nm)
        r = ['# %s' % self.dn,
             'subnet %s netmask %s {' % (self.address, nm),
             '\toption domain-name "%s.%s";' % (self.name, domain)]
        if self.routers:
            r.append('\toption routers %s;' % (', '.join(self.routers)))
        for dhcpRange in self.dhcpRanges:
            r.append('\trange %s;' % dhcpRange)
        if self.winsServers:
            r.append('\toption netbios-name-servers %s;' % (', '.join(self.winsServers)))
        if self.domainNameServers:
            r.append('\toption domain-name-servers %s;' % (', '.join(self.domainNameServers)))
        r.append('}')

        print '\n'.join([prefix+line for line in r])

        addrs = sets.Set()
        for addr in self.hosts:
            addrs.add(addr)

        for addr in self.hosts:
            g = addr.host.group
            if g is not None:
                g.printDHCP(self.name+'.'+domain,
                            addrs,
                            prefix=prefix)

        while addrs:
            addr = addrs.pop()
            addr.printDHCP(self.name+'.'+domain, prefix=prefix)

    def __repr__(self):
        return (self.__class__.__name__
                +'('
                +'dn=%s, ' % repr(self.dn)
                +'name=%s, ' % repr(self.name)
                +'address=%s, ' % repr(self.address)
                +'mask=%s' % repr(self.mask)
                +')')

class SharedNet:
    def __init__(self, name):
        self.name=name
        self.nets=[]

    def addNet(self, net):
        self.nets.append(net)

    def printDHCP(self, domain):
        print 'shared-network "%s" {' % self.name
        for net in self.nets:
            net.printDHCP(domain, prefix='\t')
        print '}'
        print

def _cbGetGroups(entries, hosts):
    dnToHost = {}
    for host in hosts:
        assert host.dn not in dnToHost
        dnToHost[host.dn] = host

    for e in entries:
        group = Group(dn=e.dn,
                      bootFile=only(e, 'bootFile', None))

        for member in e.get('member', []):
            host = dnToHost.get(member, None)
            if host is not None:
                group.addHost(host)

    return hosts

def getGroups(hosts, e, filter):
    """Add group info to hosts."""
    def buildFilter(hosts):
        for host in hosts:
            f = pureldap.LDAPFilter_equalityMatch(
                attributeDesc=pureldap.LDAPAttributeDescription('member'),
                assertionValue=pureber.BEROctetString(str(host.dn)))
            yield f

    filt=pureldap.LDAPFilter_and(value=(
        # the only reason we do groups is for the bootFile,
        # so require one to be present
        pureldap.LDAPFilter_present('bootFile'),

        pureldap.LDAPFilter_or(value=list(buildFilter(hosts))),
        ))
    if filter:
        filt = pureldap.LDAPFilter_and(value=(filter, filt))

    d = e.search(filterObject=filt,
                 attributes=['member',
                             'bootFile'])

    d.addCallback(_cbGetGroups, hosts)
    return d

def haveHosts(hosts, e, filt, nets, sharedNets, dnsDomain):
    d = getGroups(hosts, e, filt)
    d.addCallback(haveGroups, nets, sharedNets, dnsDomain)
    return d

def haveGroups(hosts, nets, sharedNets, dnsDomain):
    for host in hosts:
        for hostIP in host.ipAddresses:
            parent=None
            for net in nets + reduce(lambda x,y: x+y,
                                          [x.nets for x in sharedNets.values()],
                                          []):
                if net.isInNet(hostIP.ipAddress):
                    parent=net
                    break

            if parent:
                parent.addHost(hostIP)
            else:
                sys.stderr.write("IP address %s is in no net, discarding.\n" % hostIP)

    for net in sharedNets.values():
        net.printDHCP(dnsDomain)
    for net in nets:
        net.printDHCP(dnsDomain)

class _NO_DEFAULT(object): pass
def only(e, attr, default=_NO_DEFAULT):
    val = e.get(attr, _NO_DEFAULT)
    if val is _NO_DEFAULT:
        if default is not _NO_DEFAULT:
            return default
        else:
            raise RuntimeError("object %s does not have attribute %r."
                               % (e.dn, attr))
    else:
        if len(val)!=1:
            raise RuntimeError("object %s attribute %r has multiple values: %s"
                               % (e.dn, attr, val))
        for item in val:
            return item

def _cbGetHosts(entries):
    hosts = []
    for e in entries:
        cn = only(e, 'cn')
        hosts.append(Host(str(e.dn),
                          str(cn),
                          map(str, e['ipHostNumber']),
                          map(str, e.get('macAddress', ())),
                          bootFile=only(e, 'bootFile', default=None),
                          ))
    return hosts

def getHosts(e, filter):
    filt=pureldap.LDAPFilter_and(value=(
        pureldap.LDAPFilter_present('cn'),
        pureldap.LDAPFilter_present('ipHostNumber'),
        ))
    if filter:
        filt = pureldap.LDAPFilter_and(value=(filter, filt))

    d = e.search(filterObject=filt,
                 attributes=['cn',
                             'ipHostNumber',
                             'macAddress',
                             'bootFile',
                             ])
    d.addCallback(_cbGetHosts)
    return d

def haveNets(data, e, baseDN, filt, dnsDomain):
    nets, sharedNets = data
    d = getHosts(e, filt)
    d.addCallback(haveHosts, e, filt, nets, sharedNets, dnsDomain)
    return d

def _cbGetNets(entries):
    sharedNetworks = {}
    nets = []

    for e in entries:
        cn=only(e, 'cn')
        ipNetworkNumber=only(e, 'ipNetworkNumber')
        ipNetmaskNumber=only(e, 'ipNetmaskNumber')
        net = Net(e.dn, cn,
                  ipNetworkNumber, ipNetmaskNumber,
                  routers=e.get('router', ()),
                  dhcpRanges=e.get('dhcpRange', ()),
                  winsServers=e.get('winsServer', ()),
                  domainNameServers=e.get('domainNameServer', ()),
                  )
        if e.has_key('sharedNetworkName'):
            name = only(e, 'sharedNetworkName')
            if not sharedNetworks.has_key(name):
                sharedNetworks[name]=SharedNet(name)
            sharedNetworks[name].addNet(net)
        else:
            nets.append(net)

    return (nets, sharedNetworks)

def getNets(e, filter):
    filt=pureldap.LDAPFilter_and(value=(
        pureldap.LDAPFilter_present('cn'),
        pureldap.LDAPFilter_present('ipNetworkNumber'),
        pureldap.LDAPFilter_present('ipNetmaskNumber'),
        ))
    if filter:
        filt = pureldap.LDAPFilter_and(value=(filter, filt))

    d = e.search(filterObject=filt,
                 attributes=['cn',
                             'ipNetworkNumber',
                             'ipNetmaskNumber',
                             'router',
                             'dhcpRange',
                             'winsServer',
                             'domainNameServer',
                             'sharedNetworkName'])
    d.addCallback(_cbGetNets)
    return d

def search(client, baseDN, filter, dnsDomain):
    e=ldapsyntax.LDAPEntry(client=client, dn=baseDN)
    d = getNets(e, filter)
    d.addCallback(haveNets, e, baseDN, filter, dnsDomain)
    return d


exitStatus=0

def error(fail):
    print >>sys.stderr, 'fail:', fail.getErrorMessage()
    global exitStatus
    exitStatus=1

def main(cfg, filter_text, dnsDomain):
    try:
        baseDN = cfg.getBaseDN()
    except config.MissingBaseDNError, e:
        print >>sys.stderr, "%s: %s." % (sys.argv[0], e)
        sys.exit(1)

    from twisted.python import log
    log.startLogging(sys.stderr, setStdout=0)

    if filter_text is not None:
        filt = ldapfilter.parseFilter(filter_text)
    else:
        filt = None

    c = ldapconnector.LDAPClientCreator(reactor,
                                        ldapclient.LDAPClient)
    d = c.connectAnonymously(dn=baseDN,
                             overrides=cfg.getServiceLocationOverrides())
    d.addCallback(search, baseDN, filt, dnsDomain)
    d.addErrback(error)
    d.addBoth(lambda x: reactor.stop())

    reactor.run()
    sys.exit(exitStatus)

class MyOptions(usage.Options,
                usage.Options_service_location,
                usage.Options_base_optional):
    """LDAPtor dhcpd config file exporter"""

    optParameters = (
        ('dns-domain', None, 'example.com',
         "DNS domain name"),
        )

    def parseArgs(self, filter=None):
        self.opts['filter'] = filter

if __name__ == "__main__":
    import sys
    try:
        opts = MyOptions()
        opts.parseOptions()
    except usage.UsageError, ue:
        sys.stderr.write('%s: %s\n' % (sys.argv[0], ue))
        sys.exit(1)

    cfg = config.LDAPConfig(baseDN=opts['base'],
                            serviceLocationOverrides=opts['service-location'])
    main(cfg,
         opts['filter'],
         opts['dns-domain'],
         )
