Fawaz commited on
Commit
ff116b0
Β·
1 Parent(s): 5de9a0d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -186
app.py DELETED
@@ -1,186 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torch.utils.data
5
- import torchvision.transforms as transforms
6
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
7
- from PIL import Image
8
- import clip
9
- import numpy as np
10
- import cv2
11
- import gradio as gr
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
-
14
- def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
15
-
16
- assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
17
- top_k = min(top_k, logits.size(-1))
18
- if top_k > 0:
19
- # Remove all tokens with a probability less than the last token in the top-k tokens
20
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
21
- logits[indices_to_remove] = filter_value
22
-
23
- if top_p > 0.0:
24
- # Compute cumulative probabilities of sorted tokens
25
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
26
- cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
27
-
28
- # Remove tokens with cumulative probability above the threshold
29
- sorted_indices_to_remove = cumulative_probabilities > top_p
30
- # Shift the indices to the right to keep also the first token above the threshold
31
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
32
- sorted_indices_to_remove[..., 0] = 0
33
-
34
- # Back to unsorted indices and set them to -infinity
35
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
36
- logits[indices_to_remove] = filter_value
37
-
38
- indices_to_remove = logits < threshold
39
- logits[indices_to_remove] = filter_value
40
-
41
- return logits
42
-
43
- class ImageEncoder(nn.Module):
44
-
45
- def __init__(self):
46
- super(ImageEncoder, self).__init__()
47
-
48
- self.encoder, _ = clip.load("ViT-B/16", device=device) # loads already in eval mode
49
-
50
- def forward(self, x):
51
- """
52
- Expects a tensor of size (batch_size, 3, 224, 224)
53
- """
54
- with torch.no_grad():
55
- x = x.type(self.encoder.visual.conv1.weight.dtype)
56
- x = self.encoder.visual.conv1(x) # shape = [*, width, grid, grid]
57
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
58
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
59
- x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
60
- x = x + self.encoder.visual.positional_embedding.to(x.dtype)
61
- x = self.encoder.visual.ln_pre(x)
62
- x = x.permute(1, 0, 2) # NLD -> LND
63
- x = self.encoder.visual.transformer(x)
64
- grid_feats = x.permute(1, 0, 2) # LND -> NLD (N, 197, 768)
65
- grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:])
66
-
67
- return grid_feats.float()
68
-
69
- def change_requires_grad(model, req_grad):
70
- for p in model.parameters():
71
- p.requires_grad = req_grad
72
-
73
- def load_checkpoint(ckpt_path, epoch):
74
-
75
- model_name = 'nle_model_{}'.format(str(epoch))
76
- tokenizer_name = 'nle_gpt2_tokenizer_0'
77
- tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name) # load tokenizer
78
- model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device) # load model with config
79
- return tokenizer, model
80
-
81
- def sample_sequences(img, model, input_ids, segment_ids, tokenizer):
82
-
83
- SPECIAL_TOKENS = ['<|endoftext|>', '<pad>', '<question>', '<answer>', '<explanation>']
84
- special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
85
- because_token = tokenizer.convert_tokens_to_ids('Δ because')
86
- max_len = 20
87
- current_output = []
88
- img_embeddings = image_encoder(img)
89
- always_exp = False
90
-
91
- with torch.no_grad():
92
-
93
- for step in range(max_len + 1):
94
-
95
- if step == max_len:
96
- break
97
-
98
- outputs = model(input_ids=input_ids,
99
- past_key_values=None,
100
- attention_mask=None,
101
- token_type_ids=segment_ids,
102
- position_ids=None,
103
- encoder_hidden_states=img_embeddings,
104
- encoder_attention_mask=None,
105
- labels=None,
106
- use_cache=False,
107
- output_attentions=True,
108
- return_dict=True)
109
-
110
- lm_logits = outputs.logits
111
- xa_maps = outputs.cross_attentions
112
- logits = lm_logits[0, -1, :] / temperature
113
- logits = top_filtering(logits, top_k=top_k, top_p=top_p)
114
- probs = F.softmax(logits, dim=-1)
115
- prev = torch.topk(probs, 1)[1] if no_sample else torch.multinomial(probs, 1)
116
-
117
- if prev.item() in special_tokens_ids:
118
- break
119
-
120
- # take care of when to start the <explanation> token. Nasty code in here (i hate lots of ifs)
121
- if not always_exp:
122
-
123
- if prev.item() != because_token:
124
- new_segment = special_tokens_ids[-2] # answer segment
125
- else:
126
- new_segment = special_tokens_ids[-1] # explanation segment
127
- always_exp = True
128
- else:
129
- new_segment = special_tokens_ids[-1] # explanation segment
130
-
131
- new_segment = torch.LongTensor([new_segment]).to(device)
132
- current_output.append(prev.item())
133
- input_ids = torch.cat((input_ids, prev.unsqueeze(0)), dim = 1)
134
- segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0)), dim = 1)
135
-
136
- decoded_sequences = tokenizer.decode(current_output, skip_special_tokens=True).lstrip()
137
-
138
- return decoded_sequences, xa_maps
139
-
140
- def get_inputs(tokenizer):
141
- a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<answer>', '<explanation>'])
142
- tokens = [tokenizer.bos_token] + tokenizer.tokenize("the answer is")
143
- segment_ids = [a_segment_id] * len(tokens)
144
- input_ids = tokenizer.convert_tokens_to_ids(tokens)
145
- input_ids = torch.tensor(input_ids, dtype=torch.long)
146
- segment_ids = torch.tensor(segment_ids, dtype=torch.long)
147
-
148
- return input_ids.unsqueeze(0).to(device), segment_ids.unsqueeze(0).to(device)
149
-
150
- img_size = 224
151
- ckpt_path = 'ACTX_p/'
152
- max_seq_len = 30
153
- load_from_epoch = 5
154
- no_sample = True
155
- top_k = 0
156
- top_p = 0.9
157
- temperature = 1
158
-
159
- image_encoder = ImageEncoder().to(device)
160
- change_requires_grad(image_encoder, False)
161
- tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch)
162
- model.eval()
163
-
164
-
165
- img_transform = transforms.Compose([transforms.Resize((img_size,img_size)),
166
- transforms.ToTensor(),
167
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
168
-
169
- def inference(raw_image):
170
-
171
- oimg = raw_image.convert('RGB').resize((224,224))
172
- img = img_transform(oimg).unsqueeze(0).to(device)
173
- input_ids, segment_ids = get_inputs(tokenizer)
174
- seq, xa_maps = sample_sequences(img, model, input_ids, segment_ids, tokenizer)
175
- last_am = xa_maps[-1].mean(1)[0]
176
- mask = last_am[0, :].reshape(14,14).cpu().numpy()
177
- mask = cv2.resize(mask / mask.max(), oimg.size)[..., np.newaxis]
178
- attention_map = (mask * oimg).astype("uint8")
179
- splitted_seq = seq.split("because")
180
- return splitted_seq[0].strip(), "because " + splitted_seq[-1].strip(), Image.fromarray(attention_map)
181
-
182
- inputs = [gr.inputs.Image(type='pil', label="Load the image of your interest")]
183
- outputs = [gr.outputs.Textbox(label="What action is this?"), gr.outputs.Textbox(label="Textual Explanation"), gr.outputs.Image(type='pil', label="Visual Explanation")]
184
-
185
- title = "NLX-GPT: Explanations with Natural Text (Action Recognition Demo)"
186
- gr.Interface(inference, inputs, outputs, title=title).launch()