image_services / app.py
Uhhy's picture
Update app.py
63c2086 verified
import numpy as np
import random
import torch
from diffusers import DiffusionPipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import uvicorn
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
class InferenceRequest(BaseModel):
prompt: str
seed: int = 42
randomize_seed: bool = False
width: int = 1024
height: int = 1024
num_inference_steps: int = 4
class InferenceResponse(BaseModel):
image: str
seed: int
app = FastAPI()
@app.post("/infer", response_model=InferenceResponse)
async def infer(request: InferenceRequest):
if request.randomize_seed:
seed = random.randint(0, MAX_SEED)
else:
seed = request.seed
if not (256 <= request.width <= MAX_IMAGE_SIZE) or not (256 <= request.height <= MAX_IMAGE_SIZE):
raise HTTPException(status_code=400, detail="Width and height must be between 256 and 2048.")
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=request.prompt,
width=request.width,
height=request.height,
num_inference_steps=request.num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
# Convert image to base64
image_base64 = image_to_base64(image)
return InferenceResponse(image=image_base64, seed=seed)
def image_to_base64(image):
import io
import base64
from PIL import Image
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)