from fastapi import FastAPI, HTTPException from pydantic import BaseModel from multiprocessing import Process, Queue from diffusers import FluxPipeline import torch import io from fastapi.responses import StreamingResponse import uvicorn app = FastAPI() pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="main") pipe.enable_model_cpu_offload() class ImageRequest(BaseModel): prompt: str def generate_image_response(request, queue): try: image = pipe( request.prompt, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, generator=torch.Generator("cpu").manual_seed(0) ).images[0] img_io = io.BytesIO() image.save(img_io, 'PNG') img_io.seek(0) queue.put(img_io.getvalue()) except Exception as e: queue.put(f"Error: {str(e)}") @app.post("/generate_image") async def generate_image(request: ImageRequest): queue = Queue() p = Process(target=generate_image_response, args=(request, queue)) p.start() p.join() response = queue.get() if "Error" in response: raise HTTPException(status_code=500, detail=response) return StreamingResponse(io.BytesIO(response), media_type="image/png") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)