Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from pathlib import Path | |
import torch | |
class TrainingConfig: | |
"""Configuration for model training""" | |
# Model parameters | |
model_name: str = "microsoft/deberta-v3-large" | |
dropout: float = 0.1 | |
# Training parameters | |
num_epochs: int = 5 | |
batch_size: int = 8 | |
learning_rate: float = 2e-5 | |
warmup_ratio: float = 0.1 | |
weight_decay: float = 0.01 | |
max_grad_norm: float = 1.0 | |
# Data parameters | |
max_length: int = 512 | |
train_ratio: float = 0.8 | |
# Output parameters | |
output_dir: Path = Path("outputs") | |
save_steps: int = 100 | |
eval_steps: int = 50 | |
# Device | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
def __post_init__(self): | |
"""Create output directory if it doesn't exist""" | |
self.output_dir.mkdir(parents=True, exist_ok=True) | |
# Test code | |
if __name__ == "__main__": | |
# Create default config | |
default_config = TrainingConfig() | |
print("\n=== Default Configuration ===") | |
print(f"Model name: {default_config.model_name}") | |
print(f"Batch size: {default_config.batch_size}") | |
print(f"Learning rate: {default_config.learning_rate}") | |
print(f"Device: {default_config.device}") | |
# Create custom config | |
custom_config = TrainingConfig( | |
batch_size=16, | |
num_epochs=10, | |
learning_rate=1e-5 | |
) | |
print("\n=== Custom Configuration ===") | |
print(f"Model name: {custom_config.model_name}") # Uses default | |
print(f"Batch size: {custom_config.batch_size}") # Customized | |
print(f"Learning rate: {custom_config.learning_rate}") # Customized | |
print(f"Number of epochs: {custom_config.num_epochs}") # Customized |