Spaces:
Runtime error
Runtime error
# mcqa_bert.py | |
# -------------------------------------------------- | |
# Plain BertModel + single‑unit classification head | |
# -------------------------------------------------- | |
import torch | |
import torch.nn as nn | |
from transformers import BertModel | |
class MCQABERT(nn.Module): | |
def __init__(self, ckpt: str = "bert-base-uncased"): | |
super().__init__() | |
self.encoder = BertModel.from_pretrained(ckpt) | |
self.head = nn.Linear(self.encoder.config.hidden_size, 1) | |
# -------------------------------------------------- | |
def forward(self, input_ids, attention_mask): | |
out = self.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=True, | |
) | |
cls_vec = out.last_hidden_state[:, 0] # [CLS] | |
logits = self.head(cls_vec).squeeze(-1) # (B) | |
return logits | |