ig1 / app.py
Afrinetwork7's picture
Update app.py
9237c74 verified
raw
history blame
3.43 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import random
import torch
import boto3
from io import BytesIO
import time
import os
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline
# S3 Configuration
S3_BUCKET = "afri"
S3_REGION = "eu-west-3"
S3_ACCESS_KEY_ID = "AKIAQQABC7IQWFLKSE62"
S3_SECRET_ACCESS_KEY = "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"
# Set up S3 client
s3_client = boto3.client('s3',
region_name=S3_REGION,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY)
# Set up cache path
cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
# Set up CUDA and model
torch.backends.cuda.matmul.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device=device, dtype=torch.bfloat16)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
app = FastAPI()
class InferenceRequest(BaseModel):
prompt: str
seed: int = 42
randomize_seed: bool = True
width: int = 1024
height: int = 1024
guidance_scale: float = 3.5
num_inference_steps: int = 8
class Timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
def save_image_to_s3(image):
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
filename = f"generated_image_{int(time.time())}.png"
s3_client.put_object(Bucket=S3_BUCKET,
Key=filename,
Body=img_byte_arr,
ContentType='image/png')
url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
return url
@app.post("/infer")
async def infer(request: InferenceRequest):
if request.randomize_seed:
seed = random.randint(0, MAX_SEED)
else:
seed = request.seed
generator = torch.Generator().manual_seed(seed)
try:
with Timer("Image generation"):
image = pipe(
prompt=request.prompt,
width=request.width,
height=request.height,
num_inference_steps=request.num_inference_steps,
generator=generator,
guidance_scale=request.guidance_scale
).images[0]
image_url = save_image_to_s3(image)
return {"image_url": image_url, "seed": seed}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Welcome to the IG API"}