File size: 1,896 Bytes
f8f14a5
 
 
 
dff3f4b
 
f8f14a5
dff3f4b
 
f8f14a5
 
dff3f4b
f8f14a5
dff3f4b
f8f14a5
 
 
 
dff3f4b
f8f14a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dff3f4b
f8f14a5
 
 
 
 
dff3f4b
 
63c2086
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import random
import torch
from diffusers import DiffusionPipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import uvicorn

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

class InferenceRequest(BaseModel):
    prompt: str
    seed: int = 42
    randomize_seed: bool = False
    width: int = 1024
    height: int = 1024
    num_inference_steps: int = 4

class InferenceResponse(BaseModel):
    image: str
    seed: int

app = FastAPI()

@app.post("/infer", response_model=InferenceResponse)
async def infer(request: InferenceRequest):
    if request.randomize_seed:
        seed = random.randint(0, MAX_SEED)
    else:
        seed = request.seed

    if not (256 <= request.width <= MAX_IMAGE_SIZE) or not (256 <= request.height <= MAX_IMAGE_SIZE):
        raise HTTPException(status_code=400, detail="Width and height must be between 256 and 2048.")

    generator = torch.Generator().manual_seed(seed)
    image = pipe(
        prompt=request.prompt,
        width=request.width,
        height=request.height,
        num_inference_steps=request.num_inference_steps,
        generator=generator,
        guidance_scale=0.0
    ).images[0]

    # Convert image to base64
    image_base64 = image_to_base64(image)
    
    return InferenceResponse(image=image_base64, seed=seed)

def image_to_base64(image):
    import io
    import base64
    from PIL import Image

    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    return img_str

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)