#!/usr/bin/python3

import os
import sys
import re
import subprocess
import threading
import argparse
import textwrap

class PreserveWhiteSpaceWrapRawTextHelpFormatter(argparse.RawTextHelpFormatter):
	def __add_whitespace(self, idx, iWSpace, text):
		if idx == 0:
			return text
		return (' ' * iWSpace) + text

	def _split_lines(self, text, width):
		textRows = text.splitlines()
		for idx,line in enumerate(textRows):
			search = re.search(r'\s*[\d\-]*\.?\s*', line)
			if line.strip() == '':
				textRows[idx] = ' '
			elif search:
				lWSpace = search.end()
				lines = [self.__add_whitespace(i,lWSpace,x) for i,x in enumerate(textwrap.wrap(line, width))]
				textRows[idx] = lines
		return [item for sublist in textRows for item in sublist]


ap = argparse.ArgumentParser(
	formatter_class=PreserveWhiteSpaceWrapRawTextHelpFormatter,
	description=textwrap.dedent('''\
		Wireguard Info
		==============

		This tool enhances the output of 'wg show' to include node names.
		Also it can ping the nodes (using the first ip in AllowedIPs)
		and indicate the online status via red/green color coding.

		It expects you to use wg-quick and reads the wg-quick config at
		/etc/wireguard/INTERFACE.conf

		The human readable peer names are expected in the wg-quick config 
		within a comment like this:
		[Peer]
		# Name = Very secret node in antarctica

	'''))
ap.add_argument('--html', action='store_true', help="Format output as HTML")
ap.add_argument('--tty', action='store_true', help="Force terminal colors even when writing to pipe")
ap.add_argument('-p', '--ping', action='store_true', help="Ping all nodes (in parallel) and show online status")
ap.add_argument('-i', '--interface', help="Only show status for this interface")
args = ap.parse_args()

if args.html:
	output = 'html'
	redfmt = '<span style="color: red;">'
	redbldfmt = '<span style="color: red; font-weight: bold;">'
	greenfmt = '<span style="color: green;">'
	greenbldfmt = '<span style="color: green; font-weight: bold;">'
	yellowfmt = '<span style="color: orange;">'
	yellowbldfmt = '<span style="color: orange; font-weight: bold;">'
	bldfmt = '<span style="font-weight: bold;">'
	endfmt = '</span>'
elif sys.stdout.isatty() or args.tty:
	output = 'tty'
	redfmt = '\033[0;31m'
	redbldfmt = '\033[1;31m'
	greenfmt = '\033[0;32m'
	greenbldfmt = '\033[1;32m'
	yellowfmt = '\033[0;33m'
	yellowbldfmt = '\033[1;33m'
	bldfmt = '\033[1m'
	endfmt = '\033[0m'
else:
	output = 'pipe'
	redfmt = ''
	redbldfmt = ''
	greenfmt = ''
	greenbldfmt = ''
	yellowfmt = ''
	yellowbldfmt = ''
	bldfmt = ''
	endfmt = ''

peers = {}


def read_config(interface):
	peer_section = False
	peer_name = "*nameless*"
	peer_pubkey = ""
	peer_ip = ""

	with open('/etc/wireguard/%s.conf' % interface) as cfg:
		for line in cfg.readlines():
			line = line.strip()
			if line == "[Peer]":
				if peer_section and peer_pubkey:
					peers[peer_pubkey] = {
						'name': peer_name,
						'ip': peer_ip,
					}
					peer_name = "*nameless*"
					peer_pubkey = ""
					peer_ip = ""
				peer_section = True

			if peer_section:
				if line.startswith("PublicKey"):
					peer_pubkey = line.split('=', 1)[-1].strip()
				elif re.match(r"#?\s*Name", line):
					peer_name = line.split('=', 1)[-1].strip()
				elif line.startswith("AllowedIPs") and not peer_ip:
					peer_ip = line.split('=', 1)[-1].strip().split(',')[0].split('/')[0]

		if peer_section and peer_pubkey:
			peers[peer_pubkey] = {
				'name': peer_name,
				'ip': peer_ip,
			}


def show_info(interface):
	peer_section = False
	for line in subprocess.check_output(
	  ['wg', 'show', interface]).decode('utf-8').split("\n")[:-1]:
		line = line.strip()

		if line.startswith('peer:'):
			peer_section = True
			peer_pubkey = line.split(':', 1)[1].strip()

			if peers[peer_pubkey].get('online', True):
				colorfmt = greenfmt
				colorbldfmt = greenbldfmt
			else:
				colorfmt = redfmt
				colorbldfmt = redbldfmt
			print('  ' + colorbldfmt + 'peer' + endfmt + ': ' + colorfmt +
				  peers[peer_pubkey]['name'] + ' ('+peer_pubkey+')' + endfmt)
		elif line.startswith('interface:'):
			peer_section = False
			interface = line.split(':', 1)[1].strip()
			print(yellowbldfmt + 'interface' + endfmt + ': ' + yellowfmt + interface + endfmt)
		elif line.startswith('preshared key:') or line.startswith('private key:'):
			continue
		elif line:
			key = line.split(':')[0].strip()
			value = line.split(':', 1)[1].strip()
			indent = '	' if peer_section else '  '
			print(indent + bldfmt + key + endfmt + ': ' + value)
		else:
			print(line)


def ping(pubkey):
	FNULL = open(os.devnull, 'w')
	retcode = subprocess.call(['ping', '-c1', '-W1', peers[pubkey]['ip']], stdout=FNULL,
							  stderr=subprocess.STDOUT)
	peers[pubkey]['online'] = (retcode == 0)


def lookahead(iterable):
	"""Pass through all values from the given iterable, augmented by the
	information if there are more values to come after the current one
	(True), or if it is the last value (False).
	"""

	# Get an iterator and pull the first value.
	it = iter(iterable)
	last = next(it)

	# Run the iterator to exhaustion (starting from the second value).
	for val in it:
		# Report the *previous* value (more to come).
		yield last, True
		last = val

	# Report the last value.
	yield last, False


def get_up_interfaces(interfaces):
	output = subprocess.check_output(["ip", "link"])
	up = []
	for line in re.findall(rb"\s?\d+: .*:[^\n]+", output):
		interface = line.split(b':')[1].decode('utf-8').strip()
		if interface in interfaces:
			up.append(interface)
	return up


if os.getuid() != 0:
    print(redbldfmt + "\nERROR: " + yellowfmt + "The script must run as root\n" + endfmt)
    exit(-1)

interfaces = [i.replace('.conf', '') for i in os.listdir('/etc/wireguard/') if i.endswith('.conf')]
if args.interface:
	interfaces = [args.interface]
interfaces = get_up_interfaces(interfaces)
for interface in interfaces:
	read_config(interface)


# Ping all peers in parallel so that this will just take 1 second in total
if args.ping:
	threads = []
	for peer in peers:
		th = threading.Thread(target=ping, args=(peer,), daemon=True)
		threads.append(th)
		th.start()

	for th in threads:
		th.join()

if output == 'html':
	print('<pre>')

if len(interfaces) > 0:
	for interface, has_more in lookahead(interfaces):
		show_info(interface)
		if has_more:
			print("\n")

if output == 'html':
	print('</pre>')
