#!/usr/bin/python3
#
# Trax (network activity tracker) Version 1.4
#
# Copyright 2023 Brandyn Webb
# 
# This file is part of Trax (network activity tracker).
# 
# Trax is free software: you can redistribute it and/or modify it under the terms 
# of the GNU General Public License as published by the Free Software Foundation, 
# either version 3 of the License, or (at your option) any later version.
# 
# Trax is distributed in the hope that it will be useful, but WITHOUT ANY 
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 
# PARTICULAR PURPOSE. See the GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along with 
# Trax. If not, see <https://www.gnu.org/licenses/>. 
# 


use_ctypes = True   # pcapy stopped working in Ubuntu '22 LTS (python3.10, etc), so calling the C lib directly via ctypes...
                    # Need to install libpcap-dev (e.g., apt install libpcap-dev) for this to work.
                    # If this is false, you need pcapy installed (and working...)

import socket, select, pickle, subprocess

from struct   import pack, unpack
from datetime import datetime
from time     import time, sleep

from sys import argv, exit

#
# Defaults:
#
reporting_period = 1            # Seconds
verbose          = 0            # Logging output verbosity level
ip               = ''           # IP addy to accept connections on.  Default: ALL.
port             = 3344         # Default port to accept connections at.
debug            = False        # If True, scan packets even if no clients are listening.

#
# Logging and failure reporting:
#
def log(msg, verbose_level=0):  # Pre-filtering on verbosity is preferred to save on message formatting
    if verbose >= verbose_level:
        print(msg, flush=True)

def err(msg, exit_code=1):
    print(msg, flush=True)      # TODO: should send to stderr instead
    exit(exit_code)

# These are just used for logging:
def mac_str(mac):
    """MAC 6-bytes -> printable version
    """
    if mac is None:
        return "N/A"
    return mac.hex(':')

def ip_str(ip):
    """IPV4 IP 4-bytes -> printable version
    """
    if ip is None:
        return "N/A"
    s = '.'.join(str(i) for i in ip)
    if (ip[0]>>4) == 14:    # Multicast
        s += "(M)"
    return s


