Skip to content
Snippets Groups Projects
generate_graphs.py 49.1 KiB
Newer Older
#!/usr/bin/env python

from functools import cmp_to_key
import colorsys
import os
import sys

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import scipy

import helper_scripts.helper_functions as helper_functions

RESULTS_DIR = "saved/results-run-20241231-vm-p16"
FILTER_RESULTS = []
PLOTS_DIR = "plots"
FEATHERS_DIR = "feathers"

cmap = plt.cm.hsv
# cmap = plt.colormaps.get_cmap("nipy_spectral")
    # generally they both only seconds for graphs when not generating for single algorithms
    # plot_general_plots()  # takes about 4 seconds
    # plot_lines(data)  # takes about 1:50 min
    # plot_static_data(data)  # takes about 4 min
    plot_distributions(data)


def load_data():
    if os.path.exists(f"{FEATHERS_DIR}/data.feather"):
        data = pd.read_feather(f"{FEATHERS_DIR}/data.feather")
    else:
        data = read_data_into_pandas()
    return data


# Reading in the data takes about 11 seconds per scenario
def read_data_into_pandas():
    data = pd.DataFrame(
        columns=[
            "scenario",
            "protocol",
            "sec_level",
            "kem_alg",
            "srv_pkt_loss",
            "srv_delay",
            "srv_jitter",
            "srv_duplicate",
            "srv_corrupt",
            "srv_reorder",
            "srv_rate",
            "cli_pkt_loss",
            "cli_delay",
            "cli_jitter",
            "cli_duplicate",
            "cli_corrupt",
            "cli_reorder",
            "cli_rate",
            "error_count",
            "error_rate",
            "qtl_01",
            "qtl_05",
            "skewness",
            "kurtosis",
        ]
    )

    def get_all_result_files():
        result_files = []
        for dirpath, _, filenames in os.walk(RESULTS_DIR):
            if filenames and not any(
                filter_value in dirpath for filter_value in FILTER_RESULTS
            ):
                for filename in filenames:
                    result_files.append(os.path.join(dirpath, filename))
        return result_files

    for csv_result_file_name in get_all_result_files():
        *_, scenario, protocol, sec_level, kem_alg = csv_result_file_name.split("/")
        kem_alg = kem_alg.split(".")[0]
        # print(f"csv_result_file_name: {csv_result_file_name}")
        result_file_data = pd.read_csv(csv_result_file_name, header=None)
        result_file_data = result_file_data.T
        df_scenariofile = pd.read_csv(f"testscenarios/scenario_{scenario}.csv")
        df_scenariofile = df_scenariofile.drop(
            df_scenariofile.columns[0], axis="columns"
        )

        assert len(result_file_data.columns) == len(df_scenariofile)
        for i in range(len(result_file_data.columns)):
            measurements_and_error_count = result_file_data.iloc[:, i].tolist()
            error_count = measurements_and_error_count[0]
            measurements = np.array(measurements_and_error_count[1:])
            data.loc[len(data)] = {
                "scenario": scenario,
                "protocol": protocol,
                "sec_level": sec_level,
                "kem_alg": kem_alg,
                "srv_pkt_loss": df_scenariofile.iloc[i]["srv_pkt_loss"],
                "srv_delay": df_scenariofile.iloc[i]["srv_delay"],
                "srv_jitter": df_scenariofile.iloc[i]["srv_jitter"],
                "srv_duplicate": df_scenariofile.iloc[i]["srv_duplicate"],
                "srv_corrupt": df_scenariofile.iloc[i]["srv_corrupt"],
                "srv_reorder": df_scenariofile.iloc[i]["srv_reorder"],
                "srv_rate": df_scenariofile.iloc[i]["srv_rate"],
                "cli_pkt_loss": df_scenariofile.iloc[i]["cli_pkt_loss"],
                "cli_delay": df_scenariofile.iloc[i]["cli_delay"],
                "cli_jitter": df_scenariofile.iloc[i]["cli_jitter"],
                "cli_duplicate": df_scenariofile.iloc[i]["cli_duplicate"],
                "cli_corrupt": df_scenariofile.iloc[i]["cli_corrupt"],
                "cli_reorder": df_scenariofile.iloc[i]["cli_reorder"],
                "cli_rate": df_scenariofile.iloc[i]["cli_rate"],
                "error_count": error_count,
                "error_rate": error_count / len(measurements),
                "measurements": measurements,
                "mean": np.mean(measurements),
                "std": np.std(measurements),
                "cv": np.std(measurements) / np.mean(measurements),
                "median": np.median(measurements),
                "qtl_01": np.quantile(measurements, 0.01),
                "qtl_05": np.quantile(measurements, 0.05),
                "qtl_25": np.quantile(measurements, 0.25),
                "qtl_75": np.quantile(measurements, 0.75),
                "qtl_95": np.quantile(measurements, 0.95),
                "qtl_99": np.quantile(measurements, 0.99),
                "iqr": scipy.stats.iqr(measurements),
                "riqr": scipy.stats.iqr(measurements) / np.median(measurements),
                "skewness": scipy.stats.skew(measurements),
                "kurtosis": scipy.stats.kurtosis(measurements),
            }

    dtypes = {
        "scenario": "category",
        "protocol": "category",
        "sec_level": "category",
        "kem_alg": "category",
    }
    data = data.astype(dtypes)
    categories = [
        "secp256r1",
        "secp384r1",
        "secp521r1",
        "x25519",
        "x448",
        "mlkem512",
        "p256_mlkem512",
        "x25519_mlkem512",
        "mlkem768",
        "p384_mlkem768",
        "x448_mlkem768",
        "x25519_mlkem768",
        "p256_mlkem768",
        "mlkem1024",
        "p521_mlkem1024",
        "p384_mlkem1024",
        "bikel1",
        "p256_bikel1",
        "x25519_bikel1",
        "bikel3",
        "p384_bikel3",
        "x448_bikel3",
        "bikel5",
        "p521_bikel5",
        "hqc128",
        "p256_hqc128",
        "x25519_hqc128",
        "hqc192",
        "p384_hqc192",
        "x448_hqc192",
        "hqc256",
        "p521_hqc256",
        "frodo640aes",
        "p256_frodo640aes",
        "x25519_frodo640aes",
        "frodo640shake",
        "p256_frodo640shake",
        "x25519_frodo640shake",
        "frodo976aes",
        "p384_frodo976aes",
        "x448_frodo976aes",
        "frodo976shake",
        "p384_frodo976shake",
        "x448_frodo976shake",
        "frodo1344aes",
        "p521_frodo1344aes",
        "frodo1344shake",
        "p521_frodo1344shake",
    ]
    data["kem_alg"] = pd.Categorical(
        data["kem_alg"], categories=categories, ordered=True
    )

    print(data.head())
    print(data.describe())
    print(data.info())
    print()
    print("Scenarios read:", data["scenario"].unique())

    os.makedirs(FEATHERS_DIR, mode=0o777, exist_ok=True)
    data.to_feather(f"{FEATHERS_DIR}/data.feather")
    print("Data written to feather file")
    return data


