Commit
·
f0061fc
1
Parent(s):
074b1c4
wip: allow to use any pyannote pipeline
Browse files
pyannote_viewer/backend/pyannote_viewer/pyannote_viewer.py
CHANGED
|
@@ -19,7 +19,9 @@ from gradio.events import Events
|
|
| 19 |
from gradio.exceptions import Error
|
| 20 |
|
| 21 |
from pyannote.core.annotation import Annotation
|
|
|
|
| 22 |
|
|
|
|
| 23 |
|
| 24 |
@dataclasses.dataclass
|
| 25 |
class WaveformOptions:
|
|
@@ -249,7 +251,7 @@ class PyannoteViewer(
|
|
| 249 |
)
|
| 250 |
|
| 251 |
def postprocess(
|
| 252 |
-
self, value: Tuple[Annotation, np.ndarray] | None
|
| 253 |
) -> FileData | bytes | None:
|
| 254 |
"""
|
| 255 |
Parameters:
|
|
@@ -260,30 +262,40 @@ class PyannoteViewer(
|
|
| 260 |
if value is None:
|
| 261 |
return None
|
| 262 |
|
| 263 |
-
annotations,
|
|
|
|
| 264 |
labels = annotations.labels()
|
| 265 |
|
| 266 |
# format diarization output
|
| 267 |
segments = []
|
| 268 |
for segment, _, label in annotations.itertracks(yield_label=True):
|
| 269 |
-
label_idx = labels.index(label)
|
| 270 |
segments.append(
|
| 271 |
Segment(start=segment.start, end=segment.end, channel=label_idx)
|
| 272 |
)
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
return {
|
| 284 |
"segments": segments,
|
| 285 |
"labels": labels,
|
| 286 |
-
"
|
|
|
|
| 287 |
}
|
| 288 |
|
| 289 |
def stream_output(
|
|
|
|
| 19 |
from gradio.exceptions import Error
|
| 20 |
|
| 21 |
from pyannote.core.annotation import Annotation
|
| 22 |
+
from pyannote.core.feature import SlidingWindowFeature
|
| 23 |
|
| 24 |
+
import torchaudio
|
| 25 |
|
| 26 |
@dataclasses.dataclass
|
| 27 |
class WaveformOptions:
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
def postprocess(
|
| 254 |
+
self, value: Tuple[Annotation, np.ndarray | Path | str] | None
|
| 255 |
) -> FileData | bytes | None:
|
| 256 |
"""
|
| 257 |
Parameters:
|
|
|
|
| 262 |
if value is None:
|
| 263 |
return None
|
| 264 |
|
| 265 |
+
annotations, audio = value
|
| 266 |
+
|
| 267 |
labels = annotations.labels()
|
| 268 |
|
| 269 |
# format diarization output
|
| 270 |
segments = []
|
| 271 |
for segment, _, label in annotations.itertracks(yield_label=True):
|
| 272 |
+
label_idx = labels.index(label) if isinstance(audio, SlidingWindowFeature) else 0
|
| 273 |
segments.append(
|
| 274 |
Segment(start=segment.start, end=segment.end, channel=label_idx)
|
| 275 |
)
|
| 276 |
|
| 277 |
+
if isinstance(audio, SlidingWindowFeature):
|
| 278 |
+
# save sources in cache
|
| 279 |
+
audio_filepath = processing_utils.save_audio_to_cache(
|
| 280 |
+
data=audio.data,
|
| 281 |
+
sample_rate=16_000,
|
| 282 |
+
format=self.format,
|
| 283 |
+
cache_dir=self.GRADIO_CACHE,
|
| 284 |
+
)
|
| 285 |
+
multichannel = True
|
| 286 |
+
elif isinstance(audio, (Path, str)):
|
| 287 |
+
audio_filepath = audio
|
| 288 |
+
multichannel = False
|
| 289 |
+
else:
|
| 290 |
+
raise ValueError("Unknown type for audio value")
|
| 291 |
+
|
| 292 |
+
orig_name = Path(audio_filepath).name
|
| 293 |
|
| 294 |
return {
|
| 295 |
"segments": segments,
|
| 296 |
"labels": labels,
|
| 297 |
+
"multichannel": multichannel,
|
| 298 |
+
"sources_file": FileData(path=audio_filepath, orig_name=orig_name),
|
| 299 |
}
|
| 300 |
|
| 301 |
def stream_output(
|
pyannote_viewer/demo/app.py
CHANGED
|
@@ -5,10 +5,20 @@ import os
|
|
| 5 |
|
| 6 |
|
| 7 |
def apply_pipeline(audio: str) -> tuple:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
pipeline = Pipeline.from_pretrained(
|
| 9 |
-
"pyannote/
|
| 10 |
)
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
with gr.Blocks() as demo:
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def apply_pipeline(audio: str) -> tuple:
|
| 8 |
+
# pipeline = Pipeline.from_pretrained(
|
| 9 |
+
# "pyannote/speech-separation-ami-1.0", use_auth_token=os.environ["HF_TOKEN"]
|
| 10 |
+
# )
|
| 11 |
+
|
| 12 |
pipeline = Pipeline.from_pretrained(
|
| 13 |
+
"pyannote/speaker-diarization-3.1", use_auth_token=os.environ["HF_TOKEN"]
|
| 14 |
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
outputs = pipeline(audio)
|
| 18 |
+
if isinstance(outputs, tuple):
|
| 19 |
+
return outputs
|
| 20 |
+
else:
|
| 21 |
+
return (outputs, audio)
|
| 22 |
|
| 23 |
|
| 24 |
with gr.Blocks() as demo:
|
pyannote_viewer/frontend/Index.svelte
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
export let elem_classes: string[] = [];
|
| 17 |
export let visible = true;
|
| 18 |
export let interactive: boolean;
|
| 19 |
-
export let value: null | {"segments": Segment[], "labels" : string[], "sources_file": FileData} = null;
|
| 20 |
export let sources:
|
| 21 |
| ["microphone"]
|
| 22 |
| ["upload"]
|
|
|
|
| 16 |
export let elem_classes: string[] = [];
|
| 17 |
export let visible = true;
|
| 18 |
export let interactive: boolean;
|
| 19 |
+
export let value: null | {"segments": Segment[], "labels" : string[], "multichannel": boolean, "sources_file": FileData} = null;
|
| 20 |
export let sources:
|
| 21 |
| ["microphone"]
|
| 22 |
| ["upload"]
|
pyannote_viewer/frontend/player/AudioPlayer.svelte
CHANGED
|
@@ -72,6 +72,7 @@
|
|
| 72 |
$: waveform?.on("decode", (duration: any) => {
|
| 73 |
audioDecoded = true;
|
| 74 |
const numChannels = waveform.getDecodedData().numberOfChannels;
|
|
|
|
| 75 |
audio_duration = duration;
|
| 76 |
durationRef && (durationRef.textContent = format_time(duration));
|
| 77 |
|
|
|
|
| 72 |
$: waveform?.on("decode", (duration: any) => {
|
| 73 |
audioDecoded = true;
|
| 74 |
const numChannels = waveform.getDecodedData().numberOfChannels;
|
| 75 |
+
console.log(numChannels);
|
| 76 |
audio_duration = duration;
|
| 77 |
durationRef && (durationRef.textContent = format_time(duration));
|
| 78 |
|