mskov commited on
Commit
0aa0d62
·
1 Parent(s): eb23c5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -15
app.py CHANGED
@@ -1,9 +1,8 @@
1
  from typing import Dict
2
-
3
  import gradio as gr
4
  import whisper
5
  from whisper.tokenizer import get_tokenizer
6
-
7
  import classify
8
 
9
  model_cache = {}
@@ -11,22 +10,15 @@ model_cache = {}
11
 
12
  def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
13
  class_names = class_names.split(",")
14
- print("model name: ", model_name)
 
15
 
16
- # Specify the path to the "whisper-small-esc50" model
17
- model_path = "mskov/whisper-small-esc50"
18
 
19
- if model_name not in model_cache:
20
- model = whisper.load_model(model_path)
21
- print("whisper model not in model_cache ", model)
22
- model_cache[model_name] = model
23
- else:
24
- model = model_cache[model_name]
25
- print("model is in cache ", model)
26
 
27
- # Rest of your code remains the same
28
- tokenizer = get_tokenizer(multilingual=".en" not in model_name)
29
-
30
  internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
31
  model=model,
32
  class_names=class_names,
 
1
  from typing import Dict
2
+ import torch
3
  import gradio as gr
4
  import whisper
5
  from whisper.tokenizer import get_tokenizer
 
6
  import classify
7
 
8
  model_cache = {}
 
10
 
11
  def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
12
  class_names = class_names.split(",")
13
+ # Specify the path to your fine-tuned model and configuration
14
+ model_path = "mskov/whisper-small-esc50"
15
 
16
+ # Load the model
17
+ model = whisper.Whisper.load(model_path)
18
 
19
+ # Load the tokenizer
20
+ tokenizer = whisper.tokenizer.get_tokenizer(multilingual=".en" not in model.config.name)
 
 
 
 
 
21
 
 
 
 
22
  internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
23
  model=model,
24
  class_names=class_names,