# mcqa_bert.py # -------------------------------------------------- # Plain BertModel + single‑unit classification head # -------------------------------------------------- import torch import torch.nn as nn from transformers import BertModel class MCQABERT(nn.Module): def __init__(self, ckpt: str = "bert-base-uncased"): super().__init__() self.encoder = BertModel.from_pretrained(ckpt) self.head = nn.Linear(self.encoder.config.hidden_size, 1) # -------------------------------------------------- def forward(self, input_ids, attention_mask): out = self.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) cls_vec = out.last_hidden_state[:, 0] # [CLS] logits = self.head(cls_vec).squeeze(-1) # (B) return logits