#include <iostream>
#include <iomanip>
#include <sstream>
#include <string>
#include <thread>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <mutex>
#include <set>
#include <fstream>

extern "C"
{
#include <sched.h>
#include <unistd.h>
}

const long n = 42l * 1024 * 1024 * 100;
const double h = 1.0 / (double)n;

std::mutex used_cpu_ids_mtx;
std::set<int> used_cpu_ids;


// A mutex and macro to sync the output to stdout
std::mutex iomtx;
#define IO_SYNC(X)                                 \
    {                                              \
        std::lock_guard<std::mutex> iolock(iomtx); \
        X                                          \
    }

// Get the current timestamp in micro seconds
uint64_t now_micro()
{
    return std::chrono::time_point_cast<std::chrono::microseconds>(
               std::chrono::high_resolution_clock::now()
    ).time_since_epoch().count();
}

// Print stats about the calling thread, including the currently executing CPU Core ID
void print_thread_stats()
{
    // Get the CPU Core ID on which this thread is currently being executed
    auto cpu = sched_getcpu();
    {
        std::lock_guard<std::mutex> lock(used_cpu_ids_mtx);
        used_cpu_ids.insert(cpu);
    }

    std::cout
        // Process ID of this process
        << "pid = "
        << getpid()
        // Thread ID of this thread
        << ", thread_id = "
        << std::this_thread::get_id()
        // CPU Core ID on which this is currently being executed
        << ", cpu_id = "
        << cpu;
}

// In Linux CPUs with SMT/Hyperthreading are counted twice in the cpu list. So for example on a 
// machine with SMT and 4 cores, the cpu list might look like this:
// 0, 1, 2, 3, 4, 5, 6, 7
// While there are 8 cpus in the list, only 4 cores are actually physically available. So there are 
// pairs of cpus (so called siblings) that are actually the same physical core, but split up into 
// 2 virtual cores. Those pairs could be for example 0,4  1,5  2,6  3,7  . 
// This function returns the sibling pair for a given cpu core as a string
std::string cpu_siblings(int cpu_id)
{
    std::stringstream filename;
    filename 
        << "/sys/devices/system/cpu/cpu"
        << cpu_id
        << "/topology/thread_siblings_list";

    std::ifstream file(filename.str());
    
    std::string siblings;
    file >> siblings;

    return siblings;
}

void pi_thread(int thread_num, int numThreads, double *partial_pi)
{
    IO_SYNC(
        print_thread_stats();
        std::cout << '\n';
    );

    auto tstart = now_micro();

    double sum = 0.0;
    for (long i = thread_num + 1; i <= n; i += numThreads)
    {
        double x = h * ((double)i - 0.5);
        sum += 4.0 / (1.0 + x * x);
    }
    *partial_pi = h * sum;

    auto elapsed = now_micro() - tstart;

    IO_SYNC(
        print_thread_stats();
        std::cout
            // The time spent calculating on this specific thread
            << ", thread_calc_time = "
            << (elapsed / 1000)
            << " ms"
            << '\n';
    );
}

int main(int argc, char *argv[])
{
    if (argc < 2)
    {
        std::cerr << "Usage: " << argv[0] << " <number-of-threads>" << std::endl;
        return -1;
    }

    int numThreads = std::stoi(argv[1]);

    char hostname[256];
    gethostname(hostname, 256);

    auto hwc = std::thread::hardware_concurrency();

    std::cout << "Running on node: " << hostname << '\n';
    std::cout << "CPP detected hardware concurrency: " << hwc << '\n';
    std::cout << "Main thread: ";
    print_thread_stats();
    std::cout << "\n--------------------\n";

    auto tstart = now_micro();

    std::thread threads[numThreads];
    double partials[numThreads];

    for (int thread_num = 0; thread_num < numThreads; thread_num++)
    {
        threads[thread_num] = std::thread(
            pi_thread, thread_num, numThreads, &(partials[thread_num])
        );
    }

    double pi = 0;
    for (int i = 0; i < numThreads; ++i)
    {
        threads[i].join();
        pi += partials[i];
    }

    auto elapsed_ms = (now_micro() - tstart) / 1000;

    std::cout << "--------------------\n";
    std::cout << std::setprecision(16) << "Error is " << std::fabs(pi - M_PI) << '\n';
    std::cout << "Calculation took " << elapsed_ms << " ms\n";
    std::cout << "Num Threads = " << numThreads << '\n';

    // Print what cores have been utilized over the runtime of the program. Since those are only
    // sampled at specific point, it is theoretically possible that cores are missing
    std::cout << "Utilized CPU ids: \n";
    for (auto cpu : used_cpu_ids)
    {
        // For each core, also print the siblings
        std::cout << "  " << cpu << ", siblings: " << cpu_siblings(cpu) << '\n';
    }

    return 0;
}