import os

import numpy as np
import pandas as pd
import pyshark

import helper_scripts.helper_functions as helper_functions

# NOTE there is also a packet called scapy which might work

# To create a capture and sslkeylogfile, do the following:
# First run setup.sh
# Then run experiment.py with scenario_analyze_packets.csv while setting the variables to POOL_SIZE = 1, MEASUREMENTS_PER_TIMER = 5, TIMERS = 1
# Then run teardown.sh

DATESTRING = "20240830153007"
EXPECTED_DELAY = 10  # ms
EXPECTED_MEASUREMENTS_PER_CONFIG = 5


def main():
    kem_id_df = helper_functions.get_kem_ids()
    # print(kem_id_df)

    os.makedirs("feathers", exist_ok=True)
    if os.path.exists("feathers/udp_packets.feather"):
        udp_packets_df = pd.read_feather("feathers/udp_packets.feather")
    else:
        udp_packets_df = analyze_udp_packets(kem_id_df)
        udp_packets_df.to_feather("feathers/udp_packets.feather")

    get_packets_sent_by_node(udp_packets_df)


def get_packets_sent_by_node(udp_packets_df):
    udp_packets_df = udp_packets_df.drop(columns=["srcport", "quic_cid"])
    # print(udp_packets_df.head(20))
    # print()

    i = 0
    packets_per_node = pd.DataFrame()
    for g in udp_packets_df.groupby("wireshark_quic_cid"):
        # print(g[0]) # is the group number
        # print(g[1]) # is the dataframe of this group

        g_df = g[1]

        finished_row = g_df.loc[
            (g_df["Sender"] == "Client")
            & (g_df["tls_handshake_type"].apply(lambda x: "Finished" in x))
        ]
        if finished_row.empty:
            print(
                f"No finished row found for {i}, probably cuz an error, throwing away this connection, since it was probably retried"
            )
            # print(g_df)
            continue
        # print(finished_row)
        # print("important", finished_row.iloc[0]["ID"])
        # print("before", g_df)
        g_df = g_df.query(f"ID <= {finished_row.iloc[0]['ID']}")
        # print("after", g_df)
        # print()

        packets = g_df.groupby("Sender").size()
        packets_with_crypto = g_df.query("no_crypto == False").groupby("Sender").size()

        # if g_df["kem_algo"].iloc[0] == "p256":
        #     print(finished_row.index[0])
        #     print(g_df)
        # print(g_df.query("Sender == 'Client'"))

        packets_per_node = pd.concat(
            [
                packets_per_node,
                pd.DataFrame(
                    {
                        "wireshark_quic_cid": [g[0]],
                        "kem_algo": g_df["kem_algo"].iloc[0],
                        "client_sent_packets_count": packets["Client"],
                        "server_sent_packets_count": packets["Server"],
                        "client_sent_packets_with_crypto_count": packets_with_crypto[
                            "Client"
                        ],
                        "server_sent_packets_with_crypto_count": packets_with_crypto[
                            "Server"
                        ],
                    }
                ),
            ],
            ignore_index=True,
        )
        i += 1
        # if i >= 5:
        #     break
    # print(packets_per_node)
    # print(packets_per_node.loc[packets_per_node["kem_algo"] == "p256"])

    nunique_and_count = packets_per_node.groupby("kem_algo").agg(
        {
            "client_sent_packets_count": ["nunique", "count"],
            "server_sent_packets_count": ["nunique", "count"],
            "client_sent_packets_with_crypto_count": ["nunique"],
            "server_sent_packets_with_crypto_count": ["nunique"],
        }
    )
    nunique_and_count.columns = [
        "_".join(col).strip() for col in nunique_and_count.columns.values
    ]
    assert (
        (
            nunique_and_count.client_sent_packets_count_count
            == EXPECTED_MEASUREMENTS_PER_CONFIG
        )
        & (
            nunique_and_count.server_sent_packets_count_count
            == EXPECTED_MEASUREMENTS_PER_CONFIG
        )
    ).all()
    nunique_and_count = nunique_and_count.drop(
        columns=["client_sent_packets_count_count", "server_sent_packets_count_count"]
    )
    assert (
        (nunique_and_count.client_sent_packets_with_crypto_count_nunique == 1)
        & (nunique_and_count.server_sent_packets_with_crypto_count_nunique == 1)
    ).all()
    # print(nunique_and_count)

    # print(packets_per_node)
    packets_per_node_with_crypto = packets_per_node[
        [
            "kem_algo",
            "client_sent_packets_with_crypto_count",
            "server_sent_packets_with_crypto_count",
        ]
    ]
    # print(packets_per_node_with_crypto)
    packets_per_node_with_crypto = (
        packets_per_node_with_crypto.drop_duplicates().sort_values(
            by=[
                "client_sent_packets_with_crypto_count",
                "server_sent_packets_with_crypto_count",
            ]
        )
    )
    print(packets_per_node_with_crypto)
    kem_characteristics_df = helper_functions.get_kem_characteristics()
    df = pd.merge(
        packets_per_node_with_crypto, kem_characteristics_df, on="kem_algo", how="left"
    )
    # print(df)
    # print()
    # print(df.loc[df["kem_algo"] == "p256_mlkem512", "length_public_key"])
    df = helper_functions.fill_in_kem_characteristics_for_hybrid_kems(df)
    # df["length_secret_key"] = df["length_secret_key"].astype(int)
    # df["length_shared_secret"] = df["length_shared_secret"].astype(int)

    df = df.drop(
        columns=[
            "claimed_nist_level",
            "claimed_security",
            "length_secret_key",
            "length_shared_secret",
        ]
    )

    # print(df.info())
    print(df)
    return df


