import functools
import json
from json import JSONEncoder
import logging
import threading
import time
import traceback
from collections import Counter
from datetime import datetime, date
from zoneinfo import ZoneInfo
from typing import List, Callable, Optional, Dict

import pika
import pika.exceptions
from pydantic import BaseModel

from pilab.events import utils

logger = logging.getLogger(__name__)
logging.getLogger("pika").setLevel(logging.WARNING)


class Reservation(BaseModel):
    owner: str
    cube_id: int
    starttime: datetime
    endtime: datetime
    extraUsers: Optional[List[str]] = None
    reason: str


class DateTimeEncoder(JSONEncoder):
    # Override the default method
    def default(self, obj):
        if isinstance(obj, (date, datetime)):
            return obj.isoformat()
        return super().default(obj)
    

def send_event(event: Reservation):
    routing_key = f"res.cube.{event.cube_id}"
    try:
        connection = utils.get_blocking_connection()
        channel = connection.channel()

        channel.queue_declare(
            queue=routing_key,
            durable=True,
            exclusive=False,
            auto_delete=False,
            arguments={"x-queue-type": "stream"}
        )

        event_serialized = json.dumps(event.dict(), cls=DateTimeEncoder).encode('utf-8')

        channel.basic_publish(
            exchange='',
            routing_key=routing_key,
            body=event_serialized)

        connection.close()
    except pika.exceptions.UnroutableError as e:
        print(f'Could not send event to Broker; message {e}')
        raise
    except Exception as e:
        logger.error(f'Error sending host event; message: {e}')
        raise


class ReservationConsumer(threading.Thread):

    def __init__(self, cubes: List[int], callback: Callable, amqp_url: str = utils.BROKER_URL,
                 arguments: Dict = {"x-stream-offset": "last"}):
        """Create a new instance of the consumer class, passing in the AMQP
        URL used to connect to RabbitMQ.

        :param str amqp_url: The AMQP url to connect with

        """
        super(ReservationConsumer, self).__init__()
        self.should_reconnect = False
        self.was_consuming = False

        self._connection = None
        self._channel = None
        self._closing = False
        self._consumer_tags = []
        self._url = amqp_url
        self._consuming = False
        # In production, experiment with higher prefetch values
        # for higher consumer throughput
        self._prefetch_count = 10

        self.callback = callback
        self.cubes = cubes
        self.arguments = arguments

    def connect(self):
        """This method connects to RabbitMQ, returning the connection handle.
        When the connection is established, the on_connection_open method
        will be invoked by pika.

        :rtype: pika.SelectConnection

        """
        logger.info('Connecting to %s', self._url)

        return pika.SelectConnection(
            parameters=pika.URLParameters(self._url),
            on_open_callback=self.on_connection_open,
            on_open_error_callback=self.on_connection_open_error,
            on_close_callback=self.on_connection_closed)

    def close_connection(self):
        self._consuming = False
        if self._connection.is_closing or self._connection.is_closed:
            logger.info('Connection is closing or already closed')
        else:
            logger.info('Closing connection')
            self._connection.close()

    def on_connection_open(self, _unused_connection):
        logger.debug('Connection opened')
        self.open_channel()

    def on_connection_open_error(self, _unused_connection, err):
        logger.error('Connection open failed: %s', err)
        self.reconnect()

    def on_connection_closed(self, _unused_connection, reason):
        self._channel = None
        if self._closing:
            self._connection.ioloop.stop()
        else:
            logger.warning('Connection closed, reconnect necessary: %s', reason)
            self.reconnect()

    def reconnect(self):
        """Will be invoked if the connection can't be opened or is
        closed. Indicates that a reconnect is necessary then stops the
        ioloop.

        """
        self.should_reconnect = True
        self.stop()

    def open_channel(self):
        logger.debug('Creating a new channel')
        self._connection.channel(on_open_callback=self.on_channel_open)

    def on_channel_open(self, channel):
        logger.debug('Channel opened')
        self._channel = channel
        self._channel.add_on_close_callback(self.on_channel_closed)
        self._channel.add_on_cancel_callback(self.on_consumer_cancelled)
        for cube in self.cubes:
            queue_name = f'res.cube.{str(cube)}'
            self.setup_queue(queue_name)

    def on_channel_closed(self, channel, reason):
        logger.warning('Channel %i was closed: %s', channel, reason)
        self.close_connection()

    def setup_queue(self, queue_name):
        """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC
        command. When it is complete, the on_queue_declareok method will
        be invoked by pika.

        :param str|unicode queue_name: The name of the queue to declare.

        """
        logger.debug('Declaring queue %s', queue_name)
        cb = functools.partial(self.on_queue_declareok, queue_name=queue_name)
        self._channel.queue_declare(queue=queue_name,
                                    durable=True,
                                    exclusive=False,
                                    auto_delete=False,
                                    arguments={"x-queue-type": "stream"},
                                    callback=cb)

    def on_queue_declareok(self, _unused_frame, queue_name):
        logger.info('Queue bound: %s', queue_name)
        self.set_qos(queue_name)

    def set_qos(self, queue_name):
        cb = functools.partial(self.on_basic_qos_ok, queue_name=queue_name)
        self._channel.basic_qos(
            prefetch_count=self._prefetch_count, callback=cb)

    def on_basic_qos_ok(self, _unused_frame, queue_name):
        logger.debug('QOS set to: %d', self._prefetch_count)
        self.start_consuming(queue_name)

    def start_consuming(self, queue_name):
        logger.debug('Issuing consumer related RPC commands')
        _consumer_tag = self._channel.basic_consume(
            queue_name,
            self.on_message,
            auto_ack=False,
            arguments=self.arguments)
        self._consumer_tags.append(_consumer_tag)
        self.was_consuming = True
        self._consuming = True

    def on_consumer_cancelled(self, method_frame):
        """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer
        receiving messages.

        :param pika.frame.Method method_frame: The Basic.Cancel frame

        """
        logger.info('Consumer was cancelled remotely, shutting down: %r',
                    method_frame)
        if self._channel:
            self._channel.close()

    def on_message(self, _unused_channel, basic_deliver, properties, body):
        try:
            reservation = Reservation.parse_obj(json.loads(body.decode('utf-8')))
            self.callback(reservation)
            self._channel.basic_ack(basic_deliver.delivery_tag)
            logger.debug(f"Event processed successful")
        except TypeError as e:
            traceback.print_exc()
            self._channel.basic_ack(basic_deliver.delivery_tag)
            logger.debug(f"Event format could no get parsed")
        except Exception as e:
            traceback.print_exc()
            logger.error(f"Failed to process Event; message {e}")

    def stop_consuming(self):
        """Tell RabbitMQ that you would like to stop consuming by sending the
        Basic.Cancel RPC command.

        """
        if self._channel:
            logger.debug('Sending a Basic.Cancel RPC command to RabbitMQ')
            for tag in self._consumer_tags:
                cb = functools.partial(self.on_cancelok, userdata=tag)
                self._channel.basic_cancel(tag, cb)

    def on_cancelok(self, _unused_frame, userdata):
        self._consumer_tags.remove(userdata)
        if not self._consumer_tags:
            self._consuming = False
            logger.info('RabbitMQ acknowledged the cancellation of all consumers')
            self.close_channel()

    def close_channel(self):
        """Call to close the channel with RabbitMQ cleanly by issuing the
        Channel.Close RPC command.

        """
        logger.info('Closing the channel')
        self._channel.close()

    def run(self):
        while True:
            try:
                self._connection = self.connect()
                self._connection.ioloop.start()
            except (KeyboardInterrupt, SystemExit):
                self.stop()
                break
            if not self.should_reconnect:
                break

            self._consuming = False
            self._closing = False
            time.sleep(1)

    def stop(self):
        """Cleanly shutdown the connection to RabbitMQ by stopping the consumer
        with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok
        will be invoked by pika, which will then closing the channel and
        connection. The IOLoop is started again because this method is invoked
        when CTRL-C is pressed raising a KeyboardInterrupt exception. This
        exception stops the IOLoop which needs to be running for pika to
        communicate with RabbitMQ. All of the commands issued prior to starting
        the IOLoop will be buffered but not processed.

        """
        if not self._closing:
            self._closing = True
            logger.debug('Stopping')
            if self._consuming:
                self.stop_consuming()
            else:
                self._connection.ioloop.stop()
            logger.info('ReservationConsumer stopped')


