File size: 10,150 Bytes
3b3134b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import os
from itertools import combinations

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from datasets import Audio, load_dataset
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoFeatureExtractor, WhisperModel

from .config import *

model_ids = ENABLED_MODELS

# Load dataset
dataset = load_dataset("JacobLinCool/cv161-en-zh-subset-200", split="train")
if MAX_SAMPLES is not None:
    dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset))))
    print(f"Limited dataset to {len(dataset)} samples for testing")

dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using device: {device}")


def extract_layer_reps_generator(model_id, batch_size=4):
    """
    Use a generator to process samples in batches, avoiding loading all hidden states into memory at once.
    Yields (sample_idx, layer_reps) pairs, where layer_reps is a list of all layer representations for the sample.
    """
    model = WhisperModel.from_pretrained(model_id).to(device)
    feat_ext = AutoFeatureExtractor.from_pretrained(model_id)
    model.eval()

    for i in tqdm(
        range(0, len(dataset), batch_size), desc=f"Processing {model_id} in batches"
    ):
        batch_end = min(i + batch_size, len(dataset))
        batch_samples = dataset.select(range(i, batch_end))

        # Process each sample in the batch
        for j, sample in enumerate(batch_samples):
            audio = sample["audio"]
            samples = audio["array"]
            sr = audio["sampling_rate"]

            inputs = feat_ext(
                samples, sampling_rate=sr, return_tensors="pt"
            ).input_features.to(device)
            with torch.no_grad():
                outputs = model.encoder(
                    inputs, return_dict=True, output_hidden_states=True
                )

            # Save the full sequence for each layer and immediately move to CPU; optionally use half precision to save memory
            layer_reps_for_sample = []
            for hs in outputs.hidden_states:
                # hs: [1, T, D] -> [T, D]
                layer_rep = hs.squeeze(0)
                if USE_HALF_PRECISION:
                    layer_rep = layer_rep.to(HALF_PRECISION_DTYPE)
                layer_reps_for_sample.append(layer_rep)

            yield i + j, layer_reps_for_sample

            # Clean up GPU memory
            del outputs, inputs
            if AGGRESSIVE_CLEANUP and torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Clean up model memory
    del model, feat_ext
    if AGGRESSIVE_CLEANUP and torch.cuda.is_available():
        torch.cuda.empty_cache()


def compute_linear_mse_matrix_temporal_memory_efficient(
    model_a_id, model_b_id, n_steps=200, lr=1e-3, batch_size=4
):
    """
    Memory-efficient version: For each layer pair (i, j), trains a 1x1 convolution as a linear probe and computes MSE.
    Uses a generator to process in batches, avoiding loading all representations into memory at once.
    Returns an MSE matrix of shape (layers_a, layers_b) and all trained probes.
    """
    print(f"Computing alignment between {model_a_id} and {model_b_id}...")

    # First, get the number of layers
    sample_gen_a = extract_layer_reps_generator(model_a_id, batch_size=1)
    _, sample_reps_a = next(sample_gen_a)
    layers_a = len(sample_reps_a)

    sample_gen_b = extract_layer_reps_generator(model_b_id, batch_size=1)
    _, sample_reps_b = next(sample_gen_b)
    layers_b = len(sample_reps_b)

    mse_mat = np.zeros((layers_a, layers_b))
    trained_probes = {}

    pbar = tqdm(total=layers_a * layers_b, desc="Comparing layer pairs")

    # Re-initialize generators to process all samples
    gen_a = extract_layer_reps_generator(model_a_id, batch_size=batch_size)
    gen_b = extract_layer_reps_generator(model_b_id, batch_size=batch_size)

    # Collect all sample representations for specified layers
    reps_a_dict_all = {}
    for sample_idx, layer_reps in gen_a:
        reps_a_dict_all[sample_idx] = layer_reps

    reps_b_dict_all = {}
    for sample_idx, layer_reps in gen_b:
        reps_b_dict_all[sample_idx] = layer_reps

    for i in range(layers_a):
        for j in range(layers_b):
            # Collect all sample representations for the specified layer
            reps_a_dict = {}
            for sample_idx, layer_reps in reps_a_dict_all.items():
                if i < len(layer_reps):
                    reps_a_dict[sample_idx] = layer_reps[i]

            reps_b_dict = {}
            for sample_idx, layer_reps in reps_b_dict_all.items():
                if j < len(layer_reps):
                    reps_b_dict[sample_idx] = layer_reps[j]

            # Concatenate representations in order
            X_list = [reps_a_dict[idx] for idx in sorted(reps_a_dict.keys())]
            Y_list = [reps_b_dict[idx] for idx in sorted(reps_b_dict.keys())]

            # Process in batches to avoid memory issues
            X_cat = torch.cat(X_list, dim=0).to(device)
            Y_cat = torch.cat(Y_list, dim=0).to(device)

            dim_a = X_cat.shape[1]
            dim_b = Y_cat.shape[1]

            # For Conv1d, reshape to [Batch, Channels, Length]
            X = X_cat.T.unsqueeze(0)  # [1, Dim_A, Total_Tokens]
            Y = Y_cat.T.unsqueeze(0)  # [1, Dim_B, Total_Tokens]

            # 2. Define and train linear probe (1x1 Conv)
            probe = nn.Conv1d(
                in_channels=dim_a, out_channels=dim_b, kernel_size=1, bias=False
            ).to(device=device, dtype=HALF_PRECISION_DTYPE)
            probe.train()

            optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
            loss_fn = nn.MSELoss()

            for step in tqdm(range(n_steps), desc=f"Training probe {i}->{j}"):
                optimizer.zero_grad()
                Y_pred = probe(X)
                loss = loss_fn(Y_pred, Y)
                loss.backward()
                optimizer.step()

            # 3. Record final MSE and trained probe
            final_mse = loss.item()
            mse_mat[i, j] = final_mse
            trained_probes[f"layer_{i}_to_{j}"] = probe.state_dict()["weight"]

            # Clean up memory
            del (
                X_cat,
                Y_cat,
                X,
                Y,
                probe,
                optimizer,
                reps_a_dict,
                reps_b_dict,
                X_list,
                Y_list,
            )
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            pbar.update(1)
            pbar.set_postfix({"layer_a": i, "layer_b": j, "mse": f"{final_mse:.4f}"})

    pbar.close()
    return mse_mat, trained_probes


