mskov commited on
Commit
24a2384
·
1 Parent(s): 0aa0d62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -4,20 +4,21 @@ import gradio as gr
4
  import whisper
5
  from whisper.tokenizer import get_tokenizer
6
  import classify
 
 
7
 
8
  model_cache = {}
9
 
10
 
11
  def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
12
  class_names = class_names.split(",")
13
- # Specify the path to your fine-tuned model and configuration
14
- model_path = "mskov/whisper-small-esc50"
15
-
16
- # Load the model
17
- model = whisper.Whisper.load(model_path)
18
-
19
- # Load the tokenizer
20
- tokenizer = whisper.tokenizer.get_tokenizer(multilingual=".en" not in model.config.name)
21
 
22
  internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
23
  model=model,
@@ -36,6 +37,7 @@ def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Di
36
  return {class_name: score for class_name, score in zip(class_names, scores)}
37
 
38
 
 
39
  def main():
40
  CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking],[popping],[sneezing],[sigh],[slurping],[mouth sounds],[clearing thoat],"
41
  AUDIO_PATHS = [
@@ -69,7 +71,7 @@ def main():
69
  gr.Audio(label="Input Audio",show_label=False,source="microphone",type="filepath"),
70
  gr.Textbox(lines=1, label="Candidate class names (comma-separated)"),
71
  gr.Radio(
72
- choices=["whisper-small-esc50"],
73
  value="small",
74
  label="Model Name",
75
  ),
 
4
  import whisper
5
  from whisper.tokenizer import get_tokenizer
6
  import classify
7
+ from transformers import AutoFeatureExtractor, WhisperForAudioClassification
8
+ from datasets import load_dataset
9
 
10
  model_cache = {}
11
 
12
 
13
  def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
14
  class_names = class_names.split(",")
15
+ tokenizer = get_tokenizer(multilingual=".en" not in model_name)
16
+
17
+ if model_name not in model_cache:
18
+ model = whisper.load_model(model_name)
19
+ model_cache[model_name] = model
20
+ else:
21
+ model = model_cache[model_name]
 
22
 
23
  internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
24
  model=model,
 
37
  return {class_name: score for class_name, score in zip(class_names, scores)}
38
 
39
 
40
+
41
  def main():
42
  CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking],[popping],[sneezing],[sigh],[slurping],[mouth sounds],[clearing thoat],"
43
  AUDIO_PATHS = [
 
71
  gr.Audio(label="Input Audio",show_label=False,source="microphone",type="filepath"),
72
  gr.Textbox(lines=1, label="Candidate class names (comma-separated)"),
73
  gr.Radio(
74
+ choices=["tiny", "base", "small", "medium", "large"],
75
  value="small",
76
  label="Model Name",
77
  ),