import os from itertools import combinations import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from datasets import Audio, load_dataset from safetensors.torch import save_file from tqdm import tqdm from transformers import AutoFeatureExtractor, WhisperModel from .config import * model_ids = ENABLED_MODELS # Load dataset dataset = load_dataset("JacobLinCool/cv161-en-zh-subset-200", split="train") if MAX_SAMPLES is not None: dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset)))) print(f"Limited dataset to {len(dataset)} samples for testing") dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000)) device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using device: {device}") def extract_layer_reps_generator(model_id, batch_size=4): """ Use a generator to process samples in batches, avoiding loading all hidden states into memory at once. Yields (sample_idx, layer_reps) pairs, where layer_reps is a list of all layer representations for the sample. """ model = WhisperModel.from_pretrained(model_id).to(device) feat_ext = AutoFeatureExtractor.from_pretrained(model_id) model.eval() for i in tqdm( range(0, len(dataset), batch_size), desc=f"Processing {model_id} in batches" ): batch_end = min(i + batch_size, len(dataset)) batch_samples = dataset.select(range(i, batch_end)) # Process each sample in the batch for j, sample in enumerate(batch_samples): audio = sample["audio"] samples = audio["array"] sr = audio["sampling_rate"] inputs = feat_ext( samples, sampling_rate=sr, return_tensors="pt" ).input_features.to(device) with torch.no_grad(): outputs = model.encoder( inputs, return_dict=True, output_hidden_states=True ) # Save the full sequence for each layer and immediately move to CPU; optionally use half precision to save memory layer_reps_for_sample = [] for hs in outputs.hidden_states: # hs: [1, T, D] -> [T, D] layer_rep = hs.squeeze(0) if USE_HALF_PRECISION: layer_rep = layer_rep.to(HALF_PRECISION_DTYPE) layer_reps_for_sample.append(layer_rep) yield i + j, layer_reps_for_sample # Clean up GPU memory del outputs, inputs if AGGRESSIVE_CLEANUP and torch.cuda.is_available(): torch.cuda.empty_cache() # Clean up model memory del model, feat_ext if AGGRESSIVE_CLEANUP and torch.cuda.is_available(): torch.cuda.empty_cache() def compute_linear_mse_matrix_temporal_memory_efficient( model_a_id, model_b_id, n_steps=200, lr=1e-3, batch_size=4 ): """ Memory-efficient version: For each layer pair (i, j), trains a 1x1 convolution as a linear probe and computes MSE. Uses a generator to process in batches, avoiding loading all representations into memory at once. Returns an MSE matrix of shape (layers_a, layers_b) and all trained probes. """ print(f"Computing alignment between {model_a_id} and {model_b_id}...") # First, get the number of layers sample_gen_a = extract_layer_reps_generator(model_a_id, batch_size=1) _, sample_reps_a = next(sample_gen_a) layers_a = len(sample_reps_a) sample_gen_b = extract_layer_reps_generator(model_b_id, batch_size=1) _, sample_reps_b = next(sample_gen_b) layers_b = len(sample_reps_b) mse_mat = np.zeros((layers_a, layers_b)) trained_probes = {} pbar = tqdm(total=layers_a * layers_b, desc="Comparing layer pairs") # Re-initialize generators to process all samples gen_a = extract_layer_reps_generator(model_a_id, batch_size=batch_size) gen_b = extract_layer_reps_generator(model_b_id, batch_size=batch_size) # Collect all sample representations for specified layers reps_a_dict_all = {} for sample_idx, layer_reps in gen_a: reps_a_dict_all[sample_idx] = layer_reps reps_b_dict_all = {} for sample_idx, layer_reps in gen_b: reps_b_dict_all[sample_idx] = layer_reps for i in range(layers_a): for j in range(layers_b): # Collect all sample representations for the specified layer reps_a_dict = {} for sample_idx, layer_reps in reps_a_dict_all.items(): if i < len(layer_reps): reps_a_dict[sample_idx] = layer_reps[i] reps_b_dict = {} for sample_idx, layer_reps in reps_b_dict_all.items(): if j < len(layer_reps): reps_b_dict[sample_idx] = layer_reps[j] # Concatenate representations in order X_list = [reps_a_dict[idx] for idx in sorted(reps_a_dict.keys())] Y_list = [reps_b_dict[idx] for idx in sorted(reps_b_dict.keys())] # Process in batches to avoid memory issues X_cat = torch.cat(X_list, dim=0).to(device) Y_cat = torch.cat(Y_list, dim=0).to(device) dim_a = X_cat.shape[1] dim_b = Y_cat.shape[1] # For Conv1d, reshape to [Batch, Channels, Length] X = X_cat.T.unsqueeze(0) # [1, Dim_A, Total_Tokens] Y = Y_cat.T.unsqueeze(0) # [1, Dim_B, Total_Tokens] # 2. Define and train linear probe (1x1 Conv) probe = nn.Conv1d( in_channels=dim_a, out_channels=dim_b, kernel_size=1, bias=False ).to(device=device, dtype=HALF_PRECISION_DTYPE) probe.train() optimizer = torch.optim.Adam(probe.parameters(), lr=lr) loss_fn = nn.MSELoss() for step in tqdm(range(n_steps), desc=f"Training probe {i}->{j}"): optimizer.zero_grad() Y_pred = probe(X) loss = loss_fn(Y_pred, Y) loss.backward() optimizer.step() # 3. Record final MSE and trained probe final_mse = loss.item() mse_mat[i, j] = final_mse trained_probes[f"layer_{i}_to_{j}"] = probe.state_dict()["weight"] # Clean up memory del ( X_cat, Y_cat, X, Y, probe, optimizer, reps_a_dict, reps_b_dict, X_list, Y_list, ) if torch.cuda.is_available(): torch.cuda.empty_cache() pbar.update(1) pbar.set_postfix({"layer_a": i, "layer_b": j, "mse": f"{final_mse:.4f}"}) pbar.close() return mse_mat, trained_probes if __name__ == "__main__": print(f"Memory optimization settings:") print(f" Batch size: {BATCH_SIZE}") print(f" Training steps: {TRAINING_STEPS}") if USE_HALF_PRECISION: dtype_name = "bfloat16" if HALF_PRECISION_DTYPE == torch.bfloat16 else "float16" print(f" Half precision: {USE_HALF_PRECISION} ({dtype_name})") else: print(f" Half precision: {USE_HALF_PRECISION}") print(f" Aggressive cleanup: {AGGRESSIVE_CLEANUP}") print(f" Models: {list(model_ids.keys())}") print(f" Dataset size: {len(dataset)} samples") # Create results directory os.makedirs(OUTPUT_DIR, exist_ok=True) # 2. Compare all model pairs - using memory-efficient method model_names = list(model_ids.keys()) all_pairs = list(combinations(model_names, 2)) print( f"\nProcessing {len(all_pairs)} model pairs with memory-efficient approach..." ) for pair_idx, (model_a, model_b) in enumerate(all_pairs): print( f"\n[{pair_idx + 1}/{len(all_pairs)}] Computing temporal linear MSE for whisper-{model_a} vs whisper-{model_b}..." ) # Compute linear MSE along the temporal dimension and get trained probes - memory-efficient version mse_mat_temporal, trained_probes = ( compute_linear_mse_matrix_temporal_memory_efficient( model_ids[model_a], model_ids[model_b], n_steps=TRAINING_STEPS, lr=LEARNING_RATE, batch_size=BATCH_SIZE, ) ) # Save trained models model_save_path = f"{OUTPUT_DIR}/{model_a}-to-{model_b}-probes.safetensors" save_file( trained_probes, model_save_path, { "from_model": model_a, "to_model": model_b, "from_layers": str(len(mse_mat_temporal)), "to_layers": str(len(mse_mat_temporal[0])), }, ) print(f"Saved trained probes to: {model_save_path}") if SAVE_PLOTS: # Visualize results # Avoid log(0) by adding a small value eps = 1e-10 log_mse_mat = -np.log10(mse_mat_temporal + eps) plt.figure(figsize=(8, 6)) plt.imshow( log_mse_mat, aspect="auto", origin="lower" ) # origin='lower' is more standard for matrices plt.colorbar(label="-log10(MSE)") plt.title( f"Temporal Linear MSE (log scale): whisper-{model_a} vs whisper-{model_b}" ) plt.xlabel(f"whisper-{model_b} layers") plt.ylabel(f"whisper-{model_a} layers") plt.tight_layout() # Save visualization results plot_save_path = ( f"{OUTPUT_DIR}/{model_a}-vs-{model_b}-temporal-linear-mse-log.png" ) plt.savefig(plot_save_path, dpi=PLOT_DPI) plt.close() # Close figure to save memory print(f"Saved plot to: {plot_save_path}") print(f"\nAll experiments complete! Results saved to '{OUTPUT_DIR}' directory") print( f"Generated {len(all_pairs)} visualization plots and {len(all_pairs)} trained probe models" )