E2-F5-TTS / mcqa_bert.py
kevinwang676's picture
Create mcqa_bert.py
d831168 verified
raw
history blame contribute delete
883 Bytes
# mcqa_bert.py
# --------------------------------------------------
# Plain BertModel + single‑unit classification head
# --------------------------------------------------
import torch
import torch.nn as nn
from transformers import BertModel
class MCQABERT(nn.Module):
def __init__(self, ckpt: str = "bert-base-uncased"):
super().__init__()
self.encoder = BertModel.from_pretrained(ckpt)
self.head = nn.Linear(self.encoder.config.hidden_size, 1)
# --------------------------------------------------
def forward(self, input_ids, attention_mask):
out = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
cls_vec = out.last_hidden_state[:, 0] # [CLS]
logits = self.head(cls_vec).squeeze(-1) # (B)
return logits