Spaces:
Runtime error
Runtime error
""" | |
Text Preprocessing Module | |
======================== | |
Handles text normalization, translation, chunking, and optimization for TTS processing. | |
Implements caching and batch processing for improved performance. | |
""" | |
import re | |
import string | |
import logging | |
import asyncio | |
from typing import List, Tuple, Dict, Optional | |
from functools import lru_cache | |
from concurrent.futures import ThreadPoolExecutor | |
import time | |
import inflect | |
import requests | |
from requests.exceptions import Timeout, RequestException | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TextProcessor: | |
"""High-performance text processor with caching and optimization.""" | |
def __init__(self, max_chunk_length: int = 200, overlap_words: int = 5, | |
translation_timeout: int = 10): | |
""" | |
Initialize the text processor. | |
Args: | |
max_chunk_length: Maximum characters per chunk | |
overlap_words: Number of words to overlap between chunks | |
translation_timeout: Timeout for translation requests in seconds | |
""" | |
self.max_chunk_length = max_chunk_length | |
self.overlap_words = overlap_words | |
self.translation_timeout = translation_timeout | |
self.inflect_engine = inflect.engine() | |
self.translation_cache: Dict[str, str] = {} | |
self.number_cache: Dict[str, str] = {} | |
# Thread pool for parallel processing | |
self.executor = ThreadPoolExecutor(max_workers=4) | |
def _cached_translate(self, text: str) -> str: | |
""" | |
Cached translation function to avoid repeated API calls. | |
Args: | |
text: Text to translate | |
Returns: | |
Translated text in Armenian | |
""" | |
if not text.strip(): | |
return text | |
try: | |
response = requests.get( | |
"https://translate.googleapis.com/translate_a/single", | |
params={ | |
'client': 'gtx', | |
'sl': 'auto', | |
'tl': 'hy', | |
'dt': 't', | |
'q': text, | |
}, | |
timeout=self.translation_timeout, | |
) | |
response.raise_for_status() | |
translation = response.json()[0][0][0] | |
logger.debug(f"Translated '{text}' to '{translation}'") | |
return translation | |
except (RequestException, Timeout, IndexError) as e: | |
logger.warning(f"Translation failed for '{text}': {e}") | |
return text # Return original text if translation fails | |
def _convert_number_to_armenian_words(self, number: int) -> str: | |
""" | |
Convert number to Armenian words with caching. | |
Args: | |
number: Integer to convert | |
Returns: | |
Number as Armenian words | |
""" | |
cache_key = str(number) | |
if cache_key in self.number_cache: | |
return self.number_cache[cache_key] | |
try: | |
# Convert to English words first | |
english_words = self.inflect_engine.number_to_words(number) | |
# Translate to Armenian | |
armenian_words = self._cached_translate(english_words) | |
# Cache the result | |
self.number_cache[cache_key] = armenian_words | |
return armenian_words | |
except Exception as e: | |
logger.warning(f"Number conversion failed for {number}: {e}") | |
return str(number) # Fallback to original number | |
def _normalize_text(self, text: str) -> str: | |
""" | |
Normalize text by handling numbers, punctuation, and special characters. | |
Args: | |
text: Input text to normalize | |
Returns: | |
Normalized text | |
""" | |
if not text: | |
return "" | |
# Convert to string and strip | |
text = str(text).strip() | |
# Process each word | |
words = [] | |
for word in text.split(): | |
# Extract numbers from word | |
if re.search(r'\d', word): | |
# Extract just the digits | |
digits = ''.join(filter(str.isdigit, word)) | |
if digits: | |
try: | |
number = int(digits) | |
armenian_word = self._convert_number_to_armenian_words(number) | |
words.append(armenian_word) | |
except ValueError: | |
words.append(word) # Keep original if conversion fails | |
else: | |
words.append(word) | |
else: | |
words.append(word) | |
return ' '.join(words) | |
def _split_into_sentences(self, text: str) -> List[str]: | |
""" | |
Split text into sentences using multiple delimiters. | |
Args: | |
text: Text to split | |
Returns: | |
List of sentences | |
""" | |
# Armenian sentence delimiters | |
sentence_endings = r'[.!?։՞՜]+' | |
sentences = re.split(sentence_endings, text) | |
# Clean and filter empty sentences | |
sentences = [s.strip() for s in sentences if s.strip()] | |
return sentences | |
def chunk_text(self, text: str) -> List[str]: | |
""" | |
Intelligently chunk text for optimal TTS processing. | |
This method implements sophisticated chunking that: | |
1. Respects sentence boundaries | |
2. Maintains semantic coherence | |
3. Includes overlap for smooth transitions | |
4. Optimizes chunk sizes for the TTS model | |
Args: | |
text: Input text to chunk | |
Returns: | |
List of text chunks optimized for TTS | |
""" | |
if not text or len(text) <= self.max_chunk_length: | |
return [text] if text else [] | |
sentences = self._split_into_sentences(text) | |
if not sentences: | |
return [text] | |
chunks = [] | |
current_chunk = "" | |
for i, sentence in enumerate(sentences): | |
# If single sentence is too long, split by clauses | |
if len(sentence) > self.max_chunk_length: | |
# Split by commas and conjunctions | |
clauses = re.split(r'[,;]|\sև\s|\sկամ\s|\sբայց\s', sentence) | |
for clause in clauses: | |
clause = clause.strip() | |
if not clause: | |
continue | |
if len(current_chunk + " " + clause) <= self.max_chunk_length: | |
current_chunk = (current_chunk + " " + clause).strip() | |
else: | |
if current_chunk: | |
chunks.append(current_chunk) | |
current_chunk = clause | |
else: | |
# Try to add whole sentence | |
test_chunk = (current_chunk + " " + sentence).strip() | |
if len(test_chunk) <= self.max_chunk_length: | |
current_chunk = test_chunk | |
else: | |
# Current chunk is full, start new one | |
if current_chunk: | |
chunks.append(current_chunk) | |
current_chunk = sentence | |
# Add final chunk | |
if current_chunk: | |
chunks.append(current_chunk) | |
# Implement overlap for smooth transitions | |
if len(chunks) > 1: | |
chunks = self._add_overlap(chunks) | |
logger.info(f"Split text into {len(chunks)} chunks") | |
return chunks | |
def _add_overlap(self, chunks: List[str]) -> List[str]: | |
""" | |
Add overlapping words between chunks for smoother transitions. | |
Args: | |
chunks: List of text chunks | |
Returns: | |
Chunks with added overlap | |
""" | |
if len(chunks) <= 1: | |
return chunks | |
overlapped_chunks = [chunks[0]] | |
for i in range(1, len(chunks)): | |
prev_words = chunks[i-1].split() | |
current_chunk = chunks[i] | |
# Get last few words from previous chunk | |
overlap_words = prev_words[-self.overlap_words:] if len(prev_words) >= self.overlap_words else prev_words | |
overlap_text = " ".join(overlap_words) | |
# Prepend overlap to current chunk | |
overlapped_chunk = f"{overlap_text} {current_chunk}".strip() | |
overlapped_chunks.append(overlapped_chunk) | |
return overlapped_chunks | |
def process_text(self, text: str) -> str: | |
""" | |
Main text processing pipeline. | |
Args: | |
text: Raw input text | |
Returns: | |
Processed and normalized text ready for TTS | |
""" | |
start_time = time.time() | |
if not text or not text.strip(): | |
return "" | |
try: | |
# Normalize the text | |
processed_text = self._normalize_text(text) | |
processing_time = time.time() - start_time | |
logger.info(f"Text processed in {processing_time:.3f}s") | |
return processed_text | |
except Exception as e: | |
logger.error(f"Text processing failed: {e}") | |
return str(text) # Return original text as fallback | |
def process_chunks(self, text: str) -> List[str]: | |
""" | |
Process text and return optimized chunks for TTS. | |
Args: | |
text: Input text | |
Returns: | |
List of processed text chunks | |
""" | |
# First normalize the text | |
processed_text = self.process_text(text) | |
# Then chunk it | |
chunks = self.chunk_text(processed_text) | |
return chunks | |
def clear_cache(self): | |
"""Clear all caches to free memory.""" | |
self._cached_translate.cache_clear() | |
self.translation_cache.clear() | |
self.number_cache.clear() | |
logger.info("Caches cleared") | |
def get_cache_stats(self) -> Dict[str, int]: | |
"""Get statistics about cache usage.""" | |
return { | |
"translation_cache_size": len(self.translation_cache), | |
"number_cache_size": len(self.number_cache), | |
"lru_cache_hits": self._cached_translate.cache_info().hits, | |
"lru_cache_misses": self._cached_translate.cache_info().misses, | |
} | |