Skip to content
Snippets Groups Projects
power.py 12.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • import collections
    
    import functools
    
    istmxrein's avatar
    istmxrein committed
    import threading
    
    import traceback
    from datetime import datetime
    from enum import Enum
    
    istmxrein's avatar
    istmxrein committed
    from typing import List, Callable, Optional, Dict
    
    from pydantic import BaseModel, validator
    
    
    from pilab.events import utils
    
    logger = logging.getLogger(__name__)
    logging.getLogger("pika").setLevel(logging.WARNING)
    
    
    
    class ActionType(str, Enum):
        POWER_ON = "POWER_ON"
        POWER_OFF = "POWER_OFF"
        START_KERNEL = "START_KERNEL"
        FETCH_WORKFLOW = "FETCH_WORKFLOW"
        DOWNLOAD_IMAGE = "DOWNLOAD_IMAGE"
        PRE_BOOT_CONFIG = "PRE_BOOT_CONFIGURATION"
        BOOT_IMAGE = "BOOT_IMAGE"
        CLOUD_INIT_CONFIG = "CLOUD_INIT_CONFIG"
        ANSIBLE_CONFIG = "ANSIBLE_CONFIG"
    
    class ActionState(str, Enum):
        ACTION_SUCCESS = "ACTION_SUCCESS"
        ACTION_IN_PROGRESS = "ACTION_IN_PROGRESS"
        ACTION_FAILED = "ACTION_FAILED"
        ACTION_PENDING = "ACTION_PENDING"
    
    
    class Properties(BaseModel):
        timestamp: float = None
        configured: Optional[bool] = False
    
        @validator("timestamp", pre=True, always=True)
        def validate_date(cls, value, values):
            return value or datetime.utcnow().timestamp()
    
    
    class Action(BaseModel):
        type: ActionType
        state: ActionState
        progress: Optional[int]
        message: Optional[str]
    
    
    class State(str, Enum):
        INACTIVE = "INACTIVE"
        PENDING = "PENDING"
        ACTIVE = "ACTIVE"
        FAILED = "FAILED"
    
    class Event(BaseModel):
        mac: str
        state: State
    
        action: Optional[Action]
        properties: Properties
    
    
    class Type(str, Enum):
        controller = "controller"
        pi = "pi"
        switch = "switch"
    
    
    
    def send_event(routing_key: str, event: Event):
        try:
            connection = utils.get_blocking_connection()
            channel = connection.channel()
    
            channel.queue_declare(
                queue=routing_key,
                durable=True,
    
                exclusive=False,
                auto_delete=False,
    
                arguments={"x-queue-type": "stream"}
            )
    
    
            event_serialized = json.dumps(event.dict()).encode('utf-8')
    
    
            channel.basic_publish(
                exchange='',
                routing_key=routing_key,
                body=event_serialized)
    
            connection.close()
        except pika.exceptions.UnroutableError as e:
            print(f'Could not send event to Broker; message {e}')
            raise
        except Exception as e:
            logger.error(f'Error sending host event; message: {e}')
            raise
    
    
    
    def send_pi_event(mac: str, state: State, action: Action, props: Properties):
    
        routing_key = f"host.pi.{mac}"
        event = Event(
            mac=mac,
    
            state=state,
            action=action,
            properties=props
    
        )
        send_event(routing_key=routing_key, event=event)
    
    
    def send_controller_event(mac: str, state: State):
        routing_key = f"host.controller.{mac}"
        event = Event(
            mac=mac,
            state=state,
    
            properties=Properties()
    
        )
        send_event(routing_key=routing_key, event=event)
    
    
    def send_switch_event(mac: str, state: State):
        routing_key = f"host.switch.{mac}"
        event = Event(
            mac=mac,
            state=state,
    
            properties=Properties()
    
        )
        send_event(routing_key=routing_key, event=event)
    
    istmxrein's avatar
    istmxrein committed
    class PowerConsumer(threading.Thread):
    
        def __init__(self, queues: List[str], callback: Callable, type: Type, amqp_url: str = utils.BROKER_URL,
                     arguments: Dict = {"x-stream-offset": "last"}):
    
            """Create a new instance of the consumer class, passing in the AMQP
            URL used to connect to RabbitMQ.
    
            :param str amqp_url: The AMQP url to connect with
    
            """
    
    istmxrein's avatar
    istmxrein committed
            super(PowerConsumer, self).__init__()
    
            self.should_reconnect = False
            self.was_consuming = False
    
            self._connection = None
            self._channel = None
            self._closing = False
            self._consumer_tags = []
            self._url = amqp_url
            self._consuming = False
            # In production, experiment with higher prefetch values
            # for higher consumer throughput
            self._prefetch_count = 10
    
    
            self.callback = callback
            self.queues = queues
    
    istmxrein's avatar
    istmxrein committed
            self.arguments = arguments
    
    
        def connect(self):
            """This method connects to RabbitMQ, returning the connection handle.
            When the connection is established, the on_connection_open method
            will be invoked by pika.
    
            :rtype: pika.SelectConnection
    
            """
            logger.info('Connecting to %s', self._url)
    
            return pika.SelectConnection(
                parameters=pika.URLParameters(self._url),
                on_open_callback=self.on_connection_open,
                on_open_error_callback=self.on_connection_open_error,
                on_close_callback=self.on_connection_closed)
    
        def close_connection(self):
            self._consuming = False
            if self._connection.is_closing or self._connection.is_closed:
                logger.info('Connection is closing or already closed')
            else:
                logger.info('Closing connection')
                self._connection.close()
    
        def on_connection_open(self, _unused_connection):
    
            logger.debug('Connection opened')
    
            self.open_channel()
    
        def on_connection_open_error(self, _unused_connection, err):
            logger.error('Connection open failed: %s', err)
            self.reconnect()
    
        def on_connection_closed(self, _unused_connection, reason):
            self._channel = None
            if self._closing:
                self._connection.ioloop.stop()
            else:
                logger.warning('Connection closed, reconnect necessary: %s', reason)
                self.reconnect()
    
        def reconnect(self):
            """Will be invoked if the connection can't be opened or is
            closed. Indicates that a reconnect is necessary then stops the
            ioloop.
    
            """
            self.should_reconnect = True
            self.stop()
    
        def open_channel(self):
    
            logger.debug('Creating a new channel')
    
            self._connection.channel(on_open_callback=self.on_channel_open)
    
        def on_channel_open(self, channel):
    
            logger.debug('Channel opened')
    
            self._channel = channel
    
            self._channel.add_on_close_callback(self.on_channel_closed)
            self._channel.add_on_cancel_callback(self.on_consumer_cancelled)
    
            for q in self.queues:
    
                queue_name = f'host.{self.type.value}.{q}'
    
    
        def on_channel_closed(self, channel, reason):
            logger.warning('Channel %i was closed: %s', channel, reason)
            self.close_connection()
    
        def setup_queue(self, queue_name):
            """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC
            command. When it is complete, the on_queue_declareok method will
            be invoked by pika.
    
            :param str|unicode queue_name: The name of the queue to declare.
    
            """
    
            logger.debug('Declaring queue %s', queue_name)
    
            cb = functools.partial(self.on_queue_declareok, queue_name=queue_name)
            self._channel.queue_declare(queue=queue_name,
                                        durable=True,
                                        exclusive=False,
                                        auto_delete=False,
                                        arguments={"x-queue-type": "stream"},
                                        callback=cb)
    
        def on_queue_declareok(self, _unused_frame, queue_name):
            logger.info('Queue bound: %s', queue_name)
            self.set_qos(queue_name)
    
        def set_qos(self, queue_name):
            cb = functools.partial(self.on_basic_qos_ok, queue_name=queue_name)
            self._channel.basic_qos(
                prefetch_count=self._prefetch_count, callback=cb)
    
        def on_basic_qos_ok(self, _unused_frame, queue_name):
    
            logger.debug('QOS set to: %d', self._prefetch_count)
    
            self.start_consuming(queue_name)
    
        def start_consuming(self, queue_name):
    
            logger.debug('Issuing consumer related RPC commands')
    
            _consumer_tag = self._channel.basic_consume(
    
    istmxrein's avatar
    istmxrein committed
                queue_name,
                self.on_message,
                auto_ack=False,
                arguments=self.arguments)
    
            self._consumer_tags.append(_consumer_tag)
            self.was_consuming = True
            self._consuming = True
    
        def on_consumer_cancelled(self, method_frame):
            """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer
            receiving messages.
    
            :param pika.frame.Method method_frame: The Basic.Cancel frame
    
            """
            logger.info('Consumer was cancelled remotely, shutting down: %r',
                        method_frame)
            if self._channel:
                self._channel.close()
    
        def on_message(self, _unused_channel, basic_deliver, properties, body):
            try:
                event = Event.parse_obj(json.loads(body.decode('utf-8')))
                self.callback(event)
                self._channel.basic_ack(basic_deliver.delivery_tag)
    
                logger.debug(f"Event processed successful")
    
            except Exception as e:
                traceback.print_exc()
                logger.error(f"Failed to process Event; message {e}")
    
        def stop_consuming(self):
            """Tell RabbitMQ that you would like to stop consuming by sending the
            Basic.Cancel RPC command.
    
            """
            if self._channel:
    
                logger.debug('Sending a Basic.Cancel RPC command to RabbitMQ')
    
                for tag in self._consumer_tags:
                    cb = functools.partial(self.on_cancelok, userdata=tag)
                    self._channel.basic_cancel(tag, cb)
    
        def on_cancelok(self, _unused_frame, userdata):
    
            if not self._consumer_tags:
                self._consuming = False
                logger.info('RabbitMQ acknowledged the cancellation of all consumers')
                self.close_channel()
    
        def close_channel(self):
            """Call to close the channel with RabbitMQ cleanly by issuing the
            Channel.Close RPC command.
    
            """
            logger.info('Closing the channel')
            self._channel.close()
    
        def run(self):
    
            while True:
                try:
                    self._connection = self.connect()
                    self._connection.ioloop.start()
                except (KeyboardInterrupt, SystemExit):
                    self.stop()
                    break
                if not self.should_reconnect:
                    break
    
                self._consuming = False
                self._closing = False
                time.sleep(1)
    
    
        def stop(self):
            """Cleanly shutdown the connection to RabbitMQ by stopping the consumer
            with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok
            will be invoked by pika, which will then closing the channel and
            connection. The IOLoop is started again because this method is invoked
            when CTRL-C is pressed raising a KeyboardInterrupt exception. This
            exception stops the IOLoop which needs to be running for pika to
            communicate with RabbitMQ. All of the commands issued prior to starting
            the IOLoop will be buffered but not processed.
    
            """
            if not self._closing:
                self._closing = True
    
                logger.debug('Stopping')
    
                if self._consuming:
                    self.stop_consuming()
                else:
                    self._connection.ioloop.stop()
    
                logger.info('PowerConsumer stopped')
    
    state: Dict[str, Event] = {}
    
    custom_callback: Callable = None
    
    pi_consumer: PowerConsumer = None
    controller_consumer: PowerConsumer = None
    switch_consumer: PowerConsumer = None
    
    
    def event_callback(event: Event):
    
        logger.debug(f"Received Power event: {event}")
    
        if custom_callback:
            custom_callback(event)
    
        state[event.mac] = event
    
    
    
    def run_restart_consumer(macs: List[str], _type: Type, callback: Callable = None):
        global custom_callback, pi_consumer, controller_consumer, switch_consumer
        custom_callback = callback
    
        def restart_consumer(consumer: PowerConsumer):
            if not consumer:
                consumer = PowerConsumer(queues=macs, callback=event_callback, type=_type)
                consumer.start()
                logger.info(f'Restart power consumer for type {_type.value}')
                return consumer
            elif collections.Counter(consumer.queues) != collections.Counter(macs):
                consumer.stop()
                consumer.join()
                consumer = PowerConsumer(queues=macs, callback=event_callback, type=_type)
                consumer.start()
                logger.info(f'Restart power consumer for type {_type.value}')
                return consumer
            else:
                return consumer
    
        match _type:
            case Type.pi:
                pi_consumer = restart_consumer(pi_consumer)
            case Type.controller:
                controller_consumer = restart_consumer(controller_consumer)
            case Type.switch:
                switch_consumer = restart_consumer(switch_consumer)
            case _:
                raise ValueError("Type is not known")
    
    
    
    def get_latest_event(mac: str):
        return state.get(mac)
    
    def get_latest_state(mac: str):
        event = state.get(mac)
        return event.state if event else None