import math
import os
from datetime import datetime

class IOOperations:
    """This class is used to read and write the different file formats used."""

    def __init__(self):
        pass

    def exportDetectedDefects(self, defects_dict, pathToFolder):
        """Exports the found defects

        Parameters
        ------
        defects_dict : dict
        A dictionary with the defective channel indices

        pathToFolder : str
        Path to the output folder
        """

        if not os.path.exists(pathToFolder + "/results"):
            os.mkdir(pathToFolder + "/results")

        date = datetime.now()
        date_str = date.strftime("%d%m%Y")
        time_str = date.strftime("%H%M")

        filename = pathToFolder + "/results/quality_result_" + date_str + ".txt"

        out_str = ""
        numOfDefectsTotal = 0
        for asic, data in defects_dict.items():
            out_str += "%d : ["%(asic)
            for ch in data.get("even"):
                out_str += "%d,"%(ch)
                numOfDefectsTotal += 1
            for ch in data.get("odd"):
                out_str += "%d,"%(ch)
                numOfDefectsTotal += 1
            if out_str[-1] != "[":
                out_str = out_str[:-1] + "]\n"
            else:
                out_str += "]\n"

        with open(filename, "w") as file:
            file.write("Date : %s\n"%(date_str))
            file.write("Time : %s\n"%(time_str))
            file.write("Total number of defects : %d\n"%(numOfDefectsTotal))
            file.write(out_str)
            print("Data successfully exported to file \"%s\""%(filename))


    def exportNumOfDefects(self, defects_list, thld, asic_nr, pathToFolder, additionalFilename=""):
        """Exports the number of defects found and the respective threshold used.

        Parameters
        ------
        defects_list : list
        A list with the different numbers of defects found

        thld : float
        The threshold step used

        asic_nr : int
        The number of the ASIC used to generate the data

        pathToFolder : str
        Path to the output folder

        additionalFilename : str
        Optional string to be added to the filename
        """

        if not os.path.exists(pathToFolder + "/results"):
            os.mkdir(pathToFolder + "/results")

        filename = pathToFolder + "/results/defects_" + str(asic_nr) + "_" + additionalFilename + ".csv"

        with open(filename, "w") as file:
            for i in range(len(defects_list)):
                line = str(float(i*thld)) + "\t" + str(int(defects_list[i])) + "\n"
                file.write(line)


    def exportFilterThresholds(self, data, start, step, pathToFolder, asic_nr, additionalFilename=""):
        """Exports the number of filtered channels with the respective threshold.

        Parameters
        ------
        data : list
        List with the number of filtered channels

        start : float
        The initial threshold

        step : float
        The step the threshold was increased for every element

        asic_nr : int
        The number of the ASIC used to generate the data

        pathToFolder : str
        Path to the output folder

        additionalFilename : str
        Optional string to be added to the filename
        """

        if not os.path.exists(pathToFolder + "/results"):
            os.mkdir(pathToFolder + "/results")

        filename = pathToFolder + "/results/" + additionalFilename + "_" + str(asic_nr) + ".csv"

        with open(filename, "w") as file:
            for i in range(len(data)):
                line = str(i*step + start) + "\t" + str(int(data[i])) + "\n"
                file.write(line)


    def exportCountComparison(self, data, isEven, asic_nr, pathToFolder, additionalFilename=""):
        """This function provides a special output to generate a file with pre and post filtering data for direct comparison.

        Parameters
        ------
        data : list
        A list with lists of the same length to be printed parallel

        isEven : bool
        Is true if even channels are processed; Used to display correct channel index

        asic_nr : int
        The number of the ASIC used to generate the data

        pathToFolder : str
        Path to the output folder

        additionalFilename : str
        Optional string to be added to the filename
        """

        if not os.path.exists(pathToFolder + "/results"):
            os.mkdir(pathToFolder + "/results")

        filename = pathToFolder + "/results/" + additionalFilename + "_" + str(asic_nr) + ".csv"

        with open(filename, "w") as file:
            for ch_index in range(len(data[0])):
                if isEven:
                    line = str(ch_index*2) + "\t"
                else:
                    line = str(ch_index*2+1) + "\t"

                for elem in data:
                    line += str(int(elem[ch_index])) + "\t"
                line+= "\n"
                file.write(line)



    def exportChannelCounts(self, channel_counts, asic_nr, pathToFolder, additionalFilename = ""):
        """This function exports the calculated channel counts.

        Parameters
        ------
        channel_counts : list
        A list with the average counts for each channel

        asic_nr : int
        The number of the ASIC this data was generated

        pathToFolder : str
        A path to the output folder

        additionalFilename : str
        Optional string to be added to the filename
        """

        if not os.path.exists(pathToFolder + "/results"):
            os.mkdir(pathToFolder + "/results")

        filename = pathToFolder + "/results/counts_" + str(asic_nr) + additionalFilename + ".csv"
        with open(filename, "w") as file:
            for ch_index in range(len(channel_counts)):
                line = str(ch_index) + "\t" + str(int(channel_counts[ch_index])) + "\n"
                file.write(line)

    def exportSCurves(self, sCurveDict, asic_nr, pathToFolder):
        """This function exports the calculated s curves into a csv file.

        Parameters
        ------
        sCurveDict : dict
        A dictionary with the averages and sigmas for each discriminator

        asic_nr : int
        The number of the ASIC this data was generated

        pathToFolder : str
        A path to the output folder
        """

        if not os.path.exists(pathToFolder + "/results"):
            os.mkdir(pathToFolder + "/results")

        filename = pathToFolder + "/results/scurve_" + str(asic_nr) + ".csv"
        with open(filename, "w") as file:
            for disc, data in sCurveDict.items():
                line = str(disc) + "\t" + str(int(data.get("avg"))) + "\t" + str(int(data.get("sigma"))) + "\n"
                file.write(line)

    def readConnResults(self, pathToFile):
        """This class reads the results from the connection test stored in the referenced file.

        Parameters
        ------
        pathToFolder : str
        The path to the file with the raw data

        Returns
        ------
        raw_data : dict
        A dictionary with a list representing the data and additional information
        """

        raw_data = dict()
        lines_list = list()

        filename = pathToFile.split("/")[-1]
        asic_nr = int(filename.split("_")[4])
        loop_nr = int(filename.split("_")[-1].split(".")[0])
        input_str = ""

        with open(pathToFile, "r") as file:
            input_str = file.read()

        input_str = input_str.split("\n")

        if input_str[-1] == "":
            input_str = input_str[:-1]

        for line in input_str:
            line = line.split()
            channel_list = line[4:]
            lines_list.append(channel_list)

        if loop_nr != (len(lines_list) / 128):
            print("nloops not matching")

        raw_data["data"] = lines_list
        raw_data["asic_nr"] = asic_nr
        raw_data["nloops"] = loop_nr

        return raw_data

