TruthCheck / src /models /hybrid_model.py
adnaan05's picture
Initial commit for Hugging Face Space
469c254
raw
history blame
3.05 kB
import torch
import torch.nn as nn
from transformers import BertModel
from typing import Tuple, Dict
class AttentionLayer(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.attention = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1)
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
attention_weights = torch.softmax(self.attention(x), dim=1)
attended = torch.sum(attention_weights * x, dim=1)
return attended, attention_weights
class HybridFakeNewsDetector(nn.Module):
def __init__(self,
bert_model_name: str = "bert-base-uncased",
lstm_hidden_size: int = 256,
lstm_num_layers: int = 2,
dropout_rate: float = 0.3,
num_classes: int = 2):
super().__init__()
# BERT encoder
self.bert = BertModel.from_pretrained(bert_model_name)
bert_hidden_size = self.bert.config.hidden_size
# BiLSTM layer
self.lstm = nn.LSTM(
input_size=bert_hidden_size,
hidden_size=lstm_hidden_size,
num_layers=lstm_num_layers,
batch_first=True,
bidirectional=True
)
# Attention layer
self.attention = AttentionLayer(lstm_hidden_size * 2)
# Classification head
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(lstm_hidden_size * 2, lstm_hidden_size),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(lstm_hidden_size, num_classes)
)
def forward(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
# Get BERT embeddings
bert_outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
bert_embeddings = bert_outputs.last_hidden_state
# Process through BiLSTM
lstm_output, _ = self.lstm(bert_embeddings)
# Apply attention
attended, attention_weights = self.attention(lstm_output)
# Classification
logits = self.classifier(attended)
return {
'logits': logits,
'attention_weights': attention_weights
}
def predict(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
"""Get model predictions."""
outputs = self.forward(input_ids, attention_mask)
return torch.softmax(outputs['logits'], dim=1)
def get_attention_weights(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
"""Get attention weights for interpretability."""
outputs = self.forward(input_ids, attention_mask)
return outputs['attention_weights']