#!/usr/bin/python3
#
# Trax (network activity tracker) Version 1.0
#
# Copyright 2023 Brandyn Webb
# 
# This file is part of Trax (network activity tracker).
# 
# Trax is free software: you can redistribute it and/or modify it under the terms 
# of the GNU General Public License as published by the Free Software Foundation, 
# either version 3 of the License, or (at your option) any later version.
# 
# Trax is distributed in the hope that it will be useful, but WITHOUT ANY 
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 
# PARTICULAR PURPOSE. See the GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along with 
# Trax. If not, see <https://www.gnu.org/licenses/>. 
# 


import pcapy, socket, select, pickle, subprocess

from struct   import pack, unpack
from datetime import datetime
from time     import time, sleep

from sys import argv, exit

#
# Defaults:
#
reporting_period = 1            # Seconds
verbose          = 0            # Logging output verbosity level
ip               = ''           # IP addy to accept connections on.  Default: ALL.
port             = 3344         # Default port to accept connections at.

#
# Logging and failure reporting:
#
def log(msg, verbose_level=0):  # Pre-filtering on verbosity is preferred to save on message formatting
    if verbose >= verbose_level:
        print(msg, flush=True)

def err(msg, exit_code=1):
    print(msg, flush=True)      # TODO: should sent to stderr instead
    exit(exit_code)

#
# Available devs:
#
devs = pcapy.findalldevs()

if not devs:
    err("No devices available.", 2)

net_dev = devs[0]

#
# Arg parsing:
#
help_str         = f"Use: {argv[0]} [<network_interface>] [-v(erbose)] [-r <reporting_period>] [-i <ip>] [-p <port>]\n" \
                   f"  Network interface (default {net_dev!r}) must be one of: {', '.join(devs)}\n" \
                   f"  Default ip (to listen at) is all.  Specify an IP addy of this host to limit.\n" \
                   f"  Default port is {port}\n" \
                   f"  Reporting period is in seconds (default {reporting_period})."

args = argv[1:]

def nextarg():
    if args:
        return args.pop(0)
    else:
        err(help_str, 1)

while args:
    arg = nextarg()

    if arg == '-v':
        verbose += 1

    elif arg == '-r':
        try:
            reporting_period = float(nextarg())
        except:
            err("Reporting period must be numeric seconds.", 1)

    elif arg == '-i':
        ip = nextarg()

    elif arg == '-p':
        try:
            port = int(nextarg())
        except:
            err("Port must be an integer.", 1)

    elif arg in devs:
        net_dev = arg

    elif arg in ('help', '-h'):
        err(help_str, 0)

    else:
        err(help_str, 1)

if net_dev not in devs:
    err(f"Can't find network device {net_dev!r}.\nAvailable devices: {', '.join(devs)}", 2)

#
# We'll send clients "arp -a" output along with our own mac addy, for
#  their use.  We don't use either of these otherwise:
#
# For now we'll run arp once at startup to cache the IP/MAC mapping (with hostnames).
# TODO: Eventually we should update this, say when we see a new MAC we haven't
#       seen before, and/or periodically to catch IP changes.
#
arp_out = subprocess.run(['arp', '-a', '-i', net_dev], capture_output=True, check=True, text=True).stdout

with open(f"/sys/class/net/{net_dev}/address") as fl: # cheap and dirty way to get our own mac addy
    my_mac = fl.read().strip()
my_macb = bytes.fromhex(my_mac.replace(':', ' '))
log(f"My MAC: {my_mac!r} aka {my_macb!r}")

#
# Tally function:
#

# IPV4 protocol numbers:
TCP = 6
UDP = 17

if verbose:
    def mac_str(mac):
        """MAC 6-bytes -> printable version
        """
        return "%02X:%02X:%02X:%02X:%02X:%02X"%tuple(mac)

    def ip_str(ip):
        """IPV4 IP 4-bytes -> printable version
        """
        if ip is None:
            return "N/A"
        if (ip[0]>>4) == 14:    # Multicast
            return "*"
        return '.'.join(str(i) for i in ip)

def tally_packet(header, payload):

    if verbose:
        when, ms = header.getts()
        when += ms/1000000

    dstmac = payload[0: 6]
    srcmac = payload[6:12]
    etype, = unpack("!H", payload[12:14])

    size = header.getlen()

    outbound  = srcmac == my_macb
    broadcast = dstmac[0]&1

    if not (outbound or broadcast) and dstmac != my_macb:
        # This might happen routinely if the interface is in permiscuous mode for any reason?
        if verbose > 1:
            log(f"WARNING: got packet not for us?  {mac_str(srcmac)} -> {mac_str(dstmac)}", 2)
        return

    #
    # Protocol-aware tracking:
    #
    src_ip = dst_ip = src_port = dst_port = protocol = None
    if etype == 0x0800: # IPV4

        protocol = payload[14+ 9]
        src_ip   = payload[14+12:14+16]
        dst_ip   = payload[14+16:14+20]

        if protocol in (TCP, UDP):
            src_port, dst_port = unpack('!HH', payload[34:38])

            #
            # For tallying connections, treat all 5+ digit port numbers as transient.
            #
            # This is somewhat arbitrary but works well enough in practice to
            #   limit the proliferation of entries.
            #
            if src_port > 9999:
                src_port = 0
            if dst_port > 9999:
                dst_port = 0

        else:
            # Protocols other than UDP/TCP
            if verbose > 1:
                log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} P:{protocol}(!=UDP/TCP) {ip_str(src_ip)}->{ip_str(dst_ip)}", 2)

    elif etype == 0x0806: # ARP
        if verbose > 1:
            log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} ARP", 2)

    elif etype == 0x86DD: # IPV6
        # TODO...
        if verbose > 1:
            log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} IPV6", 2)

    else:
        if verbose:
            log(f"{datetime.fromtimestamp(when)} {size:5}B {mac_str(srcmac)}->{mac_str(dstmac)} UNKNOWN etype (0x{etype:04x})", 1)

    if outbound:
        key = (dstmac, etype, protocol, dst_ip, dst_port, src_ip, src_port)
    else:
        key = (srcmac, etype, protocol, src_ip, src_port, dst_ip, dst_port)

    inpackets, inbytes, outpackets, outbytes = totals.get(key, (0, 0, 0, 0))

    if outbound:
        outpackets += 1
        outbytes   += size
    else:
        inpackets += 1
        inbytes   += size

    totals[key] = (inpackets, inbytes, outpackets, outbytes)