class DetectDefects:
    """This class is used for detecting the defective channels"""

    def findDefectiveChannels(self, deviation_list:list, asic_nr:int, isEven:bool)-> list:
        """Analyses the deviation data and detects defective channels and their index.

        Parameters
        ------
        deviation_list : list
        A list with the percentual deviations

        asic_nr : int
        The number of the ASIC the data was generated with

        isEven : bool
        Is the data from even or odd channels (used for numbering)

        Returns
        ------
        defective_channels : list
        A list with the indices of the defective channels
        """

        thld = 0.90
        defective_channels = []

        offset = 0
        if not isEven:
            offset = 1

        if isEven:
            print("====================\n||\tASIC %d\t  ||\n===================="%(asic_nr))

        for i in range(len(deviation_list)):
            if deviation_list[i] > thld:
                defective_channels.append(i*2+offset)
                print("\tCh: %d"%(i*2+offset))
        return defective_channels

class ExtremeValues:
    """This class provides functionality to filter extreme values."""

    def __init__(self):
        self.calc = Calculus()

    def calculateNumberOfFilteredValues(self, raw:list, filtered:list)-> int:
        """This function compares two lists and counts the number of different values.

        Parameters
        ------
        raw : list
        A list with numerical values for the unfiltered counts

        filtered : list
        A list with the filtered count values

        Returns
        ------
        numOfFiltered : int
        Number of different (=filtered) values
        """

        numOfFiltered = 0
        if len(raw) == len(filtered):
            for index in range(len(raw)):
                if raw[index] != filtered[index]:
                    numOfFiltered += 1

        return numOfFiltered


    def filterExtremeValues(self, data:list, thld:float)->list:
        """Filters extreme values based on the percentual deviation between two channels.
        To prevent division by 0 errors, any 0 counts will be replaced with 1.

        Parameters
        ------
        data : list
        The channel counts

        thld : float
        The threshold the deviation may not exceed

        Returns
        ------
        filtered_list : list
        A list with the filtered and replaced values.
        """

        #Define borders for good values
        upperLimit = self.calc.calculateUpperBorder(data)
        lowerLimit = self.calc.calculateLowerBorder(data)

        #Calculate percentualdeviation between neighbouring channels
        deviation_list = []

        for index in range(len(data)-1):
            if data[index] == 0:
                _dev = (data[index+1] - data[index])
            else:
                _dev = (data[index+1] - data[index])/data[index]
            deviation_list.append(_dev)

        filtered_list = list(data)
        index = 0
        leftBorderIndex = 0
        rightBorderIndex = 0
        replaceMode = False

        #Is the first element outside the limits
        if (data[index] < lowerLimit) or (data[index] > upperLimit):
            replaceMode = True
            leftBorderIndex = 0
            index = 1
            while replaceMode == True:
                if (data[index] < upperLimit) and (data[index] > lowerLimit):
                    rightBorderIndex = index
                    replaceMode = False
                index = index + 1
                if index >= len(data):
                    rightBorderIndex = index
                    replaceMode = False
            for i in range(rightBorderIndex):
                filtered_list[i] = data[rightBorderIndex]
        while index < (len(data)-1):
            if (deviation_list[index] > thld) or (deviation_list[index] < -thld):
                leftBorderIndex = index
                replaceMode = True
                index += 2
                if index >= len(data):
                    index = len(data)-1
                while replaceMode == True:
                    if (data[index] < upperLimit) and (data[index] > lowerLimit):
                        rightBorderIndex = index
                        replaceMode = False
                    index = index + 1
                    if index >= len(data):
                        rightBorderIndex = index
                        replaceMode = False
                if rightBorderIndex == len(data):               #Last element is extreme - replace complete with left border value
                    for i in range(leftBorderIndex+1, rightBorderIndex):
                        filtered_list[i] = data[leftBorderIndex]
                else:
                    _avg = (data[leftBorderIndex] + data[rightBorderIndex])/2
                    for i in range(leftBorderIndex+1, rightBorderIndex):
                        filtered_list[i] = _avg
            index += 1

        return filtered_list

