#!/usr/bin/python3
#
# Trax (network activity tracker) Version 1.0
#
# Copyright 2023 Brandyn Webb
# 
# This file is part of Trax (network activity tracker).
# 
# Trax is free software: you can redistribute it and/or modify it under the terms 
# of the GNU General Public License as published by the Free Software Foundation, 
# either version 3 of the License, or (at your option) any later version.
# 
# Trax is distributed in the hope that it will be useful, but WITHOUT ANY 
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 
# PARTICULAR PURPOSE. See the GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along with 
# Trax. If not, see <https://www.gnu.org/licenses/>. 
# 


from Procs     import ProcessManager, Process
from TextWin2  import Window
from time      import time
from datetime  import datetime
from trio      import socket, sleep
from TrioUtils import CloseableQueue, Event, Queue
from Logger    import Logger
from struct    import unpack
from pickle    import loads
from pprint    import pformat
from os        import getenv

prefix = "IPPROTO_"
proto_names = {num:name[len(prefix):] for name,num in vars(socket).items() if name.startswith(prefix)} # Maps IP protocol # to name

ethertype_names = {
        0x0800: 'IPV4',
        0x0806: 'ARP',
        0x86DD: 'IPV6',
    }

# Navigation:
down_keys = (258, ord('j'))
up_keys   = (259, ord('k'))
left_keys = (260, ord('h'))
right_keys= (261, ord('l'))
del_keys  = (263, 330)

class TraxPrefs(object):
    def __init__(self, host=getenv("TRAXD") or 'localhost', port=3344, local=False, units='MB', verbose=0, indent='  ', color=True, percent=True, graph=True):
        self.host    = host     # Traxd server
        self.port    = port     # Traxd port
        self.local   = local    # We're running on the same host as traxd -- reverses in/out
        self.verbose = verbose
        self.units   = units
        self.indent  = indent   # Indentation string
        self.color   = color    # Use different colors for different indentations?
        self.percent = percent  # Show percentages of total bandwidth instead of absolute qties.
        self.graph   = graph

