TUEN-YUE commited on
Commit
e76c6f9
·
verified ·
1 Parent(s): 1ee6252

Upload 4 files

Browse files
model/frequential.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ class FreqNetwork(nn.Module):
6
+ def __init__(self,
7
+ tfidf_input_dim,
8
+ tfidf_output_dim,
9
+ tfidf_hidden_dim,
10
+ n_layers,
11
+ skip_in=(4,),
12
+ weight_norm=True):
13
+ super(FreqNetwork, self).__init__()
14
+ dims = [tfidf_input_dim] + [tfidf_hidden_dim for _ in range(n_layers)] + [tfidf_output_dim]
15
+ self.num_layers = n_layers
16
+ self.skip_in = skip_in
17
+ for l in range(0, self.num_layers+1):
18
+ if l+1 in self.skip_in:
19
+ out_dim = dims[l + 1] + dims[0]
20
+ dims[l + 1] = out_dim
21
+ else:
22
+ out_dim = dims[l + 1]
23
+ lin = nn.Linear(dims[l], out_dim)
24
+ if weight_norm:
25
+ lin = nn.utils.weight_norm(lin)
26
+ setattr(self, "lin" + str(l), lin)
27
+ self.activation = nn.ReLU
28
+
29
+ def forward(self, inputs):
30
+ x = inputs
31
+ for l in range(0, self.num_layers+1):
32
+ lin = getattr(self, "lin" + str(l))
33
+
34
+ if l in self.skip_in:
35
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
36
+
37
+ x = lin(x)
38
+
39
+ if l < self.num_layers:
40
+ x = self.activation()(x)
41
+
42
+ # x = torch.dropout(x, p=0.2, train=self.training)
43
+
44
+
45
+ return x
model/network.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from sklearn.linear_model import LogisticRegression
5
+ import numpy as np
6
+
7
+
8
+
9
+ class Classifier(nn.Module):
10
+ def __init__(self,
11
+ combined_input,
12
+ combined_dim,
13
+ num_classes,
14
+ n_layers,
15
+ skip_in=(4,),
16
+ weight_norm=True):
17
+ super(Classifier, self).__init__()
18
+ self.num_layers = n_layers
19
+ self.skip_in = skip_in
20
+ self.model = LogisticRegression()
21
+ # Combined classification layers
22
+ dims = [combined_input] + [combined_dim for _ in range(n_layers)] + [num_classes]
23
+ for l in range(0, self.num_layers + 1):
24
+ if l+1 in self.skip_in:
25
+ out_dim = dims[l + 1] + dims[0]
26
+ dims[l + 1] = out_dim
27
+ else:
28
+ out_dim = dims[l + 1]
29
+ lin = nn.Linear(dims[l], out_dim)
30
+ if weight_norm:
31
+ lin = nn.utils.weight_norm(lin)
32
+ setattr(self, "lin" + str(l), lin)
33
+ self.activation = nn.ReLU
34
+
35
+ def forward(self, inputs):
36
+
37
+ x = inputs
38
+ for l in range(0, self.num_layers + 1):
39
+ lin = getattr(self, "lin" + str(l))
40
+
41
+ if l+1 in self.skip_in:
42
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
43
+ x = lin(x)
44
+
45
+ if l < self.num_layers:
46
+ x = self.activation()(x)
47
+
48
+ # x = torch.dropout(x, p=0.2, train=self.training)
49
+ # Output layer
50
+ return x
model/positional.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from model.siren import SIREN
5
+
6
+ class PosNetwork(nn.Module):
7
+ def __init__(self,
8
+ input_dim,
9
+ output_dim,
10
+ hidden_dim,
11
+ n_layers,
12
+ skip_in=(4,),
13
+ weight_norm=True):
14
+ super(PosNetwork, self).__init__()
15
+ dims = [input_dim] + [hidden_dim for _ in range(n_layers)] + [output_dim]
16
+ self.num_layers = n_layers
17
+ self.skip_in = skip_in
18
+ self.siren_layers = nn.ModuleList()
19
+ for l in range(0, self.num_layers + 1):
20
+ if l + 1 in self.skip_in:
21
+ out_dim = dims[l + 1] + dims[0]
22
+ dims[l + 1] = out_dim
23
+ else:
24
+ out_dim = dims[l + 1]
25
+ lin = nn.Linear(dims[l], out_dim)
26
+ if weight_norm:
27
+ lin = nn.utils.weight_norm(lin)
28
+ setattr(self, "lin" + str(l), lin)
29
+ self.activation = nn.ReLU
30
+
31
+ def forward(self, inputs):
32
+ x = inputs
33
+ for l in range(0, self.num_layers + 1):
34
+ lin = getattr(self, "lin" + str(l))
35
+ if l + 1 in self.skip_in:
36
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
37
+ x = lin(x)
38
+ if l < self.num_layers:
39
+ x = self.activation()(x)
40
+ x = x.mean(dim=1) # Pooling to match Shape: (32, 128)
41
+ return x
model/sequential.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertModel, BertTokenizer
4
+ import numpy as np
5
+
6
+ class SeqNetwork(nn.Module):
7
+ def __init__(self,
8
+ input_dim,
9
+ output_dim,
10
+ hidden_dim,
11
+ lstm_in,
12
+ n_layers,
13
+ skip_in=(4,),
14
+ weight_norm=True,
15
+ freeze = False,
16
+ use_LSTM = False):
17
+ super(SeqNetwork, self).__init__()
18
+
19
+ self.freeze = freeze
20
+ self.skip_in = skip_in
21
+ self.num_layers = n_layers
22
+ self.weight_norm = weight_norm
23
+ self.use_LSTM = use_LSTM
24
+ # BERT model
25
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
26
+ if self.freeze:
27
+ for param in self.bert.parameters():
28
+ param.requires_grad = False
29
+
30
+ # Sequential model for BERT embeddings
31
+ self.lstm = nn.LSTM(input_size=lstm_in, hidden_size=input_dim, batch_first=True)
32
+ dims = [input_dim] + [hidden_dim for _ in range(n_layers)] + [output_dim]
33
+ for l in range(0, self.num_layers + 1):
34
+ if l+1 in self.skip_in:
35
+ out_dim = dims[l + 1] + dims[0]
36
+ dims[l + 1] = out_dim
37
+ else:
38
+ out_dim = dims[l + 1]
39
+ lin = nn.Linear(dims[l], out_dim)
40
+ if weight_norm:
41
+ lin = nn.utils.weight_norm(lin)
42
+ setattr(self, "lin" + str(l), lin)
43
+ self.activation = nn.ReLU
44
+
45
+ def forward(self, input_ids, attention_mask):
46
+ # BERT embeddings
47
+ bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
48
+ bert_sequence_output = bert_outputs.pooler_output # Shape: (batch_size, feature_size)
49
+ # print(bert_sequence_output.shape)
50
+ # LSTM over BERT embeddings
51
+ if self.use_LSTM:
52
+ lstm_out, (h_n, c_n) = self.lstm(bert_sequence_output)
53
+ inputs = h_n[-1] # Use the last hidden state
54
+ else:
55
+ inputs = bert_sequence_output
56
+ x = inputs
57
+ for l in range(0, self.num_layers + 1):
58
+ lin = getattr(self, "lin" + str(l))
59
+
60
+ if l in self.skip_in:
61
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
62
+
63
+ x = lin(x)
64
+
65
+ if l < self.num_layers:
66
+ x = self.activation()(x)
67
+ bert_feature = x # Shape: (batch_size, feature_size)
68
+ return bert_feature