import json
import logging
import threading
import time
import traceback
from enum import Enum
from ipaddress import IPv4Address
from typing import List, Optional, Callable, Dict

import pika
import pika.exceptions
from pydantic import BaseModel

from pilab.events import utils

logger = logging.getLogger(__name__)


STREAM_NAME = "meta.cube"


def convert_to_optional(schema):
    return {k: Optional[v] for k, v in schema.__annotations__.items()}


class Host(BaseModel):
    """
    representation of host object with id
    """
    id: int
    mac: str
    ipv4_address: IPv4Address
    hostname: str


class HostUpdate(Host):
    """
    representation of host for update
    """
    __annotations__ = convert_to_optional(Host)


class Pi(Host):
    """
    representation of pi object with host id
    """
    serial: bytes
    display: bool
    position: int
    ssh_host_ed25519_key: Optional[str]


class PiUpdate(HostUpdate):
    """
    representation of pi object for update
    """
    __annotations__ = convert_to_optional(Pi)


class Cube(BaseModel):
    """
    representation of cube object with host ids
    """
    id: int
    controller: Host
    switch: Host
    head: Pi
    workers: List[Pi]


class EventType(str, Enum):
    CREATE = 'create'
    UPDATE = 'update'
    DELETE = 'delete'


class Event(BaseModel):
    type: EventType
    timestamp: float
    payload: Cube


class MetaListener(threading.Thread):
    def __init__(self, callback: Callable, queue: str = STREAM_NAME, offset="first"):
        super(MetaListener, self).__init__()
        self.queue: str = queue
        self.callback: Callable = callback
        self.offset = offset
        self.channel = None
        self.connection = None
        self._prefetch_count = 1

    def default_callback(self, ch, method, properties, body):
        try:
            event = Event.parse_obj(json.loads(body.decode('utf-8')))
            offset = properties.headers['x-stream-offset']
            self.callback(event, offset)
            self.channel.basic_ack(method.delivery_tag)
            self.offset = offset
            logger.info(f"Event processed succesful")
        except Exception as e:
            traceback.print_exc()
            logger.error(f"Failed to process Meta event; message {e}")
            raise

    def run(self):
        while True:
            try:
                self.connection = utils.get_blocking_connection()
                self.channel = self.connection.channel()
                self.channel.queue_declare(queue=self.queue, durable=True, exclusive=False, auto_delete=False,
                                           arguments={"x-queue-type": "stream"})
                self.channel.basic_qos(prefetch_count=self._prefetch_count)
                self.channel.basic_consume(queue=self.queue, on_message_callback=self.default_callback, auto_ack=False,
                                           arguments={"x-stream-offset": self.offset})

                try:
                    self.channel.start_consuming()
                except SystemExit:
                    logger.info(f"System exit. Close Broker Connection...")
                    self.channel.stop_consuming()
                    self.connection.close()
                    break
            except pika.exceptions.ConnectionClosedByBroker:
                logger.error(f"Connection was closed, retrying...")
                time.sleep(1)
                continue
            # Do not recover on channel errors
            except pika.exceptions.AMQPChannelError as err:
                logger.error(f"Caught a channel error: {err}, stopping...")
                continue
            # Recover on all other connection errors
            except pika.exceptions.AMQPConnectionError:
                logger.error(f"Connection closed unexpected, retrying...")
                time.sleep(1)
                continue
            except Exception:
                logger.error(f"Unexpected error occured, retrying...")
                time.sleep(1)
                continue