class Trax(Process):

    unitmap = {
        'B':1/1, 'KB':1/1000, 'MB':1/1000000, 'GB':1/1000000000, 'TB':1/1000000000000,
        'b':8/1, 'Kb':8/1000, 'Mb':8/1000000, 'Gb':8/1000000000, 'Tb':8/1000000000000,
        }

    unitkeys = set([ord(u[0]) for u in unitmap.keys()])

    def __init__(self, pm, log_name=None, prefs=TraxPrefs()):

        if log_name is None:
            log_name = f"trax/{time()}"

        Process.__init__(self, pm, log_name)

        self.prefs = prefs
        self.host  = prefs.host
        self.port  = prefs.port

        self.set_units(prefs.units)

        #
        # Tallies:
        #
        # MAC and IP addys here are 6 and 4 byte bytes objects respectively (not human-friendly strings).
        #
        # Maps (mac, etype, protocol, lan_ip, lan_port, wan_ip, wan_port) to (inpackets, inbytes, outpackets, outbytes)
        #
        self.totals     = {}
        self.start_time = None  # Will be the 'when' of the first (discarded) update.
        self.end_time   = None  # Will be the 'when' of the most recently received update.
        self.elapsed    = None  # Will be the difference between the two.

        #
        # Name mappings:
        #
        self.rootmac  = None    # Will be the mac addy of the server (presumably router) running traxd.
        self.mac2ip   = {}      # Maps text MAC addy to text IP addy.
        self.ip2host  = {}      # Maps text IP addy to text hostname.
        self.port2service = {}  # Maps (integer_port_number, is_udp) tuple to name of service at that port.
        self.resolveQ = Queue() # queue of IP addys or ports to reverse map into ip2host and port2service

        self.port2service[prefs.port, False] = 'trax'
        self.port2service[       443,  True] = 'http(s)'    # getnameinfo() doesn't recognize UDP http...

        #
        # UI:
        #
        self.width     = None
        self.height    = None
        self.log_lines = 20     # Number of log lines at the bottom.  On initialization, this is a percentage.  After that, it's the count.
        self.log_hist  = []     # Last log_lines of log history

        self.items   = {}            # self.totals reformatted into a hierarchy of dicts (see reformat_key).  At each level, a special '#' entry holds totals.
        self.cursor  = ('/',)        # The prefix describing the currently selected/highlighted item.  (See render_items_)
        self.is_open = set([('/',)]) # Set of item key prefixes that are "open" (as in expanded) in the UI.  '/' is special token for the root
        self.lines   = []            # self.items rendered as (text, attr, prefix) lines suitable to the current window size.

        self.is_dirty     = Event()  # Needs re-render.

    def set_units(self, units):

        if units not in self.unitmap:
            self.log("Invalid units {units}.  Must be one of {self.unitmap.keys()}")
            return

        self.prefs.units = units
        self.scale       = self.unitmap[units]

    def reset(self):
        self.totals     = {}
        self.start_time = None
        self.end_time   = None
        self.elapsed    = None
        self.items      = {}
        #self.cursor     = ('/',)           # If user doesn't try to move it first, cursor should re-appear when matching host does..
        #self.is_open    = set([('/',)])    # Likewise, is_open should survive reset reasonably well.
        self.lines      = []
        self.dirty()

    def reformat_key(self, key):
        """Here key is a connection specification tuple as sent in a traxd sums message.
        This reformats it into a hierarchy-ordered tuple of human-friendly names.

        NOTE that this is NOT deterministic (so beware caching) since it
            changes as the various name mapping update (ip2host and such).
        """
        if self.prefs.local:
            mac, etype, protocol, wan_ip, wan_port, lan_ip, lan_port = key
        else:
            mac, etype, protocol, lan_ip, lan_port, wan_ip, wan_port = key

        mac      = mac.hex(':')
        mac      = self.ip2host.get(self.mac2ip.get(mac), mac)

        lan_ip   = self.ip2host_('.'.join(str(i) for i in lan_ip)) if lan_ip else None
        wan_ip   = self.ip2host_('.'.join(str(i) for i in wan_ip)) if wan_ip else None

        protocol = (proto_names.get(protocol) or str(protocol)) if protocol is not None else None
        etype    = ethertype_names.get(etype) or str(etype)

        lan_port = self.port2service_(lan_port, protocol) if lan_port is not None else None
        wan_port = self.port2service_(wan_port, protocol) if lan_port is not None else None

        #
        # At this point wan_ip is either None, an ip addy, or ideally a full host name.
        #
        # We'll break it into two parts to compact the hierarchy more usefully -- the
        #  main domain name (e.g., google.com) and then the actual exact host (e.g.,
        #   foo123.x95.blarp.google.com)
        #
        if wan_ip == '*':
            domain = 'Multicast'
        elif wan_ip:
            host = wan_ip.split('.')
            if host[-1].isalpha():
                domain = '.'.join(host[-2:])
            else:
                domain = "Unnamed"
        else:
            domain = None
        host = wan_ip

        if lan_ip and lan_ip != mac:
            mac = f"{mac}/{lan_ip}"
        return (mac, etype, protocol, wan_port, domain, host, lan_port)

    def render_items(self, num, width):
        """This formats the first num lines of a human-friendly display
            of self.items.  Returns (text, attrs, prefix) tuples where
            text and attrs are suitable for rendering with Window.show()
            and prefix indicates which item is visible on each line.

        Width is the visible number of characters per line.
        """
        if self.prefs.percent or self.prefs.graph:
            # Caching scales based on top level totals, used by format_line in percent mode:
            _, rinb, _, routb = self.items.get('/',{}).get('#', (0, 0, 0, 0))
            self.ri_scale = (100/(rinb+routb+1), 100/(rinb+1), 100/(routb+1))   # +1 avoids both divide by zero and 100% (the latter blows the formatting budget)

        if self.prefs.graph:
            # Do a pre-pass to find how much space we have for the graph...
            lines = []
            self.render_items_(lines, num, width, self.items, 0, (), True)
            minspace = None
            for space, attrs, prefix in lines:
                if minspace is None or space < minspace:
                    minspace = space
            minspace   = max(1, minspace or 1)
            leftspace  = max(0, int(minspace*rinb/(rinb+routb+1)))
            rightspace = max(0, minspace - leftspace - 1)
            self.ri_gspace = (leftspace, rightspace)

        lines = []
        self.render_items_(lines, num, width, self.items, 0, ())
        return lines

    def render_items_(self, lines, num, width, items, indent, prefix, calc_space=False):
        order = []
        for name, d in items.items():
            if name == '#':
                continue
            vals = d.get('#', (0, 0, 0, 0))
            sumb = vals[1] + vals[3]
            order.append((sumb, vals, name, d))
        order.sort(reverse=True)
        for sumb, vals, name, d in order:

            p2    = prefix+(name,)
            attrs = self.text_attrs.get(indent, 0)

            if p2 == self.cursor:
                attrs |= Window.Reverse            # Highlight the current cursor position

            lines.append((self.format_line(width, indent, name, sumb, vals, calc_space), attrs, p2))

            if len(lines) >= num:
                return

            if p2 in self.is_open:
                self.render_items_(lines, num, width, d, indent+1, p2, calc_space)

    def format_line(self, width, indent, name, sumb, vals, calc_space=False):
        """Returns the formatted line, unless calc_space is True in which
            case it returns how much space would be left for a bargraph.
        """
        if name == '/':
            name = f"Elapsed Time {int(self.elapsed):,d} seconds"

        elapsed = self.elapsed or 1
        inp, inb, outp, outb = vals

        units = self.prefs.units
        scale = self.scale

        if self.prefs.percent and indent:
            ssum, sin, sout = self.ri_scale
            rhs = f" {sumb*scale:11,.3f}{units}, {sumb*ssum:10.6f}% ({inb*sin:10.6f}% up + {outb*sout:10.6f}% dn){self.prefs.indent*(7-indent)}"
        else:
            rhs = f" {sumb*scale:11,.3f}{units}, {sumb*scale/elapsed:10.6f}{units}/s ({inb*scale/elapsed:10.6f}{units}/s up + {outb*scale/elapsed:10.6f}{units}/s dn){self.prefs.indent*(7-indent)}"

        width -= len(rhs)
        lhs = f"{self.prefs.indent*indent}{name}"[:width]

        if calc_space:
            return width-len(lhs)

        if self.prefs.graph:
            ssum, sin, sout       = self.ri_scale
            leftspace, rightspace = self.ri_gspace
            lbar  = min( leftspace, int( leftspace* inb* sin/100))
            rbar  = min(rightspace, int(rightspace*outb*sout/100))
            graph = ' '*(leftspace-lbar) + '^'*lbar + '|' + 'v'*rbar + ' '*(rightspace-rbar)
            rhs   = graph + rhs
            width -= len(graph)

        if len(lhs) < width:
            return lhs + (' '*(width-len(lhs))) + rhs

        return lhs + rhs

    def text_config(self, win):
        """Returns a dict mapping integer indent level to the Window
            attributes for embelishing a line at that level.
        """
        if self.prefs.color:
            return {
                #0: win.colorPairAttribute(Window.Black, Window.Yellow),
                0: win.colorPairAttribute(Window.Yellow),
                # 1 defaults to white on black
                2: win.colorPairAttribute(Window.Yellow),
                3: win.colorPairAttribute(Window.Magenta),
                4: win.colorPairAttribute(Window.Cyan),
                5: win.colorPairAttribute(Window.Green),
                6: win.colorPairAttribute(Window.Red),
                7: win.colorPairAttribute(Window.Blue),
            }
        else:
            return {}

    #
    # TODO: We could make this more efficient (for long runs with large totals dict)
    #       by considering the dirty chain from traxd updates, and updates to the
    #       port and service mappings, and then editing a cached version of the
    #       key-reformatted list.
    #
    def update_items(self):
        """This (currently from scratch) generates self.items from self.totals.
        """
        self.items = {}
        for key, vals in self.totals.items():
            inp1, inb1, outp1, outb1 = vals
            key2 = self.reformat_key(key)               # Translate from raw traxd to our human-friendly names and order
            d    = self.items
            for k in ('/',) + key2:                     # Traverse the hierarchy...
                if k in d:
                    d = d[k]
                    inp2, inb2, outp2, outb2 = d.get('#', (0, 0, 0, 0))
                    d['#'] = (inp1+inp2, inb1+inb2, outp1+outp2, outb1+outb2)
                else:
                    d[k] = d = {}   # Wow, the (left-first) order here in python is surprising
                    d['#'] = vals

    def handle_update(self, u):
        """This processes a single update received from traxd.
        """
        kind = u.get('kind')

        if kind == 'sums':

            if self.start_time is None:
                self.start_time = self.end_time = u['when'] # Let this throw an error if it's missing...
                self.elapsed = 0
                self.log(f"Start time: {datetime.fromtimestamp(self.start_time)}")
                return

            self.end_time = u['when']
            self.elapsed = self.end_time - self.start_time

            d = u.get('vals')
            if d:
                for key, (inp, inb, outp, outb) in d.items():
                    inp2, inb2, outp2, outb2 = self.totals.get(key, (0, 0, 0, 0))
                    self.totals[key] = (inp+inp2, inb+inb2, outp+outp2, outb+outb2)
                    if self.prefs.verbose > 1:
                        self.log(f"{self.reformat_key(key)}: {inb}+{outb}")
                if self.prefs.verbose:
                    self.log(f"------------ elapsed {int(self.elapsed)} seconds; {len(d)} items updated ----------")
            else:
                if self.prefs.verbose:
                    self.log(f"------------ elapsed {int(self.elapsed)} seconds; [Empty update] ----------")

            self.update_items() # TODO: Would be better if this was done on a dirty/clean sweep; name mapping updates should also mark dirty!
            self.redraw_items() # TODO: Likewise..

        elif kind == 'arp':
            self.handle_arp(u)
        else:
            self.log(pformat(u))

    def handle_arp(self, u):
        """This parses raw 'arp -a' output (as run on the traxd server) and
            establishes a mac->IP mapping as well as priming the IP->hostname
            mapping.

        Also, the server informs us of its own MAC addy here.
        """
        self.rootmac = u.get('my_mac')
        arp          = u.get('raw')

        for line in arp.split('\n'):
            if not line:
                continue
            try:
                host, ip, at, mac, rest = line.split(maxsplit=4)

                if at != 'at':
                    raise Exception(f"Got {at!r} when expecting 'at'")
                if len(ip) < 2 or ip[0] != '(' or ip[-1] != ')':
                    raise Exception(f"Missing parens around IP addy")
                if mac == '<incomplete>':
                    mac = None
                elif ':' not in mac:
                    raise Exception(f"Invalid MAC addy")

            except Exception as e:
                self.log(f"ARP: {e} in {line!r}")
                continue

            ip = ip[1:-1]

            if mac:
                self.mac2ip[mac.lower()] = ip

            if len(host) > 1:
                self.ip2host[ip] = host

        if self.prefs.verbose > 2:
            self.log(pformat(self.mac2ip))
            self.log(pformat(self.ip2host))

    async def run_(self):

        self.log_queue = CloseableQueue()
        self.log("Redirecting log to Window....")
        Logger.log_channel_handler('raw-log', self.log_queue.push)         # If we want to redirect log output...

        self.win = Window(color=self.prefs.color)

        self.text_attrs = self.text_config(self.win)

        self.spawn(self.win.run())
        self.spawn(self.monitor_resizes())
        self.spawn(self.monitor_log())
        self.spawn(self.monitor_refresh())
        self.spawn(self.monitor_network())
        self.spawn(self.monitor_resolveQ())

        #
        # Really any thread could go here...
        #
        try:
            await self.monitor_keystrokes()
        finally:
            self.win.close()
            self.win = None
            Logger.log_channel_handler('raw-log', Logger.log_channel_rawlog)
            self.log("Redirecting log to stdout....")

    def ip2host_(self, ip):
        if ip is None:
            return None
        if ip in self.ip2host:
            return self.ip2host[ip]
        self.resolveQ.push(ip)
        return ip

    def port2service_(self, port, protocol):
        if port == 0:
            return '<userport>'
        udp = protocol == 'UDP'
        key = (port, udp)
        if key in self.port2service:
            return self.port2service[key]
        self.resolveQ.push(key)
        return f"{port}/{protocol}"

    async def monitor_resolveQ(self):
        async for ip in self.resolveQ:
            if isinstance(ip, str):
                if ip in self.ip2host:
                    continue
                if self.prefs.verbose:
                    self.log(f"Hostname lookup: {ip}")
                try:
                    host, port = await socket.getnameinfo((ip, 0), 0)
                except Exception as e:
                    host = ip
                    self.log(f"{ip}: {e}")
                self.ip2host[ip] = host

            else:   # actually a port number and udp flag
                port, udp = key = ip
                if key in self.port2service:
                    continue
                if self.prefs.verbose:
                    self.log(f"Portname lookup: {port}{' (UDP)' if udp else ''}")
                try:
                    _, name = await socket.getnameinfo(('127.0.0.1', port), socket.NI_NUMERICHOST | (socket.NI_DGRAM if udp else 0))
                except Exception as e:
                    self.log(str(e))
                    name = str(port)
                else:
                    if name in ('http', 'https'):
                        name = 'http(s)'
                self.port2service[key] = name

    async def monitor_network(self):

        async def read_exactly(num):
            """Reads exactly num bytes from sock.
            """
            nonlocal inq
            while len(inq) < num:
                data = await sock.recv(4096)
                if not data:
                    raise EOFError()
                inq.extend(data)
            if len(inq) == num:
                data = inq
                inq  = bytearray()
            else:
                data = inq[:num]
                del    inq[:num]
            return data

        while True:

            inq = bytearray()

            while True:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                try:
                    await sock.connect((self.host, self.port))
                except Exception as e:
                    sock.close()
                    self.log(f"Socket connect failed.  Retrying in 10 seconds.  ({e})")
                    await sleep(10)
                    continue
                break

            self.log("Connection established.")

            while True:
                try:
                    size, = unpack("!H", await read_exactly(2))
                    blob  = await read_exactly(size)
                    msg   = loads(blob)
                except Exception as e:
                    self.log(f"Lost connection.  ({e})")
                    break
                #self.log(str(msg))
                self.handle_update(msg)

            sock.close()
            await sleep(1)

    def current_line(self):
        for i, (text, attrs, prefix) in enumerate(self.lines):
            if prefix == self.cursor:
                return i
        return 0

    def move_cursor(self, n):
        """Move the cursor down n lines (negative for up).
        """
        i = min(len(self.lines)-1, max(0, self.current_line() + n))
        if i >= 0:
            self.cursor = self.lines[i][2]  # [2] = prefix
        self.redraw_items() # TODO: could just redraw the two lines?

    async def monitor_keystrokes(self):

        async for key in self.win.keystrokes:

            if self.win is None:
                return

            if key == ord('q'):
                self.quit()

            elif key in up_keys:
                self.move_cursor(-1)

            elif key in down_keys:
                self.move_cursor(1)

            elif key in right_keys:
                self.is_open.add(self.cursor)
                self.redraw_items()

            elif key in left_keys:
                if self.cursor in self.is_open:
                    # Close the item the cursor is over, if it's open:
                    self.is_open.discard(self.cursor)
                elif self.cursor:
                    # Close (and move to) the parent item if this item is already closed:
                    self.cursor = self.cursor[:-1]
                    self.is_open.discard(self.cursor)
                self.redraw_items()

            elif key == ord('-'):
                self.log_lines = max(0, self.log_lines-1)
                self.redraw()

            elif key == ord('_'):
                self.log_lines += 1
                self.redraw()

            elif key == ord('r'):
                self.log("Resetting totals...")
                self.reset()
                self.redraw_items()

            elif key in (ord('p'), ord('%')):
                self.prefs.percent ^= True
                self.redraw_items()

            elif key == ord('g'):
                self.prefs.graph ^= True
                self.redraw_items()

            elif key == ord('b'):
                u = list(self.prefs.units)
                u[-1] = 'b' if u[-1] == 'B' else 'B'
                self.set_units(''.join(u))
                self.redraw_items()

            elif key in self.unitkeys:
                if key == ord('B'):
                    self.set_units(self.prefs.units[-1])
                else:
                    self.set_units(chr(key) + self.prefs.units[-1])
                self.redraw_items()

            elif key == ord('?'):
                self.log("""q           - quit
hjkl/arrows - navigate up and down, open and close items
r           - reset the totals to 0
p/%         - toggle percentages
g           - toggle graph
b           - toggle bits/bytes
MTGBK       - set units
-/_         - move the log bar up and down
?           - show help""")

    async def monitor_resizes(self):

        async for size in self.win.resizes:
            if self.win is None:
                return

            if self.width is None:
                self.log_lines = size[1]*self.log_lines//100
            self.width, self.height = size
            #self.constrain_cursory()
            self.redraw()
            #self.log(f"Resized to {size}")

    async def monitor_log(self):
        y = 0
        async for msg in self.log_queue:
            for line in msg.split('\n'):
                self.log_hist.append(line)
            self.log_hist = self.log_hist[-self.log_lines:]
            self.redraw_log()

    async def monitor_refresh(self):
        while True:
            await self.is_dirty.wait()
            self.is_dirty = Event()
            self.win.flush()
            await sleep(0.01)
            if self.win is None:
                return

    def dirty(self):
        self.is_dirty.set()

    def redraw(self):
        """Redraw entire screen.  Typically after a resize.
        """
        #self.win.clear()
        self.redraw_items()
        self.redraw_divider()
        self.redraw_log()

    def redraw_items(self):
        height     = self.height - self.log_lines - 1
        self.lines = self.render_items(height, self.width)  # TODO: This can be skipped if prereqs haven't changed since last call...
        blank      = ' '*self.width
        for y in range(height):
            if y < len(self.lines):
                item, attrs, prefix = self.lines[y][:self.width]
                if len(item) < self.width:
                    item += ' '*(self.width-len(item))
                self.win.show(0, y, item, attrs)
            else:
                self.win.show(0, y, blank)
        self.dirty()

    def redraw_divider(self):
        if self.win is not None:
            self.win.show(0, self.height-self.log_lines-1, '-'*self.width)
            self.dirty()

    def redraw_log(self):
        if self.width is None or self.win is None:
            return
        for i in range(self.log_lines):
            if i < len(self.log_hist):
                s = self.log_hist[i]
            else:
                s = ''
            s = s[:self.width]
            s += ' '*(self.width-len(s))
            self.win.show(0, self.height-self.log_lines+i, s)
        self.dirty()


