from fastapi import FastAPI, HTTPException from pydantic import BaseModel from diffusers import FluxPipeline import torch from io import BytesIO from fastapi.responses import StreamingResponse app = FastAPI() class Prompt(BaseModel): text: str # Load the FLUX model model_id = "black-forest-labs/FLUX.1-schnell" pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() @app.post("/generate-image/") async def generate_image(prompt: Prompt): try: # Generate the image image = pipe( prompt.text, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, generator=torch.Generator("cpu").manual_seed(0) ).images[0] # Save image to a BytesIO object img_byte_arr = BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return StreamingResponse(img_byte_arr, media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=str(e))