asr / consumer /asr.py
maolin.liu
[bugfix]Do not serialize string.
7ec99bc
raw
history blame
3.66 kB
import base64
import io
import logging
import os
from pathlib import Path
from typing import Literal, Union
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, ValidationError, model_validator
from .base import BasicMessageReceiver, BasicMessageSender
class TranscribeInputMessage(BaseModel):
uuid: str = Field(title='Request Unique Id.')
audio_file: str
language: Literal['en', 'zh',]
using_file_content: bool
@model_validator(mode='after')
def check_audio_file(self):
if self.using_file_content:
return self
if not Path(self.audio_file).exists():
raise FileNotFoundError(f'Audio file not exists.')
return self
class TranscribeOutputMessage(BaseModel):
uuid: str
if_success: bool
msg: str
transcribed_text: str = Field(default='')
class TranscribeConsumer(BasicMessageReceiver):
def __init__(self):
super().__init__()
self.exchange_name = 'transcribe'
self.queue_name = 'transcribe-input'
self.routing_key = 'transcribe-input'
self.setup_consume_parameters()
self.setup_message_sender()
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
# Run on GPU with FP16
self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
def setup_consume_parameters(self):
self.declare_exchange(self.exchange_name)
self.declare_queue(self.queue_name, max_priority=-1)
self.bind_queue(self.exchange_name, self.queue_name, self.routing_key)
def setup_message_sender(self):
self.sender = BasicMessageSender()
def send_message(self, message: Union[dict, str]):
routing_key = 'transcribe-output'
# headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
self.sender.send_message(
exchange_name=self.exchange_name,
routing_key=routing_key,
body=message,
headers=None
)
def send_success_message(self, uuid: str, transcribed_text):
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
transcribed_text=transcribed_text)
self.send_message(message.model_dump_json())
def send_fail_message(self, uuid: str, error: str):
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
self.send_message(message.model_dump_json())
def consume(self, channel, method, properties, message):
body = self.decode_message(message)
try:
validated_message = TranscribeInputMessage.model_validate(body)
audio_file = validated_message.audio_file
if validated_message.using_file_content:
audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file))
segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language)
transcribed_text = ''
for segment in segments:
transcribed_text = segment.text
break
except ValidationError as exc:
logging.exception('Consume message failed: \n message: %s\n\n exception info: %s', message, exc)
self.send_fail_message(body.get('uuid'), f'{exc}')
except Exception as exc:
logging.exception('Consume message failed: \n message: %s\n\n exception info: %s', message, exc)
self.send_fail_message(body.get('uuid'), f'{exc}')
else:
self.send_success_message(validated_message.uuid, transcribed_text)