#
# We'll wrap libpcap directly, via ctypes, in a mock pcapy class to keep the rest of the code the same.
#
# We could/should pull this out into a separate file, but keeping traxd all one file makes
#  it a lot easier to install...
#
if use_ctypes:
    import ctypes

    libpcap = ctypes.cdll.LoadLibrary("libpcap.so")

    #
    # Much of this is adopted from https://github.com/alagoa/libpcapy
    # (but that didn't support various things I needed like nonblock
    #  and dispatch..)
    #
    class t_sockaddr(ctypes.Structure):
        _fields_ = [("sa_family", ctypes.c_ushort),
                    ("sa_data"  , ctypes.c_char * 14)]

    class t_pcap_addr(ctypes.Structure):
        pass
    t_pcap_addr._fields_ = [(     'next', ctypes.POINTER(t_pcap_addr)),
                            (     'addr', ctypes.POINTER(t_sockaddr)),
                            (  'netmask', ctypes.POINTER(t_sockaddr)),
                            ('broadaddr', ctypes.POINTER(t_sockaddr)),
                            (  'dstaddr', ctypes.POINTER(t_sockaddr))]

    class t_pcap_if(ctypes.Structure):
        pass
    t_pcap_if._fields_ = [(       'next', ctypes.POINTER(t_pcap_if)),
                          (       'name', ctypes.c_char_p),
                          ('description', ctypes.c_char_p),
                          (  'addresses', ctypes.POINTER(t_pcap_addr)),
                          (      'flags', ctypes.c_uint)]

    class t_timeval(ctypes.Structure):
        _fields_ = [('tv_sec' , ctypes.c_long),
                    ('tv_usec', ctypes.c_long)]


    class t_pcap_pkthdr(ctypes.Structure):
        _fields_ = [(    'ts', t_timeval),
                    ('caplen', ctypes.c_uint),
                    (   'len', ctypes.c_uint)]

    class Pcapy(object):

        errbuf = ctypes.create_string_buffer(256)

        f_findalldevs          = libpcap.pcap_findalldevs
        f_findalldevs.restype  = ctypes.c_int
        f_findalldevs.argtypes = [ctypes.POINTER(ctypes.POINTER(t_pcap_if)), ctypes.c_char_p]

        f_freealldevs          = libpcap.pcap_freealldevs
        f_freealldevs.restype  = None
        f_freealldevs.argtypes = [ctypes.POINTER(t_pcap_if)]

        f_open_live            = libpcap.pcap_open_live
        f_open_live.restype    = ctypes.POINTER(ctypes.c_void_p)
        f_open_live.argtypes   = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p]

        def findalldevs(self):
            #log("Findalldevs")

            alldevs = ctypes.POINTER(t_pcap_if)()
            result  = self.f_findalldevs(ctypes.byref(alldevs), self.errbuf)

            if result:
                raise Exception(self.errbuf.value.decode())
            else:
                devices = []
                device  = alldevs.contents
                while device:
                    devices.append(device.name.decode())
                    if device.next:
                        device = device.next.contents
                    else:
                        device = None
                self.f_freealldevs(alldevs)
                alldevs = None

            return devices

        def open_live(self, dev, snaplen, promisc, ms):
            #log(f"Open_live({dev}, {snaplen}, {promisc}, {ms})")
            cap = self.f_open_live(dev.encode(), snaplen, promisc, ms, self.errbuf)
            if not cap:
                raise Exception(f"Error opening {dev}: {self.errbuf.value.decode()}")
            return PcapyCap(cap)

    class PcapyCap(object):

        errbuf = ctypes.create_string_buffer(256)

        t_handler = ctypes.CFUNCTYPE(None, ctypes.POINTER(ctypes.py_object), ctypes.POINTER(t_pcap_pkthdr), ctypes.POINTER(ctypes.c_ubyte))

        f_setnonblock          = libpcap.pcap_setnonblock
        f_setnonblock.restype  = ctypes.c_int
        f_setnonblock.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_char_p]

        f_getfd                = libpcap.pcap_get_selectable_fd
        f_getfd.restype        = ctypes.c_int
        f_getfd.argtypes       = [ctypes.POINTER(ctypes.c_void_p)]

        f_dispatch             = libpcap.pcap_dispatch
        f_dispatch.restype     = ctypes.c_int
        f_dispatch.argtypes    = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, t_handler, ctypes.POINTER(ctypes.py_object)]

        f_close                = libpcap.pcap_close
        f_close.restype        = None
        f_close.argtypes       = [ctypes.POINTER(ctypes.c_void_p)]

        def __init__(self, cap):
            self.cap       = cap
            self.f_handler = self.t_handler(self.handler)

        def getfd(self):
            #log("getfd()")
            return self.f_getfd(self.cap)

        def setnonblock(self, nonblocking=True):

            nonblocking = int(nonblocking)

            rv = self.f_setnonblock(self.cap, nonblocking, self.errbuf)

            if rv < 0:
                if rv == -1:
                    raise Exception(f"Setnonblock({nonblocking}) failed (code {rv}): {self.errbuf.value.decode()}")
                raise Exception(f"Setnonblock({nonblocking}) failed (code {rv}).")

            #log(f"Setnonblock({nonblocking}) = {rv}")
            return rv

        def dispatch(self, count, callback):
            #log(f"dispatch({count}, {callback})")
            return self.f_dispatch(self.cap, count, self.f_handler, ctypes.pointer(ctypes.py_object(callback)))

        def close(self):
            self.f_close(self.cap)

        #--- Internal ---

        def handler(self, callback, header, payload):

            callback = callback.contents.value  # Pointer to python value
            header   = header.contents          # Pointer to pkhdr
            payload  = ctypes.string_at(payload, header.caplen)

            #log(f"callback2(callback={callback}, header={header}, payload={payload})")
            callback(PcapyHeader(header), payload)

    class PcapyHeader(object):

        __slots__ = ('hdr',)

        def __init__(self, hdr):
            self.hdr = hdr

        def getlen(self):
            return self.hdr.len

        def getts(self):
            """Returns (seconds, ms)
            """
            return (self.hdr.ts.tv_sec, self.hdr.ts.tv_usec)

    pcapy = Pcapy()
else:
    import pcapy


#
# Available devs:
#
devs = pcapy.findalldevs()

if not devs:
    err("No devices available.", 2)

net_dev = devs[0]

