image_services / app.py
Uhhy's picture
Create app.py
dff3f4b verified
raw
history blame
1.4 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from multiprocessing import Process, Queue
from diffusers import FluxPipeline
import torch
import io
from fastapi.responses import StreamingResponse
import uvicorn
app = FastAPI()
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="main")
pipe.enable_model_cpu_offload()
class ImageRequest(BaseModel):
prompt: str
def generate_image_response(request, queue):
try:
image = pipe(
request.prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
img_io = io.BytesIO()
image.save(img_io, 'PNG')
img_io.seek(0)
queue.put(img_io.getvalue())
except Exception as e:
queue.put(f"Error: {str(e)}")
@app.post("/generate_image")
async def generate_image(request: ImageRequest):
queue = Queue()
p = Process(target=generate_image_response, args=(request, queue))
p.start()
p.join()
response = queue.get()
if "Error" in response:
raise HTTPException(status_code=500, detail=response)
return StreamingResponse(io.BytesIO(response), media_type="image/png")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8002)