acvss24 commited on
Commit
1b9e93b
·
verified ·
1 Parent(s): da76f9a
Files changed (2) hide show
  1. app.py +254 -4
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,13 +1,264 @@
1
  import gradio as gr
2
  import PIL.Image as Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def process_image(image):
6
- return "English text","Yoruba text", "Swahili text", "Twi text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  app = gr.Interface(
8
  fn=process_image,
9
- inputs = gr.Image(type="pil", label="Upload an Image"),
10
-
11
  outputs = [
12
  gr.Text(label="English: "),
13
  gr.Text(label="Yoruba: "),
@@ -22,4 +273,3 @@ app = gr.Interface(
22
  )
23
  if __name__ == "__main__":
24
  app.launch()
25
-
 
1
  import gradio as gr
2
  import PIL.Image as Image
3
+ import clip
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+ import torch.nn.functional as nnf
8
+ from typing import Tuple, List, Union, Optional
9
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
10
+ from tqdm import tqdm, trange
11
+ import skimage.io as io
12
+ import os
13
+ import requests
14
+
15
+ N = type(None)
16
+ V = np.array
17
+ ARRAY = np.ndarray
18
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
19
+ VS = Union[Tuple[V, ...], List[V]]
20
+ VN = Union[V, N]
21
+ VNS = Union[VS, N]
22
+ T = torch.Tensor
23
+ TS = Union[Tuple[T, ...], List[T]]
24
+ TN = Optional[T]
25
+ TNS = Union[Tuple[TN, ...], List[TN]]
26
+ TSN = Optional[TS]
27
+ TA = Union[T, ARRAY]
28
+
29
+ D = torch.device
30
+ CPU = torch.device('cpu')
31
+
32
+ current_directory = os.getcwd()
33
+ save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
34
+ os.makedirs(save_path, exist_ok=True)
35
+ model_path = os.path.join(save_path, 'model_weights.pt')
36
+
37
+ class MLP(nn.Module):
38
+
39
+ def forward(self, x: T) -> T:
40
+ return self.model(x)
41
+
42
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
43
+ super(MLP, self).__init__()
44
+ layers = []
45
+ for i in range(len(sizes) -1):
46
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
47
+ if i < len(sizes) - 2:
48
+ layers.append(act())
49
+ self.model = nn.Sequential(*layers)
50
+
51
+
52
+ class ClipCaptionModel(nn.Module):
53
+
54
+ #@functools.lru_cache #FIXME
55
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
56
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
57
+
58
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
59
+ embedding_text = self.gpt.transformer.wte(tokens)
60
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
61
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
62
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
63
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
64
+ if labels is not None:
65
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
66
+ labels = torch.cat((dummy_token, tokens), dim=1)
67
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
68
+ return out
69
+
70
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
71
+ super(ClipCaptionModel, self).__init__()
72
+ self.prefix_length = prefix_length
73
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
74
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
75
+ if prefix_length > 10: # not enough memory
76
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
77
+ else:
78
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
79
+
80
+
81
+ class ClipCaptionPrefix(ClipCaptionModel):
82
+
83
+ def parameters(self, recurse: bool = True):
84
+ return self.clip_project.parameters()
85
+
86
+ def train(self, mode: bool = True):
87
+ super(ClipCaptionPrefix, self).train(mode)
88
+ self.gpt.eval()
89
+ return self
90
+
91
+ #@title Caption prediction
92
+
93
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
94
+ entry_length=67, temperature=1., stop_token: str = '.'):
95
+
96
+ model.eval()
97
+ stop_token_index = tokenizer.encode(stop_token)[0]
98
+ tokens = None
99
+ scores = None
100
+ device = next(model.parameters()).device
101
+ seq_lengths = torch.ones(beam_size, device=device)
102
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
103
+ with torch.no_grad():
104
+ if embed is not None:
105
+ generated = embed
106
+ else:
107
+ if tokens is None:
108
+ tokens = torch.tensor(tokenizer.encode(prompt))
109
+ tokens = tokens.unsqueeze(0).to(device)
110
+ generated = model.gpt.transformer.wte(tokens)
111
+ for i in range(entry_length):
112
+ outputs = model.gpt(inputs_embeds=generated)
113
+ logits = outputs.logits
114
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
115
+ logits = logits.softmax(-1).log()
116
+ if scores is None:
117
+ scores, next_tokens = logits.topk(beam_size, -1)
118
+ generated = generated.expand(beam_size, *generated.shape[1:])
119
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
120
+ if tokens is None:
121
+ tokens = next_tokens
122
+ else:
123
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
124
+ tokens = torch.cat((tokens, next_tokens), dim=1)
125
+ else:
126
+ logits[is_stopped] = -float(np.inf)
127
+ logits[is_stopped, 0] = 0
128
+ scores_sum = scores[:, None] + logits
129
+ seq_lengths[~is_stopped] += 1
130
+ scores_sum_average = scores_sum / seq_lengths[:, None]
131
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
132
+ next_tokens_source = next_tokens // scores_sum.shape[1]
133
+ seq_lengths = seq_lengths[next_tokens_source]
134
+ next_tokens = next_tokens % scores_sum.shape[1]
135
+ next_tokens = next_tokens.unsqueeze(1)
136
+ tokens = tokens[next_tokens_source]
137
+ tokens = torch.cat((tokens, next_tokens), dim=1)
138
+ generated = generated[next_tokens_source]
139
+ scores = scores_sum_average * seq_lengths
140
+ is_stopped = is_stopped[next_tokens_source]
141
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
142
+ generated = torch.cat((generated, next_token_embed), dim=1)
143
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
144
+ if is_stopped.all():
145
+ break
146
+ scores = scores / seq_lengths
147
+ output_list = tokens.cpu().numpy()
148
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
149
+ order = scores.argsort(descending=True)
150
+ output_texts = [output_texts[i] for i in order]
151
+ return output_texts
152
+
153
+
154
+ def generate2(
155
+ model,
156
+ tokenizer,
157
+ tokens=None,
158
+ prompt=None,
159
+ embed=None,
160
+ entry_count=1,
161
+ entry_length=67, # maximum number of words
162
+ top_p=0.8,
163
+ temperature=1.,
164
+ stop_token: str = '.',
165
+ ):
166
+ model.eval()
167
+ generated_num = 0
168
+ generated_list = []
169
+ stop_token_index = tokenizer.encode(stop_token)[0]
170
+ filter_value = -float("Inf")
171
+ device = next(model.parameters()).device
172
+
173
+ with torch.no_grad():
174
+
175
+ for entry_idx in trange(entry_count):
176
+ if embed is not None:
177
+ generated = embed
178
+ else:
179
+ if tokens is None:
180
+ tokens = torch.tensor(tokenizer.encode(prompt))
181
+ tokens = tokens.unsqueeze(0).to(device)
182
+
183
+ generated = model.gpt.transformer.wte(tokens)
184
+
185
+ for i in range(entry_length):
186
+
187
+ outputs = model.gpt(inputs_embeds=generated)
188
+ logits = outputs.logits
189
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
190
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
191
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
192
+ sorted_indices_to_remove = cumulative_probs > top_p
193
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
194
+ ..., :-1
195
+ ].clone()
196
+ sorted_indices_to_remove[..., 0] = 0
197
+
198
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
199
+ logits[:, indices_to_remove] = filter_value
200
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
201
+ next_token_embed = model.gpt.transformer.wte(next_token)
202
+ if tokens is None:
203
+ tokens = next_token
204
+ else:
205
+ tokens = torch.cat((tokens, next_token), dim=1)
206
+ generated = torch.cat((generated, next_token_embed), dim=1)
207
+ if stop_token_index == next_token.item():
208
+ break
209
+
210
+ output_list = list(tokens.squeeze().cpu().numpy())
211
+ output_text = tokenizer.decode(output_list)
212
+ generated_list.append(output_text)
213
+
214
+ return generated_list[0]
215
+ pretrained_model = 'COCO'
216
+
217
 
