Arkm20 commited on
Commit
a2ae7dd
·
verified ·
1 Parent(s): 9960385

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -0
main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from diffusers import FluxPipeline
4
+ import torch
5
+ from io import BytesIO
6
+ from fastapi.responses import StreamingResponse
7
+
8
+ app = FastAPI()
9
+
10
+ class Prompt(BaseModel):
11
+ text: str
12
+
13
+ # Load the FLUX model
14
+ model_id = "black-forest-labs/FLUX.1-schnell"
15
+ pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
16
+ pipe.enable_model_cpu_offload()
17
+
18
+ @app.post("/generate-image/")
19
+ async def generate_image(prompt: Prompt):
20
+ try:
21
+ # Generate the image
22
+ image = pipe(
23
+ prompt.text,
24
+ guidance_scale=0.0,
25
+ num_inference_steps=4,
26
+ max_sequence_length=256,
27
+ generator=torch.Generator("cpu").manual_seed(0)
28
+ ).images[0]
29
+
30
+ # Save image to a BytesIO object
31
+ img_byte_arr = BytesIO()
32
+ image.save(img_byte_arr, format='PNG')
33
+ img_byte_arr.seek(0)
34
+
35
+ return StreamingResponse(img_byte_arr, media_type="image/png")
36
+ except Exception as e:
37
+ raise HTTPException(status_code=500, detail=str(e))