Spaces:
Sleeping
Sleeping
File size: 1,896 Bytes
f8f14a5 dff3f4b f8f14a5 dff3f4b f8f14a5 dff3f4b f8f14a5 dff3f4b f8f14a5 dff3f4b f8f14a5 dff3f4b f8f14a5 dff3f4b 63c2086 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import numpy as np
import random
import torch
from diffusers import DiffusionPipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import uvicorn
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
class InferenceRequest(BaseModel):
prompt: str
seed: int = 42
randomize_seed: bool = False
width: int = 1024
height: int = 1024
num_inference_steps: int = 4
class InferenceResponse(BaseModel):
image: str
seed: int
app = FastAPI()
@app.post("/infer", response_model=InferenceResponse)
async def infer(request: InferenceRequest):
if request.randomize_seed:
seed = random.randint(0, MAX_SEED)
else:
seed = request.seed
if not (256 <= request.width <= MAX_IMAGE_SIZE) or not (256 <= request.height <= MAX_IMAGE_SIZE):
raise HTTPException(status_code=400, detail="Width and height must be between 256 and 2048.")
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=request.prompt,
width=request.width,
height=request.height,
num_inference_steps=request.num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
# Convert image to base64
image_base64 = image_to_base64(image)
return InferenceResponse(image=image_base64, seed=seed)
def image_to_base64(image):
import io
import base64
from PIL import Image
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|