Antonio
commited on
Commit
·
f50e742
1
Parent(s):
357c13c
Cleaned Names
Browse files
app.py
CHANGED
@@ -184,14 +184,18 @@ decision_frameworks = {
|
|
184 |
def predict(video_file, video_model_name, audio_model_name, framework_name):
|
185 |
|
186 |
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
|
187 |
-
|
|
|
|
|
|
|
188 |
|
189 |
model_id = "facebook/wav2vec2-large"
|
190 |
config = AutoConfig.from_pretrained(model_id, num_labels=6)
|
191 |
audio_processor = AutoFeatureExtractor.from_pretrained(model_id)
|
192 |
audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id, config=config)
|
193 |
-
|
194 |
-
|
|
|
195 |
|
196 |
delete_directory_path = "./temp/"
|
197 |
|
@@ -219,8 +223,8 @@ def predict(video_file, video_model_name, audio_model_name, framework_name):
|
|
219 |
|
220 |
inputs = [
|
221 |
gr.File(label="Upload Video"),
|
222 |
-
gr.Dropdown(["
|
223 |
-
gr.Dropdown(["
|
224 |
gr.Dropdown(list(decision_frameworks.keys()), label="Select Decision Framework")
|
225 |
]
|
226 |
|
|
|
184 |
def predict(video_file, video_model_name, audio_model_name, framework_name):
|
185 |
|
186 |
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
|
187 |
+
if video_model_name == "60% Accuracy":
|
188 |
+
video_model = torch.load("video_model_60_acc.pth", map_location=torch.device('cpu'))
|
189 |
+
elif video_model_name == "80% Accuracy":
|
190 |
+
video_model = torch.load("video_model_80_acc.pth", map_location=torch.device('cpu'))
|
191 |
|
192 |
model_id = "facebook/wav2vec2-large"
|
193 |
config = AutoConfig.from_pretrained(model_id, num_labels=6)
|
194 |
audio_processor = AutoFeatureExtractor.from_pretrained(model_id)
|
195 |
audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id, config=config)
|
196 |
+
if audio_model_name == "60% Accuracy":
|
197 |
+
audio_model.load_state_dict(torch.load("audio_model_state_dict_6e.pth", map_location=torch.device('cpu')))
|
198 |
+
audio_model.eval()
|
199 |
|
200 |
delete_directory_path = "./temp/"
|
201 |
|
|
|
223 |
|
224 |
inputs = [
|
225 |
gr.File(label="Upload Video"),
|
226 |
+
gr.Dropdown(["60% Accuracy", "80% Accuracy"], label="Select Video Model"),
|
227 |
+
gr.Dropdown(["60% Accuracy"], label="Select Audio Model"),
|
228 |
gr.Dropdown(list(decision_frameworks.keys()), label="Select Decision Framework")
|
229 |
]
|
230 |
|