File size: 3,534 Bytes
402daee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package.path = '../common/libs/?.lua;../common/vendor/?.lua;' .. package.path

app = {}

local lu = require('luaunit')

require('mock_reaper')
require('Polo')
require('source/SRTWriter')
require('source/Transcript')

--

reaper.GetMediaItemTake_Source = function () return {fileName = "test_audio.wav"} end
reaper.GetMediaSourceFileName = function (source) return source.fileName end

TestSRTWriter = {}

function TestSRTWriter:setUp()
  function app:trap(f) return xpcall(f, function(e) print(tostring(e)) end) end

  reaper.__test_setUp()
end

function TestSRTWriter.make_transcript()
  local t = Transcript.new()
  t:add_segment(TranscriptSegment.new {
    data = {start = 0, ['end'] = 1, text = 'hello'},
    item = {},
    take = {}
  })
  t:add_segment(TranscriptSegment.new {
    data = {start = 1, ['end'] = 2, text = 'world'},
    item = {},
    take = {}
  })
  t:update()
  return t
end

function TestSRTWriter:testFormatTime()
  lu.assertEquals(SRTWriter.format_time(0), '00:00:00,000')
  lu.assertEquals(SRTWriter.format_time(1), '00:00:01,000')
  lu.assertEquals(SRTWriter.format_time(1.5), '00:00:01,500')
  lu.assertEquals(SRTWriter.format_time(60), '00:01:00,000')
  lu.assertEquals(SRTWriter.format_time(60.5), '00:01:00,500')
  lu.assertEquals(SRTWriter.format_time(3600), '01:00:00,000')
  lu.assertEquals(SRTWriter.format_time(3600.5), '01:00:00,500')
end

function TestSRTWriter:testInit()
  local f = {}
  local writer = SRTWriter.new { file = f }
end

function TestSRTWriter:testInitNoFile()
  lu.assertErrorMsgContains('missing file', SRTWriter.new)
end

function TestSRTWriter:testWrite()
  local t = TestSRTWriter.make_transcript()
  local output = {}
  local f = {
    write = function (self, s)
      table.insert(output, s)
    end
  }
  local writer = SRTWriter.new { file = f }
  writer:write(t)
  local output_str = table.concat(output)
  lu.assertEquals(output_str, '1\n00:00:00,000 --> 00:00:01,000\nhello\n\n2\n00:00:01,000 --> 00:00:02,000\nworld\n\n')
end

function TestSRTWriter:testXYCoordinates()
  local t = TestSRTWriter.make_transcript()
  local output = {}
  local f = {
    write = function (self, s)
      table.insert(output, s)
    end
  }
  local writer = SRTWriter.new {
    file = f,
    options = {
      coords_x1 = '1',
      coords_y1 = '2',
      coords_x2 = '3',
      coords_y2 = '4'
    }
  }
  writer:write(t)
  local output_str = table.concat(output)
  lu.assertEquals(output_str, '1\n00:00:00,000 --> 00:00:01,000 X1:1 X2:3 Y1:2 Y2:4\nhello\n\n2\n00:00:01,000 --> 00:00:02,000 X1:1 X2:3 Y1:2 Y2:4\nworld\n\n')
end

function TestSRTWriter:testWriteSegment()
  local output = {}
  local f = {
    write = function (self, s)
      table.insert(output, s)
    end
  }
  local writer = SRTWriter.new { file = f }
  local segment = {
    get = function (self, key)
      if key == 'start' then
        return 0
      elseif key == 'end' then
        return 1
      elseif key == 'text' then
        return 'hello'
      end
    end
  }
  writer:write_segment(segment, 1)
  local output_str = table.concat(output)
  lu.assertEquals(output_str, '1\n00:00:00,000 --> 00:00:01,000\nhello\n\n')
end

function TestSRTWriter:testWriteLine()
  local output = {}
  local f = {
    write = function (self, s)
      table.insert(output, s)
    end
  }
  local writer = SRTWriter.new { file = f }
  writer:write_line('hello', 1, 0, 1)
  local output_str = table.concat(output)
  lu.assertEquals(output_str, '1\n00:00:00,000 --> 00:00:01,000\nhello\n\n')
end

--

os.exit(lu.LuaUnit.run())