import collections
import functools
import json
import logging
import threading
import time
import traceback
from datetime import datetime
from enum import Enum
from typing import List, Callable, Optional, Dict

import pika
import pika.exceptions
from pydantic import BaseModel, validator

from pilab.events import utils

logger = logging.getLogger(__name__)
logging.getLogger("pika").setLevel(logging.WARNING)


class ActionType(str, Enum):
    POWER_ON = "POWER_ON"
    POWER_OFF = "POWER_OFF"
    START_KERNEL = "START_KERNEL"
    FETCH_WORKFLOW = "FETCH_WORKFLOW"
    DOWNLOAD_IMAGE = "DOWNLOAD_IMAGE"
    PRE_BOOT_CONFIG = "PRE_BOOT_CONFIGURATION"
    BOOT_IMAGE = "BOOT_IMAGE"
    CLOUD_INIT_CONFIG = "CLOUD_INIT_CONFIG"
    ANSIBLE_CONFIG = "ANSIBLE_CONFIG"


class ActionState(str, Enum):
    ACTION_SUCCESS = "ACTION_SUCCESS"
    ACTION_IN_PROGRESS = "ACTION_IN_PROGRESS"
    ACTION_FAILED = "ACTION_FAILED"
    ACTION_PENDING = "ACTION_PENDING"


class Properties(BaseModel):
    timestamp: float = None
    configured: Optional[bool] = False

    @validator("timestamp", pre=True, always=True)
    def validate_date(cls, value, values):
        return value or datetime.utcnow().timestamp()


class Action(BaseModel):
    type: ActionType
    state: ActionState
    progress: Optional[int]
    message: Optional[str]


class State(str, Enum):
    INACTIVE = "INACTIVE"
    PENDING = "PENDING"
    ACTIVE = "ACTIVE"
    FAILED = "FAILED"


class Event(BaseModel):
    mac: str
    state: State
    action: Optional[Action]
    properties: Properties


class Type(str, Enum):
    controller = "controller"
    pi = "pi"
    switch = "switch"


def send_event(routing_key: str, event: Event):
    try:
        connection = utils.get_blocking_connection()
        channel = connection.channel()

        channel.queue_declare(
            queue=routing_key,
            durable=True,
            exclusive=False,
            auto_delete=False,
            arguments={"x-queue-type": "stream"}
        )

        event_serialized = json.dumps(event.dict()).encode('utf-8')

        channel.basic_publish(
            exchange='',
            routing_key=routing_key,
            body=event_serialized)

        connection.close()
    except pika.exceptions.UnroutableError as e:
        print(f'Could not send event to Broker; message {e}')
        raise
    except Exception as e:
        logger.error(f'Error sending host event; message: {e}')
        raise


def send_pi_event(mac: str, state: State, action: Action, props: Properties):
    routing_key = f"host.pi.{mac}"
    event = Event(
        mac=mac,
        state=state,
        action=action,
        properties=props
    )
    send_event(routing_key=routing_key, event=event)


def send_controller_event(mac: str, state: State):
    routing_key = f"host.controller.{mac}"
    event = Event(
        mac=mac,
        state=state,
        properties=Properties()
    )
    send_event(routing_key=routing_key, event=event)


def send_switch_event(mac: str, state: State):
    routing_key = f"host.switch.{mac}"
    event = Event(
        mac=mac,
        state=state,
        properties=Properties()
    )
    send_event(routing_key=routing_key, event=event)


