#! /usr/bin/env python

import json

def uuid_add(uuid,idx):
    uuid_pieces = uuid.split('-')
    last_one = 0xffffffffffff & (int(uuid_pieces[-1],base=16) + idx)
    uuid_pieces[-1] = "%0x" % last_one
    return '-'.join(uuid_pieces)

def collapse_subnets(full_list):
    parsed_out = map( lambda x: (x,cidr_to_hex(x)), full_list )
    rearranged = map( lambda x: (x[1][0] , (x[1][1],x[0])), parsed_out )
    full_map = dict(rearranged)
    base_addrs = full_map.keys()
    base_addrs.sort()
    removals = []
    for i in range(len(base_addrs)):
        cur_addr = base_addrs[i]
        cur_data = full_map[cur_addr]
        cur_max = cur_data[0]
        for j in range(i):
            prev_addr = base_addrs[j]
            prev_data = full_map[prev_addr]
            prev_max = prev_data[0]
            if cur_addr < prev_addr: continue
            if cur_addr > prev_max: continue
            if cur_max > prev_max: continue
            removals.append(cur_addr)
            break
    for addr in removals:
        full_map.pop(addr)
    subnets = map(lambda x: full_map[x][1], full_map)
    return subnets

def cidr_to_hex(cidr):
    base_addr, mask_len = cidr.split("/")

    mask = 0xffffffff & ( 0xffffffff << (32-int(mask_len)) )

    q1,q2,q3,q4 = base_addr.split(".")
    base_naddr = (int(q1)<<24) + (int(q2)<<16) + (int(q3)<<8) + int(q4)
    # Make sure the base address is the base of the block (ie, valid CIDR)
    base_naddr = base_naddr & mask
    max_addr = base_naddr + (2 ** (32-int(mask_len))) - 1
    return base_naddr,max_addr

def hex_to_quads(hex_addr):
    q1 = (hex_addr & 0xff000000) >> 24
    q2 = (hex_addr & 0xff0000) >> 16
    q3 = (hex_addr & 0xff00) >> 8
    q4 = (hex_addr & 0xff)
    return "%d.%d.%d.%d" % (q1,q2,q3,q4)

def extract(ntwk,cidr,root,root_bridge,root_address,uplist=[]):
    if 'name' in ntwk:
        name = ntwk['name']
    else:
        module.fail_json(msg="Missing 'name' attribute")
    networks = []
    gw_name = name + "0"
    current = { "name": name }
    if 'subnet' not in ntwk:
        module.fail_json(msg="Missing 'subnet' attribute")
    current["top"] = cidr
    current["subnet"] = ntwk['subnet']
    current["upward"] = collapse_subnets( [ cidr, current["subnet"] ] + uplist )
    addr_base, max_addr = cidr_to_hex(current["subnet"])
    gw_addr = addr_base + 0
    bridge_dev = "ctr_b%0X" % gw_addr

    if 'nodes' in ntwk:
        nodes = ntwk['nodes']

        if 'num' not in nodes:
            module.fail_json(msg="Missing 'num' attribute")
        num_nodes = nodes['num']
        if addr_base + num_nodes > max_addr:
            module.fail_json(
                msg="Requested %d nodes, but address space is too small" % num_nodes
            )

        current["nodes"] = []
        for idx in range(num_nodes):
            node_idx = idx+1
            addr = addr_base + node_idx
            node_data = {
                "name" : name + str(node_idx),
                "address" : hex_to_quads(addr),
                "internal_device" : "ctr_i%0X" % addr,
                "external_device" : "ctr_e%0X" % addr,
                "route_up" : current["upward"],
                "gateway" : gw_name,
                "bridge" : bridge_dev,
                "gateway_address" : hex_to_quads(gw_addr),
                "subnet" : current["subnet"]
            }
            if 'uuid' in nodes:
                node_data["uuid"] = uuid_add(nodes["uuid"],node_idx)
            current["nodes"].append(node_data)
            for k in nodes:
                if k == 'num': continue
                if k == 'uuid': continue
                node_data[k] = nodes[k]

    if 'networks' in ntwk:
        current["downstreams"] = []
        for subnet in ntwk['networks']:
            g,d = extract(subnet,cidr,
                          gw_name,bridge_dev,hex_to_quads(gw_addr),
                          uplist=current["upward"])
            networks.extend(g)
            downstream = {}
            downstream["name"] = d["name"]
            if 'downstreams' in d:
                ds_reach = map(lambda x: x["reachable"], d["downstreams"])
                ds_ntwks = [ d["subnet"] ]
                for x in ds_reach:
                    ds_ntwks.extend(x)
                downstream["reachable"] = collapse_subnets( ds_ntwks )
            else:
                downstream["reachable"] = [ d["subnet"] ]
            current["downstreams"].append(downstream)
    if 'downstreams' in current:
        ds_reach = map(lambda x: x["reachable"], current["downstreams"])
        ds_ntwks = [ current["subnet"] ]
        for x in ds_reach:
            ds_ntwks.extend( x )
        current["reachable"] = collapse_subnets( ds_ntwks )
    else:
        current["reachable"] = [ current["subnet"] ]
    current["upstream"] = root
    current["upstream_bridge"] = root_bridge
    current["upstream_address"] = root_address
    current["gateway"] = {
        "name" : gw_name,
        "address" : hex_to_quads(gw_addr),
        "bridge" : bridge_dev,
        "internal_device" : "ctr_i%0X" % gw_addr,
        "external_device" : "ctr_e%0X" % gw_addr,
        "upstream" : current["upstream"],
        "upstream_bridge" : current["upstream_bridge"],
        "upstream_address" : current["upstream_address"],
        "route_down" : current["reachable"],
        "route_up" : current["upward"],
    }
    networks.append(current)
    return (networks,current)

def main():
    module = AnsibleModule(argument_spec=dict(
        network=dict(required=True),
        cidr=dict(required=True),
        root=dict(required=True),
        root_bridge=dict(required=True),
        root_address=dict(required=True),
    ))

    changed=False
    network = module.params['network']
    cidr = module.params['cidr']
    root = module.params['root']
    root_bridge = module.params['root_bridge']
    root_address = module.params['root_address']
    js = network.replace('\'','"')
    j = json.loads(js)
    netlist, top = extract(j,cidr,root,root_bridge,root_address)
    netlist_tuples = map(lambda x: (x["name"],x), netlist)
    netmap = dict(netlist_tuples)
    gateways = []
    nodes = []
    for ntwk_name in netmap:
        ntwk = netmap[ntwk_name]
        gateways.append(ntwk["gateway"])
        if 'nodes' in ntwk:
            nodes.extend(ntwk["nodes"])
    top_dict = {
        "name" : top["name"],
        "gateway" : top["gateway"]["name"],
        "bridge" : top["gateway"]["bridge"],
        "address" : top["gateway"]["address"],
        "external_device" : top["gateway"]["external_device"],
        "internal_device" : top["gateway"]["internal_device"]
    }
    reachable = top["reachable"]

    module.exit_json(changed=changed,networks=netmap,top=top_dict,gateways=gateways,nodes=nodes,reachable=reachable)

from ansible.module_utils.basic import *
main()