class ReferenceValues:
    """This class is used to calculate the different reference values."""

    def __init__(self):
        self.calc = Calculus()

    def calculateReferenceAverages(self, data):
        """Funtion to calculate different averages on the same data for comparison.

        Parameters
        ------
        data : list
        List with the counts per channel
        """

        average_list = list()

        for size in range(3, 14, 2):           #every odd number from 2 to 13
            average_list.append(self.calc.calculateNAverageOnList(data, size))

        return average_list

class ChannelCounts:
    """This class contains different funtions to perform different analysis on
    the ASIC level.
    """

    def __init__(self):
        self.calc = Calculus()

    def performAllAnalysis(self, channel_array:list)-> None:
        """Performs all analysis for channel based counts.

        Parameters
        ------
        data : list
        A two dimensional array with the average discriminator counts for each channel
        """

        self.counts = self.calculateAvgCounts(channel_array)


    def calculateAvgCounts(self, data:list) -> tuple[list, list, list]:
        """This function takes the per channel discriminator counts and
        calculates the channel count as the sum of the discriminator counts.

        Parameters
        ------
        data : list
        A two dimensional array with the average discriminator counts for each channel

        Returns
        ------
        channel_counts : tuple
        A list with the 128 channel counts and
        """

        channel_counts = list()

        for channel in data:
            channel_counts.append(sum(channel))

        counts_even = list()
        counts_odd = list()
        for index in range(0,128, 2):
            counts_even.append(channel_counts[index])
        for index in range(1,128, 2):
            counts_odd.append(channel_counts[index])
        return (channel_counts, counts_even, counts_odd)

