mshukor commited on
Commit
d5f4cd4
·
1 Parent(s): 02eb18e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -63,7 +63,7 @@ vision_model_name = 'vit_base_patch16_224'
63
  start_layer_idx = 19
64
  end_layer_idx = 31
65
  low_cpu = True
66
- model_caption = ePALM(opt_model_name=text_model,
67
  vision_model_name=vision_model_name,
68
  use_vis_prefix=True,
69
  start_layer_idx=start_layer_idx,
@@ -73,15 +73,15 @@ model_caption = ePALM(opt_model_name=text_model,
73
  low_cpu=low_cpu
74
  )
75
  print("Model Built")
76
- model_caption.to(device)
77
 
78
  checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
79
  # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
80
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
81
  state_dict = checkpoint['model']
82
- msg = model_caption.load_state_dict(state_dict,strict=False)
83
 
84
- model_caption.bfloat16()
85
 
86
  # ###### VQA
87
  # config = 'configs/image/ePALM_vqa.yaml'
@@ -242,27 +242,22 @@ def inference(image, audio, video, task_type, instruction):
242
  if task_type == 'Image Captioning':
243
  text = ['']
244
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
245
- model = model_caption
246
  elif task_type == 'Video Captioning':
247
  text = ['']
248
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
249
- model_caption = model_caption.load_state_dict(state_dict_video_caption,strict=False)
250
- model = model_caption
251
  elif task_type == 'Audio Captioning':
252
  text = ['']
253
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
254
- model_caption = model_caption.load_state_dict(state_dict_audio_caption,strict=False)
255
- model = model_caption
256
  elif task_type == 'Visual Question Answering':
257
  question = instruction+'?'+special_answer_token
258
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
259
- model_caption = model_caption.load_state_dict(state_dict_vqa,strict=False)
260
- model = model_caption
261
  elif task_type == 'Visual Question Answering':
262
  question = instruction+'?'+special_answer_token
263
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
264
- model_caption = model_caption.load_state_dict(state_dict_video_qa,strict=False)
265
- model = model_caption
266
  else:
267
  raise NotImplemented
268
 
 
63
  start_layer_idx = 19
64
  end_layer_idx = 31
65
  low_cpu = True
66
+ model = ePALM(opt_model_name=text_model,
67
  vision_model_name=vision_model_name,
68
  use_vis_prefix=True,
69
  start_layer_idx=start_layer_idx,
 
73
  low_cpu=low_cpu
74
  )
75
  print("Model Built")
76
+ model.to(device)
77
 
78
  checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
79
  # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
80
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
81
  state_dict = checkpoint['model']
82
+ msg = model.load_state_dict(state_dict,strict=False)
83
 
84
+ model.bfloat16()
85
 
86
  # ###### VQA
87
  # config = 'configs/image/ePALM_vqa.yaml'
 
242
  if task_type == 'Image Captioning':
243
  text = ['']
244
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
 
245
  elif task_type == 'Video Captioning':
246
  text = ['']
247
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
248
+ model = model.load_state_dict(state_dict_video_caption,strict=False)
 
249
  elif task_type == 'Audio Captioning':
250
  text = ['']
251
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
252
+ model = model.load_state_dict(state_dict_audio_caption,strict=False)
 
253
  elif task_type == 'Visual Question Answering':
254
  question = instruction+'?'+special_answer_token
255
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
256
+ model = model.load_state_dict(state_dict_vqa,strict=False)
 
257
  elif task_type == 'Visual Question Answering':
258
  question = instruction+'?'+special_answer_token
259
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
260
+ model = model.load_state_dict(state_dict_video_qa,strict=False)
 
261
  else:
262
  raise NotImplemented
263