vn-smartphone-absa / modeling_vnsabsa.py
ptdat's picture
Upload model
9b0df61 verified
from transformers import PreTrainedModel
import torch
import torch.nn as nn
from .configuration_vnsabsa import VnSmartphoneAbsaConfig
from typing import Tuple
class VnSmartphoneAbsaModel(PreTrainedModel):
config_class = VnSmartphoneAbsaConfig
def __init__(
self,
config: VnSmartphoneAbsaConfig
):
super().__init__(config)
self.model = SmartphoneBERT(
vocab_size=config.vocab_size,
embed_dim=config.embed_dim,
num_heads=config.num_heads,
num_encoders=config.num_encoders,
encoder_dropout=config.encoder_dropout,
fc_dropout=config.fc_dropout,
fc_hidden_size=config.fc_hidden_size
)
self.ASPECT_LOOKUP = {
i: a
for i, a in enumerate(["CAMERA", "FEATURES", "BATTERY", "PRICE", "GENERAL", "SER&ACC", "PERFORMANCE", "SCREEN", "DESIGN", "STORAGE", "OTHERS"])
}
self.POLARITY_LOOKUP = {
i: p
for i, p in enumerate(["Negative", "Neutral", "Positive"])
}
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
aspect_thresholds: float | torch.Tensor = 0.5
):
pred = self.model(input_ids, attention_mask)
result = self.decode_absa(
pred,
aspect_thresholds=aspect_thresholds
)
return result
def decode_absa(
self,
pred: Tuple[torch.Tensor, torch.Tensor],
aspect_thresholds: float | torch.Tensor = 0.5
):
if isinstance(aspect_thresholds, float):
aspect_thresholds = torch.full((11,), aspect_thresholds)
a, p = pred
a = a.sigmoid().cpu()
p = p.argmax(dim=-1).cpu()
results = []
for a_i, p_i in zip(a, p):
res_i = {}
for i in range(10):
a = self.ASPECT_LOOKUP[i]
p = self.POLARITY_LOOKUP[p_i[i].item()]
if a_i[i] >= aspect_thresholds[i]:
res_i[a] = p
results.append(res_i)
# OTHERS
if a_i[-1] >= aspect_thresholds[-1]:
res_i["OTHERS"] = ""
return results
class AspectClassifier(nn.Module):
def __init__(
self,
input_size: int,
dropout: float = 0.3,
hidden_size: int = 64,
*args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.input_size = input_size
self.fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(
in_features=input_size,
out_features=hidden_size
),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(
in_features=hidden_size,
out_features=10+1
)
)
def forward(self, input: torch.Tensor):
x = self.fc(input)
return x
class PolarityClassifier(nn.Module):
def __init__(
self,
input_size: int,
dropout: float = 0.5,
hidden_size: int = 64,
*args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.polarity_fcs = nn.ModuleList([
nn.Sequential(
nn.Dropout(dropout),
nn.Linear(
in_features=input_size,
out_features=hidden_size
),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(
in_features=hidden_size,
out_features=3
)
)
for _ in torch.arange(10)
])
def forward(self, input: torch.Tensor):
polarities = torch.stack([
fc(input)
for fc in self.polarity_fcs
])
if input.ndim == 2:
polarities = polarities.transpose(0, 1)
return polarities
class SmartphoneBERT(nn.Module):
def __init__(
self,
vocab_size: int,
embed_dim: int = 768,
num_heads: int = 8,
num_encoders: int = 4,
encoder_dropout: float = 0.1,
fc_dropout: float =0.4,
fc_hidden_size: int = 128,
*args, **kwargs
):
super().__init__(*args, **kwargs)
self.embed = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embed_dim,
padding_idx=0
)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim,
dropout=encoder_dropout,
batch_first=True
),
num_layers=num_encoders,
norm=nn.LayerNorm(embed_dim),
enable_nested_tensor=False
)
self.a_fc = AspectClassifier(
input_size=2*embed_dim,
dropout=fc_dropout,
hidden_size=fc_hidden_size
)
self.p_fc = PolarityClassifier(
input_size=2*embed_dim,
dropout=fc_dropout,
hidden_size=fc_hidden_size
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor
):
padding_mask = ~attention_mask.bool()
x = self.embed(input_ids)
x = self.encoder(x, src_key_padding_mask=padding_mask)
x[padding_mask] = 0
x = torch.cat([
x[..., 0, :],
torch.mean(x, dim=-2)
], dim=-1)
a_logits = self.a_fc(x)
p_logits = self.p_fc(x)
return a_logits, p_logits