image-gen / tld /app.py
BeveledCube's picture
Pls work
610afda
raw
history blame
2.06 kB
import io
import os
from typing import Optional
import torch
import torchvision.transforms as transforms
from fastapi import FastAPI, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from diffusion import DiffusionTransformer, LTDConfig
# Get the directory of the script
script_directory = os.path.dirname(os.path.realpath(__file__))
# Specify the directory where the cache will be stored (same folder as the script)
cache_directory = os.path.join(script_directory, "cache")
home_directory = os.path.join(script_directory, "home")
# Create the cache directory if it doesn't exist
os.makedirs(cache_directory, exist_ok=True)
os.makedirs(home_directory, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = cache_directory
os.environ["HF_HOME"] = home_directory
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
to_pil = transforms.ToPILImage()
ltdconfig = LTDConfig()
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
app = FastAPI()
class ImageRequest(BaseModel):
prompt: str
class_guidance: Optional[int] = 6
seed: Optional[int] = 11
num_imgs: Optional[int] = 1
img_size: Optional[int] = 32
@app.get("/")
def read_root():
return {"message": "Welcome to Image Generator"}
@app.post("/generate-image/")
async def generate_image(request: ImageRequest):
try:
img = diffusion_transformer.generate_image_from_text(
prompt=request.prompt,
class_guidance=request.class_guidance,
seed=request.seed,
num_imgs=request.num_imgs,
img_size=request.img_size,
)
# Convert PIL image to byte stream suitable for HTTP response
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="JPEG")
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/jpeg")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# build job to test and deploy the API on a docker image (maybe in Azure?)