Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
from typing import Dict
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import whisper
|
5 |
from whisper.tokenizer import get_tokenizer
|
6 |
-
|
7 |
import classify
|
8 |
|
9 |
model_cache = {}
|
@@ -11,22 +10,15 @@ model_cache = {}
|
|
11 |
|
12 |
def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
|
13 |
class_names = class_names.split(",")
|
14 |
-
|
|
|
15 |
|
16 |
-
#
|
17 |
-
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
print("whisper model not in model_cache ", model)
|
22 |
-
model_cache[model_name] = model
|
23 |
-
else:
|
24 |
-
model = model_cache[model_name]
|
25 |
-
print("model is in cache ", model)
|
26 |
|
27 |
-
# Rest of your code remains the same
|
28 |
-
tokenizer = get_tokenizer(multilingual=".en" not in model_name)
|
29 |
-
|
30 |
internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
|
31 |
model=model,
|
32 |
class_names=class_names,
|
|
|
1 |
from typing import Dict
|
2 |
+
import torch
|
3 |
import gradio as gr
|
4 |
import whisper
|
5 |
from whisper.tokenizer import get_tokenizer
|
|
|
6 |
import classify
|
7 |
|
8 |
model_cache = {}
|
|
|
10 |
|
11 |
def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
|
12 |
class_names = class_names.split(",")
|
13 |
+
# Specify the path to your fine-tuned model and configuration
|
14 |
+
model_path = "mskov/whisper-small-esc50"
|
15 |
|
16 |
+
# Load the model
|
17 |
+
model = whisper.Whisper.load(model_path)
|
18 |
|
19 |
+
# Load the tokenizer
|
20 |
+
tokenizer = whisper.tokenizer.get_tokenizer(multilingual=".en" not in model.config.name)
|
|
|
|
|
|
|
|
|
|
|
21 |
|
|
|
|
|
|
|
22 |
internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
|
23 |
model=model,
|
24 |
class_names=class_names,
|