Arkm20 commited on
Commit
fc7d0e5
·
verified ·
1 Parent(s): 846d4d5

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +57 -42
app/main.py CHANGED
@@ -5,67 +5,82 @@ from gradio_client import Client
5
 
6
  app = FastAPI()
7
 
8
-
9
  app.add_middleware(
10
  CORSMiddleware,
11
- allow_origins=["*"],
 
12
  allow_methods=["*"],
13
  allow_headers=["*"],
14
  )
15
 
 
 
16
 
17
- client = Client("Efficient-Large-Model/SanaSprint")
18
-
19
  class GenerationRequest(BaseModel):
20
  prompt: str
21
- model_size: str = "1.6B"
22
- seed: int = 0
23
- randomize_seed: bool = True
24
- width: int = 1024
25
- height: int = 1024
26
- guidance_scale: float = 4.5
27
- num_inference_steps: int = 2
 
 
 
28
 
29
- @app.get("/gen")
30
- def generate_image_get(
31
- prompt: str = Query(..., description="Prompt for image generation"),
32
- model_size: str = "1.6B",
33
- seed: int = 0,
34
- randomize_seed: bool = True,
35
- width: int = 1024,
36
- height: int = 1024,
37
- guidance_scale: float = 4.5,
38
- num_inference_steps: int = 2
39
- ):
40
  try:
41
  result = client.predict(
42
- prompt=prompt,
43
- model_size=model_size,
44
- seed=seed,
45
- randomize_seed=randomize_seed,
46
- width=width,
47
- height=height,
48
- guidance_scale=guidance_scale,
49
- num_inference_steps=num_inference_steps,
50
- api_name="/infer"
 
 
 
51
  )
52
  return {"result": result}
53
  except Exception as e:
54
  raise HTTPException(status_code=500, detail=str(e))
55
 
56
- @app.post("/generate")
57
- async def generate_image(request: GenerationRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  try:
59
  result = client.predict(
60
- prompt=request.prompt,
61
- model_size=request.model_size,
62
- seed=request.seed,
63
- randomize_seed=request.randomize_seed,
64
- width=request.width,
65
- height=request.height,
66
- guidance_scale=request.guidance_scale,
67
- num_inference_steps=request.num_inference_steps,
68
- api_name="/infer"
 
 
 
69
  )
70
  return {"result": result}
71
  except Exception as e:
 
5
 
6
  app = FastAPI()
7
 
8
+ # Enable CORS for browser access
9
  app.add_middleware(
10
  CORSMiddleware,
11
+ allow_origins=["*"], # Adjust as needed for security
12
+ allow_credentials=True,
13
  allow_methods=["*"],
14
  allow_headers=["*"],
15
  )
16
 
17
+ # Initialize the Gradio Client
18
+ client = Client("K00B404/flux_666")
19
 
20
+ # Request body schema for POST requests
 
21
  class GenerationRequest(BaseModel):
22
  prompt: str
23
+ basemodel: str = "black-forest-labs/FLUX.1-schnell"
24
+ width: int = 1280
25
+ height: int = 768
26
+ scales: int = 8
27
+ steps: int = 8
28
+ seed: int = -1
29
+ upscale_factor: str = "2"
30
+ process_upscale: bool = False
31
+ lora_model: str = "XLabs-AI/flux-RealismLora"
32
+ process_lora: bool = False
33
 
34
+ @app.post("/generate")
35
+ async def generate_image(request: GenerationRequest):
 
 
 
 
 
 
 
 
 
36
  try:
37
  result = client.predict(
38
+ prompt=request.prompt,
39
+ basemodel=request.basemodel,
40
+ width=request.width,
41
+ height=request.height,
42
+ scales=request.scales,
43
+ steps=request.steps,
44
+ seed=request.seed,
45
+ upscale_factor=request.upscale_factor,
46
+ process_upscale=request.process_upscale,
47
+ lora_model=request.lora_model,
48
+ process_lora=request.process_lora,
49
+ api_name="/gen"
50
  )
51
  return {"result": result}
52
  except Exception as e:
53
  raise HTTPException(status_code=500, detail=str(e))
54
 
55
+ # Optional: GET endpoint for browser address bar access
56
+ @app.get("/gen")
57
+ async def generate_image_get(
58
+ prompt: str = Query(..., description="Prompt for image generation"),
59
+ basemodel: str = "black-forest-labs/FLUX.1-schnell",
60
+ width: int = 1280,
61
+ height: int = 768,
62
+ scales: int = 8,
63
+ steps: int = 8,
64
+ seed: int = -1,
65
+ upscale_factor: str = "2",
66
+ process_upscale: bool = False,
67
+ lora_model: str = "XLabs-AI/flux-RealismLora",
68
+ process_lora: bool = False
69
+ ):
70
  try:
71
  result = client.predict(
72
+ prompt=prompt,
73
+ basemodel=basemodel,
74
+ width=width,
75
+ height=height,
76
+ scales=scales,
77
+ steps=steps,
78
+ seed=seed,
79
+ upscale_factor=upscale_factor,
80
+ process_upscale=process_upscale,
81
+ lora_model=lora_model,
82
+ process_lora=process_lora,
83
+ api_name="/gen"
84
  )
85
  return {"result": result}
86
  except Exception as e: