Skip to content
Snippets Groups Projects
Pi.py 4.96 KiB
Newer Older
  • Learn to ignore specific revisions
  • import logging as log
    from typing import List
    from sqlalchemy.orm import Session
    from pilab.crud.Util import from_json, update_attrs
    
    from pilab.crud.Host import Host
    from pilab.crud.Pi import Pi
    from pilab.schemas.Host import HostType
    from pilab.schemas.Pi import PiType
    from pilab.schemas.Host import HostType
    from pilab.schemas.Cube import CubesType
    from pilab.events import meta
    
    log.getLogger('sqlalchemy.engine').setLevel(log.WARNING)
    
    class Pi(object):
        @staticmethod
        def create(db: Session, pi: meta.Pi):#DONE
            host = Host.create(db, pi)
            db_pi = from_json(PiType, pi.dict(), host_id=host.id)
            db.add(db_pi)
            db.flush()
            pi = meta.Pi(**vars(db_pi), **host.dict())
            return pi
    
        @staticmethod
        def update(db: Session, pi: meta.PiUpdate, pi_id: int):
            db_pi = db.query(PiType).filter(
                PiType.host_id == pi_id).first()
            host = Host.update(db, pi, db_pi.host_id)
            update_attrs(pi, db_pi)
            db.flush()
            pi = meta.Pi(**vars(db_pi), **host.dict())
            return pi
    
        @staticmethod
        def update_image(db: Session, pi_id: int, image_id: int, user_data_id: int):
            db_pi = db.query(PiType).filter(
                PiType.host_id == pi_id).first()
            db_pi.image_id = image_id
            db_pi.user_data_id = user_data_id
            db.flush()
    
        @staticmethod
        def update_playbook(db: Session, pi_id: int, playbook_id: int):
            db_pi = db.query(PiType).filter(
                PiType.host_id == pi_id).first()
            db_pi.playbook_id = playbook_id
            db.flush()
     
        @staticmethod
        def assign_image(db: Session, pi_id: int, image_id: int):
            db_pi = db.query(PiType).filter(
                PiType.host_id == pi_id).first()
            db_pi.image_id = image_id
            db.flush()
        
        @staticmethod
        def assign_host_key(db: Session, pi_id: int, key_id: int):
            db_pi = db.query(PiType).filter(
                PiType.host_id == pi_id).first()
            db_pi.key_id = key_id
            db.flush()
    
        @staticmethod
        def assign_data(db: Session, pi_id: int, user_data_id: int):
            db_pi = db.query(PiType).filter(
                PiType.host_id == pi_id).first()
            db_pi.user_data_id = user_data_id
            db.flush()
    
        @staticmethod
        def get(db: Session, id: int = None, serial : bytes = None, mac: int = None):#DONE
            if id:
                db_pi = db.query(PiType).filter(PiType.host_id == id).first()
            elif serial:
                db_pi = db.query(PiType).filter(PiType.serial == serial).first()
            elif mac:
                db_host = db.query(HostType).filter(HostType.mac == mac).first()
                db_pi = db.query(PiType).filter(PiType.host_id == db_host.id).first()
            host = Host.get(db, db_pi.host_id)
            return meta.Pi(**vars(db_pi), **host.dict())
    
        @staticmethod
        def delete(db: Session, host_id: int):#DONE
            db_pi = db.query(PiType).filter(
                PiType.host_id == host_id
            ).first()
            db.delete(db_pi)
            db.flush()
            Host.delete(db, db_pi.host_id)
            return
    
        @staticmethod
        def is_valid_id(db: Session, host_id: int):
            pi = db.query(PiType).filter(
                PiType.host_id == host_id).first()
            return True if pi else False
    
        @staticmethod
        def get_cube_id(db: Session, pi_id: int):
            pi = Pi.get(db, pi_id)
            cube_id = db.query(CubesType.id).filter(
                CubesType.id == pi.cube_id).first()
            return cube_id
    
        @staticmethod
        def getAll(db: Session, _ids: List[int] = None, cube_id: int = None, switch_id: int = None, controller_id: int = None, image_id: int= None):#DONE
            if _ids:
                ids = db.query(PiType.host_id).filter(PiType.host_id.in_(_ids)).all()
                if len(ids) != len(_ids):
                    invalid_ids = set(ids) - {pi.id for pi in ids}
                    raise ValueError(f"Invalid ID(s) in the list: {invalid_ids}")
            elif cube_id:
                ids = db.query(PiType.host_id).filter(
                PiType.cube_id == cube_id).all()
            elif switch_id:
                db_cube = db.query(CubesType).filter(
                    CubesType.switch_id == switch_id
                ).first()
                ids = db.query(PiType.host_id).filter(
                    PiType.cube_id == db_cube.id
                ).all()
            elif controller_id:
                db_cube = db.query(CubesType).filter(
                    CubesType.controller_id == controller_id
                ).first()
                ids = db.query(PiType.host_id).filter(
                    PiType.cube_id == db_cube.id
                ).all()
            elif image_id:
                ids = db.query(PiType.host_id).filter(PiType.image_id == image_id).all()
            else:
                ids = db.query(PiType.host_id).all()
            
            pis = []
            for i, in ids:
                pi = Pi.get(db, i)
                pis.append(pi)
            return pis
        @staticmethod
        def get_ids(db: Session):#DONE
            ids_tuple = db.query(PiType.host_id).all()
            ids = []
            for id, in ids_tuple:
                ids.append(id)
            return ids