class SCurves:
    """This class performs different calculation on the discriminator level of the data."""

    def __init__(self):
        self.calc = Calculus()

    def performAllAnalysis(self, raw_data_lines:dict)-> None:
        self.channel_array = self.calculateAvgDiscriminator(raw_data_lines)
        self.sCurve = self.calculateAvgSCurve(self.channel_array)

    def calculateAvgDiscriminator(self, data:dict)-> list:
        """This function calculates the average for each channel discriminator over all nloops.

        Parameters
        ------
        data : dict
        A dictionary with a list of all data lines from the data file

        Returns
        ------
        channel_array : list
        A two-dimensional list with the channel discriminators
        """

        channel_list = list()

        for i in range(128):
            disc_array = list()
            for j in range(31):
                disc_array.append(list())
            channel_list.append(disc_array)

        for line_index in range(len(data.get("data"))):
            line = data.get("data")[line_index]
            for disc_index in range(len(line)):
                channel_list[(line_index%128)][disc_index].append(int(line[disc_index]))

        channel_array = list()

        for ch_index in range(128):
            disc_list = list()
            for disc_index in range(31):
                _avg = self.calc.calculateAverageOnList(channel_list[ch_index][disc_index])
                disc_list.append(_avg)
            channel_array.append(disc_list)

        return channel_array

    def calculateAvgSCurve(self, data:list)-> dict:
        """This function calculates the average and standard deviation for each
        discriminator over all channels.

        Parameters
        ------
        data : list
        A list with the discriminator count for each channels

        Returns
        ------
        sCurve : dict
        A dictionary with two list: the average counts and the respective sigma
        """

        disc_list = list()
        sCurve = dict()

        for i in range(31):
            disc_list.append(list())

        for line_index in range(len(data)):
            line = data[line_index]
            for disc_index in range(len(line)):
                disc_list[disc_index].append(int(line[disc_index]))

        for disc_index in range(len(disc_list)):
            avg = self.calc.calculateAverageOnList(disc_list[disc_index])
            sCurve[disc_index] = {"avg" : avg}

        return sCurve

