Skip to content
Snippets Groups Projects
Image.py 7.19 KiB
Newer Older
  • Learn to ignore specific revisions
  • import logging as log
    from typing import List
    from datetime import datetime, timezone
    from zoneinfo import ZoneInfo
    from pydantic import BaseModel
    from sqlalchemy import inspect
    from sqlalchemy.orm import Session, class_mapper
    from pilab.crud.Util import from_json, update_attrs
    
    
    
    from pilab.schemas.Image import ImageType, ScriptType, UserDataType
    from pilab.events import meta
    
    log.getLogger('sqlalchemy.engine').setLevel(log.WARNING)
    
    
    class Image(object):
        @staticmethod
        def create(db: Session, **kwargs):
            db_image = from_json(ImageType, kwargs, change=datetime.now(timezone.utc))
            db.add(db_image)
            db.flush()
            return db_image.id
    
        @staticmethod
        def delete(db: Session, image_id: int):
            db_image = db.query(ImageType).filter(
                ImageType.id == image_id
            ).first()
            if db_image:
                db.delete(db_image)
            db.flush()
    
        @staticmethod
        def get(db: Session, _id: int = None, name: str = None, version: str = None, username: str = None):
            """Returns images from the postgres database. List or single image.
            filter by id to get specific image.
            filter by name to get all images with name
            filter by name and version to get specific image. Version without name will be omitted.
            filter by username to get all public images and user specific images. Get all images if None
            """
    
            def process_images(data):
                if isinstance(data, list):
                    images = []
                    for image in data or []:
                        image.change = image.change.replace(tzinfo=ZoneInfo('UTC'))
                        if username is None or image.public or image.owner == username:
                            images.append(meta.Image(
                            **vars(image)))
                    return images
                else:
                    image = data
                    if image:
                        image.change = image.change.replace(tzinfo=ZoneInfo('UTC'))
                        return meta.Image(
                            **vars(image)) if username is None or image.public or data.owner == username else None
                    else:
                        return None
    
            if _id:
                db_image = db.query(ImageType).filter(ImageType.id == _id).first()
                return process_images(db_image)
            elif name and not version:
                db_images = db.query(ImageType).filter(ImageType.name == name).all()
                return process_images(db_images)
            elif name and version:
                db_image = db.query(ImageType).filter(
                    ImageType.name == name,
                    ImageType.version == version
                ).first()
                return process_images(db_image)
            else:
                db_images = db.query(ImageType).all()
                return process_images(db_images)
    
        @staticmethod
        def update_size(db: Session, image_id: int, size: int):
            db_image = db.query(ImageType).filter(
                ImageType.id == image_id
            ).first()
            db_image.size = size
            db_image.change = datetime.now(timezone.utc)
            db.flush()
            return
    
        @staticmethod
        def update_script(db: Session, image_id: int, script_id: int):
            db_image = db.query(ImageType).filter(
                ImageType.id == image_id
            ).first()
            db_image.script_id = script_id
            db_image.change = datetime.now(timezone.utc)
            db.flush()
            return
    
        @staticmethod
        def get_id(db: Session, name: str, version: str):
            db_image = db.query(ImageType).filter(
                ImageType.name == name,
                ImageType.version == version
            ).first()
            if db_image:
                return db_image.id
            else:
                return None
    
        @staticmethod
        def is_valid_id(db: Session, image_id: int):
            cube = db.query(ImageType).filter(
                ImageType.id == image_id).first()
            return True if cube else False
    
    class UserData(object):
        @staticmethod
        def create(db: Session, data: str, owner: str):
            db_data = UserDataType(
                data=data,
                owner=owner
            )
            db.add(db_data)
            db.flush()
            return meta.UserData(**vars(db_data))
    
        @staticmethod
        def update(db: Session, data: str, _id: int):
            db_data = db.query(UserDataType).filter(
                UserDataType.id == _id).first()
            if db_data:
                db_data.data = data
            db.flush()
            return meta.UserData(**vars(db_data))
    
        @staticmethod
        def delete(db: Session, _id: int):
            db_data = db.query(UserDataType).filter(UserDataType.id == _id).first()
            if db_data:
                db.delete(db_data)
            db.flush()
    
        @staticmethod
        def get(db: Session, _id: int = None, owner: str = None):
            if _id:
                db_data = db.query(UserDataType).filter(UserDataType.id == _id).first()
                return meta.UserData(**vars(db_data))
            elif owner:
                db_datas = db.query(UserDataType).filter(UserDataType.owner == owner).all()
                return [meta.UserData(**vars(data)) for data in db_datas]
            else:
                db_datas = db.query(UserDataType).all()
                return [meta.UserData(**vars(data)) for data in db_datas]
    
        @staticmethod
        def is_valid_id(db: Session, _id: int):
            db_data = db.query(UserDataType).filter(
                UserDataType.id == _id).first()
            return True if db_data else False
    
    
    class Script(object):
        @staticmethod
        def create(db: Session, script: str, script_chroot: str, owner: str, name: str, read_only: bool):
            db_script = ScriptType(
                owner=owner,
                name=name,
                read_only=read_only,
                script=script,
                script_chroot=script_chroot
            )
            db.add(db_script)
            db.flush()
            return meta.Script(**vars(db_script))
    
        @staticmethod
        def update(db: Session, script_id: int, script: str = None, script_chroot: str = None, name: str = None):
            db_script = db.query(ScriptType).filter(ScriptType.id == script_id).first()
            if script:
                db_script.script = script
            if script_chroot:
                db_script.script_chroot = script_chroot
            if name:
                db_script.name = name
            db.flush()
            return meta.Script(**vars(db_script))
    
        @staticmethod
        def delete(db: Session, script_id: int):
            db_script = db.query(ScriptType).filter(
                ScriptType.id == script_id
            ).first()
            if db_script:
                db.delete(db_script)
            db.flush()
    
        @staticmethod
        def get(db: Session, script_id: int = None, name: str = None):
            if script_id:
                db_script = db.query(ScriptType).filter(ScriptType.id == script_id).first()
                return meta.Script(**vars(db_script)) if db_script else None
            elif name:
                db_script = db.query(ScriptType).filter(ScriptType.name == name).first()
                return meta.Script(**vars(db_script)) if db_script else None
            else:
                db_scripts = db.query(ScriptType).all()
                return [meta.Script(**vars(p)) for p in db_scripts]
    
        @staticmethod
        def is_valid_id(db: Session, script_id: int):
            script = db.query(ScriptType).filter(
                ScriptType.id == script_id).first()
            return True if script else False