Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import requests | |
from tld.denoiser import Denoiser | |
from tld.diffusion import DiffusionGenerator | |
from diffusers import AutoencoderKL, AutoencoderTiny | |
from tqdm import tqdm | |
import clip | |
import torch | |
import numpy as np | |
import torchvision.utils as vutils | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader, TensorDataset | |
from PIL import Image | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
to_pil = transforms.ToPILImage() | |
def download_file(url, filename): | |
with requests.get(url, stream=True) as r: | |
r.raise_for_status() | |
with open(filename, 'wb') as f: | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
def encode_text(label, model): | |
text_tokens = clip.tokenize(label, truncate=True).to(device) | |
text_encoding = model.encode_text(text_tokens) | |
return text_encoding.cpu() | |
def generate_image_from_text(prompt, class_guidance=6, seed=11, num_imgs=1, img_size = 32): | |
n_iter = 15 | |
nrow = int(np.sqrt(num_imgs)) | |
cur_prompts = [prompt]*num_imgs | |
labels = encode_text(cur_prompts, clip_model) | |
out, out_latent = diffuser.generate(labels=labels, | |
num_imgs=num_imgs, | |
class_guidance=class_guidance, | |
seed=seed, | |
n_iter=n_iter, | |
exponent=1, | |
scale_factor=8, | |
sharp_f=0, | |
bright_f=0 | |
) | |
out = to_pil((vutils.make_grid((out+1)/2, nrow=nrow, padding=4)).float().clip(0, 1)) | |
out.save(f'{prompt}_cfg:{class_guidance}_seed:{seed}.png') | |
print("Images Generated and Saved. They will shortly output below.") | |
return out | |
###config: | |
vae_scale_factor = 8 | |
img_size = 32 | |
model_dtype = torch.float32 | |
file_url = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth" | |
local_filename = "state_dict_378000.pth" | |
download_file(file_url, local_filename) | |
denoiser = Denoiser(image_size=32, noise_embed_dims=256, patch_size=2, | |
embed_dim=768, dropout=0, n_layers=12) | |
state_dict = torch.load('state_dict_378000.pth', map_location=torch.device('cpu')) | |
denoiser = denoiser.to(model_dtype) | |
denoiser.load_state_dict(state_dict) | |
denoiser = denoiser.to(device) | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", | |
torch_dtype=model_dtype).to(device) | |
clip_model, preprocess = clip.load("ViT-L/14") | |
clip_model = clip_model.to(device) | |
diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype) | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=generate_image_from_text, # The function to generate the image | |
inputs=["text", "slider"], | |
outputs="image", | |
title="Text-to-Image Generator", | |
description="Enter a text prompt to generate an image." | |
) | |
# Launch the app | |
iface.launch() |