218
 
219
  def process_image(image):
220
+ #@title CLIP model + GPT2 tokenizer
221
+ device = "cpu"
222
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
223
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
224
+
225
+ #@title Load model weights
226
+ prefix_length = 10
227
+ model = ClipCaptionModel(prefix_length)
228
+ model.load_state_dict(torch.load(model_path, map_location=CPU),strict=False)
229
+ model = model.eval()
230
+
231
+ #@title Inference
232
+ use_beam_search = False #@param {type:"boolean"}
233
+
234
+ #image = io.imread(image)
235
+ pil_image = Image.fromarray(image)
236
+ #pil_img = Image(filename=UPLOADED_FILE)
237
+ #display(pil_image)
238
+
239
+ image = preprocess(pil_image).unsqueeze(0).to(device)
240
+ with torch.no_grad():
241
+ # if type(model) is ClipCaptionE2E:
242
+ # prefix_embed = model.forward_image(image)
243
+ # else:
244
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
245
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
246
+ if use_beam_search:
247
+ generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
248
+ else:
249
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
250
+
251
+ #print('\n')
252
+ #print(generated_text_prefix)
253
+
254
+
255
+ return generated_text_prefix,"Yoruba text", "Swahili text", "Twi text"
256
+
257
+
258
  app = gr.Interface(
259
  fn=process_image,
260
+ inputs = gr.Image(label="Upload an Image"),
261
+
262
  outputs = [
263
  gr.Text(label="English: "),
264
  gr.Text(label="Yoruba: "),
 
273
  )
274
  if __name__ == "__main__":
275
  app.launch()
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pillow
2
+ torch
3
+ torchvision
4
+ transfomers
5
+ CLIPModel
6
+ requests