mshukor commited on
Commit
4da2434
·
1 Parent(s): 87d7283
Files changed (1) hide show
  1. app.py +23 -157
app.py CHANGED
@@ -9,7 +9,6 @@ os.system('pwd')
9
  import os, sys
10
  sys.path.append("/home/user/app/TimeSformer/")
11
 
12
- import timesformer
13
 
14
 
15
  import torch
@@ -37,7 +36,6 @@ from ruamel.yaml import YAML
37
  import torch
38
  import gradio as gr
39
 
40
- import torchaudio
41
 
42
  yaml=YAML(typ='safe')
43
 
@@ -50,7 +48,7 @@ device_type = 'cuda' if use_cuda else 'cpu'
50
  ## Load model
51
 
52
  ### Captioning
53
- config = 'configs/image/ePALM_caption.yaml'
54
  # config = yaml.load(open(config, 'r'), Loader=yaml.Loader)
55
  config = yaml.load(open(config, 'r'))
56
 
@@ -63,7 +61,7 @@ vision_model_name = 'vit_base_patch16_224'
63
  start_layer_idx = 19
64
  end_layer_idx = 31
65
  low_cpu = True
66
- model = ePALM(opt_model_name=text_model,
67
  vision_model_name=vision_model_name,
68
  use_vis_prefix=True,
69
  start_layer_idx=start_layer_idx,
@@ -73,64 +71,14 @@ model = ePALM(opt_model_name=text_model,
73
  low_cpu=low_cpu
74
  )
75
  print("Model Built")
76
- model.to(device)
77
 
78
- checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
79
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
80
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
81
- state_dict = checkpoint['model']
82
- msg = model.load_state_dict(state_dict,strict=False)
83
-
84
- model.bfloat16()
85
-
86
- # ###### VQA
87
- # config = 'configs/image/ePALM_vqa.yaml'
88
- # config = yaml.load(open(config, 'r'))
89
-
90
- # start_layer_idx = 19
91
- # end_layer_idx = 31
92
- # low_cpu = True
93
- # model_vqa = ePALM(opt_model_name=text_model,
94
- # vision_model_name=vision_model_name,
95
- # use_vis_prefix=True,
96
- # start_layer_idx=start_layer_idx,
97
- # end_layer_idx=end_layer_idx,
98
- # return_hidden_state_vision=True,
99
- # config=config,
100
- # low_cpu=low_cpu
101
- # )
102
- # print("Model Built")
103
- # model_vqa.to(device)
104
-
105
-
106
- checkpoint_path = 'checkpoints/float32/ePALM_vqa/checkpoint_best.pth'
107
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
108
- state_dict_vqa = checkpoint['model']
109
- # msg = model_vqa.load_state_dict(state_dict,strict=False)
110
-
111
-
112
- # model_vqa.bfloat16()
113
-
114
-
115
-
116
- # Video Captioning
117
  checkpoint_path = 'checkpoints/float32/ePALM_video_caption_msrvtt/checkpoint_best.pth'
118
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
119
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
120
- state_dict_video_caption = checkpoint['model']
121
-
122
- # Video QA
123
- checkpoint_path = 'checkpoints/float32/ePALM_video_qa_msrvtt/checkpoint_best.pth'
124
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
125
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
126
- state_dict_video_qa = checkpoint['model']
127
-
128
 
129
- # Audio Captioning
130
- checkpoint_path = 'checkpoints/float32/ePALM_audio_caption/checkpoint_best.pth'
131
- # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
132
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
133
- state_dict_audio_caption = checkpoint['model']
134
 
135
 
136
 
@@ -149,11 +97,7 @@ tokenizer.add_special_tokens(special_tokens_dict)
149
  image_size = 224
150
  normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
151
 
152
- transform = transforms.Compose([
153
- transforms.Resize((image_size,image_size),interpolation=Image.BICUBIC),
154
- transforms.ToTensor(),
155
- normalize,
156
- ])
157
 
158
  type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
159
  test_transform = transforms.Compose([
@@ -161,6 +105,7 @@ test_transform = transforms.Compose([
161
  type_transform,
162
  normalize,
163
  ])
 
164
  from dataset.video_utils import VIDEO_READER_FUNCS
165
  video_reader = VIDEO_READER_FUNCS['decord']
166
 
@@ -174,60 +119,6 @@ def read_video(path, num_frames=16):
174
 
175
  return video
176
 
177
- def read_audio(path):
178
-
179
- melbins = 128
180
- target_length = 1024
181
- skip_norm = False
182
- norm_mean = -4.2677393
183
- norm_std = 4.5689974
184
-
185
- waveform, sr = torchaudio.load(path)
186
- waveform = waveform - waveform.mean()
187
-
188
- # audio
189
- fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
190
- window_type='hanning', num_mel_bins=melbins, dither=0.0,
191
- frame_shift=10)
192
-
193
- n_frames = fbank.shape[0]
194
-
195
- p = target_length - n_frames
196
-
197
- # cut and pad
198
- if p > 0:
199
- m = torch.nn.ZeroPad2d((0, 0, 0, p))
200
- fbank = m(fbank)
201
- elif p < 0:
202
- fbank = fbank[0:target_length, :]
203
-
204
-
205
-
206
-
207
- # SpecAug, not do for eval set
208
-
209
- fbank = torch.transpose(fbank, 0, 1)
210
- # this is just to satisfy new torchaudio version, which only accept [1, freq, time]
211
- fbank = fbank.unsqueeze(0)
212
-
213
-
214
-
215
- # squeeze it back, it is just a trick to satisfy new torchaudio version
216
- fbank = fbank.squeeze(0)
217
- fbank = torch.transpose(fbank, 0, 1)
218
-
219
-
220
- # normalize the input for both training and test
221
- if not skip_norm:
222
- fbank = (fbank - norm_mean) / (norm_std * 2)
223
- # skip normalization the input if you are trying to get the normalization stats.
224
- else:
225
- pass
226
-
227
-
228
- audio = fbank
229
-
230
- return audio
231
 
232
  do_sample=False
233
  num_beams=3
@@ -237,37 +128,19 @@ max_length=30
237
 
238
 
239
 
240
- def inference(image, audio, video, task_type, instruction):
241
 
242
- if task_type == 'Image Captioning':
243
- text = ['']
244
- text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
245
- elif task_type == 'Video Captioning':
246
- text = ['']
247
- text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
248
- model = model.load_state_dict(state_dict_video_caption,strict=False)
249
- elif task_type == 'Audio Captioning':
250
  text = ['']
251
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
252
- model = model.load_state_dict(state_dict_audio_caption,strict=False)
253
- elif task_type == 'Visual Question Answering':
254
- question = instruction+'?'+special_answer_token
255
- text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
256
- model = model.load_state_dict(state_dict_vqa,strict=False)
257
- elif task_type == 'Visual Question Answering':
258
- question = instruction+'?'+special_answer_token
259
- text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
260
- model = model.load_state_dict(state_dict_video_qa,strict=False)
261
  else:
262
  raise NotImplemented
263
 
264
- if "Video" in task_type:
265
- image = read_video(image)
266
- elif "Audio" in task_type:
267
- image = read_audio(image)
268
- else:
269
- image = transform(image)
270
- image = image.to(device,non_blocking=True).unsqueeze(0)
271
 
272
 
273
 
@@ -290,25 +163,18 @@ def inference(image, audio, video, task_type, instruction):
290
  return response
291
 
292
 
293
- inputs = [gr.inputs.Image(type='pil'), gr.Audio(source="upload", type="filepath"), gr.Video(source="upload", type="filepath"), gr.inputs.Radio(choices=['Image Captioning', 'Video Captioning', 'Audio Captioning', "Visual Question Answering", "Visual Grounding", "General", "General Video"], type="value", default="Image Captioning", label="Task"), gr.inputs.Textbox(lines=1, label="Instruction")]
294
  outputs = ['text']
295
  examples = [
296
- ['examples/images/soccer.jpg', None, None, 'Image Captioning', None],
297
- ['examples/images/ski.jpg', None, None, 'Visual Question Answering', 'what does the woman wearing black do?'],
298
- ['examples/images/banana.jpg', None, None, 'Image Captioning', None],
299
- ['examples/images/skateboard.jpg', None, None, 'Visual Question Answering', 'what is on top of the skateboard?'],
300
- ['examples/images/baseball.jpg', None, None, 'Image Captioning', None],
301
- [None, None, 'examples/videos/video7014.mp4', 'Video Captioning', None],
302
- [None, None, 'examples/videos/video7017.mp4', 'Video Captioning', None],
303
- [None, None, 'examples/videos/video7019.mp4', 'Video Captioning', None],
304
- [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
305
- [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
306
- [None, 'examples/audios/6cS0FsUM-cQ.wav', None, 'Audio Captioning', None],
307
- [None, 'examples/audios/AJtNitYMa1I.wav', None, 'Audio Captioning', None],
308
  ]
309
 
310
- title = "eP-ALM"
311
- description = "Gradio Demo for eP-ALM: "
312
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2303.11403' target='_blank'>Paper</a> | <a href='https://github.com/mshukor/eP-ALM' target='_blank'>Github Repo</a></p>"
313
 
314
  io = gr.Interface(fn=inference, inputs=inputs, outputs=outputs,
 
9
  import os, sys
10
  sys.path.append("/home/user/app/TimeSformer/")
11
 
 
12
 
13
 
14
  import torch
 
36
  import torch
37
  import gradio as gr
38
 
 
39
 
40
  yaml=YAML(typ='safe')
41
 
 
48
  ## Load model
49
 
50
  ### Captioning
51
+ config = 'configs/video/ePALM_video_caption_msrvtt.yaml'
52
  # config = yaml.load(open(config, 'r'), Loader=yaml.Loader)
53
  config = yaml.load(open(config, 'r'))
54
 
 
61
  start_layer_idx = 19
62
  end_layer_idx = 31
63
  low_cpu = True
64
+ MODEL = ePALM(opt_model_name=text_model,
65
  vision_model_name=vision_model_name,
66
  use_vis_prefix=True,
67
  start_layer_idx=start_layer_idx,
 
71
  low_cpu=low_cpu
72
  )
73
  print("Model Built")
74
+ MODEL.to(device)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  checkpoint_path = 'checkpoints/float32/ePALM_video_caption_msrvtt/checkpoint_best.pth'
 
77
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
78
+ state_dict = checkpoint['model']
79
+ msg = MODEL.load_state_dict(state_dict,strict=False)
 
 
 
 
 
 
80
 
81
+ MODEL.bfloat16()
 
 
 
 
82
 
83
 
84
 
 
97
  image_size = 224
98
  normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
99
 
100
+
 
 
 
 
101
 
102
  type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
103
  test_transform = transforms.Compose([
 
105
  type_transform,
106
  normalize,
107
  ])
108
+
109
  from dataset.video_utils import VIDEO_READER_FUNCS
110
  video_reader = VIDEO_READER_FUNCS['decord']
111
 
 
119
 
120
  return video
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  do_sample=False
124
  num_beams=3
 
128
 
129
 
130
 
131
+ def inference(image, task_type, instruction):
132
 
133
+
134
+ if task_type == 'Video Captioning':
 
 
 
 
 
 
135
  text = ['']
136
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
137
+ model = MODEL
 
 
 
 
 
 
 
 
138
  else:
139
  raise NotImplemented
140
 
141
+ image = read_video(image)
142
+
143
+
 
 
 
 
144
 
145
 
146
 
 
163
  return response
164
 
165
 
166
+ inputs = [gr.Video(source="upload", type="filepath"), gr.inputs.Radio(choices=['Video Captioning'], type="value", default="Video Captioning", label="Task"), gr.inputs.Textbox(lines=1, label="Instruction")]
167
  outputs = ['text']
168
  examples = [
169
+ ['examples/videos/video7014.mp4', 'Video Captioning', None],
170
+ ['examples/videos/video7017.mp4', 'Video Captioning', None],
171
+ ['examples/videos/video7019.mp4', 'Video Captioning', None],
172
+ ['examples/videos/video7021.mp4', 'Video Captioning', None],
173
+ ['examples/videos/video7021.mp4', 'Video Captioning', None],
 
 
 
 
 
 
 
174
  ]
175
 
176
+ title = "eP-ALM for Video-Text tasks"
177
+ description = "Gradio Demo for eP-ALM. For this demo, we use 2.7B OPT. As the model runs on CPUs and float16 mixed precision is not supported on CPUs, the generation can take up to 2 mins."
178
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2303.11403' target='_blank'>Paper</a> | <a href='https://github.com/mshukor/eP-ALM' target='_blank'>Github Repo</a></p>"
179
 
180
  io = gr.Interface(fn=inference, inputs=inputs, outputs=outputs,