File size: 4,331 Bytes
2d0cabb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5241f9
2d0cabb
01210e1
2d0cabb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import pipeline

import torch
import os
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, one_hot_from_names, one_hot_from_int

config = {
    "model_name": "keras-io/multimodal-entailment",
    "base_model_name": "distilbert-base-uncased",
    "image_gen_model": "biggan-deep-512",
    "max_length": 20,
    "freeze_text_model": True,
    "freeze_image_gen_model": True,
    "text_embedding_dim": 768,
    "class_embedding_dim": 128
}
truncation=0.4

is_gpu = False
device = torch.device('cuda') if is_gpu else torch.device('cpu')
print(device)

model = AutoModelForSequenceClassification.from_pretrained(config["model_name"], use_auth_token=os.environ.get(
    'huggingface-api-token'))
tokenizer = AutoTokenizer.from_pretrained(config["base_model_name"])
model.to(device)
model.eval()

gan_model = BigGAN.from_pretrained(config["image_gen_model"])
gan_model.to(device)
gan_model.eval()
print("Models were loaded")


def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4):
    seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None
    noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)
    noise_vector = torch.from_numpy(noise_vector)
    if int_index is not None:
        class_vector = one_hot_from_int([int_index], batch_size=1)
        class_vector = torch.from_numpy(class_vector)
        dense_class_vector = gan_model.embeddings(class_vector)
    else:
        if isinstance(dense_class_vector, np.ndarray):
            dense_class_vector = torch.tensor(dense_class_vector)
        dense_class_vector = dense_class_vector.view(1, 128)

    input_vector = torch.cat([noise_vector, dense_class_vector], dim=1)

    # Generate an image
    with torch.no_grad():
        output = gan_model.generator(input_vector, truncation)
    output = output.cpu().numpy()
    output = output.transpose((0, 2, 3, 1))
    output = ((output + 1.0) / 2.0) * 256
    output.clip(0, 255, out=output)
    output = np.asarray(np.uint8(output[0]), dtype=np.uint8)
    return output


def print_image(numpy_array):
    """ Utility function to print a numpy uint8 array as an image
    """
    img = Image.fromarray(numpy_array)
    plt.imshow(img)
    plt.show()


def text_to_image(text):
    tokens = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt').to(device)
    with torch.no_grad():
        lm_output = model(tokens, return_dict=True)
        pred_int_index = torch.argmax(lm_output.logits[0], dim=-1).cpu().detach().numpy().tolist()
        print(pred_int_index)

    # Now generate an image (a numpy array)
    numpy_image = generate_image(int_index=pred_int_index,
                                 truncation=truncation,
                                 noise_seed_vector=tokens)

    img = Image.fromarray(numpy_image)
    #print_image(numpy_image)
    return img

examples = ["a high resoltuion photo of a pizza from famous food magzine.",
            "this is a photo of my pet golden retriever.",
            "this is a photo of a trouble some street cat.",
            "a blur image of coral reef.",
            "a yellow taxi cab commonly found in USA.",
            "Once upon a time, there was a black ship full of pirates.",
            "a photo of a large castle.",
            "a sketch of an old Church"]

if __name__ == '__main__':
    interFace = gr.Interface(fn=text_to_image,
                             inputs=gr.inputs.Textbox(placeholder="Enter the text to generate an image", label="Text "
                                                                                                           "query",
                                                      lines=1),
                             outputs=gr.outputs.Image(type="auto", label="Generated Image"),
                             verbose=True,
                             examples=examples,
                             title="Generate Image from Text",
                             description="",
                             theme="huggingface")
    interFace.launch()