clement-pages commited on
Commit
4712d9f
·
1 Parent(s): 61dee21

implement component postprocessing

Browse files
sourceviewer/backend/gradio_sourceviewer/sourceviewer.py CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
4
 
5
  import dataclasses
6
  from pathlib import Path
7
- from typing import Any, Callable, Literal
8
 
9
  import httpx
10
  import numpy as np
@@ -18,6 +18,8 @@ from gradio.data_classes import FileData
18
  from gradio.events import Events
19
  from gradio.exceptions import Error
20
 
 
 
21
 
22
  @dataclasses.dataclass
23
  class WaveformOptions:
@@ -42,6 +44,13 @@ class WaveformOptions:
42
  sample_rate: int = 44100
43
 
44
 
 
 
 
 
 
 
 
45
  class SourceViewer(
46
  StreamingInput,
47
  StreamingOutput,
@@ -239,7 +248,7 @@ class SourceViewer(
239
  )
240
 
241
  def postprocess(
242
- self, value: str | Path | bytes | tuple[int, np.ndarray] | None
243
  ) -> FileData | bytes | None:
244
  """
245
  Parameters:
@@ -247,28 +256,33 @@ class SourceViewer(
247
  Returns:
248
  FileData object, bytes, or None.
249
  """
250
- orig_name = None
251
  if value is None:
252
  return None
253
- if isinstance(value, bytes):
254
- if self.streaming:
255
- return value
256
- file_path = processing_utils.save_bytes_to_cache(
257
- value, "audio", cache_dir=self.GRADIO_CACHE
258
- )
259
- orig_name = Path(file_path).name
260
- elif isinstance(value, tuple):
261
- sample_rate, data = value
262
- file_path = processing_utils.save_audio_to_cache(
263
- data, sample_rate, format=self.format, cache_dir=self.GRADIO_CACHE
264
  )
265
- orig_name = Path(file_path).name
266
- else:
267
- if not isinstance(value, (str, Path)):
268
- raise ValueError(f"Cannot process {value} as SourceViewer")
269
- file_path = str(value)
270
- orig_name = Path(file_path).name if Path(file_path).exists() else None
271
- return FileData(path=file_path, orig_name=orig_name)
 
 
 
 
 
 
 
272
 
273
  def stream_output(
274
  self, value, output_id: str, first_chunk: bool
 
4
 
5
  import dataclasses
6
  from pathlib import Path
7
+ from typing import Any, Callable, Literal, Tuple
8
 
9
  import httpx
10
  import numpy as np
 
18
  from gradio.events import Events
19
  from gradio.exceptions import Error
20
 
21
+ from pyannote.core.annotation import Annotation
22
+
23
 
24
  @dataclasses.dataclass
25
  class WaveformOptions:
 
44
  sample_rate: int = 44100
45
 
46
 
47
+ @dataclasses.dataclass
48
+ class Segment:
49
+ start: float
50
+ end: float
51
+ channel: int
52
+
53
+
54
  class SourceViewer(
55
  StreamingInput,
56
  StreamingOutput,
 
248
  )
249
 
250
  def postprocess(
251
+ self, value: Tuple[Annotation, np.ndarray] | None
252
  ) -> FileData | bytes | None:
253
  """
254
  Parameters:
 
256
  Returns:
257
  FileData object, bytes, or None.
258
  """
 
259
  if value is None:
260
  return None
261
+
262
+ annotations, sources = value
263
+ labels = annotations.labels()
264
+
265
+ # format diarization output
266
+ segments = []
267
+ for segment, _, label in annotations.itertracks(yield_label=True):
268
+ label_idx = labels.index(label)
269
+ segments.append(
270
+ Segment(start=segment.start, end=segment.end, channel=label_idx)
 
271
  )
272
+
273
+ # save sources in cache
274
+ source_filepath = processing_utils.save_audio_to_cache(
275
+ data=sources.data,
276
+ sample_rate=16_000,
277
+ format=self.format,
278
+ cache_dir=self.GRADIO_CACHE,
279
+ )
280
+ orig_name = Path(source_filepath).name
281
+
282
+ return {
283
+ "segments": segments,
284
+ "sources_file": FileData(path=source_filepath, orig_name=orig_name),
285
+ }
286
 
287
  def stream_output(
288
  self, value, output_id: str, first_chunk: bool