File size: 1,778 Bytes
284cb2b
bfedca6
284cb2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfedca6
284cb2b
 
 
 
 
 
 
 
 
 
 
 
 
 
bfedca6
 
 
 
284cb2b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Extra
import argparse
from typing import Optional
import uvicorn
from model import ChallengePromptGenerator


class Prompt(BaseModel, extra=Extra.allow):
    prompt: str
    seed: Optional[int] = 0
    max_length: Optional[int] = 77


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=10001)
    parser.add_argument("--netuid", type=str, default=23)
    parser.add_argument("--min_stake", type=int, default=100)
    parser.add_argument(
        "--chain_endpoint",
        type=str,
        default="finney",
    )
    parser.add_argument("--disable_secure", action="store_true", default=False)
    args = parser.parse_args()
    return args


class ChallengeImage:
    def __init__(self):
        self.challenge_prompt = ChallengePromptGenerator()
        self.app = FastAPI(title="Challenge Prompt")
        self.app.add_api_route("/", self.__call__, methods=["POST"])
        self.app.add_api_route("/", self.serve_index, methods=["GET"])

    async def __call__(
        self,
        data: Prompt,
    ):
        data = dict(data)
        prompt = data["prompt"]
        if not prompt:
            prompt = "an image of "
        complete_prompt = self.challenge_prompt.infer_prompt(
            [prompt], max_generation_length=77, sampling_topk=100
        )[0].strip()
        return complete_prompt

    async def serve_index(self):
        with open("index.html", "r") as file:
            return HTMLResponse(content=file.read(), status_code=200)


if __name__ == "__main__":
    args = get_args()
    print("Args: ", args)
    app = ChallengeImage()
    uvicorn.run(app.app, host="0.0.0.0", port=args.port)