Spaces:
Running
on
A10G
Running
on
A10G
removed offloading
Browse files
app.py
CHANGED
@@ -97,7 +97,7 @@ model.eval()
|
|
97 |
audio_tokenizer = AudioTokenizer(device)
|
98 |
|
99 |
# ASR
|
100 |
-
whisper_model = whisper.load_model("medium")
|
101 |
|
102 |
def clear_prompts():
|
103 |
try:
|
@@ -125,7 +125,7 @@ def transcribe_one(model, audio_path):
|
|
125 |
print(f"Detected language: {max(probs, key=probs.get)}")
|
126 |
lang = max(probs, key=probs.get)
|
127 |
# decode the audio
|
128 |
-
options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=
|
129 |
result = whisper.decode(model, mel, options)
|
130 |
|
131 |
# print the recognized text
|
@@ -168,7 +168,6 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio):
|
|
168 |
def make_prompt(name, wav, sr, save=True):
|
169 |
|
170 |
global whisper_model
|
171 |
-
whisper_model.to(device)
|
172 |
if not isinstance(wav, torch.FloatTensor):
|
173 |
wav = torch.tensor(wav)
|
174 |
if wav.abs().max() > 1:
|
@@ -188,7 +187,6 @@ def make_prompt(name, wav, sr, save=True):
|
|
188 |
os.remove(f"./prompts/{name}.wav")
|
189 |
os.remove(f"./prompts/{name}.txt")
|
190 |
|
191 |
-
whisper_model.cpu()
|
192 |
torch.cuda.empty_cache()
|
193 |
return text, lang
|
194 |
|
|
|
97 |
audio_tokenizer = AudioTokenizer(device)
|
98 |
|
99 |
# ASR
|
100 |
+
whisper_model = whisper.load_model("medium")
|
101 |
|
102 |
def clear_prompts():
|
103 |
try:
|
|
|
125 |
print(f"Detected language: {max(probs, key=probs.get)}")
|
126 |
lang = max(probs, key=probs.get)
|
127 |
# decode the audio
|
128 |
+
options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
|
129 |
result = whisper.decode(model, mel, options)
|
130 |
|
131 |
# print the recognized text
|
|
|
168 |
def make_prompt(name, wav, sr, save=True):
|
169 |
|
170 |
global whisper_model
|
|
|
171 |
if not isinstance(wav, torch.FloatTensor):
|
172 |
wav = torch.tensor(wav)
|
173 |
if wav.abs().max() > 1:
|
|
|
187 |
os.remove(f"./prompts/{name}.wav")
|
188 |
os.remove(f"./prompts/{name}.txt")
|
189 |
|
|
|
190 |
torch.cuda.empty_cache()
|
191 |
return text, lang
|
192 |
|