Fix OOM
Browse files- app.py +0 -3
- models/vallex.py +3 -0
app.py
CHANGED
@@ -116,7 +116,6 @@ def transcribe_one(model, audio_path):
|
|
116 |
return lang, text_pr
|
117 |
|
118 |
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
119 |
-
global model, text_collater, text_tokenizer, audio_tokenizer
|
120 |
clear_prompts()
|
121 |
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
|
122 |
sr, wav_pr = audio_prompt
|
@@ -159,7 +158,6 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
|
159 |
|
160 |
|
161 |
def make_prompt(name, wav, sr, save=True):
|
162 |
-
global whisper_model
|
163 |
if not isinstance(wav, torch.FloatTensor):
|
164 |
wav = torch.tensor(wav)
|
165 |
if wav.abs().max() > 1:
|
@@ -185,7 +183,6 @@ def make_prompt(name, wav, sr, save=True):
|
|
185 |
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
|
186 |
if len(text) > 150:
|
187 |
return "Rejected, Text too long (should be less than 150 characters)", None
|
188 |
-
global model, text_collater, text_tokenizer, audio_tokenizer
|
189 |
audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
|
190 |
sr, wav_pr = audio_prompt
|
191 |
if len(wav_pr) / sr > 15:
|
|
|
116 |
return lang, text_pr
|
117 |
|
118 |
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
|
|
119 |
clear_prompts()
|
120 |
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
|
121 |
sr, wav_pr = audio_prompt
|
|
|
158 |
|
159 |
|
160 |
def make_prompt(name, wav, sr, save=True):
|
|
|
161 |
if not isinstance(wav, torch.FloatTensor):
|
162 |
wav = torch.tensor(wav)
|
163 |
if wav.abs().max() > 1:
|
|
|
183 |
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
|
184 |
if len(text) > 150:
|
185 |
return "Rejected, Text too long (should be less than 150 characters)", None
|
|
|
186 |
audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
|
187 |
sr, wav_pr = audio_prompt
|
188 |
if len(wav_pr) / sr > 15:
|
models/vallex.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
|
15 |
import random
|
16 |
from typing import Dict, Iterator, List, Tuple, Union
|
|
|
17 |
|
18 |
import numpy as np
|
19 |
import torch
|
@@ -462,6 +463,7 @@ class VALLE(VALLF):
|
|
462 |
**kwargs,
|
463 |
):
|
464 |
raise NotImplementedError
|
|
|
465 |
def inference(
|
466 |
self,
|
467 |
x: torch.Tensor,
|
@@ -674,6 +676,7 @@ class VALLE(VALLF):
|
|
674 |
y_emb[:, prefix_len:] += embedding_layer(samples)
|
675 |
|
676 |
assert len(codes) == self.num_quantizers
|
|
|
677 |
return torch.stack(codes, dim=-1)
|
678 |
|
679 |
def continual(
|
|
|
14 |
|
15 |
import random
|
16 |
from typing import Dict, Iterator, List, Tuple, Union
|
17 |
+
import gc
|
18 |
|
19 |
import numpy as np
|
20 |
import torch
|
|
|
463 |
**kwargs,
|
464 |
):
|
465 |
raise NotImplementedError
|
466 |
+
|
467 |
def inference(
|
468 |
self,
|
469 |
x: torch.Tensor,
|
|
|
676 |
y_emb[:, prefix_len:] += embedding_layer(samples)
|
677 |
|
678 |
assert len(codes) == self.num_quantizers
|
679 |
+
gc.collect()
|
680 |
return torch.stack(codes, dim=-1)
|
681 |
|
682 |
def continual(
|