import collections import functools import json import logging import threading import time import traceback from datetime import datetime from enum import Enum from typing import List, Callable, Optional, Dict import pika import pika.exceptions from pydantic import BaseModel, validator from pilab.events import utils logger = logging.getLogger(__name__) logging.getLogger("pika").setLevel(logging.WARNING) class ActionType(str, Enum): POWER_ON = "POWER_ON" POWER_OFF = "POWER_OFF" START_KERNEL = "START_KERNEL" FETCH_WORKFLOW = "FETCH_WORKFLOW" DOWNLOAD_IMAGE = "DOWNLOAD_IMAGE" PRE_BOOT_CONFIG = "PRE_BOOT_CONFIGURATION" BOOT_IMAGE = "BOOT_IMAGE" CLOUD_INIT_CONFIG = "CLOUD_INIT_CONFIG" ANSIBLE_CONFIG = "ANSIBLE_CONFIG" class ActionState(str, Enum): ACTION_SUCCESS = "ACTION_SUCCESS" ACTION_IN_PROGRESS = "ACTION_IN_PROGRESS" ACTION_FAILED = "ACTION_FAILED" ACTION_PENDING = "ACTION_PENDING" class Properties(BaseModel): timestamp: float = None configured: Optional[bool] = False @validator("timestamp", pre=True, always=True) def validate_date(cls, value, values): return value or datetime.utcnow().timestamp() class Action(BaseModel): type: ActionType state: ActionState progress: Optional[int] message: Optional[str] class State(str, Enum): INACTIVE = "INACTIVE" PENDING = "PENDING" ACTIVE = "ACTIVE" FAILED = "FAILED" class Event(BaseModel): mac: str state: State action: Optional[Action] properties: Properties class Type(str, Enum): controller = "controller" pi = "pi" switch = "switch" def send_event(routing_key: str, event: Event): try: connection = utils.get_blocking_connection() channel = connection.channel() channel.queue_declare( queue=routing_key, durable=True, exclusive=False, auto_delete=False, arguments={"x-queue-type": "stream"} ) event_serialized = json.dumps(event.dict()).encode('utf-8') channel.basic_publish( exchange='', routing_key=routing_key, body=event_serialized) connection.close() except pika.exceptions.UnroutableError as e: print(f'Could not send event to Broker; message {e}') raise except Exception as e: logger.error(f'Error sending host event; message: {e}') raise def send_pi_event(mac: str, state: State, action: Action, props: Properties): routing_key = f"host.pi.{mac}" event = Event( mac=mac, state=state, action=action, properties=props ) send_event(routing_key=routing_key, event=event) def send_controller_event(mac: str, state: State): routing_key = f"host.controller.{mac}" event = Event( mac=mac, state=state, properties=Properties() ) send_event(routing_key=routing_key, event=event) def send_switch_event(mac: str, state: State): routing_key = f"host.switch.{mac}" event = Event( mac=mac, state=state, properties=Properties() ) send_event(routing_key=routing_key, event=event) class PowerConsumer(threading.Thread): def __init__(self, queues: List[str], callback: Callable, type: Type, amqp_url: str = utils.BROKER_URL, arguments: Dict = {"x-stream-offset": "last"}): """Create a new instance of the consumer class, passing in the AMQP URL used to connect to RabbitMQ. :param str amqp_url: The AMQP url to connect with """ super(PowerConsumer, self).__init__() self.should_reconnect = False self.was_consuming = False self._connection = None self._channel = None self._closing = False self._consumer_tags = [] self._url = amqp_url self._consuming = False # In production, experiment with higher prefetch values # for higher consumer throughput self._prefetch_count = 10 self.type = type self.callback = callback self.queues = queues self.arguments = arguments def connect(self): """This method connects to RabbitMQ, returning the connection handle. When the connection is established, the on_connection_open method will be invoked by pika. :rtype: pika.SelectConnection """ logger.info('Connecting to %s', self._url) return pika.SelectConnection( parameters=pika.URLParameters(self._url), on_open_callback=self.on_connection_open, on_open_error_callback=self.on_connection_open_error, on_close_callback=self.on_connection_closed) def close_connection(self): self._consuming = False if self._connection.is_closing or self._connection.is_closed: logger.info('Connection is closing or already closed') else: logger.info('Closing connection') self._connection.close() def on_connection_open(self, _unused_connection): logger.debug('Connection opened') self.open_channel() def on_connection_open_error(self, _unused_connection, err): logger.error('Connection open failed: %s', err) self.reconnect() def on_connection_closed(self, _unused_connection, reason): self._channel = None if self._closing: self._connection.ioloop.stop() else: logger.warning('Connection closed, reconnect necessary: %s', reason) self.reconnect() def reconnect(self): """Will be invoked if the connection can't be opened or is closed. Indicates that a reconnect is necessary then stops the ioloop. """ self.should_reconnect = True self.stop() def open_channel(self): logger.debug('Creating a new channel') self._connection.channel(on_open_callback=self.on_channel_open) def on_channel_open(self, channel): logger.debug('Channel opened') self._channel = channel self._channel.add_on_close_callback(self.on_channel_closed) self._channel.add_on_cancel_callback(self.on_consumer_cancelled) for q in self.queues: queue_name = f'host.{self.type.value}.{q}' self.setup_queue(queue_name) def on_channel_closed(self, channel, reason): logger.warning('Channel %i was closed: %s', channel, reason) self.close_connection() def setup_queue(self, queue_name): """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC command. When it is complete, the on_queue_declareok method will be invoked by pika. :param str|unicode queue_name: The name of the queue to declare. """ logger.debug('Declaring queue %s', queue_name) cb = functools.partial(self.on_queue_declareok, queue_name=queue_name) self._channel.queue_declare(queue=queue_name, durable=True, exclusive=False, auto_delete=False, arguments={"x-queue-type": "stream"}, callback=cb) def on_queue_declareok(self, _unused_frame, queue_name): logger.info('Queue bound: %s', queue_name) self.set_qos(queue_name) def set_qos(self, queue_name): cb = functools.partial(self.on_basic_qos_ok, queue_name=queue_name) self._channel.basic_qos( prefetch_count=self._prefetch_count, callback=cb) def on_basic_qos_ok(self, _unused_frame, queue_name): logger.debug('QOS set to: %d', self._prefetch_count) self.start_consuming(queue_name) def start_consuming(self, queue_name): logger.debug('Issuing consumer related RPC commands') _consumer_tag = self._channel.basic_consume( queue_name, self.on_message, auto_ack=False, arguments=self.arguments) self._consumer_tags.append(_consumer_tag) self.was_consuming = True self._consuming = True def on_consumer_cancelled(self, method_frame): """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer receiving messages. :param pika.frame.Method method_frame: The Basic.Cancel frame """ logger.info('Consumer was cancelled remotely, shutting down: %r', method_frame) if self._channel: self._channel.close() def on_message(self, _unused_channel, basic_deliver, properties, body): try: event = Event.parse_obj(json.loads(body.decode('utf-8'))) self.callback(event) self._channel.basic_ack(basic_deliver.delivery_tag) logger.debug(f"Event processed successful") except Exception as e: traceback.print_exc() logger.error(f"Failed to process Event; message {e}") def stop_consuming(self): """Tell RabbitMQ that you would like to stop consuming by sending the Basic.Cancel RPC command. """ if self._channel: logger.debug('Sending a Basic.Cancel RPC command to RabbitMQ') for tag in self._consumer_tags: cb = functools.partial(self.on_cancelok, userdata=tag) self._channel.basic_cancel(tag, cb) def on_cancelok(self, _unused_frame, userdata): self._consumer_tags.remove(userdata) if not self._consumer_tags: self._consuming = False logger.info('RabbitMQ acknowledged the cancellation of all consumers') self.close_channel() def close_channel(self): """Call to close the channel with RabbitMQ cleanly by issuing the Channel.Close RPC command. """ logger.info('Closing the channel') self._channel.close() def run(self): while True: try: self._connection = self.connect() self._connection.ioloop.start() except (KeyboardInterrupt, SystemExit): self.stop() break if not self.should_reconnect: break self._consuming = False self._closing = False time.sleep(1) def stop(self): """Cleanly shutdown the connection to RabbitMQ by stopping the consumer with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok will be invoked by pika, which will then closing the channel and connection. The IOLoop is started again because this method is invoked when CTRL-C is pressed raising a KeyboardInterrupt exception. This exception stops the IOLoop which needs to be running for pika to communicate with RabbitMQ. All of the commands issued prior to starting the IOLoop will be buffered but not processed. """ if not self._closing: self._closing = True logger.debug('Stopping') if self._consuming: self.stop_consuming() else: self._connection.ioloop.stop() logger.info('PowerConsumer stopped') state: Dict[str, Event] = {} custom_callback: Callable = None pi_consumer: PowerConsumer = None controller_consumer: PowerConsumer = None switch_consumer: PowerConsumer = None def event_callback(event: Event): logger.debug(f"Received Power event: {event}") if custom_callback: custom_callback(event) state[event.mac] = event def run_restart_consumer(macs: List[str], _type: Type, callback: Callable = None): global custom_callback, pi_consumer, controller_consumer, switch_consumer custom_callback = callback def restart_consumer(consumer: PowerConsumer): if not consumer: consumer = PowerConsumer(queues=macs, callback=event_callback, type=_type) consumer.start() logger.info(f'Restart power consumer for type {_type.value}') return consumer elif collections.Counter(consumer.queues) != collections.Counter(macs): consumer.stop() consumer.join() consumer = PowerConsumer(queues=macs, callback=event_callback, type=_type) consumer.start() logger.info(f'Restart power consumer for type {_type.value}') return consumer else: return consumer match _type: case Type.pi: pi_consumer = restart_consumer(pi_consumer) case Type.controller: controller_consumer = restart_consumer(controller_consumer) case Type.switch: switch_consumer = restart_consumer(switch_consumer) case _: raise ValueError("Type is not known") def get_latest_event(mac: str): return state.get(mac) def get_latest_state(mac: str): event = state.get(mac) return event.state if event else None