Spaces:
Runtime error
Runtime error
import io | |
import asyncio | |
import os | |
import torch | |
import torchvision.transforms as transforms | |
from diffusion import DiffusionTransformer, LTDConfig | |
# Get the directory of the script | |
script_directory = os.path.dirname(os.path.realpath(__file__)) | |
# Specify the directory where the cache will be stored (same folder as the script) | |
cache_directory = os.path.join(script_directory, "cache") | |
home_directory = os.path.join(script_directory, "home") | |
# Create the cache directory if it doesn't exist | |
os.makedirs(cache_directory, exist_ok=True) | |
os.makedirs(home_directory, exist_ok=True) | |
os.environ["TRANSFORMERS_CACHE"] = cache_directory | |
os.environ["HF_HOME"] = home_directory | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
to_pil = transforms.ToPILImage() | |
ltdconfig = LTDConfig() | |
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here | |
async def generate_image(prompt): | |
try: | |
img = diffusion_transformer.generate_image_from_text( | |
prompt=prompt, | |
class_guidance=6, | |
seed=11, | |
num_imgs=1, | |
img_size=32, | |
) | |
img.save("generated_img.png") | |
except Exception as e: | |
print(e) | |
asyncio.run(generate_image("a cute cat")) |