Spaces:
Runtime error
Runtime error
add audio extractor
Browse files- __init__.py +0 -0
- app.py +146 -0
- extractors/__init__.py +0 -0
- extractors/asrdiarization/__init__.py +0 -0
- extractors/asrdiarization/asr_extractor.py +134 -0
- extractors/asrdiarization/diarization_utils.py +141 -0
- requirements.txt +12 -0
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import base64
|
4 |
+
import librosa
|
5 |
+
from extractors.asrdiarization.asr_extractor import ASRExtractorConfig, ASRExtractor
|
6 |
+
from indexify_extractor_sdk import Content
|
7 |
+
|
8 |
+
MAX_AUDIO_MINUTES = 60 # wont try to transcribe if longer than this
|
9 |
+
|
10 |
+
asr_extractor = ASRExtractor()
|
11 |
+
|
12 |
+
def check_audio(audio_filepath):
|
13 |
+
"""
|
14 |
+
Do not convert and raise error if audio too long.
|
15 |
+
"""
|
16 |
+
data, sr = librosa.load(audio_filepath, sr=None, mono=True)
|
17 |
+
duration = librosa.get_duration(y=data, sr=sr)
|
18 |
+
|
19 |
+
if duration / 60.0 > MAX_AUDIO_MINUTES:
|
20 |
+
raise gr.Error(
|
21 |
+
f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. "
|
22 |
+
"If you wish, you may trim the audio using the Audio viewer in Step 1 "
|
23 |
+
"(click on the scissors icon to start trimming audio)."
|
24 |
+
)
|
25 |
+
|
26 |
+
return audio_filepath
|
27 |
+
|
28 |
+
@spaces.GPU
|
29 |
+
def transcribe(audio_filepath, task, batch_size, chunk_length_s, sampling_rate, language, num_speakers, min_speakers, max_speakers, assisted):
|
30 |
+
if audio_filepath is None:
|
31 |
+
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
|
32 |
+
|
33 |
+
audio_filepath = check_audio(audio_filepath)
|
34 |
+
|
35 |
+
with open(audio_filepath, "rb") as f:
|
36 |
+
converted_audio_filepath = base64.b64encode(f.read()).decode("utf-8")
|
37 |
+
|
38 |
+
content = Content(content_type="audio/mpeg", data=converted_audio_filepath)
|
39 |
+
config = ASRExtractorConfig(task=task, batch_size=batch_size, chunk_length_s=chunk_length_s, sampling_rate=sampling_rate, language=language, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers, assisted=assisted)
|
40 |
+
|
41 |
+
result = asr_extractor.extract(content, config)
|
42 |
+
text_content = next(content.data.decode('utf-8') for content in result)
|
43 |
+
|
44 |
+
return text_content
|
45 |
+
|
46 |
+
with gr.Blocks(
|
47 |
+
title="ASR + diarization + speculative decoding with Indexify"
|
48 |
+
) as audio_demo:
|
49 |
+
|
50 |
+
gr.HTML("<h1 style='text-align: center'>ASR + diarization + speculative decoding with Indexify</h1>")
|
51 |
+
gr.HTML("<p style='text-align: center'>Indexify is a scalable realtime and continuous indexing and structured extraction engine for unstructured data to build generative AI applications</p>")
|
52 |
+
gr.HTML("<h3 style='text-align: center'>If you like this demo, please ⭐ Star us on <a href='https://github.com/tensorlakeai/indexify' target='_blank'>GitHub</a>!</h3>")
|
53 |
+
|
54 |
+
with gr.Row():
|
55 |
+
with gr.Column():
|
56 |
+
gr.HTML(
|
57 |
+
"<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
|
58 |
+
|
59 |
+
"<p style='color: #A0A0A0;'>Use this demo for audio files only up to 60 mins long. "
|
60 |
+
"You can transcribe longer files and try various other extractors locally with "
|
61 |
+
"<a href='https://getindexify.io/'>Indexify</a>.</p>"
|
62 |
+
)
|
63 |
+
|
64 |
+
audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
|
65 |
+
|
66 |
+
gr.HTML("<p><b>Step 2:</b> Choose the parameters or leave to default.</p>")
|
67 |
+
|
68 |
+
task = gr.Dropdown(
|
69 |
+
choices=["transcribe", "translate"],
|
70 |
+
value="transcribe",
|
71 |
+
info="passed to the ASR pipeline",
|
72 |
+
label="Task:"
|
73 |
+
)
|
74 |
+
|
75 |
+
with gr.Column():
|
76 |
+
batch_size = gr.Number(
|
77 |
+
value=24,
|
78 |
+
info="for assisted generation the `batch_size` must be set to 1",
|
79 |
+
label="Batch Size:"
|
80 |
+
)
|
81 |
+
chunk_length_s = gr.Number(
|
82 |
+
value=30,
|
83 |
+
info="passed to the ASR pipeline",
|
84 |
+
label="Chunk Length:"
|
85 |
+
)
|
86 |
+
sampling_rate = gr.Number(
|
87 |
+
value=16000,
|
88 |
+
info="`sampling_rate` indicates the sampling rate of the audio to process and is used for preprocessing",
|
89 |
+
label="Sampling Rate:"
|
90 |
+
)
|
91 |
+
language = gr.Dropdown(
|
92 |
+
choices=['english', 'chinese', 'german', 'spanish', 'russian', 'korean', 'french', 'japanese', 'portuguese', 'turkish', 'polish', 'catalan', 'dutch', 'arabic', 'swedish', 'italian', 'indonesian', 'hindi', 'finnish', 'vietnamese', 'hebrew', 'ukrainian', 'greek', 'malay', 'czech', 'romanian', 'danish', 'hungarian', 'tamil', 'norwegian', 'thai', 'urdu', 'croatian', 'bulgarian', 'lithuanian', 'latin', 'maori', 'malayalam', 'welsh', 'slovak', 'telugu', 'persian', 'latvian', 'bengali', 'serbian', 'azerbaijani', 'slovenian', 'kannada', 'estonian', 'macedonian', 'breton', 'basque', 'icelandic', 'armenian', 'nepali', 'mongolian', 'bosnian', 'kazakh', 'albanian', 'swahili', 'galician', 'marathi', 'punjabi', 'sinhala', 'khmer', 'shona', 'yoruba', 'somali', 'afrikaans', 'occitan', 'georgian', 'belarusian', 'tajik', 'sindhi', 'gujarati', 'amharic', 'yiddish', 'lao', 'uzbek', 'faroese', 'haitian creole', 'pashto', 'turkmen', 'nynorsk', 'maltese', 'sanskrit', 'luxembourgish', 'myanmar', 'tibetan', 'tagalog', 'malagasy', 'assamese', 'tatar', 'hawaiian', 'lingala', 'hausa', 'bashkir', 'javanese', 'sundanese', 'cantonese', 'burmese', 'valencian', 'flemish', 'haitian', 'letzeburgesch', 'pushto', 'panjabi', 'moldavian', 'moldovan', 'sinhalese', 'castilian', 'mandarin'],
|
93 |
+
info="passed to the ASR pipeline",
|
94 |
+
label="Language:"
|
95 |
+
)
|
96 |
+
num_speakers = gr.Number(
|
97 |
+
info="passed to diarization pipeline",
|
98 |
+
label="Number of Speakers:"
|
99 |
+
)
|
100 |
+
min_speakers = gr.Number(
|
101 |
+
info="passed to diarization pipeline",
|
102 |
+
label="Minimum Speakers:"
|
103 |
+
)
|
104 |
+
max_speakers = gr.Number(
|
105 |
+
info="passed to diarization pipeline",
|
106 |
+
label="Maximum Speakers:"
|
107 |
+
)
|
108 |
+
assisted = gr.Checkbox(
|
109 |
+
value=False,
|
110 |
+
info="the `assisted` flag tells the pipeline whether to use speculative decoding",
|
111 |
+
label="Assisted?",
|
112 |
+
)
|
113 |
+
|
114 |
+
with gr.Column():
|
115 |
+
|
116 |
+
gr.HTML("<p><b>Step 3:</b> Run the extractor.</p>")
|
117 |
+
|
118 |
+
go_button = gr.Button(
|
119 |
+
value="Run extractor",
|
120 |
+
variant="primary", # make "primary" so it stands out (default is "secondary")
|
121 |
+
)
|
122 |
+
|
123 |
+
model_output_text_box = gr.Textbox(
|
124 |
+
label="Extractor Output",
|
125 |
+
elem_id="model_output_text_box",
|
126 |
+
)
|
127 |
+
|
128 |
+
with gr.Row():
|
129 |
+
|
130 |
+
gr.HTML(
|
131 |
+
"<p style='text-align: center'>"
|
132 |
+
"Developed with 🫶 by <a href='https://getindexify.io/' target='_blank'>Indexify</a> | "
|
133 |
+
"a <a href='https://www.tensorlake.ai/' target='_blank'>Tensorlake</a> product"
|
134 |
+
"</p>"
|
135 |
+
)
|
136 |
+
|
137 |
+
go_button.click(
|
138 |
+
fn=transcribe,
|
139 |
+
inputs = [audio_file, task, batch_size, chunk_length_s, sampling_rate, language, num_speakers, min_speakers, max_speakers, assisted],
|
140 |
+
outputs = [model_output_text_box]
|
141 |
+
)
|
142 |
+
|
143 |
+
demo = gr.TabbedInterface([audio_demo], ["Audio Extraction"], theme=gr.themes.Soft())
|
144 |
+
|
145 |
+
demo.queue()
|
146 |
+
demo.launch()
|
extractors/__init__.py
ADDED
File without changes
|
extractors/asrdiarization/__init__.py
ADDED
File without changes
|
extractors/asrdiarization/asr_extractor.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
import base64
|
4 |
+
import os
|
5 |
+
|
6 |
+
from indexify_extractor_sdk import Content, Extractor, Feature
|
7 |
+
from pyannote.audio import Pipeline
|
8 |
+
from transformers import pipeline, AutoModelForCausalLM
|
9 |
+
from .diarization_utils import diarize
|
10 |
+
from huggingface_hub import HfApi
|
11 |
+
from starlette.exceptions import HTTPException
|
12 |
+
|
13 |
+
from pydantic import BaseModel
|
14 |
+
from pydantic_settings import BaseSettings
|
15 |
+
from typing import Optional, Literal, List, Union
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
token = os.getenv('HF_TOKEN')
|
19 |
+
|
20 |
+
class ModelSettings(BaseSettings):
|
21 |
+
asr_model: str = "openai/whisper-large-v3"
|
22 |
+
assistant_model: Optional[str] = "distil-whisper/distil-large-v3"
|
23 |
+
diarization_model: Optional[str] = "pyannote/speaker-diarization-3.1"
|
24 |
+
hf_token: Optional[str] = token
|
25 |
+
|
26 |
+
model_settings = ModelSettings()
|
27 |
+
|
28 |
+
class ASRExtractorConfig(BaseModel):
|
29 |
+
task: Literal["transcribe", "translate"] = "transcribe"
|
30 |
+
batch_size: int = 24
|
31 |
+
assisted: bool = False
|
32 |
+
chunk_length_s: int = 30
|
33 |
+
sampling_rate: int = 16000
|
34 |
+
language: Optional[str] = None
|
35 |
+
num_speakers: Optional[int] = None
|
36 |
+
min_speakers: Optional[int] = None
|
37 |
+
max_speakers: Optional[int] = None
|
38 |
+
|
39 |
+
class ASRExtractor(Extractor):
|
40 |
+
name = "tensorlake/asrdiarization"
|
41 |
+
description = "Powerful ASR + diarization + speculative decoding."
|
42 |
+
system_dependencies = ["ffmpeg"]
|
43 |
+
input_mime_types = ["audio", "audio/mpeg"]
|
44 |
+
|
45 |
+
def __init__(self):
|
46 |
+
super(ASRExtractor, self).__init__()
|
47 |
+
|
48 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
49 |
+
logger.info(f"Using device: {device.type}")
|
50 |
+
torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
|
51 |
+
|
52 |
+
self.assistant_model = AutoModelForCausalLM.from_pretrained(
|
53 |
+
model_settings.assistant_model,
|
54 |
+
torch_dtype=torch_dtype,
|
55 |
+
low_cpu_mem_usage=True,
|
56 |
+
use_safetensors=True
|
57 |
+
) if model_settings.assistant_model else None
|
58 |
+
|
59 |
+
if self.assistant_model:
|
60 |
+
self.assistant_model.to(device)
|
61 |
+
|
62 |
+
self.asr_pipeline = pipeline(
|
63 |
+
"automatic-speech-recognition",
|
64 |
+
model=model_settings.asr_model,
|
65 |
+
torch_dtype=torch_dtype,
|
66 |
+
device=device
|
67 |
+
)
|
68 |
+
|
69 |
+
if model_settings.diarization_model:
|
70 |
+
# diarization pipeline doesn't raise if there is no token
|
71 |
+
HfApi().whoami(model_settings.hf_token)
|
72 |
+
self.diarization_pipeline = Pipeline.from_pretrained(
|
73 |
+
checkpoint_path=model_settings.diarization_model,
|
74 |
+
use_auth_token=model_settings.hf_token,
|
75 |
+
)
|
76 |
+
self.diarization_pipeline.to(device)
|
77 |
+
else:
|
78 |
+
self.diarization_pipeline = None
|
79 |
+
|
80 |
+
def extract(self, content: Content, params: ASRExtractorConfig) -> List[Union[Feature, Content]]:
|
81 |
+
file = base64.b64decode(content.data)
|
82 |
+
logger.info(f"inference params: {params}")
|
83 |
+
|
84 |
+
generate_kwargs = {
|
85 |
+
"task": params.task,
|
86 |
+
"language": params.language,
|
87 |
+
"assistant_model": self.assistant_model if params.assisted else None
|
88 |
+
}
|
89 |
+
|
90 |
+
try:
|
91 |
+
asr_outputs = self.asr_pipeline(
|
92 |
+
file,
|
93 |
+
chunk_length_s=params.chunk_length_s,
|
94 |
+
batch_size=params.batch_size,
|
95 |
+
generate_kwargs=generate_kwargs,
|
96 |
+
return_timestamps=True,
|
97 |
+
)
|
98 |
+
except RuntimeError as e:
|
99 |
+
logger.error(f"ASR inference error: {str(e)}")
|
100 |
+
raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Unknown error diring ASR inference: {str(e)}")
|
103 |
+
raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")
|
104 |
+
|
105 |
+
if self.diarization_pipeline:
|
106 |
+
try:
|
107 |
+
transcript = diarize(self.diarization_pipeline, file, params, asr_outputs)
|
108 |
+
except RuntimeError as e:
|
109 |
+
logger.error(f"Diarization inference error: {str(e)}")
|
110 |
+
raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f"Unknown error during diarization: {str(e)}")
|
113 |
+
raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
|
114 |
+
else:
|
115 |
+
transcript = []
|
116 |
+
|
117 |
+
feature = Feature.metadata(value={"chunks": asr_outputs["chunks"], "text": asr_outputs["text"]})
|
118 |
+
return [Content.from_text(str(transcript), features=[feature])]
|
119 |
+
|
120 |
+
def sample_input(self) -> Content:
|
121 |
+
filepath = "sample.mp3"
|
122 |
+
with open(filepath, 'rb') as f:
|
123 |
+
audio_encoded = base64.b64encode(f.read()).decode("utf-8")
|
124 |
+
return Content(content_type="audio/mpeg", data=audio_encoded)
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
filepath = "sample.mp3"
|
128 |
+
with open(filepath, 'rb') as f:
|
129 |
+
audio_encoded = base64.b64encode(f.read()).decode("utf-8")
|
130 |
+
data = Content(content_type="audio/mpeg", data=audio_encoded)
|
131 |
+
params = ASRExtractorConfig(batch_size=24)
|
132 |
+
extractor = ASRExtractor()
|
133 |
+
results = extractor.extract(data, params=params)
|
134 |
+
print(results)
|
extractors/asrdiarization/diarization_utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torchaudio import functional as F
|
4 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
5 |
+
from starlette.exceptions import HTTPException
|
6 |
+
import sys
|
7 |
+
|
8 |
+
# Code from insanely-fast-whisper:
|
9 |
+
# https://github.com/Vaibhavs10/insanely-fast-whisper
|
10 |
+
|
11 |
+
import logging
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
def preprocess_inputs(inputs, sampling_rate):
|
15 |
+
inputs = ffmpeg_read(inputs, sampling_rate)
|
16 |
+
|
17 |
+
if sampling_rate != 16000:
|
18 |
+
inputs = F.resample(
|
19 |
+
torch.from_numpy(inputs), sampling_rate, 16000
|
20 |
+
).numpy()
|
21 |
+
|
22 |
+
if len(inputs.shape) != 1:
|
23 |
+
logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}")
|
24 |
+
raise HTTPException(
|
25 |
+
status_code=400,
|
26 |
+
detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}"
|
27 |
+
)
|
28 |
+
|
29 |
+
# diarization model expects float32 torch tensor of shape `(channels, seq_len)`
|
30 |
+
diarizer_inputs = torch.from_numpy(inputs).float()
|
31 |
+
diarizer_inputs = diarizer_inputs.unsqueeze(0)
|
32 |
+
|
33 |
+
return inputs, diarizer_inputs
|
34 |
+
|
35 |
+
|
36 |
+
def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
|
37 |
+
diarization = diarization_pipeline(
|
38 |
+
{"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate},
|
39 |
+
num_speakers=parameters.num_speakers,
|
40 |
+
min_speakers=parameters.min_speakers,
|
41 |
+
max_speakers=parameters.max_speakers,
|
42 |
+
)
|
43 |
+
|
44 |
+
segments = []
|
45 |
+
for segment, track, label in diarization.itertracks(yield_label=True):
|
46 |
+
segments.append(
|
47 |
+
{
|
48 |
+
"segment": {"start": segment.start, "end": segment.end},
|
49 |
+
"track": track,
|
50 |
+
"label": label,
|
51 |
+
}
|
52 |
+
)
|
53 |
+
|
54 |
+
# diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
|
55 |
+
# we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
|
56 |
+
new_segments = []
|
57 |
+
prev_segment = cur_segment = segments[0]
|
58 |
+
|
59 |
+
for i in range(1, len(segments)):
|
60 |
+
cur_segment = segments[i]
|
61 |
+
|
62 |
+
# check if we have changed speaker ("label")
|
63 |
+
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
|
64 |
+
# add the start/end times for the super-segment to the new list
|
65 |
+
new_segments.append(
|
66 |
+
{
|
67 |
+
"segment": {
|
68 |
+
"start": prev_segment["segment"]["start"],
|
69 |
+
"end": cur_segment["segment"]["start"],
|
70 |
+
},
|
71 |
+
"speaker": prev_segment["label"],
|
72 |
+
}
|
73 |
+
)
|
74 |
+
prev_segment = segments[i]
|
75 |
+
|
76 |
+
# add the last segment(s) if there was no speaker change
|
77 |
+
new_segments.append(
|
78 |
+
{
|
79 |
+
"segment": {
|
80 |
+
"start": prev_segment["segment"]["start"],
|
81 |
+
"end": cur_segment["segment"]["end"],
|
82 |
+
},
|
83 |
+
"speaker": prev_segment["label"],
|
84 |
+
}
|
85 |
+
)
|
86 |
+
|
87 |
+
return new_segments
|
88 |
+
|
89 |
+
|
90 |
+
def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list:
|
91 |
+
# get the end timestamps for each chunk from the ASR output
|
92 |
+
end_timestamps = np.array(
|
93 |
+
[chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript])
|
94 |
+
segmented_preds = []
|
95 |
+
|
96 |
+
# align the diarizer timestamps and the ASR timestamps
|
97 |
+
for segment in new_segments:
|
98 |
+
# get the diarizer end timestamp
|
99 |
+
end_time = segment["segment"]["end"]
|
100 |
+
# find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
|
101 |
+
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
|
102 |
+
|
103 |
+
if group_by_speaker:
|
104 |
+
segmented_preds.append(
|
105 |
+
{
|
106 |
+
"speaker": segment["speaker"],
|
107 |
+
"text": "".join(
|
108 |
+
[chunk["text"] for chunk in transcript[: upto_idx + 1]]
|
109 |
+
),
|
110 |
+
"timestamp": (
|
111 |
+
transcript[0]["timestamp"][0],
|
112 |
+
transcript[upto_idx]["timestamp"][1],
|
113 |
+
),
|
114 |
+
}
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
for i in range(upto_idx + 1):
|
118 |
+
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
|
119 |
+
|
120 |
+
# crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
|
121 |
+
transcript = transcript[upto_idx + 1:]
|
122 |
+
end_timestamps = end_timestamps[upto_idx + 1:]
|
123 |
+
|
124 |
+
if len(end_timestamps) == 0:
|
125 |
+
break
|
126 |
+
|
127 |
+
return segmented_preds
|
128 |
+
|
129 |
+
|
130 |
+
def diarize(diarization_pipeline, file, parameters, asr_outputs):
|
131 |
+
_, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
|
132 |
+
|
133 |
+
segments = diarize_audio(
|
134 |
+
diarizer_inputs,
|
135 |
+
diarization_pipeline,
|
136 |
+
parameters
|
137 |
+
)
|
138 |
+
|
139 |
+
return post_process_segments_and_transcripts(
|
140 |
+
segments, asr_outputs["chunks"], group_by_speaker=False
|
141 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
indexify-extractor-sdk
|
2 |
+
accelerate==0.27.2
|
3 |
+
pyannote-audio==3.1.1
|
4 |
+
transformers==4.40.2
|
5 |
+
numpy==1.26.4
|
6 |
+
torchaudio==2.2.0
|
7 |
+
pydantic==2.6.3
|
8 |
+
pydantic-settings==2.2.1
|
9 |
+
librosa==0.10.2
|
10 |
+
torch==2.2.0
|
11 |
+
bitsandbytes
|
12 |
+
peft
|