File size: 636 Bytes
0144345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
from transformers import AutoModel


class ThaiEncoder(nn.Module):
    def __init__(self, model_name: str, trainable: bool = False) -> None:
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        
        for p in self.model.parameters():
            p.requires_grad = trainable
        
        self.target_token_idx = 0
        
    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]