Antonio commited on
Commit
f50e742
·
1 Parent(s): 357c13c

Cleaned Names

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -184,14 +184,18 @@ decision_frameworks = {
184
  def predict(video_file, video_model_name, audio_model_name, framework_name):
185
 
186
  image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
187
- video_model = torch.load('./' + video_model_name, map_location=torch.device('cpu'))
 
 
 
188
 
189
  model_id = "facebook/wav2vec2-large"
190
  config = AutoConfig.from_pretrained(model_id, num_labels=6)
191
  audio_processor = AutoFeatureExtractor.from_pretrained(model_id)
192
  audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id, config=config)
193
- audio_model.load_state_dict(torch.load('./' + audio_model_name, map_location=torch.device('cpu')))
194
- audio_model.eval()
 
195
 
196
  delete_directory_path = "./temp/"
197
 
@@ -219,8 +223,8 @@ def predict(video_file, video_model_name, audio_model_name, framework_name):
219
 
220
  inputs = [
221
  gr.File(label="Upload Video"),
222
- gr.Dropdown(["video_model_60_acc.pth", "video_model_80_acc.pth"], label="Select Video Model"),
223
- gr.Dropdown(["audio_model_state_dict_6e.pth"], label="Select Audio Model"),
224
  gr.Dropdown(list(decision_frameworks.keys()), label="Select Decision Framework")
225
  ]
226
 
 
184
  def predict(video_file, video_model_name, audio_model_name, framework_name):
185
 
186
  image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
187
+ if video_model_name == "60% Accuracy":
188
+ video_model = torch.load("video_model_60_acc.pth", map_location=torch.device('cpu'))
189
+ elif video_model_name == "80% Accuracy":
190
+ video_model = torch.load("video_model_80_acc.pth", map_location=torch.device('cpu'))
191
 
192
  model_id = "facebook/wav2vec2-large"
193
  config = AutoConfig.from_pretrained(model_id, num_labels=6)
194
  audio_processor = AutoFeatureExtractor.from_pretrained(model_id)
195
  audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id, config=config)
196
+ if audio_model_name == "60% Accuracy":
197
+ audio_model.load_state_dict(torch.load("audio_model_state_dict_6e.pth", map_location=torch.device('cpu')))
198
+ audio_model.eval()
199
 
200
  delete_directory_path = "./temp/"
201
 
 
223
 
224
  inputs = [
225
  gr.File(label="Upload Video"),
226
+ gr.Dropdown(["60% Accuracy", "80% Accuracy"], label="Select Video Model"),
227
+ gr.Dropdown(["60% Accuracy"], label="Select Audio Model"),
228
  gr.Dropdown(list(decision_frameworks.keys()), label="Select Decision Framework")
229
  ]
230