File size: 2,229 Bytes
3b3134b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# Memory Configuration for Whisper Alignment Analysis
# Batch processing settings
BATCH_SIZE = 16 # Reduce if you get OOM errors, increase for faster processing
TRAINING_STEPS = 200 # Number of training steps for linear probes
LEARNING_RATE = 1e-3
# Model selection
ENABLED_MODELS = {
"tiny": "openai/whisper-tiny", # ~39M parameters
"base": "openai/whisper-base", # ~74M parameters
"small": "openai/whisper-small", # ~244M parameters
"medium": "openai/whisper-medium", # ~769M parameters
"large": "openai/whisper-large-v3-turbo", # ~1550M parameters
}
# Memory optimization settings
USE_HALF_PRECISION = (
True # Use half precision (bfloat16 preferred, float16 fallback) instead of float32
)
AGGRESSIVE_CLEANUP = False # Clear GPU cache after each operation
# Dataset settings
MAX_SAMPLES = None # Set to a number to limit dataset size for testing (e.g., 50)
# Output settings
OUTPUT_DIR = "whisper-alignment-results"
SAVE_PLOTS = True
PLOT_DPI = 300
# Half precision dtype selection (bfloat16 preferred if available, fallback to float16)
def get_half_precision_dtype():
"""
Determine the best half precision dtype based on hardware support.
bfloat16 is preferred when available as it has better numerical stability.
"""
import torch
if not USE_HALF_PRECISION:
return torch.float32
# Check if bfloat16 is supported
if torch.cuda.is_available():
# Check GPU support for bfloat16
device_capability = torch.cuda.get_device_capability()
# bfloat16 is supported on Ampere (8.x) and newer GPUs
if device_capability[0] >= 8:
return torch.bfloat16
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
# Apple Silicon supports bfloat16
return torch.bfloat16
elif (
hasattr(torch, "backends")
and hasattr(torch.backends, "cpu")
and hasattr(torch.backends.cpu, "supports_bfloat16")
):
# Check CPU support for bfloat16 (newer PyTorch versions)
if torch.backends.cpu.supports_bfloat16:
return torch.bfloat16
# Fallback to float16
return torch.float16
HALF_PRECISION_DTYPE = get_half_precision_dtype()
|