Update app.py
Browse files
app.py
CHANGED
@@ -246,22 +246,27 @@ def inference(image, audio, video, task_type, instruction):
|
|
246 |
elif task_type == 'Video Captioning':
|
247 |
text = ['']
|
248 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
249 |
-
|
|
|
250 |
elif task_type == 'Audio Captioning':
|
251 |
text = ['']
|
252 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
253 |
-
|
|
|
254 |
elif task_type == 'Visual Question Answering':
|
255 |
question = instruction+'?'+special_answer_token
|
256 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
257 |
-
|
|
|
|
|
258 |
elif task_type == 'Visual Question Answering':
|
259 |
question = instruction+'?'+special_answer_token
|
260 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
261 |
-
|
|
|
262 |
else:
|
263 |
raise NotImplemented
|
264 |
-
|
265 |
if "Video" in task_type:
|
266 |
image = read_video(image)
|
267 |
elif "Audio" in task_type:
|
|
|
246 |
elif task_type == 'Video Captioning':
|
247 |
text = ['']
|
248 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
249 |
+
msg = MODEL.load_state_dict(state_dict_video_caption,strict=False)
|
250 |
+
model = MODEL
|
251 |
elif task_type == 'Audio Captioning':
|
252 |
text = ['']
|
253 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
254 |
+
msg = MODEL.load_state_dict(state_dict_audio_caption,strict=False)
|
255 |
+
model = MODEL
|
256 |
elif task_type == 'Visual Question Answering':
|
257 |
question = instruction+'?'+special_answer_token
|
258 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
259 |
+
msg = MODEL.load_state_dict(state_dict_vqa,strict=False)
|
260 |
+
model = MODEL
|
261 |
+
print(msg)
|
262 |
elif task_type == 'Visual Question Answering':
|
263 |
question = instruction+'?'+special_answer_token
|
264 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
265 |
+
msg = MODEL.load_state_dict(state_dict_video_qa,strict=False)
|
266 |
+
model = MODEL
|
267 |
else:
|
268 |
raise NotImplemented
|
269 |
+
|
270 |
if "Video" in task_type:
|
271 |
image = read_video(image)
|
272 |
elif "Audio" in task_type:
|