Spaces:
Sleeping
Sleeping
Commit
·
f8a1225
1
Parent(s):
3face26
Upload 6 files
Browse files- T2I.py +112 -0
- gan_cls_768.py +151 -0
- gen_125.pth +3 -0
- main.py +239 -0
- requirements.txt +9 -0
- run.sh +1 -0
T2I.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import gan_cls_768
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from PIL import Image
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
|
12 |
+
def clean(txt):
|
13 |
+
txt = txt.lower()
|
14 |
+
txt = txt.strip()
|
15 |
+
txt = txt.strip('.')
|
16 |
+
return txt
|
17 |
+
|
18 |
+
|
19 |
+
max_len = 76
|
20 |
+
|
21 |
+
|
22 |
+
def tokenize(tokenizer, txt):
|
23 |
+
return tokenizer(
|
24 |
+
txt,
|
25 |
+
max_length=max_len,
|
26 |
+
padding='max_length',
|
27 |
+
truncation=True,
|
28 |
+
return_offsets_mapping=False
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def encode(model_name, model, tokenizer, txt):
|
33 |
+
txt = clean(txt)
|
34 |
+
txt_tokenized = tokenize(tokenizer, txt)
|
35 |
+
|
36 |
+
for k, v in txt_tokenized.items():
|
37 |
+
txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
|
38 |
+
|
39 |
+
model.eval()
|
40 |
+
with torch.no_grad():
|
41 |
+
encoded = model(**txt_tokenized)
|
42 |
+
|
43 |
+
return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
|
44 |
+
|
45 |
+
|
46 |
+
model_name = 'roberta-base'
|
47 |
+
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
49 |
+
model = AutoModel.from_pretrained(
|
50 |
+
model_name,
|
51 |
+
config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
def generate_image(text, n):
|
56 |
+
embed = encode(model_name, model, tokenizer, text)
|
57 |
+
|
58 |
+
generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
|
59 |
+
generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
|
60 |
+
generator.eval()
|
61 |
+
|
62 |
+
embed2 = torch.FloatTensor(embed)
|
63 |
+
embed2 = embed2.unsqueeze(0)
|
64 |
+
right_embed = Variable(embed2.float()).to(device)
|
65 |
+
|
66 |
+
l = []
|
67 |
+
for i in tqdm(range(n)):
|
68 |
+
noise = Variable(torch.randn(1, 100)).to(device)
|
69 |
+
noise = noise.view(noise.size(0), 100, 1, 1)
|
70 |
+
fake_images = generator(right_embed, noise)
|
71 |
+
|
72 |
+
for idx, image in enumerate(fake_images):
|
73 |
+
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
|
74 |
+
l.append(im)
|
75 |
+
return l
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == '__main__':
|
80 |
+
|
81 |
+
|
82 |
+
n = 10
|
83 |
+
imgs = generate_image('Red images', n)
|
84 |
+
|
85 |
+
|
86 |
+
fig, ax = plt.subplots(nrows=5, ncols=2)
|
87 |
+
ax = ax.flatten()
|
88 |
+
|
89 |
+
for idx, ax in enumerate(ax):
|
90 |
+
|
91 |
+
ax.imshow(imgs[idx])
|
92 |
+
ax.axis('off')
|
93 |
+
|
94 |
+
|
95 |
+
fig.tight_layout()
|
96 |
+
|
97 |
+
plt.show()
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
# while True:
|
102 |
+
# print('Type Caption: ')
|
103 |
+
# txt = input()
|
104 |
+
# print('Generating images...')
|
105 |
+
# generate_image(txt)
|
106 |
+
# print('Completed')
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
gan_cls_768.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.autograd import Variable
|
4 |
+
import numpy as np
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn import init
|
9 |
+
|
10 |
+
|
11 |
+
'''
|
12 |
+
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
class Concat_embed4(nn.Module):
|
17 |
+
|
18 |
+
def __init__(self, embed_dim, projected_embed_dim):
|
19 |
+
super(Concat_embed4, self).__init__()
|
20 |
+
self.projection = nn.Sequential(
|
21 |
+
nn.Linear(in_features=embed_dim, out_features=embed_dim),
|
22 |
+
nn.BatchNorm1d(num_features=embed_dim),
|
23 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
24 |
+
|
25 |
+
nn.Linear(in_features=embed_dim, out_features=embed_dim),
|
26 |
+
nn.BatchNorm1d(num_features=embed_dim),
|
27 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
28 |
+
|
29 |
+
nn.Linear(in_features=embed_dim, out_features=projected_embed_dim),
|
30 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, inp, embed):
|
34 |
+
projected_embed = self.projection(embed)
|
35 |
+
replicated_embed = projected_embed.repeat(4, 4, 1, 1).permute(2, 3, 0, 1)
|
36 |
+
hidden_concat = torch.cat([inp, replicated_embed], 1)
|
37 |
+
return hidden_concat
|
38 |
+
|
39 |
+
|
40 |
+
class generator(nn.Module):
|
41 |
+
def __init__(self):
|
42 |
+
super(generator, self).__init__()
|
43 |
+
self.image_size = 64
|
44 |
+
self.num_channels = 3
|
45 |
+
self.noise_dim = 100
|
46 |
+
self.embed_dim = 768
|
47 |
+
self.projected_embed_dim = 128
|
48 |
+
self.latent_dim = self.noise_dim + self.projected_embed_dim
|
49 |
+
self.ngf = 64
|
50 |
+
|
51 |
+
self.projection = nn.Sequential(
|
52 |
+
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
|
53 |
+
nn.BatchNorm1d(num_features=self.embed_dim),
|
54 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
55 |
+
|
56 |
+
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
|
57 |
+
nn.BatchNorm1d(num_features=self.embed_dim),
|
58 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
59 |
+
|
60 |
+
nn.Linear(in_features=self.embed_dim, out_features=self.projected_embed_dim),
|
61 |
+
nn.BatchNorm1d(num_features=self.projected_embed_dim),
|
62 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
63 |
+
)
|
64 |
+
|
65 |
+
self.netG = nn.ModuleList([
|
66 |
+
nn.ConvTranspose2d(self.latent_dim, self.ngf * 8, 4, 1, 0, bias=False),
|
67 |
+
nn.BatchNorm2d(self.ngf * 8),
|
68 |
+
nn.ReLU(True),
|
69 |
+
|
70 |
+
|
71 |
+
# state size. (ngf*8) x 4 x 4
|
72 |
+
nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
|
73 |
+
nn.BatchNorm2d(self.ngf * 4),
|
74 |
+
nn.ReLU(True),
|
75 |
+
|
76 |
+
# state size. (ngf*4) x 8 x 8
|
77 |
+
nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
|
78 |
+
nn.BatchNorm2d(self.ngf * 2),
|
79 |
+
nn.ReLU(True),
|
80 |
+
|
81 |
+
# state size. (ngf*2) x 16 x 16
|
82 |
+
nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
|
83 |
+
nn.BatchNorm2d(self.ngf),
|
84 |
+
nn.ReLU(True),
|
85 |
+
|
86 |
+
# state size. (ngf) x 32 x 32
|
87 |
+
nn.ConvTranspose2d(self.ngf, self.num_channels, 4, 2, 1, bias=False),
|
88 |
+
nn.Tanh()
|
89 |
+
# state size. (num_channels) x 64 x 64
|
90 |
+
])
|
91 |
+
|
92 |
+
def forward(self, embed_vector, z):
|
93 |
+
projected_embed = self.projection(embed_vector)
|
94 |
+
out = torch.cat([projected_embed.unsqueeze(2).unsqueeze(3), z], 1)
|
95 |
+
for m in self.netG:
|
96 |
+
out = m(out)
|
97 |
+
return out
|
98 |
+
|
99 |
+
|
100 |
+
class discriminator(nn.Module):
|
101 |
+
def __init__(self):
|
102 |
+
super(discriminator, self).__init__()
|
103 |
+
self.image_size = 64
|
104 |
+
self.num_channels = 3
|
105 |
+
self.embed_dim = 768
|
106 |
+
self.projected_embed_dim = 128
|
107 |
+
self.ndf = 64
|
108 |
+
self.B_dim = 128
|
109 |
+
self.C_dim = 16
|
110 |
+
|
111 |
+
self.netD_1 = nn.Sequential(
|
112 |
+
# input is (nc) x 64 x 64
|
113 |
+
nn.Conv2d(self.num_channels, self.ndf, 4, 2, 1, bias=False),
|
114 |
+
nn.LeakyReLU(0.2, inplace=True),
|
115 |
+
# state size. (ndf) x 32 x 32
|
116 |
+
|
117 |
+
# SelfAttention(self.ndf),
|
118 |
+
nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
|
119 |
+
nn.BatchNorm2d(self.ndf * 2),
|
120 |
+
nn.LeakyReLU(0.2, inplace=True),
|
121 |
+
|
122 |
+
# state size. (ndf*2) x 16 x 16
|
123 |
+
|
124 |
+
nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
|
125 |
+
nn.BatchNorm2d(self.ndf * 4),
|
126 |
+
nn.LeakyReLU(0.2, inplace=True),
|
127 |
+
|
128 |
+
# state size. (ndf*4) x 8 x 8
|
129 |
+
nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
|
130 |
+
nn.BatchNorm2d(self.ndf * 8),
|
131 |
+
nn.LeakyReLU(0.2, inplace=True),
|
132 |
+
)
|
133 |
+
|
134 |
+
self.projector = Concat_embed4(self.embed_dim, self.projected_embed_dim)
|
135 |
+
|
136 |
+
self.netD_2 = nn.Sequential(
|
137 |
+
# state size. (ndf*8) x 4 x 4
|
138 |
+
nn.Conv2d(self.ndf * 8 + self.projected_embed_dim,
|
139 |
+
self.ndf * 8, 1, 1, 0, bias=False),
|
140 |
+
nn.BatchNorm2d(self.ndf * 8),
|
141 |
+
nn.LeakyReLU(0.2, inplace=True),
|
142 |
+
nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
|
143 |
+
nn.Sigmoid()
|
144 |
+
)
|
145 |
+
|
146 |
+
def forward(self, inp, embed):
|
147 |
+
x_intermediate = self.netD_1(inp)
|
148 |
+
x = self.projector(x_intermediate, embed)
|
149 |
+
x = self.netD_2(x)
|
150 |
+
|
151 |
+
return x.view(-1, 1).squeeze(1), x_intermediate
|
gen_125.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd835271c23087d4cf25a974b51e6680592d906da9cd20159a060123fdc7b8c5
|
3 |
+
size 23668507
|
main.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import time
|
3 |
+
from PIL import Image
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import gan_cls_768
|
12 |
+
from torch.autograd import Variable
|
13 |
+
from PIL import Image
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
|
17 |
+
def clean(txt):
|
18 |
+
txt = txt.lower()
|
19 |
+
txt = txt.strip()
|
20 |
+
txt = txt.strip('.')
|
21 |
+
return txt
|
22 |
+
|
23 |
+
|
24 |
+
max_len = 76
|
25 |
+
|
26 |
+
def tokenize(tokenizer, txt):
|
27 |
+
return tokenizer(
|
28 |
+
txt,
|
29 |
+
max_length=max_len,
|
30 |
+
padding='max_length',
|
31 |
+
truncation=True,
|
32 |
+
return_offsets_mapping=False
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def encode(model, tokenizer, txt):
|
37 |
+
txt = clean(txt)
|
38 |
+
txt_tokenized = tokenize(tokenizer, txt)
|
39 |
+
|
40 |
+
for k, v in txt_tokenized.items():
|
41 |
+
txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
|
42 |
+
|
43 |
+
model.eval()
|
44 |
+
with torch.no_grad():
|
45 |
+
encoded = model(**txt_tokenized)
|
46 |
+
|
47 |
+
return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
|
48 |
+
|
49 |
+
|
50 |
+
@st.cache_resource
|
51 |
+
def get_model_roberta():
|
52 |
+
model_name = 'roberta-base'
|
53 |
+
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
55 |
+
model = AutoModel.from_pretrained(
|
56 |
+
model_name,
|
57 |
+
config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
|
58 |
+
|
59 |
+
return model, tokenizer
|
60 |
+
|
61 |
+
|
62 |
+
@st.cache_resource
|
63 |
+
def get_model_gan():
|
64 |
+
generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
|
65 |
+
generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
|
66 |
+
generator.eval()
|
67 |
+
return generator
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
def generate_image(text, n):
|
72 |
+
model, tokenizer = get_model_roberta()
|
73 |
+
generator = get_model_gan()
|
74 |
+
|
75 |
+
embed = encode(model, tokenizer, text)
|
76 |
+
embed2 = torch.FloatTensor(embed)
|
77 |
+
embed2 = embed2.unsqueeze(0)
|
78 |
+
right_embed = Variable(embed2.float()).to(device)
|
79 |
+
|
80 |
+
l = []
|
81 |
+
for i in tqdm(range(n)):
|
82 |
+
noise = Variable(torch.randn(1, 100)).to(device)
|
83 |
+
noise = noise.view(noise.size(0), 100, 1, 1)
|
84 |
+
fake_images = generator(right_embed, noise)
|
85 |
+
|
86 |
+
for idx, image in enumerate(fake_images):
|
87 |
+
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
|
88 |
+
l.append(im)
|
89 |
+
return l
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
st.set_page_config(
|
95 |
+
page_title="ImageGen",
|
96 |
+
page_icon="🧊",
|
97 |
+
layout="centered",
|
98 |
+
initial_sidebar_state="expanded",
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
hide_st_style = """
|
103 |
+
<style>
|
104 |
+
#MainMenu {visibility: hidden;}
|
105 |
+
footer {visibility: hidden;}
|
106 |
+
header {visibility: hidden;}
|
107 |
+
</style>
|
108 |
+
"""
|
109 |
+
st.markdown(hide_st_style, unsafe_allow_html=True)
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
examples = [
|
114 |
+
"this petal has gorgeous purple petals and a long green pedicel",
|
115 |
+
"this petal has gorgeous green petals and a long green pedicel",
|
116 |
+
"a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.",
|
117 |
+
"salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.",
|
118 |
+
"this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.",
|
119 |
+
"delicated pink petals clumped on one green pedicel with small sepals.",
|
120 |
+
"the flower has big yellow upright petals attached to a thick vine",
|
121 |
+
"these bright flowers have many yellow strip petals and stamen.",
|
122 |
+
"a large red flower with black dots and a very long stigmas.",
|
123 |
+
"this flower has petals that are pink and bell shaped",
|
124 |
+
"this flower has petals that are yellow and has black lines",
|
125 |
+
"the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre",
|
126 |
+
"this flower has orange petals with many dark spots, white stamen, and dark anthers.",
|
127 |
+
"this flower has petals that are white and has a yellow style",
|
128 |
+
"his flower has petals that are orange and are very thin",
|
129 |
+
"a flower with singular conical purple petal and large white pistil.",
|
130 |
+
"this flower is yellow in color, and has petals that are very skinny.",
|
131 |
+
"a velvet large flower with a dark marking and a green stem.",
|
132 |
+
"this flower is yellow in color, and has petals that are very skinny.",
|
133 |
+
"the flower has bright yellow soft petals with yellow stamens.",
|
134 |
+
"this flower has petals that are pink and has red stamen",
|
135 |
+
"this flower has petals that are purple and have dark lines",
|
136 |
+
"this purple flower has pointy short petals and green sepal.",
|
137 |
+
"this flower has petals that are purple and has a yellow style",
|
138 |
+
"this flower is yellow in color, with petals that are skinny and pointed.",
|
139 |
+
"the petals on this flower are orange with a purple pistil.",
|
140 |
+
"this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.",
|
141 |
+
"this purple color flower has the simple row of petals arranged in the circle with the red color pistils at the center",
|
142 |
+
"this flower has petals that are red and are very thin",
|
143 |
+
"a flower with many folded over bright yellow petals",
|
144 |
+
"a flower with no visible petals and purple pistils in the center.",
|
145 |
+
"a star shaped flower with five white petals with purple lines running through them.",
|
146 |
+
"the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.",
|
147 |
+
"this flower features a purple stigma surrounded by pointed waxy orange petals.",
|
148 |
+
"this flower is yellow and brown in color, with petals that are oval shaped.",
|
149 |
+
"this flower has petals that are white and has a yellow stigma",
|
150 |
+
"a flower with folded open and back red petals with black spots and think red anther",
|
151 |
+
"this flower has large light red petals and a few white stamen in the center",
|
152 |
+
"this flower has bright orange tubular petals rising out of a thick receptacle on a green pedicel.",
|
153 |
+
"this flower is a beauty with light red leaves in an equal circle.",
|
154 |
+
"a flower with an open conical red petal and white anther supported by red filaments",
|
155 |
+
"this flower is red in color, with petals that are bell shaped.",
|
156 |
+
"the petals of this flower are yellow with a long stigma",
|
157 |
+
]
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
def app():
|
162 |
+
|
163 |
+
st.title("Text to Flower")
|
164 |
+
st.markdown(
|
165 |
+
"""
|
166 |
+
**Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach.
|
167 |
+
Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"*
|
168 |
+
"""
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
se = st.selectbox("Select from example",
|
174 |
+
examples)
|
175 |
+
|
176 |
+
row1_col1, row1_col2 = st.columns([2, 3])
|
177 |
+
width = 950
|
178 |
+
height = 600
|
179 |
+
|
180 |
+
with row1_col1:
|
181 |
+
caption = st.text_area("Write your flower description here:", se, height=120)
|
182 |
+
|
183 |
+
|
184 |
+
backend = st.selectbox(
|
185 |
+
"Select a Model", ["Convolutional GAN with RoBERTa", ], index=0
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
if st.button("Generate", type="primary"):
|
191 |
+
with st.spinner("Generating Flower Images..."):
|
192 |
+
|
193 |
+
imgs = generate_image(caption, 12)
|
194 |
+
#ss = st.success("Scores predicted successfully!")
|
195 |
+
|
196 |
+
with row1_col2:
|
197 |
+
st.markdown("Generated Flower Images:")
|
198 |
+
|
199 |
+
fig, ax = plt.subplots(nrows=3, ncols=4)
|
200 |
+
ax = ax.flatten()
|
201 |
+
|
202 |
+
for idx, ax in enumerate(ax):
|
203 |
+
ax.imshow(imgs[idx])
|
204 |
+
ax.axis('off')
|
205 |
+
|
206 |
+
fig.tight_layout()
|
207 |
+
st.pyplot(fig)
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
# with row1_col2:
|
213 |
+
# img1 = Image.open('./images/t2i/1.jpg')
|
214 |
+
# img2 = Image.open('./images/t2i/2.jpg')
|
215 |
+
# img3 = Image.open('./images/t2i/3.jpg')
|
216 |
+
# img4 = Image.open('./images/t2i/4.jpg')
|
217 |
+
# cont = st.container()
|
218 |
+
# with cont:
|
219 |
+
|
220 |
+
# st.write("This is a container with a caption like a button.")
|
221 |
+
# col1, col2, col3, col4 = st.columns(4)
|
222 |
+
# with col1:
|
223 |
+
# st.image(img1, width=128)
|
224 |
+
# with col2:
|
225 |
+
# st.image(img2, width=128)
|
226 |
+
# with col3:
|
227 |
+
# st.image(img3, width=128)
|
228 |
+
# with col4:
|
229 |
+
# st.image(img4, width=128)
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
app()
|
235 |
+
|
236 |
+
# # Display a footer with links and credits
|
237 |
+
st.markdown("---")
|
238 |
+
st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).")
|
239 |
+
# #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
streamlit==1.21.0
|
3 |
+
Pillow
|
4 |
+
torch==2.0.1
|
5 |
+
numpy
|
6 |
+
transformers==4.30.2
|
7 |
+
tokenizers==0.13.3
|
8 |
+
matplotlib==3.7.1
|
9 |
+
|
run.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
streamlit run main.py --server.runOnSave True
|