File size: 1,247 Bytes
4bb9d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# 1. app/model/config.py
from dataclasses import dataclass
from typing import List, Optional
import torch

@dataclass
class ModelConfig:
    model_name: str = "gpt2"
    max_length: int = 128
    batch_size: int = 16
    learning_rate: float = 2e-5
    num_train_epochs: int = 3
    languages: List[str] = ("YORUBA", "IGBO", "HAUSA")
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    output_dir: str = "outputs"

# app/model/model.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from .config import ModelConfig

class NigerianLanguageModel:
    def __init__(self, config: ModelConfig):
        self.config = config
        self.setup_model()
    
    def setup_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(self.config.model_name)
        self._setup_special_tokens()
        self.model.to(self.config.device)

    def _setup_special_tokens(self):
        special_tokens = {
            "additional_special_tokens": [f"[{lang}]" for lang in self.config.languages]
        }
        self.tokenizer.add_special_tokens(special_tokens)
        self.model.resize_token_embeddings(len(self.tokenizer))