Skip to content
Snippets Groups Projects
auth.py 3.51 KiB
Newer Older
  • Learn to ignore specific revisions
  • import logging
    
    from collections import Counter
    from typing import Optional, List, Dict, Callable
    
    from fastapi import Header, HTTPException
    
    from pilab.events import reservation
    from datetime import datetime
    
    
    logger = logging.getLogger(__name__)
    
    ADMIN_GROUPS = ["/admin"]
    
    
    cube_reservations: Dict[int: reservation.Reservation] = {}
    res_thread: reservation.ReservationConsumer = None
    custom_callback: Callable = None
    
    
    
    def is_admin(x_forwarded_groups: str):
        admin = False
        if x_forwarded_groups and x_forwarded_groups is not None:
            for group in ADMIN_GROUPS:
                if group in x_forwarded_groups:
                    admin = True
        return admin
    
    
    def get_username(usernames: List[str]):
        for name in usernames:
            if name and name is not None:
                return name
        return None
    
    
    
    def get_active_users(cube_id: int):
        res: reservation.Reservation = cube_reservations.get(cube_id)
        if res and res.endtime > datetime.utcnow():
            return [*res.extraUsers, res.owner]
        else:
            return []
    
    
    def get_active_reservation(cube_id: int):
        res: reservation.Reservation = cube_reservations.get(cube_id)
        if res and res.endtime > datetime.utcnow():
            return res
        else:
            return None
    
    
    
    async def get_user(x_forwarded_user: Optional[str] = Header(None),
                       x_forwarded_preferred_username: Optional[str] = Header(None),
                       x_forwarded_groups: Optional[str] = Header(None)):
        """
        Extract the username and admin status from the http headers oauth2-proxy provides
        """
    
        logger.debug([
                         "X-Forwarded-Preferred-Username: " + x_forwarded_preferred_username if x_forwarded_preferred_username else ""] +
    
                     ["X-Forwarded-User: " + x_forwarded_user if x_forwarded_user else ""] +
                     ["X-Forwarded-Groups: " + x_forwarded_groups if x_forwarded_groups else ""])
    
        admin = is_admin(x_forwarded_groups)
        username = get_username([x_forwarded_preferred_username, x_forwarded_user])
        return username, admin
    
    
    async def verify_user(cube_id: int, x_forwarded_preferred_username: Optional[str] = Header(None),
                          x_forwarded_user: Optional[str] = Header(None),
                          x_forwarded_groups: Optional[str] = Header(None)):
        admin = is_admin(x_forwarded_groups)
        username = get_username([x_forwarded_preferred_username, x_forwarded_user])
        if admin:
            return True
    
        if username is not None and username in get_active_users(cube_id):
            return True
    
    
        raise HTTPException(status_code=401, detail="Unauthorized")
    
    
    
    def verify_user_pi(cube_id: int, username: str, admin: bool):
        if admin:
            return True
    
        if username is not None and username in get_active_users(cube_id):
            return True
    
    
    
    
    def restart_reservation_consumer(ids: List[int], callback: Callable = None):
        global res_thread
        global custom_callback
    
        custom_callback = callback
    
        def res_event_callback(res: reservation.Reservation):
            logger.info(f'Received Reservation for cube {res.cube_id}; {res}')
            if custom_callback:
                custom_callback(res)
            cube_reservations[res.cube_id] = res
    
        if not res_thread:
            res_thread = reservation.ReservationConsumer(cubes=ids, callback=res_event_callback)
            res_thread.start()
        elif Counter(ids) != Counter(res_thread.cubes):
            res_thread.stop()
            res_thread.join()
            res_thread = reservation.ReservationConsumer(cubes=ids, callback=res_event_callback)
            res_thread.start()
    
        return res_thread