j
fixed typos
e559e5b
--[[
ReaSpeechWorker.lua - Speech transcription worker
]]--
ReaSpeechWorker = Polo {}
function ReaSpeechWorker:init()
assert(self.requests, 'missing requests')
assert(self.responses, 'missing responses')
assert(self.logs, 'missing logs')
self.active_job = nil
self.pending_jobs = {}
self.job_count = 0
end
function ReaSpeechWorker:react()
local time = reaper.time_precise()
local fs = self:interval_functions()
for i = 1, #fs do
app:trap(function ()
fs[i]:react(time)
end)
end
end
function ReaSpeechWorker:interval_functions()
if self._interval_functions then
return self._interval_functions
end
self._interval_functions = {
IntervalFunction.new(0.3, function () self:react_handle_request() end),
IntervalFunction.new(0.5, function () self:react_handle_jobs() end),
}
return self._interval_functions
end
-- Handle next request
function ReaSpeechWorker:react_handle_request()
local request = table.remove(self.requests, 1)
if request then
self:handle_request(request)
end
end
-- Make progress on jobs
function ReaSpeechWorker:react_handle_jobs()
if self.active_job then
self:check_active_job()
return
end
local pending_job = table.remove(self.pending_jobs, 1)
if pending_job then
self.active_job = pending_job
self:start_active_job()
elseif self.job_count ~= 0 then
app:log('Processing finished')
self.job_count = 0
end
end
function ReaSpeechWorker:progress()
local job_count = self.job_count
if job_count == 0 then
return nil
end
local pending_job_count = #self.pending_jobs
local active_job_progress = 0
-- the active job adds 1 to the total count, and if we can know the progress
-- then we can use that fraction
if self.active_job then
if self.active_job.job and self.active_job.job.progress then
local progress = self.active_job.job.progress
active_job_progress = (progress.current / progress.total)
end
pending_job_count = pending_job_count + 1
end
local completed_job_count = job_count + active_job_progress - pending_job_count
return completed_job_count / job_count
end
function ReaSpeechWorker:status()
if self.active_job and self.active_job.job then
return self.active_job.job.job_status
end
end
function ReaSpeechWorker:cancel()
if self.active_job then
if self.active_job.job and self.active_job.job.job_id then
self:cancel_job(self.active_job.job.job_id)
end
self.active_job = nil
end
self.pending_jobs = {}
self.job_count = 0
end
function ReaSpeechWorker:cancel_job(job_id)
local url_path = "jobs/" .. job_id
ReaSpeechAPI:fetch_json(url_path, 'DELETE', function(error_message)
self:handle_error(self.active_job, error_message)
end)
end
function ReaSpeechWorker:get_job_status(job_id, retry_count)
retry_count = retry_count or 0
local max_retries = 5
local retry_delay = 1 * (2 ^ retry_count) -- Exponential backoff
local url_path = "jobs/" .. job_id
ReaSpeechAPI:fetch_json(url_path, 'GET', function(error_message)
if error_message:match("500") and retry_count < max_retries then
app:debug("Got 500 error, retrying in " .. retry_delay .. " seconds. Retry " .. (retry_count + 1) .. " of " .. max_retries)
reaper.defer(function()
self:get_job_status(job_id, retry_count + 1)
end)
else
self:handle_error(self.active_job, error_message)
self.active_job = nil
end
end, function(response)
if self:handle_job_status(self.active_job, response) then
self.active_job = nil
end
end)
end
function ReaSpeechWorker:handle_request(request)
app:log('Processing speech...')
self.job_count = #request.jobs
local data = {
task = request.translate and 'translate' or 'transcribe',
output = 'json',
use_async = 'true',
vad_filter = request.vad_filter and 'true' or 'false',
word_timestamps = 'true',
model_name = request.model_name,
}
if request.language and request.language ~= '' then
data.language = request.language
end
if request.hotwords and request.hotwords ~= '' then
data.hotwords = request.hotwords
end
if request.initial_prompt and request.initial_prompt ~= '' then
data.initial_prompt = request.initial_prompt
end
local seen_path = {}
for _, job in pairs(request.jobs) do
if not seen_path[job.path] then
seen_path[job.path] = true
table.insert(self.pending_jobs, {job = job, data = data})
end
end
end
-- May return true if the job has completed and should no longer be active
function ReaSpeechWorker:handle_job_status(active_job, response)
app:debug('Active job: ' .. dump(active_job))
app:debug('Status: ' .. dump(response))
if response.error then
table.insert(self.responses, { error = response.error })
return true
end
active_job.job.job_id = response.job_id
active_job.job.job_status = response.job_status
if not response.job_status then
return false
end
if response.job_status == 'SUCCESS' then
local transcript_url_path = response.job_result.url_path
response._job = active_job.job
active_job.transcript_output_file, active_job.transcript_output_sentinel_file = ReaSpeechAPI:fetch_large(transcript_url_path)
-- Job completion depends on non-blocking download of transcript
return false
elseif response.job_status == 'FAILURE' then
self:handle_error(active_job, response.job_result.error)
return true
end
if response.job_result and response.job_result.progress then
active_job.job.progress = response.job_result.progress
end
return false
end
function ReaSpeechWorker:handle_response(active_job, response)
response._job = active_job.job
table.insert(self.responses, response)
end
function ReaSpeechWorker:handle_error(_active_job, error_message)
table.insert(self.responses, { error = error_message })
end
function ReaSpeechWorker:start_active_job()
if not self.active_job then
return
end
local active_job = self.active_job
local output_file, sentinel_file = ReaSpeechAPI:post_request('/asr', active_job.data, active_job.job.path)
if output_file then
active_job.request_output_file = output_file
active_job.request_output_sentinel_file = sentinel_file
else
self.active_job = nil
end
end
function ReaSpeechWorker:check_active_job()
if not self.active_job then return end
local active_job = self.active_job
if active_job.request_output_file then
self:check_active_job_request_output_file()
end
if active_job.transcript_output_file then
self:check_active_job_transcript_output_file()
else
self:check_active_job_status()
end
end
function ReaSpeechWorker:check_active_job_status()
if not self.active_job then return end
local active_job = self.active_job
if not active_job.job.job_id then return end
ReaSpeechAPI:fetch_json("jobs/" .. active_job.job.job_id, 'GET',
function(error_message)
-- Error handler
self:handle_error(active_job, error_message)
self.active_job = nil
end,
function(response)
-- Success handler
if self:handle_job_status(active_job, response) then
self.active_job = nil
end
end
)
end
ReaSpeechWorker.check_sentinel = function(filename)
local sentinel = io.open(filename, 'r')
if not sentinel then
return false
end
sentinel:close()
return true
end
function ReaSpeechWorker:handle_response_json(output_file, sentinel_file, success_f, fail_f)
if not self.check_sentinel(sentinel_file) then
return
end
local f = io.open(output_file, 'r')
if not f then
fail_f("Couldn't open output file: " .. tostring(output_file))
Tempfile:remove(sentinel_file)
return
end
local http_status, body = ReaSpeechAPI.http_status_and_body(f)
f:close()
if http_status == -1 then
app:debug(body .. ", trying again later")
return
end
if http_status ~= 200 then
local msg = "Server responded with status " .. http_status
fail_f(msg)
app:log(msg)
app:debug(body)
return
end
if #body < 1 then
fail_f("Empty response from server")
return
end
local response = nil
if app:trap(function ()
response = json.decode(body)
end) then
success_f(response)
else
fail_f("Error parsing response JSON")
return
end
-- remove tempfiles only on success, moved for debugging purposes only!
Tempfile:remove(output_file)
Tempfile:remove(sentinel_file)
end
function ReaSpeechWorker:check_active_job_request_output_file()
local active_job = self.active_job
self:handle_response_json(
active_job.request_output_file,
active_job.request_output_sentinel_file,
function(response)
if self:handle_job_status(active_job, response) then
self.active_job = nil
end
end,
function(error_message)
self:handle_error(active_job, error_message)
self.active_job = nil
end
)
end
function ReaSpeechWorker:check_active_job_transcript_output_file()
local active_job = self.active_job
self:handle_response_json(
active_job.transcript_output_file,
active_job.transcript_output_sentinel_file,
function(response)
self:handle_response(active_job, response)
self.active_job = nil
end,
function(error_message)
self:handle_error(active_job, error_message)
self.active_job = nil
end
)
end