class PowerConsumer(threading.Thread):

    def __init__(self, queues: List[str], callback: Callable, type: Type, amqp_url: str = utils.BROKER_URL,
                 arguments: Dict = {"x-stream-offset": "last"}):
        """Create a new instance of the consumer class, passing in the AMQP
        URL used to connect to RabbitMQ.

        :param str amqp_url: The AMQP url to connect with

        """
        super(PowerConsumer, self).__init__()
        self.should_reconnect = False
        self.was_consuming = False

        self._connection = None
        self._channel = None
        self._closing = False
        self._consumer_tags = []
        self._url = amqp_url
        self._consuming = False
        # In production, experiment with higher prefetch values
        # for higher consumer throughput
        self._prefetch_count = 10

        self.type = type
        self.callback = callback
        self.queues = queues
        self.arguments = arguments

    def connect(self):
        """This method connects to RabbitMQ, returning the connection handle.
        When the connection is established, the on_connection_open method
        will be invoked by pika.

        :rtype: pika.SelectConnection

        """
        logger.info('Connecting to %s', self._url)

        return pika.SelectConnection(
            parameters=pika.URLParameters(self._url),
            on_open_callback=self.on_connection_open,
            on_open_error_callback=self.on_connection_open_error,
            on_close_callback=self.on_connection_closed)

    def close_connection(self):
        self._consuming = False
        if self._connection.is_closing or self._connection.is_closed:
            logger.info('Connection is closing or already closed')
        else:
            logger.info('Closing connection')
            self._connection.close()

    def on_connection_open(self, _unused_connection):
        logger.debug('Connection opened')
        self.open_channel()

    def on_connection_open_error(self, _unused_connection, err):
        logger.error('Connection open failed: %s', err)
        self.reconnect()

    def on_connection_closed(self, _unused_connection, reason):
        self._channel = None
        if self._closing:
            self._connection.ioloop.stop()
        else:
            logger.warning('Connection closed, reconnect necessary: %s', reason)
            self.reconnect()

    def reconnect(self):
        """Will be invoked if the connection can't be opened or is
        closed. Indicates that a reconnect is necessary then stops the
        ioloop.

        """
        self.should_reconnect = True
        self.stop()

    def open_channel(self):
        logger.debug('Creating a new channel')
        self._connection.channel(on_open_callback=self.on_channel_open)

    def on_channel_open(self, channel):
        logger.debug('Channel opened')
        self._channel = channel
        self._channel.add_on_close_callback(self.on_channel_closed)
        self._channel.add_on_cancel_callback(self.on_consumer_cancelled)
        for q in self.queues:
            queue_name = f'host.{self.type.value}.{q}'
            self.setup_queue(queue_name)

    def on_channel_closed(self, channel, reason):
        logger.warning('Channel %i was closed: %s', channel, reason)
        self.close_connection()

    def setup_queue(self, queue_name):
        """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC
        command. When it is complete, the on_queue_declareok method will
        be invoked by pika.

        :param str|unicode queue_name: The name of the queue to declare.

        """
        logger.debug('Declaring queue %s', queue_name)
        cb = functools.partial(self.on_queue_declareok, queue_name=queue_name)
        self._channel.queue_declare(queue=queue_name,
                                    durable=True,
                                    exclusive=False,
                                    auto_delete=False,
                                    arguments={"x-queue-type": "stream"},
                                    callback=cb)

    def on_queue_declareok(self, _unused_frame, queue_name):
        logger.info('Queue bound: %s', queue_name)
        self.set_qos(queue_name)

    def set_qos(self, queue_name):
        cb = functools.partial(self.on_basic_qos_ok, queue_name=queue_name)
        self._channel.basic_qos(
            prefetch_count=self._prefetch_count, callback=cb)

    def on_basic_qos_ok(self, _unused_frame, queue_name):
        logger.debug('QOS set to: %d', self._prefetch_count)
        self.start_consuming(queue_name)

    def start_consuming(self, queue_name):
        logger.debug('Issuing consumer related RPC commands')
        _consumer_tag = self._channel.basic_consume(
            queue_name,
            self.on_message,
            auto_ack=False,
            arguments=self.arguments)
        self._consumer_tags.append(_consumer_tag)
        self.was_consuming = True
        self._consuming = True

    def on_consumer_cancelled(self, method_frame):
        """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer
        receiving messages.

        :param pika.frame.Method method_frame: The Basic.Cancel frame

        """
        logger.info('Consumer was cancelled remotely, shutting down: %r',
                    method_frame)
        if self._channel:
            self._channel.close()

    def on_message(self, _unused_channel, basic_deliver, properties, body):
        try:
            event = Event.parse_obj(json.loads(body.decode('utf-8')))
            self.callback(event)
            self._channel.basic_ack(basic_deliver.delivery_tag)
            logger.debug(f"Event processed successful")
        except Exception as e:
            traceback.print_exc()
            logger.error(f"Failed to process Event; message {e}")

    def stop_consuming(self):
        """Tell RabbitMQ that you would like to stop consuming by sending the
        Basic.Cancel RPC command.

        """
        if self._channel:
            logger.debug('Sending a Basic.Cancel RPC command to RabbitMQ')
            for tag in self._consumer_tags:
                cb = functools.partial(self.on_cancelok, userdata=tag)
                self._channel.basic_cancel(tag, cb)

    def on_cancelok(self, _unused_frame, userdata):
        self._consumer_tags.remove(userdata)
        if not self._consumer_tags:
            self._consuming = False
            logger.info('RabbitMQ acknowledged the cancellation of all consumers')
            self.close_channel()

    def close_channel(self):
        """Call to close the channel with RabbitMQ cleanly by issuing the
        Channel.Close RPC command.

        """
        logger.info('Closing the channel')
        self._channel.close()

    def run(self):
        while True:
            try:
                self._connection = self.connect()
                self._connection.ioloop.start()
            except (KeyboardInterrupt, SystemExit):
                self.stop()
                break
            if not self.should_reconnect:
                break

            self._consuming = False
            self._closing = False
            time.sleep(1)

    def stop(self):
        """Cleanly shutdown the connection to RabbitMQ by stopping the consumer
        with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok
        will be invoked by pika, which will then closing the channel and
        connection. The IOLoop is started again because this method is invoked
        when CTRL-C is pressed raising a KeyboardInterrupt exception. This
        exception stops the IOLoop which needs to be running for pika to
        communicate with RabbitMQ. All of the commands issued prior to starting
        the IOLoop will be buffered but not processed.

        """
        if not self._closing:
            self._closing = True
            logger.debug('Stopping')
            if self._consuming:
                self.stop_consuming()
            else:
                self._connection.ioloop.stop()
            logger.info('PowerConsumer stopped')


