#!/usr/bin/python3
#
# Trax (network activity tracker) Version 1.0
#
# Copyright 2023 Brandyn Webb
# 
# This file is part of Trax (network activity tracker).
# 
# Trax is free software: you can redistribute it and/or modify it under the terms 
# of the GNU General Public License as published by the Free Software Foundation, 
# either version 3 of the License, or (at your option) any later version.
# 
# Trax is distributed in the hope that it will be useful, but WITHOUT ANY 
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 
# PARTICULAR PURPOSE. See the GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along with 
# Trax. If not, see <https://www.gnu.org/licenses/>. 
# 


import pcapy, socket, select, pickle, subprocess

from struct   import pack, unpack
from datetime import datetime
from time     import time, sleep

from sys import argv, exit

#
# Defaults:
#
reporting_period = 1            # Seconds
verbose          = 0            # Logging output verbosity level
ip               = ''           # IP addy to accept connections on.  Default: ALL.
port             = 3344         # Default port to accept connections at.
debug            = False        # If True, scan packets even if no clients are listening.

#
# Logging and failure reporting:
#
def log(msg, verbose_level=0):  # Pre-filtering on verbosity is preferred to save on message formatting
    if verbose >= verbose_level:
        print(msg, flush=True)

def err(msg, exit_code=1):
    print(msg, flush=True)      # TODO: should send to stderr instead
    exit(exit_code)

# These are just used for logging:
def mac_str(mac):
    """MAC 6-bytes -> printable version
    """
    if mac is None:
        return "N/A"
    return mac.hex(':')

def ip_str(ip):
    """IPV4 IP 4-bytes -> printable version
    """
    if ip is None:
        return "N/A"
    s = '.'.join(str(i) for i in ip)
    if (ip[0]>>4) == 14:    # Multicast
        s += "(M)"
    return s


#
# Available devs:
#
devs = pcapy.findalldevs()

if not devs:
    err("No devices available.", 2)

net_dev = devs[0]

#
# Arg parsing:
#
help_str         = f"Use: {argv[0]} [<network_interface>] [-v(erbose)] [-r <reporting_period>] [-i <ip>] [-p <port>]\n" \
                   f"  Network interface (default {net_dev!r}) must be one of: {', '.join(devs)}\n" \
                   f"  Default ip (to listen at) is all.  Specify an IP addy of this host to limit.\n" \
                   f"  Default port is {port}\n" \
                   f"  Reporting period is in seconds (default {reporting_period})."

args = argv[1:]

def nextarg():
    if args:
        return args.pop(0)
    else:
        err(help_str, 1)

while args:
    arg = nextarg()

    if arg == '-v':
        verbose += 1

    elif arg == '-r':
        try:
            reporting_period = float(nextarg())
        except:
            err("Reporting period must be numeric seconds.", 1)

    elif arg == '-i':
        ip = nextarg()

    elif arg == '-p':
        try:
            port = int(nextarg())
        except:
            err("Port must be an integer.", 1)

    elif arg in devs:
        net_dev = arg

    elif arg in ('help', '-h'):
        err(help_str, 0)

    elif arg == '-d':
        debug = True

    else:
        err(help_str, 1)

if net_dev not in devs:
    err(f"Can't find network device {net_dev!r}.\nAvailable devices: {', '.join(devs)}", 2)

#
# We'll parse "arp -a" output to prime our MAC->IP mapping, which we'll
#  then try to keep up to date by watching ARP packets go by.
#
# We don't actually use this table ourselves, but rather just pass it
#  on to the clients.
#
# Since we sleep when no clients are connected, we'll have to call this
#  on demand each time the capture wakes up.
#
def load_arp():
    """Returns a dict mapping MAC to IP (bytes), and a dict mapping ip (bytes) to hostname (str).
    """
    arp_out = subprocess.run(['arp', '-a', '-i', net_dev], capture_output=True, check=True, text=True).stdout

    mac2ip = {}
    ip2host= {}

    for line in arp_out.split('\n'):
        if not line:
            continue
        try:
            host, ip, at, mac, rest = line.split(maxsplit=4)

            if at != 'at':
                raise Exception(f"Got {at!r} when expecting 'at'")
            if len(ip) < 2 or ip[0] != '(' or ip[-1] != ')':
                raise Exception(f"Missing parens around IP addy")
            if mac == '<incomplete>':
                mac = None
            elif ':' not in mac:
                raise Exception(f"Invalid MAC addy")

            ip = ip[1:-1]
            ip = bytes([int(w) for w in ip.split('.')])

            if mac:
                mac = bytes.fromhex(mac.replace(':',''))

        except Exception as e:
            log(f"ARP: {e} in {line!r}")
            continue

        if mac:
            mac2ip[mac] = ip

        if len(host) > 1:
            ip2host[ip] = host

    if verbose > 2:
        from pprint import pformat
        log("Arp tables:")
        print(" MAC->IP")
        for mac, ip in mac2ip.items():
            print(f"  {mac_str(mac)} -> {ip_str(ip)}")
        print(" IP->HOST")
        for ip, host in ip2host.items():
            print(f"  {ip_str(ip)} -> {host}")

    return (mac2ip, ip2host)

