Spaces:
Build error
Build error
import numpy as np | |
import random | |
import torch | |
from diffusers import DiffusionPipeline | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from fastapi.responses import JSONResponse | |
import uvicorn | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
class InferenceRequest(BaseModel): | |
prompt: str | |
seed: int = 42 | |
randomize_seed: bool = False | |
width: int = 1024 | |
height: int = 1024 | |
num_inference_steps: int = 4 | |
class InferenceResponse(BaseModel): | |
image: str | |
seed: int | |
app = FastAPI() | |
async def infer(request: InferenceRequest): | |
if request.randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
else: | |
seed = request.seed | |
if not (256 <= request.width <= MAX_IMAGE_SIZE) or not (256 <= request.height <= MAX_IMAGE_SIZE): | |
raise HTTPException(status_code=400, detail="Width and height must be between 256 and 2048.") | |
generator = torch.Generator().manual_seed(seed) | |
image = pipe( | |
prompt=request.prompt, | |
width=request.width, | |
height=request.height, | |
num_inference_steps=request.num_inference_steps, | |
generator=generator, | |
guidance_scale=0.0 | |
).images[0] | |
# Convert image to base64 | |
image_base64 = image_to_base64(image) | |
return InferenceResponse(image=image_base64, seed=seed) | |
def image_to_base64(image): | |
import io | |
import base64 | |
from PIL import Image | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return img_str | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |