#lustre:
#   operation: <build|config|all>
#   build_src: <path/to/lustre_src>
#   build_script: <path/to/build_script>
#   config_name: lnet
#   lnet: 
#      net:
#        - net type: <network>
#             interfaces: <intf list>
#      tunables:
#         peer_credits: <>
#         peer_credits_hiw: <>
#         concurrent_sends: <>
#         <other lnet params>: <>
#   lustre:
#      <lustre paramets>: <>
#

import shlex, subprocess, logging
import sys, getopt, re, os
import yaml
import mrrouting

verbose = False
dry_run = False

class yamlDumper(yaml.Dumper):
	def increase_indent(self, flow=False, indentless=False):
		return super(yamlDumper, self).increase_indent(flow, False)

def exec_local_cmd(cmd, dr=False):
	global verbose

	if dr or verbose:
		print(cmd)
		if dr:
			return True

	args = shlex.split(cmd)
	try:
		out = subprocess.Popen(args, stderr=subprocess.STDOUT,
			stdout=subprocess.PIPE)
	except Exception as e:
		logging.critical(e)
		return
	t = out.communicate()[0],out.returncode
	if t[1] != 0:
		return None
	return t

def grab_mod_info(module):
	rc = exec_local_cmd('modinfo '+module)
	if not rc:
		return None
	info = rc[0].decode('utf-8').split('\n')
	infod = {}
	for e in info:
		desc = e.split(maxsplit=1)
		if len(desc) <= 0:
			continue

		key = desc[0].strip()[:-1]
		value = desc[1].strip()

		#print(key + ' ' + value)
		if key == 'parm' and key in infod:
			tmp = value.split(':', 1)
			infod[key][tmp[0]] = tmp[1]
		elif key == 'parm':
			tmp = value.split(':', 1)
			infod[key] = {tmp[0]: tmp[1]}
		else:
			infod[key] = value
	return infod

def translate2lnet_yaml_net(nets):
	networks = []
	for e in nets:
		existing = False
		entry = {}
		for existing_nets in networks:
			if e['net type'] == existing_nets['net type']:
				entry = existing_nets
				existing = True
		if not entry:
			entry['net type'] = e['net type']
			entry['local NI(s)'] = []
		intfs = re.split(' |,', e['interfaces'])
		for intf in intfs:
			existing_intf = False
			for localni in entry['local NI(s)']:
				if intf == localni['interfaces'][0]:
					existing_intf = True
					break
			if existing_intf:
				continue
			intf_entry = {}
			intf_entry['interfaces'] = {0: intf}
			if 'tunables' in e:
				intf_entry['tunables'] = e['tunables'].copy()
			entry['local NI(s)'].append(intf_entry)
		if not existing:
			networks.append(entry)
	return networks

def translate2lnet_yaml_global(glob):
	glob_yaml = {}
	if 'lnet_numa_range' in glob:
		glob_yaml['numa_range'] = glob['lnet_numa_range']
	if 'lnet_peer_discovery_disabled' in glob:
		if glob['lnet_peer_discovery_disabled'] == 0:
			glob_yaml['discovery'] = 1
		else:
			glob_yaml['discovery'] = 0
	if 'lnet_transaction_timeout' in glob:
		glob_yaml['transaction_timeout'] = glob['lnet_transaction_timeout']
	if 'lnet_health_sensitivity' in glob:
		glob_yaml['health_sensitivity'] = glob['lnet_health_sensitivity']
	if 'lnet_recovery_interval' in glob:
		glob_yaml['recovery_interval'] = glob['lnet_recovery_interval']
	if 'lnet_router_sensitivity' in glob:
		glob_yaml['router_sensitivity'] = glob['lnet_router_sensitivity']
	if 'lnet_drop_asym_route' in glob:
		glob_yaml['drop_asym_route'] = glob['lnet_drop_asym_route']
	if 'lnet_retry_count' in glob:
		glob_yaml['retry_count'] = glob['lnet_retry_count']
	return glob_yaml

def translate2lnet_yaml(lnetcfg):
	config = {}
	config['net'] = translate2lnet_yaml_net(lnetcfg['net'])
	if 'route' in lnetcfg:
		config['route'] = lnetcfg['route']
	if 'global' in lnetcfg:
		config['global'] = translate2lnet_yaml_global(lnetcfg['global'])
	return config

