File size: 2,229 Bytes
3b3134b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Memory Configuration for Whisper Alignment Analysis

# Batch processing settings
BATCH_SIZE = 16  # Reduce if you get OOM errors, increase for faster processing
TRAINING_STEPS = 200  # Number of training steps for linear probes
LEARNING_RATE = 1e-3

# Model selection
ENABLED_MODELS = {
    "tiny": "openai/whisper-tiny",  # ~39M parameters
    "base": "openai/whisper-base",  # ~74M parameters
    "small": "openai/whisper-small",  # ~244M parameters
    "medium": "openai/whisper-medium",  # ~769M parameters
    "large": "openai/whisper-large-v3-turbo",  # ~1550M parameters
}

# Memory optimization settings
USE_HALF_PRECISION = (
    True  # Use half precision (bfloat16 preferred, float16 fallback) instead of float32
)
AGGRESSIVE_CLEANUP = False  # Clear GPU cache after each operation

# Dataset settings
MAX_SAMPLES = None  # Set to a number to limit dataset size for testing (e.g., 50)

# Output settings
OUTPUT_DIR = "whisper-alignment-results"
SAVE_PLOTS = True
PLOT_DPI = 300


# Half precision dtype selection (bfloat16 preferred if available, fallback to float16)
def get_half_precision_dtype():
    """
    Determine the best half precision dtype based on hardware support.
    bfloat16 is preferred when available as it has better numerical stability.
    """
    import torch

    if not USE_HALF_PRECISION:
        return torch.float32

    # Check if bfloat16 is supported
    if torch.cuda.is_available():
        # Check GPU support for bfloat16
        device_capability = torch.cuda.get_device_capability()
        # bfloat16 is supported on Ampere (8.x) and newer GPUs
        if device_capability[0] >= 8:
            return torch.bfloat16
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        # Apple Silicon supports bfloat16
        return torch.bfloat16
    elif (
        hasattr(torch, "backends")
        and hasattr(torch.backends, "cpu")
        and hasattr(torch.backends.cpu, "supports_bfloat16")
    ):
        # Check CPU support for bfloat16 (newer PyTorch versions)
        if torch.backends.cpu.supports_bfloat16:
            return torch.bfloat16

    # Fallback to float16
    return torch.float16


HALF_PRECISION_DTYPE = get_half_precision_dtype()