Upload 4 files
Browse files- model/frequential.py +45 -0
- model/network.py +50 -0
- model/positional.py +41 -0
- model/sequential.py +68 -0
model/frequential.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class FreqNetwork(nn.Module):
|
6 |
+
def __init__(self,
|
7 |
+
tfidf_input_dim,
|
8 |
+
tfidf_output_dim,
|
9 |
+
tfidf_hidden_dim,
|
10 |
+
n_layers,
|
11 |
+
skip_in=(4,),
|
12 |
+
weight_norm=True):
|
13 |
+
super(FreqNetwork, self).__init__()
|
14 |
+
dims = [tfidf_input_dim] + [tfidf_hidden_dim for _ in range(n_layers)] + [tfidf_output_dim]
|
15 |
+
self.num_layers = n_layers
|
16 |
+
self.skip_in = skip_in
|
17 |
+
for l in range(0, self.num_layers+1):
|
18 |
+
if l+1 in self.skip_in:
|
19 |
+
out_dim = dims[l + 1] + dims[0]
|
20 |
+
dims[l + 1] = out_dim
|
21 |
+
else:
|
22 |
+
out_dim = dims[l + 1]
|
23 |
+
lin = nn.Linear(dims[l], out_dim)
|
24 |
+
if weight_norm:
|
25 |
+
lin = nn.utils.weight_norm(lin)
|
26 |
+
setattr(self, "lin" + str(l), lin)
|
27 |
+
self.activation = nn.ReLU
|
28 |
+
|
29 |
+
def forward(self, inputs):
|
30 |
+
x = inputs
|
31 |
+
for l in range(0, self.num_layers+1):
|
32 |
+
lin = getattr(self, "lin" + str(l))
|
33 |
+
|
34 |
+
if l in self.skip_in:
|
35 |
+
x = torch.cat([x, inputs], 1) / np.sqrt(2)
|
36 |
+
|
37 |
+
x = lin(x)
|
38 |
+
|
39 |
+
if l < self.num_layers:
|
40 |
+
x = self.activation()(x)
|
41 |
+
|
42 |
+
# x = torch.dropout(x, p=0.2, train=self.training)
|
43 |
+
|
44 |
+
|
45 |
+
return x
|
model/network.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from sklearn.linear_model import LogisticRegression
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class Classifier(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
combined_input,
|
12 |
+
combined_dim,
|
13 |
+
num_classes,
|
14 |
+
n_layers,
|
15 |
+
skip_in=(4,),
|
16 |
+
weight_norm=True):
|
17 |
+
super(Classifier, self).__init__()
|
18 |
+
self.num_layers = n_layers
|
19 |
+
self.skip_in = skip_in
|
20 |
+
self.model = LogisticRegression()
|
21 |
+
# Combined classification layers
|
22 |
+
dims = [combined_input] + [combined_dim for _ in range(n_layers)] + [num_classes]
|
23 |
+
for l in range(0, self.num_layers + 1):
|
24 |
+
if l+1 in self.skip_in:
|
25 |
+
out_dim = dims[l + 1] + dims[0]
|
26 |
+
dims[l + 1] = out_dim
|
27 |
+
else:
|
28 |
+
out_dim = dims[l + 1]
|
29 |
+
lin = nn.Linear(dims[l], out_dim)
|
30 |
+
if weight_norm:
|
31 |
+
lin = nn.utils.weight_norm(lin)
|
32 |
+
setattr(self, "lin" + str(l), lin)
|
33 |
+
self.activation = nn.ReLU
|
34 |
+
|
35 |
+
def forward(self, inputs):
|
36 |
+
|
37 |
+
x = inputs
|
38 |
+
for l in range(0, self.num_layers + 1):
|
39 |
+
lin = getattr(self, "lin" + str(l))
|
40 |
+
|
41 |
+
if l+1 in self.skip_in:
|
42 |
+
x = torch.cat([x, inputs], 1) / np.sqrt(2)
|
43 |
+
x = lin(x)
|
44 |
+
|
45 |
+
if l < self.num_layers:
|
46 |
+
x = self.activation()(x)
|
47 |
+
|
48 |
+
# x = torch.dropout(x, p=0.2, train=self.training)
|
49 |
+
# Output layer
|
50 |
+
return x
|
model/positional.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from model.siren import SIREN
|
5 |
+
|
6 |
+
class PosNetwork(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
input_dim,
|
9 |
+
output_dim,
|
10 |
+
hidden_dim,
|
11 |
+
n_layers,
|
12 |
+
skip_in=(4,),
|
13 |
+
weight_norm=True):
|
14 |
+
super(PosNetwork, self).__init__()
|
15 |
+
dims = [input_dim] + [hidden_dim for _ in range(n_layers)] + [output_dim]
|
16 |
+
self.num_layers = n_layers
|
17 |
+
self.skip_in = skip_in
|
18 |
+
self.siren_layers = nn.ModuleList()
|
19 |
+
for l in range(0, self.num_layers + 1):
|
20 |
+
if l + 1 in self.skip_in:
|
21 |
+
out_dim = dims[l + 1] + dims[0]
|
22 |
+
dims[l + 1] = out_dim
|
23 |
+
else:
|
24 |
+
out_dim = dims[l + 1]
|
25 |
+
lin = nn.Linear(dims[l], out_dim)
|
26 |
+
if weight_norm:
|
27 |
+
lin = nn.utils.weight_norm(lin)
|
28 |
+
setattr(self, "lin" + str(l), lin)
|
29 |
+
self.activation = nn.ReLU
|
30 |
+
|
31 |
+
def forward(self, inputs):
|
32 |
+
x = inputs
|
33 |
+
for l in range(0, self.num_layers + 1):
|
34 |
+
lin = getattr(self, "lin" + str(l))
|
35 |
+
if l + 1 in self.skip_in:
|
36 |
+
x = torch.cat([x, inputs], 1) / np.sqrt(2)
|
37 |
+
x = lin(x)
|
38 |
+
if l < self.num_layers:
|
39 |
+
x = self.activation()(x)
|
40 |
+
x = x.mean(dim=1) # Pooling to match Shape: (32, 128)
|
41 |
+
return x
|
model/sequential.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import BertModel, BertTokenizer
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class SeqNetwork(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
input_dim,
|
9 |
+
output_dim,
|
10 |
+
hidden_dim,
|
11 |
+
lstm_in,
|
12 |
+
n_layers,
|
13 |
+
skip_in=(4,),
|
14 |
+
weight_norm=True,
|
15 |
+
freeze = False,
|
16 |
+
use_LSTM = False):
|
17 |
+
super(SeqNetwork, self).__init__()
|
18 |
+
|
19 |
+
self.freeze = freeze
|
20 |
+
self.skip_in = skip_in
|
21 |
+
self.num_layers = n_layers
|
22 |
+
self.weight_norm = weight_norm
|
23 |
+
self.use_LSTM = use_LSTM
|
24 |
+
# BERT model
|
25 |
+
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
26 |
+
if self.freeze:
|
27 |
+
for param in self.bert.parameters():
|
28 |
+
param.requires_grad = False
|
29 |
+
|
30 |
+
# Sequential model for BERT embeddings
|
31 |
+
self.lstm = nn.LSTM(input_size=lstm_in, hidden_size=input_dim, batch_first=True)
|
32 |
+
dims = [input_dim] + [hidden_dim for _ in range(n_layers)] + [output_dim]
|
33 |
+
for l in range(0, self.num_layers + 1):
|
34 |
+
if l+1 in self.skip_in:
|
35 |
+
out_dim = dims[l + 1] + dims[0]
|
36 |
+
dims[l + 1] = out_dim
|
37 |
+
else:
|
38 |
+
out_dim = dims[l + 1]
|
39 |
+
lin = nn.Linear(dims[l], out_dim)
|
40 |
+
if weight_norm:
|
41 |
+
lin = nn.utils.weight_norm(lin)
|
42 |
+
setattr(self, "lin" + str(l), lin)
|
43 |
+
self.activation = nn.ReLU
|
44 |
+
|
45 |
+
def forward(self, input_ids, attention_mask):
|
46 |
+
# BERT embeddings
|
47 |
+
bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
48 |
+
bert_sequence_output = bert_outputs.pooler_output # Shape: (batch_size, feature_size)
|
49 |
+
# print(bert_sequence_output.shape)
|
50 |
+
# LSTM over BERT embeddings
|
51 |
+
if self.use_LSTM:
|
52 |
+
lstm_out, (h_n, c_n) = self.lstm(bert_sequence_output)
|
53 |
+
inputs = h_n[-1] # Use the last hidden state
|
54 |
+
else:
|
55 |
+
inputs = bert_sequence_output
|
56 |
+
x = inputs
|
57 |
+
for l in range(0, self.num_layers + 1):
|
58 |
+
lin = getattr(self, "lin" + str(l))
|
59 |
+
|
60 |
+
if l in self.skip_in:
|
61 |
+
x = torch.cat([x, inputs], 1) / np.sqrt(2)
|
62 |
+
|
63 |
+
x = lin(x)
|
64 |
+
|
65 |
+
if l < self.num_layers:
|
66 |
+
x = self.activation()(x)
|
67 |
+
bert_feature = x # Shape: (batch_size, feature_size)
|
68 |
+
return bert_feature
|