#
# Listen for connections at the specified ip/port:
#
log(f"Listening at {ip or 'ALL'}:{port}")

while True:
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        server.setblocking(0)
        server.bind((ip, port))
        server.listen(5)
        break
    except Exception as e:
        server.close()
        log(e)
        log("Retrying in 10 seconds...")
        sleep(10)

clients = {}    # Maps client socket to output queue.

def new_connection():
    client, client_address = server.accept()
    log(f"New connection from {client_address}", 1)
    client.setblocking(0)
    client.shutdown(socket.SHUT_RD)  # Write-only socket; don't queue up incoming data because we'll never read it.
    clients[client] = []

    send_msg_to(client, kind='arp', raw=arp_out, my_mac=my_mac)  # Send cached arp output first thing.

def close_connection(con):
    if con in clients:
        del clients[con]
    else:
        log(f"WARNING: Closing connection that's not in clients list.")
    con.close()
    log(f"Client closed.", 1)

def pack_obj(o):
    b = pickle.dumps(o)
    return pack('!H', len(b)) + b

DROP = pack_obj({'kind':'drop'}) # Special message meaning outgoing frames were dropped
def send_msg(**msg):
    """Sends msg to all connected clients.
    Returns number of bytes queued to send (per client).
    """
    if not clients and verbose < 2: # Verbose>=2 logs the message length.
        return 0
    msg = pack_obj(msg)
    for queue in clients.values():
        if len(queue) < 4:
            queue.append(msg)
        elif len(queue) < 10:
            queue.append(DROP)
    return len(msg)

def send_msg_to(client, **msg):
    """Sends msg to the specified client.
    """
    if client not in clients:
        return

    msg   = pack_obj(msg)
    queue = clients[client]

    if len(queue) < 4:
        queue.append(msg)
    elif len(queue) < 10:
        queue.append(DROP)

#
# Do it:
#

cap = cap_fd = None

try:
    totals = {}
    t1 = time() + reporting_period  # When to issue first report

    while True:

        #
        # Start and stop packet capture on-demand:
        #
        if cap is None:
            if clients:
                log(f"Opening network interface {net_dev}...", 1)
                try:
                    # device, bytes to capture per packet, promiscuous mode, timeout (ms)
                    cap    = pcapy.open_live(net_dev, 64, False, min(1000, int(reporting_period*1000)))
                    cap_fd = cap.getfd()
                    cap.setnonblock(True)
                except Exception as e:
                    cap    = None
                    cap_fd = None
                    err(e, 3)       # TODO: should retry after a while instead of exitting?
                t1 = time() + reporting_period  # When to issue next report
        else:
            if not clients:
                log(f"Closing network interface {net_dev}...", 1)
                cap.close()
                cap = cap_fd = None

        #
        # Wait for anything we can do...
        #
        read_fds = [server]

        if cap_fd is not None:
            read_fds.append(cap_fd)

        write_fds = [sock for sock, queue in clients.items() if queue]

        read_ready, write_ready, err_ready = select.select(read_fds, write_fds, (), max(0, t1-time()) if (cap is not None or clients) else 9999)

        #
        # Capture packets, and accept new connections:
        #
        for fd in read_ready:

            if fd == cap_fd:
                cap.dispatch(-1, tally_packet)

            elif fd == server:
                new_connection()

        #
        # Send queued output to any clients that aren't backed up:
        #
        for fd in write_ready:

            if fd in clients:
                queue = clients[fd]
                while queue:
                    try:
                        msg = queue[0]
                        n = fd.send(msg)
                        if n >= len(msg):
                            queue.pop(0)
                        elif n > 0:
                            queue[0] = msg[n:]
                        else:
                            break
                    except Exception as e:
                        log(f"Exception during send: {e}", 1)
                        close_connection(fd)
                        break

        #
        # Generate periodic reports (to client queues):
        #
        when = time()
        if when >= t1:

            n = send_msg(
                    kind='sums',
                    when=when,
                    vals=totals,
                )

            if verbose > 1:
                log(f"---- {datetime.fromtimestamp(int(when))}: {len(totals)} = {n} bytes  ----", 2)

            totals = {}
            t1 += reporting_period

finally:
    log(f"Closing pcap and {len(clients)} clients.")
    for client in clients:
        client.close()
    if cap is not None:
        cap.close()
    log("Done.")

