File size: 2,057 Bytes
3b3a783
 
 
 
 
 
610afda
3b3a783
 
 
 
610afda
 
 
 
 
 
 
 
 
 
 
 
 
3b3a783
 
 
 
610afda
3b3a783
 
 
610afda
 
 
 
 
3b3a783
 
 
 
 
 
 
 
610afda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b3a783
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import io
import os
from typing import Optional

import torch
import torchvision.transforms as transforms
from fastapi import FastAPI, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel

from diffusion import DiffusionTransformer, LTDConfig

# Get the directory of the script
script_directory = os.path.dirname(os.path.realpath(__file__))
# Specify the directory where the cache will be stored (same folder as the script)
cache_directory = os.path.join(script_directory, "cache")
home_directory = os.path.join(script_directory, "home")
# Create the cache directory if it doesn't exist
os.makedirs(cache_directory, exist_ok=True)
os.makedirs(home_directory, exist_ok=True)

os.environ["TRANSFORMERS_CACHE"] = cache_directory
os.environ["HF_HOME"] = home_directory

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
to_pil = transforms.ToPILImage()
ltdconfig = LTDConfig()
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
app = FastAPI()

class ImageRequest(BaseModel):
  prompt: str
  class_guidance: Optional[int] = 6
  seed: Optional[int] = 11
  num_imgs: Optional[int] = 1
  img_size: Optional[int] = 32


@app.get("/")
def read_root():
    return {"message": "Welcome to Image Generator"}


@app.post("/generate-image/")
async def generate_image(request: ImageRequest):
  try:
    img = diffusion_transformer.generate_image_from_text(
      prompt=request.prompt,
      class_guidance=request.class_guidance,
      seed=request.seed,
      num_imgs=request.num_imgs,
      img_size=request.img_size,
    )
    
    # Convert PIL image to byte stream suitable for HTTP response
    img_byte_arr = io.BytesIO()
    img.save(img_byte_arr, format="JPEG")
    img_byte_arr.seek(0)

    return StreamingResponse(img_byte_arr, media_type="image/jpeg")
  except Exception as e:
    raise HTTPException(status_code=500, detail=str(e))


# build job to test and deploy the API on a docker image (maybe in Azure?)