Skip to content
Snippets Groups Projects
experiment.py 4.74 KiB
Newer Older
  • Learn to ignore specific revisions
  • #!/usr/bin/env python
    
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    from multiprocessing import Pool
    
    import networkmgmt
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    import os
    import subprocess
    import sys
    
    
    import csv
    import pandas as pd
    
    POOL_SIZE = 7
    MEASUREMENTS_PER_TIMER = 20
    TIMERS = 10
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
    DEFAULT_SCENARIO_TO_TEST = "testscenarios/scenario_static.csv"
    ALGORITHMS_CSV_FILE = "algorithms.csv"
    MAIN_DIR = "results"
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
    SRV_NS = "srv_ns"
    SRV_VE = "srv_ve"
    CLI_NS = "cli_ns"
    CLI_VE = "cli_ve"
    
    def main():
        timer_pool = Pool(processes=POOL_SIZE)
        scenariofiles = parse_scenariofiles_to_bench()
        algorithms_to_bench_dict = read_algorithms()
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
        for scenariofile in scenariofiles:
            scenarios = pd.read_csv(scenariofile)
            testscenario_name = scenarios.columns[0]
            make_dirs(testscenario_name, algorithms_to_bench_dict)
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
            # To get actual (emulated) RTT
    
            networkmgmt.change_qdisc(
                SRV_NS,
                SRV_VE,
                0,
                scenarios.iloc[0]["srv_delay"],
                0,
                0,
                0,
                0,
                scenarios.iloc[0]["srv_rate"],
            )
            networkmgmt.change_qdisc(
                CLI_NS,
                CLI_VE,
                0,
                scenarios.iloc[0]["cli_delay"],
                0,
                0,
                0,
                0,
                scenarios.iloc[0]["cli_rate"],
            )
    
    Johanna Henrich's avatar
    Johanna Henrich committed
            rtt_str = networkmgmt.get_rtt_ms()
    
            print(
                f"Emulated average RTT: {rtt_str} ms",
            )
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
            # set network parameters of scenario
    
            for _, parameters in scenarios.iterrows():
                networkmgmt.change_qdisc(
                    SRV_NS,
                    SRV_VE,
                    parameters["srv_pkt_loss"],
                    parameters["srv_delay"],
                    parameters["srv_jitter"],
                    parameters["srv_duplicate"],
                    parameters["srv_corrupt"],
                    parameters["srv_reorder"],
                    parameters["srv_rate"],
                )
                networkmgmt.change_qdisc(
                    CLI_NS,
                    CLI_VE,
                    parameters["cli_pkt_loss"],
                    parameters["cli_delay"],
                    parameters["cli_jitter"],
                    parameters["cli_duplicate"],
                    parameters["cli_corrupt"],
                    parameters["cli_reorder"],
                    parameters["cli_rate"],
                )
    
                for algorithm_class, algorithms in algorithms_to_bench_dict.items():
                    for kem_alg in algorithms:
                        path_to_results_csv_file = os.path.join(
                            MAIN_DIR,
                            testscenario_name,
                            algorithm_class,
                            f"{kem_alg}.csv",
                        )
                        print(path_to_results_csv_file)
                        with open(
                            path_to_results_csv_file,
                            "a",
                        ) as out:
                            result = run_timers(kem_alg, timer_pool)
                            csv.writer(out).writerow(result)
                            # result.insert(0, f"{ROW_NAMES[index]}") # TODO change everywhere else also
    
        timer_pool.close()
        timer_pool.join()
    
    
    def parse_scenariofiles_to_bench():
        scenariofiles = []
        # Check for handed scenarios
        if len(sys.argv) > 1:
            for i in range(1, len(sys.argv)):
                print(sys.argv[i])
                scenariofiles.append(sys.argv[i])
        else:
            scenariofiles.append(DEFAULT_SCENARIO_TO_TEST)
        return scenariofiles
    
    
    def read_algorithms():
        with open(ALGORITHMS_CSV_FILE, "r") as csv_file:
            return {row[0]: row[1:] for row in csv.reader(csv_file)}
    
    
    def make_dirs(testscenario_name, algorithms_to_bench_dict):
        for algorithm_class in algorithms_to_bench_dict.keys():
            os.makedirs(
                os.path.join(MAIN_DIR, testscenario_name, algorithm_class),
                exist_ok=True,
            )
    
    
    def run_timers(kex_alg, timer_pool):
        results_nested = timer_pool.starmap(
            time_handshake, [(kex_alg, MEASUREMENTS_PER_TIMER)] * TIMERS
        )
        return [item for sublist in results_nested for item in sublist]
    
    
    # do TLS handshake (s_timer.c)
    def time_handshake(kex_alg, measurements) -> list[float]:
        command = ["ip", "netns", "exec", "cli_ns", "./s_timer", kex_alg, str(measurements)]
        result = run_subprocess(command)
        return [float(i) for i in result.strip().split(",")]
    
    
    def run_subprocess(command, working_dir=".", expected_returncode=0) -> str:
        env = os.environ.copy()
        env["LD_LIBRARY_PATH"] = "../tmp/.local/openssl/lib64"
        result = subprocess.run(
            command,
            stdout=subprocess.PIPE,  # puts stdout in result.stdout
            stderr=subprocess.PIPE,
            cwd=working_dir,
            env=env,
        )
        if result.stderr:
            print(result.stderr)
        assert result.returncode == expected_returncode
        return result.stdout.decode("utf-8")
    
    
    main()