Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
experiment.py 10.01 KiB
#!/usr/bin/env python

import csv
import multiprocessing as mp
import os
import subprocess
import sys

import pandas as pd

# POOL_SIZE is the parallelism level. There should be the same amount of cores or more on your CPU.
POOL_SIZE = 16
# There must be the same amount of network namespaces with each an nginx running as NETWORK_NAMESPACES.
# The amount of network namespaces should be at least POOL_SIZE. The amount is specified by setup.sh's first argument.
NETWORK_NAMESPACES = 16
# Ensure that TIMERS * MEASUREMENTS_PER_TIMER is at least 200 to get good results.
# TIMERS should be a multiple of POOL_SIZE. Best is to have TIMERS = POOL_SIZE.
TIMERS = 16
MEASUREMENTS_PER_TIMER = 15


PROTOCOLS_TO_BENCH = [
    "quic",
    "tlstcp",
    "cquiche-reno",
    "cquiche-cubic",
    "cquiche-bbr",
    "cquiche-bbr2",
]
DEFAULT_SCENARIO_TO_BENCH = "testscenarios/scenario_static.csv"
ALGORITHMS_CSV_FILE = "algorithms.csv"
MAIN_DIR = "results"
S_TIMER = "./s_timer"
QUIC_S_TIMER = "./quic_s_timer"
CQUICHE_S_TIMER = "./cquiche_s_timer"

CQUICHE_RENO = "cquiche-reno"
CQUICHE_CUBIC = "cquiche-cubic"
CQUICHE_BBR = "cquiche-bbr"
CQUICHE_BBR2 = "cquiche-bbr2"


SRV_NS = "srv_ns"
CLI_NS = "cli_ns"
SRV_VE = "srv_ve"
CLI_VE = "cli_ve"


def main():
    shared_condition = mp.Condition()
    shared_array = mp.Array("b", [False] * NETWORK_NAMESPACES)
    timer_pool = mp.Pool(
        processes=POOL_SIZE,
        initializer=init_worker,
        initargs=(
            shared_condition,
            shared_array,
        ),
    )
    scenariofiles = parse_scenariofiles_to_bench()
    algorithms_to_bench_dict = read_algorithms()

    for scenariofile in scenariofiles:
        for protocol in PROTOCOLS_TO_BENCH:
            scenarios = pd.read_csv(scenariofile)
            testscenario_name = scenarios.columns[0]
            if protocol in [
                CQUICHE_RENO,
                CQUICHE_BBR,
                CQUICHE_BBR2,
            ] and testscenario_name not in ["static", "corrupt", "packetloss"]:
                continue
            make_dirs(testscenario_name, protocol, algorithms_to_bench_dict)
            # get_emulated_rtt(scenarios)

            for _, parameters in scenarios.iterrows():
                # set network parameters of scenario
                set_network_parameters(parameters)
                print(protocol)
                continue

                for algorithm_class, algorithms in algorithms_to_bench_dict.items():
                    for kem_alg in algorithms:
                        algorithm_identifier_for_openssl = kem_alg
                        if kem_alg == "x25519_mlkem768":
                            algorithm_identifier_for_openssl = "X25519MLKEM768"
                        elif kem_alg == "p256_mlkem768":
                            algorithm_identifier_for_openssl = "SecP256r1MLKEM768"
                        path_to_results_csv_file = os.path.join(
                            MAIN_DIR,
                            testscenario_name,
                            protocol,
                            algorithm_class,
                            f"{kem_alg}.csv",
                        )
                        print(path_to_results_csv_file)
                        with open(
                            path_to_results_csv_file,
                            "a",
                        ) as out:
                            error_count, result = run_timers(
                                timer_pool, protocol, algorithm_identifier_for_openssl
                            )
                            csv.writer(out).writerow([error_count, *result])

    timer_pool.close()
    timer_pool.join()


# This function declares the global variables namespace_condition and acquired_network_namespaces so they can be used by the pool subprocesses.
def init_worker(condition, shared_array):
    global namespace_condition
    global acquired_network_namespaces
    namespace_condition = condition
    acquired_network_namespaces = shared_array


def parse_scenariofiles_to_bench():
    scenariofiles = []
    # Check for handed scenarios
    if len(sys.argv) > 1:
        for i in range(1, len(sys.argv)):
            print(sys.argv[i])
            scenariofiles.append(sys.argv[i])
    else:
        scenariofiles.append(DEFAULT_SCENARIO_TO_BENCH)
    return scenariofiles


def read_algorithms():
    with open(ALGORITHMS_CSV_FILE, "r") as csv_file:
        return {row[0]: row[1:] for row in csv.reader(csv_file)}


def make_dirs(testscenario_name, protocol, algorithms_to_bench_dict):
    for algorithm_class in algorithms_to_bench_dict.keys():
        os.makedirs(
            os.path.join(MAIN_DIR, testscenario_name, protocol, algorithm_class),
            exist_ok=True,
        )


