Spaces:
Running
Running
jhj0517
commited on
Commit
·
0e66735
1
Parent(s):
1f71b24
add spaces annotation
Browse files- modules/diarize/diarizer.py +6 -0
- modules/translation/nllb_inference.py +4 -0
- modules/vad/silero_vad.py +3 -0
- modules/whisper/whisper_base.py +8 -0
- requirements.txt +2 -1
modules/diarize/diarizer.py
CHANGED
@@ -3,12 +3,14 @@ import torch
|
|
3 |
from typing import List
|
4 |
import time
|
5 |
import logging
|
|
|
6 |
|
7 |
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
8 |
from modules.diarize.audio_loader import load_audio
|
9 |
|
10 |
|
11 |
class Diarizer:
|
|
|
12 |
def __init__(self,
|
13 |
model_dir: str = os.path.join("models", "Diarization")
|
14 |
):
|
@@ -19,6 +21,7 @@ class Diarizer:
|
|
19 |
os.makedirs(self.model_dir, exist_ok=True)
|
20 |
self.pipe = None
|
21 |
|
|
|
22 |
def run(self,
|
23 |
audio: str,
|
24 |
transcribed_result: List[dict],
|
@@ -73,6 +76,7 @@ class Diarizer:
|
|
73 |
elapsed_time = time.time() - start_time
|
74 |
return diarized_result["segments"], elapsed_time
|
75 |
|
|
|
76 |
def update_pipe(self,
|
77 |
use_auth_token: str,
|
78 |
device: str
|
@@ -110,6 +114,7 @@ class Diarizer:
|
|
110 |
logger.disabled = False
|
111 |
|
112 |
@staticmethod
|
|
|
113 |
def get_device():
|
114 |
if torch.cuda.is_available():
|
115 |
return "cuda"
|
@@ -119,6 +124,7 @@ class Diarizer:
|
|
119 |
return "cpu"
|
120 |
|
121 |
@staticmethod
|
|
|
122 |
def get_available_device():
|
123 |
devices = ["cpu"]
|
124 |
if torch.cuda.is_available():
|
|
|
3 |
from typing import List
|
4 |
import time
|
5 |
import logging
|
6 |
+
import spaces
|
7 |
|
8 |
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
9 |
from modules.diarize.audio_loader import load_audio
|
10 |
|
11 |
|
12 |
class Diarizer:
|
13 |
+
@spaces.GPU
|
14 |
def __init__(self,
|
15 |
model_dir: str = os.path.join("models", "Diarization")
|
16 |
):
|
|
|
21 |
os.makedirs(self.model_dir, exist_ok=True)
|
22 |
self.pipe = None
|
23 |
|
24 |
+
@spaces.GPU
|
25 |
def run(self,
|
26 |
audio: str,
|
27 |
transcribed_result: List[dict],
|
|
|
76 |
elapsed_time = time.time() - start_time
|
77 |
return diarized_result["segments"], elapsed_time
|
78 |
|
79 |
+
@spaces.GPU
|
80 |
def update_pipe(self,
|
81 |
use_auth_token: str,
|
82 |
device: str
|
|
|
114 |
logger.disabled = False
|
115 |
|
116 |
@staticmethod
|
117 |
+
@spaces.GPU
|
118 |
def get_device():
|
119 |
if torch.cuda.is_available():
|
120 |
return "cuda"
|
|
|
124 |
return "cpu"
|
125 |
|
126 |
@staticmethod
|
127 |
+
@spaces.GPU
|
128 |
def get_available_device():
|
129 |
devices = ["cpu"]
|
130 |
if torch.cuda.is_available():
|
modules/translation/nllb_inference.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
2 |
import gradio as gr
|
3 |
import os
|
|
|
4 |
|
5 |
from modules.translation.translation_base import TranslationBase
|
6 |
|
7 |
|
8 |
class NLLBInference(TranslationBase):
|
|
|
9 |
def __init__(self,
|
10 |
model_dir: str,
|
11 |
output_dir: str
|
@@ -20,12 +22,14 @@ class NLLBInference(TranslationBase):
|
|
20 |
self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
21 |
self.pipeline = None
|
22 |
|
|
|
23 |
def translate(self,
|
24 |
text: str
|
25 |
):
|
26 |
result = self.pipeline(text)
|
27 |
return result[0]['translation_text']
|
28 |
|
|
|
29 |
def update_model(self,
|
30 |
model_size: str,
|
31 |
src_lang: str,
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
+
import spaces
|
5 |
|
6 |
from modules.translation.translation_base import TranslationBase
|
7 |
|
8 |
|
9 |
class NLLBInference(TranslationBase):
|
10 |
+
@spaces.GPU
|
11 |
def __init__(self,
|
12 |
model_dir: str,
|
13 |
output_dir: str
|
|
|
22 |
self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
23 |
self.pipeline = None
|
24 |
|
25 |
+
@spaces.GPU
|
26 |
def translate(self,
|
27 |
text: str
|
28 |
):
|
29 |
result = self.pipeline(text)
|
30 |
return result[0]['translation_text']
|
31 |
|
32 |
+
@spaces.GPU
|
33 |
def update_model(self,
|
34 |
model_size: str,
|
35 |
src_lang: str,
|
modules/vad/silero_vad.py
CHANGED
@@ -4,12 +4,14 @@ from typing import BinaryIO, Union, List, Optional
|
|
4 |
import warnings
|
5 |
import faster_whisper
|
6 |
import gradio as gr
|
|
|
7 |
|
8 |
|
9 |
class SileroVAD:
|
10 |
def __init__(self):
|
11 |
self.sampling_rate = 16000
|
12 |
|
|
|
13 |
def run(self,
|
14 |
audio: Union[str, BinaryIO, np.ndarray],
|
15 |
vad_parameters: VadOptions,
|
@@ -55,6 +57,7 @@ class SileroVAD:
|
|
55 |
return audio
|
56 |
|
57 |
@staticmethod
|
|
|
58 |
def get_speech_timestamps(
|
59 |
audio: np.ndarray,
|
60 |
vad_options: Optional[VadOptions] = None,
|
|
|
4 |
import warnings
|
5 |
import faster_whisper
|
6 |
import gradio as gr
|
7 |
+
import spaces
|
8 |
|
9 |
|
10 |
class SileroVAD:
|
11 |
def __init__(self):
|
12 |
self.sampling_rate = 16000
|
13 |
|
14 |
+
@spaces.GPU
|
15 |
def run(self,
|
16 |
audio: Union[str, BinaryIO, np.ndarray],
|
17 |
vad_parameters: VadOptions,
|
|
|
57 |
return audio
|
58 |
|
59 |
@staticmethod
|
60 |
+
@spaces.GPU
|
61 |
def get_speech_timestamps(
|
62 |
audio: np.ndarray,
|
63 |
vad_options: Optional[VadOptions] = None,
|
modules/whisper/whisper_base.py
CHANGED
@@ -9,6 +9,7 @@ from datetime import datetime
|
|
9 |
from argparse import Namespace
|
10 |
from faster_whisper.vad import VadOptions
|
11 |
from dataclasses import astuple
|
|
|
12 |
|
13 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
14 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
@@ -18,6 +19,7 @@ from modules.vad.silero_vad import SileroVAD
|
|
18 |
|
19 |
|
20 |
class WhisperBase(ABC):
|
|
|
21 |
def __init__(self,
|
22 |
model_dir: str,
|
23 |
output_dir: str,
|
@@ -41,6 +43,7 @@ class WhisperBase(ABC):
|
|
41 |
self.vad = SileroVAD()
|
42 |
|
43 |
@abstractmethod
|
|
|
44 |
def transcribe(self,
|
45 |
audio: Union[str, BinaryIO, np.ndarray],
|
46 |
progress: gr.Progress,
|
@@ -49,6 +52,7 @@ class WhisperBase(ABC):
|
|
49 |
pass
|
50 |
|
51 |
@abstractmethod
|
|
|
52 |
def update_model(self,
|
53 |
model_size: str,
|
54 |
compute_type: str,
|
@@ -56,6 +60,7 @@ class WhisperBase(ABC):
|
|
56 |
):
|
57 |
pass
|
58 |
|
|
|
59 |
def run(self,
|
60 |
audio: Union[str, BinaryIO, np.ndarray],
|
61 |
progress: gr.Progress,
|
@@ -121,6 +126,7 @@ class WhisperBase(ABC):
|
|
121 |
elapsed_time += elapsed_time_diarization
|
122 |
return result, elapsed_time
|
123 |
|
|
|
124 |
def transcribe_file(self,
|
125 |
files: list,
|
126 |
file_format: str,
|
@@ -191,6 +197,7 @@ class WhisperBase(ABC):
|
|
191 |
if not files:
|
192 |
self.remove_input_files([file.name for file in files])
|
193 |
|
|
|
194 |
def transcribe_mic(self,
|
195 |
mic_audio: str,
|
196 |
file_format: str,
|
@@ -402,6 +409,7 @@ class WhisperBase(ABC):
|
|
402 |
return "cpu"
|
403 |
|
404 |
@staticmethod
|
|
|
405 |
def release_cuda_memory():
|
406 |
if torch.cuda.is_available():
|
407 |
torch.cuda.empty_cache()
|
|
|
9 |
from argparse import Namespace
|
10 |
from faster_whisper.vad import VadOptions
|
11 |
from dataclasses import astuple
|
12 |
+
import spaces
|
13 |
|
14 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
15 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
|
|
19 |
|
20 |
|
21 |
class WhisperBase(ABC):
|
22 |
+
@spaces.GPU
|
23 |
def __init__(self,
|
24 |
model_dir: str,
|
25 |
output_dir: str,
|
|
|
43 |
self.vad = SileroVAD()
|
44 |
|
45 |
@abstractmethod
|
46 |
+
@spaces.GPU
|
47 |
def transcribe(self,
|
48 |
audio: Union[str, BinaryIO, np.ndarray],
|
49 |
progress: gr.Progress,
|
|
|
52 |
pass
|
53 |
|
54 |
@abstractmethod
|
55 |
+
@spaces.GPU
|
56 |
def update_model(self,
|
57 |
model_size: str,
|
58 |
compute_type: str,
|
|
|
60 |
):
|
61 |
pass
|
62 |
|
63 |
+
@spaces.GPU
|
64 |
def run(self,
|
65 |
audio: Union[str, BinaryIO, np.ndarray],
|
66 |
progress: gr.Progress,
|
|
|
126 |
elapsed_time += elapsed_time_diarization
|
127 |
return result, elapsed_time
|
128 |
|
129 |
+
@spaces.GPU
|
130 |
def transcribe_file(self,
|
131 |
files: list,
|
132 |
file_format: str,
|
|
|
197 |
if not files:
|
198 |
self.remove_input_files([file.name for file in files])
|
199 |
|
200 |
+
@spaces.GPU
|
201 |
def transcribe_mic(self,
|
202 |
mic_audio: str,
|
203 |
file_format: str,
|
|
|
409 |
return "cpu"
|
410 |
|
411 |
@staticmethod
|
412 |
+
@spaces.GPU
|
413 |
def release_cuda_memory():
|
414 |
if torch.cuda.is_available():
|
415 |
torch.cuda.empty_cache()
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ faster-whisper==1.0.2
|
|
5 |
transformers
|
6 |
gradio
|
7 |
pytube
|
8 |
-
pyannote.audio==3.3.1
|
|
|
|
5 |
transformers
|
6 |
gradio
|
7 |
pytube
|
8 |
+
pyannote.audio==3.3.1
|
9 |
+
spaces
|