File size: 3,658 Bytes
85378a6
 
 
 
 
7ec99bc
85378a6
 
 
 
475a8bc
85378a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ec99bc
85378a6
475a8bc
85378a6
 
 
 
 
 
 
 
 
 
7ec99bc
85378a6
 
 
7ec99bc
85378a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import base64
import io
import logging
import os
from pathlib import Path
from typing import Literal, Union

from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, ValidationError, model_validator

from .base import BasicMessageReceiver, BasicMessageSender


class TranscribeInputMessage(BaseModel):
    uuid: str = Field(title='Request Unique Id.')
    audio_file: str
    language: Literal['en', 'zh',]
    using_file_content: bool

    @model_validator(mode='after')
    def check_audio_file(self):
        if self.using_file_content:
            return self

        if not Path(self.audio_file).exists():
            raise FileNotFoundError(f'Audio file not exists.')
        return self


class TranscribeOutputMessage(BaseModel):
    uuid: str
    if_success: bool
    msg: str
    transcribed_text: str = Field(default='')


class TranscribeConsumer(BasicMessageReceiver):

    def __init__(self):
        super().__init__()

        self.exchange_name = 'transcribe'
        self.queue_name = 'transcribe-input'
        self.routing_key = 'transcribe-input'

        self.setup_consume_parameters()
        self.setup_message_sender()

        model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
        # Run on GPU with FP16
        self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')

    def setup_consume_parameters(self):
        self.declare_exchange(self.exchange_name)
        self.declare_queue(self.queue_name, max_priority=-1)
        self.bind_queue(self.exchange_name, self.queue_name, self.routing_key)

    def setup_message_sender(self):
        self.sender = BasicMessageSender()

    def send_message(self, message: Union[dict, str]):
        routing_key = 'transcribe-output'
        # headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
        self.sender.send_message(
            exchange_name=self.exchange_name,
            routing_key=routing_key,
            body=message,
            headers=None
        )

    def send_success_message(self, uuid: str, transcribed_text):
        message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
                                          transcribed_text=transcribed_text)
        self.send_message(message.model_dump_json())

    def send_fail_message(self, uuid: str, error: str):
        message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
        self.send_message(message.model_dump_json())

    def consume(self, channel, method, properties, message):
        body = self.decode_message(message)

        try:
            validated_message = TranscribeInputMessage.model_validate(body)

            audio_file = validated_message.audio_file
            if validated_message.using_file_content:
                audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file))

            segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language)

            transcribed_text = ''
            for segment in segments:
                transcribed_text = segment.text
                break
        except ValidationError as exc:
            logging.exception('Consume message failed: \n message: %s\n\n exception info: %s', message, exc)
            self.send_fail_message(body.get('uuid'), f'{exc}')
        except Exception as exc:
            logging.exception('Consume message failed: \n message: %s\n\n exception info: %s', message, exc)
            self.send_fail_message(body.get('uuid'), f'{exc}')
        else:
            self.send_success_message(validated_message.uuid, transcribed_text)