ptdat commited on
Commit
9b0df61
·
verified ·
1 Parent(s): 7738c2e

Upload model

Browse files
Files changed (1) hide show
  1. modeling_vnsabsa.py +130 -2
modeling_vnsabsa.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import PreTrainedModel
2
- from modules import SmartphoneBERT
3
  import torch
 
4
 
5
  from .configuration_vnsabsa import VnSmartphoneAbsaConfig
6
 
@@ -72,4 +72,132 @@ class VnSmartphoneAbsaModel(PreTrainedModel):
72
  if a_i[-1] >= aspect_thresholds[-1]:
73
  res_i["OTHERS"] = ""
74
 
75
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import PreTrainedModel
 
2
  import torch
3
+ import torch.nn as nn
4
 
5
  from .configuration_vnsabsa import VnSmartphoneAbsaConfig
6
 
 
72
  if a_i[-1] >= aspect_thresholds[-1]:
73
  res_i["OTHERS"] = ""
74
 
75
+ return results
76
+
77
+
78
+ class AspectClassifier(nn.Module):
79
+ def __init__(
80
+ self,
81
+ input_size: int,
82
+ dropout: float = 0.3,
83
+ hidden_size: int = 64,
84
+ *args, **kwargs
85
+ ) -> None:
86
+ super().__init__(*args, **kwargs)
87
+
88
+ self.input_size = input_size
89
+
90
+ self.fc = nn.Sequential(
91
+ nn.Dropout(dropout),
92
+ nn.Linear(
93
+ in_features=input_size,
94
+ out_features=hidden_size
95
+ ),
96
+ nn.ReLU(),
97
+ nn.Dropout(dropout),
98
+ nn.Linear(
99
+ in_features=hidden_size,
100
+ out_features=10+1
101
+ )
102
+ )
103
+
104
+ def forward(self, input: torch.Tensor):
105
+ x = self.fc(input)
106
+ return x
107
+
108
+
109
+ class PolarityClassifier(nn.Module):
110
+ def __init__(
111
+ self,
112
+ input_size: int,
113
+ dropout: float = 0.5,
114
+ hidden_size: int = 64,
115
+ *args, **kwargs
116
+ ) -> None:
117
+ super().__init__(*args, **kwargs)
118
+ self.polarity_fcs = nn.ModuleList([
119
+ nn.Sequential(
120
+ nn.Dropout(dropout),
121
+ nn.Linear(
122
+ in_features=input_size,
123
+ out_features=hidden_size
124
+ ),
125
+ nn.ReLU(),
126
+ nn.Dropout(dropout),
127
+ nn.Linear(
128
+ in_features=hidden_size,
129
+ out_features=3
130
+ )
131
+ )
132
+ for _ in torch.arange(10)
133
+ ])
134
+
135
+ def forward(self, input: torch.Tensor):
136
+ polarities = torch.stack([
137
+ fc(input)
138
+ for fc in self.polarity_fcs
139
+ ])
140
+
141
+ if input.ndim == 2:
142
+ polarities = polarities.transpose(0, 1)
143
+ return polarities
144
+
145
+
146
+ class SmartphoneBERT(nn.Module):
147
+ def __init__(
148
+ self,
149
+ vocab_size: int,
150
+ embed_dim: int = 768,
151
+ num_heads: int = 8,
152
+ num_encoders: int = 4,
153
+ encoder_dropout: float = 0.1,
154
+ fc_dropout: float =0.4,
155
+ fc_hidden_size: int = 128,
156
+ *args, **kwargs
157
+ ):
158
+ super().__init__(*args, **kwargs)
159
+ self.embed = nn.Embedding(
160
+ num_embeddings=vocab_size,
161
+ embedding_dim=embed_dim,
162
+ padding_idx=0
163
+ )
164
+ self.encoder = nn.TransformerEncoder(
165
+ nn.TransformerEncoderLayer(
166
+ d_model=embed_dim,
167
+ nhead=num_heads,
168
+ dim_feedforward=embed_dim,
169
+ dropout=encoder_dropout,
170
+ batch_first=True
171
+ ),
172
+ num_layers=num_encoders,
173
+ norm=nn.LayerNorm(embed_dim),
174
+ enable_nested_tensor=False
175
+ )
176
+ self.a_fc = AspectClassifier(
177
+ input_size=2*embed_dim,
178
+ dropout=fc_dropout,
179
+ hidden_size=fc_hidden_size
180
+ )
181
+ self.p_fc = PolarityClassifier(
182
+ input_size=2*embed_dim,
183
+ dropout=fc_dropout,
184
+ hidden_size=fc_hidden_size
185
+ )
186
+
187
+ def forward(
188
+ self,
189
+ input_ids: torch.Tensor,
190
+ attention_mask: torch.Tensor
191
+ ):
192
+ padding_mask = ~attention_mask.bool()
193
+ x = self.embed(input_ids)
194
+ x = self.encoder(x, src_key_padding_mask=padding_mask)
195
+ x[padding_mask] = 0
196
+ x = torch.cat([
197
+ x[..., 0, :],
198
+ torch.mean(x, dim=-2)
199
+ ], dim=-1)
200
+
201
+ a_logits = self.a_fc(x)
202
+ p_logits = self.p_fc(x)
203
+ return a_logits, p_logits