|
|
|
|
|
|
|
BATCH_SIZE = 16 |
|
TRAINING_STEPS = 200 |
|
LEARNING_RATE = 1e-3 |
|
|
|
|
|
ENABLED_MODELS = { |
|
"tiny": "openai/whisper-tiny", |
|
"base": "openai/whisper-base", |
|
"small": "openai/whisper-small", |
|
"medium": "openai/whisper-medium", |
|
"large": "openai/whisper-large-v3-turbo", |
|
} |
|
|
|
|
|
USE_HALF_PRECISION = ( |
|
True |
|
) |
|
AGGRESSIVE_CLEANUP = False |
|
|
|
|
|
MAX_SAMPLES = None |
|
|
|
|
|
OUTPUT_DIR = "whisper-alignment-results" |
|
SAVE_PLOTS = True |
|
PLOT_DPI = 300 |
|
|
|
|
|
|
|
def get_half_precision_dtype(): |
|
""" |
|
Determine the best half precision dtype based on hardware support. |
|
bfloat16 is preferred when available as it has better numerical stability. |
|
""" |
|
import torch |
|
|
|
if not USE_HALF_PRECISION: |
|
return torch.float32 |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
device_capability = torch.cuda.get_device_capability() |
|
|
|
if device_capability[0] >= 8: |
|
return torch.bfloat16 |
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
|
|
return torch.bfloat16 |
|
elif ( |
|
hasattr(torch, "backends") |
|
and hasattr(torch.backends, "cpu") |
|
and hasattr(torch.backends.cpu, "supports_bfloat16") |
|
): |
|
|
|
if torch.backends.cpu.supports_bfloat16: |
|
return torch.bfloat16 |
|
|
|
|
|
return torch.float16 |
|
|
|
|
|
HALF_PRECISION_DTYPE = get_half_precision_dtype() |
|
|