if __name__ == "__main__":
    print(f"Memory optimization settings:")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Training steps: {TRAINING_STEPS}")
    if USE_HALF_PRECISION:
        dtype_name = "bfloat16" if HALF_PRECISION_DTYPE == torch.bfloat16 else "float16"
        print(f"  Half precision: {USE_HALF_PRECISION} ({dtype_name})")
    else:
        print(f"  Half precision: {USE_HALF_PRECISION}")
    print(f"  Aggressive cleanup: {AGGRESSIVE_CLEANUP}")
    print(f"  Models: {list(model_ids.keys())}")
    print(f"  Dataset size: {len(dataset)} samples")

    # Create results directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # 2. Compare all model pairs - using memory-efficient method
    model_names = list(model_ids.keys())
    all_pairs = list(combinations(model_names, 2))

    print(
        f"\nProcessing {len(all_pairs)} model pairs with memory-efficient approach..."
    )

    for pair_idx, (model_a, model_b) in enumerate(all_pairs):
        print(
            f"\n[{pair_idx + 1}/{len(all_pairs)}] Computing temporal linear MSE for whisper-{model_a} vs whisper-{model_b}..."
        )

        # Compute linear MSE along the temporal dimension and get trained probes - memory-efficient version
        mse_mat_temporal, trained_probes = (
            compute_linear_mse_matrix_temporal_memory_efficient(
                model_ids[model_a],
                model_ids[model_b],
                n_steps=TRAINING_STEPS,
                lr=LEARNING_RATE,
                batch_size=BATCH_SIZE,
            )
        )

        # Save trained models
        model_save_path = f"{OUTPUT_DIR}/{model_a}-to-{model_b}-probes.safetensors"
        save_file(
            trained_probes,
            model_save_path,
            {
                "from_model": model_a,
                "to_model": model_b,
                "from_layers": str(len(mse_mat_temporal)),
                "to_layers": str(len(mse_mat_temporal[0])),
            },
        )
        print(f"Saved trained probes to: {model_save_path}")

        if SAVE_PLOTS:
            # Visualize results
            # Avoid log(0) by adding a small value
            eps = 1e-10
            log_mse_mat = -np.log10(mse_mat_temporal + eps)

            plt.figure(figsize=(8, 6))
            plt.imshow(
                log_mse_mat, aspect="auto", origin="lower"
            )  # origin='lower' is more standard for matrices
            plt.colorbar(label="-log10(MSE)")
            plt.title(
                f"Temporal Linear MSE (log scale): whisper-{model_a} vs whisper-{model_b}"
            )
            plt.xlabel(f"whisper-{model_b} layers")
            plt.ylabel(f"whisper-{model_a} layers")
            plt.tight_layout()

            # Save visualization results
            plot_save_path = (
                f"{OUTPUT_DIR}/{model_a}-vs-{model_b}-temporal-linear-mse-log.png"
            )
            plt.savefig(plot_save_path, dpi=PLOT_DPI)
            plt.close()  # Close figure to save memory
            print(f"Saved plot to: {plot_save_path}")

    print(f"\nAll experiments complete! Results saved to '{OUTPUT_DIR}' directory")
    print(
        f"Generated {len(all_pairs)} visualization plots and {len(all_pairs)} trained probe models"
    )