File size: 405 Bytes
4bb9d41
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from dataclasses import dataclass
from typing import List, Optional
import torch

@dataclass
class ModelConfig:
    model_name: str = "gpt2"
    max_length: int = 128
    batch_size: int = 16
    learning_rate: float = 2e-5
    num_train_epochs: int = 3
    languages: List[str] = ("YORUBA", "IGBO", "HAUSA")
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    output_dir: str = "outputs"