DataRaptor commited on
Commit
3419697
·
1 Parent(s): f8a1225

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -0
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+
8
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
9
+ import torch
10
+ from tqdm import tqdm
11
+ import gan_cls_768
12
+ from torch.autograd import Variable
13
+ from PIL import Image
14
+ import matplotlib.pyplot as plt
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ def clean(txt):
18
+ txt = txt.lower()
19
+ txt = txt.strip()
20
+ txt = txt.strip('.')
21
+ return txt
22
+
23
+
24
+ max_len = 76
25
+
26
+ def tokenize(tokenizer, txt):
27
+ return tokenizer(
28
+ txt,
29
+ max_length=max_len,
30
+ padding='max_length',
31
+ truncation=True,
32
+ return_offsets_mapping=False
33
+ )
34
+
35
+
36
+ def encode(model, tokenizer, txt):
37
+ txt = clean(txt)
38
+ txt_tokenized = tokenize(tokenizer, txt)
39
+
40
+ for k, v in txt_tokenized.items():
41
+ txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
42
+
43
+ model.eval()
44
+ with torch.no_grad():
45
+ encoded = model(**txt_tokenized)
46
+
47
+ return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
48
+
49
+
50
+ @st.cache_resource
51
+ def get_model_roberta():
52
+ model_name = 'roberta-base'
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ model = AutoModel.from_pretrained(
56
+ model_name,
57
+ config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
58
+
59
+ return model, tokenizer
60
+
61
+
62
+ @st.cache_resource
63
+ def get_model_gan():
64
+ generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
65
+ generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
66
+ generator.eval()
67
+ return generator
68
+
69
+
70
+
71
+ def generate_image(text, n):
72
+ model, tokenizer = get_model_roberta()
73
+ generator = get_model_gan()
74
+
75
+ embed = encode(model, tokenizer, text)
76
+ embed2 = torch.FloatTensor(embed)
77
+ embed2 = embed2.unsqueeze(0)
78
+ right_embed = Variable(embed2.float()).to(device)
79
+
80
+ l = []
81
+ for i in tqdm(range(n)):
82
+ noise = Variable(torch.randn(1, 100)).to(device)
83
+ noise = noise.view(noise.size(0), 100, 1, 1)
84
+ fake_images = generator(right_embed, noise)
85
+
86
+ for idx, image in enumerate(fake_images):
87
+ im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
88
+ l.append(im)
89
+ return l
90
+
91
+
92
+
93
+
94
+ st.set_page_config(
95
+ page_title="ImageGen",
96
+ page_icon="🧊",
97
+ layout="centered",
98
+ initial_sidebar_state="expanded",
99
+ )
100
+
101
+
102
+ hide_st_style = """
103
+ <style>
104
+ #MainMenu {visibility: hidden;}
105
+ footer {visibility: hidden;}
106
+ header {visibility: hidden;}
107
+ </style>
108
+ """
109
+ st.markdown(hide_st_style, unsafe_allow_html=True)
110
+
111
+
112
+
113
+ examples = [
114
+ "this petal has gorgeous purple petals and a long green pedicel",
115
+ "this petal has gorgeous green petals and a long green pedicel",
116
+ "a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.",
117
+ "salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.",
118
+ "this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.",
119
+ "delicated pink petals clumped on one green pedicel with small sepals.",
120
+ "the flower has big yellow upright petals attached to a thick vine",
121
+ "these bright flowers have many yellow strip petals and stamen.",
122
+ "a large red flower with black dots and a very long stigmas.",
123
+ "this flower has petals that are pink and bell shaped",
124
+ "this flower has petals that are yellow and has black lines",
125
+ "the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre",
126
+ "this flower has orange petals with many dark spots, white stamen, and dark anthers.",
127
+ "this flower has petals that are white and has a yellow style",
128
+ "his flower has petals that are orange and are very thin",
129
+ "a flower with singular conical purple petal and large white pistil.",
130
+ "this flower is yellow in color, and has petals that are very skinny.",
131
+ "a velvet large flower with a dark marking and a green stem.",
132
+ "this flower is yellow in color, and has petals that are very skinny.",
133
+ "the flower has bright yellow soft petals with yellow stamens.",
134
+ "this flower has petals that are pink and has red stamen",
135
+ "this flower has petals that are purple and have dark lines",
136
+ "this purple flower has pointy short petals and green sepal.",
137
+ "this flower has petals that are purple and has a yellow style",
138
+ "this flower is yellow in color, with petals that are skinny and pointed.",
139
+ "the petals on this flower are orange with a purple pistil.",
140
+ "this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.",
141
+ "this purple color flower has the simple row of petals arranged in the circle with the red color pistils at the center",
142
+ "this flower has petals that are red and are very thin",
143
+ "a flower with many folded over bright yellow petals",
144
+ "a flower with no visible petals and purple pistils in the center.",
145
+ "a star shaped flower with five white petals with purple lines running through them.",
146
+ "the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.",
147
+ "this flower features a purple stigma surrounded by pointed waxy orange petals.",
148
+ "this flower is yellow and brown in color, with petals that are oval shaped.",
149
+ "this flower has petals that are white and has a yellow stigma",
150
+ "a flower with folded open and back red petals with black spots and think red anther",
151
+ "this flower has large light red petals and a few white stamen in the center",
152
+ "this flower has bright orange tubular petals rising out of a thick receptacle on a green pedicel.",
153
+ "this flower is a beauty with light red leaves in an equal circle.",
154
+ "a flower with an open conical red petal and white anther supported by red filaments",
155
+ "this flower is red in color, with petals that are bell shaped.",
156
+ "the petals of this flower are yellow with a long stigma",
157
+ ]
158
+
159
+
160
+
161
+ def app():
162
+
163
+ st.title("Text to Flower")
164
+ st.markdown(
165
+ """
166
+ **Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach.
167
+ Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"*
168
+ """
169
+ )
170
+
171
+
172
+
173
+ se = st.selectbox("Select from example",
174
+ examples)
175
+
176
+ row1_col1, row1_col2 = st.columns([2, 3])
177
+ width = 950
178
+ height = 600
179
+
180
+ with row1_col1:
181
+ caption = st.text_area("Write your flower description here:", se, height=120)
182
+
183
+
184
+ backend = st.selectbox(
185
+ "Select a Model", ["Convolutional GAN with RoBERTa", ], index=0
186
+ )
187
+
188
+
189
+
190
+ if st.button("Generate", type="primary"):
191
+ with st.spinner("Generating Flower Images..."):
192
+
193
+ imgs = generate_image(caption, 12)
194
+ #ss = st.success("Scores predicted successfully!")
195
+
196
+ with row1_col2:
197
+ st.markdown("Generated Flower Images:")
198
+
199
+ fig, ax = plt.subplots(nrows=3, ncols=4)
200
+ ax = ax.flatten()
201
+
202
+ for idx, ax in enumerate(ax):
203
+ ax.imshow(imgs[idx])
204
+ ax.axis('off')
205
+
206
+ fig.tight_layout()
207
+ st.pyplot(fig)
208
+
209
+
210
+
211
+
212
+ # with row1_col2:
213
+ # img1 = Image.open('./images/t2i/1.jpg')
214
+ # img2 = Image.open('./images/t2i/2.jpg')
215
+ # img3 = Image.open('./images/t2i/3.jpg')
216
+ # img4 = Image.open('./images/t2i/4.jpg')
217
+ # cont = st.container()
218
+ # with cont:
219
+
220
+ # st.write("This is a container with a caption like a button.")
221
+ # col1, col2, col3, col4 = st.columns(4)
222
+ # with col1:
223
+ # st.image(img1, width=128)
224
+ # with col2:
225
+ # st.image(img2, width=128)
226
+ # with col3:
227
+ # st.image(img3, width=128)
228
+ # with col4:
229
+ # st.image(img4, width=128)
230
+
231
+
232
+
233
+
234
+ app()
235
+
236
+ # # Display a footer with links and credits
237
+ st.markdown("---")
238
+ st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).")
239
+ # #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")