mshukor commited on
Commit
0665edf
1 Parent(s): 26fbfb5

epalm images

Browse files
Files changed (1) hide show
  1. app.py +11 -151
app.py CHANGED
@@ -83,55 +83,15 @@ msg = MODEL.load_state_dict(state_dict,strict=False)
83
 
84
  MODEL.bfloat16()
85
 
86
- # ###### VQA
87
- # config = 'configs/image/ePALM_vqa.yaml'
88
- # config = yaml.load(open(config, 'r'))
89
-
90
- # start_layer_idx = 19
91
- # end_layer_idx = 31
92
- # low_cpu = True
93
- # model_vqa = ePALM(opt_model_name=text_model,
94
- # vision_model_name=vision_model_name,
95
- # use_vis_prefix=True,
96
- # start_layer_idx=start_layer_idx,
97
- # end_layer_idx=end_layer_idx,
98
- # return_hidden_state_vision=True,
99
- # config=config,
100
- # low_cpu=low_cpu
101
- # )
102
- # print("Model Built")
103
- # model_vqa.to(device)
104
 
105
 
106
  checkpoint_path = 'checkpoints/float32/ePALM_vqa/checkpoint_best.pth'
107
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
108
  state_dict_vqa = checkpoint['model']
109
- # msg = model_vqa.load_state_dict(state_dict,strict=False)
110
-
111
-
112
- # model_vqa.bfloat16()
113
-
114
-
115
-
116
- # Video Captioning
117
- checkpoint_path = 'checkpoints/float32/ePALM_video_caption_msrvtt/checkpoint_best.pth'
118
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
119
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
120
- state_dict_video_caption = checkpoint['model']
121
 
122
- # Video QA
123
- checkpoint_path = 'checkpoints/float32/ePALM_video_qa_msrvtt/checkpoint_best.pth'
124
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
125
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
126
- state_dict_video_qa = checkpoint['model']
127
 
128
 
129
- # Audio Captioning
130
- checkpoint_path = 'checkpoints/float32/ePALM_audio_caption/checkpoint_best.pth'
131
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
132
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
133
- state_dict_audio_caption = checkpoint['model']
134
-
135
 
136
 
137
 
@@ -155,125 +115,32 @@ transform = transforms.Compose([
155
  normalize,
156
  ])
157
 
158
- type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
159
- test_transform = transforms.Compose([
160
- transforms.Resize((image_size,image_size),interpolation=Image.BICUBIC),
161
- type_transform,
162
- normalize,
163
- ])
164
- from dataset.video_utils import VIDEO_READER_FUNCS
165
- video_reader = VIDEO_READER_FUNCS['decord']
166
-
167
- def read_video(path, num_frames=16):
168
-
169
-
170
- frames, frame_indices, video_duration = video_reader(
171
- path, num_frames, 'rand', max_num_frames=-1
172
- )
173
- video = test_transform(frames)
174
-
175
- return video
176
-
177
- def read_audio(path):
178
-
179
- melbins = 128
180
- target_length = 1024
181
- skip_norm = False
182
- norm_mean = -4.2677393
183
- norm_std = 4.5689974
184
-
185
- waveform, sr = torchaudio.load(path)
186
- waveform = waveform - waveform.mean()
187
-
188
- # audio
189
- fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
190
- window_type='hanning', num_mel_bins=melbins, dither=0.0,
191
- frame_shift=10)
192
-
193
- n_frames = fbank.shape[0]
194
-
195
- p = target_length - n_frames
196
-
197
- # cut and pad
198
- if p > 0:
199
- m = torch.nn.ZeroPad2d((0, 0, 0, p))
200
- fbank = m(fbank)
201
- elif p < 0:
202
- fbank = fbank[0:target_length, :]
203
-
204
-
205
-
206
-
207
- # SpecAug, not do for eval set
208
-
209
- fbank = torch.transpose(fbank, 0, 1)
210
- # this is just to satisfy new torchaudio version, which only accept [1, freq, time]
211
- fbank = fbank.unsqueeze(0)
212
-
213
-
214
-
215
- # squeeze it back, it is just a trick to satisfy new torchaudio version
216
- fbank = fbank.squeeze(0)
217
- fbank = torch.transpose(fbank, 0, 1)
218
-
219
-
220
- # normalize the input for both training and test
221
- if not skip_norm:
222
- fbank = (fbank - norm_mean) / (norm_std * 2)
223
- # skip normalization the input if you are trying to get the normalization stats.
224
- else:
225
- pass
226
-
227
-
228
- audio = fbank
229
-
230
- return audio
231
 
232
  do_sample=False
233
- num_beams=3
234
  max_length=30
235
 
236
 
237
 
238
 
239
 
240
- def inference(image, audio, video, task_type, instruction):
241
 
242
  if task_type == 'Image Captioning':
243
  text = ['']
244
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
245
  model = MODEL