arp_mac2ip, arp_ip2host = load_arp()

with open(f"/sys/class/net/{net_dev}/address") as fl: # cheap and dirty way to get our own mac addy
    my_mac = fl.read().strip()
my_macb = bytes.fromhex(my_mac.replace(':', ' '))
log(f"My MAC: {my_mac!r} aka {my_macb!r}")

#
# Tally function:
#

# packet types:
pt_ARP  = 0x0806
pt_IPV4 = 0x0800
pt_IPV6 = 0x86DD

# IPV4 protocol numbers:
TCP = 6
UDP = 17

def tally_packet(header, payload):

    if verbose:
        when, ms = header.getts()
        when += ms/1000000

    dstmac = payload[0: 6]
    srcmac = payload[6:12]
    etype, = unpack("!H", payload[12:14])

    size = header.getlen()

    outbound  = srcmac == my_macb
    broadcast = dstmac[0]&1

    if not (outbound or broadcast) and dstmac != my_macb:
        # This might happen routinely if the interface is in permiscuous mode for any reason?
        if verbose > 1:
            log(f"WARNING: got packet not for us?  {mac_str(srcmac)} -> {mac_str(dstmac)}", 2)
        return

    #
    # Protocol-aware tracking:
    #
    src_ip = dst_ip = src_port = dst_port = protocol = None
    if etype == pt_IPV4:

        protocol = payload[14+ 9]
        src_ip   = payload[14+12:14+16]
        dst_ip   = payload[14+16:14+20]

        if protocol in (TCP, UDP):
            src_port, dst_port = unpack('!HH', payload[34:38])

            #
            # For tallying connections, treat all 5+ digit port numbers as transient.
            #
            # This is somewhat arbitrary but works well enough in practice to
            #   limit the proliferation of entries.
            #
            if src_port > 9999:
                src_port = 0
            if dst_port > 9999:
                dst_port = 0

        else:
            # Protocols other than UDP/TCP
            if verbose > 1:
                log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} P:{protocol}(!=UDP/TCP) {ip_str(src_ip)}->{ip_str(dst_ip)}", 2)

    elif etype == pt_ARP:

        ptyp, = unpack("!H", payload[14+ 2:14+ 4])

        if ptyp == pt_IPV4:
            oper  = payload[14+ 7]          # 1 = ask / 2 = reply
            sha   = payload[14+ 8:14+14]
            spa   = payload[14+14:14+18]
            tha   = payload[14+18:14+24]
            tpa   = payload[14+24:14+28]

            if sum(spa) and sum(sha):   # Make sure neither are 0's aka unknown
                if spa != arp_mac2ip.get(sha):
                    arp_mac2ip[sha] = spa               # Update our cache
                    send_msg(**arp_msg({sha:spa}))      # Notify all clients of the change
                    if verbose:
                        log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} ARP CHANGE: {mac_str(sha)} -> {ip_str(spa)}")
                elif verbose > 1:
                    log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} ARP (same): {mac_str(sha)} -> {ip_str(spa)}")
            elif verbose > 1:
                log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} ARP (null): {mac_str(sha)} -> {ip_str(spa)}")
        else:
            if verbose:
                log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)}: Unhandled ARP type: {ptyp:04X}")

    elif etype == pt_IPV6:
        # TODO...
        if verbose > 1:
            log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} IPV6", 2)

    else:
        if verbose:
            log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} UNKNOWN etype (0x{etype:04x})", 1)

    if outbound:
        key = (dstmac, etype, protocol, dst_ip, dst_port, src_ip, src_port)
    else:
        key = (srcmac, etype, protocol, src_ip, src_port, dst_ip, dst_port)

    inpackets, inbytes, outpackets, outbytes = totals.get(key, (0, 0, 0, 0))

    if outbound:
        outpackets += 1
        outbytes   += size
    else:
        inpackets += 1
        inbytes   += size

    totals[key] = (inpackets, inbytes, outpackets, outbytes)


#
# Listen for connections at the specified ip/port:
#
log(f"Listening at {ip or 'ALL'}:{port}")

while True:
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        server.setblocking(0)
        server.bind((ip, port))
        server.listen(5)
        break
    except Exception as e:
        server.close()
        log(e)
        log("Retrying in 10 seconds...")
        sleep(10)

clients = {}    # Maps client socket to output queue.

def arp_msg(mac2ip, ip2host={}):
    return dict(kind='arp', my_mac=my_mac, mac2ip=mac2ip, ip2host=ip2host)