class Calculus:
    """A class for calculating different metrics."""

    def __init__(self):
        pass

    def calculateDefectiveChannels(self, deviation:list, thld:float)-> float:
        """Calculates the number of defective channels based on the deviation lists and a threshold.

        Parameters
        ------
        deviation : list
        A list with the percentual deviations

        thld : float
        The threshold for the maximum allowed deviation

        Returns
        ------
        numOfDefects : int
        The number of defects
        """

        numOfDefects = 0

        for dev in deviation:
            if (dev > thld):
                numOfDefects += 1

        return numOfDefects

    def calculateDeviationBetweenLists(self, data1:list, data2:list)-> list:
        """Calculates the percentual deviation between the n-th elements of the two lists.

        The data1 list represents the reference value the deviation is relative to.

        Parameters
        ------
        data1 : list
        List with numerical data. Acts as reference for the calculation

        data2 : list
        List with numerical data

        Returns
        ------
        deviation_list : list
        List with the percentual deviation between the elements of both lists
        """

        if len(data1) == len(data2):
            deviation_list = []
            for index in range(len(data1)):
                if data1[index] == 0:
                    _dev = data2[index] - data1[index]
                else:
                    _dev = (data2[index] - data1[index]) / data1[index]
                deviation_list.append(_dev)

        return deviation_list

    def calculatePercentualDeviation(self, data:list)-> list:
        """Caclulates the percentual deviation between the elements of the given list.

        Parameters
        ------
        data : list
        The list with the data

        Returns
        ------
        deviation_list : list
        The list with the percentual deviation between the elements
        """

        deviation_list = []

        for index in range(len(data)-1):
            if data[index] == 0:
                data[index+1] - data[index]
            else:
                _dev = (data[index+1] - data[index])/data[index]
            deviation_list.append(_dev)

        return deviation_list


    def calculateNAverageOnList(self, _list:list, size:int)-> list:
        """Calculate a reference average over a defined subset of elements from a given list
        for each element in the list.
        The used elements are defined by the size parameter and the index of the
        referenced element.

        Parameters
        ------
        _list : list
        A list of numerical elements

        size : int
        The number of elements to use - needs to be an odd number

        Returns
        ------
        avg_list : list
        A list with numerical values
        """

        numberPerSide = int(size / 2)
        avg_list = []

        for index in range(len(_list)):

            leftBorderIndex = index - numberPerSide
            rightBorderIndex = index + numberPerSide

            if leftBorderIndex < 0:
                leftBorderIndex = 0
            if rightBorderIndex >= len(_list):
                rightBorderIndex = len(_list)

            _avg = self.calculateAverageOnList(_list[leftBorderIndex:rightBorderIndex])
            avg_list.append(_avg)

        return avg_list



    def calculateAverageOnList(self, _list:list)-> float:
        """Calculates the average over a list of numerical values.

        Parameters
        ------
        _list : list
        A list with numerical values

        Returns
        ------
        avg : float
        The calculated average
        """
        _avg = 0.0

        for elem in _list:
            _avg += elem

        if len(_list) == 0:
            avg = 0
        else:
            avg = float(_avg / len(_list))

        return avg

    def calculateStdDeviationOnList(self, _list:list)->float:
        """Calculates the standard deviation over all elements of a given list.

        Parameters
        ------
        _list : list
        A list of numerical values

        Returns
        ------
        sigma : float
        The sigma of the list
        """

        _avg = self.calculateAverageOnList(_list)

        _sigma_sum = 0.0

        for elem in _list:
            _sigma_sum = _sigma_sum + pow((elem - _avg), 2)

        _sigma_avg = _sigma_sum / len(_list)

        sigma = math.sqrt(_sigma_avg)

        return sigma

    def calculateUpperBorder(self, _list:list)->float:
        """Caclulates the upper border by calculating the average over all local maxima.

        Parameters
        ------
        _list : list
        List with numerical values

        Returns
        ------
        upperBorder : float
        The value of the upper border

        """

        localMaxima = []

        for index in range(1, len(_list)-1):
            if _list[index] > _list[index-1] and _list[index] > _list[index+1]:
                localMaxima.append(_list[index])

        if _list[0] > _list[1]:
            localMaxima.append(_list[0])
        if _list[-2] < _list[-1]:
            localMaxima.append(_list[-1])

        upperBorder = self.calculateAverageOnList(localMaxima)
        return upperBorder

    def calculateLowerBorder(self, _list:list)-> float:
        """Caclulates the lower border by calculating the average over all local minima.

        Parameters
        ------
        _list : list
        List with numerical values

        Returns
        ------
        lowerBorder : float
        The value of the lower border

        """

        localMinima = []

        for index in range(1, len(_list)-1):
            if _list[index] < _list[index-1] and _list[index] < _list[index+1]:
                localMinima.append(_list[index])

        if _list[0] < _list[1]:
            localMinima.append(_list[0])
        if _list[-2] > _list[-1]:
            localMinima.append(_list[-1])

        lowerBorder = self.calculateAverageOnList(localMinima)
        return lowerBorder