import base64 import io import json import logging import os from pathlib import Path from typing import Literal, Union from faster_whisper import WhisperModel from pydantic import BaseModel, Field, ValidationError, model_validator from .base import BasicMessageReceiver, BasicMessageSender def setup_logger(): logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) return logger logger = setup_logger() class TranscribeInputMessage(BaseModel): uuid: str = Field(title='Request Unique Id.') audio_file: str language: Literal['en', 'zh',] using_file_content: bool @model_validator(mode='after') def check_audio_file(self): if self.using_file_content: return self if not Path(self.audio_file).exists(): raise FileNotFoundError(f'Audio file not exists.') return self class TranscribeOutputMessage(BaseModel): uuid: str if_success: bool msg: str transcribed_text: str = Field(default='') class TranscribeConsumer(BasicMessageReceiver): def __init__(self): super().__init__() self.exchange_name = 'transcribe' self.input_queue_name = 'transcribe-input' self.input_routing_key = 'transcribe-input' self.output_queue_name = 'transcribe-output' self.output_routing_key = 'transcribe-output' self.setup_consume_parameters() self.setup_producer_parameters() logger.info('Loading model...') model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3') # Run on GPU with FP16 self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16') logger.info('Load model finished.') def setup_consume_parameters(self): logger.info( f'Create consumer exchange: {self.exchange_name}, ' f'routing-key: {self.input_routing_key}, ' f'queue: {self.input_queue_name}' ) self.declare_exchange(self.exchange_name) self.declare_queue(self.input_queue_name, max_priority=-1) self.bind_queue(self.exchange_name, self.input_queue_name, self.input_routing_key) def setup_producer_parameters(self): logger.info( f'Create producer exchange: {self.exchange_name}, ' f'routing-key: {self.output_routing_key}, ' f'queue: {self.output_queue_name}' ) self.declare_exchange(self.exchange_name) self.declare_queue(self.output_queue_name, max_priority=-1) self.bind_queue(self.exchange_name, self.output_queue_name, self.output_routing_key) def send_message(self, message: Union[dict, str]): routing_key = 'transcribe-output' # headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL) sender = BasicMessageSender() sender.send_message( exchange_name=self.exchange_name, routing_key=routing_key, body=message, headers=None ) logger.info(f'{"-" * 80}') logger.info(f"Send message to Exchange: {self.exchange_name}, Routing-key: {routing_key}, \n" f"Messgae body: {message}") logger.info(f'{"-" * 80}') def send_success_message(self, uuid: str, transcribed_text): message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.', transcribed_text=transcribed_text) self.send_message(message.model_dump_json()) def send_fail_message(self, uuid: str, error: str): message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error) self.send_message(message.model_dump_json()) def consume(self, channel, method, properties, message): logger.info(f'Recevied a message: {message}') try: body = self.decode_message(message) except json.JSONDecodeError as exc: logging.exception('Message decode failed: \n message:\n %s\n\n exception info:\n %s', message, exc) self.send_fail_message('', f'Message decode failed, message: \n {message}') return try: validated_message = TranscribeInputMessage.model_validate(body) audio_file = validated_message.audio_file if validated_message.using_file_content: audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file)) logger.info(f'Start transcribe input...') segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language) transcribed_segment_text = [] for segment in segments: transcribed_segment_text.append(segment.text) transcribed_text = ', '.join(transcribed_segment_text) logger.info(f'Transcribed text: {transcribed_text}') except ValidationError as exc: logging.exception('Message validated failed: \n message:\n %s\n\n exception info:\n %s', message, exc) self.send_fail_message(body.get('uuid'), f'{exc}') except Exception as exc: logging.exception('Consume message failed: \n message:\n %s\n\n exception info:\n %s', message, exc) self.send_fail_message(body.get('uuid'), f'{exc}') else: self.send_success_message(validated_message.uuid, transcribed_text)