#
# Arg parsing:
#
help_str         = f"Use: {argv[0]} [<network_interface>] [-v(erbose)] [-r <reporting_period>] [-i <ip>] [-p <port>]\n" \
                   f"  Network interface (default {net_dev!r}) must be one of: {', '.join(devs)}\n" \
                   f"  Default ip (to listen at) is all.  Specify an IP addy of this host to limit.\n" \
                   f"  Default port is {port}\n" \
                   f"  Reporting period is in seconds (default {reporting_period})."

args = argv[1:]

def nextarg():
    if args:
        return args.pop(0)
    else:
        err(help_str, 1)

while args:
    arg = nextarg()

    if arg == '-v':
        verbose += 1

    elif arg == '-r':
        try:
            reporting_period = float(nextarg())
        except:
            err("Reporting period must be numeric seconds.", 1)

    elif arg == '-i':
        ip = nextarg()

    elif arg == '-p':
        try:
            port = int(nextarg())
        except:
            err("Port must be an integer.", 1)

    elif arg in devs:
        net_dev = arg

    elif arg in ('help', '-h'):
        err(help_str, 0)

    elif arg == '-d':
        debug = True

    else:
        err(help_str, 1)

if net_dev not in devs:
    err(f"Can't find network device {net_dev!r}.\nAvailable devices: {', '.join(devs)}", 2)

#
# We'll parse "arp -a" output to prime our MAC->IP mapping, which we'll
#  then try to keep up to date by watching ARP packets go by.
#
# We don't actually use this table ourselves, but rather just pass it
#  on to the clients.
#
# Since we sleep when no clients are connected, we'll have to call this
#  on demand each time the capture wakes up.
#
def load_arp():
    """Returns a dict mapping MAC to IP (bytes), and a dict mapping ip (bytes) to hostname (str).
    """
    arp_out = subprocess.run(['arp', '-a', '-i', net_dev], capture_output=True, check=True, text=True).stdout

    mac2ip = {}
    ip2host= {}

    for line in arp_out.split('\n'):
        if not line:
            continue
        try:
            host, ip, at, mac, rest = line.split(maxsplit=4)

            if at != 'at':
                raise Exception(f"Got {at!r} when expecting 'at'")
            if len(ip) < 2 or ip[0] != '(' or ip[-1] != ')':
                raise Exception(f"Missing parens around IP addy")
            if mac == '<incomplete>':
                mac = None
            elif ':' not in mac:
                raise Exception(f"Invalid MAC addy")

            ip = ip[1:-1]
            ip = bytes([int(w) for w in ip.split('.')])

            if mac:
                mac = bytes.fromhex(mac.replace(':',''))

        except Exception as e:
            log(f"ARP: {e} in {line!r}")
            continue

        if mac:
            mac2ip[mac] = ip

        if len(host) > 1:
            ip2host[ip] = host

    if verbose > 2:
        from pprint import pformat
        log("Arp tables:")
        print(" MAC->IP")
        for mac, ip in mac2ip.items():
            print(f"  {mac_str(mac)} -> {ip_str(ip)}")
        print(" IP->HOST")
        for ip, host in ip2host.items():
            print(f"  {ip_str(ip)} -> {host}")

    return (mac2ip, ip2host)

arp_mac2ip, arp_ip2host = load_arp()

with open(f"/sys/class/net/{net_dev}/address") as fl: # cheap and dirty way to get our own mac addy
    my_mac = fl.read().strip()
my_macb = bytes.fromhex(my_mac.replace(':', ' '))
log(f"My MAC: {my_mac!r} aka {my_macb!r}")

#
# Tally function:
#

# packet types:
pt_ARP  = 0x0806
pt_IPV4 = 0x0800
pt_IPV6 = 0x86DD
pt_80211r = 0x890D

# IPV4 protocol numbers:
TCP = 6
UDP = 17