def filter_data(
    data,
    scenario: str | None = None,
    protocol: str | None = None,
    sec_level: str | list[str] | None = None,
    kem_alg: str | None = None,
    drop_zero_columns: bool = True,
):
    filtered_data = data
    # print(filtered_data["kem_alg"] == "x25519") # is a boolean series
    if scenario is not None:
        filtered_data = filtered_data[filtered_data["scenario"] == scenario]
    if protocol is not None:
        filtered_data = filtered_data[filtered_data["protocol"] == protocol]
    if sec_level is not None:
        if type(sec_level) == list:
            filtered_data = filtered_data[filtered_data["sec_level"].isin(sec_level)]
        else:
            filtered_data = filtered_data[filtered_data["sec_level"] == sec_level]
    if kem_alg is not None:
        filtered_data = filtered_data[filtered_data["kem_alg"] == kem_alg]

    def drop_columns_with_only_zero_values(data):
        # this complicated way is necessary, because measurements is a list of values
        filtered_data_without_measurements = data.drop(columns=["measurements"])
        zero_columns_to_drop = (filtered_data_without_measurements != 0).any()
        zero_columns_to_drop = [
            col
            for col in filtered_data_without_measurements.columns
            if not zero_columns_to_drop[col]
        ]
        return data.drop(columns=zero_columns_to_drop)

    if drop_zero_columns and "measurements" in data.columns:
        filtered_data = drop_columns_with_only_zero_values(filtered_data)

        #     if drop_zero_columns:
        # filtered_data = drop_columns_with_only_zero_values(filtered_data)

    # print(filtered_data["measurements"].head())
    # print(filtered_data)
    return filtered_data


def get_x_axis_column_name(scenario: str) -> str:
    match scenario:
        case "duplicate":
            return "srv_duplicate"
        case "packetloss":
            return "srv_pkt_loss"
        case "delay":
            return "srv_delay"
        case "jitter_delay20ms":
            return "srv_jitter"
        case "corrupt":
            return "srv_corrupt"
        case "reorder":
            return "srv_reorder"
        case "rate_both":
            return "srv_rate"
        case "rate_client":
            return "cli_rate"
        case "rate_server":
            return "srv_rate"
        case "static":
            assert False, "static scenario has no x-axis"
        case _:
            print(f"NO MATCH FOUND FOR {scenario}", file=sys.stderr)
            sys.exit(1)


def get_x_axis(scenario, data, length):
    match scenario:
        case "duplicate":
            return data["srv_duplicate"]
        case "packetloss":
            return data["srv_pkt_loss"]
        case "delay":
            return data["srv_delay"]
            return data["srv_jitter"]
        case "corrupt":
            return data["srv_corrupt"]
        case "reorder":
            return data["srv_reorder"]
        case "rate_both":
            return data["srv_rate"]
        case "rate_client":
            return data["cli_rate"]
        case "rate_server":
            return data["srv_rate"]
        case "static":
            return pd.Series(range(length))
        case _:
            print(f"NO MATCH FOUND FOR {scenario}", file=sys.stderr)
            sys.exit(1)


def map_security_level_hybrid_together(sec_level: str):
    match sec_level:
        case "secLevel1":
            return ["secLevel1", "secLevel1_hybrid"]
        case "secLevel3":
            return ["secLevel3", "secLevel3_hybrid"]
        case "secLevel5":
            return ["secLevel5", "secLevel5_hybrid"]
        case "miscLevel":
            return "miscLevel"
        case _:
            return None
