Spaces:
Runtime error
Runtime error
import io | |
import os | |
from typing import Optional | |
import torch | |
import torchvision.transforms as transforms | |
from fastapi import FastAPI, HTTPException, status | |
from fastapi.responses import StreamingResponse | |
from fastapi.security import OAuth2PasswordBearer | |
from pydantic import BaseModel | |
from diffusion import DiffusionTransformer, LTDConfig | |
# Get the directory of the script | |
script_directory = os.path.dirname(os.path.realpath(__file__)) | |
# Specify the directory where the cache will be stored (same folder as the script) | |
cache_directory = os.path.join(script_directory, "cache") | |
home_directory = os.path.join(script_directory, "home") | |
# Create the cache directory if it doesn't exist | |
os.makedirs(cache_directory, exist_ok=True) | |
os.makedirs(home_directory, exist_ok=True) | |
os.environ["TRANSFORMERS_CACHE"] = cache_directory | |
os.environ["HF_HOME"] = home_directory | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
to_pil = transforms.ToPILImage() | |
ltdconfig = LTDConfig() | |
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here | |
app = FastAPI() | |
class ImageRequest(BaseModel): | |
prompt: str | |
class_guidance: Optional[int] = 6 | |
seed: Optional[int] = 11 | |
num_imgs: Optional[int] = 1 | |
img_size: Optional[int] = 32 | |
def read_root(): | |
return {"message": "Welcome to Image Generator"} | |
async def generate_image(request: ImageRequest): | |
try: | |
img = diffusion_transformer.generate_image_from_text( | |
prompt=request.prompt, | |
class_guidance=request.class_guidance, | |
seed=request.seed, | |
num_imgs=request.num_imgs, | |
img_size=request.img_size, | |
) | |
# Convert PIL image to byte stream suitable for HTTP response | |
img_byte_arr = io.BytesIO() | |
img.save(img_byte_arr, format="JPEG") | |
img_byte_arr.seek(0) | |
return StreamingResponse(img_byte_arr, media_type="image/jpeg") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# build job to test and deploy the API on a docker image (maybe in Azure?) | |