cube_reservations: Dict[int, Reservation] = {}
res_thread: ReservationConsumer = None
custom_callback: Callable = None


def get_active_users(cube_id: int):
    res: Reservation = cube_reservations.get(cube_id)
    if not res_thread:
        raise RuntimeError("Reservation consumer is not initialized")

    if res and res.endtime > datetime.utcnow().replace(tzinfo=ZoneInfo('UTC')):
        if res.extraUsers is None:
            return [res.owner]
        else:
            return [*res.extraUsers, res.owner]
    else:
        return []


def get_active_reservation(cube_id: int):
    res: Reservation = cube_reservations.get(cube_id)
    if not res_thread:
        raise RuntimeError("Reservation consumer is not initialized")

    if res and res.endtime > datetime.utcnow().replace(tzinfo=ZoneInfo('UTC')):
        return res
    else:
        return None


def restart_reservation_consumer(ids: List[int], callback: Callable = None):
    global res_thread
    global custom_callback

    custom_callback = callback

    def res_event_callback(res: Reservation):
        logger.info(f'Received Reservation for cube {res.cube_id}; {res}')
        cube_reservations[res.cube_id] = res
        if custom_callback:
            custom_callback(res)
        

    if not res_thread:
        res_thread = ReservationConsumer(cubes=ids, callback=res_event_callback)
        res_thread.start()
    elif Counter(ids) != Counter(res_thread.cubes):
        res_thread.stop()
        res_thread.join()
        res_thread = ReservationConsumer(cubes=ids, callback=res_event_callback)
        res_thread.start()

    return res_thread