246
- elif task_type == 'Video Captioning':
247
- text = ['']
248
- text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
249
- msg = MODEL.load_state_dict(state_dict_video_caption,strict=False)
250
- model = MODEL
251
- elif task_type == 'Audio Captioning':
252
- text = ['']
253
- text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
254
- msg = MODEL.load_state_dict(state_dict_audio_caption,strict=False)
255
- model = MODEL
256
  elif task_type == 'Visual Question Answering':
257
  question = instruction+'?'+special_answer_token
258
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
259
  msg = MODEL.load_state_dict(state_dict_vqa,strict=False)
260
  model = MODEL
261
  print(msg)
262
- elif task_type == 'Visual Question Answering':
263
- question = instruction+'?'+special_answer_token
264
- text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
265
- msg = MODEL.load_state_dict(state_dict_video_qa,strict=False)
266
- model = MODEL
267
  else:
268
  raise NotImplemented
269
 
270
- if "Video" in task_type:
271
- image = read_video(image)
272
- elif "Audio" in task_type:
273
- image = read_audio(image)
274
- else:
275
- image = transform(image)
276
- image = image.to(device,non_blocking=True).unsqueeze(0)
277
 
278
 
279
 
@@ -296,21 +163,14 @@ def inference(image, audio, video, task_type, instruction):
296
  return response
297
 
298
 
299
- inputs = [gr.inputs.Image(type='pil'), gr.Audio(source="upload", type="filepath"), gr.Video(source="upload", type="filepath"), gr.inputs.Radio(choices=['Image Captioning', 'Video Captioning', 'Audio Captioning', "Visual Question Answering", "Visual Grounding", "General", "General Video"], type="value", default="Image Captioning", label="Task"), gr.inputs.Textbox(lines=1, label="Instruction")]
300
  outputs = ['text']
301
  examples = [
302
- ['examples/images/soccer.jpg', None, None, 'Image Captioning', None],
303
- ['examples/images/ski.jpg', None, None, 'Visual Question Answering', 'what does the woman wearing black do?'],
304
- ['examples/images/banana.jpg', None, None, 'Image Captioning', None],
305
- ['examples/images/skateboard.jpg', None, None, 'Visual Question Answering', 'what is on top of the skateboard?'],
306
- ['examples/images/baseball.jpg', None, None, 'Image Captioning', None],
307
- [None, None, 'examples/videos/video7014.mp4', 'Video Captioning', None],
308
- [None, None, 'examples/videos/video7017.mp4', 'Video Captioning', None],
309
- [None, None, 'examples/videos/video7019.mp4', 'Video Captioning', None],
310
- [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
311
- [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
312
- [None, 'examples/audios/6cS0FsUM-cQ.wav', None, 'Audio Captioning', None],
313
- [None, 'examples/audios/AJtNitYMa1I.wav', None, 'Audio Captioning', None],
314
  ]
315
 
316
  title = "eP-ALM"
 
83
 
84
  MODEL.bfloat16()
85
 
86
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  checkpoint_path = 'checkpoints/float32/ePALM_vqa/checkpoint_best.pth'
90
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
91
  state_dict_vqa = checkpoint['model']
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
 
 
 
 
93
 
94
 
 
 
 
 
 
 
95
 
96
 
97
 
 
115
  normalize,
116
  ])
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  do_sample=False
120
+ num_beams=5
121
  max_length=30
122
 
123
 
124
 
125
 
126
 
127
+ def inference(image, task_type, instruction):
128
 
129
  if task_type == 'Image Captioning':
130
  text = ['']
131
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
132
  model = MODEL
 
 
 
 
 
 
 
 
 
 
133
  elif task_type == 'Visual Question Answering':
134
  question = instruction+'?'+special_answer_token
135
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
136
  msg = MODEL.load_state_dict(state_dict_vqa,strict=False)
137
  model = MODEL
138
  print(msg)
 
 
 
 
 
139
  else:
140
  raise NotImplemented
141
 
142
+ image = transform(image)
143
+ image = image.to(device,non_blocking=True).unsqueeze(0)
 
 
 
 
 
144
 
145
 
146
 
 
163
  return response
164
 
165
 
166
+ inputs = [gr.inputs.Image(type='pil'), gr.Audio(source="upload", type="filepath"), gr.Video(source="upload", type="filepath"), gr.inputs.Radio(choices=['Image Captioning', "Visual Question Answering",], type="value", default="Image Captioning", label="Task"), gr.inputs.Textbox(lines=1, label="Instruction")]
167
  outputs = ['text']
168
  examples = [
169
+ ['examples/images/soccer.jpg', 'Image Captioning', None],
170
+ ['examples/images/ski.jpg', 'Visual Question Answering', 'what does the woman do?'],
171
+ ['examples/images/banana.jpg', 'Image Captioning', None],
172
+ ['examples/images/skateboard.jpg', 'Visual Question Answering', 'what is on top of the skateboard?'],
173
+ ['examples/images/baseball.jpg', 'Image Captioning', None],
 
 
 
 
 
 
 
174
  ]
175
 
176
  title = "eP-ALM"