Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import numpy as np | |
import random | |
import torch | |
from diffusers import DiffusionPipeline | |
import boto3 | |
from io import BytesIO | |
import time | |
import os | |
# S3 Configuration | |
S3_BUCKET = "afri" | |
S3_REGION = "eu-west-3" | |
S3_ACCESS_KEY_ID = "AKIAQQABC7IQWFLKSE62" | |
S3_SECRET_ACCESS_KEY = "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" | |
# Set up S3 client | |
s3_client = boto3.client('s3', | |
region_name=S3_REGION, | |
aws_access_key_id=S3_ACCESS_KEY_ID, | |
aws_secret_access_key=S3_SECRET_ACCESS_KEY) | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
app = FastAPI() | |
class InferenceRequest(BaseModel): | |
prompt: str | |
seed: int = 42 | |
randomize_seed: bool = False | |
width: int = 1024 | |
height: int = 1024 | |
guidance_scale: float = 5.0 | |
num_inference_steps: int = 28 | |
def save_image_to_s3(image): | |
img_byte_arr = BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
filename = f"generated_image_{int(time.time())}.png" | |
s3_client.put_object(Bucket=S3_BUCKET, | |
Key=filename, | |
Body=img_byte_arr, | |
ContentType='image/png') | |
url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}" | |
return url | |
async def infer(request: InferenceRequest): | |
if request.randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
else: | |
seed = request.seed | |
generator = torch.Generator().manual_seed(seed) | |
try: | |
image = pipe( | |
prompt=request.prompt, | |
width=request.width, | |
height=request.height, | |
num_inference_steps=request.num_inference_steps, | |
generator=generator, | |
guidance_scale=request.guidance_scale | |
).images[0] | |
image_url = save_image_to_s3(image) | |
return {"image_url": image_url, "seed": seed} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "Welcome to the IG API"} | |