File size: 5,397 Bytes
85378a6
 
a829e96
85378a6
 
 
7ec99bc
85378a6
 
 
 
475a8bc
85378a6
 
d057131
 
 
 
 
 
 
 
 
85378a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f62fb80
 
 
 
85378a6
 
e5e423e
85378a6
d057131
85378a6
 
 
d057131
85378a6
 
f62fb80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85378a6
f62fb80
 
85378a6
7ec99bc
85378a6
475a8bc
fa21f8b
 
85378a6
 
 
 
 
d057131
 
 
 
85378a6
 
 
 
7ec99bc
85378a6
 
 
7ec99bc
85378a6
 
d057131
a829e96
 
 
 
 
 
85378a6
 
 
 
 
 
 
 
c08591d
85378a6
 
f62fb80
85378a6
f62fb80
 
 
85378a6
a829e96
85378a6
 
a829e96
85378a6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import base64
import io
import json
import logging
import os
from pathlib import Path
from typing import Literal, Union

from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, ValidationError, model_validator

from .base import BasicMessageReceiver, BasicMessageSender


def setup_logger():
    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.INFO)
    return logger


logger = setup_logger()


class TranscribeInputMessage(BaseModel):
    uuid: str = Field(title='Request Unique Id.')
    audio_file: str
    language: Literal['en', 'zh',]
    using_file_content: bool

    @model_validator(mode='after')
    def check_audio_file(self):
        if self.using_file_content:
            return self

        if not Path(self.audio_file).exists():
            raise FileNotFoundError(f'Audio file not exists.')
        return self


class TranscribeOutputMessage(BaseModel):
    uuid: str
    if_success: bool
    msg: str
    transcribed_text: str = Field(default='')


class TranscribeConsumer(BasicMessageReceiver):

    def __init__(self):
        super().__init__()

        self.exchange_name = 'transcribe'
        self.input_queue_name = 'transcribe-input'
        self.input_routing_key = 'transcribe-input'
        self.output_queue_name = 'transcribe-output'
        self.output_routing_key = 'transcribe-output'

        self.setup_consume_parameters()
        self.setup_producer_parameters()

        logger.info('Loading model...')
        model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
        # Run on GPU with FP16
        self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
        logger.info('Load model finished.')

    def setup_consume_parameters(self):
        logger.info(
            f'Create consumer exchange: {self.exchange_name}, '
            f'routing-key: {self.input_routing_key}, '
            f'queue: {self.input_queue_name}'
        )
        self.declare_exchange(self.exchange_name)
        self.declare_queue(self.input_queue_name, max_priority=-1)
        self.bind_queue(self.exchange_name, self.input_queue_name, self.input_routing_key)

    def setup_producer_parameters(self):
        logger.info(
            f'Create producer exchange: {self.exchange_name}, '
            f'routing-key: {self.output_routing_key}, '
            f'queue: {self.output_queue_name}'
        )
        self.declare_exchange(self.exchange_name)
        self.declare_queue(self.output_queue_name, max_priority=-1)
        self.bind_queue(self.exchange_name, self.output_queue_name, self.output_routing_key)

    def send_message(self, message: Union[dict, str]):
        routing_key = 'transcribe-output'
        # headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
        sender = BasicMessageSender()
        sender.send_message(
            exchange_name=self.exchange_name,
            routing_key=routing_key,
            body=message,
            headers=None
        )
        logger.info(f'{"-" * 80}')
        logger.info(f"Send message to Exchange: {self.exchange_name}, Routing-key: {routing_key}, \n"
                    f"Messgae body: {message}")
        logger.info(f'{"-" * 80}')

    def send_success_message(self, uuid: str, transcribed_text):
        message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
                                          transcribed_text=transcribed_text)
        self.send_message(message.model_dump_json())

    def send_fail_message(self, uuid: str, error: str):
        message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
        self.send_message(message.model_dump_json())

    def consume(self, channel, method, properties, message):
        logger.info(f'Recevied a message: {message}')
        try:
            body = self.decode_message(message)
        except json.JSONDecodeError as exc:
            logging.exception('Message decode failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
            self.send_fail_message('', f'Message decode failed, message: \n {message}')
            return

        try:
            validated_message = TranscribeInputMessage.model_validate(body)

            audio_file = validated_message.audio_file
            if validated_message.using_file_content:
                audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file))

            logger.info(f'Start transcribe input...')
            segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language)

            transcribed_segment_text = []
            for segment in segments:
                transcribed_segment_text.append(segment.text)
            transcribed_text = ', '.join(transcribed_segment_text)
            logger.info(f'Transcribed text: {transcribed_text}')
        except ValidationError as exc:
            logging.exception('Message validated failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
            self.send_fail_message(body.get('uuid'), f'{exc}')
        except Exception as exc:
            logging.exception('Consume message failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
            self.send_fail_message(body.get('uuid'), f'{exc}')
        else:
            self.send_success_message(validated_message.uuid, transcribed_text)