File size: 1,778 Bytes
284cb2b bfedca6 284cb2b bfedca6 284cb2b bfedca6 284cb2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Extra
import argparse
from typing import Optional
import uvicorn
from model import ChallengePromptGenerator
class Prompt(BaseModel, extra=Extra.allow):
prompt: str
seed: Optional[int] = 0
max_length: Optional[int] = 77
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=10001)
parser.add_argument("--netuid", type=str, default=23)
parser.add_argument("--min_stake", type=int, default=100)
parser.add_argument(
"--chain_endpoint",
type=str,
default="finney",
)
parser.add_argument("--disable_secure", action="store_true", default=False)
args = parser.parse_args()
return args
class ChallengeImage:
def __init__(self):
self.challenge_prompt = ChallengePromptGenerator()
self.app = FastAPI(title="Challenge Prompt")
self.app.add_api_route("/", self.__call__, methods=["POST"])
self.app.add_api_route("/", self.serve_index, methods=["GET"])
async def __call__(
self,
data: Prompt,
):
data = dict(data)
prompt = data["prompt"]
if not prompt:
prompt = "an image of "
complete_prompt = self.challenge_prompt.infer_prompt(
[prompt], max_generation_length=77, sampling_topk=100
)[0].strip()
return complete_prompt
async def serve_index(self):
with open("index.html", "r") as file:
return HTMLResponse(content=file.read(), status_code=200)
if __name__ == "__main__":
args = get_args()
print("Args: ", args)
app = ChallengeImage()
uvicorn.run(app.app, host="0.0.0.0", port=args.port)
|