|
import base64 |
|
import io |
|
import json |
|
import logging |
|
import os |
|
from pathlib import Path |
|
from typing import Literal, Union |
|
|
|
from faster_whisper import WhisperModel |
|
from pydantic import BaseModel, Field, ValidationError, model_validator |
|
|
|
from .base import BasicMessageReceiver, BasicMessageSender |
|
|
|
|
|
def setup_logger(): |
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
return logger |
|
|
|
|
|
logger = setup_logger() |
|
|
|
|
|
class TranscribeInputMessage(BaseModel): |
|
uuid: str = Field(title='Request Unique Id.') |
|
audio_file: str |
|
language: Literal['en', 'zh',] |
|
using_file_content: bool |
|
|
|
@model_validator(mode='after') |
|
def check_audio_file(self): |
|
if self.using_file_content: |
|
return self |
|
|
|
if not Path(self.audio_file).exists(): |
|
raise FileNotFoundError(f'Audio file not exists.') |
|
return self |
|
|
|
|
|
class TranscribeOutputMessage(BaseModel): |
|
uuid: str |
|
if_success: bool |
|
msg: str |
|
transcribed_text: str = Field(default='') |
|
|
|
|
|
class TranscribeConsumer(BasicMessageReceiver): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.exchange_name = 'transcribe' |
|
self.input_queue_name = 'transcribe-input' |
|
self.input_routing_key = 'transcribe-input' |
|
self.output_queue_name = 'transcribe-output' |
|
self.output_routing_key = 'transcribe-output' |
|
|
|
self.setup_consume_parameters() |
|
self.setup_producer_parameters() |
|
|
|
logger.info('Loading model...') |
|
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3') |
|
|
|
self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16') |
|
logger.info('Load model finished.') |
|
|
|
def setup_consume_parameters(self): |
|
logger.info( |
|
f'Create consumer exchange: {self.exchange_name}, ' |
|
f'routing-key: {self.input_routing_key}, ' |
|
f'queue: {self.input_queue_name}' |
|
) |
|
self.declare_exchange(self.exchange_name) |
|
self.declare_queue(self.input_queue_name, max_priority=-1) |
|
self.bind_queue(self.exchange_name, self.input_queue_name, self.input_routing_key) |
|
|
|
def setup_producer_parameters(self): |
|
logger.info( |
|
f'Create producer exchange: {self.exchange_name}, ' |
|
f'routing-key: {self.output_routing_key}, ' |
|
f'queue: {self.output_queue_name}' |
|
) |
|
self.declare_exchange(self.exchange_name) |
|
self.declare_queue(self.output_queue_name, max_priority=-1) |
|
self.bind_queue(self.exchange_name, self.output_queue_name, self.output_routing_key) |
|
|
|
def send_message(self, message: Union[dict, str]): |
|
routing_key = 'transcribe-output' |
|
|
|
sender = BasicMessageSender() |
|
sender.send_message( |
|
exchange_name=self.exchange_name, |
|
routing_key=routing_key, |
|
body=message, |
|
headers=None |
|
) |
|
logger.info(f'{"-" * 80}') |
|
logger.info(f"Send message to Exchange: {self.exchange_name}, Routing-key: {routing_key}, \n" |
|
f"Messgae body: {message}") |
|
logger.info(f'{"-" * 80}') |
|
|
|
def send_success_message(self, uuid: str, transcribed_text): |
|
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.', |
|
transcribed_text=transcribed_text) |
|
self.send_message(message.model_dump_json()) |
|
|
|
def send_fail_message(self, uuid: str, error: str): |
|
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error) |
|
self.send_message(message.model_dump_json()) |
|
|
|
def consume(self, channel, method, properties, message): |
|
logger.info(f'Recevied a message: {message}') |
|
try: |
|
body = self.decode_message(message) |
|
except json.JSONDecodeError as exc: |
|
logging.exception('Message decode failed: \n message:\n %s\n\n exception info:\n %s', message, exc) |
|
self.send_fail_message('', f'Message decode failed, message: \n {message}') |
|
return |
|
|
|
try: |
|
validated_message = TranscribeInputMessage.model_validate(body) |
|
|
|
audio_file = validated_message.audio_file |
|
if validated_message.using_file_content: |
|
audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file)) |
|
|
|
logger.info(f'Start transcribe input...') |
|
segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language) |
|
|
|
transcribed_segment_text = [] |
|
for segment in segments: |
|
transcribed_segment_text.append(segment.text) |
|
transcribed_text = ', '.join(transcribed_segment_text) |
|
logger.info(f'Transcribed text: {transcribed_text}') |
|
except ValidationError as exc: |
|
logging.exception('Message validated failed: \n message:\n %s\n\n exception info:\n %s', message, exc) |
|
self.send_fail_message(body.get('uuid'), f'{exc}') |
|
except Exception as exc: |
|
logging.exception('Consume message failed: \n message:\n %s\n\n exception info:\n %s', message, exc) |
|
self.send_fail_message(body.get('uuid'), f'{exc}') |
|
else: |
|
self.send_success_message(validated_message.uuid, transcribed_text) |
|
|