import shlex, subprocess, logging
from netaddr import IPNetwork
import netifaces, sys, getopt, re, os

dry_run = False
verbose = False

def exec_local_cmd(cmd, dr=False):
	global verbose

	if dr or verbose:
		print(cmd)
		if dr:
			return True

	args = shlex.split(cmd)
	try:
		out = subprocess.Popen(args, stderr=subprocess.STDOUT,
			stdout=subprocess.PIPE)
	except Exception as e:
		logging.critical(e)
		return
	t = out.communicate()[0],out.returncode
	if t[1] != 0:
		return None
	return t

arp_info = {'accept_local': 1, 'arp_announce': 2, 'rp_filter': 0}
intf_arp_info =  {'arp_ignore': 1, 'arp_filter': 0, 'arp_announce': 2, 'rp_filter': 0}

def get_attr(attr_path, value, def_arp, new_arp):
	rc = exec_local_cmd("sysctl "+attr_path)
	if rc == None:
		print("Failed No " + attr_path)
		exit()
	def_arp[attr_path] = rc[0].decode("utf-8").split(' = ')[1]
	new_arp[attr_path] = value
	return def_arp, new_arp

def build_arp_info(intf_list):
	def_arp = {}
	new_arp = {}
	all_prefix = 'net.ipv4.conf.all.'
	if_prefix = 'net.ipv4.conf.'
	for attr, value in arp_info.items():
		attr_path = all_prefix+attr
		def_arp, new_arp = get_attr(attr_path, value, def_arp, new_arp)
	for intf in intf_list:
		for attr, value in intf_arp_info.items():
			attr_path = if_prefix+intf+'.'+attr
			def_arp, new_arp = get_attr(attr_path, value, def_arp, new_arp)
	return def_arp, new_arp

def setup_arp(attr_dict):
	global dry_run
	for attr, value in attr_dict.items():
		if exec_local_cmd("sysctl -w "+attr+'='+str(value), dry_run) == None:
			print("Failed. No " + attr)
			return False
	return True

def add_routes(intf_list):
	global dry_run

	# find all the tables
	rc = exec_local_cmd("ip route show table all")
	if rc == None:
		raise ValueError("No routing information")
	result = rc[0].decode("utf-8")
	rtlist = result.strip().split("\n")
	# remove any existing tables
	for rt in rtlist:
		if "table" not in rt:
			continue
		table = rt.split('table')[1].strip().split()[0]
		if table in intf_list:
			exec_local_cmd('ip route flush table ' + table, dry_run)
	# add new entries for routing tables if they don't exist
	if not os.path.isfile('/etc/iproute2/rt_tables'):
		raise ValueError('/etc/iproute2/rt_tables not found')
	f = open('/etc/iproute2/rt_tables', 'r')
	lines = f.readlines()
	f.close()
	max_table_num = 0
	if sys.version_info[0] < 3:
		local_intf_list = list(intf_list)
	else:
		local_intf_list = intf_list.copy()
	for line in lines:
		found = False
		if len(line.strip()) > 0 and line.strip()[0] == '#':
			continue
		table_num = int(re.split('\t+| ', line)[0].strip())
		if max_table_num < table_num:
			max_table_num = table_num
		for intf in intf_list:
			if intf in line:
				local_intf_list.remove(intf)
				break
	for intf in local_intf_list:
		max_table_num += 1
		if not dry_run:
			f = open('/etc/iproute2/rt_tables', 'a')
			f.write(str(max_table_num) + ' ' + intf + '\n')
			f.close()
		else:
			print('echo "'+str(max_table_num) + ' ' + intf+'" >> /etc/iproute2/rt_tables')
	# add the routing entries and rules
	for intf in intf_list:
		addr = netifaces.ifaddresses(intf)[netifaces.AF_INET][0]['addr']
		netmask = netifaces.ifaddresses(intf)[netifaces.AF_INET][0]['netmask']
		cidr = str(IPNetwork(addr+'/'+netmask).cidr)
		cmd = 'ip route add '+cidr+' dev '+ intf + ' proto kernel scope link src '+addr+' table '+intf
		if not exec_local_cmd(cmd, dry_run):
			raise ValueError('Failed: '+cmd)
		cmd = 'ip rule del from '+addr+' table '+intf
		exec_local_cmd(cmd, dry_run)
		# add rule
		cmd = 'ip rule add from '+addr+' table '+intf
		exec_local_cmd(cmd, dry_run)
	exec_local_cmd('ip route flush cache', dry_run)

def flush_neigh(intf_list):
	for intf in intf_list:
		if exec_local_cmd("ip neigh flush dev "+intf, dry_run) == None:
			print("Failed. No " + intf)

def configure_routing(intf_list, dr=False):
	global dry_run
	if dr:
		dry_run = dr
	old_arp, new_arp = build_arp_info(intf_list)
	if not setup_arp(new_arp):
		setup_arp(old_arp)
	add_routes(intf_list)
	flush_neigh(intf_list)


def main(argv):
	intf_list = ''
	global dry_run
	global verbose
	try:
		opts, args = getopt.getopt(argv,"hdvi:",["help", "dry-run", "verbose", "if="])
	except getopt.GetoptError:
		sys.exit(2)

	if len(argv) < 1:
		print("python %s [--dry-run] -i <comma separated interface list>" % __file__)
		sys.exit()

	for opt, arg in opts:
		if opt == '-h':
			print("python %s [--dry-run] -i <comma separated interface list>" % __file__)
			sys.exit()
		elif opt in ("-i", "--if"):
			intf_list = arg
		elif opt in ("-d", "--dry-run"):
			dry_run = True
		elif opt in ("-v", "--verbose"):
			verbose = True

	if len(intf_list) <= 0:
		print("python %s [--dry-run] -i <comma separated interface list>" % __file__)
		sys.exit()

	configure_routing(re.split(',| ', intf_list))

if __name__ == "__main__":
	main(sys.argv[1:])
