Nachaphat's picture
Upload model
0144345
raw
history blame contribute delete
636 Bytes
import torch.nn as nn
from transformers import AutoModel
class ThaiEncoder(nn.Module):
def __init__(self, model_name: str, trainable: bool = False) -> None:
super().__init__()
self.model = AutoModel.from_pretrained(model_name)
for p in self.model.parameters():
p.requires_grad = trainable
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]