Spaces:
Runtime error
Runtime error
File size: 883 Bytes
d831168 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
# mcqa_bert.py
# --------------------------------------------------
# Plain BertModel + single‑unit classification head
# --------------------------------------------------
import torch
import torch.nn as nn
from transformers import BertModel
class MCQABERT(nn.Module):
def __init__(self, ckpt: str = "bert-base-uncased"):
super().__init__()
self.encoder = BertModel.from_pretrained(ckpt)
self.head = nn.Linear(self.encoder.config.hidden_size, 1)
# --------------------------------------------------
def forward(self, input_ids, attention_mask):
out = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
cls_vec = out.last_hidden_state[:, 0] # [CLS]
logits = self.head(cls_vec).squeeze(-1) # (B)
return logits
|