from fastapi import FastAPI from fastapi.responses import HTMLResponse from pydantic import BaseModel, Extra import argparse from typing import Optional import uvicorn from model import ChallengePromptGenerator class Prompt(BaseModel, extra=Extra.allow): prompt: str seed: Optional[int] = 0 max_length: Optional[int] = 77 def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=10001) parser.add_argument("--netuid", type=str, default=23) parser.add_argument("--min_stake", type=int, default=100) parser.add_argument( "--chain_endpoint", type=str, default="finney", ) parser.add_argument("--disable_secure", action="store_true", default=False) args = parser.parse_args() return args class ChallengeImage: def __init__(self): self.challenge_prompt = ChallengePromptGenerator() self.app = FastAPI(title="Challenge Prompt") self.app.add_api_route("/", self.__call__, methods=["POST"]) self.app.add_api_route("/", self.serve_index, methods=["GET"]) async def __call__( self, data: Prompt, ): data = dict(data) prompt = data["prompt"] if not prompt: prompt = "an image of " complete_prompt = self.challenge_prompt.infer_prompt( [prompt], max_generation_length=77, sampling_topk=100 )[0].strip() return complete_prompt async def serve_index(self): with open("index.html", "r") as file: return HTMLResponse(content=file.read(), status_code=200) if __name__ == "__main__": args = get_args() print("Args: ", args) app = ChallengeImage() uvicorn.run(app.app, host="0.0.0.0", port=args.port)