File size: 2,390 Bytes
4559e89
 
8ccf632
 
 
3b76ae1
0a30c9c
 
 
3b76ae1
0a30c9c
 
 
 
 
 
 
 
 
 
 
 
06f0278
 
8ccf632
5c4c947
8ccf632
06f0278
8ccf632
4559e89
 
 
 
 
 
 
 
 
 
 
0a30c9c
 
 
 
 
 
 
8291039
 
 
3b76ae1
8291039
 
0a30c9c
27495d6
4559e89
 
 
54192f0
4559e89
 
 
8ccf632
0a30c9c
4559e89
 
 
 
 
 
 
 
 
8ccf632
4559e89
8ccf632
4559e89
 
 
8ccf632
4559e89
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import random
import torch
from diffusers import DiffusionPipeline
import boto3
from io import BytesIO
import time
import os

# S3 Configuration
S3_BUCKET = "afri"
S3_REGION = "eu-west-3"
S3_ACCESS_KEY_ID = "AKIAQQABC7IQWFLKSE62"
S3_SECRET_ACCESS_KEY = "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"

# Set up S3 client
s3_client = boto3.client('s3', 
                         region_name=S3_REGION,
                         aws_access_key_id=S3_ACCESS_KEY_ID,
                         aws_secret_access_key=S3_SECRET_ACCESS_KEY)

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

app = FastAPI()

class InferenceRequest(BaseModel):
    prompt: str
    seed: int = 42
    randomize_seed: bool = False
    width: int = 1024
    height: int = 1024
    guidance_scale: float = 5.0
    num_inference_steps: int = 28

def save_image_to_s3(image):
    img_byte_arr = BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()

    filename = f"generated_image_{int(time.time())}.png"

    s3_client.put_object(Bucket=S3_BUCKET, 
                         Key=filename, 
                         Body=img_byte_arr, 
                         ContentType='image/png')

    url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
    return url

@app.post("/infer")
async def infer(request: InferenceRequest):
    if request.randomize_seed:
        seed = random.randint(0, MAX_SEED)
    else:
        seed = request.seed

    generator = torch.Generator().manual_seed(seed)
    
    try:
        image = pipe(
            prompt=request.prompt, 
            width=request.width,
            height=request.height,
            num_inference_steps=request.num_inference_steps, 
            generator=generator,
            guidance_scale=request.guidance_scale
        ).images[0]
        
        image_url = save_image_to_s3(image)
        
        return {"image_url": image_url, "seed": seed}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "Welcome to the IG API"}