Uhhy commited on
Commit
dff3f4b
·
verified ·
1 Parent(s): 2d0a464

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from multiprocessing import Process, Queue
4
+ from diffusers import FluxPipeline
5
+ import torch
6
+ import io
7
+ from fastapi.responses import StreamingResponse
8
+ import uvicorn
9
+
10
+ app = FastAPI()
11
+
12
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="main")
13
+ pipe.enable_model_cpu_offload()
14
+
15
+ class ImageRequest(BaseModel):
16
+ prompt: str
17
+
18
+ def generate_image_response(request, queue):
19
+ try:
20
+ image = pipe(
21
+ request.prompt,
22
+ guidance_scale=0.0,
23
+ num_inference_steps=4,
24
+ max_sequence_length=256,
25
+ generator=torch.Generator("cpu").manual_seed(0)
26
+ ).images[0]
27
+
28
+ img_io = io.BytesIO()
29
+ image.save(img_io, 'PNG')
30
+ img_io.seek(0)
31
+ queue.put(img_io.getvalue())
32
+ except Exception as e:
33
+ queue.put(f"Error: {str(e)}")
34
+
35
+ @app.post("/generate_image")
36
+ async def generate_image(request: ImageRequest):
37
+ queue = Queue()
38
+ p = Process(target=generate_image_response, args=(request, queue))
39
+ p.start()
40
+ p.join()
41
+ response = queue.get()
42
+ if "Error" in response:
43
+ raise HTTPException(status_code=500, detail=response)
44
+ return StreamingResponse(io.BytesIO(response), media_type="image/png")
45
+
46
+ if __name__ == "__main__":
47
+ uvicorn.run(app, host="0.0.0.0", port=8002)