File size: 5,484 Bytes
04ffec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f3bd14
04ffec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f3bd14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04ffec9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from speakers.server.model.flow_data import PayLoad
from speakers.server.model.result import (BaseResponse,
                                          TaskInfoResponse,
                                          TaskVoiceFlowInfo,
                                          RunnerState,
                                          TaskRunnerResponse)
from speakers.server.bootstrap.bootstrap_register import get_bootstrap
from speakers.common.utils import get_tmp_path
from fastapi import File, Form, Body, Query
from fastapi.responses import FileResponse
from speakers.common.registry import registry
import os
import time
import logging

logger = logging.getLogger('server_runner')


def set_server_runner_logger(l):
    global logger
    logger = l


def constant_compare(a, b):
    if isinstance(a, str):
        a = a.encode('utf-8')
    if isinstance(b, str):
        b = b.encode('utf-8')
    if not isinstance(a, bytes) or not isinstance(b, bytes):
        return False
    if len(a) != len(b):
        return False

    result = 0
    for x, y in zip(a, b):
        result |= x ^ y
    return result == 0


async def submit_async(payload: PayLoad):
    """
        Adds new task to the queue
        @see task function prepare gen taskid
    """

    runner_bootstrap_web = get_bootstrap("runner_bootstrap_web")
    task = registry.get_task_class(payload.parameter.task_name)

    runner = task.prepare(payload=payload)
    task_id = runner.task_id
    now = time.time()
    payload.created_at = now
    payload.requested_at = now

    task_state = {}
    if os.path.exists(get_tmp_path(f'result/{task_id}.wav')):
        task_state = {
            'task_id': task_id,
            'info': 'saved',
            'finished': True,
        }
        if task_id not in runner_bootstrap_web.task_data or task_id not in runner_bootstrap_web.task_states:

            logger.info(f'New `submit` task {task_id}')
            runner_bootstrap_web.task_data[task_id] = payload
            runner_bootstrap_web.queue.append(task_id)
            runner_bootstrap_web.task_states[task_id] = task_state

    elif task_id not in runner_bootstrap_web.task_data or task_id not in runner_bootstrap_web.task_states:
        os.makedirs(get_tmp_path('result'), exist_ok=True)
        task_state = {
            'task_id': task_id,
            'info': 'pending',
            'finished': False,
        }

        logger.info(f'New `submit` task {task_id}')
        runner_bootstrap_web.task_data[task_id] = payload
        runner_bootstrap_web.queue.append(task_id)

        runner_bootstrap_web.task_states[task_id] = task_state
    else:
        task_state = runner_bootstrap_web.task_states[task_id]

    return TaskRunnerResponse(code=200, msg="提交任务成功", data=task_state)


async def get_task_async(nonce: str = Query(..., examples=["samples"])):
    """
    Called by the translator to get a translation task.
    """

    runner_bootstrap_web = get_bootstrap("runner_bootstrap_web")

    if constant_compare(nonce, runner_bootstrap_web.nonce):
        if len(runner_bootstrap_web.ongoing_tasks) < runner_bootstrap_web.max_ongoing_tasks:
            if len(runner_bootstrap_web.queue) > 0:
                task_id = runner_bootstrap_web.queue.popleft()
                if task_id in runner_bootstrap_web.task_data:
                    data = runner_bootstrap_web.task_data[task_id]
                    runner_bootstrap_web.ongoing_tasks.append(task_id)
                    info = TaskVoiceFlowInfo(task_id=task_id, data=data)
                    return TaskInfoResponse(code=200, msg="成功", data=info)

            return BaseResponse(code=200, msg="成功")

        else:
            return BaseResponse(code=200, msg="max_ongoing_tasks")
    return BaseResponse(code=401, msg="无法获取任务")


async def post_task_update_async(runner_state: RunnerState):
    """
    Lets the translator update the task state it is working on.
    """

    runner_bootstrap_web = get_bootstrap("runner_bootstrap_web")

    if constant_compare(runner_state.nonce, runner_bootstrap_web.nonce):
        task_id = runner_state.task_id
        if task_id in runner_bootstrap_web.task_states and task_id in runner_bootstrap_web.task_data:
            runner_bootstrap_web.task_states[task_id] = {
                'info': runner_state.state,
                'finished': runner_state.finished,
            }
            if runner_state.finished:
                try:
                    i = runner_bootstrap_web.ongoing_tasks.index(task_id)
                    runner_bootstrap_web.ongoing_tasks.pop(i)
                except ValueError:
                    pass

            logger.info(f'Task state {task_id} to {runner_bootstrap_web.task_states[task_id]}')

    return BaseResponse(code=200, msg="成功")


async def result_async(task_id: str = Query(..., examples=["task_id"])):
    try:

        filepath = get_tmp_path(f'result/{task_id}.wav')
        logger.info(f'Task  {task_id} result_async {filepath}')
        if os.path.exists(filepath):
            return FileResponse(
                path=filepath,
                filename=f"{task_id}.wav",
                media_type="multipart/form-data")
        else:
            return BaseResponse(code=500, msg=f"{task_id}.wav 读取文件失败")

    except Exception as e:
        logger.error(f'{e.__class__.__name__}: {e}',
                     exc_info=e)
        return BaseResponse(code=500, msg=f"{task_id}.wav 读取文件失败")