def tally_packet(header, payload):

    if verbose:
        when, ms = header.getts()
        when += ms/1000000

    dstmac = payload[0: 6]
    srcmac = payload[6:12]
    etype, = unpack("!H", payload[12:14])

    size = header.getlen()

    outbound  = srcmac == my_macb
    broadcast = dstmac[0]&1

    if not (outbound or broadcast) and dstmac != my_macb:
        # This might happen routinely if the interface is in permiscuous mode for any reason?
        if verbose > 1:
            log(f"WARNING: got packet not for us?  {mac_str(srcmac)} -> {mac_str(dstmac)}", 2)
        return

    #
    # Protocol-aware tracking:
    #
    src_ip = dst_ip = src_port = dst_port = protocol = None
    if etype == pt_IPV4:

        protocol = payload[14+ 9]
        src_ip   = payload[14+12:14+16]
        dst_ip   = payload[14+16:14+20]

        if protocol in (TCP, UDP):
            src_port, dst_port = unpack('!HH', payload[34:38])

            #
            # For tallying connections, treat all 5+ digit port numbers as transient.
            #
            # This is somewhat arbitrary but works well enough in practice to
            #   limit the proliferation of entries.
            #
            if src_port > 9999:
                src_port = 0
            if dst_port > 9999:
                dst_port = 0

        else:
            # Protocols other than UDP/TCP
            if verbose > 1:
                log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} P:{protocol}(!=UDP/TCP) {ip_str(src_ip)}->{ip_str(dst_ip)}", 2)

    elif etype == pt_ARP:

        ptyp, = unpack("!H", payload[14+ 2:14+ 4])

        if ptyp == pt_IPV4:
            oper  = payload[14+ 7]          # 1 = ask / 2 = reply
            sha   = payload[14+ 8:14+14]
            spa   = payload[14+14:14+18]
            tha   = payload[14+18:14+24]
            tpa   = payload[14+24:14+28]

            if sum(spa) and sum(sha):   # Make sure neither are 0's aka unknown
                if spa != arp_mac2ip.get(sha):
                    arp_mac2ip[sha] = spa               # Update our cache
                    send_msg(**arp_msg({sha:spa}))      # Notify all clients of the change
                    if verbose:
                        log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} ARP CHANGE: {mac_str(sha)} -> {ip_str(spa)}")
                elif verbose > 1:
                    log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} ARP (same): {mac_str(sha)} -> {ip_str(spa)}")
            elif verbose > 1:
                log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} ARP (null): {mac_str(sha)} -> {ip_str(spa)}")
        else:
            if verbose:
                log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)}: Unhandled ARP type: {ptyp:04X}")

    elif etype == pt_IPV6:
        # TODO...
        if verbose > 1:
            log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} IPV6", 2)

    elif etype == pt_80211r:
        # TODO...
        if verbose > 1:
            log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} 802.11r", 2)

    else:
        if verbose:
            log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} UNKNOWN etype (0x{etype:04x})", 1)

    if outbound:
        key = (dstmac, etype, protocol, dst_ip, dst_port, src_ip, src_port)
    else:
        key = (srcmac, etype, protocol, src_ip, src_port, dst_ip, dst_port)

    inpackets, inbytes, outpackets, outbytes = totals.get(key, (0, 0, 0, 0))

    if outbound:
        outpackets += 1
        outbytes   += size
    else:
        inpackets += 1
        inbytes   += size

    totals[key] = (inpackets, inbytes, outpackets, outbytes)


#
# Listen for connections at the specified ip/port:
#
log(f"Listening at {ip or 'ALL'}:{port}")

while True:
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        server.setblocking(0)
        server.bind((ip, port))
        server.listen(5)
        break
    except Exception as e:
        server.close()
        log(e)
        log("Retrying in 10 seconds...")
        sleep(10)

clients = {}    # Maps client socket to output queue.

def arp_msg(mac2ip, ip2host={}):
    return dict(kind='arp', my_mac=my_mac, mac2ip=mac2ip, ip2host=ip2host)

def new_connection():
    client, client_address = server.accept()
    log(f"New connection from {client_address}", 1)
    client.setblocking(0)
    client.shutdown(socket.SHUT_RD)  # Write-only socket; don't queue up incoming data because we'll never read it.
    clients[client] = []

    send_msg_to(client, **arp_msg(arp_mac2ip, arp_ip2host))  # Send cached arp output first thing.

def close_connection(con):
    if con in clients:
        del clients[con]
    else:
        log(f"WARNING: Closing connection that's not in clients list.")
    con.close()
    log(f"Client closed.", 1)

def pack_obj(o):
    b = pickle.dumps(o)
    return pack('!H', len(b)) + b