def analyze_udp_packets(kem_id_df):
    cap = pyshark.FileCapture(
        os.path.join("saved", "captures", f"capture_{DATESTRING}.pcap"),
        override_prefs={
            "tls.keylog_file": os.path.join(
                "saved", "captures", f"sslkeylogfile_{DATESTRING}.log"
            )
        },
        display_filter="udp",
    )
    # print(cap)
    df = pd.DataFrame()

    for idx, packet in enumerate(cap):
        # icmp messages with pings that contain quic, ignore them
        if "udp" not in packet:
            # print(packet)
            # print(packet.layers)
            continue

        # if idx >= 2000:
        # if idx >= 6:
        # break

        # print(packet.number)
        # print(packet.layers)
        # print(packet.ip.field_names)  # ['version', 'hdr_len', 'dsfield', 'dsfield_dscp', 'dsfield_ecn', 'len', 'id', 'flags', 'flags_rb', 'flags_df', 'flags_mf', 'frag_offset', 'ttl', 'proto', 'checksum', 'checksum_status', 'src', 'addr', 'src_host', 'host', 'dst', 'dst_host']
        # print(packet.eth.field_names)  # ['dst', 'dst_resolved', 'dst_oui', 'dst_oui_resolved', 'addr', 'addr_resolved', 'addr_oui', 'addr_oui_resolved', 'dst_lg', 'lg', 'dst_ig', 'ig', 'src', 'src_resolved', 'src_oui', 'src_oui_resolved', 'src_lg', 'src_ig', 'type']
        # print(packet.udp.field_names)  # ['srcport', 'dstport', 'port', 'length', 'checksum', 'checksum_status', 'stream', '', 'time_relative', 'time_delta', 'payload']
        # if packet.number == "695" or packet.number == "696":
        #     for quic_layer in packet.get_multiple_layers("quic"):
        #         print(packet.number, quic_layer.field_names)

        match ("scid" in packet.quic.field_names, "dcid" in packet.quic.field_names):
            case (True, True):
                assert False, "Both scid and dcid are present"
            case (False, False):
                cid = np.nan
            case (True, False):
                cid = packet.quic.scid
            case (False, True):
                cid = packet.quic.dcid

        # A packet can have multiple quic layers, the layers can have multiple fields with the same name, but they are hidden behind the all_fields attribute
        tls_handshake_types = []
        for quic_layer in packet.get_multiple_layers("quic"):
            if "tls_handshake_type" in quic_layer.field_names:
                for field in quic_layer.tls_handshake_type.all_fields:
                    tls_handshake_types.append(field.show)
        tls_handshake_types = map_tls_handshake_types(tls_handshake_types)

        # The naming inside of wireshark of the kem algos is not correct all the time
        supported_group = np.nan
        if "Client Hello" in tls_handshake_types:
            for quic_layer in packet.get_multiple_layers("quic"):
                if "tls_handshake_extensions_supported_group" in quic_layer.field_names:
                    # only shows the first of the supported groups, but fine in our context, when only looking at the client hello
                    supported_group = (
                        quic_layer.tls_handshake_extensions_supported_group
                    )

        # no_crypto is only correct for the quic packets sent in the handshake, not for the packets sent after the handshake
        no_crypto = []
        for quic_layer in packet.get_multiple_layers("quic"):
            if "crypto_offset" in quic_layer.field_names:
                no_crypto.append(False)
            else:
                no_crypto.append(True)
        assert len(no_crypto) > 0, "No quic layer"
        no_crypto = all(no_crypto)

        df = pd.concat(
            [
                df,
                pd.DataFrame(
                    {
                        "ID": [packet.number],
                        "Sender": [
                            (
                                "Server"
                                if packet.eth.src == "00:00:00:00:00:01"
                                else "Client"
                            )
                        ],
                        "srcport": [packet.udp.srcport],
                        "time_relative": [packet.udp.time_relative],
                        "time_delta": [packet.udp.time_delta],
                        "frame_length": [packet.length],
                        "ip_length": [packet.ip.len],
                        "udp_length": [packet.udp.length],
                        "quic_length": [packet.quic.packet_length],
                        "wireshark_quic_cid": [packet.quic.connection_number],
                        "quic_cid": [cid],
                        "supported_group": [supported_group],
                        "tls_handshake_type": [tls_handshake_types],
                        "no_crypto": [no_crypto],
                    }
                ),
            ],
            ignore_index=True,
        )

    # change type from str to int
    df["ID"] = df["ID"].astype(int)
    df["srcport"] = df["srcport"].astype(int)
    df["time_relative"] = df["time_relative"].astype(float)
    df["time_delta"] = df["time_delta"].astype(float)
    df["frame_length"] = df["frame_length"].astype(int)
    df["ip_length"] = df["ip_length"].astype(int)
    df["udp_length"] = df["udp_length"].astype(int)
    df["quic_length"] = df["quic_length"].astype(int)
    df["wireshark_quic_cid"] = df["wireshark_quic_cid"].astype(int)

    # supported groups do have hex string values, but with lowercase letters, so keep the x lowercase and transform the rest to uppercase
    df["supported_group"] = df["supported_group"].apply(
        lambda x: x[0:2] + x[2:].upper() if pd.notna(x) else np.nan
    )
    df["kem_algo"] = df["supported_group"].apply(
        lambda x: (
            kem_id_df.loc[kem_id_df["nid"] == x, "kem_name"].values[0]
            if pd.notna(x)
            else np.nan
        )
    )
    df["kem_algo"] = df.groupby("wireshark_quic_cid")["kem_algo"].transform(
        lambda x: x.ffill().bfill()
    )

    printdf = df.drop(columns=["srcport", "quic_cid"])
    # print(printdf.head())
    # print(printdf.query("ID >= 689 and ID <= 699"))
    # print()
    # print(printdf.query("ID >= 1657 and ID <= 1680"))
    return df


def map_tls_handshake_types(handshake_types):
    tls_handshake_type_map = {
        "1": "Client Hello",
        "2": "Server Hello",
        "4": "New Session Ticket",
        "8": "Encrypted Extensions",
        "11": "Certificate",
        "12": "Server Key Exchange",
        "13": "Certificate Request",
        "14": "Server Hello Done",
        "15": "Certificate Verify",
        "16": "Client Key Exchange",
        "20": "Finished",
    }
    return [
        tls_handshake_type_map.get(
            handshake_type, f"Unknown tls_handshake_type {handshake_type}"
        )
        for handshake_type in handshake_types
    ]


if __name__ == "__main__":
    main()