from transformers import PreTrainedModel import torch import torch.nn as nn from .configuration_vnsabsa import VnSmartphoneAbsaConfig from typing import Tuple class VnSmartphoneAbsaModel(PreTrainedModel): config_class = VnSmartphoneAbsaConfig def __init__( self, config: VnSmartphoneAbsaConfig ): super().__init__(config) self.model = SmartphoneBERT( vocab_size=config.vocab_size, embed_dim=config.embed_dim, num_heads=config.num_heads, num_encoders=config.num_encoders, encoder_dropout=config.encoder_dropout, fc_dropout=config.fc_dropout, fc_hidden_size=config.fc_hidden_size ) self.ASPECT_LOOKUP = { i: a for i, a in enumerate(["CAMERA", "FEATURES", "BATTERY", "PRICE", "GENERAL", "SER&ACC", "PERFORMANCE", "SCREEN", "DESIGN", "STORAGE", "OTHERS"]) } self.POLARITY_LOOKUP = { i: p for i, p in enumerate(["Negative", "Neutral", "Positive"]) } def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, aspect_thresholds: float | torch.Tensor = 0.5 ): pred = self.model(input_ids, attention_mask) result = self.decode_absa( pred, aspect_thresholds=aspect_thresholds ) return result def decode_absa( self, pred: Tuple[torch.Tensor, torch.Tensor], aspect_thresholds: float | torch.Tensor = 0.5 ): if isinstance(aspect_thresholds, float): aspect_thresholds = torch.full((11,), aspect_thresholds) a, p = pred a = a.sigmoid().cpu() p = p.argmax(dim=-1).cpu() results = [] for a_i, p_i in zip(a, p): res_i = {} for i in range(10): a = self.ASPECT_LOOKUP[i] p = self.POLARITY_LOOKUP[p_i[i].item()] if a_i[i] >= aspect_thresholds[i]: res_i[a] = p results.append(res_i) # OTHERS if a_i[-1] >= aspect_thresholds[-1]: res_i["OTHERS"] = "" return results class AspectClassifier(nn.Module): def __init__( self, input_size: int, dropout: float = 0.3, hidden_size: int = 64, *args, **kwargs ) -> None: super().__init__(*args, **kwargs) self.input_size = input_size self.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear( in_features=input_size, out_features=hidden_size ), nn.ReLU(), nn.Dropout(dropout), nn.Linear( in_features=hidden_size, out_features=10+1 ) ) def forward(self, input: torch.Tensor): x = self.fc(input) return x class PolarityClassifier(nn.Module): def __init__( self, input_size: int, dropout: float = 0.5, hidden_size: int = 64, *args, **kwargs ) -> None: super().__init__(*args, **kwargs) self.polarity_fcs = nn.ModuleList([ nn.Sequential( nn.Dropout(dropout), nn.Linear( in_features=input_size, out_features=hidden_size ), nn.ReLU(), nn.Dropout(dropout), nn.Linear( in_features=hidden_size, out_features=3 ) ) for _ in torch.arange(10) ]) def forward(self, input: torch.Tensor): polarities = torch.stack([ fc(input) for fc in self.polarity_fcs ]) if input.ndim == 2: polarities = polarities.transpose(0, 1) return polarities class SmartphoneBERT(nn.Module): def __init__( self, vocab_size: int, embed_dim: int = 768, num_heads: int = 8, num_encoders: int = 4, encoder_dropout: float = 0.1, fc_dropout: float =0.4, fc_hidden_size: int = 128, *args, **kwargs ): super().__init__(*args, **kwargs) self.embed = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=0 ) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim, dropout=encoder_dropout, batch_first=True ), num_layers=num_encoders, norm=nn.LayerNorm(embed_dim), enable_nested_tensor=False ) self.a_fc = AspectClassifier( input_size=2*embed_dim, dropout=fc_dropout, hidden_size=fc_hidden_size ) self.p_fc = PolarityClassifier( input_size=2*embed_dim, dropout=fc_dropout, hidden_size=fc_hidden_size ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor ): padding_mask = ~attention_mask.bool() x = self.embed(input_ids) x = self.encoder(x, src_key_padding_mask=padding_mask) x[padding_mask] = 0 x = torch.cat([ x[..., 0, :], torch.mean(x, dim=-2) ], dim=-1) a_logits = self.a_fc(x) p_logits = self.p_fc(x) return a_logits, p_logits