DROP = pack_obj({'kind':'drop'}) # Special message meaning outgoing frames were dropped
def send_msg(**msg):
    """Sends msg to all connected clients.
    Returns number of bytes queued to send (per client).
    """
    if not clients and verbose < 2: # Verbose>=2 logs the message length.
        return 0
    msg = pack_obj(msg)
    for queue in clients.values():
        if len(queue) < 4:
            queue.append(msg)
        elif len(queue) < 10:
            queue.append(DROP)
    return len(msg)

def send_msg_to(client, **msg):
    """Sends msg to the specified client.
    """
    if client not in clients:
        return

    msg   = pack_obj(msg)
    queue = clients[client]

    if len(queue) < 4:
        queue.append(msg)
    elif len(queue) < 10:
        queue.append(DROP)

def refresh_arp():
    """Returns True if anything has changed.  (Shouldn't happen often?)
    """
    global arp_mac2ip, arp_ip2host

    dirty = False

    mac2ip, ip2host = load_arp()

    for mac, ip in mac2ip.items():
        if arp_mac2ip.get(mac) != ip:
            arp_mac2ip[mac] = ip
            dirty = True
            if verbose:
                log(f"ARP update: {mac} -> {ip}")

    for ip, host in ip2host.items():
        if arp_ip2host.get(ip) != host:
            arp_ip2host[ip] = host
            dirty = True
            if verbose:
                log(f"ARP update: {ip} -> {host}")

    if dirty:
        if verbose:
            log("ARP tables have changed.")
    elif verbose > 1:
        log("ARP tables unchanged.")

    return dirty

#
# Do it:
#

cap = cap_fd = None

try:
    totals = {}
    t1 = time() + reporting_period  # When to issue first report

    while True:

        #
        # Start and stop packet capture on-demand:
        #
        if cap is None:
            if clients or debug:

                log(f"Opening network interface {net_dev}...", 1)
                try:
                    # device, bytes to capture per packet, promiscuous mode, timeout (ms)
                    cap    = pcapy.open_live(net_dev, 64, False, min(1000, int(reporting_period*1000)))
                    cap_fd = cap.getfd()
                    cap.setnonblock(True)
                except Exception as e:
                    cap    = None
                    cap_fd = None
                    err(e, 3)       # TODO: should retry after a while instead of exitting?
                t1 = time() + reporting_period  # When to issue next report

                #
                # Refresh our arp tables (we've been asleep, so things may have changed...)
                #
                if refresh_arp():
                    send_msg(**arp_msg(arp_mac2ip, arp_ip2host))  # Notify all clients of updated

        else:
            if not (clients or debug):
                log(f"Closing network interface {net_dev}...", 1)
                cap.close()
                cap = cap_fd = None

        #
        # Wait for anything we can do...
        #
        read_fds = [server]

        if cap_fd is not None:
            read_fds.append(cap_fd)

        write_fds = [sock for sock, queue in clients.items() if queue]

        read_ready, write_ready, err_ready = select.select(read_fds, write_fds, (), max(0, t1-time()) if (cap is not None or clients) else 9999)

        #
        # Capture packets, and accept new connections:
        #
        for fd in read_ready:

            if fd == cap_fd:
                cap.dispatch(-1, tally_packet)

            elif fd == server:
                new_connection()

        #
        # Send queued output to any clients that aren't backed up:
        #
        for fd in write_ready:

            if fd in clients:
                queue = clients[fd]
                while queue:
                    try:
                        msg = queue[0]
                        n = fd.send(msg)
                        if n >= len(msg):
                            queue.pop(0)
                        elif n > 0:
                            queue[0] = msg[n:]
                        else:
                            break
                    except Exception as e:
                        log(f"Exception during send: {e}", 1)
                        close_connection(fd)
                        break

        #
        # Generate periodic reports (to client queues):
        #
        when = time()
        if when >= t1:

            n = send_msg(
                    kind='sums',
                    when=when,
                    vals=totals,
                )

            if verbose > 1:
                log(f"---- {datetime.fromtimestamp(int(when))}: {len(totals)} = {n} bytes  ----", 2)

            totals = {}
            t1 += reporting_period

finally:
    log(f"Closing pcap and {len(clients)} clients.")
    for client in clients:
        client.close()
    if cap is not None:
        cap.close()
    log("Done.")

