ig / app.py
Afrinetwork7's picture
Update app.py
4559e89 verified
raw
history blame
2.39 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import random
import torch
from diffusers import DiffusionPipeline
import boto3
from io import BytesIO
import time
import os
# S3 Configuration
S3_BUCKET = "afri"
S3_REGION = "eu-west-3"
S3_ACCESS_KEY_ID = "AKIAQQABC7IQWFLKSE62"
S3_SECRET_ACCESS_KEY = "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"
# Set up S3 client
s3_client = boto3.client('s3',
region_name=S3_REGION,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY)
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
app = FastAPI()
class InferenceRequest(BaseModel):
prompt: str
seed: int = 42
randomize_seed: bool = False
width: int = 1024
height: int = 1024
guidance_scale: float = 5.0
num_inference_steps: int = 28
def save_image_to_s3(image):
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
filename = f"generated_image_{int(time.time())}.png"
s3_client.put_object(Bucket=S3_BUCKET,
Key=filename,
Body=img_byte_arr,
ContentType='image/png')
url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
return url
@app.post("/infer")
async def infer(request: InferenceRequest):
if request.randomize_seed:
seed = random.randint(0, MAX_SEED)
else:
seed = request.seed
generator = torch.Generator().manual_seed(seed)
try:
image = pipe(
prompt=request.prompt,
width=request.width,
height=request.height,
num_inference_steps=request.num_inference_steps,
generator=generator,
guidance_scale=request.guidance_scale
).images[0]
image_url = save_image_to_s3(image)
return {"image_url": image_url, "seed": seed}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Welcome to the IG API"}