def build_lustre(script, src, pybin):
	global dry_run
	if not os.path.isfile(script):
		print("build script '%s' missing" % script)
		exit(2)
	if not os.path.isfile(src) and \
	   not os.path.isdir(src):
		print("Lustre source '%s' missing" % src)
		exit(2)

	cmd = pybin+' '+script+' '+'-v -s '+src
	if not dry_run:
		args = shlex.split(cmd)
		subprocess.run(args)
		print("building Lustre script finished")
	else:
		print(cmd)

def config_lustre(params):
	print("This functionality is currently not supported")

def is_param_valid(key, value, *dicts):
	# TODO: validate parameter
	found = False
	for l in dicts:
		if key in l:
			found = True
			break;

	return found

def config_lnet(params, config_name):
	nets = False
	tunables = False

	# run modeinfo to grab all the parameters
	lnet = grab_mod_info('lnet')
	socklnd = grab_mod_info('ksocklnd')
	o2iblnd = grab_mod_info('ko2iblnd')
	# If module is not installed then output an error
	if not lnet:
		print('lnet is not installed')
		exit(2)
	if not socklnd:
		print('ksocklnd is not installed')
		exit(2)
	if not o2iblnd:
		print('ko2iblnd is not installed')
		exit(2)
	# Verify that all parameters defined in the YAML config are valid
	lnd_name = ''
	if 'net' in params:
		nets = True
		for e in params['net']:
			if not 'net type' in e or \
			   not 'interfaces' in e:
				print("malformed lnet configuration")
				return
			if 'tcp' in e['net type']:
				lnd = socklnd['parm']
				lnd_name = 'ksocklnd'
			elif 'o2ib' in e['net type']:
				lnd = o2iblnd['parm']
				lnd_name = 'ko2iblnd'
			else:
				raise ValueError("unknown net type %s" % e['net type'])
			if 'tunables' in e:
				tunables = True
				for k, v in e['tunables'].items():
					if not is_param_valid(k, v, lnet['parm'], lnd):
						raise ValueError("unrecognized parameter: '%s'" % k)
	if 'global' in params:
		glob = True
		for k, v in params['global'].items():
			if not is_param_valid(k, v, lnet['parm']):
				raise ValueError("unrecognized parameter: '%s'" % k)
	if 'route' in params:
		route = True
		route_params = ['gateway', 'health_sensitivity', 'hop', 'net', 'priority']
		for e in params['route']:
			for k, v in e.items():
				if not k in route_params:
					raise ValueError("unregonized parameter: '%s'" % k)

	intf_list = []
	network_options = 'options lnet networks="'
	tunable_options = ''
	routes_options = ''
	if nets:
		first = True
		for e in params['net']:
			new_intfs = re.split(' |,', e['interfaces'])
			for i in new_intfs:
				if not i in intf_list:
					intf_list.append(i)
			if not first:
				network_options += ','
			network_options += e['net type']+'('+e['interfaces']+')'
			first = False
			if tunables:
				for k, v in e['tunables'].items():
					if k in lnet['parm']:
						mod = 'lnet'
					else:
						mod = lnd_name
					# don't want to repeat existing
					# parameters in the modprobe
					# configuration
					if not k in tunable_options:
						tunable_options += "options " + mod + " "+k+"=" + str(v)+"\n"
		network_options += '"\n'
	if glob:
		for k, v in params['global'].items():
			# don't want to repeat existing
			# parameters in the modprobe
			# configuration
			if not k in tunable_options:
				tunable_options += "options lnet "+k+"=" + str(v)+"\n"
	if route:
		route_options = 'options lnet routes="'
		first = False
		for r in params['route']:
			if not first:
				route_options += ';'
			route_options += r['net']+' '+r['gateway']
			if 'hop' in r and r['hop'] != -1:
				route_options += ' '+str(r['hop'])
			if 'priority' in r and r['priority'] != 0:
				route_options += ' '+str(r['priority'])
		route_options += '\n'

	yaml_cfg = translate2lnet_yaml(params)

	# put the parameters in /etc/modprobe.d/lnet.conf
	cfg = os.path.splitext(config_name)[0]+'.cfg'
	cfgpath = os.path.join(os.sep, 'etc', 'modprobe.d', cfg)
	cfg_yaml = os.path.splitext(config_name)[0]+'.yaml'
	yamlpath = os.path.join(os.sep, 'etc', cfg_yaml)
	if not dry_run:
		with open(cfgpath, 'w') as f:
			f.write(network_options)
			f.write(tunable_options)
		with open(yamlpath, 'w') as f:
			f.write(yaml.dump(yaml_cfg, Dumper=yamlDumper, indent=2, sort_keys=False))
	else:
		print("#> cat %s" % cfgpath)
		print(network_options)
		print(tunable_options)
		print("----------")
		print("#> cat %s" % yamlpath)
		print(yaml.dump(yaml_cfg, Dumper=yamlDumper, indent=2, sort_keys=False))
		print("----------")

	# configure the interfaces
	mrrouting.configure_routing(intf_list, dry_run)

