Spaces:
Runtime error
Runtime error
from tqdm import tqdm | |
import numpy as np | |
from pathlib import Path | |
import json | |
# torch | |
import torch | |
from einops import repeat | |
# vision imports | |
from PIL import Image | |
# dalle related classes and utils | |
from dalle_pytorch import VQGanVAE, DALLE | |
from dalle_pytorch.tokenizer import tokenizer | |
from io import BytesIO | |
import gradio as gr | |
# load DALL-E | |
def exists(val): | |
return val is not None | |
models = json.load(open("model_paths.json")) | |
vae = VQGanVAE(None, None) | |
dalles = {} | |
for name, model_path in models.items(): | |
assert Path(model_path).exists(), 'trained DALL-E '+model_path+' must exist' | |
load_obj = torch.load(model_path) | |
dalle_params, _, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights') | |
dalle_params.pop('vae', None) # cleanup later | |
dalle = DALLE(vae = vae, **dalle_params).cuda() | |
dalle.load_state_dict(weights) | |
dalles[name] = dalle | |
batch_size = 4 | |
top_k = 0.9 | |
# generate images | |
image_size = vae.image_size | |
def generate(text): | |
text_input = text | |
num_images = 4 | |
dalle_name = "weird_car" | |
dalle = dalles[dalle_name] | |
text = tokenizer.tokenize([text_input], dalle.text_seq_len).cuda() | |
text = repeat(text, '() n -> b n', b = num_images) | |
outputs = [] | |
for text_chunk in tqdm(text.split(batch_size), desc = f'generating images for - {text}'): | |
output = dalle.generate_images(text_chunk, filter_thres = top_k) | |
outputs.append(output) | |
outputs = torch.cat(outputs) | |
response = [] | |
for image in tqdm(outputs, desc = 'saving images'): | |
np_image = np.moveaxis(image.cpu().numpy(), 0, -1) | |
formatted = (np_image * 255).astype('uint8') | |
img = Image.fromarray(formatted) | |
response.append(img) | |
return response | |
iface = gr.Interface(fn=generate, inputs="text", outputs=gr.outputs.Carousel("image")) | |
iface.launch(share=True) |