Skip to content
Snippets Groups Projects
generate_graphs.py 22.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • #!/usr/bin/env python
    
    from functools import cmp_to_key
    import os
    import sys
    
    import matplotlib.pyplot as plt
    
    import matplotlib.ticker as ticker
    
    import numpy as np
    import pandas as pd
    import scipy
    
    
    import helper_scripts.helper_functions as helper_functions
    
    RESULTS_DIR = "results-run-20240921-vm-p16"
    
    FILTER_RESULTS = []
    PLOTS_DIR = "plots"
    FEATHERS_DIR = "feathers"
    
    cmap = plt.cm.hsv
    
    
    def main():
        create_directories()
        data = load_data()
    
    
        # plot_median(data)
    
        # plot_static_data(data)
        # plot_median_against_iqr(data)
    
        plot_general_plots()
    
    
    
    def create_directories():
        os.makedirs(FEATHERS_DIR, mode=0o777, exist_ok=True)
        os.makedirs(PLOTS_DIR, mode=0o777, exist_ok=True)
        os.makedirs(f"{PLOTS_DIR}/median-of-single-algorithm", mode=0o777, exist_ok=True)
        os.makedirs(f"{PLOTS_DIR}/medians-of-sec-level", mode=0o777, exist_ok=True)
        os.makedirs(f"{PLOTS_DIR}/static", mode=0o777, exist_ok=True)
    
    
    def load_data():
        if os.path.exists(f"{FEATHERS_DIR}/data.feather"):
            data = pd.read_feather(f"{FEATHERS_DIR}/data.feather")
        else:
            data = read_data_into_pandas()
        return data
    
    
    # Reading in the data takes about 10 seconds per scenario
    def read_data_into_pandas():
        data = pd.DataFrame(
            columns=[
                "scenario",
                "protocol",
                "sec_level",
                "kem_alg",
                "srv_pkt_loss",
                "srv_delay",
                "srv_jitter",
                "srv_duplicate",
                "srv_corrupt",
                "srv_reorder",
                "srv_rate",
                "cli_pkt_loss",
                "cli_delay",
                "cli_jitter",
                "cli_duplicate",
                "cli_corrupt",
                "cli_reorder",
                "cli_rate",
                "measurements",
                "mean",
                "std",
                "cv",
                "median",
                "qtl_25",
                "qtl_75",
                "qtl_95",
                "qtl_99",
                "iqr",
                "skewness",
                "kurtosis",
            ]
        )
    
        def get_all_result_files():
            result_files = []
            for dirpath, _, filenames in os.walk(RESULTS_DIR):
                if filenames and not any(
                    filter_value in dirpath for filter_value in FILTER_RESULTS
                ):
                    for filename in filenames:
                        result_files.append(os.path.join(dirpath, filename))
            return result_files
    
        for csv_result_file_name in get_all_result_files():
            _, scenario, protocol, sec_level, kem_alg = csv_result_file_name.split("/")
            kem_alg = kem_alg.split(".")[0]
            # print(f"csv_result_file_name: {csv_result_file_name}")
            result_file_data = pd.read_csv(csv_result_file_name, header=None)
            result_file_data = result_file_data.T
            df_scenariofile = pd.read_csv(f"testscenarios/scenario_{scenario}.csv")
            df_scenariofile = df_scenariofile.drop(
                df_scenariofile.columns[0], axis="columns"
            )
    
            assert len(result_file_data.columns) == len(df_scenariofile)
            for i in range(len(result_file_data.columns)):
                measurements = result_file_data.iloc[:, i].tolist()
                measurements = np.array(measurements)
                data.loc[len(data)] = {
                    "scenario": scenario,
                    "protocol": protocol,
                    "sec_level": sec_level,
                    "kem_alg": kem_alg,
                    "srv_pkt_loss": df_scenariofile.iloc[i]["srv_pkt_loss"],
                    "srv_delay": df_scenariofile.iloc[i]["srv_delay"],
                    "srv_jitter": df_scenariofile.iloc[i]["srv_jitter"],
                    "srv_duplicate": df_scenariofile.iloc[i]["srv_duplicate"],
                    "srv_corrupt": df_scenariofile.iloc[i]["srv_corrupt"],
                    "srv_reorder": df_scenariofile.iloc[i]["srv_reorder"],
                    "srv_rate": df_scenariofile.iloc[i]["srv_rate"],
                    "cli_pkt_loss": df_scenariofile.iloc[i]["cli_pkt_loss"],
                    "cli_delay": df_scenariofile.iloc[i]["cli_delay"],
                    "cli_jitter": df_scenariofile.iloc[i]["cli_jitter"],
                    "cli_duplicate": df_scenariofile.iloc[i]["cli_duplicate"],
                    "cli_corrupt": df_scenariofile.iloc[i]["cli_corrupt"],
                    "cli_reorder": df_scenariofile.iloc[i]["cli_reorder"],
                    "cli_rate": df_scenariofile.iloc[i]["cli_rate"],
                    "measurements": measurements,
                    "mean": np.mean(measurements),
                    "std": np.std(measurements),
                    "cv": np.std(measurements) / np.mean(measurements),
                    "median": np.median(measurements),
                    "qtl_25": np.quantile(measurements, 0.25),
                    "qtl_75": np.quantile(measurements, 0.75),
                    "qtl_95": np.quantile(measurements, 0.95),
                    "qtl_99": np.quantile(measurements, 0.99),
                    "iqr": scipy.stats.iqr(measurements),
                    "skewness": scipy.stats.skew(measurements),
                    "kurtosis": scipy.stats.kurtosis(measurements),
                }
    
        dtypes = {
            "scenario": "category",
            "protocol": "category",
            "sec_level": "category",
            "kem_alg": "category",
        }
        data = data.astype(dtypes)
        categories = [
            "secp256r1",
            "secp384r1",
            "secp521r1",
            "x25519",
            "x448",
            "mlkem512",
            "p256_mlkem512",
            "x25519_mlkem512",
            "mlkem768",
            "p384_mlkem768",
            "x448_mlkem768",
            "x25519_mlkem768",
            "p256_mlkem768",
            "mlkem1024",
            "p521_mlkem1024",
            "p384_mlkem1024",
            "bikel1",
            "p256_bikel1",
            "x25519_bikel1",
            "bikel3",
            "p384_bikel3",
            "x448_bikel3",
            "bikel5",
            "p521_bikel5",
            "hqc128",
            "p256_hqc128",
            "x25519_hqc128",
            "hqc192",
            "p384_hqc192",
            "x448_hqc192",
            "hqc256",
            "p521_hqc256",
            "frodo640aes",
            "p256_frodo640aes",
            "x25519_frodo640aes",
            "frodo640shake",
            "p256_frodo640shake",
            "x25519_frodo640shake",
            "frodo976aes",
            "p384_frodo976aes",
            "x448_frodo976aes",
            "frodo976shake",
            "p384_frodo976shake",
            "x448_frodo976shake",
            "frodo1344aes",
            "p521_frodo1344aes",
            "frodo1344shake",
            "p521_frodo1344shake",
        ]
        data["kem_alg"] = pd.Categorical(
            data["kem_alg"], categories=categories, ordered=True
        )
    
        print(data.head())
        print(data.describe())
        print(data.info())
        print()
        print("Scenarios read:", data["scenario"].unique())
    
        data.to_feather(f"{FEATHERS_DIR}/data.feather")
        print("Data written to feather file")
        return data
    
    
    def filter_data(
        data,
        scenario: str | None = None,
        protocol: str | None = None,
        sec_level: str | list[str] | None = None,
        kem_alg: str | None = None,
    ):
        filtered_data = data
        # print(filtered_data["kem_alg"] == "x25519") # is a boolean series
        if scenario is not None:
            filtered_data = filtered_data[filtered_data["scenario"] == scenario]
        if protocol is not None:
            filtered_data = filtered_data[filtered_data["protocol"] == protocol]
        if sec_level is not None:
            if type(sec_level) == list:
                filtered_data = filtered_data[filtered_data["sec_level"].isin(sec_level)]
            else:
                filtered_data = filtered_data[filtered_data["sec_level"] == sec_level]
        if kem_alg is not None:
            filtered_data = filtered_data[filtered_data["kem_alg"] == kem_alg]
    
        def drop_columns_with_only_zero_values(data):
            # this complicated way is necessary, because measurements is a list of values
            filtered_data_without_measurements = data.drop(columns=["measurements"])
            zero_columns_to_drop = (filtered_data_without_measurements != 0).any()
            zero_columns_to_drop = [
                col
                for col in filtered_data_without_measurements.columns
                if not zero_columns_to_drop[col]
            ]
            return data.drop(columns=zero_columns_to_drop)
    
        filtered_data = drop_columns_with_only_zero_values(filtered_data)
    
        # print(filtered_data["measurements"].head())
        # print(filtered_data)
        return filtered_data
    
    
    def plot_median(data):
        plot_median_for_sec_level(data)
        # plot_median_of_single_algorithm(data)
    
    
    
    def get_x_axis(scenario, data, length):
        match scenario:
            case "duplicate":
                return data["srv_duplicate"]
            case "packetloss":
                return data["srv_pkt_loss"]
            case "delay":
                return data["srv_delay"]
            case "jitter_delay20":
                return data["srv_jitter"]
            case "corrupt":
                return data["srv_corrupt"]
            case "reorder":
                return data["srv_reorder"]
            case "rate_both":
                return data["srv_rate"]
            case "rate_client":
                return data["cli_rate"]
            case "rate_server":
                return data["srv_rate"]
            case "static":
                return list(range(length))
            case _:
                print(f"NO MATCH FOUND FOR {scenario}", file=sys.stderr)
                sys.exit(1)
    
    
    
    def plot_median_for_sec_level(data):
        # get all combination of scenario, protocol, sec_level
        unique_combinations = data[["scenario", "protocol", "sec_level"]].drop_duplicates()
        # print(len(unique_combinations))
        # print(unique_combinations)
        for _, row in unique_combinations.iterrows():
            filtered_data = filter_data(
                data,
                scenario=row["scenario"],
                protocol=row["protocol"],
                sec_level=row["sec_level"],
            )
            # print(f"scenario: {row['scenario']}, protocol: {row['protocol']}, sec_level: {row['sec_level']}")
    
            plt.figure()
            for idx, kem_alg in enumerate(filtered_data["kem_alg"].unique().sort_values()):
                color = cmap(idx / len(filtered_data["kem_alg"].unique()))
                filtered_data_single_kem_alg = filter_data(filtered_data, kem_alg=kem_alg)
                # print(filtered_data_single_kem_alg)
                y = filtered_data_single_kem_alg["median"]
    
                x = get_x_axis(row["scenario"], filtered_data_single_kem_alg, len(y))
    
                # print(
                #     f"scenario: {row['scenario']}, protocol: {row['protocol']}, sec_level: {row['sec_level']}, kem_alg: {kem_alg}"
                # )
                # print(f"x: {x}")
                # print(f"y: {y}")
    
    
                # plt.fill_between(x, filtered_data_single_kem_alg["qtl_25"], filtered_data_single_kem_alg["qtl_75"], alpha=0.2, color=color)
                plt.plot(x, y, linestyle="-", marker=".", color=color, label=kem_alg)
    
            plt.ylim(bottom=0)
            plt.xlim(left=0)
            plt.xlabel(row["scenario"])
            plt.ylabel(f"Time-to-first-byte (ms)")
            # plt.title(
            #     f"Medians of {row['scenario']} in {row['protocol']} in {row['sec_level']}"
            # )
            plt.legend(
                bbox_to_anchor=(0.5, 1), loc="lower center", ncol=3, fontsize="small"
            )
            plt.tight_layout()
    
            plt.savefig(
                f"{PLOTS_DIR}/medians-of-sec-level/median-{row['scenario']}-{row['protocol']}-{row['sec_level']}.png"
            )
            plt.close()
    
    
    def plot_median_of_single_algorithm(data):
        # get all combination of scenario, protocol, sec_level, kem_alg
        unique_combinations = data[
            ["scenario", "protocol", "sec_level", "kem_alg"]
        ].drop_duplicates()
        for _, row in unique_combinations.iterrows():
            filtered_data = filter_data(
                data,
                scenario=row["scenario"],
                protocol=row["protocol"],
                sec_level=row["sec_level"],
                kem_alg=row["kem_alg"],
            )
            # print(f"scenario: {row['scenario']}, protocol: {row['protocol']}, sec_level: {row['sec_level']}, kem_alg: {row['kem_alg']}")
            y = filtered_data["median"]
            match row["scenario"]:
                case "packetloss":
                    x = filtered_data["srv_pkt_loss"]
                case "delay":
                    x = filtered_data["srv_delay"]
                case "jitter":
                    x = filtered_data["srv_jitter"]
                case "static":
                    continue
                case _:
                    print(f"NO MATCH FOUND FOR {row['scenario']}", file=sys.stderr)
    
            plt.figure()
            plt.fill_between(x, filtered_data["qtl_25"], filtered_data["qtl_75"], alpha=0.5)
            plt.plot(x, y, linestyle="-", marker=".")
            plt.ylim(bottom=0)
            plt.xlim(left=0)
            plt.xlabel(row["scenario"])
            plt.ylabel(f"Time-to-first-byte (ms)")
            plt.title(
                f"Median of {row['scenario']} in {row['protocol']} in {row['sec_level']} with {row['kem_alg']}"
            )
    
            plt.savefig(
                f"{PLOTS_DIR}/median-of-single-algorithm/median-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.png"
            )
            plt.close()
    
    
    # This does not yet seem like a good idea
    def plot_median_against_iqr(data):
        plt.figure()
        plt.hexbin(data["median"], data["iqr"], gridsize=50)
        print(data["iqr"].describe())
        print(data["median"].describe())
        # get the line with the maximum median
        max_median = data["median"].idxmax()
        print(data.iloc[max_median])
        plt.savefig(f"{PLOTS_DIR}/median_against_iqr_hexbin.png")
    
        plt.figure()
        plt.hist2d(data["median"], data["iqr"], bins=50)
        plt.savefig(f"{PLOTS_DIR}/median_against_iqr_hist2d.png")
    
    
    # TODO make a violinplot/eventplot for many algos in static scenario
    def plot_static_data(data):
        def plot_static_data_for_single_algorithms(data):
            unique_combinations = data[
                ["scenario", "protocol", "sec_level", "kem_alg"]
            ].drop_duplicates()
            for idx, row in unique_combinations.iterrows():
                filtered_data = filter_data(
                    data,
                    scenario="static",
                    protocol=row["protocol"],
                    sec_level=row["sec_level"],
                    kem_alg=row["kem_alg"],
                )
    
                plt.figure()
                plt.boxplot(filtered_data["median"])
                plt.savefig(
                    os.path.join(
                        PLOTS_DIR,
                        "static",
                        f"boxplot-of-medians-for-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.png",
                    )
                )
                plt.close()
    
                plt.figure()
                plt.violinplot(filtered_data["measurements"], showmedians=True)
                plt.savefig(
                    os.path.join(
                        PLOTS_DIR,
                        "static",
                        f"multiple-violin-plots-for-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.png",
                    )
                )
                plt.close()
    
                # for multiple runs of the same static scenario, data taken together
                measurements_flattend = filtered_data["measurements"].explode().tolist()
                # print(filtered_data["measurements"].explode())
                # print(len(measurements_flattend))
                plt.figure()
                plt.violinplot(measurements_flattend, showmedians=True)
                plt.savefig(
                    os.path.join(
                        PLOTS_DIR,
                        "static",
                        f"condensed-violin-plot-for-{len(measurements_flattend)}-measurements-of-{row['scenario']}-{row['protocol']}-{row['sec_level']}-{row['kem_alg']}.png",
                    )
                )
                plt.close()
    
        def plot_static_data_for_multiple_algorithms(data):
            unique_combinations = data[["protocol", "sec_level"]].drop_duplicates()
            for idx, row in unique_combinations.iterrows():
                sec_level = map_security_level_hybrid_together(row["sec_level"])
                if sec_level is None:
                    continue
    
                filtered_data = filter_data(
                    data,
                    scenario="static",
                    protocol=row["protocol"],
                    sec_level=sec_level,
                )
    
                # plt.figure(figsize=(6, 6))
                plt.figure()
                kem_algs = []
                for idx, kem_alg in enumerate(
                    sorted(
                        filtered_data["kem_alg"].unique(),
                        key=cmp_to_key(sort_kem_algorithms),
                    )
                ):
                    kem_algs.append(kem_alg)
                    filtered_data_single_kem_alg = filter_data(
                        filtered_data, kem_alg=kem_alg
                    )
                    plt.boxplot(
                        filtered_data_single_kem_alg["median"],
                        positions=[idx],
                        widths=0.6,
                    )
    
                plt.xticks(
                    range(len(filtered_data["kem_alg"].unique())),
                    kem_algs,
                    rotation=45,
                    ha="right",
                )
                plt.xlabel("KEM Algorithms")
                plt.ylabel("Time-to-first-byte (ms)")
    
                sec_level_string = (
                    sec_level if type(sec_level) == str else "-".join(sec_level)
                )
                plt.tight_layout()
                plt.savefig(
                    os.path.join(
                        PLOTS_DIR,
                        "static",
                        f"boxplots-of-medians-for-static-{row['protocol']}-{sec_level_string}.png",
                    )
                )
                plt.close()
    
        plot_static_data_for_multiple_algorithms(data)
        # plot_static_data_for_single_algorithms(data)
    
    
    def map_security_level_hybrid_together(sec_level: str):
        match sec_level:
            case "secLevel1":
                return ["secLevel1", "secLevel1_hybrid"]
            case "secLevel3":
                return ["secLevel3", "secLevel3_hybrid"]
            case "secLevel5":
                return ["secLevel5", "secLevel5_hybrid"]
            case "miscLevel":
                return "miscLevel"
            case _:
                return None
    
    
    
    def plot_general_plots():
        def get_color_for_kem_algo(kem_algo):
            if "mlkem" in kem_algo:
                return "blue"
            if "bikel" in kem_algo:
                return "red"
            if "hqc" in kem_algo:
                return "green"
            if "frodo" in kem_algo:
                return "orange"
            return "grey"
    
        os.makedirs(f"{PLOTS_DIR}/general", mode=0o777, exist_ok=True)
    
        df = helper_functions.prepare_kem_performance_data_for_plotting(
            helper_functions.get_kem_performance_data()
        )
    
        def plot_send_bytes_against_kem_performance(df, with_hybrids: bool):
            if not with_hybrids:
                # filter out all hybrids, otherwise the plot is too cluttered
                df = df[~df["kem_algo"].str.contains("_")]
    
            plt.figure()
            # plt.scatter(df["bytes_sent"], df["performance_us"])
            for kem_algo in df["kem_algo"]:
                subset = df[df["kem_algo"] == kem_algo]
                color = get_color_for_kem_algo(kem_algo)
                plt.scatter(
                    subset["bytes_sent"],
                    subset["performance_us"],
                    color=color,
                    label=kem_algo,
                    alpha=0.7,
                )
    
            for i, txt in enumerate(df["kem_algo"]):
                plt.annotate(
                    txt,
                    (df["bytes_sent"].iloc[i], df["performance_us"].iloc[i]),
                    xytext=(0, -3),
                    textcoords="offset points",
                    fontsize=8,
                    ha="center",
                    va="bottom",
                )
    
            plt.xscale("log")
            plt.yscale("log")
            plt.xlim(30)
            plt.ylim(30)
    
            # custom tick stuff from claude
            def custom_ticks(start, end):
                ticks = [start] + [
                    10**i for i in range(int(np.log10(start)) + 1, int(np.log10(end)) + 1)
                ]
                return ticks
    
            x_ticks = custom_ticks(30, df["bytes_sent"].max())
            y_ticks = custom_ticks(30, df["performance_us"].max())
            plt.xticks(x_ticks, [f"{int(x):,}" for x in x_ticks])
            plt.yticks(y_ticks, [f"{int(y):,}" for y in y_ticks])
    
            plt.xlabel("Bytes sent")
            plt.ylabel("Performance (µs)")
    
            name = (
                "scatter-of-bytes-sent-against-kem-performance-with-hybrids.png"
                if with_hybrids
                else "scatter-of-bytes-sent-against-kem-performance.png"
            )
            plt.savefig(
                os.path.join(PLOTS_DIR, "general", name),
                dpi=300,
            )
            plt.close()
    
        print(df)
    
        def plot_public_key_length_against_ciphertext_length(df, with_hybrids: bool):
            if not with_hybrids:
                # filter out all hybrids, otherwise the plot is too cluttered
                df = df[~df["kem_algo"].str.contains("_")]
    
            # only keep one for frodo, since they are the same and remove the hash algo
            df = df[~df["kem_algo"].str.contains("shake")]
            df.loc[:, "kem_algo"] = df["kem_algo"].apply(lambda x: x.replace("aes", ""))
    
            plt.figure()
            for kem_algo in df["kem_algo"]:
                subset = df[df["kem_algo"] == kem_algo]
                color = get_color_for_kem_algo(kem_algo)
                plt.scatter(
                    subset["length_public_key"],
                    subset["length_ciphertext"],
                    color=color,
                    label=kem_algo,
                    alpha=0.7,
                )
    
            for i, txt in enumerate(df["kem_algo"]):
                annotate_offset = (0, -3)
                if "bikel1" in txt:
                    annotate_offset = (0, 0)
                if "mlkem1024" in txt:
                    annotate_offset = (0, -7)
                plt.annotate(
                    txt,
                    (df["length_public_key"].iloc[i], df["length_ciphertext"].iloc[i]),
                    xytext=annotate_offset,
                    textcoords="offset points",
                    fontsize=8,
                    ha="center",
                    va="bottom",
                )
    
            # reason for these magic numbers in obsidian note [[Packet lengths in QUIC over Ethernet]]
            plt.axvline(x=940, color="purple", linestyle="--", label="1 Paket Grenze")
            plt.axhline(y=277, color="purple", linestyle="--", label="1 Paket Grenze")
            plt.axvline(
                x=940 + 1157, color="purple", linestyle="--", label="2 Paket Grenze"
            )
            plt.axhline(
                y=277 + 1100, color="purple", linestyle="--", label="2 Paket Grenze"
            )
    
            plt.xscale("log")
            plt.yscale("log")
            plt.xlim(10)
            plt.ylim(10)
    
            plt.xlabel("Public Key Länge in Bytes")
            plt.ylabel("Ciphertext Länge in Bytes")
    
            plt.gca().xaxis.set_major_formatter(ticker.ScalarFormatter())
            plt.gca().yaxis.set_major_formatter(ticker.ScalarFormatter())
    
            name = (
                "scatter-of-public-key-agains-ciphertext-length-with-hybrids.png"
                if with_hybrids
                else "scatter-of-public-key-agains-ciphertext-length.png"
            )
            plt.savefig(
                os.path.join(PLOTS_DIR, "general", name),
                dpi=300,
            )
            plt.close()
    
        plot_send_bytes_against_kem_performance(df, with_hybrids=False)
        plot_send_bytes_against_kem_performance(df, with_hybrids=True)
    
        plot_public_key_length_against_ciphertext_length(df, with_hybrids=False)
        plot_public_key_length_against_ciphertext_length(df, with_hybrids=True)