Anustup commited on
Commit
b0207ad
·
1 Parent(s): e763e44

Create vad_test.py

Browse files
Files changed (1) hide show
  1. tests/vad_test.py +66 -0
tests/vad_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import unittest
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append('../NS_AI_LABS')
7
+
8
+ from src.vad import AbstractTranscription, VadSileroTranscription
9
+
10
+ class TestVad(unittest.TestCase):
11
+ def __init__(self, *args, **kwargs):
12
+ super(TestVad, self).__init__(*args, **kwargs)
13
+ self.transcribe_calls = []
14
+
15
+ def test_transcript(self):
16
+ mock = MockVadTranscription()
17
+
18
+ self.transcribe_calls.clear()
19
+ result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
+
21
+ self.assertListEqual(self.transcribe_calls, [
22
+ [30, 30],
23
+ [100, 100]
24
+ ])
25
+
26
+ self.assertListEqual(result['segments'],
27
+ [{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
28
+ {'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
29
+ )
30
+
31
+ def transcribe_segments(self, segment):
32
+ self.transcribe_calls.append(segment.tolist())
33
+
34
+ # Dummy text
35
+ return {
36
+ 'text': "Hello world ",
37
+ 'segments': [
38
+ {
39
+ "start": 10.0,
40
+ "end": 20.0,
41
+ "text": "Hello world "
42
+ }
43
+ ],
44
+ 'language': ""
45
+ }
46
+
47
+ class MockVadTranscription(AbstractTranscription):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
+ start_time_seconds = float(start_time.removesuffix("s"))
53
+ duration_seconds = float(duration.removesuffix("s"))
54
+
55
+ # For mocking, this just returns a simple numppy array
56
+ return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
+
58
+ def get_transcribe_timestamps(self, audio: str):
59
+ result = []
60
+
61
+ result.append( { 'start': 30, 'end': 60 } )
62
+ result.append( { 'start': 100, 'end': 200 } )
63
+ return result
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()