Skip to content
Snippets Groups Projects
experiment.py 9.89 KiB
Newer Older
  • Learn to ignore specific revisions
  • #!/usr/bin/env python
    
    
    import csv
    
    import multiprocessing as mp
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    import os
    import subprocess
    import sys
    
    
    import pandas as pd
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
    # POOL_SIZE is the parallelism level. There should be the same amount of cores or more on your CPU.
    
    POOL_SIZE = 16
    
    # There must be the same amount of network namespaces with each an nginx running as NETWORK_NAMESPACES.
    # The amount of network namespaces should be at least POOL_SIZE. The amount is specified by setup.sh's first argument.
    
    NETWORK_NAMESPACES = 16
    # Ensure that TIMERS * MEASUREMENTS_PER_TIMER is at least 200 to get good results.
    
    # TIMERS should be a multiple of POOL_SIZE. Best is to have TIMERS = POOL_SIZE.
    
    TIMERS = 16
    MEASUREMENTS_PER_TIMER = 15
    
    
    PROTOCOLS_TO_BENCH = [
        "quic",
        "tlstcp",
        "cquiche-reno",
        "cquiche-cubic",
        "cquiche-bbr",
        "cquiche-bbr2",
    ]
    
    DEFAULT_SCENARIO_TO_BENCH = "testscenarios/scenario_static.csv"
    
    ALGORITHMS_CSV_FILE = "algorithms.csv"
    MAIN_DIR = "results"
    
    S_TIMER = "./s_timer"
    QUIC_S_TIMER = "./quic_s_timer"
    
    CQUICHE_S_TIMER = "./cquiche_s_timer"
    
    CQUICHE_RENO = "cquiche-reno"
    CQUICHE_CUBIC = "cquiche-cubic"
    CQUICHE_BBR = "cquiche-bbr"
    CQUICHE_BBR2 = "cquiche-bbr2"
    
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
    SRV_NS = "srv_ns"
    CLI_NS = "cli_ns"
    
    CLI_VE = "cli_ve"
    
    def main():
    
        shared_condition = mp.Condition()
        shared_array = mp.Array("b", [False] * NETWORK_NAMESPACES)
        timer_pool = mp.Pool(
            processes=POOL_SIZE,
            initializer=init_worker,
            initargs=(
                shared_condition,
                shared_array,
            ),
        )
    
        scenariofiles = parse_scenariofiles_to_bench()
        algorithms_to_bench_dict = read_algorithms()
    
    Johanna Henrich's avatar
    Johanna Henrich committed
    
    
        for scenariofile in scenariofiles:
    
            for protocol in PROTOCOLS_TO_BENCH:
                scenarios = pd.read_csv(scenariofile)
                testscenario_name = scenarios.columns[0]
    
                if protocol in [
                    CQUICHE_RENO,
                    CQUICHE_BBR,
                    CQUICHE_BBR2,
                ] and testscenario_name not in ["static", "corrupt", "packetloss"]:
                    continue
    
                make_dirs(testscenario_name, protocol, algorithms_to_bench_dict)
                # get_emulated_rtt(scenarios)
    
                for _, parameters in scenarios.iterrows():
                    # set network parameters of scenario
    
                    set_network_parameters(parameters)
    
    
                    for algorithm_class, algorithms in algorithms_to_bench_dict.items():
                        for kem_alg in algorithms:
    
                            algorithm_identifier_for_openssl = kem_alg
                            if kem_alg == "x25519_mlkem768":
                                algorithm_identifier_for_openssl = "X25519MLKEM768"
                            elif kem_alg == "p256_mlkem768":
                                algorithm_identifier_for_openssl = "SecP256r1MLKEM768"
    
                            path_to_results_csv_file = os.path.join(
                                MAIN_DIR,
                                testscenario_name,
                                protocol,
                                algorithm_class,
                                f"{kem_alg}.csv",
                            )
                            print(path_to_results_csv_file)
                            with open(
                                path_to_results_csv_file,
                                "a",
                            ) as out:
    
                                error_count, result = run_timers(
    
                                    timer_pool, protocol, algorithm_identifier_for_openssl
    
                                )
                                csv.writer(out).writerow([error_count, *result])
    
    
        timer_pool.close()
        timer_pool.join()
    
    
    
    # This function declares the global variables namespace_condition and acquired_network_namespaces so they can be used by the pool subprocesses.
    def init_worker(condition, shared_array):
        global namespace_condition
        global acquired_network_namespaces
        namespace_condition = condition
        acquired_network_namespaces = shared_array
    
    
    
    def parse_scenariofiles_to_bench():
        scenariofiles = []
        # Check for handed scenarios
        if len(sys.argv) > 1:
            for i in range(1, len(sys.argv)):
                print(sys.argv[i])
                scenariofiles.append(sys.argv[i])
        else:
    
            scenariofiles.append(DEFAULT_SCENARIO_TO_BENCH)
    
        return scenariofiles
    
    
    def read_algorithms():
        with open(ALGORITHMS_CSV_FILE, "r") as csv_file:
            return {row[0]: row[1:] for row in csv.reader(csv_file)}
    
    
    
    def make_dirs(testscenario_name, protocol, algorithms_to_bench_dict):
    
        for algorithm_class in algorithms_to_bench_dict.keys():
            os.makedirs(
    
                os.path.join(MAIN_DIR, testscenario_name, protocol, algorithm_class),
    
    def set_network_parameters(parameters):
        for i in range(1, NETWORK_NAMESPACES + 1):
            change_qdisc(
                f"{SRV_NS}_{i}",
                SRV_VE,
                parameters["srv_rate"],
                parameters["srv_delay"],
                parameters["srv_jitter"],
                parameters["srv_pkt_loss"],
                parameters["srv_duplicate"],
                parameters["srv_corrupt"],
                parameters["srv_reorder"],
            )
            change_qdisc(
                f"{CLI_NS}_{i}",
                CLI_VE,
                parameters["cli_rate"],
                parameters["cli_delay"],
                parameters["cli_jitter"],
                parameters["cli_pkt_loss"],
                parameters["cli_duplicate"],
                parameters["cli_corrupt"],
                parameters["cli_reorder"],
            )
    
    
    # TODO maybe add slot configuration for WLAN emulation
    def change_qdisc(ns, dev, rate, delay, jitter, pkt_loss, duplicate, corrupt, reorder):
        command = [
            "ip",
            "netns",
            "exec",
            ns,
            "tc",
            "qdisc",
            "change",
            "dev",
            dev,
            "root",
            "netem",
            # limit 1000 is the default value for the number of packets that can be queued
            "limit",
            "1000",
            "rate",
            f"{rate}mbit",
            "delay",
            f"{delay}ms",
            f"{jitter}ms",
            "loss",
            f"{pkt_loss}%",
            "duplicate",
            f"{duplicate}%",
            "corrupt",
            f"{corrupt}%",
            "reorder",
            f"{reorder}%",
        ]
    
        print(" > " + " ".join(command))
        run_subprocess(command)
    
    
    
    def run_timers(timer_pool, protocol, kem_alg):
    
        results_nested = timer_pool.starmap(
    
            time_handshake, [(protocol, kem_alg, MEASUREMENTS_PER_TIMER)] * TIMERS
    
        # results_nested is a list of tuples, which contain the errors_count and the list of measurements
        error_count_aggregated = sum([error_count for error_count, _ in results_nested])
        results_nested = [measurements for _, measurements in results_nested]
        return error_count_aggregated, [
            item for sublist in results_nested for item in sublist
        ]
    
    
    
    # do TLS handshake (s_timer.c)
    
    def time_handshake(protocol, kem_alg, measurements) -> list[float]:
    
        def aquire_network_namespace():
            with namespace_condition:
                while True:
                    for i in range(1, len(acquired_network_namespaces) + 1):
                        if not acquired_network_namespaces[i - 1]:
                            acquired_network_namespaces[i - 1] = True
                            return i
                    # make this process sleep until another wakes him up
                    namespace_condition.wait()
    
        def release_network_namespace(i):
            with namespace_condition:
                acquired_network_namespaces[i - 1] = False
                # wake another process that is sleeping
                namespace_condition.notify()
    
        network_namespace = aquire_network_namespace()
    
    Bartolomeo Berend Müller's avatar
    Bartolomeo Berend Müller committed
    
    
        cc_algo = None
        match protocol:
            case "tlstcp":
                program = S_TIMER
            case "quic":
                program = QUIC_S_TIMER
            case "cquiche-reno":
                cc_algo = "reno"
                program = CQUICHE_S_TIMER
            case "cquiche-cubic":
                cc_algo = "cubic"
                program = CQUICHE_S_TIMER
            case "cquiche-bbr":
                cc_algo = "bbr"
                program = CQUICHE_S_TIMER
            case "cquiche-bbr2":
                cc_algo = "bbr2"
                program = CQUICHE_S_TIMER
            case _:
                raise ValueError("Invalid protocol")
    
            f"{CLI_NS}_{network_namespace}",
    
        if cc_algo is not None:
            command.append(cc_algo)
    
    
        result = run_subprocess(command)
    
        release_network_namespace(network_namespace)
    
        error_count, result = result.split(";")
        error_count = int(error_count)
        return error_count, [float(i) for i in result.strip().split(",")]
    
    
    
    def run_subprocess(command, working_dir=".", expected_returncode=0) -> str:
        result = subprocess.run(
            command,
            stdout=subprocess.PIPE,  # puts stdout in result.stdout
            stderr=subprocess.PIPE,
            cwd=working_dir,
        )
        if result.stderr:
            print(result.stderr)
        assert result.returncode == expected_returncode
        return result.stdout.decode("utf-8")
    
    
    
    # TODO think about what to do with this rtt calculation, maybe just delete
    
    # maybe put it below the setting of the network parameters and save to a different file
    # To get actual (emulated) RTT
    def get_emulated_rtt(scenarios):
    
            scenarios.iloc[0]["srv_rate"],
    
            scenarios.iloc[0]["cli_rate"],
    
        rtt_str = get_rtt_ms()
    
    def get_rtt_ms():
        command = ["ip", "netns", "exec", "cli_ns", "ping", "10.0.0.1", "-c", "10"]
    
        print(" > " + " ".join(command))
        result = run_subprocess(command)
    
        # last line is "rtt min/avg/max/mdev = 5.978/6.107/6.277/0.093 ms"
        result_fmt = result.splitlines()[-1].split("/")
        return result_fmt[4]