File size: 1,899 Bytes
3fc663b
 
42559f5
3fc663b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef45f75
 
 
3fc663b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d82ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from tqdm import tqdm
import numpy as np

from pathlib import Path
import json

# torch

import torch

from einops import repeat

# vision imports

from PIL import Image

# dalle related classes and utils

from dalle_pytorch import VQGanVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer

from io import BytesIO
import gradio as gr


# load DALL-E

def exists(val):
    return val is not None

models = json.load(open("model_paths.json"))


vae = VQGanVAE(None, None)

dalles = {}

for name, model_path in models.items():
    assert Path(model_path).exists(), 'trained DALL-E '+model_path+' must exist'
    load_obj = torch.load(model_path)
    dalle_params, _, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')
    dalle_params.pop('vae', None) # cleanup later
    
    dalle = DALLE(vae = vae, **dalle_params).cuda()
    
    dalle.load_state_dict(weights)
    dalles[name] = dalle



batch_size = 4

top_k = 0.9

# generate images

image_size = vae.image_size

def generate(text):
    text_input = text
    num_images = 4
    dalle_name = "weird_car"
    dalle = dalles[dalle_name]

    text = tokenizer.tokenize([text_input], dalle.text_seq_len).cuda()

    text = repeat(text, '() n -> b n', b = num_images)

    outputs = []

    for text_chunk in tqdm(text.split(batch_size), desc = f'generating images for - {text}'):
        output = dalle.generate_images(text_chunk, filter_thres = top_k)
        outputs.append(output)

    outputs = torch.cat(outputs)

    response = []

    for image in tqdm(outputs, desc = 'saving images'):
        np_image = np.moveaxis(image.cpu().numpy(), 0, -1)
        formatted = (np_image * 255).astype('uint8')

        img = Image.fromarray(formatted)
        response.append(img)

    return response

iface = gr.Interface(fn=generate, inputs="text", outputs=gr.outputs.Carousel("image"))
iface.launch(share=True)