def get_color_and_mode(kem_alg: str, combined_with_hybrids: bool = False):
    # NOTE maybe just use hle colors directly from the start
    primary_mode = "-"
    secondary_mode = "--" if combined_with_hybrids else "-"
    tertiary_mode = ":" if combined_with_hybrids else "--"

    secondary_lightness_factor = 0.9 if combined_with_hybrids else 1
    tertiary_lightness_factor = 0.8 if combined_with_hybrids else 0.9

    a = 0.8 if combined_with_hybrids else 1

    no_algos = 8
    match kem_alg:
        case "secp256r1":
            return cmap(0 / no_algos), primary_mode
        case "secp384r1":
            return cmap(0.3 / no_algos), primary_mode
        case "secp521r1":
            return cmap(0.6 / no_algos), primary_mode
        case "x25519":
            return cmap(0 / no_algos), "--"
        case "x448":
            return cmap(0.3 / no_algos), "--"
        case "mlkem512":
            return cmap(1 / no_algos), primary_mode
        case "p256_mlkem512":
            return (
                transform_cmap_color(
                    cmap(1 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x25519_mlkem512":
            return (
                transform_cmap_color(
                    cmap(1 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "mlkem768":
            return cmap(1.3 / no_algos), primary_mode
        case "p384_mlkem768":
            return (
                transform_cmap_color(
                    cmap(1.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x448_mlkem768":
            return (
                transform_cmap_color(
                    cmap(1.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "mlkem1024":
            return cmap(1.6 / no_algos), primary_mode
        case "p521_mlkem1024":
            return (
                transform_cmap_color(
                    cmap(1.6 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "bikel1":
            return cmap(2 / no_algos), primary_mode
        case "p256_bikel1":
            return (
                transform_cmap_color(
                    cmap(2 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x25519_bikel1":
            return (
                transform_cmap_color(
                    cmap(2 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "bikel3":
            return cmap(2.3 / no_algos), primary_mode
        case "p384_bikel3":
            return (
                transform_cmap_color(
                    cmap(2.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x448_bikel3":
            return (
                transform_cmap_color(
                    cmap(2.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "bikel5":
            return cmap(2.6 / no_algos), primary_mode
        case "p521_bikel5":
            return (
                transform_cmap_color(
                    cmap(2.6 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "hqc128":
            return cmap(4 / no_algos), primary_mode
        case "p256_hqc128":
            return (
                transform_cmap_color(
                    cmap(4 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x25519_hqc128":
            return (
                transform_cmap_color(
                    cmap(4 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "hqc192":
            return cmap(4.3 / no_algos), primary_mode
        case "p384_hqc192":
            return (
                transform_cmap_color(
                    cmap(4.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x448_hqc192":
            return (
                transform_cmap_color(
                    cmap(4.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "hqc256":
            return cmap(4.6 / no_algos), primary_mode
        case "p521_hqc256":
            return (
                transform_cmap_color(
                    cmap(4.6 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "frodo640aes":
            return cmap(5 / no_algos), primary_mode
        case "p256_frodo640aes":
            return (
                transform_cmap_color(
                    cmap(5 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x25519_frodo640aes":
            return (
                transform_cmap_color(
                    cmap(5 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "frodo640shake":
            return cmap(6 / no_algos), primary_mode
        case "p256_frodo640shake":
            return (
                transform_cmap_color(
                    cmap(6 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x25519_frodo640shake":
            return (
                transform_cmap_color(
                    cmap(6 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "frodo976aes":
            return cmap(5.3 / no_algos), primary_mode
        case "p384_frodo976aes":
            return (
                transform_cmap_color(
                    cmap(5.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x448_frodo976aes":
            return (
                transform_cmap_color(
                    cmap(5.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "frodo976shake":
            return cmap(6.3 / no_algos), primary_mode
        case "p384_frodo976shake":
            return (
                transform_cmap_color(
                    cmap(6.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x448_frodo976shake":
            return (
                transform_cmap_color(
                    cmap(6.3 / no_algos),
                    alpha_factor=a,
                    lightness_factor=tertiary_lightness_factor,
                ),
                tertiary_mode,
            )
        case "frodo1344aes":
            return cmap(5.6 / no_algos), primary_mode
        case "p521_frodo1344aes":
            return (
                transform_cmap_color(
                    cmap(5.6 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "frodo1344shake":
            return cmap(6.6 / no_algos), primary_mode
        case "p521_frodo1344shake":
            return (
                transform_cmap_color(
                    cmap(6.6 / no_algos),
                    alpha_factor=a,
                    lightness_factor=secondary_lightness_factor,
                ),
                secondary_mode,
            )
        case "x25519_mlkem768":
            return (
                transform_cmap_color(cmap(7 / no_algos), alpha_factor=a),
                primary_mode,
            )
        case "p256_mlkem768":
            return (
                transform_cmap_color(
                    cmap(7.5 / no_algos), lightness_factor=secondary_lightness_factor
                ),
                primary_mode,
            )
        case "p384_mlkem1024":
            return (
                transform_cmap_color(
                    cmap(7.99 / no_algos), lightness_factor=tertiary_lightness_factor
                ),
                primary_mode,
            )
        case _:
            print(f"NO COLOR MATCH FOUND FOR {kem_alg}", file=sys.stderr)
            sys.exit(1)


def get_color_and_mode_for_protocol(protocol: str):
    primary_mode = "-"

    no_protocols = 6
    match protocol:
        case "quic":
            return cmap(0 / no_protocols), primary_mode
        case "tlstcp":
            return cmap(1 / no_protocols), primary_mode
        case "cquiche-reno":
            return cmap(2 / no_protocols), primary_mode
        case "cquiche-cubic":
            return cmap(3 / no_protocols), primary_mode
        case "cquiche-bbr":
            return cmap(4 / no_protocols), primary_mode
        case "cquiche-bbr2":
            return cmap(5 / no_protocols), primary_mode
        case _:
            print(f"NO COLOR MATCH FOUND FOR Protocol {protocol}", file=sys.stderr)
            sys.exit(1)


def transform_cmap_color(
    color, hue_shift=0, saturation_factor=1, lightness_factor=1, alpha_factor=1
):
    def value_between(minimum, value, maximum):
        return max(minimum, min(value, maximum))

    r, g, b, a = color
    h, l, s = colorsys.rgb_to_hls(r, g, b)
    h += hue_shift
    l *= lightness_factor
    l = value_between(0, l, 1)
    s *= saturation_factor
    s = value_between(0, s, 1)
    r, g, b = colorsys.hls_to_rgb(h, l, s)
    a *= alpha_factor
    return r, g, b, a


# plots lines of different statistical values
def plot_lines(data):
    def plot_lines_for_sec_level(
        data, line_type="median", combined_with_hybrids: bool = False
    ):
            f"{PLOTS_DIR}/lines/per-protocol/{line_type}s-of-sec-level/combined-with-hybrids",
        # get all combination of scenario, protocol, sec_level
        unique_combinations = data[
            ["scenario", "protocol", "sec_level"]
        ].drop_duplicates()
        # print(len(unique_combinations))
        # print(unique_combinations)
        for _, row in unique_combinations.iterrows():
            sec_level = row["sec_level"]
            if combined_with_hybrids:
                sec_level = map_security_level_hybrid_together(row["sec_level"])
                if sec_level is None:
                    continue
            filtered_data = filter_data(
                data,
                scenario=row["scenario"],
                protocol=row["protocol"],
                sec_level=sec_level,
            )
            # print(f"scenario: {row['scenario']}, protocol: {row['protocol']}, sec_level: {row['sec_level']}")
            plt.figure()
            for idx, kem_alg in enumerate(
                filtered_data["kem_alg"].unique().sort_values()
            ):
                # color = cmap(idx / len(filtered_data["kem_alg"].unique()))
                color, mode = get_color_and_mode(
                    kem_alg, combined_with_hybrids=combined_with_hybrids
                )

                filtered_data_single_kem_alg = filter_data(
                    filtered_data, kem_alg=kem_alg
                )
                # print(filtered_data_single_kem_alg)
                y = filtered_data_single_kem_alg[line_type]
                x = get_x_axis(row["scenario"], filtered_data_single_kem_alg, len(y))

                # print(
                #     f"scenario: {row['scenario']}, protocol: {row['protocol']}, sec_level: {row['sec_level']}, kem_alg: {kem_alg}"
                # )
                # print(f"x: {x}")
                # print(f"y: {y}")

                # plt.fill_between(x, filtered_data_single_kem_alg["qtl_25"], filtered_data_single_kem_alg["qtl_75"], alpha=0.2, color=color)
                plt.plot(x, y, linestyle=mode, marker=".", color=color, label=kem_alg)

            plt.ylim(bottom=0)
            plt.xlim(left=0)
            plt.xlabel(row["scenario"])
            plt.ylabel(f"Time-to-first-byte (ms)")
            # plt.title(
            #     f"Medians of {row['scenario']} in {row['protocol']} in {row['sec_level']}"
            # )
            plt.legend(
                bbox_to_anchor=(0.5, 1), loc="lower center", ncol=3, fontsize="small"
            )
            plt.tight_layout()
            subdir = ""
            appendix = ""
            if combined_with_hybrids:
                subdir = "combined-with-hybrids/"
                appendix = "-combined-with-hybrids"
            plt.savefig(
                f"{PLOTS_DIR}/lines/per-protocol/{line_type}s-of-sec-level/{subdir}{line_type}-{row['scenario']}-{row['protocol']}-{row['sec_level']}{appendix}.pdf"
            )
            plt.close()

    def plot_lines_for_comparisons_between_protocols(data, line_type="median"):
        os.makedirs(
            f"{PLOTS_DIR}/lines/between-protocols/comparison-of-{line_type}s",
            mode=0o777,
            exist_ok=True,
        )

        # sec_level is only needed for the filename
        unique_combinations = data[
            ["scenario", "kem_alg", "sec_level"]
        ].drop_duplicates()

        for _, row in unique_combinations.iterrows():
            filtered_data = filter_data(
                data,
                scenario=row["scenario"],
                kem_alg=row["kem_alg"],
            )

            plt.figure()
            for idx, protocol in enumerate(
                filtered_data["protocol"].unique().sort_values()
            ):
                color, mode = get_color_and_mode_for_protocol(protocol)

                filtered_data_single_protocol = filter_data(
                    filtered_data, protocol=protocol
                )
                y = filtered_data_single_protocol[line_type]
                x = get_x_axis(row["scenario"], filtered_data_single_protocol, len(y))

                plt.plot(x, y, linestyle=mode, marker=".", color=color, label=protocol)

            plt.ylim(bottom=0)
            plt.xlim(left=0)
            plt.xlabel(row["scenario"])
            plt.ylabel(f"Time-to-first-byte (ms)")
            plt.legend(
                bbox_to_anchor=(0.5, 1), loc="lower center", ncol=3, fontsize="small"
            )
            plt.tight_layout()

            subdir = ""
            appendix = ""
            plt.savefig(
                f"{PLOTS_DIR}/lines/between-protocols/comparison-of-{line_type}s/{subdir}{line_type}-{row['scenario']}-{row['sec_level']}-{row['kem_alg']}{appendix}.pdf"
    def plot_median_of_single_algorithm(data):
        os.makedirs(
            f"{PLOTS_DIR}/median-of-single-algorithm", mode=0o777, exist_ok=True
        )
        # get all combination of scenario, protocol, sec_level, kem_alg
        unique_combinations = data[
            ["scenario", "protocol", "sec_level", "kem_alg"]
        ].drop_duplicates()
        for _, row in unique_combinations.iterrows():
                scenario=row["scenario"],
                protocol=row["protocol"],
                sec_level=row["sec_level"],
                kem_alg=row["kem_alg"],
            )
            # print(f"scenario: {row['scenario']}, protocol: {row['protocol']}, sec_level: {row['sec_level']}, kem_alg: {row['kem_alg']}")
            y = filtered_data["median"]
            x = get_x_axis(row["scenario"], filtered_data, len(y))
            plt.fill_between(
                x, filtered_data["qtl_25"], filtered_data["qtl_75"], alpha=0.5
            plt.plot(x, y, linestyle="-", marker=".")
            plt.ylim(bottom=0)
            plt.xlim(left=0)
            plt.xlabel(row["scenario"])
            plt.ylabel(f"Time-to-first-byte (ms)")
            plt.title(
                f"Median of {row['scenario']} in {row['protocol']} in {row['sec_level']} with {row['kem_alg']}"
                f"{PLOTS_DIR}/median-of-single-algorithm/median-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.pdf"
    statistical_measurements = [
        "qtl_01",
        "qtl_05",
        "qtl_25",
        "qtl_75",
        "qtl_95",
        "qtl_99",
        "iqr",
        "riqr",
        "skewness",
        "kurtosis",
    ]
    for statistical_measurement in statistical_measurements:
        print(f"Generating graphs for {statistical_measurement}")
        plot_lines_for_sec_level(
            data, line_type=statistical_measurement, combined_with_hybrids=False
        )
        plot_lines_for_sec_level(
            data, line_type=statistical_measurement, combined_with_hybrids=True
        )
        plot_lines_for_comparisons_between_protocols(
            data, line_type=statistical_measurement
        )
    # plot_median_of_single_algorithm(data)
    # plot_median_against_iqr(data)


# plots distributions of the individual data points
def plot_distributions(data):
    os.makedirs(
        f"{PLOTS_DIR}/distributions/single",
        mode=0o777,
        exist_ok=True,
    )

    def plot_multiple_violin_plots(data, filtered: bool = False):
        os.makedirs(
            f"{PLOTS_DIR}/distributions/filtered",
            mode=0o777,
            exist_ok=True,
        )

        unique_combinations = data[
            ["scenario", "protocol", "sec_level", "kem_alg"]
        ].drop_duplicates()
        # print(unique_combinations)
        for _, row in unique_combinations.iterrows():
            filtered_data = filter_data(
                data,
                scenario=row["scenario"],
                protocol=row["protocol"],
                sec_level=row["sec_level"],
                kem_alg=row["kem_alg"],
            )
            if row["scenario"] == "static":
                continue
            # print(
            #     f"scenario: {row['scenario']}, protocol: {row['protocol']}, sec_level: {row['sec_level']}, kem_alg: {row['kem_alg']}, len: {len(filtered_data)}"
            # )
            # return the data with removed rows and the ticks for the x-axis
            def remove_rows_of_data(data):
                scenario = data.iloc[0]["scenario"]
                scenario_column_name = get_x_axis_column_name(scenario)
                match scenario:
                    case "packetloss" | "duplicate" | "reorder" | "corrupt":
                        # return data where src_packetloss is 0, 4, 8, 12, 16 or 20
                        ldata = data.query(f"{scenario_column_name} % 4 == 0")
                    case "jitter_delay20ms":
                        ldata = data[
                            data[scenario_column_name].isin([0, 3, 7, 12, 15, 20])
                        ]
                    case "rate_both" | "rate_client" | "rate_server":
                        ldata = data[
                            data[scenario_column_name].isin(
                                [
                                    0.1,
                                    5,
                                    10,
                                    100,
                                ]
                            )
                        ]
                    case "delay":
                        ldata = data[
                            data[scenario_column_name].isin([1, 20, 40, 80, 100, 190])
                        ]
                    case _:
                        print("No case for this scenario:", scenario)
                        ldata = pd.concat(
                            [
                                data.iloc[[0]],
                                data.iloc[3:-4:4],
                                data.iloc[[-1]],
                            ]
                        )
                return ldata, ldata[scenario_column_name].to_list()
            if filtered:
                filtered_data, tick_values = remove_rows_of_data(filtered_data)
                # if filtered_data.iloc[0]["scenario"] == "jitter_delay20ms":
                #     exit()

            plt.figure()
            x = get_x_axis(row["scenario"], filtered_data, len(filtered_data))
            x = x.to_list()
            # print(x)
            width = 0.5 if not filtered else 2.5
            vplots = plt.violinplot(
                filtered_data["measurements"],
                positions=x,
                showmedians=False,
                showextrema=False,
                widths=width,
            )
            for pc in vplots["bodies"]:
                pc.set_facecolor("blue")
                pc.set_edgecolor("darkblue")
            # make the median line transparent
            # for pc in plt.gca().collections:
            #     pc.set_alpha(0.5)

            plt.ylim(bottom=0)
            if filtered:
                plt.ylim(bottom=0, top=1.5 * filtered_data["qtl_95"].max())
                plt.xticks(tick_values)
            plt.xlim(left=-1.5)
            plt.xlabel(row["scenario"])
            plt.ylabel(f"Time-to-first-byte (ms)")

            subdir = "filtered/" if filtered else ""
            appendix = "-filtered" if filtered else ""
            plt.savefig(
                f"{PLOTS_DIR}/distributions/{subdir}multiple-violin-plots-for-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}{appendix}.pdf"
            )
            plt.close()

    def plot_single_violin_plot(data):
        unique_combinations = data[
            ["scenario", "protocol", "sec_level", "kem_alg"]
        ].drop_duplicates()
        for _, row in unique_combinations.iterrows():
            filtered_data = filter_data(
                data,
                scenario=row["scenario"],
                protocol=row["protocol"],
                sec_level=row["sec_level"],
                kem_alg=row["kem_alg"],
            )
            # print(
            #     f"scenario: {row['scenario']}, protocol: {row['protocol']}, sec_level: {row['sec_level']}, kem_alg: {row['kem_alg']}"
            # )
            for _, row in filtered_data.iterrows():
                if row["scenario"] == "static":
                    continue
                value = get_x_axis(row["scenario"], row, 1)

                plt.figure()
                plt.violinplot(row["measurements"], showmedians=True)
                # plt.ylim(bottom=0)
                # plt.xlim(left=0)
                plt.xlabel("Dichte")
                plt.ylabel(f"Time-to-first-byte (ms)")

                plt.savefig(
                    f"{PLOTS_DIR}/distributions/single/single-violin-plot-for-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}-{value}.pdf"
                )
                plt.close()
                # return

    plot_multiple_violin_plots(data, filtered=False)
    plot_multiple_violin_plots(data, filtered=True)
    # plot_single_violin_plot(data)  # takes an age


# TODO make a violinplot/eventplot for many algos in static scenario
def plot_static_data(data):
    os.makedirs(f"{PLOTS_DIR}/static/single", mode=0o777, exist_ok=True)

    def plot_static_data_for_multiple_algorithms(data):
        unique_combinations = data[["protocol", "sec_level"]].drop_duplicates()
        for idx, row in unique_combinations.iterrows():
            sec_level = map_security_level_hybrid_together(row["sec_level"])
            if sec_level is None:
                continue

            filtered_data = filter_data(
                data,
                scenario="static",
                protocol=row["protocol"],
                sec_level=sec_level,
            )

            # plt.figure(figsize=(6, 6))
            plt.figure()
            kem_algs = []
            for idx, kem_alg in enumerate(
                filtered_data["kem_alg"].unique().sort_values(),
            ):
                kem_algs.append(kem_alg)
                filtered_data_single_kem_alg = filter_data(
                    filtered_data, kem_alg=kem_alg
                )
                plt.boxplot(
                    filtered_data_single_kem_alg["median"],
                    positions=[idx],
                    widths=0.6,
                )

                iqrs = pd.concat(
                    [
                        iqrs,
                        pd.DataFrame(
                            {
                                "kem_alg": [kem_alg],
                                "iqr": [
                                    scipy.stats.iqr(
                                        filtered_data_single_kem_alg["median"]
                                    )
                                ],
                            }
                        ),
                    ],
                    ignore_index=True,
                )
                # print(
                #     f"IQR for {kem_alg}: {scipy.stats.iqr(filtered_data_single_kem_alg['median'])}"
                # )

            # Get the median irqs for all algorithms
            # print(row["protocol"], row["sec_level"])
            # print(iqrs)
            # print(iqrs["iqr"].describe())
            # print("Median:", iqrs["iqr"].median())
            # print("IQR:", scipy.stats.iqr(iqrs["iqr"]))
            # print()
            plt.xticks(
                range(len(filtered_data["kem_alg"].unique())),
                kem_algs,
                rotation=45,
                ha="right",
            )
            plt.xlabel("KEM Algorithms")
            plt.ylabel("Time-to-first-byte (ms)")

            sec_level_string = (
                sec_level if type(sec_level) == str else "-".join(sec_level)
            )
            plt.tight_layout()
            plt.savefig(
                os.path.join(
                    PLOTS_DIR,
                    "static",
                    f"boxplots-of-medians-for-static-{row['protocol']}-{sec_level_string}.pdf",
    def plot_static_data_for_single_algorithms(data):
        unique_combinations = data[
            ["scenario", "protocol", "sec_level", "kem_alg"]
        ].drop_duplicates()
        unique_combinations = filter_data(unique_combinations, scenario="static")
        for idx, row in unique_combinations.iterrows():
            filtered_data = filter_data(
                data,
                scenario="static",
                protocol=row["protocol"],
                sec_level=row["sec_level"],
                kem_alg=row["kem_alg"],
            )

            def boxplot_of_medians_for_configuration(filtered_data, row):
                plt.figure()
                plt.boxplot(filtered_data["median"])
                plt.savefig(
                    os.path.join(
                        PLOTS_DIR,
                        "static",
                        "single",
                        f"boxplot-of-medians-for-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.pdf",
                    )
                plt.close()

            # why the density of violin plot and kde plot differ, while using the same scott kde just sideways:
            # Dove deep into the implementation from matplotlib and scipy, and they seem to calculate scotts factor in the same way, so dunno
            def condensed_violin_plot_for_configuration(filtered_data, row):
                # for multiple runs of the same static scenario, data taken together
                measurements_flattend = filtered_data["measurements"].explode().tolist()
                # print(filtered_data["measurements"].explode())
                # print(len(measurements_flattend))
                plt.figure()
                plt.violinplot(measurements_flattend, showmedians=True)
                plt.ylabel("Time-to-first-byte (ms)")
                plt.xlabel("Dichte")
                plt.savefig(
                    os.path.join(
                        PLOTS_DIR,
                        "static",
                        "single",
                        f"condensed-violin-plot-for-{len(measurements_flattend)}-measurements-of-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.pdf",
                    )
                plt.close()

            def condensed_histogram_plot_for_configuration(filtered_data, row):
                measurements_flattend = filtered_data["measurements"].explode().tolist()
                plt.figure()
                plt.hist(measurements_flattend, bins=100, density=True)
                plt.xlabel("Time-to-first-byte (ms)")
                plt.ylabel("Dichte")

                plt.savefig(
                    os.path.join(
                        PLOTS_DIR,
                        "static",
                        "single",
                        f"condensed-histogram-plot-for-{len(measurements_flattend)}-measurements-of-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.pdf",
                    )
                )
                plt.close()
            def condensed_kernel_density_estimate_plot_for_configuration(
                filtered_data, row
            ):
                measurements_flattend = filtered_data["measurements"].explode().tolist()
                plt.figure()

                kde = scipy.stats.gaussian_kde(measurements_flattend)
                xmin = min(measurements_flattend) - 0.2
                xmax = max(measurements_flattend) + 0.1
                x = np.linspace(
                    xmin,
                    xmax,
                    1000,
                kde_values = kde(x)

                plt.plot(x, kde_values)
                plt.fill_between(x, kde_values, alpha=0.5)

                plt.xlabel("Time-to-first-byte (ms)")
                plt.ylabel("Dichte")
                plt.xlim([xmin, xmax])
                plt.ylim([0, max(kde_values) + 0.1])

                plt.savefig(
                    os.path.join(
                        PLOTS_DIR,
                        "static",
                        "single",
                        f"condensed-kde-plot-for-{len(measurements_flattend)}-measurements-of-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.pdf",
                    )
                )
                plt.close()

            boxplot_of_medians_for_configuration(filtered_data, row)
            condensed_violin_plot_for_configuration(filtered_data, row)
            condensed_histogram_plot_for_configuration(filtered_data, row)
            condensed_kernel_density_estimate_plot_for_configuration(filtered_data, row)
            # return

    plot_static_data_for_multiple_algorithms(data)
    plot_static_data_for_single_algorithms(data)
def plot_general_plots():
    def get_color_for_kem_algo(kem_algo):
        if "mlkem" in kem_algo:
            return "blue"
        if "bikel" in kem_algo:
            return "red"
        if "hqc" in kem_algo:
            return "green"
        if "frodo" in kem_algo:
            return "orange"
        return "grey"

    os.makedirs(f"{PLOTS_DIR}/general", mode=0o777, exist_ok=True)

    df = helper_functions.prepare_kem_performance_data_for_plotting(
        helper_functions.get_kem_performance_data()
    )

    def plot_send_bytes_against_kem_performance(df, with_hybrids: bool):
        if not with_hybrids:
            # filter out all hybrids, otherwise the plot is too cluttered
            df = df[~df["kem_algo"].str.contains("_")]

        plt.figure()
        # plt.scatter(df["bytes_sent"], df["performance_us"])
        for kem_algo in df["kem_algo"]:
            subset = df[df["kem_algo"] == kem_algo]
            color = get_color_for_kem_algo(kem_algo)
            plt.scatter(
                subset["bytes_sent"],
                subset["performance_us"],
                color=color,
                label=kem_algo,
                alpha=0.7,
            )

        for i, txt in enumerate(df["kem_algo"]):
            plt.annotate(
                txt,
                (df["bytes_sent"].iloc[i], df["performance_us"].iloc[i]),
                xytext=(0, -3),
                textcoords="offset points",
                fontsize=8,
                ha="center",
                va="bottom",
            )

        plt.xscale("log")
        plt.yscale("log")
        plt.xlim(30)
        plt.ylim(30)

        # custom tick stuff from claude
        def custom_ticks(start, end):
            ticks = [start] + [
                10**i for i in range(int(np.log10(start)) + 1, int(np.log10(end)) + 1)
            ]
            return ticks

        x_ticks = custom_ticks(30, df["bytes_sent"].max())
        y_ticks = custom_ticks(30, df["performance_us"].max())
        plt.xticks(x_ticks, [f"{int(x):,}" for x in x_ticks])
        plt.yticks(y_ticks, [f"{int(y):,}" for y in y_ticks])

        plt.xlabel("Bytes sent")
        plt.ylabel("Performance (µs)")

        name = (
            "scatter-of-bytes-sent-against-kem-performance-with-hybrids.pdf"
            if with_hybrids
            else "scatter-of-bytes-sent-against-kem-performance.pdf"
        )
        plt.savefig(
            os.path.join(PLOTS_DIR, "general", name),
            dpi=300,
        )
        plt.close()

    def plot_public_key_length_against_ciphertext_length(
        df, with_hybrids: bool, with_lines: bool
    ):
        if not with_hybrids:
            # filter out all hybrids, otherwise the plot is too cluttered
            df = df[~df["kem_algo"].str.contains("_")]

        # only keep one for frodo, since they are the same and remove the hash algo
        df = df[~df["kem_algo"].str.contains("shake")]
        df.loc[:, "kem_algo"] = df["kem_algo"].apply(lambda x: x.replace("aes", ""))

        plt.figure()
        for kem_algo in df["kem_algo"]:
            subset = df[df["kem_algo"] == kem_algo]
            color = get_color_for_kem_algo(kem_algo)
            plt.scatter(
                subset["length_public_key"],
                subset["length_ciphertext"],
                color=color,
                label=kem_algo,
                alpha=0.7,
            )

        for i, txt in enumerate(df["kem_algo"]):
            annotate_offset = (0, -3)
            if "bikel1" in txt:
                annotate_offset = (0, 0)
            if "mlkem1024" in txt:
                annotate_offset = (0, -7)
            plt.annotate(
                txt,
                (df["length_public_key"].iloc[i], df["length_ciphertext"].iloc[i]),
                xytext=annotate_offset,
                textcoords="offset points",
                fontsize=8,
                ha="center",
                va="bottom",
            )

        if with_lines:
            # reason for these magic numbers in obsidian note [[Packet lengths in QUIC over Ethernet]]
            plt.axvline(x=940, color="purple", linestyle="--", label="1 Paket Grenze")
            plt.axhline(y=277, color="purple", linestyle="--", label="1 Paket Grenze")
            plt.axvline(
                x=940 + 1157, color="purple", linestyle="--", label="2 Paket Grenze"
            )
            plt.axhline(
                y=277 + 1100, color="purple", linestyle="--", label="2 Paket Grenze"
            )

        plt.xscale("log")
        plt.yscale("log")
        plt.xlim(10)
        plt.ylim(10)

        plt.xlabel("Public Key Länge in Bytes")
        plt.ylabel("Ciphertext Länge in Bytes")

        plt.gca().xaxis.set_major_formatter(ticker.ScalarFormatter())
        plt.gca().yaxis.set_major_formatter(ticker.ScalarFormatter())

        with_hybrids_string = "-with-hybrids" if with_hybrids else ""
        with_lines_string = "-with-lines" if with_lines else ""
        name = f"scatter-of-public-key-against-ciphertext-length{with_hybrids_string}{with_lines_string}.pdf"
        plt.savefig(
            os.path.join(PLOTS_DIR, "general", name),
            dpi=300,
        )
        plt.close()

    # This does not yet seem like a good idea
    # TODO use the riqr to make some graphs
    def plot_median_against_iqr(data):
        plt.figure()
        plt.hexbin(data["median"], data["iqr"], gridsize=50)
        print(data["iqr"].describe())
        print(data["median"].describe())
        # get the line with the maximum median
        max_median = data["median"].idxmax()
        print(data.iloc[max_median])
        plt.savefig(f"{PLOTS_DIR}/general/median_against_iqr_hexbin.pdf")
        plt.close()

        plt.figure()
        plt.hist2d(data["median"], data["iqr"], bins=50)
        plt.savefig(f"{PLOTS_DIR}/general/median_against_iqr_hist2d.pdf")
        plt.close()

    plot_send_bytes_against_kem_performance(df, with_hybrids=False)
    plot_send_bytes_against_kem_performance(df, with_hybrids=True)

    plot_public_key_length_against_ciphertext_length(
        df, with_hybrids=False, with_lines=True
    )
    plot_public_key_length_against_ciphertext_length(
        df, with_hybrids=False, with_lines=False
    )
    plot_public_key_length_against_ciphertext_length(
        df, with_hybrids=True, with_lines=True
    )
    plot_public_key_length_against_ciphertext_length(
        df, with_hybrids=True, with_lines=False
    )
if __name__ == "__main__":
    main()