Uhhy commited on
Commit
f8f14a5
1 Parent(s): 40b5dee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -37
app.py CHANGED
@@ -1,47 +1,69 @@
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from multiprocessing import Process, Queue
4
- from diffusers import FluxPipeline
5
- import torch
6
- import io
7
- from fastapi.responses import StreamingResponse
8
  import uvicorn
9
 
10
- app = FastAPI()
 
11
 
12
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="main")
13
- pipe.enable_model_cpu_offload()
14
 
15
- class ImageRequest(BaseModel):
 
 
 
16
  prompt: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def generate_image_response(request, queue):
19
- try:
20
- image = pipe(
21
- request.prompt,
22
- guidance_scale=0.0,
23
- num_inference_steps=4,
24
- max_sequence_length=256,
25
- generator=torch.Generator("cpu").manual_seed(0)
26
- ).images[0]
27
-
28
- img_io = io.BytesIO()
29
- image.save(img_io, 'PNG')
30
- img_io.seek(0)
31
- queue.put(img_io.getvalue())
32
- except Exception as e:
33
- queue.put(f"Error: {str(e)}")
34
-
35
- @app.post("/generate_image")
36
- async def generate_image(request: ImageRequest):
37
- queue = Queue()
38
- p = Process(target=generate_image_response, args=(request, queue))
39
- p.start()
40
- p.join()
41
- response = queue.get()
42
- if "Error" in response:
43
- raise HTTPException(status_code=500, detail=response)
44
- return StreamingResponse(io.BytesIO(response), media_type="image/png")
45
 
46
  if __name__ == "__main__":
47
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from diffusers import DiffusionPipeline
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
+ from fastapi.responses import JSONResponse
 
 
 
 
8
  import uvicorn
9
 
10
+ dtype = torch.bfloat16
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
14
 
15
+ MAX_SEED = np.iinfo(np.int32).max
16
+ MAX_IMAGE_SIZE = 2048
17
+
18
+ class InferenceRequest(BaseModel):
19
  prompt: str
20
+ seed: int = 42
21
+ randomize_seed: bool = False
22
+ width: int = 1024
23
+ height: int = 1024
24
+ num_inference_steps: int = 4
25
+
26
+ class InferenceResponse(BaseModel):
27
+ image: str
28
+ seed: int
29
+
30
+ app = FastAPI()
31
+
32
+ @app.post("/infer", response_model=InferenceResponse)
33
+ async def infer(request: InferenceRequest):
34
+ if request.randomize_seed:
35
+ seed = random.randint(0, MAX_SEED)
36
+ else:
37
+ seed = request.seed
38
+
39
+ if not (256 <= request.width <= MAX_IMAGE_SIZE) or not (256 <= request.height <= MAX_IMAGE_SIZE):
40
+ raise HTTPException(status_code=400, detail="Width and height must be between 256 and 2048.")
41
+
42
+ generator = torch.Generator().manual_seed(seed)
43
+ image = pipe(
44
+ prompt=request.prompt,
45
+ width=request.width,
46
+ height=request.height,
47
+ num_inference_steps=request.num_inference_steps,
48
+ generator=generator,
49
+ guidance_scale=0.0
50
+ ).images[0]
51
+
52
+ # Convert image to base64
53
+ image_base64 = image_to_base64(image)
54
+
55
+ return InferenceResponse(image=image_base64, seed=seed)
56
+
57
+ def image_to_base64(image):
58
+ import io
59
+ import base64
60
+ from PIL import Image
61
 
62
+ buffered = io.BytesIO()
63
+ image.save(buffered, format="PNG")
64
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
65
+
66
+ return img_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  if __name__ == "__main__":
69
+ uvicorn.run(app, host="0.0.0.0", port=8000)