def new_connection():
    client, client_address = server.accept()
    log(f"New connection from {client_address}", 1)
    client.setblocking(0)
    client.shutdown(socket.SHUT_RD)  # Write-only socket; don't queue up incoming data because we'll never read it.
    clients[client] = []

    send_msg_to(client, **arp_msg(arp_mac2ip, arp_ip2host))  # Send cached arp output first thing.

def close_connection(con):
    if con in clients:
        del clients[con]
    else:
        log(f"WARNING: Closing connection that's not in clients list.")
    con.close()
    log(f"Client closed.", 1)

def pack_obj(o):
    b = pickle.dumps(o)
    return pack('!H', len(b)) + b

DROP = pack_obj({'kind':'drop'}) # Special message meaning outgoing frames were dropped
def send_msg(**msg):
    """Sends msg to all connected clients.
    Returns number of bytes queued to send (per client).
    """
    if not clients and verbose < 2: # Verbose>=2 logs the message length.
        return 0
    msg = pack_obj(msg)
    for queue in clients.values():
        if len(queue) < 4:
            queue.append(msg)
        elif len(queue) < 10:
            queue.append(DROP)
    return len(msg)

def send_msg_to(client, **msg):
    """Sends msg to the specified client.
    """
    if client not in clients:
        return

    msg   = pack_obj(msg)
    queue = clients[client]

    if len(queue) < 4:
        queue.append(msg)
    elif len(queue) < 10:
        queue.append(DROP)

def refresh_arp():
    """Returns True if anything has changed.  (Shouldn't happen often?)
    """
    global arp_mac2ip, arp_ip2host

    dirty = False

    mac2ip, ip2host = load_arp()

    for mac, ip in mac2ip.items():
        if arp_mac2ip.get(mac) != ip:
            arp_mac2ip[mac] = ip
            dirty = True

    for ip, host in ip2host.items():
        if arp_ip2host.get(ip) != host:
            arp_ip2host[i] = host
            dirty = True

    if dirty:
        if verbose:
            log("ARP tables have changed.")
    elif verbose > 1:
        log("ARP tables unchanged.")

    return dirty

#
# Do it:
#

cap = cap_fd = None

try:
    totals = {}
    t1 = time() + reporting_period  # When to issue first report

    while True:

        #
        # Start and stop packet capture on-demand:
        #
        if cap is None:
            if clients or debug:

                log(f"Opening network interface {net_dev}...", 1)
                try:
                    # device, bytes to capture per packet, promiscuous mode, timeout (ms)
                    cap    = pcapy.open_live(net_dev, 64, False, min(1000, int(reporting_period*1000)))
                    cap_fd = cap.getfd()
                    cap.setnonblock(True)
                except Exception as e:
                    cap    = None
                    cap_fd = None
                    err(e, 3)       # TODO: should retry after a while instead of exitting?
                t1 = time() + reporting_period  # When to issue next report

                #
                # Refresh our arp tables (we've been asleep, so things may have changed...)
                #
                if refresh_arp():
                    send_msg(**arp_msg(arp_mac2ip, arp_ip2host))  # Notify all clients of updated

        else:
            if not (clients or debug):
                log(f"Closing network interface {net_dev}...", 1)
                cap.close()
                cap = cap_fd = None

        #
        # Wait for anything we can do...
        #
        read_fds = [server]

        if cap_fd is not None:
            read_fds.append(cap_fd)

        write_fds = [sock for sock, queue in clients.items() if queue]

        read_ready, write_ready, err_ready = select.select(read_fds, write_fds, (), max(0, t1-time()) if (cap is not None or clients) else 9999)

        #
        # Capture packets, and accept new connections:
        #
        for fd in read_ready:

            if fd == cap_fd:
                cap.dispatch(-1, tally_packet)

            elif fd == server:
                new_connection()

        #
        # Send queued output to any clients that aren't backed up:
        #
        for fd in write_ready:

            if fd in clients:
                queue = clients[fd]
                while queue:
                    try:
                        msg = queue[0]
                        n = fd.send(msg)
                        if n >= len(msg):
                            queue.pop(0)
                        elif n > 0:
                            queue[0] = msg[n:]
                        else:
                            break
                    except Exception as e:
                        log(f"Exception during send: {e}", 1)
                        close_connection(fd)
                        break

        #
        # Generate periodic reports (to client queues):
        #
        when = time()
        if when >= t1:

            n = send_msg(
                    kind='sums',
                    when=when,
                    vals=totals,
                )

            if verbose > 1:
                log(f"---- {datetime.fromtimestamp(int(when))}: {len(totals)} = {n} bytes  ----", 2)

            totals = {}
            t1 += reporting_period

finally:
    log(f"Closing pcap and {len(clients)} clients.")
    for client in clients:
        client.close()
    if cap is not None:
        cap.close()
    log("Done.")

