File size: 3,426 Bytes
4559e89
 
8ccf632
 
 
0a30c9c
 
 
3b76ae1
9237c74
 
 
0a30c9c
 
 
 
 
 
 
 
 
 
 
 
06f0278
9237c74
 
 
 
 
 
 
 
 
 
 
8ccf632
9237c74
 
 
 
 
 
 
8ccf632
06f0278
8ccf632
4559e89
 
 
 
 
9237c74
4559e89
 
9237c74
 
 
 
 
 
 
 
 
 
 
 
 
 
4559e89
0a30c9c
 
 
 
 
8291039
 
 
3b76ae1
8291039
0a30c9c
27495d6
4559e89
 
 
54192f0
4559e89
 
8ccf632
0a30c9c
4559e89
9237c74
 
 
 
 
 
 
 
 
8ccf632
4559e89
8ccf632
4559e89
 
 
8ccf632
4559e89
 
9237c74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import random
import torch
import boto3
from io import BytesIO
import time
import os
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline

# S3 Configuration
S3_BUCKET = "afri"
S3_REGION = "eu-west-3"
S3_ACCESS_KEY_ID = "AKIAQQABC7IQWFLKSE62"
S3_SECRET_ACCESS_KEY = "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"

# Set up S3 client
s3_client = boto3.client('s3', 
                         region_name=S3_REGION,
                         aws_access_key_id=S3_ACCESS_KEY_ID,
                         aws_secret_access_key=S3_SECRET_ACCESS_KEY)

# Set up cache path
cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path

if not os.path.exists(cache_path):
    os.makedirs(cache_path, exist_ok=True)

# Set up CUDA and model
torch.backends.cuda.matmul.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device=device, dtype=torch.bfloat16)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

app = FastAPI()

class InferenceRequest(BaseModel):
    prompt: str
    seed: int = 42
    randomize_seed: bool = True
    width: int = 1024
    height: int = 1024
    guidance_scale: float = 3.5
    num_inference_steps: int = 8

class Timer:
    def __init__(self, method_name="timed process"):
        self.method = method_name

    def __enter__(self):
        self.start = time.time()
        print(f"{self.method} starts")

    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        print(f"{self.method} took {str(round(end - self.start, 2))}s")

def save_image_to_s3(image):
    img_byte_arr = BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()
    filename = f"generated_image_{int(time.time())}.png"
    s3_client.put_object(Bucket=S3_BUCKET, 
                         Key=filename, 
                         Body=img_byte_arr, 
                         ContentType='image/png')
    url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
    return url

@app.post("/infer")
async def infer(request: InferenceRequest):
    if request.randomize_seed:
        seed = random.randint(0, MAX_SEED)
    else:
        seed = request.seed
    generator = torch.Generator().manual_seed(seed)
    
    try:
        with Timer("Image generation"):
            image = pipe(
                prompt=request.prompt, 
                width=request.width,
                height=request.height,
                num_inference_steps=request.num_inference_steps, 
                generator=generator,
                guidance_scale=request.guidance_scale
            ).images[0]
        
        image_url = save_image_to_s3(image)
        
        return {"image_url": image_url, "seed": seed}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "Welcome to the IG API"}