if __name__ == "__main__":

    def run():
        #
        # Arg parsing:
        #
        from sys import argv
        help_str = f"Use: {argv[0]} [traxd_host(localhost)] [traxd_port(3344)] [options..]\n" \
                    "  Options:\n" \
                    "   -u <units>  -- [%][TGMK][Bb] ; default %MB (total in megabytes, and percentages below)\n" \
                    "   -c          -- toggle color mode (default: On)\n" \
                    "   -g          -- toggle bar graph (default: On)\n" \
                    "   -i          -- turn off indentation\n" \
                    "   -I <str>    -- use <str> instead of '  ' to indent\n" \
                    "   -l          -- local mode (inverts the hierarchy)\n" \
                    "   -f          -- increase verbosity\n" \
                    "   -h          -- show this help\n" \
                    ""

        prefs = TraxPrefs() # Starts with reasonable defaults
        args  = argv[1:]

        hostset = portset = False

        def nextarg():
            if args:
                return args.pop(0)
            else:
                print(help_str)
                exit(1)
        while args:
            arg = nextarg()

            if arg == '-h':
                print(help_str)
                exit(0)

            elif arg == '-u':
                u = nextarg()
                if u.startswith("%"):
                    prefs.percent = True
                    u = u[1:]
                if len(u) > 1:
                    u = u[0].upper() + u[1]
                prefs.units = u

            elif arg == '-v':
                prefs.verbose += 1

            elif arg == '-c':
                prefs.color ^= True

            elif arg == '-i':
                prefs.indent = ''

            elif arg == '-I':
                prefs.indent = nextarg()

            elif arg == '-l':
                prefs.local ^= True

            elif arg == '-g':
                prefs.graph ^= True

            elif not hostset:
                hostset    = True
                prefs.host = arg

            elif not portset:
                portset = True
                try:
                    prefs.port = int(arg)
                except:
                    print("Port must be an integer.")
                    exit(1)
            else:
                print(help_str)
                exit(1)

        if prefs.units not in Trax.unitmap:
            print(f"Bad units specification ({prefs.units!r}).")
            exit(1)

        pm = ProcessManager()
        Trax(pm, "trax", prefs=prefs)
        pm.run()

    run()

