import gradio as gr import librosa import numpy as np import pandas as pd from sklearn.metrics.pairwise import cosine_similarity import torch from speechbrain.inference.speaker import EncoderClassifier from sklearn.decomposition import PCA from sklearn.manifold import TSNE import plotly.graph_objects as go from sklearn.preprocessing import normalize import os from cryptography.fernet import Fernet import pickle # --- Configuration using Environment Variables --- encrypted_file_path = os.environ.get("SPEAKER_EMBEDDINGS_FILE") metadata_file = os.environ.get("METADATA_FILE") visualization_method = os.environ.get("VISUALIZATION_METHOD", "pca") max_length = 5 * 16000 num_closest_speakers = 20 pca_dim = 50 # --- Check for Missing Environment Variables --- if not encrypted_file_path: raise ValueError("SPEAKER_EMBEDDINGS_FILE environment variable is not set.") if not metadata_file: raise ValueError("METADATA_FILE environment variable is not set.") # --- Check for valid visualization method --- if visualization_method not in ["pca", "tsne"]: raise ValueError("Invalid VISUALIZATION_METHOD. Choose 'pca' or 'tsne'.") # --- Debugging: Check Environment Variables --- print(f"DECRYPTION_KEY: {os.getenv('DECRYPTION_KEY')}") print(f"SPEAKER_EMBEDDINGS_FILE: {os.getenv('SPEAKER_EMBEDDINGS_FILE')}") if os.getenv('SPEAKER_EMBEDDINGS_FILE'): print( f"Encrypted file path exists: {os.path.exists(os.getenv('SPEAKER_EMBEDDINGS_FILE'))}" ) else: print( "Encrypted file path does not exist: SPEAKER_EMBEDDINGS_FILE environment variable not set or file not found." ) # --- Decryption --- key = os.getenv("DECRYPTION_KEY") if not key: raise ValueError( "Decryption key is missing. Ensure DECRYPTION_KEY is set in the environment variables." ) fernet = Fernet(key.encode("utf-8")) # --- Sample Audio Files --- sample_audio_dir = "sample_audio" sample_audio_files = [ "Bob_Barker.mp3", "Howie_Mandel.m4a", "Katherine_Jenkins.mp3", ] # --- Load Embeddings and Metadata --- try: with open(encrypted_file_path, "rb") as encrypted_file: encrypted_data = encrypted_file.read() decrypted_data_bytes = fernet.decrypt(encrypted_data) # Deserialize using pickle.loads() speaker_embeddings = pickle.loads(decrypted_data_bytes) print("Speaker embeddings loaded successfully!") except FileNotFoundError: raise FileNotFoundError( f"Could not find encrypted embeddings file at: {encrypted_file_path}" ) except Exception as e: raise Exception(f"Error during decryption or loading embeddings: {e}") df = pd.read_csv(metadata_file, delimiter="\t") # --- Convert Embeddings to NumPy Arrays --- for spk_id, embeddings in speaker_embeddings.items(): speaker_embeddings[spk_id] = [np.array(embedding) for embedding in embeddings] # --- Speaker ID to Name Mapping --- speaker_id_to_name = dict(zip(df["VoxCeleb1 ID"], df["VGGFace1 ID"])) # --- Load SpeechBrain Classifier --- classifier = EncoderClassifier.from_hparams( source="speechbrain/spkrec-xvect-voxceleb", savedir="pretrained_models/spkrec-xvect-voxceleb", ) # --- Function to Calculate Average Embedding (Centroid) --- def calculate_average_embedding(embeddings): avg_embedding = np.mean(embeddings, axis=0) return normalize(avg_embedding.reshape(1, -1)).flatten() # --- Precompute Speaker Centroids --- speaker_centroids = { spk_id: calculate_average_embedding(embeddings) for spk_id, embeddings in speaker_embeddings.items() } # --- Function to Prepare Data for Visualization --- def prepare_data_for_visualization(speaker_centroids, closest_speaker_ids): all_embeddings = [ centroid for speaker_id, centroid in speaker_centroids.items() if speaker_id in closest_speaker_ids ] all_speaker_ids = [ speaker_id for speaker_id in speaker_centroids if speaker_id in closest_speaker_ids ] return np.array(all_embeddings), np.array(all_speaker_ids) # --- Function to Reduce Dimensionality --- def reduce_dimensionality(all_embeddings, method="tsne", perplexity=5, pca_dim=50): if method == "pca": reducer = PCA(n_components=2) elif method == "tsne": pca_reducer = PCA(n_components=pca_dim) all_embeddings = pca_reducer.fit_transform(all_embeddings) reducer = TSNE(n_components=2, random_state=42, perplexity=perplexity) else: raise ValueError("Invalid method. Choose 'pca' or 'tsne'.") reduced_embeddings = reducer.fit_transform(all_embeddings) return reducer, reduced_embeddings # --- Function to Get Speaker Name from ID --- def get_speaker_name(speaker_id): return speaker_id_to_name.get(speaker_id, f"Unknown ({speaker_id})") # --- Function to Generate Visualization --- def generate_visualization( pca_reducer, reduced_embeddings, all_speaker_ids, new_embedding, predicted_speaker_id, visualization_method, perplexity, pca_dim, ): if visualization_method == "pca": new_embedding_reduced = pca_reducer.transform(new_embedding.reshape(1, -1)) elif visualization_method == "tsne": combined_embeddings = np.vstack( [reduced_embeddings, new_embedding.reshape(1, -1)] ) reducer = TSNE(n_components=2, random_state=42, perplexity=perplexity) combined_reduced = reducer.fit_transform(combined_embeddings) reduced_embeddings = combined_reduced[:-1] new_embedding_reduced = combined_reduced[-1].reshape(1, -1) else: raise ValueError("Invalid visualization method.") fig = go.Figure() fig.add_trace( go.Scatter( x=reduced_embeddings[:, 0], y=reduced_embeddings[:, 1], mode="markers", marker=dict(color="blue", size=8, opacity=0.5), text=[get_speaker_name(speaker_id) for speaker_id in all_speaker_ids], name="Other Speakers", ) ) if predicted_speaker_id in all_speaker_ids: predicted_speaker_index = list(all_speaker_ids).index(predicted_speaker_id) fig.add_trace( go.Scatter( x=[reduced_embeddings[predicted_speaker_index, 0]], y=[reduced_embeddings[predicted_speaker_index, 1]], mode="markers", marker=dict( color="green", size=10, symbol="circle", line=dict(color="black", width=2), ), name=get_speaker_name(predicted_speaker_id), text=[get_speaker_name(predicted_speaker_id)], ) ) fig.add_trace( go.Scatter( x=new_embedding_reduced[:, 0], y=new_embedding_reduced[:, 1], mode="markers", marker=dict(color="red", size=12, symbol="star"), name="New Voice", text=["New Voice"], ) ) fig.update_layout( title=f"Dimensionality Reduction of Speaker Embeddings using {visualization_method.upper()}", xaxis_title="Component 1", yaxis_title="Component 2", legend=dict(x=0, y=1, traceorder="normal", orientation="h"), hovermode="closest", ) return fig # --- Main Function --- def identify_voice_and_visualize_with_averaging(audio_file, perplexity=5): try: if isinstance(audio_file, str): signal, fs = librosa.load(audio_file, sr=16000) elif isinstance(audio_file, np.ndarray): signal = audio_file fs = 16000 else: raise ValueError( "Invalid audio input. Must be a file path or a NumPy array." ) signal_tensor = torch.tensor(signal, dtype=torch.float32).unsqueeze(0) signal_tensor = torch.nn.functional.pad( signal_tensor, (0, max_length - signal_tensor.shape[1]) ) user_embedding = classifier.encode_batch(signal_tensor).cpu().detach().numpy() user_embedding = normalize( user_embedding.squeeze(axis=(0, 1)).reshape(1, -1) ).flatten() similarity_scores = { spk_id: cosine_similarity( user_embedding.reshape(1, -1), centroid.reshape(1, -1) )[0][0] for spk_id, centroid in speaker_centroids.items() } closest_speaker_ids = sorted( similarity_scores, key=similarity_scores.get, reverse=True )[:num_closest_speakers] predicted_speaker_id = closest_speaker_ids[0] highest_similarity = similarity_scores[predicted_speaker_id] all_embeddings, all_speaker_ids = prepare_data_for_visualization( speaker_centroids, closest_speaker_ids ) reducer, reduced_embeddings = reduce_dimensionality( all_embeddings, method=visualization_method, perplexity=perplexity, pca_dim=pca_dim, ) predicted_speaker_name = get_speaker_name(predicted_speaker_id) similarity_percentage = round(highest_similarity * 100, 2) # Rounded here visualization = generate_visualization( reducer, reduced_embeddings, all_speaker_ids, user_embedding, predicted_speaker_id, visualization_method, perplexity, pca_dim, ) result_text = ( f"The voice resembles speaker: {predicted_speaker_name} " f"with a similarity of {similarity_percentage:.2f}%" # Display rounded value ) return result_text, visualization except Exception as e: return f"Error during processing: {e}", None # --- Gradio Interface --- # Create a directory for caching examples if it doesn't exist cache_dir = "examples_cache" if not os.path.exists(cache_dir): os.makedirs(cache_dir) # Define the Gradio interface iface = gr.Interface( fn=identify_voice_and_visualize_with_averaging, inputs=gr.Audio(type="filepath", label="Input Audio"), outputs=["text", gr.Plot()], title="Discover Your Celebrity Voice Twin!", description="Record your voice or upload an audio file, and see your celebrity match! Not ready to record? Try our sample voices to see how it works!", cache_examples=False, examples_per_page=3, examples=[ [os.path.join(sample_audio_dir, sample_audio_files[0])], [os.path.join(sample_audio_dir, sample_audio_files[1])], [os.path.join(sample_audio_dir, sample_audio_files[2])], ], ) # Launch the interface iface.launch(debug=True, share=True)