|
from transformers import PreTrainedModel |
|
import torch |
|
import torch.nn as nn |
|
|
|
from .configuration_vnsabsa import VnSmartphoneAbsaConfig |
|
|
|
from typing import Tuple |
|
|
|
|
|
class VnSmartphoneAbsaModel(PreTrainedModel): |
|
config_class = VnSmartphoneAbsaConfig |
|
|
|
def __init__( |
|
self, |
|
config: VnSmartphoneAbsaConfig |
|
): |
|
super().__init__(config) |
|
self.model = SmartphoneBERT( |
|
vocab_size=config.vocab_size, |
|
embed_dim=config.embed_dim, |
|
num_heads=config.num_heads, |
|
num_encoders=config.num_encoders, |
|
encoder_dropout=config.encoder_dropout, |
|
fc_dropout=config.fc_dropout, |
|
fc_hidden_size=config.fc_hidden_size |
|
) |
|
self.ASPECT_LOOKUP = { |
|
i: a |
|
for i, a in enumerate(["CAMERA", "FEATURES", "BATTERY", "PRICE", "GENERAL", "SER&ACC", "PERFORMANCE", "SCREEN", "DESIGN", "STORAGE", "OTHERS"]) |
|
} |
|
self.POLARITY_LOOKUP = { |
|
i: p |
|
for i, p in enumerate(["Negative", "Neutral", "Positive"]) |
|
} |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
aspect_thresholds: float | torch.Tensor = 0.5 |
|
): |
|
pred = self.model(input_ids, attention_mask) |
|
result = self.decode_absa( |
|
pred, |
|
aspect_thresholds=aspect_thresholds |
|
) |
|
return result |
|
|
|
def decode_absa( |
|
self, |
|
pred: Tuple[torch.Tensor, torch.Tensor], |
|
aspect_thresholds: float | torch.Tensor = 0.5 |
|
): |
|
if isinstance(aspect_thresholds, float): |
|
aspect_thresholds = torch.full((11,), aspect_thresholds) |
|
|
|
a, p = pred |
|
a = a.sigmoid().cpu() |
|
p = p.argmax(dim=-1).cpu() |
|
|
|
results = [] |
|
for a_i, p_i in zip(a, p): |
|
res_i = {} |
|
for i in range(10): |
|
a = self.ASPECT_LOOKUP[i] |
|
p = self.POLARITY_LOOKUP[p_i[i].item()] |
|
if a_i[i] >= aspect_thresholds[i]: |
|
res_i[a] = p |
|
results.append(res_i) |
|
|
|
|
|
if a_i[-1] >= aspect_thresholds[-1]: |
|
res_i["OTHERS"] = "" |
|
|
|
return results |
|
|
|
|
|
class AspectClassifier(nn.Module): |
|
def __init__( |
|
self, |
|
input_size: int, |
|
dropout: float = 0.3, |
|
hidden_size: int = 64, |
|
*args, **kwargs |
|
) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
self.input_size = input_size |
|
|
|
self.fc = nn.Sequential( |
|
nn.Dropout(dropout), |
|
nn.Linear( |
|
in_features=input_size, |
|
out_features=hidden_size |
|
), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear( |
|
in_features=hidden_size, |
|
out_features=10+1 |
|
) |
|
) |
|
|
|
def forward(self, input: torch.Tensor): |
|
x = self.fc(input) |
|
return x |
|
|
|
|
|
class PolarityClassifier(nn.Module): |
|
def __init__( |
|
self, |
|
input_size: int, |
|
dropout: float = 0.5, |
|
hidden_size: int = 64, |
|
*args, **kwargs |
|
) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.polarity_fcs = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Dropout(dropout), |
|
nn.Linear( |
|
in_features=input_size, |
|
out_features=hidden_size |
|
), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear( |
|
in_features=hidden_size, |
|
out_features=3 |
|
) |
|
) |
|
for _ in torch.arange(10) |
|
]) |
|
|
|
def forward(self, input: torch.Tensor): |
|
polarities = torch.stack([ |
|
fc(input) |
|
for fc in self.polarity_fcs |
|
]) |
|
|
|
if input.ndim == 2: |
|
polarities = polarities.transpose(0, 1) |
|
return polarities |
|
|
|
|
|
class SmartphoneBERT(nn.Module): |
|
def __init__( |
|
self, |
|
vocab_size: int, |
|
embed_dim: int = 768, |
|
num_heads: int = 8, |
|
num_encoders: int = 4, |
|
encoder_dropout: float = 0.1, |
|
fc_dropout: float =0.4, |
|
fc_hidden_size: int = 128, |
|
*args, **kwargs |
|
): |
|
super().__init__(*args, **kwargs) |
|
self.embed = nn.Embedding( |
|
num_embeddings=vocab_size, |
|
embedding_dim=embed_dim, |
|
padding_idx=0 |
|
) |
|
self.encoder = nn.TransformerEncoder( |
|
nn.TransformerEncoderLayer( |
|
d_model=embed_dim, |
|
nhead=num_heads, |
|
dim_feedforward=embed_dim, |
|
dropout=encoder_dropout, |
|
batch_first=True |
|
), |
|
num_layers=num_encoders, |
|
norm=nn.LayerNorm(embed_dim), |
|
enable_nested_tensor=False |
|
) |
|
self.a_fc = AspectClassifier( |
|
input_size=2*embed_dim, |
|
dropout=fc_dropout, |
|
hidden_size=fc_hidden_size |
|
) |
|
self.p_fc = PolarityClassifier( |
|
input_size=2*embed_dim, |
|
dropout=fc_dropout, |
|
hidden_size=fc_hidden_size |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor |
|
): |
|
padding_mask = ~attention_mask.bool() |
|
x = self.embed(input_ids) |
|
x = self.encoder(x, src_key_padding_mask=padding_mask) |
|
x[padding_mask] = 0 |
|
x = torch.cat([ |
|
x[..., 0, :], |
|
torch.mean(x, dim=-2) |
|
], dim=-1) |
|
|
|
a_logits = self.a_fc(x) |
|
p_logits = self.p_fc(x) |
|
return a_logits, p_logits |