image-gen / tld /gen_img.py
BeveledCube's picture
Hopefully fixed shi
0ea1363
raw
history blame
1.2 kB
import io
import asyncio
import os
import torch
import torchvision.transforms as transforms
from diffusion import DiffusionTransformer, LTDConfig
# Get the directory of the script
script_directory = os.path.dirname(os.path.realpath(__file__))
# Specify the directory where the cache will be stored (same folder as the script)
cache_directory = os.path.join(script_directory, "cache")
home_directory = os.path.join(script_directory, "home")
# Create the cache directory if it doesn't exist
os.makedirs(cache_directory, exist_ok=True)
os.makedirs(home_directory, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = cache_directory
os.environ["HF_HOME"] = home_directory
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
to_pil = transforms.ToPILImage()
ltdconfig = LTDConfig()
diffusion_transformer = DiffusionTransformer(ltdconfig) #Downloads model here
async def generate_image(prompt):
try:
img = diffusion_transformer.generate_image_from_text(
prompt=prompt,
class_guidance=6,
seed=11,
num_imgs=1,
img_size=32,
)
img.save("generated_img.png")
except Exception as e:
print(e)
asyncio.run(generate_image("a cute cat"))