mshukor commited on
Commit
ea7ec0b
1 Parent(s): d5f4cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -63,7 +63,7 @@ vision_model_name = 'vit_base_patch16_224'
63
  start_layer_idx = 19
64
  end_layer_idx = 31
65
  low_cpu = True
66
- model = ePALM(opt_model_name=text_model,
67
  vision_model_name=vision_model_name,
68
  use_vis_prefix=True,
69
  start_layer_idx=start_layer_idx,
@@ -73,15 +73,15 @@ model = ePALM(opt_model_name=text_model,
73
  low_cpu=low_cpu
74
  )
75
  print("Model Built")
76
- model.to(device)
77
 
78
  checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
79
  # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
80
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
81
  state_dict = checkpoint['model']
82
- msg = model.load_state_dict(state_dict,strict=False)
83
 
84
- model.bfloat16()
85
 
86
  # ###### VQA
87
  # config = 'configs/image/ePALM_vqa.yaml'
@@ -242,22 +242,23 @@ def inference(image, audio, video, task_type, instruction):
242
  if task_type == 'Image Captioning':
243
  text = ['']
244
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
 
245
  elif task_type == 'Video Captioning':
246
  text = ['']
247
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
248
- model = model.load_state_dict(state_dict_video_caption,strict=False)
249
  elif task_type == 'Audio Captioning':
250
  text = ['']
251
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
252
- model = model.load_state_dict(state_dict_audio_caption,strict=False)
253
  elif task_type == 'Visual Question Answering':
254
  question = instruction+'?'+special_answer_token
255
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
256
- model = model.load_state_dict(state_dict_vqa,strict=False)
257
  elif task_type == 'Visual Question Answering':
258
  question = instruction+'?'+special_answer_token
259
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
260
- model = model.load_state_dict(state_dict_video_qa,strict=False)
261
  else:
262
  raise NotImplemented
263
 
 
63
  start_layer_idx = 19
64
  end_layer_idx = 31
65
  low_cpu = True
66
+ MODEL = ePALM(opt_model_name=text_model,
67
  vision_model_name=vision_model_name,
68
  use_vis_prefix=True,
69
  start_layer_idx=start_layer_idx,
 
73
  low_cpu=low_cpu
74
  )
75
  print("Model Built")
76
+ MODEL.to(device)
77
 
78
  checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
79
  # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
80
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
81
  state_dict = checkpoint['model']
82
+ msg = MODEL.load_state_dict(state_dict,strict=False)
83
 
84
+ MODEL.bfloat16()
85
 
86
  # ###### VQA
87
  # config = 'configs/image/ePALM_vqa.yaml'
 
242
  if task_type == 'Image Captioning':
243
  text = ['']
244
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
245
+ model = MODEL
246
  elif task_type == 'Video Captioning':
247
  text = ['']
248
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
249
+ model = MODEL.load_state_dict(state_dict_video_caption,strict=False)
250
  elif task_type == 'Audio Captioning':
251
  text = ['']
252
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
253
+ model = MODEL.load_state_dict(state_dict_audio_caption,strict=False)
254
  elif task_type == 'Visual Question Answering':
255
  question = instruction+'?'+special_answer_token
256
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
257
+ model = MODEL.load_state_dict(state_dict_vqa,strict=False)
258
  elif task_type == 'Visual Question Answering':
259
  question = instruction+'?'+special_answer_token
260
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
261
+ model = MODEL.load_state_dict(state_dict_video_qa,strict=False)
262
  else:
263
  raise NotImplemented
264