pratham0011's picture
Upload 6 files
252fde6 verified
import os
import tempfile
import logging
from typing import Optional
import torch
import librosa
import edge_tts
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from config.config import VOICE, FALLBACK_VOICES
logger = logging.getLogger(__name__)
# Whisper model for speech to text
processor = WhisperProcessor.from_pretrained(
"openai/whisper-tiny",
local_files_only=False
)
model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny",
local_files_only=False,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
).to("cpu")
# Voice selection handling
async def get_valid_voice() -> str:
available_voices = await edge_tts.list_voices()
voice_names = [VOICE] + FALLBACK_VOICES
available_voice_names = {v["ShortName"] for v in available_voices}
for voice in voice_names:
if voice in available_voice_names:
return voice
raise RuntimeError("No valid voice found")
# Text-to-speech conversion using Edge TTS
async def generate_speech(text: str) -> Optional[str]:
if not text or not isinstance(text, str):
raise ValueError("Invalid text input")
voice = await get_valid_voice()
logger.info(f"Using voice: {voice}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
tmp_path = tmp_file.name
communicate = edge_tts.Communicate(text, voice)
await communicate.save(tmp_path)
if not os.path.exists(tmp_path) or os.path.getsize(tmp_path) == 0:
raise RuntimeError("Speech file empty or not created")
logger.info(f"Speech generated successfully: {tmp_path}")
return tmp_path
# Speech-to-text using Whisper
async def transcribe(audio_file: str) -> str:
audio, sr = librosa.load(
audio_file,
sr=16000,
mono=True,
duration=30
)
inputs = processor(
audio,
sampling_rate=sr,
return_tensors="pt",
return_attention_mask=True
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(
input_features=inputs.input_features,
attention_mask=inputs.attention_mask,
language="en",
task="transcribe",
max_length=448,
temperature=0.0
)
transcription = processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0].strip()
logger.info(f"Transcribed text: {transcription}")
return transcription