Skip to content
Snippets Groups Projects
transformer.py 4.58 KiB
Newer Older
  • Learn to ignore specific revisions
  • import datetime
    import ipaddress
    
    import httpx
    from apscheduler.schedulers.asyncio import AsyncIOScheduler
    import argparse
    import asyncio
    from contextlib import asynccontextmanager
    from fastapi import APIRouter, FastAPI
    import jsonpickle
    import uvicorn
    from apscheduler.jobstores.memory import MemoryJobStore
    from util import get_pub_ip
    import logging
    
    from connection import Connection
    
    from remotetransformer import RemoteTransformer, SignedRemoteTransformer
    
    from remoteline import RemoteLine
    
    logger = logging.getLogger(__name__)
    
    
    class Transformer(Connection):
        def __init__(self, cap: float, conips: set[ipaddress.IPv4Address], sched: AsyncIOScheduler):
            super().__init__(cap, conips, sched)
            self.adjacentLines: set[RemoteLine] = set()
    
            self.lastSignedTransformer: dict[datetime.datetime, SignedRemoteTransformer] = dict()
    
            # adding fastapi endpoints
            self.fastRouter.add_api_route("/asRemoteJSON", self.asRemoteTransformerJSON, methods=["GET"])
    
            self.fastRouter.add_api_route("/sign/{time}", self.sign, methods=["POST"])
    
            # setting up scheduling
            run_date = datetime.datetime.now() + datetime.timedelta(hours=1,
                                                                    seconds=5)  # +1 hour because timezones suck hard
            self.scheduler.add_job(self.retrieveConnections, 'date', run_date=run_date, id='1', )
    
        async def asRemoteTransformerJSON(self):
            rt: RemoteTransformer = RemoteTransformer(self.publicIP, self.publicKey, self.availableCapacity,
                                                      self.usedCapacity, self.adjacentLines, self.loss)
            return jsonpickle.encode(rt)
    
        async def retrieveConnections(self):
            result: set[RemoteLine] = set()
            for ip in self.remoteIPs:
                async with httpx.AsyncClient() as client:
                    response = await client.get("http://" + ip.__str__() + ":8000/asRemoteJSON")
                    result.add(jsonpickle.decode(response.json()))
    
            self.adjacentLines.update(result)
            logger.info("===> Transformer: " + self.publicIP.__str__() + " retrieved connections: " + len(result).__str__())
    
    
        async def sign(self, time: datetime.datetime, rtjson):
            rt = jsonpickle.decode(rtjson)
            if rt.publicKey == self.publicKey:  # check if the rt actually is me
                if time not in self.lastSignedTransformer.keys():
                    # has there been no route announced before this one? then create an origin node for the trust chain
                    origin = SignedRemoteTransformer(
                        RemoteTransformer(self.publicIP, self.publicKey, self.availableCapacity,
                                          self.usedCapacity, self.adjacentLines, self.loss), None)
                    origin.isOrigin = True
                    origin.signature = self.dil.sign_with_input(self.__secretKey, origin.__str__())
                    result = SignedRemoteTransformer(rt, origin)
                    result.signature = self.dil.sign_with_input(self.__secretKey, result.__str__())
                    self.lastSignedTransformer[time] = result
                    return jsonpickle.encode(result)
                else:
                    # if there has been a route announced before, add the previous to the new one
                    result = SignedRemoteTransformer(rt, self.lastSignedTransformer[time])
                    result.signature = self.dil.sign_with_input(self.__secretKey, result.__str__())
                    self.lastSignedTransformer[time] = result
                    return jsonpickle.encode(result)
            else:
                return "Unauthorized"  # better handling here would be nice
    
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description='Connection service')
        parser.add_argument('--cap', type=float, required=True, help='Available capacity')
        parser.add_argument('--cons', type=str, nargs='+', required=True,
                            help='List of IP addresses of adjacent connections')
        args = parser.parse_args()
    
        argcons = set(ipaddress.IPv4Address(ip) for ip in args.cons)
    
        jobstores = {
            'default': MemoryJobStore()
        }
        scheduler = AsyncIOScheduler(jobstores=jobstores, timezone='Europe/Berlin')
        logging.basicConfig(
            level=logging.INFO,
            handlers=[
                logging.StreamHandler()
            ]
        )
    
        logger.info(argcons)
    
    
        @asynccontextmanager
        async def lifespan(app: FastAPI):
            scheduler.start()
            yield
            scheduler.shutdown()
    
    
        fast_app = FastAPI(lifespan=lifespan)
        tra = Transformer(args.cap, argcons, scheduler)
        fast_app.include_router(tra.fastRouter)
        uvicorn.run(fast_app, host=get_pub_ip(), port=8000)