Arkm20 commited on
Commit
a59eb8d
·
verified ·
1 Parent(s): fc7d0e5

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +21 -43
app/main.py CHANGED
@@ -1,60 +1,30 @@
1
  from fastapi import FastAPI, HTTPException, Query
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  from gradio_client import Client
 
 
5
 
6
  app = FastAPI()
7
 
8
- # Enable CORS for browser access
9
  app.add_middleware(
10
  CORSMiddleware,
11
- allow_origins=["*"], # Adjust as needed for security
12
  allow_credentials=True,
13
  allow_methods=["*"],
14
  allow_headers=["*"],
15
  )
16
 
17
- # Initialize the Gradio Client
18
  client = Client("K00B404/flux_666")
19
 
20
- # Request body schema for POST requests
21
- class GenerationRequest(BaseModel):
22
- prompt: str
23
- basemodel: str = "black-forest-labs/FLUX.1-schnell"
24
- width: int = 1280
25
- height: int = 768
26
- scales: int = 8
27
- steps: int = 8
28
- seed: int = -1
29
- upscale_factor: str = "2"
30
- process_upscale: bool = False
31
- lora_model: str = "XLabs-AI/flux-RealismLora"
32
- process_lora: bool = False
33
-
34
- @app.post("/generate")
35
- async def generate_image(request: GenerationRequest):
36
- try:
37
- result = client.predict(
38
- prompt=request.prompt,
39
- basemodel=request.basemodel,
40
- width=request.width,
41
- height=request.height,
42
- scales=request.scales,
43
- steps=request.steps,
44
- seed=request.seed,
45
- upscale_factor=request.upscale_factor,
46
- process_upscale=request.process_upscale,
47
- lora_model=request.lora_model,
48
- process_lora=request.process_lora,
49
- api_name="/gen"
50
- )
51
- return {"result": result}
52
- except Exception as e:
53
- raise HTTPException(status_code=500, detail=str(e))
54
 
55
- # Optional: GET endpoint for browser address bar access
56
- @app.get("/gen")
57
- async def generate_image_get(
58
  prompt: str = Query(..., description="Prompt for image generation"),
59
  basemodel: str = "black-forest-labs/FLUX.1-schnell",
60
  width: int = 1280,
@@ -68,7 +38,8 @@ async def generate_image_get(
68
  process_lora: bool = False
69
  ):
70
  try:
71
- result = client.predict(
 
72
  prompt=prompt,
73
  basemodel=basemodel,
74
  width=width,
@@ -82,6 +53,13 @@ async def generate_image_get(
82
  process_lora=process_lora,
83
  api_name="/gen"
84
  )
85
- return {"result": result}
 
 
 
 
 
 
 
86
  except Exception as e:
87
  raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, HTTPException, Query
2
+ from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  from gradio_client import Client
5
+ import requests
6
+ import io
7
 
8
  app = FastAPI()
9
 
10
+ # Allow CORS for all origins (you can restrict this in production)
11
  app.add_middleware(
12
  CORSMiddleware,
13
+ allow_origins=["*"],
14
  allow_credentials=True,
15
  allow_methods=["*"],
16
  allow_headers=["*"],
17
  )
18
 
19
+ # Initialize Gradio client
20
  client = Client("K00B404/flux_666")
21
 
22
+ @app.get("/")
23
+ def root():
24
+ return {"message": "Welcome to the Flux 666 Image Generator API!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ @app.get("/generate_image")
27
+ def generate_image(
 
28
  prompt: str = Query(..., description="Prompt for image generation"),
29
  basemodel: str = "black-forest-labs/FLUX.1-schnell",
30
  width: int = 1280,
 
38
  process_lora: bool = False
39
  ):
40
  try:
41
+ # Call the Gradio prediction API
42
+ image_url = client.predict(
43
  prompt=prompt,
44
  basemodel=basemodel,
45
  width=width,
 
53
  process_lora=process_lora,
54
  api_name="/gen"
55
  )
56
+
57
+ # Download the image
58
+ response = requests.get(image_url)
59
+ response.raise_for_status()
60
+
61
+ # Return the image stream to the browser
62
+ return StreamingResponse(io.BytesIO(response.content), media_type="image/png")
63
+
64
  except Exception as e:
65
  raise HTTPException(status_code=500, detail=str(e))