asr / consumer /asr.py
maolin.liu
[bugfix]Create producer queue.
e5e423e
raw
history blame
5.4 kB
import base64
import io
import json
import logging
import os
from pathlib import Path
from typing import Literal, Union
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, ValidationError, model_validator
from .base import BasicMessageReceiver, BasicMessageSender
def setup_logger():
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
return logger
logger = setup_logger()
class TranscribeInputMessage(BaseModel):
uuid: str = Field(title='Request Unique Id.')
audio_file: str
language: Literal['en', 'zh',]
using_file_content: bool
@model_validator(mode='after')
def check_audio_file(self):
if self.using_file_content:
return self
if not Path(self.audio_file).exists():
raise FileNotFoundError(f'Audio file not exists.')
return self
class TranscribeOutputMessage(BaseModel):
uuid: str
if_success: bool
msg: str
transcribed_text: str = Field(default='')
class TranscribeConsumer(BasicMessageReceiver):
def __init__(self):
super().__init__()
self.exchange_name = 'transcribe'
self.input_queue_name = 'transcribe-input'
self.input_routing_key = 'transcribe-input'
self.output_queue_name = 'transcribe-output'
self.output_routing_key = 'transcribe-output'
self.setup_consume_parameters()
self.setup_producer_parameters()
logger.info('Loading model...')
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
# Run on GPU with FP16
self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
logger.info('Load model finished.')
def setup_consume_parameters(self):
logger.info(
f'Create consumer exchange: {self.exchange_name}, '
f'routing-key: {self.input_routing_key}, '
f'queue: {self.input_queue_name}'
)
self.declare_exchange(self.exchange_name)
self.declare_queue(self.input_queue_name, max_priority=-1)
self.bind_queue(self.exchange_name, self.input_queue_name, self.input_routing_key)
def setup_producer_parameters(self):
logger.info(
f'Create producer exchange: {self.exchange_name}, '
f'routing-key: {self.output_routing_key}, '
f'queue: {self.output_queue_name}'
)
self.declare_exchange(self.exchange_name)
self.declare_queue(self.output_queue_name, max_priority=-1)
self.bind_queue(self.exchange_name, self.output_queue_name, self.output_routing_key)
def send_message(self, message: Union[dict, str]):
routing_key = 'transcribe-output'
# headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
sender = BasicMessageSender()
sender.send_message(
exchange_name=self.exchange_name,
routing_key=routing_key,
body=message,
headers=None
)
logger.info(f'{"-" * 80}')
logger.info(f"Send message to Exchange: {self.exchange_name}, Routing-key: {routing_key}, \n"
f"Messgae body: {message}")
logger.info(f'{"-" * 80}')
def send_success_message(self, uuid: str, transcribed_text):
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
transcribed_text=transcribed_text)
self.send_message(message.model_dump_json())
def send_fail_message(self, uuid: str, error: str):
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
self.send_message(message.model_dump_json())
def consume(self, channel, method, properties, message):
logger.info(f'Recevied a message: {message}')
try:
body = self.decode_message(message)
except json.JSONDecodeError as exc:
logging.exception('Message decode failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
self.send_fail_message('', f'Message decode failed, message: \n {message}')
return
try:
validated_message = TranscribeInputMessage.model_validate(body)
audio_file = validated_message.audio_file
if validated_message.using_file_content:
audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file))
logger.info(f'Start transcribe input...')
segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language)
transcribed_segment_text = []
for segment in segments:
transcribed_segment_text.append(segment.text)
transcribed_text = ', '.join(transcribed_segment_text)
logger.info(f'Transcribed text: {transcribed_text}')
except ValidationError as exc:
logging.exception('Message validated failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
self.send_fail_message(body.get('uuid'), f'{exc}')
except Exception as exc:
logging.exception('Consume message failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
self.send_fail_message(body.get('uuid'), f'{exc}')
else:
self.send_success_message(validated_message.uuid, transcribed_text)