import torch import torch.nn as nn import torch.nn.functional as F from speechbrain.pretrained import EncoderClassifier import numpy as np from scipy.spatial.distance import cosine import librosa import torchaudio import gradio as gr import noisereduce as nr # Import WavLM components from Hugging Face from transformers import WavLMForXVector, Wav2Vec2FeatureExtractor # ---------------- Noise Reduction and Silence Removal Functions ---------------- def reduce_noise(waveform, sample_rate=16000): """ Apply a mild noise reduction to the waveform specialized for voice audio. The parameters are chosen to minimize alteration to the original voice. Parameters: waveform (torch.Tensor): Audio tensor of shape (1, n_samples) sample_rate (int): Sampling rate of the audio Returns: torch.Tensor: Denoised audio tensor of shape (1, n_samples) """ # Convert tensor to numpy array waveform_np = waveform.squeeze(0).cpu().numpy() # Perform noise reduction with conservative parameters. reduced_noise = nr.reduce_noise(y=waveform_np, sr=sample_rate, prop_decrease=0.5) return torch.from_numpy(reduced_noise).unsqueeze(0) def remove_long_silence(waveform, sample_rate=16000, top_db=20, max_silence_length=1.0): """ Remove silence segments longer than max_silence_length seconds from the audio. This function uses librosa.effects.split to detect non-silent intervals and preserves at most max_silence_length seconds of silence between speech segments. Parameters: waveform (torch.Tensor): Audio tensor of shape (1, n_samples) sample_rate (int): Sampling rate of the audio top_db (int): The threshold (in decibels) below reference to consider as silence max_silence_length (float): Maximum allowed silence duration in seconds Returns: torch.Tensor: Processed audio tensor with long silences removed """ # Convert tensor to numpy array waveform_np = waveform.squeeze(0).cpu().numpy() # Identify non-silent intervals non_silent_intervals = librosa.effects.split(waveform_np, top_db=top_db) if len(non_silent_intervals) == 0: return waveform output_segments = [] max_silence_samples = int(max_silence_length * sample_rate) # Handle silence before the first non-silent interval if non_silent_intervals[0][0] > 0: output_segments.append(waveform_np[:min(non_silent_intervals[0][0], max_silence_samples)]) # Process each non-silent interval and the gap following it for i, (start, end) in enumerate(non_silent_intervals): output_segments.append(waveform_np[start:end]) if i < len(non_silent_intervals) - 1: next_start = non_silent_intervals[i + 1][0] gap = next_start - end if gap > max_silence_samples: output_segments.append(waveform_np[end:end + max_silence_samples]) else: output_segments.append(waveform_np[end:next_start]) # Handle silence after the last non-silent interval if non_silent_intervals[-1][1] < len(waveform_np): gap = len(waveform_np) - non_silent_intervals[-1][1] if gap > max_silence_samples: output_segments.append(waveform_np[-max_silence_samples:]) else: output_segments.append(waveform_np[non_silent_intervals[-1][1]:]) processed_waveform = np.concatenate(output_segments) return torch.from_numpy(processed_waveform).unsqueeze(0) # ----------------------------------------------------------------------------- class EnhancedECAPATDNN(nn.Module): def __init__(self): super().__init__() # Primary pretrained model from SpeechBrain (ECAPA-TDNN, trained on VoxCeleb) self.ecapa = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", savedir="pretrained_models/spkrec-ecapa-voxceleb", run_opts={"device": "cuda" if torch.cuda.is_available() else "cpu"} ) # Secondary pretrained model: Microsoft WavLM for Speaker Verification self.wavlm_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-sv") self.wavlm = WavLMForXVector.from_pretrained("microsoft/wavlm-base-sv") self.wavlm.to("cuda" if torch.cuda.is_available() else "cpu") # Projection layer to map WavLM's embedding (now 512-dim) to 192-dim (to match ECAPA) self.wavlm_proj = nn.Linear(512, 192) # Enhanced network: deeper enhancement layers # Increase dimensionality then reduce back to 192. self.enhancement = nn.Sequential( nn.Linear(192, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 192) ) # Transformer encoder block (with batch_first=True) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=192, nhead=4, dropout=0.3, batch_first=True), num_layers=2 ) @torch.no_grad() def forward(self, x): """ x: input waveform tensor of shape (1, T) on device. """ # Extract ECAPA embedding emb_ecapa = self.ecapa.encode_batch(x) # Prepare input for WavLM: # x is a waveform tensor of shape (1, T) waveform_np = x.squeeze(0).cpu().numpy() # shape (T,) wavlm_inputs = self.wavlm_feature_extractor(waveform_np, sampling_rate=16000, return_tensors="pt") wavlm_inputs = {k: v.to(x.device) for k, v in wavlm_inputs.items()} wavlm_out = self.wavlm(**wavlm_inputs) # Extract embeddings; expected shape (batch, 512) emb_wavlm = wavlm_out.embeddings # Project WavLM embedding to 192-dim emb_wavlm_proj = self.wavlm_proj(emb_wavlm) # Process ECAPA embedding: if emb_ecapa.dim() > 2 and emb_ecapa.size(1) > 1: emb_ecapa_proc = self.transformer(emb_ecapa) emb_ecapa_proc = emb_ecapa_proc.mean(dim=1) else: emb_ecapa_proc = emb_ecapa # Fuse the two embeddings by averaging fused = (emb_ecapa_proc + emb_wavlm_proj) / 2 # Apply enhancement layers and normalize enhanced = self.enhancement(fused) output = F.normalize(enhanced, p=2, dim=-1) return output class ForensicSpeakerVerification: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.model = EnhancedECAPATDNN().to(self.device) self.model.eval() # Optimize only the enhancement and transformer layers if fine-tuning trainable_params = list(self.model.enhancement.parameters()) + list(self.model.transformer.parameters()) self.optimizer = torch.optim.AdamW(trainable_params, lr=1e-4) self.training_embeddings = [] def preprocess_audio(self, file_path, max_duration=10): try: waveform, sample_rate = torchaudio.load(file_path) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) max_length = int(16000 * max_duration) if waveform.shape[1] > max_length: waveform = waveform[:, :max_length] waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) # Apply noise reduction waveform = reduce_noise(waveform, sample_rate=16000) # Remove silences longer than 1 second waveform = remove_long_silence(waveform, sample_rate=16000) return waveform.to(self.device) except Exception as e: raise ValueError(f"Error preprocessing audio: {str(e)}") @torch.no_grad() def extract_embedding(self, file_path, chunk_duration=3, overlap=0.5): waveform = self.preprocess_audio(file_path) sample_rate = 16000 chunk_size = int(chunk_duration * sample_rate) hop_size = int(chunk_size * (1 - overlap)) embeddings = [] if waveform.shape[1] > chunk_size: for start in range(0, waveform.shape[1] - chunk_size + 1, hop_size): chunk = waveform[:, start:start+chunk_size] emb = self.model(chunk) embeddings.append(emb) final_emb = torch.mean(torch.cat(embeddings, dim=0), dim=0, keepdim=True) else: final_emb = self.model(waveform) return final_emb.cpu().numpy() def verify_speaker(self, questioned_audio, suspect_audio, progress=gr.Progress()): if not questioned_audio or not suspect_audio: return "⚠️ Please provide both audio samples" try: progress(0.2, desc="Processing questioned audio...") questioned_emb = self.extract_embedding(questioned_audio) progress(0.4, desc="Processing suspect audio...") suspect_emb = self.extract_embedding(suspect_audio) progress(0.6, desc="Computing similarity...") score = 1 - cosine(questioned_emb.flatten(), suspect_emb.flatten()) # Convert similarity score to probability (percentage) probability = score * 100 # Create heat bar HTML heat_bar = f"""
Similarity Score: {probability:.1f}%
{heat_bar}{verdict_text}