Commit
·
4712d9f
1
Parent(s):
61dee21
implement component postprocessing
Browse files
sourceviewer/backend/gradio_sourceviewer/sourceviewer.py
CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
4 |
|
5 |
import dataclasses
|
6 |
from pathlib import Path
|
7 |
-
from typing import Any, Callable, Literal
|
8 |
|
9 |
import httpx
|
10 |
import numpy as np
|
@@ -18,6 +18,8 @@ from gradio.data_classes import FileData
|
|
18 |
from gradio.events import Events
|
19 |
from gradio.exceptions import Error
|
20 |
|
|
|
|
|
21 |
|
22 |
@dataclasses.dataclass
|
23 |
class WaveformOptions:
|
@@ -42,6 +44,13 @@ class WaveformOptions:
|
|
42 |
sample_rate: int = 44100
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
class SourceViewer(
|
46 |
StreamingInput,
|
47 |
StreamingOutput,
|
@@ -239,7 +248,7 @@ class SourceViewer(
|
|
239 |
)
|
240 |
|
241 |
def postprocess(
|
242 |
-
self, value:
|
243 |
) -> FileData | bytes | None:
|
244 |
"""
|
245 |
Parameters:
|
@@ -247,28 +256,33 @@ class SourceViewer(
|
|
247 |
Returns:
|
248 |
FileData object, bytes, or None.
|
249 |
"""
|
250 |
-
orig_name = None
|
251 |
if value is None:
|
252 |
return None
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
data, sample_rate, format=self.format, cache_dir=self.GRADIO_CACHE
|
264 |
)
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
def stream_output(
|
274 |
self, value, output_id: str, first_chunk: bool
|
|
|
4 |
|
5 |
import dataclasses
|
6 |
from pathlib import Path
|
7 |
+
from typing import Any, Callable, Literal, Tuple
|
8 |
|
9 |
import httpx
|
10 |
import numpy as np
|
|
|
18 |
from gradio.events import Events
|
19 |
from gradio.exceptions import Error
|
20 |
|
21 |
+
from pyannote.core.annotation import Annotation
|
22 |
+
|
23 |
|
24 |
@dataclasses.dataclass
|
25 |
class WaveformOptions:
|
|
|
44 |
sample_rate: int = 44100
|
45 |
|
46 |
|
47 |
+
@dataclasses.dataclass
|
48 |
+
class Segment:
|
49 |
+
start: float
|
50 |
+
end: float
|
51 |
+
channel: int
|
52 |
+
|
53 |
+
|
54 |
class SourceViewer(
|
55 |
StreamingInput,
|
56 |
StreamingOutput,
|
|
|
248 |
)
|
249 |
|
250 |
def postprocess(
|
251 |
+
self, value: Tuple[Annotation, np.ndarray] | None
|
252 |
) -> FileData | bytes | None:
|
253 |
"""
|
254 |
Parameters:
|
|
|
256 |
Returns:
|
257 |
FileData object, bytes, or None.
|
258 |
"""
|
|
|
259 |
if value is None:
|
260 |
return None
|
261 |
+
|
262 |
+
annotations, sources = value
|
263 |
+
labels = annotations.labels()
|
264 |
+
|
265 |
+
# format diarization output
|
266 |
+
segments = []
|
267 |
+
for segment, _, label in annotations.itertracks(yield_label=True):
|
268 |
+
label_idx = labels.index(label)
|
269 |
+
segments.append(
|
270 |
+
Segment(start=segment.start, end=segment.end, channel=label_idx)
|
|
|
271 |
)
|
272 |
+
|
273 |
+
# save sources in cache
|
274 |
+
source_filepath = processing_utils.save_audio_to_cache(
|
275 |
+
data=sources.data,
|
276 |
+
sample_rate=16_000,
|
277 |
+
format=self.format,
|
278 |
+
cache_dir=self.GRADIO_CACHE,
|
279 |
+
)
|
280 |
+
orig_name = Path(source_filepath).name
|
281 |
+
|
282 |
+
return {
|
283 |
+
"segments": segments,
|
284 |
+
"sources_file": FileData(path=source_filepath, orig_name=orig_name),
|
285 |
+
}
|
286 |
|
287 |
def stream_output(
|
288 |
self, value, output_id: str, first_chunk: bool
|