import pandas as pd

FEATHERS_DIR = "feathers"

import generate_graphs as gg
import helper_scripts.helper_functions as hf


def main():
    data = pd.read_feather(f"{FEATHERS_DIR}/data.feather")
    # data = pd.read_feather(f"{FEATHERS_DIR}/data_run_20241028.feather")

    static_scenario_statistical_analysis(data)
    # median_of_all_static_runs_per_algorithm(data)
    # stats_of_qtl95_of_packetloss(data)
    # error_count_and_rate(data)
    # measurements_with_negative_skewness(data)
    # iqr_kurtosis_of_delay_data(data)

    # print_kem_ids()


def static_scenario_statistical_analysis(data):
    ldata = data
    print("Static scenario statistical analysis")
    ldata = gg.filter_data(
        ldata,
        scenario="static",
        protocol="quic",
        sec_level=["secLevel1", "secLevel1_hybrid"],
    )

    means_of_medians = []
    stdevs_of_medians = []

    kem_alg_names = ldata["kem_alg"].unique()
    for kem_alg_name in kem_alg_names:
        kem_alg_data = ldata.query(f"kem_alg == '{kem_alg_name}'")
        medians = kem_alg_data["median"]
        # print(kem_alg_name, medians.mean(), medians.std())
        means_of_medians.append(medians.mean())
        stdevs_of_medians.append(medians.std())

    print("Mean of stdevs of medians")
    print(pd.Series(stdevs_of_medians).mean())
    print("Stdev of stdevs of medians")
    print(pd.Series(stdevs_of_medians).std())


def median_of_all_static_runs_per_algorithm(data):
    ldata = data
    print("Median of all static runs per algorithm")
    ldata = gg.filter_data(ldata, scenario="static", protocol="quic")
    # compound per algorithm, then take the median of all

    # get every algorithm name
    # print(ldata["kem_alg"].unique())
    kem_alg_names = ldata["kem_alg"].unique()
    for kem_alg_name in kem_alg_names:
        kem_alg_data = ldata.query(f"kem_alg == '{kem_alg_name}'")
        # print(kem_alg_data)
        kem_alg_measurements = []
        for row in kem_alg_data.iterrows():
            # print(row[1]["measurements"])
            kem_alg_measurements.extend(row[1]["measurements"])
            # print(row[1]["median"])
        print(f"Median of {kem_alg_name}")
        print(pd.Series(kem_alg_measurements).median())
        print()


def stats_of_qtl95_of_packetloss(data):
    ldata = data
    print("Stats of qtl95")

    ldata = gg.filter_data(ldata, scenario="packetloss", protocol="quic")
    ldata = ldata.query("kem_alg == 'x25519' or kem_alg == 'frodo640aes'")
    # ldata = ldata.query("kem_alg == 'mlkem1024' or kem_alg == 'frodo1344aes'")

    # ldata = ldata.query
    print("Showing data of packetloss quic")
    ldata = ldata.drop(
        columns=[
            "scenario",
            "protocol",
            "sec_level",
            "cli_pkt_loss",
            "cli_delay",
            "cli_rate",
            "measurements",
        ]
    )
    print(ldata)


# For old run without bigger crypto buffer: Grep tells there are 83996 CRYPTO_BUFFER_EXCEEDEDs, while total error count is just a bit above it 84186
# For new run with fix: 187.0 other errors, probably from server side, because 'Shutdown before completion' on client side while waiting for handshake to complete -> b'808B57C2E1760000:error:0A0000CF:SSL routines:quic_do_handshake:protocol is shutdown:ssl/quic/quic_impl.c:1717:\n'
def error_count_and_rate(data):
    print("Error count and rate")
    ldata = data
    print("Total index length")
    print(len(ldata.index))
    print("Total error count")
    print(ldata["error_count"].sum())
    ldata = ldata.query("error_count > 0")
    print("Total index length with error count > 0")
    print(len(ldata.index))
    print("Error count describe")
    print(ldata["error_count"].describe())
    print(ldata["scenario"].value_counts())
    # print(ldata["scenario"].unique()) # all 10 scenarios
    print("With error count > 1")
    ldata = ldata.query("error_count > 1")
    print(
        ldata[
            [
                "scenario",
                "protocol",
                "sec_level",
                "kem_alg",
                "error_count",
                "error_rate",
            ]
        ]
    )


def measurements_with_negative_skewness(data):
    print("Measurements with negative skewness")
    ldata = data
    print("Skewness of data")
    print(ldata["skewness"].describe())

    print("Amount of data with negative skewness")
    ldata = ldata.query("skewness < 0")
    print(len(ldata.index))
    # ldata = ldata.query("scenario != 'reorder'")
    # print(len(ldata.index))
    # give out per scenario the count of measurements with negative skewness
    print("Per scenario numbers of measurements with negative skewness")
    print(ldata["scenario"].value_counts())  # mostly reorder and jitter, rate a bit


def iqr_kurtosis_of_delay_data(data):
    print("Kurtosis of data, Fisher's definition, so 0 is normal distribution")
    ldata = data
    print(ldata[["iqr", "kurtosis"]].describe())
    ldata = ldata.query("scenario == 'delay'")
    print(ldata[["iqr", "kurtosis"]].describe())


def print_kem_ids():
    data = hf.get_kem_ids()
    print(data)


main()