Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils.data
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
7 |
+
from PIL import Image
|
8 |
+
import clip
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import gradio as gr
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
+
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
|
15 |
+
|
16 |
+
assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
|
17 |
+
top_k = min(top_k, logits.size(-1))
|
18 |
+
if top_k > 0:
|
19 |
+
# Remove all tokens with a probability less than the last token in the top-k tokens
|
20 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
21 |
+
logits[indices_to_remove] = filter_value
|
22 |
+
|
23 |
+
if top_p > 0.0:
|
24 |
+
# Compute cumulative probabilities of sorted tokens
|
25 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
26 |
+
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
27 |
+
|
28 |
+
# Remove tokens with cumulative probability above the threshold
|
29 |
+
sorted_indices_to_remove = cumulative_probabilities > top_p
|
30 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
31 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
32 |
+
sorted_indices_to_remove[..., 0] = 0
|
33 |
+
|
34 |
+
# Back to unsorted indices and set them to -infinity
|
35 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
36 |
+
logits[indices_to_remove] = filter_value
|
37 |
+
|
38 |
+
indices_to_remove = logits < threshold
|
39 |
+
logits[indices_to_remove] = filter_value
|
40 |
+
|
41 |
+
return logits
|
42 |
+
|
43 |
+
class ImageEncoder(nn.Module):
|
44 |
+
|
45 |
+
def __init__(self):
|
46 |
+
super(ImageEncoder, self).__init__()
|
47 |
+
|
48 |
+
self.encoder, _ = clip.load("ViT-B/16", device=device) # loads already in eval mode
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
"""
|
52 |
+
Expects a tensor of size (batch_size, 3, 224, 224)
|
53 |
+
"""
|
54 |
+
with torch.no_grad():
|
55 |
+
x = x.type(self.encoder.visual.conv1.weight.dtype)
|
56 |
+
x = self.encoder.visual.conv1(x) # shape = [*, width, grid, grid]
|
57 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
58 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
59 |
+
x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
60 |
+
x = x + self.encoder.visual.positional_embedding.to(x.dtype)
|
61 |
+
x = self.encoder.visual.ln_pre(x)
|
62 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
63 |
+
x = self.encoder.visual.transformer(x)
|
64 |
+
grid_feats = x.permute(1, 0, 2) # LND -> NLD (N, 197, 768)
|
65 |
+
grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:])
|
66 |
+
|
67 |
+
return grid_feats.float()
|
68 |
+
|
69 |
+
def change_requires_grad(model, req_grad):
|
70 |
+
for p in model.parameters():
|
71 |
+
p.requires_grad = req_grad
|
72 |
+
|
73 |
+
def load_checkpoint(ckpt_path, epoch):
|
74 |
+
|
75 |
+
model_name = 'nle_model_{}'.format(str(epoch))
|
76 |
+
tokenizer_name = 'nle_gpt2_tokenizer_0'
|
77 |
+
tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name) # load tokenizer
|
78 |
+
model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device) # load model with config
|
79 |
+
return tokenizer, model
|
80 |
+
|
81 |
+
def sample_sequences(img, model, input_ids, segment_ids, tokenizer):
|
82 |
+
|
83 |
+
SPECIAL_TOKENS = ['<|endoftext|>', '<pad>', '<question>', '<answer>', '<explanation>']
|
84 |
+
special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
|
85 |
+
because_token = tokenizer.convert_tokens_to_ids('Δ because')
|
86 |
+
max_len = 20
|
87 |
+
current_output = []
|
88 |
+
img_embeddings = image_encoder(img)
|
89 |
+
always_exp = False
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
|
93 |
+
for step in range(max_len + 1):
|
94 |
+
|
95 |
+
if step == max_len:
|
96 |
+
break
|
97 |
+
|
98 |
+
outputs = model(input_ids=input_ids,
|
99 |
+
past_key_values=None,
|
100 |
+
attention_mask=None,
|
101 |
+
token_type_ids=segment_ids,
|
102 |
+
position_ids=None,
|
103 |
+
encoder_hidden_states=img_embeddings,
|
104 |
+
encoder_attention_mask=None,
|
105 |
+
labels=None,
|
106 |
+
use_cache=False,
|
107 |
+
output_attentions=True,
|
108 |
+
return_dict=True)
|
109 |
+
|
110 |
+
lm_logits = outputs.logits
|
111 |
+
xa_maps = outputs.cross_attentions
|
112 |
+
logits = lm_logits[0, -1, :] / temperature
|
113 |
+
logits = top_filtering(logits, top_k=top_k, top_p=top_p)
|
114 |
+
probs = F.softmax(logits, dim=-1)
|
115 |
+
prev = torch.topk(probs, 1)[1] if no_sample else torch.multinomial(probs, 1)
|
116 |
+
|
117 |
+
if prev.item() in special_tokens_ids:
|
118 |
+
break
|
119 |
+
|
120 |
+
# take care of when to start the <explanation> token. Nasty code in here (i hate lots of ifs)
|
121 |
+
if not always_exp:
|
122 |
+
|
123 |
+
if prev.item() != because_token:
|
124 |
+
new_segment = special_tokens_ids[-2] # answer segment
|
125 |
+
else:
|
126 |
+
new_segment = special_tokens_ids[-1] # explanation segment
|
127 |
+
always_exp = True
|
128 |
+
else:
|
129 |
+
new_segment = special_tokens_ids[-1] # explanation segment
|
130 |
+
|
131 |
+
new_segment = torch.LongTensor([new_segment]).to(device)
|
132 |
+
current_output.append(prev.item())
|
133 |
+
input_ids = torch.cat((input_ids, prev.unsqueeze(0)), dim = 1)
|
134 |
+
segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0)), dim = 1)
|
135 |
+
|
136 |
+
decoded_sequences = tokenizer.decode(current_output, skip_special_tokens=True).lstrip()
|
137 |
+
|
138 |
+
return decoded_sequences, xa_maps
|
139 |
+
|
140 |
+
def get_inputs(tokenizer):
|
141 |
+
a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<answer>', '<explanation>'])
|
142 |
+
tokens = [tokenizer.bos_token] + tokenizer.tokenize("the answer is")
|
143 |
+
segment_ids = [a_segment_id] * len(tokens)
|
144 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
145 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
146 |
+
segment_ids = torch.tensor(segment_ids, dtype=torch.long)
|
147 |
+
|
148 |
+
return input_ids.unsqueeze(0).to(device), segment_ids.unsqueeze(0).to(device)
|
149 |
+
|
150 |
+
img_size = 224
|
151 |
+
ckpt_path = 'ACTX_p/'
|
152 |
+
max_seq_len = 30
|
153 |
+
load_from_epoch = 5
|
154 |
+
no_sample = True
|
155 |
+
top_k = 0
|
156 |
+
top_p = 0.9
|
157 |
+
temperature = 1
|
158 |
+
|
159 |
+
image_encoder = ImageEncoder().to(device)
|
160 |
+
change_requires_grad(image_encoder, False)
|
161 |
+
tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch)
|
162 |
+
model.eval()
|
163 |
+
|
164 |
+
|
165 |
+
img_transform = transforms.Compose([transforms.Resize((img_size,img_size)),
|
166 |
+
transforms.ToTensor(),
|
167 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
168 |
+
|
169 |
+
def inference(raw_image):
|
170 |
+
|
171 |
+
oimg = raw_image.convert('RGB').resize((224,224))
|
172 |
+
img = img_transform(oimg).unsqueeze(0).to(device)
|
173 |
+
input_ids, segment_ids = get_inputs(tokenizer)
|
174 |
+
seq, xa_maps = sample_sequences(img, model, input_ids, segment_ids, tokenizer)
|
175 |
+
last_am = xa_maps[-1].mean(1)[0]
|
176 |
+
mask = last_am[0, :].reshape(14,14).cpu().numpy()
|
177 |
+
mask = cv2.resize(mask / mask.max(), oimg.size)[..., np.newaxis]
|
178 |
+
attention_map = (mask * oimg).astype("uint8")
|
179 |
+
splitted_seq = seq.split("because")
|
180 |
+
return splitted_seq[0].strip(), "because " + splitted_seq[-1].strip(), Image.fromarray(attention_map)
|
181 |
+
|
182 |
+
inputs = [gr.inputs.Image(type='pil', label="Load the image of your interest")]
|
183 |
+
outputs = [gr.outputs.Textbox(label="What action is this?"), gr.outputs.Textbox(label="Textual Explanation"), gr.outputs.Image(type='pil', label="Visual Explanation")]
|
184 |
+
|
185 |
+
title = "NLX-GPT: Explanations with Natural Text (Action Recognition Demo)"
|
186 |
+
gr.Interface(inference, inputs, outputs, title=title).launch()
|