Arkm20 commited on
Commit
3dc398f
·
verified ·
1 Parent(s): bc6468a

Create app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +34 -0
app/main.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from gradio_client import Client
4
+
5
+ app = FastAPI()
6
+ client = Client("Efficient-Large-Model/SanaSprint")
7
+
8
+ class GenerationRequest(BaseModel):
9
+ prompt: str
10
+ model_size: str = "1.6B"
11
+ seed: int = 0
12
+ randomize_seed: bool = True
13
+ width: int = 1024
14
+ height: int = 1024
15
+ guidance_scale: float = 4.5
16
+ num_inference_steps: int = 2
17
+
18
+ @app.post("/generate")
19
+ async def generate_image(request: GenerationRequest):
20
+ try:
21
+ result = client.predict(
22
+ prompt=request.prompt,
23
+ model_size=request.model_size,
24
+ seed=request.seed,
25
+ randomize_seed=request.randomize_seed,
26
+ width=request.width,
27
+ height=request.height,
28
+ guidance_scale=request.guidance_scale,
29
+ num_inference_steps=request.num_inference_steps,
30
+ api_name="/infer"
31
+ )
32
+ return {"result": result}
33
+ except Exception as e:
34
+ raise HTTPException(status_code=500, detail=str(e))