Update app.py
Browse files
app.py
CHANGED
@@ -242,27 +242,27 @@ def inference(image, audio, video, task_type, instruction):
|
|
242 |
if task_type == 'Image Captioning':
|
243 |
text = ['']
|
244 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
245 |
-
model = model_caption
|
246 |
elif task_type == 'Video Captioning':
|
247 |
text = ['']
|
248 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
249 |
model_caption = model_caption.load_state_dict(state_dict_video_caption,strict=False)
|
250 |
-
model = model_caption
|
251 |
elif task_type == 'Audio Captioning':
|
252 |
text = ['']
|
253 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
254 |
model_caption = model_caption.load_state_dict(state_dict_audio_caption,strict=False)
|
255 |
-
model = model_caption
|
256 |
elif task_type == 'Visual Question Answering':
|
257 |
question = instruction+'?'+special_answer_token
|
258 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
259 |
model_caption = model_caption.load_state_dict(state_dict_vqa,strict=False)
|
260 |
-
model = model_caption
|
261 |
elif task_type == 'Visual Question Answering':
|
262 |
question = instruction+'?'+special_answer_token
|
263 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
264 |
model_caption = model_caption.load_state_dict(state_dict_video_qa,strict=False)
|
265 |
-
model = model_caption
|
266 |
else:
|
267 |
raise NotImplemented
|
268 |
|
|
|
242 |
if task_type == 'Image Captioning':
|
243 |
text = ['']
|
244 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
245 |
+
model = model_caption
|
246 |
elif task_type == 'Video Captioning':
|
247 |
text = ['']
|
248 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
249 |
model_caption = model_caption.load_state_dict(state_dict_video_caption,strict=False)
|
250 |
+
model = model_caption
|
251 |
elif task_type == 'Audio Captioning':
|
252 |
text = ['']
|
253 |
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
|
254 |
model_caption = model_caption.load_state_dict(state_dict_audio_caption,strict=False)
|
255 |
+
model = model_caption
|
256 |
elif task_type == 'Visual Question Answering':
|
257 |
question = instruction+'?'+special_answer_token
|
258 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
259 |
model_caption = model_caption.load_state_dict(state_dict_vqa,strict=False)
|
260 |
+
model = model_caption
|
261 |
elif task_type == 'Visual Question Answering':
|
262 |
question = instruction+'?'+special_answer_token
|
263 |
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
|
264 |
model_caption = model_caption.load_state_dict(state_dict_video_qa,strict=False)
|
265 |
+
model = model_caption
|
266 |
else:
|
267 |
raise NotImplemented
|
268 |
|