import numpy as np import random import torch from diffusers import DiffusionPipeline from fastapi import FastAPI, HTTPException from pydantic import BaseModel from fastapi.responses import JSONResponse import uvicorn dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 class InferenceRequest(BaseModel): prompt: str seed: int = 42 randomize_seed: bool = False width: int = 1024 height: int = 1024 num_inference_steps: int = 4 class InferenceResponse(BaseModel): image: str seed: int app = FastAPI() @app.post("/infer", response_model=InferenceResponse) async def infer(request: InferenceRequest): if request.randomize_seed: seed = random.randint(0, MAX_SEED) else: seed = request.seed if not (256 <= request.width <= MAX_IMAGE_SIZE) or not (256 <= request.height <= MAX_IMAGE_SIZE): raise HTTPException(status_code=400, detail="Width and height must be between 256 and 2048.") generator = torch.Generator().manual_seed(seed) image = pipe( prompt=request.prompt, width=request.width, height=request.height, num_inference_steps=request.num_inference_steps, generator=generator, guidance_scale=0.0 ).images[0] # Convert image to base64 image_base64 = image_to_base64(image) return InferenceResponse(image=image_base64, seed=seed) def image_to_base64(image): import io import base64 from PIL import Image buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_str if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)