def set_network_parameters(parameters):
    for i in range(1, NETWORK_NAMESPACES + 1):
        change_qdisc(
            f"{SRV_NS}_{i}",
            SRV_VE,
            parameters["srv_rate"],
            parameters["srv_delay"],
            parameters["srv_jitter"],
            parameters["srv_pkt_loss"],
            parameters["srv_duplicate"],
            parameters["srv_corrupt"],
            parameters["srv_reorder"],
        )
        change_qdisc(
            f"{CLI_NS}_{i}",
            CLI_VE,
            parameters["cli_rate"],
            parameters["cli_delay"],
            parameters["cli_jitter"],
            parameters["cli_pkt_loss"],
            parameters["cli_duplicate"],
            parameters["cli_corrupt"],
            parameters["cli_reorder"],
        )


# TODO maybe add slot configuration for WLAN emulation
def change_qdisc(ns, dev, rate, delay, jitter, pkt_loss, duplicate, corrupt, reorder):
    command = [
        "ip",
        "netns",
        "exec",
        ns,
        "tc",
        "qdisc",
        "change",
        "dev",
        dev,
        "root",
        "netem",
        # limit 1000 is the default value for the number of packets that can be queued
        "limit",
        "1000",
        "rate",
        f"{rate}mbit",
        "delay",
        f"{delay}ms",
        f"{jitter}ms",
        "loss",
        f"{pkt_loss}%",
        "duplicate",
        f"{duplicate}%",
        "corrupt",
        f"{corrupt}%",
        "reorder",
        f"{reorder}%",
    ]

    print(" > " + " ".join(command))
    run_subprocess(command)


def run_timers(timer_pool, protocol, kem_alg):
    results_nested = timer_pool.starmap(
        time_handshake, [(protocol, kem_alg, MEASUREMENTS_PER_TIMER)] * TIMERS
    )
    # results_nested is a list of tuples, which contain the errors_count and the list of measurements
    error_count_aggregated = sum([error_count for error_count, _ in results_nested])
    results_nested = [measurements for _, measurements in results_nested]
    return error_count_aggregated, [
        item for sublist in results_nested for item in sublist
    ]


# do TLS handshake (s_timer.c)
def time_handshake(protocol, kem_alg, measurements) -> list[float]:
    def aquire_network_namespace():
        with namespace_condition:
            while True:
                for i in range(1, len(acquired_network_namespaces) + 1):
                    if not acquired_network_namespaces[i - 1]:
                        acquired_network_namespaces[i - 1] = True
                        return i
                # make this process sleep until another wakes him up
                namespace_condition.wait()

    def release_network_namespace(i):
        with namespace_condition:
            acquired_network_namespaces[i - 1] = False
            # wake another process that is sleeping
            namespace_condition.notify()

    network_namespace = aquire_network_namespace()
    # program = QUIC_S_TIMER if protocol == "quic" else S_TIMER
    cc_algo = None
    match protocol:
        case "tlstcp":
            program = S_TIMER
        case "quic":
            program = QUIC_S_TIMER
        case "cquiche-reno":
            cc_algo = "reno"
            program = CQUICHE_S_TIMER
        case "cquiche-cubic":
            cc_algo = "cubic"
            program = CQUICHE_S_TIMER
        case "cquiche-bbr":
            cc_algo = "bbr"
            program = CQUICHE_S_TIMER
        case "cquiche-bbr2":
            cc_algo = "bbr2"
            program = CQUICHE_S_TIMER
        case _:
            raise ValueError("Invalid protocol")
    command = [
        "ip",
        "netns",
        "exec",
        f"{CLI_NS}_{network_namespace}",
        program,
        kem_alg,
        str(measurements),
    ]
    if cc_algo is not None:
        command.append(cc_algo)

    result = run_subprocess(command)
    release_network_namespace(network_namespace)
    error_count, result = result.split(";")
    error_count = int(error_count)
    return error_count, [float(i) for i in result.strip().split(",")]


def run_subprocess(command, working_dir=".", expected_returncode=0) -> str:
    result = subprocess.run(
        command,
        stdout=subprocess.PIPE,  # puts stdout in result.stdout
        stderr=subprocess.PIPE,
        cwd=working_dir,
    )
    if result.stderr:
        print(result.stderr)
    assert result.returncode == expected_returncode
    return result.stdout.decode("utf-8")


# TODO think about what to do with this rtt calculation, maybe just delete
# maybe put it below the setting of the network parameters and save to a different file
# To get actual (emulated) RTT
def get_emulated_rtt(scenarios):
    change_qdisc(
        SRV_NS,
        SRV_VE,
        scenarios.iloc[0]["srv_rate"],
        scenarios.iloc[0]["srv_delay"],
        0,
        0,
        0,
        0,
        0,
    )
    change_qdisc(
        CLI_NS,
        CLI_VE,
        scenarios.iloc[0]["cli_rate"],
        scenarios.iloc[0]["cli_delay"],
        0,
        0,
        0,
        0,
        0,
    )
    rtt_str = get_rtt_ms()
    print(
        f"Emulated average RTT: {rtt_str} ms",
    )


def get_rtt_ms():
    command = ["ip", "netns", "exec", "cli_ns", "ping", "10.0.0.1", "-c", "10"]

    print(" > " + " ".join(command))
    result = run_subprocess(command)

    # last line is "rtt min/avg/max/mdev = 5.978/6.107/6.277/0.093 ms"
    result_fmt = result.splitlines()[-1].split("/")
    return result_fmt[4]


main()