mshukor commited on
Commit
02eb18e
·
1 Parent(s): ce7469b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -242,27 +242,27 @@ def inference(image, audio, video, task_type, instruction):
242
  if task_type == 'Image Captioning':
243
  text = ['']
244
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
245
- model = model_caption.clone()
246
  elif task_type == 'Video Captioning':
247
  text = ['']
248
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
249
  model_caption = model_caption.load_state_dict(state_dict_video_caption,strict=False)
250
- model = model_caption.clone()
251
  elif task_type == 'Audio Captioning':
252
  text = ['']
253
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
254
  model_caption = model_caption.load_state_dict(state_dict_audio_caption,strict=False)
255
- model = model_caption.clone()
256
  elif task_type == 'Visual Question Answering':
257
  question = instruction+'?'+special_answer_token
258
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
259
  model_caption = model_caption.load_state_dict(state_dict_vqa,strict=False)
260
- model = model_caption.clone()
261
  elif task_type == 'Visual Question Answering':
262
  question = instruction+'?'+special_answer_token
263
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
264
  model_caption = model_caption.load_state_dict(state_dict_video_qa,strict=False)
265
- model = model_caption.clone()
266
  else:
267
  raise NotImplemented
268
 
 
242
  if task_type == 'Image Captioning':
243
  text = ['']
244
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
245
+ model = model_caption
246
  elif task_type == 'Video Captioning':
247
  text = ['']
248
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
249
  model_caption = model_caption.load_state_dict(state_dict_video_caption,strict=False)
250
+ model = model_caption
251
  elif task_type == 'Audio Captioning':
252
  text = ['']
253
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
254
  model_caption = model_caption.load_state_dict(state_dict_audio_caption,strict=False)
255
+ model = model_caption
256
  elif task_type == 'Visual Question Answering':
257
  question = instruction+'?'+special_answer_token
258
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
259
  model_caption = model_caption.load_state_dict(state_dict_vqa,strict=False)
260
+ model = model_caption
261
  elif task_type == 'Visual Question Answering':
262
  question = instruction+'?'+special_answer_token
263
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
264
  model_caption = model_caption.load_state_dict(state_dict_video_qa,strict=False)
265
+ model = model_caption
266
  else:
267
  raise NotImplemented
268