Skip to content
Snippets Groups Projects
experiment.py 5.28 KiB
Newer Older
  • Learn to ignore specific revisions
  • #!/usr/bin/env python
    
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    from multiprocessing import Pool
    
    import networkmgmt
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    import os
    import subprocess
    import sys
    
    
    import csv
    import pandas as pd
    
    POOL_SIZE = 7
    MEASUREMENTS_PER_TIMER = 20
    TIMERS = 10
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
    PROTOCOLS_TO_BENCH = ["quic", "tlstcp"]
    DEFAULT_SCENARIO_TO_BENCH = "testscenarios/scenario_static.csv"
    
    ALGORITHMS_CSV_FILE = "algorithms.csv"
    MAIN_DIR = "results"
    
    S_TIMER = "./s_timer"
    QUIC_S_TIMER = "./quic_s_timer"
    LD_LIBRARY_PATH = "../tmp/.local/openssl/lib64"
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
    SRV_NS = "srv_ns"
    SRV_VE = "srv_ve"
    CLI_NS = "cli_ns"
    CLI_VE = "cli_ve"
    
    def main():
        timer_pool = Pool(processes=POOL_SIZE)
        scenariofiles = parse_scenariofiles_to_bench()
        algorithms_to_bench_dict = read_algorithms()
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
        for scenariofile in scenariofiles:
    
            for protocol in PROTOCOLS_TO_BENCH:
                scenarios = pd.read_csv(scenariofile)
                testscenario_name = scenarios.columns[0]
                make_dirs(testscenario_name, protocol, algorithms_to_bench_dict)
                # get_emulated_rtt(scenarios)
    
                for _, parameters in scenarios.iterrows():
                    # set network parameters of scenario
                    networkmgmt.change_qdisc(
                        SRV_NS,
                        SRV_VE,
                        parameters["srv_pkt_loss"],
                        parameters["srv_delay"],
                        parameters["srv_jitter"],
                        parameters["srv_duplicate"],
                        parameters["srv_corrupt"],
                        parameters["srv_reorder"],
                        parameters["srv_rate"],
                    )
                    networkmgmt.change_qdisc(
                        CLI_NS,
                        CLI_VE,
                        parameters["cli_pkt_loss"],
                        parameters["cli_delay"],
                        parameters["cli_jitter"],
                        parameters["cli_duplicate"],
                        parameters["cli_corrupt"],
                        parameters["cli_reorder"],
                        parameters["cli_rate"],
                    )
    
                    for algorithm_class, algorithms in algorithms_to_bench_dict.items():
                        for kem_alg in algorithms:
                            path_to_results_csv_file = os.path.join(
                                MAIN_DIR,
                                testscenario_name,
                                protocol,
                                algorithm_class,
                                f"{kem_alg}.csv",
                            )
                            print(path_to_results_csv_file)
                            with open(
                                path_to_results_csv_file,
                                "a",
                            ) as out:
                                result = run_timers(timer_pool, protocol, kem_alg)
                                csv.writer(out).writerow(result)
    
    
        timer_pool.close()
        timer_pool.join()
    
    
    def parse_scenariofiles_to_bench():
        scenariofiles = []
        # Check for handed scenarios
        if len(sys.argv) > 1:
            for i in range(1, len(sys.argv)):
                print(sys.argv[i])
                scenariofiles.append(sys.argv[i])
        else:
    
            scenariofiles.append(DEFAULT_SCENARIO_TO_BENCH)
    
        return scenariofiles
    
    
    def read_algorithms():
        with open(ALGORITHMS_CSV_FILE, "r") as csv_file:
            return {row[0]: row[1:] for row in csv.reader(csv_file)}
    
    
    
    def make_dirs(testscenario_name, protocol, algorithms_to_bench_dict):
    
        for algorithm_class in algorithms_to_bench_dict.keys():
            os.makedirs(
    
                os.path.join(MAIN_DIR, testscenario_name, protocol, algorithm_class),
    
    def run_timers(timer_pool, protocol, kem_alg):
    
        results_nested = timer_pool.starmap(
    
            time_handshake, [(protocol, kem_alg, MEASUREMENTS_PER_TIMER)] * TIMERS
    
        )
        return [item for sublist in results_nested for item in sublist]
    
    
    # do TLS handshake (s_timer.c)
    
    def time_handshake(protocol, kem_alg, measurements) -> list[float]:
        program = QUIC_S_TIMER if protocol == "quic" else S_TIMER
        command = [
            "ip",
            "netns",
            "exec",
            "cli_ns",
            "env",
            f"LD_LIBRARY_PATH={LD_LIBRARY_PATH}",
            program,
            kem_alg,
            str(measurements),
        ]
    
        result = run_subprocess(command)
        return [float(i) for i in result.strip().split(",")]
    
    
    def run_subprocess(command, working_dir=".", expected_returncode=0) -> str:
        result = subprocess.run(
            command,
            stdout=subprocess.PIPE,  # puts stdout in result.stdout
            stderr=subprocess.PIPE,
            cwd=working_dir,
        )
        if result.stderr:
            print(result.stderr)
        assert result.returncode == expected_returncode
        return result.stdout.decode("utf-8")
    
    
    
    # TODO think about what to do with this rtt calculation
    # maybe put it below the setting of the network parameters and save to a different file
    # To get actual (emulated) RTT
    def get_emulated_rtt(scenarios):
        networkmgmt.change_qdisc(
            SRV_NS,
            SRV_VE,
            0,
            scenarios.iloc[0]["srv_delay"],
            0,
            0,
            0,
            0,
            scenarios.iloc[0]["srv_rate"],
        )
        networkmgmt.change_qdisc(
            CLI_NS,
            CLI_VE,
            0,
            scenarios.iloc[0]["cli_delay"],
            0,
            0,
            0,
            0,
            scenarios.iloc[0]["cli_rate"],
        )
        rtt_str = networkmgmt.get_rtt_ms()
        print(
            f"Emulated average RTT: {rtt_str} ms",
        )