def process_config(config):
	if not os.path.isfile(config):
		print("%s not found" % config)
		exit(2)

	with open(config, 'r') as f:
		try:
			cfg_y = yaml.load(f, Loader=yaml.FullLoader)
		except:
			cfg_y = None
			print("couldn't parse %s" % config)
			exit(2)

	if 'lustre' not in cfg_y:
		print('malformed configuration %s' % config)
		exit(2)

	if cfg_y['lustre']['operation'] == 'all':
		if not 'build_src' in cfg_y['lustre'] or \
			not 'build_script' in cfg_y['lustre']:
			print("build source or script are not defined")
			raise ValueError
		if not 'config_name' in cfg_y['lustre']:
			print("No configuration file name specified. Pleas specify one.")
			raise ValueError
		try:
			pybin = cfg_y['lustre']['python_bin']
		except:
			pybin = 'python3'
		build_lustre(cfg_y['lustre']['build_script'],
			     cfg_y['lustre']['build_src'], pybin)
		if 'lnet' in cfg_y['lustre']:
			config_lnet(cfg_y['lustre']['lnet'], cfg_y['lustre']['config_name'])
		if 'lustre' in cfg_y['lustre']:
			config_lustre(cfg_y['lustre']['lustre'])
	elif cfg_y['lustre']['operation'] == 'build':
		if not 'build_src' in cfg_y['lustre']:
			print("build source is not defined")
			raise ValueError
		try:
			pybin = cfg_y['lustre']['python_bin']
		except:
			pybin = 'python3'
		build_lustre(cfg_y['lustre']['build_script'],
			     cfg_y['lustre']['build_src'], pybin)
	elif cfg_y['lustre']['operation'] == 'config':
		if not 'config_name' in cfg_y['lustre']:
			print("No configuration file name specified. Pleas specify one.")
			raise ValueError
		if 'lnet' in cfg_y['lustre']:
			config_lnet(cfg_y['lustre']['lnet'], cfg_y['lustre']['config_name'])
		if 'lustre' in cfg_y['lustre']:
			config_lustre(cfg_y['lustre']['lustre'])
	else:
		print("operation '%s' not recognized" % cfg_y['lustre']['operation'])
		raise ValueError

def main(argv):
	intf_list = ''
	global verbose
	global dry_run

	try:
		opts, args = getopt.getopt(argv,"hdvc:",["help", "verbose", "dry-run", "cfg="])
	except getopt.GetoptError:
		sys.exit(2)

	if len(argv) < 1:
		print("python %s [--dry-run] --cfg=<YAML configuration file>" % __file__)
		sys.exit()

	config = ''
	for opt, arg in opts:
		if opt == '-h':
			print("python %s [--dry-run] --cfg=<YAML configuration file>" % __file__)
			sys.exit()
		elif opt in ("-c", "--cfg"):
			config = arg
		elif opt in ("-d", "--dry-run"):
			dry_run = True
		elif opt in ("-v", "--verbose"):
			verbose = True

	if len(config) <= 0:
		print("python %s [--dry-run] --cfg=<YAML configuration file>" % __file__)
		sys.exit()

	process_config(config)

if __name__ == "__main__":
	main(sys.argv[1:])
