api-image / app.py
Arkm20's picture
Rename main.py to app.py
b79a2ef verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import FluxPipeline
import torch
from io import BytesIO
from fastapi.responses import StreamingResponse
app = FastAPI()
class Prompt(BaseModel):
text: str
# Load the FLUX model
model_id = "black-forest-labs/FLUX.1-schnell"
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
@app.post("/generate-image/")
async def generate_image(prompt: Prompt):
try:
# Generate the image
image = pipe(
prompt.text,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
# Save image to a BytesIO object
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))