File size: 8,242 Bytes
0b9eed5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
Base STT Implementation
======================
Common audio processing and validation for all STT providers
"""
import struct
from typing import Optional, Tuple, List
from datetime import datetime
from abc import ABC, abstractmethod

from .stt_interface import STTInterface, STTConfig, TranscriptionResult
from utils.logger import log_info, log_error, log_debug, log_warning


class STTBase(STTInterface, ABC):
    """Base class for all STT implementations with common audio processing"""
    
    def __init__(self):
        super().__init__()
    
    async def transcribe(self, audio_data: bytes, config: STTConfig) -> Optional[TranscriptionResult]:
        """Main transcription method with preprocessing"""
        try:
            # 1. Validate input
            if not audio_data:
                log_warning("⚠️ No audio data provided")
                return None
            
            log_info(f"πŸ“Š Transcribing {len(audio_data)} bytes of audio")
            
            # 2. Analyze and validate audio
            analysis_result = self._analyze_audio(audio_data, config.sample_rate)
            if not analysis_result.is_valid:
                log_warning(f"⚠️ Audio validation failed: {analysis_result.reason}")
                return None
            
            # 3. Preprocess audio if needed
            processed_audio = self._preprocess_audio(audio_data, config)
            
            # 4. Call provider-specific implementation
            return await self._transcribe_impl(processed_audio, config, analysis_result)
            
        except Exception as e:
            log_error(f"❌ Error during transcription: {str(e)}")
            import traceback
            log_error(f"Traceback: {traceback.format_exc()}")
            return None
    
    @abstractmethod
    async def _transcribe_impl(self, audio_data: bytes, config: STTConfig, analysis: 'AudioAnalysis') -> Optional[TranscriptionResult]:
        """Provider-specific transcription implementation"""
        pass
    
    def _analyze_audio(self, audio_data: bytes, sample_rate: int) -> 'AudioAnalysis':
        """Analyze audio quality and content"""
        try:
            samples = struct.unpack(f'{len(audio_data)//2}h', audio_data)
            total_samples = len(samples)
            
            # Basic statistics
            non_zero_samples = [s for s in samples if s != 0]
            zero_count = total_samples - len(non_zero_samples)
            
            if non_zero_samples:
                avg_amplitude = sum(abs(s) for s in non_zero_samples) / len(non_zero_samples)
                max_amplitude = max(abs(s) for s in non_zero_samples)
            else:
                avg_amplitude = 0
                max_amplitude = 0
            
            log_info(f"πŸ” Audio stats: {total_samples} total samples, {zero_count} zeros ({zero_count/total_samples:.1%})")
            log_info(f"πŸ” Non-zero stats: avg={avg_amplitude:.1f}, max={max_amplitude}")
            
            # Section analysis (10 sections)
            section_size = total_samples // 10
            sections = []
            
            for i in range(10):
                start_idx = i * section_size
                end_idx = (i + 1) * section_size if i < 9 else total_samples
                section = samples[start_idx:end_idx]
                
                section_non_zero = [s for s in section if s != 0]
                section_max = max(abs(s) for s in section_non_zero) if section_non_zero else 0
                section_avg = sum(abs(s) for s in section_non_zero) / len(section_non_zero) if section_non_zero else 0
                zero_ratio = (len(section) - len(section_non_zero)) / len(section)
                
                sections.append({
                    'max': section_max,
                    'avg': section_avg,
                    'zero_ratio': zero_ratio
                })
                
                log_info(f"  Section {i+1}: max={section_max}, avg={section_avg:.1f}, zeros={zero_ratio:.1%}")
            
            # Find speech start
            speech_start_idx = self._find_speech_start(samples, sample_rate)
            speech_start_time = speech_start_idx / sample_rate if speech_start_idx >= 0 else -1
            
            if speech_start_idx >= 0:
                log_info(f"🎀 Speech detected starting at sample {speech_start_idx} ({speech_start_time:.2f}s)")
            else:
                log_warning("⚠️ No speech detected above threshold in entire audio")
            
            # Validation
            is_valid = True
            reason = ""
            
            if max_amplitude < 100:
                is_valid = False
                reason = f"Audio appears silent: max_amplitude={max_amplitude}"
            elif zero_count / total_samples > 0.95:
                is_valid = False
                reason = f"Audio is mostly zeros: {zero_count/total_samples:.1%}"
            elif speech_start_idx < 0:
                is_valid = False
                reason = "No speech detected"
            
            return AudioAnalysis(
                total_samples=total_samples,
                sample_rate=sample_rate,
                zero_count=zero_count,
                avg_amplitude=avg_amplitude,
                max_amplitude=max_amplitude,
                sections=sections,
                speech_start_idx=speech_start_idx,
                speech_start_time=speech_start_time,
                is_valid=is_valid,
                reason=reason
            )
            
        except Exception as e:
            log_error(f"Audio analysis failed: {e}")
            return AudioAnalysis(
                total_samples=0,
                sample_rate=sample_rate,
                is_valid=False,
                reason=f"Analysis failed: {e}"
            )
    
    def _find_speech_start(self, samples: List[int], sample_rate: int, threshold: int = 500) -> int:
        """Find the starting point of speech in audio"""
        window_size = 100
        
        for i in range(0, len(samples) - window_size, window_size):
            window = samples[i:i + window_size]
            rms = (sum(s * s for s in window) / window_size) ** 0.5
            
            if rms > threshold:
                return i
        
        return -1
    
    def _preprocess_audio(self, audio_data: bytes, config: STTConfig) -> bytes:
        """Preprocess audio if needed (can be overridden by providers)"""
        # Default: no preprocessing
        return audio_data
    
    def _clean_audio_silence(self, audio_data: bytes, threshold: int = 50) -> bytes:
        """Remove leading/trailing silence"""
        try:
            samples = struct.unpack(f'{len(audio_data)//2}h', audio_data)
            
            # Find first non-silent sample
            start_idx = 0
            for i, sample in enumerate(samples):
                if abs(sample) > threshold:
                    start_idx = i
                    break
            
            # Find last non-silent sample
            end_idx = len(samples) - 1
            for i in range(len(samples) - 1, -1, -1):
                if abs(samples[i]) > threshold:
                    end_idx = i
                    break
            
            # Add padding
            start_idx = max(0, start_idx - 100)
            end_idx = min(len(samples) - 1, end_idx + 100)
            
            # Convert back
            cleaned_samples = samples[start_idx:end_idx + 1]
            cleaned_audio = struct.pack(f'{len(cleaned_samples)}h', *cleaned_samples)
            
            log_debug(f"Audio cleaning: {len(audio_data)} β†’ {len(cleaned_audio)} bytes")
            return cleaned_audio
            
        except Exception as e:
            log_warning(f"Audio cleaning failed: {e}, using original")
            return audio_data


@dataclass
class AudioAnalysis:
    """Audio analysis results"""
    total_samples: int = 0
    sample_rate: int = 16000
    zero_count: int = 0
    avg_amplitude: float = 0.0
    max_amplitude: int = 0
    sections: List[dict] = field(default_factory=list)
    speech_start_idx: int = -1
    speech_start_time: float = -1.0
    is_valid: bool = False
    reason: str = ""