Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,20 +4,21 @@ import gradio as gr
|
|
4 |
import whisper
|
5 |
from whisper.tokenizer import get_tokenizer
|
6 |
import classify
|
|
|
|
|
7 |
|
8 |
model_cache = {}
|
9 |
|
10 |
|
11 |
def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
|
12 |
class_names = class_names.split(",")
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
tokenizer = whisper.tokenizer.get_tokenizer(multilingual=".en" not in model.config.name)
|
21 |
|
22 |
internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
|
23 |
model=model,
|
@@ -36,6 +37,7 @@ def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Di
|
|
36 |
return {class_name: score for class_name, score in zip(class_names, scores)}
|
37 |
|
38 |
|
|
|
39 |
def main():
|
40 |
CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking],[popping],[sneezing],[sigh],[slurping],[mouth sounds],[clearing thoat],"
|
41 |
AUDIO_PATHS = [
|
@@ -69,7 +71,7 @@ def main():
|
|
69 |
gr.Audio(label="Input Audio",show_label=False,source="microphone",type="filepath"),
|
70 |
gr.Textbox(lines=1, label="Candidate class names (comma-separated)"),
|
71 |
gr.Radio(
|
72 |
-
choices=["
|
73 |
value="small",
|
74 |
label="Model Name",
|
75 |
),
|
|
|
4 |
import whisper
|
5 |
from whisper.tokenizer import get_tokenizer
|
6 |
import classify
|
7 |
+
from transformers import AutoFeatureExtractor, WhisperForAudioClassification
|
8 |
+
from datasets import load_dataset
|
9 |
|
10 |
model_cache = {}
|
11 |
|
12 |
|
13 |
def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
|
14 |
class_names = class_names.split(",")
|
15 |
+
tokenizer = get_tokenizer(multilingual=".en" not in model_name)
|
16 |
+
|
17 |
+
if model_name not in model_cache:
|
18 |
+
model = whisper.load_model(model_name)
|
19 |
+
model_cache[model_name] = model
|
20 |
+
else:
|
21 |
+
model = model_cache[model_name]
|
|
|
22 |
|
23 |
internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
|
24 |
model=model,
|
|
|
37 |
return {class_name: score for class_name, score in zip(class_names, scores)}
|
38 |
|
39 |
|
40 |
+
|
41 |
def main():
|
42 |
CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking],[popping],[sneezing],[sigh],[slurping],[mouth sounds],[clearing thoat],"
|
43 |
AUDIO_PATHS = [
|
|
|
71 |
gr.Audio(label="Input Audio",show_label=False,source="microphone",type="filepath"),
|
72 |
gr.Textbox(lines=1, label="Candidate class names (comma-separated)"),
|
73 |
gr.Radio(
|
74 |
+
choices=["tiny", "base", "small", "medium", "large"],
|
75 |
value="small",
|
76 |
label="Model Name",
|
77 |
),
|