import math
import pandas as pd

FEATHERS_DIR = "feathers"

import generate_graphs as gg
import helper_scripts.helper_functions as hf
import analyze_packets as ap


def main():
    data = pd.read_feather(f"{FEATHERS_DIR}/data.feather")
    # data = pd.read_feather(f"{FEATHERS_DIR}/data_run_20241028.feather")

    # bandwith_calcs()
    loss_calculations()
    # static_scenario_statistical_analysis(data)
    # median_of_all_static_runs_per_algorithm(data)
    # stats_of_qtl95_of_packetloss(data)
    # error_count_and_rate(data)
    # measurements_with_negative_skewness(data)
    # iqr_kurtosis_of_delay_data(data)

    # print_kem_ids()


def bandwith_calcs():
    df = get_cic_and_sic()

    def calc_delay_cuz_of_bandwidth_in_ms(cic, sic, lsf, bandwidth, packetlength=1242):
        """
        Calculates the delay in ms caused by bandwidth.

        Args:
            cic: client initial count.
            sic: server initial count.
            lsf: last server ethernet frame length.
            bandwidth: bandwidth in bits per second.
            packetlength: length of a packet in bytes.

        Returns:
            delay in seconds.
        """
        SECONDS_IN_MS = 1000
        return (
            (((cic + (sic - 1)) * packetlength) + lsf) * 8 / bandwidth * SECONDS_IN_MS
        )

    for bw in [0.1, 0.25, 0.5, 1, 3, 5, 500]:
        df[f"t_delay_{bw}"] = df.apply(
            lambda row: calc_delay_cuz_of_bandwidth_in_ms(
                row["cic"], row["sic"], row["server_last_packet_length"], bw * 1000000
            ),
            axis=1,
        )

    print(df)
    return df


def loss_calculations():
    df = get_cic_and_sic()

    # p_noOneSec does not make sense if cic or sic is bigger than 10 -> look thesis
    df = df.query("cic <= 10 and sic <= 10")

    def calc_p_no_loss(cic, sic, l):
        """
        Calculates the probability p_noLoss.

        Args:
            cic: client initial count.
            sic: server initial count.
            l: loss probability.

        Returns:
            p_noLoss as defined in the thesis.
        """

        return (1 - l) ** (cic + sic)

    def calc_p_no_one_sec_delay(cic, sidc, l):
        """
        Calculates the probability p_noOneSec.

        Args:
            cic: client initial count.
            sidc: server initial decryptable count, without last packet of sic if length < 1200.
            lsf: last server ethernet frame length.
            l: loss probability.

        Returns:
            p_noOneSec as defined in the thesis.
        """

        term1 = (1 - l) ** cic * (1 - l ** (sidc + (cic - 1)))
        term2 = 0

        # range: upper is not inclusive, in math symbol SUM it is inclusive
        for i in range(1, cic):
            term2 += (
                math.comb(cic, i) * (1 - l) ** (cic - i) * l**i * (1 - l ** (cic - i))
            )

        return term1 + term2

    def calc_l_for_no_loss_p(cic, sic, p):
        """
        Calculates the loss probability l for a p_noLoss of 0.95.

        Args:
            cic: client initial count.
            sic: server initial count.
            p: probability.

        Returns:
            l as defined in the thesis.
        """

        return 1 - (p ** (1 / (cic + sic)))

    df["sidc"] = df.apply(
        lambda row: (
            row["sic"] - 1 if row["server_last_packet_length"] < 1200 else row["sic"]
        ),
        axis=1,
    )

    for l in [0.01, 0.05, 0.10, 0.20]:
        df[f"p_noLoss_{l}"] = df.apply(
            lambda row: calc_p_no_loss(row["cic"], row["sic"], l), axis=1
        )
        df[f"p_noOneSec_{l}"] = df.apply(
            lambda row: calc_p_no_one_sec_delay(row["cic"], row["sidc"], l),
            axis=1,
        )

    df["l_for_noLoss_p50"] = df.apply(
        lambda row: calc_l_for_no_loss_p(row["cic"], row["sic"], 0.50), axis=1
    )
    df["l_for_noLoss_p95"] = df.apply(
        lambda row: calc_l_for_no_loss_p(row["cic"], row["sic"], 0.95), axis=1
    )

    print(df)

    return df


def static_scenario_statistical_analysis(data):
    ldata = data
    print("Static scenario statistical analysis")
    ldata = gg.filter_data(
        ldata,
        scenario="static",
        protocol="quic",
        sec_level=["secLevel1", "secLevel1_hybrid"],
    )

    means_of_medians = []
    stdevs_of_medians = []

    kem_alg_names = ldata["kem_alg"].unique()
    for kem_alg_name in kem_alg_names:
        kem_alg_data = ldata.query(f"kem_alg == '{kem_alg_name}'")
        medians = kem_alg_data["median"]
        # print(kem_alg_name, medians.mean(), medians.std())
        means_of_medians.append(medians.mean())
        stdevs_of_medians.append(medians.std())

    print("Mean of stdevs of medians")
    print(pd.Series(stdevs_of_medians).mean())
    print("Stdev of stdevs of medians")
    print(pd.Series(stdevs_of_medians).std())


