SpeechT5_hy / src /preprocessing.py
Edmon02's picture
Implement optimized TTS pipeline with advanced text preprocessing, audio processing, and comprehensive error handling
b163aa7
"""
Text Preprocessing Module
========================
Handles text normalization, translation, chunking, and optimization for TTS processing.
Implements caching and batch processing for improved performance.
"""
import re
import string
import logging
import asyncio
from typing import List, Tuple, Dict, Optional
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
import time
import inflect
import requests
from requests.exceptions import Timeout, RequestException
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextProcessor:
"""High-performance text processor with caching and optimization."""
def __init__(self, max_chunk_length: int = 200, overlap_words: int = 5,
translation_timeout: int = 10):
"""
Initialize the text processor.
Args:
max_chunk_length: Maximum characters per chunk
overlap_words: Number of words to overlap between chunks
translation_timeout: Timeout for translation requests in seconds
"""
self.max_chunk_length = max_chunk_length
self.overlap_words = overlap_words
self.translation_timeout = translation_timeout
self.inflect_engine = inflect.engine()
self.translation_cache: Dict[str, str] = {}
self.number_cache: Dict[str, str] = {}
# Thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=4)
@lru_cache(maxsize=1000)
def _cached_translate(self, text: str) -> str:
"""
Cached translation function to avoid repeated API calls.
Args:
text: Text to translate
Returns:
Translated text in Armenian
"""
if not text.strip():
return text
try:
response = requests.get(
"https://translate.googleapis.com/translate_a/single",
params={
'client': 'gtx',
'sl': 'auto',
'tl': 'hy',
'dt': 't',
'q': text,
},
timeout=self.translation_timeout,
)
response.raise_for_status()
translation = response.json()[0][0][0]
logger.debug(f"Translated '{text}' to '{translation}'")
return translation
except (RequestException, Timeout, IndexError) as e:
logger.warning(f"Translation failed for '{text}': {e}")
return text # Return original text if translation fails
def _convert_number_to_armenian_words(self, number: int) -> str:
"""
Convert number to Armenian words with caching.
Args:
number: Integer to convert
Returns:
Number as Armenian words
"""
cache_key = str(number)
if cache_key in self.number_cache:
return self.number_cache[cache_key]
try:
# Convert to English words first
english_words = self.inflect_engine.number_to_words(number)
# Translate to Armenian
armenian_words = self._cached_translate(english_words)
# Cache the result
self.number_cache[cache_key] = armenian_words
return armenian_words
except Exception as e:
logger.warning(f"Number conversion failed for {number}: {e}")
return str(number) # Fallback to original number
def _normalize_text(self, text: str) -> str:
"""
Normalize text by handling numbers, punctuation, and special characters.
Args:
text: Input text to normalize
Returns:
Normalized text
"""
if not text:
return ""
# Convert to string and strip
text = str(text).strip()
# Process each word
words = []
for word in text.split():
# Extract numbers from word
if re.search(r'\d', word):
# Extract just the digits
digits = ''.join(filter(str.isdigit, word))
if digits:
try:
number = int(digits)
armenian_word = self._convert_number_to_armenian_words(number)
words.append(armenian_word)
except ValueError:
words.append(word) # Keep original if conversion fails
else:
words.append(word)
else:
words.append(word)
return ' '.join(words)
def _split_into_sentences(self, text: str) -> List[str]:
"""
Split text into sentences using multiple delimiters.
Args:
text: Text to split
Returns:
List of sentences
"""
# Armenian sentence delimiters
sentence_endings = r'[.!?։՞՜]+'
sentences = re.split(sentence_endings, text)
# Clean and filter empty sentences
sentences = [s.strip() for s in sentences if s.strip()]
return sentences
def chunk_text(self, text: str) -> List[str]:
"""
Intelligently chunk text for optimal TTS processing.
This method implements sophisticated chunking that:
1. Respects sentence boundaries
2. Maintains semantic coherence
3. Includes overlap for smooth transitions
4. Optimizes chunk sizes for the TTS model
Args:
text: Input text to chunk
Returns:
List of text chunks optimized for TTS
"""
if not text or len(text) <= self.max_chunk_length:
return [text] if text else []
sentences = self._split_into_sentences(text)
if not sentences:
return [text]
chunks = []
current_chunk = ""
for i, sentence in enumerate(sentences):
# If single sentence is too long, split by clauses
if len(sentence) > self.max_chunk_length:
# Split by commas and conjunctions
clauses = re.split(r'[,;]|\sև\s|\sկամ\s|\sբայց\s', sentence)
for clause in clauses:
clause = clause.strip()
if not clause:
continue
if len(current_chunk + " " + clause) <= self.max_chunk_length:
current_chunk = (current_chunk + " " + clause).strip()
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = clause
else:
# Try to add whole sentence
test_chunk = (current_chunk + " " + sentence).strip()
if len(test_chunk) <= self.max_chunk_length:
current_chunk = test_chunk
else:
# Current chunk is full, start new one
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
# Add final chunk
if current_chunk:
chunks.append(current_chunk)
# Implement overlap for smooth transitions
if len(chunks) > 1:
chunks = self._add_overlap(chunks)
logger.info(f"Split text into {len(chunks)} chunks")
return chunks
def _add_overlap(self, chunks: List[str]) -> List[str]:
"""
Add overlapping words between chunks for smoother transitions.
Args:
chunks: List of text chunks
Returns:
Chunks with added overlap
"""
if len(chunks) <= 1:
return chunks
overlapped_chunks = [chunks[0]]
for i in range(1, len(chunks)):
prev_words = chunks[i-1].split()
current_chunk = chunks[i]
# Get last few words from previous chunk
overlap_words = prev_words[-self.overlap_words:] if len(prev_words) >= self.overlap_words else prev_words
overlap_text = " ".join(overlap_words)
# Prepend overlap to current chunk
overlapped_chunk = f"{overlap_text} {current_chunk}".strip()
overlapped_chunks.append(overlapped_chunk)
return overlapped_chunks
def process_text(self, text: str) -> str:
"""
Main text processing pipeline.
Args:
text: Raw input text
Returns:
Processed and normalized text ready for TTS
"""
start_time = time.time()
if not text or not text.strip():
return ""
try:
# Normalize the text
processed_text = self._normalize_text(text)
processing_time = time.time() - start_time
logger.info(f"Text processed in {processing_time:.3f}s")
return processed_text
except Exception as e:
logger.error(f"Text processing failed: {e}")
return str(text) # Return original text as fallback
def process_chunks(self, text: str) -> List[str]:
"""
Process text and return optimized chunks for TTS.
Args:
text: Input text
Returns:
List of processed text chunks
"""
# First normalize the text
processed_text = self.process_text(text)
# Then chunk it
chunks = self.chunk_text(processed_text)
return chunks
def clear_cache(self):
"""Clear all caches to free memory."""
self._cached_translate.cache_clear()
self.translation_cache.clear()
self.number_cache.clear()
logger.info("Caches cleared")
def get_cache_stats(self) -> Dict[str, int]:
"""Get statistics about cache usage."""
return {
"translation_cache_size": len(self.translation_cache),
"number_cache_size": len(self.number_cache),
"lru_cache_hits": self._cached_translate.cache_info().hits,
"lru_cache_misses": self._cached_translate.cache_info().misses,
}