File size: 5,397 Bytes
85378a6 a829e96 85378a6 7ec99bc 85378a6 475a8bc 85378a6 d057131 85378a6 f62fb80 85378a6 e5e423e 85378a6 d057131 85378a6 d057131 85378a6 f62fb80 85378a6 f62fb80 85378a6 7ec99bc 85378a6 475a8bc fa21f8b 85378a6 d057131 85378a6 7ec99bc 85378a6 7ec99bc 85378a6 d057131 a829e96 85378a6 c08591d 85378a6 f62fb80 85378a6 f62fb80 85378a6 a829e96 85378a6 a829e96 85378a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import base64
import io
import json
import logging
import os
from pathlib import Path
from typing import Literal, Union
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, ValidationError, model_validator
from .base import BasicMessageReceiver, BasicMessageSender
def setup_logger():
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
return logger
logger = setup_logger()
class TranscribeInputMessage(BaseModel):
uuid: str = Field(title='Request Unique Id.')
audio_file: str
language: Literal['en', 'zh',]
using_file_content: bool
@model_validator(mode='after')
def check_audio_file(self):
if self.using_file_content:
return self
if not Path(self.audio_file).exists():
raise FileNotFoundError(f'Audio file not exists.')
return self
class TranscribeOutputMessage(BaseModel):
uuid: str
if_success: bool
msg: str
transcribed_text: str = Field(default='')
class TranscribeConsumer(BasicMessageReceiver):
def __init__(self):
super().__init__()
self.exchange_name = 'transcribe'
self.input_queue_name = 'transcribe-input'
self.input_routing_key = 'transcribe-input'
self.output_queue_name = 'transcribe-output'
self.output_routing_key = 'transcribe-output'
self.setup_consume_parameters()
self.setup_producer_parameters()
logger.info('Loading model...')
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
# Run on GPU with FP16
self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
logger.info('Load model finished.')
def setup_consume_parameters(self):
logger.info(
f'Create consumer exchange: {self.exchange_name}, '
f'routing-key: {self.input_routing_key}, '
f'queue: {self.input_queue_name}'
)
self.declare_exchange(self.exchange_name)
self.declare_queue(self.input_queue_name, max_priority=-1)
self.bind_queue(self.exchange_name, self.input_queue_name, self.input_routing_key)
def setup_producer_parameters(self):
logger.info(
f'Create producer exchange: {self.exchange_name}, '
f'routing-key: {self.output_routing_key}, '
f'queue: {self.output_queue_name}'
)
self.declare_exchange(self.exchange_name)
self.declare_queue(self.output_queue_name, max_priority=-1)
self.bind_queue(self.exchange_name, self.output_queue_name, self.output_routing_key)
def send_message(self, message: Union[dict, str]):
routing_key = 'transcribe-output'
# headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
sender = BasicMessageSender()
sender.send_message(
exchange_name=self.exchange_name,
routing_key=routing_key,
body=message,
headers=None
)
logger.info(f'{"-" * 80}')
logger.info(f"Send message to Exchange: {self.exchange_name}, Routing-key: {routing_key}, \n"
f"Messgae body: {message}")
logger.info(f'{"-" * 80}')
def send_success_message(self, uuid: str, transcribed_text):
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
transcribed_text=transcribed_text)
self.send_message(message.model_dump_json())
def send_fail_message(self, uuid: str, error: str):
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
self.send_message(message.model_dump_json())
def consume(self, channel, method, properties, message):
logger.info(f'Recevied a message: {message}')
try:
body = self.decode_message(message)
except json.JSONDecodeError as exc:
logging.exception('Message decode failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
self.send_fail_message('', f'Message decode failed, message: \n {message}')
return
try:
validated_message = TranscribeInputMessage.model_validate(body)
audio_file = validated_message.audio_file
if validated_message.using_file_content:
audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file))
logger.info(f'Start transcribe input...')
segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language)
transcribed_segment_text = []
for segment in segments:
transcribed_segment_text.append(segment.text)
transcribed_text = ', '.join(transcribed_segment_text)
logger.info(f'Transcribed text: {transcribed_text}')
except ValidationError as exc:
logging.exception('Message validated failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
self.send_fail_message(body.get('uuid'), f'{exc}')
except Exception as exc:
logging.exception('Consume message failed: \n message:\n %s\n\n exception info:\n %s', message, exc)
self.send_fail_message(body.get('uuid'), f'{exc}')
else:
self.send_success_message(validated_message.uuid, transcribed_text)
|