Arkm20 commited on
Commit
846d4d5
·
verified ·
1 Parent(s): d5f6160

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +39 -1
app/main.py CHANGED
@@ -1,8 +1,19 @@
1
- from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
  from gradio_client import Client
4
 
5
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
6
  client = Client("Efficient-Large-Model/SanaSprint")
7
 
8
  class GenerationRequest(BaseModel):
@@ -15,6 +26,33 @@ class GenerationRequest(BaseModel):
15
  guidance_scale: float = 4.5
16
  num_inference_steps: int = 2
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @app.post("/generate")
19
  async def generate_image(request: GenerationRequest):
20
  try:
 
1
+ from fastapi import FastAPI, HTTPException, Query
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from gradio_client import Client
5
 
6
  app = FastAPI()
7
+
8
+
9
+ app.add_middleware(
10
+ CORSMiddleware,
11
+ allow_origins=["*"],
12
+ allow_methods=["*"],
13
+ allow_headers=["*"],
14
+ )
15
+
16
+
17
  client = Client("Efficient-Large-Model/SanaSprint")
18
 
19
  class GenerationRequest(BaseModel):
 
26
  guidance_scale: float = 4.5
27
  num_inference_steps: int = 2
28
 
29
+ @app.get("/gen")
30
+ def generate_image_get(
31
+ prompt: str = Query(..., description="Prompt for image generation"),
32
+ model_size: str = "1.6B",
33
+ seed: int = 0,
34
+ randomize_seed: bool = True,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ guidance_scale: float = 4.5,
38
+ num_inference_steps: int = 2
39
+ ):
40
+ try:
41
+ result = client.predict(
42
+ prompt=prompt,
43
+ model_size=model_size,
44
+ seed=seed,
45
+ randomize_seed=randomize_seed,
46
+ width=width,
47
+ height=height,
48
+ guidance_scale=guidance_scale,
49
+ num_inference_steps=num_inference_steps,
50
+ api_name="/infer"
51
+ )
52
+ return {"result": result}
53
+ except Exception as e:
54
+ raise HTTPException(status_code=500, detail=str(e))
55
+
56
  @app.post("/generate")
57
  async def generate_image(request: GenerationRequest):
58
  try: