Update app.py
Browse files
app.py
CHANGED
@@ -63,7 +63,7 @@ vision_model_name = 'vit_base_patch16_224'
|
|
63 |
start_layer_idx = 19
|
64 |
end_layer_idx = 31
|
65 |
low_cpu = True
|
66 |
-
|
67 |
vision_model_name=vision_model_name,
|
68 |
use_vis_prefix=True,
|
69 |
start_layer_idx=start_layer_idx,
|
@@ -73,15 +73,15 @@ model = ePALM(opt_model_name=text_model,
|
|
73 |
low_cpu=low_cpu
|
74 |
)
|
75 |
print("Model Built")
|
76 |
-
|
77 |
|
78 |
checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
|
79 |
# checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
|
80 |
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
81 |
state_dict = checkpoint['model']
|
82 |
-
msg =
|
83 |
|
84 |
-
|
85 |
|
86 |
# ###### VQA
|
87 |
# config = 'configs/image/ePALM_vqa.yaml'
|
@@ -242,22 +242,23 @@ def inference(image, audio, video, task_type, instruction):
|
|
242 |
if task_type == 'Image Captioning':
|
243 |
text = ['']
|
244 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
|
|
245 |
elif task_type == 'Video Captioning':
|
246 |
text = ['']
|
247 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
248 |
-
model =
|
249 |
elif task_type == 'Audio Captioning':
|
250 |
text = ['']
|
251 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
252 |
-
model =
|
253 |
elif task_type == 'Visual Question Answering':
|
254 |
question = instruction+'?'+special_answer_token
|
255 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
256 |
-
model =
|
257 |
elif task_type == 'Visual Question Answering':
|
258 |
question = instruction+'?'+special_answer_token
|
259 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
260 |
-
model =
|
261 |
else:
|
262 |
raise NotImplemented
|
263 |
|
|
|
63 |
start_layer_idx = 19
|
64 |
end_layer_idx = 31
|
65 |
low_cpu = True
|
66 |
+
MODEL = ePALM(opt_model_name=text_model,
|
67 |
vision_model_name=vision_model_name,
|
68 |
use_vis_prefix=True,
|
69 |
start_layer_idx=start_layer_idx,
|
|
|
73 |
low_cpu=low_cpu
|
74 |
)
|
75 |
print("Model Built")
|
76 |
+
MODEL.to(device)
|
77 |
|
78 |
checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
|
79 |
# checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
|
80 |
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
81 |
state_dict = checkpoint['model']
|
82 |
+
msg = MODEL.load_state_dict(state_dict,strict=False)
|
83 |
|
84 |
+
MODEL.bfloat16()
|
85 |
|
86 |
# ###### VQA
|
87 |
# config = 'configs/image/ePALM_vqa.yaml'
|
|
|
242 |
if task_type == 'Image Captioning':
|
243 |
text = ['']
|
244 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
245 |
+
model = MODEL
|
246 |
elif task_type == 'Video Captioning':
|
247 |
text = ['']
|
248 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
249 |
+
model = MODEL.load_state_dict(state_dict_video_caption,strict=False)
|
250 |
elif task_type == 'Audio Captioning':
|
251 |
text = ['']
|
252 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
253 |
+
model = MODEL.load_state_dict(state_dict_audio_caption,strict=False)
|
254 |
elif task_type == 'Visual Question Answering':
|
255 |
question = instruction+'?'+special_answer_token
|
256 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
257 |
+
model = MODEL.load_state_dict(state_dict_vqa,strict=False)
|
258 |
elif task_type == 'Visual Question Answering':
|
259 |
question = instruction+'?'+special_answer_token
|
260 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
261 |
+
model = MODEL.load_state_dict(state_dict_video_qa,strict=False)
|
262 |
else:
|
263 |
raise NotImplemented
|
264 |
|