File size: 9,029 Bytes
f8a1225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import streamlit as st
import time
from PIL import Image
import matplotlib.pyplot as plt



from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from tqdm import tqdm
import gan_cls_768
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"

def clean(txt):
    txt = txt.lower()
    txt = txt.strip()
    txt = txt.strip('.')
    return txt


max_len = 76

def tokenize(tokenizer, txt):
    return tokenizer(
        txt,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_offsets_mapping=False
    )


def encode(model, tokenizer, txt):
    txt = clean(txt)
    txt_tokenized = tokenize(tokenizer, txt)

    for k, v in txt_tokenized.items():
        txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]

    model.eval()
    with torch.no_grad():
        encoded = model(**txt_tokenized)

    return encoded.last_hidden_state.squeeze()[0].cpu().numpy()


@st.cache_resource
def get_model_roberta():
    model_name = 'roberta-base'
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(
        model_name,
        config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)

    return model, tokenizer


@st.cache_resource
def get_model_gan():
    generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
    generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
    generator.eval()
    return generator
    


def generate_image(text, n):
    model, tokenizer = get_model_roberta()
    generator = get_model_gan()
    
    embed = encode(model, tokenizer, text)
    embed2 = torch.FloatTensor(embed)
    embed2 = embed2.unsqueeze(0)
    right_embed = Variable(embed2.float()).to(device)
    
    l = []
    for i in tqdm(range(n)):
        noise = Variable(torch.randn(1, 100)).to(device)
        noise = noise.view(noise.size(0), 100, 1, 1)
        fake_images = generator(right_embed, noise)
        
        for idx, image in enumerate(fake_images):
            im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
            l.append(im)
    return l




st.set_page_config(
    page_title="ImageGen",
    page_icon="🧊",
    layout="centered",
    initial_sidebar_state="expanded",
    )


hide_st_style = """
            <style>
            #MainMenu {visibility: hidden;}
            footer {visibility: hidden;}
            header {visibility: hidden;}
            </style>
            """
st.markdown(hide_st_style, unsafe_allow_html=True)



examples = [
    "this petal has gorgeous purple petals and a long green pedicel",
    "this petal has gorgeous green petals and a long green pedicel",
    "a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.",
    "salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.",
    "this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.",
    "delicated pink petals clumped on one green pedicel with small sepals.",
    "the flower has big yellow upright petals attached to a thick vine",
    "these bright flowers have many yellow strip petals and stamen.",
    "a large red flower with black dots and a very long stigmas.",
    "this flower has petals that are pink and bell shaped",
    "this flower has petals that are yellow and has black lines",
    "the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre",
    "this flower has orange petals with many dark spots, white stamen, and dark anthers.",
    "this flower has petals that are white and has a yellow style",
    "his flower has petals that are orange and are very thin",
    "a flower with singular conical purple petal and large white pistil.",
    "this flower is yellow in color, and has petals that are very skinny.",
    "a velvet large flower with a dark marking and a green stem.",
    "this flower is yellow in color, and has petals that are very skinny.",
    "the flower has bright yellow soft petals with yellow stamens.",
    "this flower has petals that are pink and has red stamen",
    "this flower has petals that are purple and have dark lines",
    "this purple flower has pointy short petals and green sepal.",
    "this flower has petals that are purple and has a yellow style",
    "this flower is yellow in color, with petals that are skinny and pointed.",
    "the petals on this flower are orange with a purple pistil.",
    "this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.",
    "this purple color flower has the simple row of petals arranged in the circle with the red color pistils at the center",
    "this flower has petals that are red and are very thin",
    "a flower with many folded over bright yellow petals",
    "a flower with no visible petals and purple pistils in the center.",
    "a star shaped flower with five white petals with purple lines running through them.",
    "the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.",
    "this flower features a purple stigma surrounded by pointed waxy orange petals.",
    "this flower is yellow and brown in color, with petals that are oval shaped.",
    "this flower has petals that are white and has a yellow stigma",
    "a flower with folded open and back red petals with black spots and think red anther",
    "this flower has large light red petals and a few white stamen in the center",
    "this flower has bright orange tubular petals rising out of a thick receptacle on a green pedicel.",
    "this flower is a beauty with light red leaves in an equal circle.",
    "a flower with an open conical red petal and white anther supported by red filaments",
    "this flower is red in color, with petals that are bell shaped.",
    "the petals of this flower are yellow with a long stigma",
    ]



def app():

    st.title("Text to Flower")
    st.markdown(
        """
        **Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach.
        Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"*
    """
    )

    
    
    se = st.selectbox("Select from example", 
                             examples)
    
    row1_col1, row1_col2 = st.columns([2, 3])
    width = 950
    height = 600

    with row1_col1:
        caption = st.text_area("Write your flower description here:", se, height=120)
        
        
        backend = st.selectbox(
            "Select a Model", ["Convolutional GAN with RoBERTa", ], index=0
        )

        

        if st.button("Generate", type="primary"):
            with st.spinner("Generating Flower Images..."):
                
                imgs = generate_image(caption, 12)
                #ss = st.success("Scores predicted successfully!")
                
                with row1_col2:
                    st.markdown("Generated Flower Images:")
                    
                    fig, ax = plt.subplots(nrows=3, ncols=4)
                    ax = ax.flatten()
                    
                    for idx, ax in enumerate(ax):
                        ax.imshow(imgs[idx])
                        ax.axis('off')
                    
                    fig.tight_layout()
                    st.pyplot(fig)
                    
                    
    
    
                # with row1_col2:
                #     img1 = Image.open('./images/t2i/1.jpg')
                #     img2 = Image.open('./images/t2i/2.jpg')
                #     img3 = Image.open('./images/t2i/3.jpg')
                #     img4 = Image.open('./images/t2i/4.jpg')
                #     cont = st.container()
                #     with cont:
    
                #         st.write("This is a container with a caption like a button.")
                #         col1, col2, col3, col4 = st.columns(4)
                #         with col1:
                #             st.image(img1, width=128)
                #         with col2:
                #             st.image(img2, width=128)
                #         with col3:
                #             st.image(img3, width=128)
                #         with col4:
                #             st.image(img4, width=128)
    
    
    
    
app()

# # Display a footer with links and credits
st.markdown("---")
st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).")
# #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")