Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import BertModel | |
from typing import Tuple, Dict | |
class AttentionLayer(nn.Module): | |
def __init__(self, hidden_size: int): | |
super().__init__() | |
self.attention = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size), | |
nn.Tanh(), | |
nn.Linear(hidden_size, 1) | |
) | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
attention_weights = torch.softmax(self.attention(x), dim=1) | |
attended = torch.sum(attention_weights * x, dim=1) | |
return attended, attention_weights | |
class HybridFakeNewsDetector(nn.Module): | |
def __init__(self, | |
bert_model_name: str = "bert-base-uncased", | |
lstm_hidden_size: int = 256, | |
lstm_num_layers: int = 2, | |
dropout_rate: float = 0.3, | |
num_classes: int = 2): | |
super().__init__() | |
# BERT encoder | |
self.bert = BertModel.from_pretrained(bert_model_name) | |
bert_hidden_size = self.bert.config.hidden_size | |
# BiLSTM layer | |
self.lstm = nn.LSTM( | |
input_size=bert_hidden_size, | |
hidden_size=lstm_hidden_size, | |
num_layers=lstm_num_layers, | |
batch_first=True, | |
bidirectional=True | |
) | |
# Attention layer | |
self.attention = AttentionLayer(lstm_hidden_size * 2) | |
# Classification head | |
self.classifier = nn.Sequential( | |
nn.Dropout(dropout_rate), | |
nn.Linear(lstm_hidden_size * 2, lstm_hidden_size), | |
nn.ReLU(), | |
nn.Dropout(dropout_rate), | |
nn.Linear(lstm_hidden_size, num_classes) | |
) | |
def forward(self, input_ids: torch.Tensor, | |
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]: | |
# Get BERT embeddings | |
bert_outputs = self.bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
bert_embeddings = bert_outputs.last_hidden_state | |
# Process through BiLSTM | |
lstm_output, _ = self.lstm(bert_embeddings) | |
# Apply attention | |
attended, attention_weights = self.attention(lstm_output) | |
# Classification | |
logits = self.classifier(attended) | |
return { | |
'logits': logits, | |
'attention_weights': attention_weights | |
} | |
def predict(self, input_ids: torch.Tensor, | |
attention_mask: torch.Tensor) -> torch.Tensor: | |
"""Get model predictions.""" | |
outputs = self.forward(input_ids, attention_mask) | |
return torch.softmax(outputs['logits'], dim=1) | |
def get_attention_weights(self, input_ids: torch.Tensor, | |
attention_mask: torch.Tensor) -> torch.Tensor: | |
"""Get attention weights for interpretability.""" | |
outputs = self.forward(input_ids, attention_mask) | |
return outputs['attention_weights'] |