state: Dict[str, Event] = {}
custom_callback: Callable = None

pi_consumer: PowerConsumer = None
controller_consumer: PowerConsumer = None
switch_consumer: PowerConsumer = None


def event_callback(event: Event):
    logger.debug(f"Received Power event: {event}")
    if custom_callback:
        custom_callback(event)
    state[event.mac] = event


def run_restart_consumer(macs: List[str], _type: Type, callback: Callable = None):
    global custom_callback, pi_consumer, controller_consumer, switch_consumer
    custom_callback = callback

    def restart_consumer(consumer: PowerConsumer):
        if not consumer:
            consumer = PowerConsumer(queues=macs, callback=event_callback, type=_type)
            consumer.start()
            logger.info(f'Restart power consumer for type {_type.value}')
            return consumer
        elif collections.Counter(consumer.queues) != collections.Counter(macs):
            consumer.stop()
            consumer.join()
            consumer = PowerConsumer(queues=macs, callback=event_callback, type=_type)
            consumer.start()
            logger.info(f'Restart power consumer for type {_type.value}')
            return consumer
        else:
            return consumer

    match _type:
        case Type.pi:
            pi_consumer = restart_consumer(pi_consumer)
        case Type.controller:
            controller_consumer = restart_consumer(controller_consumer)
        case Type.switch:
            switch_consumer = restart_consumer(switch_consumer)
        case _:
            raise ValueError("Type is not known")


def get_latest_event(mac: str):
    return state.get(mac)

def get_latest_state(mac: str):
    event = state.get(mac)
    return event.state if event else None