def median_of_all_static_runs_per_algorithm(data):
    ldata = data
    print("Median of all static runs per algorithm")
    ldata = gg.filter_data(ldata, scenario="static", protocol="quic")
    # compound per algorithm, then take the median of all

    # get every algorithm name
    # print(ldata["kem_alg"].unique())
    kem_alg_names = ldata["kem_alg"].unique()
    for kem_alg_name in kem_alg_names:
        kem_alg_data = ldata.query(f"kem_alg == '{kem_alg_name}'")
        # print(kem_alg_data)
        kem_alg_measurements = []
        for row in kem_alg_data.iterrows():
            # print(row[1]["measurements"])
            kem_alg_measurements.extend(row[1]["measurements"])
            # print(row[1]["median"])
        print(f"Median of {kem_alg_name}")
        print(pd.Series(kem_alg_measurements).median())
        print()


def stats_of_qtl95_of_packetloss(data):
    ldata = data
    print("Stats of qtl95")

    ldata = gg.filter_data(ldata, scenario="packetloss", protocol="quic")
    ldata = ldata.query("kem_alg == 'x25519' or kem_alg == 'frodo640aes'")
    # ldata = ldata.query("kem_alg == 'mlkem1024' or kem_alg == 'frodo1344aes'")

    # ldata = ldata.query
    print("Showing data of packetloss quic")
    ldata = ldata.drop(
        columns=[
            "scenario",
            "protocol",
            "sec_level",
            "cli_pkt_loss",
            "cli_delay",
            "cli_rate",
            "measurements",
        ]
    )
    print(ldata)


# For old run without bigger crypto buffer: Grep tells there are 83996 CRYPTO_BUFFER_EXCEEDEDs, while total error count is just a bit above it 84186
# For new run with fix: 187.0 other errors, probably from server side, because 'Shutdown before completion' on client side while waiting for handshake to complete -> b'808B57C2E1760000:error:0A0000CF:SSL routines:quic_do_handshake:protocol is shutdown:ssl/quic/quic_impl.c:1717:\n'
def error_count_and_rate(data):
    print("Error count and rate")
    ldata = data
    print("Total index length")
    print(len(ldata.index))
    print("Total error count")
    print(ldata["error_count"].sum())
    ldata = ldata.query("error_count > 0")
    print("Total index length with error count > 0")
    print(len(ldata.index))
    print("Data with error count > 0 describe error_count")
    print(ldata["error_count"].describe())
    print("How much each scenario has error count > 0")
    print(ldata["scenario"].value_counts())
    print("How much each protocol has error count > 0")
    print(ldata["protocol"].value_counts())
    print("How much each scenario protocol combinanation has error count > 0")
    print(ldata.groupby(["scenario", "protocol"]).size())
    print("How much each kem_alg has error count > 0")
    print(ldata["kem_alg"].value_counts())

    print("With error count > 3")
    ldata = ldata.query("error_count > 12")
    print(
        ldata[
            [
                "scenario",
                "protocol",
                "sec_level",
                "kem_alg",
                "error_count",
                "error_rate",
            ]
        ]
    )


def measurements_with_negative_skewness(data):
    print("Measurements with negative skewness")
    ldata = data
    print("Skewness of data")
    print(ldata["skewness"].describe())

    print("Amount of data with negative skewness")
    ldata = ldata.query("skewness < 0")
    print(len(ldata.index))
    # ldata = ldata.query("scenario != 'reorder'")
    # print(len(ldata.index))
    # give out per scenario the count of measurements with negative skewness
    print("Per scenario numbers of measurements with negative skewness")
    print(ldata["scenario"].value_counts())  # mostly reorder and jitter, rate a bit


def iqr_kurtosis_of_delay_data(data):
    print("Kurtosis of data, Fisher's definition, so 0 is normal distribution")
    ldata = data
    print(ldata[["iqr", "kurtosis"]].describe())
    ldata = ldata.query("scenario == 'delay'")
    print(ldata[["iqr", "kurtosis"]].describe())


def print_kem_ids():
    data = hf.get_kem_ids()
    print(data)


def get_cic_and_sic():
    udp_packets_df = pd.read_feather("feathers/udp_packets.feather")
    df = ap.get_packets_sent_by_node(udp_packets_df)

    print("\n\n Loss calculations")
    df = df.drop(columns=["length_public_key", "length_ciphertext"])
    # print(df)
    df["cic"] = df["client_sent_packets_with_crypto_count"] - 1
    df["sic"] = df["server_sent_packets_with_crypto_count"]
    df = df.drop(
        columns=[
            "client_sent_packets_with_crypto_count",
            "server_sent_packets_with_crypto_count",
        ]
    )
    return df


if __name__ == "__main__":
    main()