Arkm20 commited on
Commit
5482fea
·
verified ·
1 Parent(s): 867afa8

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +17 -12
app/main.py CHANGED
@@ -2,12 +2,13 @@ from fastapi import FastAPI, HTTPException, Query
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from gradio_client import Client
5
- import requests
6
  import io
 
7
 
8
  app = FastAPI()
9
 
10
- # Allow CORS for all origins (you can restrict this in production)
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
@@ -16,14 +17,14 @@ app.add_middleware(
16
  allow_headers=["*"],
17
  )
18
 
19
- # Initialize Gradio client
20
  client = Client("K00B404/flux_666")
21
 
22
  @app.get("/")
23
  def root():
24
- return {"message": "Welcome to the Image Generator API!"}
25
 
26
- @app.get("/generate_image")
27
  def generate_image(
28
  prompt: str = Query(..., description="Prompt for image generation"),
29
  basemodel: str = "black-forest-labs/FLUX.1-schnell",
@@ -38,8 +39,7 @@ def generate_image(
38
  process_lora: bool = False
39
  ):
40
  try:
41
- # Call the Gradio prediction API
42
- image_url = client.predict(
43
  prompt=prompt,
44
  basemodel=basemodel,
45
  width=width,
@@ -54,12 +54,17 @@ def generate_image(
54
  api_name="/gen"
55
  )
56
 
57
- # Download the image
58
- response = requests.get(image_url)
59
- response.raise_for_status()
 
 
60
 
61
- # Return the image stream to the browser
62
- return StreamingResponse(io.BytesIO(response.content), media_type="image/png")
 
 
 
63
 
64
  except Exception as e:
65
  raise HTTPException(status_code=500, detail=str(e))
 
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from gradio_client import Client
5
+ from PIL import Image
6
  import io
7
+ import os
8
 
9
  app = FastAPI()
10
 
11
+ # Allow CORS
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
 
17
  allow_headers=["*"],
18
  )
19
 
20
+ # Connect to the model
21
  client = Client("K00B404/flux_666")
22
 
23
  @app.get("/")
24
  def root():
25
+ return {"message": "Welcome to the Flux 666 Image Generator API!"}
26
 
27
+ @app.get("/gen")
28
  def generate_image(
29
  prompt: str = Query(..., description="Prompt for image generation"),
30
  basemodel: str = "black-forest-labs/FLUX.1-schnell",
 
39
  process_lora: bool = False
40
  ):
41
  try:
42
+ result = client.predict(
 
43
  prompt=prompt,
44
  basemodel=basemodel,
45
  width=width,
 
54
  api_name="/gen"
55
  )
56
 
57
+ # result is a list of local file paths
58
+ if isinstance(result, list) and result:
59
+ image_path = result[0]
60
+ else:
61
+ raise ValueError("No image returned from model.")
62
 
63
+ # Read the image from disk
64
+ with open(image_path, "rb") as img_file:
65
+ img_bytes = img_file.read()
66
+
67
+ return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png")
68
 
69
  except Exception as e:
70
  raise HTTPException(status_code=500, detail=str(e))