Spaces:
Running
Running
johnpaulbin
commited on
Commit
·
3affa92
1
Parent(s):
ff8641c
First model version
Browse files- app.py +102 -6
- large.pth +3 -0
- model/.gitkeep +1 -0
- model/vocabulary.json +0 -0
- pytorch_model.bin +3 -0
- requirements.txt +2 -1
- torchmoji/.gitkeep +1 -0
- torchmoji/__init__.py +0 -0
- torchmoji/__pycache__/__init__.cpython-310.pyc +0 -0
- torchmoji/__pycache__/attlayer.cpython-310.pyc +0 -0
- torchmoji/__pycache__/create_vocab.cpython-310.pyc +0 -0
- torchmoji/__pycache__/filter_utils.cpython-310.pyc +0 -0
- torchmoji/__pycache__/global_variables.cpython-310.pyc +0 -0
- torchmoji/__pycache__/lstm.cpython-310.pyc +0 -0
- torchmoji/__pycache__/model_def.cpython-310.pyc +0 -0
- torchmoji/__pycache__/sentence_tokenizer.cpython-310.pyc +0 -0
- torchmoji/__pycache__/tokenizer.cpython-310.pyc +0 -0
- torchmoji/__pycache__/word_generator.cpython-310.pyc +0 -0
- torchmoji/attlayer.py +71 -0
- torchmoji/class_avg_finetuning.py +315 -0
- torchmoji/create_vocab.py +271 -0
- torchmoji/filter_input.py +36 -0
- torchmoji/filter_utils.py +194 -0
- torchmoji/finetuning.py +674 -0
- torchmoji/global_variables.py +28 -0
- torchmoji/lstm.py +356 -0
- torchmoji/model_def.py +315 -0
- torchmoji/sentence_tokenizer.py +245 -0
- torchmoji/tokenizer.py +156 -0
- torchmoji/word_generator.py +312 -0
- vocabulary.json +0 -0
app.py
CHANGED
@@ -3,23 +3,119 @@ import asyncio
|
|
3 |
from hypercorn.asyncio import serve
|
4 |
from hypercorn.config import Config
|
5 |
from setfit import SetFitModel
|
|
|
|
|
6 |
|
7 |
app = Flask(__name__)
|
8 |
|
9 |
-
model = SetFitModel.from_pretrained("johnpaulbin/toxic-gte-small-3")
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
@app.route('/infer', methods=['POST'])
|
13 |
def translate():
|
14 |
data = request.get_json()
|
15 |
-
result = model.predict_proba([data['text']])
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
else:
|
20 |
-
|
21 |
|
22 |
-
return
|
23 |
|
24 |
# Define more routes for other operations like download_model, etc.
|
25 |
if __name__ == "__main__":
|
|
|
3 |
from hypercorn.asyncio import serve
|
4 |
from hypercorn.config import Config
|
5 |
from setfit import SetFitModel
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
|
9 |
app = Flask(__name__)
|
10 |
|
|
|
11 |
|
12 |
+
from sentence_transformers import SentenceTransformer
|
13 |
+
sentencemodel = SentenceTransformer('johnpaulbin/toxic-gte-small-3')
|
14 |
+
|
15 |
+
USE_GPU = False
|
16 |
+
|
17 |
+
|
18 |
+
""" Use torchMoji to predict emojis from a single text input
|
19 |
+
"""
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import emoji, json
|
23 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
|
24 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
25 |
+
from torchmoji.model_def import torchmoji_emojis
|
26 |
+
import torch
|
27 |
+
|
28 |
+
# Emoji map in emoji_overview.png
|
29 |
+
EMOJIS = ":joy: :unamused: :weary: :sob: :heart_eyes: \
|
30 |
+
:pensive: :ok_hand: :blush: :heart: :smirk: \
|
31 |
+
:grin: :notes: :flushed: :100: :sleeping: \
|
32 |
+
:relieved: :relaxed: :raised_hands: :two_hearts: :expressionless: \
|
33 |
+
:sweat_smile: :pray: :confused: :kissing_heart: :heartbeat: \
|
34 |
+
:neutral_face: :information_desk_person: :disappointed: :see_no_evil: :tired_face: \
|
35 |
+
:v: :sunglasses: :rage: :thumbsup: :cry: \
|
36 |
+
:sleepy: :yum: :triumph: :hand: :mask: \
|
37 |
+
:clap: :eyes: :gun: :persevere: :smiling_imp: \
|
38 |
+
:sweat: :broken_heart: :yellow_heart: :musical_note: :speak_no_evil: \
|
39 |
+
:wink: :skull: :confounded: :smile: :stuck_out_tongue_winking_eye: \
|
40 |
+
:angry: :no_good: :muscle: :facepunch: :purple_heart: \
|
41 |
+
:sparkling_heart: :blue_heart: :grimacing: :sparkles:".split(' ')
|
42 |
+
|
43 |
+
def top_elements(array, k):
|
44 |
+
ind = np.argpartition(array, -k)[-k:]
|
45 |
+
return ind[np.argsort(array[ind])][::-1]
|
46 |
+
|
47 |
+
|
48 |
+
with open("vocabulary.json", 'r') as f:
|
49 |
+
vocabulary = json.load(f)
|
50 |
+
|
51 |
+
st = SentenceTokenizer(vocabulary, 100)
|
52 |
+
|
53 |
+
emojimodel = torchmoji_emojis("pytorch_model.bin")
|
54 |
+
|
55 |
+
if USE_GPU:
|
56 |
+
emojimodel.to("cuda:0")
|
57 |
+
|
58 |
+
def deepmojify(sentence, top_n=5, prob_only=False):
|
59 |
+
list_emojis = []
|
60 |
+
def top_elements(array, k):
|
61 |
+
ind = np.argpartition(array, -k)[-k:]
|
62 |
+
return ind[np.argsort(array[ind])][::-1]
|
63 |
+
|
64 |
+
tokenized, _, _ = st.tokenize_sentences([sentence])
|
65 |
+
tokenized = np.array(tokenized).astype(int) # convert to float first
|
66 |
+
if USE_GPU:
|
67 |
+
tokenized = torch.tensor(tokenized).cuda() # then convert to PyTorch tensor
|
68 |
+
|
69 |
+
prob = emojimodel.forward(tokenized)[0]
|
70 |
+
if prob_only:
|
71 |
+
return prob
|
72 |
+
emoji_ids = top_elements(prob.cpu().numpy(), top_n)
|
73 |
+
emojis = map(lambda x: EMOJIS[x], emoji_ids)
|
74 |
+
list_emojis.append(emoji.emojize(f"{' '.join(emojis)}", language='alias'))
|
75 |
+
# returning the emojis as a list named as list_emojis
|
76 |
+
return list_emojis, prob
|
77 |
+
|
78 |
+
|
79 |
+
model = nn.Sequential(
|
80 |
+
nn.Linear(448, 300), # Increase the number of neurons
|
81 |
+
nn.ReLU(),
|
82 |
+
nn.BatchNorm1d(300), # Batch normalization
|
83 |
+
|
84 |
+
nn.Linear(300, 300), # Increase the number of neurons
|
85 |
+
nn.ReLU(),
|
86 |
+
nn.BatchNorm1d(300), # Batch normalization
|
87 |
+
|
88 |
+
nn.Linear(300, 200), # Increase the number of neurons
|
89 |
+
nn.ReLU(),
|
90 |
+
nn.BatchNorm1d(200), # Batch normalization
|
91 |
+
|
92 |
+
nn.Linear(200, 125), # Increase the number of neurons
|
93 |
+
nn.ReLU(),
|
94 |
+
nn.BatchNorm1d(125), # Batch normalization
|
95 |
+
|
96 |
+
nn.Linear(125, 2),
|
97 |
+
nn.Dropout(0.05) # Dropout
|
98 |
+
)
|
99 |
+
|
100 |
+
model.eval()
|
101 |
+
torch.save(model.state_dict(), 'large.pth')
|
102 |
|
103 |
@app.route('/infer', methods=['POST'])
|
104 |
def translate():
|
105 |
data = request.get_json()
|
|
|
106 |
|
107 |
+
TEXT = data['text'].lower()
|
108 |
+
probs = deepmojify(TEXT, prob_only=True)
|
109 |
+
embedding = sentencemodel.encode(TEXT, convert_to_tensor=True)
|
110 |
+
INPUT = torch.cat((probs, embedding))
|
111 |
+
output = F.softmax(model(INPUT.view(1, -1)), dim=1)
|
112 |
+
|
113 |
+
if output[0][0] > output[0][1]:
|
114 |
+
output = "false"
|
115 |
else:
|
116 |
+
output = "true"
|
117 |
|
118 |
+
return output
|
119 |
|
120 |
# Define more routes for other operations like download_model, etc.
|
121 |
if __name__ == "__main__":
|
large.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e649a63d0f6cdce341cfde2ac8210cd7a1796420aac21585b1380e283cac01c
|
3 |
+
size 1265336
|
model/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
model/vocabulary.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8cbf6f7067d56aa1c2d571bb169f05fba16cea4c263c06fb3f217f42c591a978
|
3 |
+
size 89616062
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
Flask==2.0.1
|
2 |
tqdm
|
3 |
hypercorn
|
4 |
-
|
|
|
|
1 |
Flask==2.0.1
|
2 |
tqdm
|
3 |
hypercorn
|
4 |
+
emoji
|
5 |
+
sentence-transformers
|
torchmoji/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
torchmoji/__init__.py
ADDED
File without changes
|
torchmoji/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (133 Bytes). View file
|
|
torchmoji/__pycache__/attlayer.cpython-310.pyc
ADDED
Binary file (2.57 kB). View file
|
|
torchmoji/__pycache__/create_vocab.cpython-310.pyc
ADDED
Binary file (9.19 kB). View file
|
|
torchmoji/__pycache__/filter_utils.cpython-310.pyc
ADDED
Binary file (5.56 kB). View file
|
|
torchmoji/__pycache__/global_variables.cpython-310.pyc
ADDED
Binary file (922 Bytes). View file
|
|
torchmoji/__pycache__/lstm.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
torchmoji/__pycache__/model_def.cpython-310.pyc
ADDED
Binary file (9.7 kB). View file
|
|
torchmoji/__pycache__/sentence_tokenizer.cpython-310.pyc
ADDED
Binary file (8.84 kB). View file
|
|
torchmoji/__pycache__/tokenizer.cpython-310.pyc
ADDED
Binary file (2.71 kB). View file
|
|
torchmoji/__pycache__/word_generator.cpython-310.pyc
ADDED
Binary file (8.79 kB). View file
|
|
torchmoji/attlayer.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Define the Attention Layer of the model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from __future__ import print_function, division
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from torch.autograd import Variable
|
10 |
+
from torch.nn import Module
|
11 |
+
from torch.nn.parameter import Parameter
|
12 |
+
|
13 |
+
class Attention(Module):
|
14 |
+
"""
|
15 |
+
Computes a weighted average of the different channels across timesteps.
|
16 |
+
Uses 1 parameter pr. channel to compute the attention value for a single timestep.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, attention_size, return_attention=False):
|
20 |
+
""" Initialize the attention layer
|
21 |
+
|
22 |
+
# Arguments:
|
23 |
+
attention_size: Size of the attention vector.
|
24 |
+
return_attention: If true, output will include the weight for each input token
|
25 |
+
used for the prediction
|
26 |
+
|
27 |
+
"""
|
28 |
+
super(Attention, self).__init__()
|
29 |
+
self.return_attention = return_attention
|
30 |
+
self.attention_size = attention_size
|
31 |
+
self.attention_vector = Parameter(torch.FloatTensor(attention_size))
|
32 |
+
self.attention_vector.data.normal_(std=0.05) # Initialize attention vector
|
33 |
+
|
34 |
+
def __repr__(self):
|
35 |
+
s = '{name}({attention_size}, return attention={return_attention})'
|
36 |
+
return s.format(name=self.__class__.__name__, **self.__dict__)
|
37 |
+
|
38 |
+
def forward(self, inputs, input_lengths):
|
39 |
+
""" Forward pass.
|
40 |
+
|
41 |
+
# Arguments:
|
42 |
+
inputs (Torch.Variable): Tensor of input sequences
|
43 |
+
input_lengths (torch.LongTensor): Lengths of the sequences
|
44 |
+
|
45 |
+
# Return:
|
46 |
+
Tuple with (representations and attentions if self.return_attention else None).
|
47 |
+
"""
|
48 |
+
logits = inputs.matmul(self.attention_vector)
|
49 |
+
unnorm_ai = (logits - logits.max()).exp()
|
50 |
+
|
51 |
+
# Compute a mask for the attention on the padded sequences
|
52 |
+
# See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5
|
53 |
+
max_len = unnorm_ai.size(1)
|
54 |
+
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0)
|
55 |
+
mask = Variable((idxes < input_lengths.unsqueeze(1)).float())
|
56 |
+
|
57 |
+
# apply mask and renormalize attention scores (weights)
|
58 |
+
if self.attention_vector.device.type == "cuda":
|
59 |
+
masked_weights = unnorm_ai * mask.cuda()
|
60 |
+
else:
|
61 |
+
masked_weights = unnorm_ai * mask
|
62 |
+
att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence
|
63 |
+
attentions = masked_weights.div(att_sums)
|
64 |
+
|
65 |
+
# apply attention weights
|
66 |
+
weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))
|
67 |
+
|
68 |
+
# get the final fixed vector representations of the sentences
|
69 |
+
representations = weighted.sum(dim=1)
|
70 |
+
|
71 |
+
return (representations, attentions if self.return_attention else None)
|
torchmoji/class_avg_finetuning.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Class average finetuning functions. Before using any of these finetuning
|
3 |
+
functions, ensure that the model is set up with nb_classes=2.
|
4 |
+
"""
|
5 |
+
from __future__ import print_function
|
6 |
+
|
7 |
+
import uuid
|
8 |
+
from time import sleep
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.optim as optim
|
14 |
+
|
15 |
+
from torchmoji.global_variables import (
|
16 |
+
FINETUNING_METHODS,
|
17 |
+
WEIGHTS_DIR)
|
18 |
+
from torchmoji.finetuning import (
|
19 |
+
freeze_layers,
|
20 |
+
get_data_loader,
|
21 |
+
fit_model,
|
22 |
+
train_by_chain_thaw,
|
23 |
+
find_f1_threshold)
|
24 |
+
|
25 |
+
def relabel(y, current_label_nr, nb_classes):
|
26 |
+
""" Makes a binary classification for a specific class in a
|
27 |
+
multi-class dataset.
|
28 |
+
|
29 |
+
# Arguments:
|
30 |
+
y: Outputs to be relabelled.
|
31 |
+
current_label_nr: Current label number.
|
32 |
+
nb_classes: Total number of classes.
|
33 |
+
|
34 |
+
# Returns:
|
35 |
+
Relabelled outputs of a given multi-class dataset into a binary
|
36 |
+
classification dataset.
|
37 |
+
"""
|
38 |
+
|
39 |
+
# Handling binary classification
|
40 |
+
if nb_classes == 2 and len(y.shape) == 1:
|
41 |
+
return y
|
42 |
+
|
43 |
+
y_new = np.zeros(len(y))
|
44 |
+
y_cut = y[:, current_label_nr]
|
45 |
+
label_pos = np.where(y_cut == 1)[0]
|
46 |
+
y_new[label_pos] = 1
|
47 |
+
return y_new
|
48 |
+
|
49 |
+
|
50 |
+
def class_avg_finetune(model, texts, labels, nb_classes, batch_size,
|
51 |
+
method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
|
52 |
+
verbose=True):
|
53 |
+
""" Compiles and finetunes the given model.
|
54 |
+
|
55 |
+
# Arguments:
|
56 |
+
model: Model to be finetuned
|
57 |
+
texts: List of three lists, containing tokenized inputs for training,
|
58 |
+
validation and testing (in that order).
|
59 |
+
labels: List of three lists, containing labels for training,
|
60 |
+
validation and testing (in that order).
|
61 |
+
nb_classes: Number of classes in the dataset.
|
62 |
+
batch_size: Batch size.
|
63 |
+
method: Finetuning method to be used. For available methods, see
|
64 |
+
FINETUNING_METHODS in global_variables.py. Note that the model
|
65 |
+
should be defined accordingly (see docstring for torchmoji_transfer())
|
66 |
+
epoch_size: Number of samples in an epoch.
|
67 |
+
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
|
68 |
+
embed_l2: L2 regularization for the embedding layer.
|
69 |
+
verbose: Verbosity flag.
|
70 |
+
|
71 |
+
# Returns:
|
72 |
+
Model after finetuning,
|
73 |
+
score after finetuning using the class average F1 metric.
|
74 |
+
"""
|
75 |
+
|
76 |
+
if method not in FINETUNING_METHODS:
|
77 |
+
raise ValueError('ERROR (class_avg_tune_trainable): '
|
78 |
+
'Invalid method parameter. '
|
79 |
+
'Available options: {}'.format(FINETUNING_METHODS))
|
80 |
+
|
81 |
+
(X_train, y_train) = (texts[0], labels[0])
|
82 |
+
(X_val, y_val) = (texts[1], labels[1])
|
83 |
+
(X_test, y_test) = (texts[2], labels[2])
|
84 |
+
|
85 |
+
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
|
86 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
87 |
+
|
88 |
+
f1_init_path = '{}/torchmoji-f1-init-{}.bin' \
|
89 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
90 |
+
|
91 |
+
if method in ['last', 'new']:
|
92 |
+
lr = 0.001
|
93 |
+
elif method in ['full', 'chain-thaw']:
|
94 |
+
lr = 0.0001
|
95 |
+
|
96 |
+
loss_op = nn.BCEWithLogitsLoss()
|
97 |
+
|
98 |
+
# Freeze layers if using last
|
99 |
+
if method == 'last':
|
100 |
+
model = freeze_layers(model, unfrozen_keyword='output_layer')
|
101 |
+
|
102 |
+
# Define optimizer, for chain-thaw we define it later (after freezing)
|
103 |
+
if method == 'last':
|
104 |
+
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
|
105 |
+
elif method in ['full', 'new']:
|
106 |
+
# Add L2 regulation on embeddings only
|
107 |
+
special_params = [id(p) for p in model.embed.parameters()]
|
108 |
+
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
|
109 |
+
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
|
110 |
+
adam = optim.Adam([
|
111 |
+
{'params': base_params},
|
112 |
+
{'params': embed_parameters, 'weight_decay': embed_l2},
|
113 |
+
], lr=lr)
|
114 |
+
|
115 |
+
# Training
|
116 |
+
if verbose:
|
117 |
+
print('Method: {}'.format(method))
|
118 |
+
print('Classes: {}'.format(nb_classes))
|
119 |
+
|
120 |
+
if method == 'chain-thaw':
|
121 |
+
result = class_avg_chainthaw(model, nb_classes=nb_classes,
|
122 |
+
loss_op=loss_op,
|
123 |
+
train=(X_train, y_train),
|
124 |
+
val=(X_val, y_val),
|
125 |
+
test=(X_test, y_test),
|
126 |
+
batch_size=batch_size,
|
127 |
+
epoch_size=epoch_size,
|
128 |
+
nb_epochs=nb_epochs,
|
129 |
+
checkpoint_weight_path=checkpoint_path,
|
130 |
+
f1_init_weight_path=f1_init_path,
|
131 |
+
verbose=verbose)
|
132 |
+
else:
|
133 |
+
result = class_avg_tune_trainable(model, nb_classes=nb_classes,
|
134 |
+
loss_op=loss_op,
|
135 |
+
optim_op=adam,
|
136 |
+
train=(X_train, y_train),
|
137 |
+
val=(X_val, y_val),
|
138 |
+
test=(X_test, y_test),
|
139 |
+
epoch_size=epoch_size,
|
140 |
+
nb_epochs=nb_epochs,
|
141 |
+
batch_size=batch_size,
|
142 |
+
init_weight_path=f1_init_path,
|
143 |
+
checkpoint_weight_path=checkpoint_path,
|
144 |
+
verbose=verbose)
|
145 |
+
return model, result
|
146 |
+
|
147 |
+
|
148 |
+
def prepare_labels(y_train, y_val, y_test, iter_i, nb_classes):
|
149 |
+
# Relabel into binary classification
|
150 |
+
y_train_new = relabel(y_train, iter_i, nb_classes)
|
151 |
+
y_val_new = relabel(y_val, iter_i, nb_classes)
|
152 |
+
y_test_new = relabel(y_test, iter_i, nb_classes)
|
153 |
+
return y_train_new, y_val_new, y_test_new
|
154 |
+
|
155 |
+
def prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size):
|
156 |
+
# Create sample generators
|
157 |
+
# Make a fixed validation set to avoid fluctuations in validation
|
158 |
+
train_gen = get_data_loader(X_train, y_train_new, batch_size,
|
159 |
+
extended_batch_sampler=True)
|
160 |
+
val_gen = get_data_loader(X_val, y_val_new, epoch_size,
|
161 |
+
extended_batch_sampler=True)
|
162 |
+
X_val_resamp, y_val_resamp = next(iter(val_gen))
|
163 |
+
return train_gen, X_val_resamp, y_val_resamp
|
164 |
+
|
165 |
+
|
166 |
+
def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test,
|
167 |
+
epoch_size, nb_epochs, batch_size,
|
168 |
+
init_weight_path, checkpoint_weight_path, patience=5,
|
169 |
+
verbose=True):
|
170 |
+
""" Finetunes the given model using the F1 measure.
|
171 |
+
|
172 |
+
# Arguments:
|
173 |
+
model: Model to be finetuned.
|
174 |
+
nb_classes: Number of classes in the given dataset.
|
175 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
176 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
177 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
178 |
+
epoch_size: Number of samples in an epoch.
|
179 |
+
nb_epochs: Number of epochs.
|
180 |
+
batch_size: Batch size.
|
181 |
+
init_weight_path: Filepath where weights will be initially saved before
|
182 |
+
training each class. This file will be rewritten by the function.
|
183 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
184 |
+
during training. This file will be rewritten by the function.
|
185 |
+
verbose: Verbosity flag.
|
186 |
+
|
187 |
+
# Returns:
|
188 |
+
F1 score of the trained model
|
189 |
+
"""
|
190 |
+
total_f1 = 0
|
191 |
+
nb_iter = nb_classes if nb_classes > 2 else 1
|
192 |
+
|
193 |
+
# Unpack args
|
194 |
+
X_train, y_train = train
|
195 |
+
X_val, y_val = val
|
196 |
+
X_test, y_test = test
|
197 |
+
|
198 |
+
# Save and reload initial weights after running for
|
199 |
+
# each class to avoid learning across classes
|
200 |
+
torch.save(model.state_dict(), init_weight_path)
|
201 |
+
for i in range(nb_iter):
|
202 |
+
if verbose:
|
203 |
+
print('Iteration number {}/{}'.format(i+1, nb_iter))
|
204 |
+
|
205 |
+
model.load_state_dict(torch.load(init_weight_path))
|
206 |
+
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
|
207 |
+
y_test, i, nb_classes)
|
208 |
+
train_gen, X_val_resamp, y_val_resamp = \
|
209 |
+
prepare_generators(X_train, y_train_new, X_val, y_val_new,
|
210 |
+
batch_size, epoch_size)
|
211 |
+
|
212 |
+
if verbose:
|
213 |
+
print("Training..")
|
214 |
+
fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)],
|
215 |
+
nb_epochs, checkpoint_weight_path, patience, verbose=0)
|
216 |
+
|
217 |
+
# Reload the best weights found to avoid overfitting
|
218 |
+
# Wait a bit to allow proper closing of weights file
|
219 |
+
sleep(1)
|
220 |
+
model.load_state_dict(torch.load(checkpoint_weight_path))
|
221 |
+
|
222 |
+
# Evaluate
|
223 |
+
y_pred_val = model(X_val).cpu().numpy()
|
224 |
+
y_pred_test = model(X_test).cpu().numpy()
|
225 |
+
|
226 |
+
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
|
227 |
+
y_test_new, y_pred_test)
|
228 |
+
if verbose:
|
229 |
+
print('f1_test: {}'.format(f1_test))
|
230 |
+
print('best_t: {}'.format(best_t))
|
231 |
+
total_f1 += f1_test
|
232 |
+
|
233 |
+
return total_f1 / nb_iter
|
234 |
+
|
235 |
+
|
236 |
+
def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size,
|
237 |
+
epoch_size, nb_epochs, checkpoint_weight_path,
|
238 |
+
f1_init_weight_path, patience=5,
|
239 |
+
initial_lr=0.001, next_lr=0.0001, verbose=True):
|
240 |
+
""" Finetunes given model using chain-thaw and evaluates using F1.
|
241 |
+
For a dataset with multiple classes, the model is trained once for
|
242 |
+
each class, relabeling those classes into a binary classification task.
|
243 |
+
The result is an average of all F1 scores for each class.
|
244 |
+
|
245 |
+
# Arguments:
|
246 |
+
model: Model to be finetuned.
|
247 |
+
nb_classes: Number of classes in the given dataset.
|
248 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
249 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
250 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
251 |
+
batch_size: Batch size.
|
252 |
+
loss: Loss function to be used during training.
|
253 |
+
epoch_size: Number of samples in an epoch.
|
254 |
+
nb_epochs: Number of epochs.
|
255 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
256 |
+
during training. This file will be rewritten by the function.
|
257 |
+
f1_init_weight_path: Filepath where weights will be saved to and
|
258 |
+
reloaded from before training each class. This ensures that
|
259 |
+
each class is trained independently. This file will be rewritten.
|
260 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
261 |
+
training step (i.e. the softmax layer)
|
262 |
+
next_lr: Learning rate for every subsequent step.
|
263 |
+
seed: Random number generator seed.
|
264 |
+
verbose: Verbosity flag.
|
265 |
+
|
266 |
+
# Returns:
|
267 |
+
Averaged F1 score.
|
268 |
+
"""
|
269 |
+
|
270 |
+
# Unpack args
|
271 |
+
X_train, y_train = train
|
272 |
+
X_val, y_val = val
|
273 |
+
X_test, y_test = test
|
274 |
+
|
275 |
+
total_f1 = 0
|
276 |
+
nb_iter = nb_classes if nb_classes > 2 else 1
|
277 |
+
|
278 |
+
torch.save(model.state_dict(), f1_init_weight_path)
|
279 |
+
|
280 |
+
for i in range(nb_iter):
|
281 |
+
if verbose:
|
282 |
+
print('Iteration number {}/{}'.format(i+1, nb_iter))
|
283 |
+
|
284 |
+
model.load_state_dict(torch.load(f1_init_weight_path))
|
285 |
+
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
|
286 |
+
y_test, i, nb_classes)
|
287 |
+
train_gen, X_val_resamp, y_val_resamp = \
|
288 |
+
prepare_generators(X_train, y_train_new, X_val, y_val_new,
|
289 |
+
batch_size, epoch_size)
|
290 |
+
|
291 |
+
if verbose:
|
292 |
+
print("Training..")
|
293 |
+
|
294 |
+
# Train using chain-thaw
|
295 |
+
train_by_chain_thaw(model=model, train_gen=train_gen,
|
296 |
+
val_gen=[(X_val_resamp, y_val_resamp)],
|
297 |
+
loss_op=loss_op, patience=patience,
|
298 |
+
nb_epochs=nb_epochs,
|
299 |
+
checkpoint_path=checkpoint_weight_path,
|
300 |
+
initial_lr=initial_lr, next_lr=next_lr,
|
301 |
+
verbose=verbose)
|
302 |
+
|
303 |
+
# Evaluate
|
304 |
+
y_pred_val = model(X_val).cpu().numpy()
|
305 |
+
y_pred_test = model(X_test).cpu().numpy()
|
306 |
+
|
307 |
+
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
|
308 |
+
y_test_new, y_pred_test)
|
309 |
+
|
310 |
+
if verbose:
|
311 |
+
print('f1_test: {}'.format(f1_test))
|
312 |
+
print('best_t: {}'.format(best_t))
|
313 |
+
total_f1 += f1_test
|
314 |
+
|
315 |
+
return total_f1 / nb_iter
|
torchmoji/create_vocab.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from __future__ import print_function, division
|
3 |
+
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import uuid
|
7 |
+
from copy import deepcopy
|
8 |
+
from collections import defaultdict, OrderedDict
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchmoji.filter_utils import is_special_token
|
12 |
+
from torchmoji.word_generator import WordGenerator
|
13 |
+
from torchmoji.global_variables import SPECIAL_TOKENS, VOCAB_PATH
|
14 |
+
|
15 |
+
class VocabBuilder():
|
16 |
+
""" Create vocabulary with words extracted from sentences as fed from a
|
17 |
+
word generator.
|
18 |
+
"""
|
19 |
+
def __init__(self, word_gen):
|
20 |
+
# initialize any new key with value of 0
|
21 |
+
self.word_counts = defaultdict(lambda: 0, {})
|
22 |
+
self.word_length_limit=30
|
23 |
+
|
24 |
+
for token in SPECIAL_TOKENS:
|
25 |
+
assert len(token) < self.word_length_limit
|
26 |
+
self.word_counts[token] = 0
|
27 |
+
self.word_gen = word_gen
|
28 |
+
|
29 |
+
def count_words_in_sentence(self, words):
|
30 |
+
""" Generates word counts for all tokens in the given sentence.
|
31 |
+
|
32 |
+
# Arguments:
|
33 |
+
words: Tokenized sentence whose words should be counted.
|
34 |
+
"""
|
35 |
+
for word in words:
|
36 |
+
if 0 < len(word) and len(word) <= self.word_length_limit:
|
37 |
+
try:
|
38 |
+
self.word_counts[word] += 1
|
39 |
+
except KeyError:
|
40 |
+
self.word_counts[word] = 1
|
41 |
+
|
42 |
+
def save_vocab(self, path=None):
|
43 |
+
""" Saves the vocabulary into a file.
|
44 |
+
|
45 |
+
# Arguments:
|
46 |
+
path: Where the vocabulary should be saved. If not specified, a
|
47 |
+
randomly generated filename is used instead.
|
48 |
+
"""
|
49 |
+
dtype = ([('word','|S{}'.format(self.word_length_limit)),('count','int')])
|
50 |
+
np_dict = np.array(self.word_counts.items(), dtype=dtype)
|
51 |
+
|
52 |
+
# sort from highest to lowest frequency
|
53 |
+
np_dict[::-1].sort(order='count')
|
54 |
+
data = np_dict
|
55 |
+
|
56 |
+
if path is None:
|
57 |
+
path = str(uuid.uuid4())
|
58 |
+
|
59 |
+
np.savez_compressed(path, data=data)
|
60 |
+
print("Saved dict to {}".format(path))
|
61 |
+
|
62 |
+
def get_next_word(self):
|
63 |
+
""" Returns next tokenized sentence from the word geneerator.
|
64 |
+
|
65 |
+
# Returns:
|
66 |
+
List of strings, representing the next tokenized sentence.
|
67 |
+
"""
|
68 |
+
return self.word_gen.__iter__().next()
|
69 |
+
|
70 |
+
def count_all_words(self):
|
71 |
+
""" Generates word counts for all words in all sentences of the word
|
72 |
+
generator.
|
73 |
+
"""
|
74 |
+
for words, _ in self.word_gen:
|
75 |
+
self.count_words_in_sentence(words)
|
76 |
+
|
77 |
+
class MasterVocab():
|
78 |
+
""" Combines vocabularies.
|
79 |
+
"""
|
80 |
+
def __init__(self):
|
81 |
+
|
82 |
+
# initialize custom tokens
|
83 |
+
self.master_vocab = {}
|
84 |
+
|
85 |
+
def populate_master_vocab(self, vocab_path, min_words=1, force_appearance=None):
|
86 |
+
""" Populates the master vocabulary using all vocabularies found in the
|
87 |
+
given path. Vocabularies should be named *.npz. Expects the
|
88 |
+
vocabularies to be numpy arrays with counts. Normalizes the counts
|
89 |
+
and combines them.
|
90 |
+
|
91 |
+
# Arguments:
|
92 |
+
vocab_path: Path containing vocabularies to be combined.
|
93 |
+
min_words: Minimum amount of occurences a word must have in order
|
94 |
+
to be included in the master vocabulary.
|
95 |
+
force_appearance: Optional vocabulary filename that will be added
|
96 |
+
to the master vocabulary no matter what. This vocabulary must
|
97 |
+
be present in vocab_path.
|
98 |
+
"""
|
99 |
+
|
100 |
+
paths = glob.glob(vocab_path + '*.npz')
|
101 |
+
sizes = {path: 0 for path in paths}
|
102 |
+
dicts = {path: {} for path in paths}
|
103 |
+
|
104 |
+
# set up and get sizes of individual dictionaries
|
105 |
+
for path in paths:
|
106 |
+
np_data = np.load(path)['data']
|
107 |
+
|
108 |
+
for entry in np_data:
|
109 |
+
word, count = entry
|
110 |
+
if count < min_words:
|
111 |
+
continue
|
112 |
+
if is_special_token(word):
|
113 |
+
continue
|
114 |
+
dicts[path][word] = count
|
115 |
+
|
116 |
+
sizes[path] = sum(dicts[path].values())
|
117 |
+
print('Overall word count for {} -> {}'.format(path, sizes[path]))
|
118 |
+
print('Overall word number for {} -> {}'.format(path, len(dicts[path])))
|
119 |
+
|
120 |
+
vocab_of_max_size = max(sizes, key=sizes.get)
|
121 |
+
max_size = sizes[vocab_of_max_size]
|
122 |
+
print('Min: {}, {}, {}'.format(sizes, vocab_of_max_size, max_size))
|
123 |
+
|
124 |
+
# can force one vocabulary to always be present
|
125 |
+
if force_appearance is not None:
|
126 |
+
force_appearance_path = [p for p in paths if force_appearance in p][0]
|
127 |
+
force_appearance_vocab = deepcopy(dicts[force_appearance_path])
|
128 |
+
print(force_appearance_path)
|
129 |
+
else:
|
130 |
+
force_appearance_path, force_appearance_vocab = None, None
|
131 |
+
|
132 |
+
# normalize word counts before inserting into master dict
|
133 |
+
for path in paths:
|
134 |
+
normalization_factor = max_size / sizes[path]
|
135 |
+
print('Norm factor for path {} -> {}'.format(path, normalization_factor))
|
136 |
+
|
137 |
+
for word in dicts[path]:
|
138 |
+
if is_special_token(word):
|
139 |
+
print("SPECIAL - ", word)
|
140 |
+
continue
|
141 |
+
normalized_count = dicts[path][word] * normalization_factor
|
142 |
+
|
143 |
+
# can force one vocabulary to always be present
|
144 |
+
if force_appearance_vocab is not None:
|
145 |
+
try:
|
146 |
+
force_word_count = force_appearance_vocab[word]
|
147 |
+
except KeyError:
|
148 |
+
continue
|
149 |
+
#if force_word_count < 5:
|
150 |
+
#continue
|
151 |
+
|
152 |
+
if word in self.master_vocab:
|
153 |
+
self.master_vocab[word] += normalized_count
|
154 |
+
else:
|
155 |
+
self.master_vocab[word] = normalized_count
|
156 |
+
|
157 |
+
print('Size of master_dict {}'.format(len(self.master_vocab)))
|
158 |
+
print("Hashes for master dict: {}".format(
|
159 |
+
len([w for w in self.master_vocab if '#' in w[0]])))
|
160 |
+
|
161 |
+
def save_vocab(self, path_count, path_vocab, word_limit=100000):
|
162 |
+
""" Saves the master vocabulary into a file.
|
163 |
+
"""
|
164 |
+
|
165 |
+
# reserve space for 10 special tokens
|
166 |
+
words = OrderedDict()
|
167 |
+
for token in SPECIAL_TOKENS:
|
168 |
+
# store -1 instead of np.inf, which can overflow
|
169 |
+
words[token] = -1
|
170 |
+
|
171 |
+
# sort words by frequency
|
172 |
+
desc_order = OrderedDict(sorted(self.master_vocab.items(),
|
173 |
+
key=lambda kv: kv[1], reverse=True))
|
174 |
+
words.update(desc_order)
|
175 |
+
|
176 |
+
# use encoding of up to 30 characters (no token conversions)
|
177 |
+
# use float to store large numbers (we don't care about precision loss)
|
178 |
+
np_vocab = np.array(words.items(),
|
179 |
+
dtype=([('word','|S30'),('count','float')]))
|
180 |
+
|
181 |
+
# output count for debugging
|
182 |
+
counts = np_vocab[:word_limit]
|
183 |
+
np.savez_compressed(path_count, counts=counts)
|
184 |
+
|
185 |
+
# output the index of each word for easy lookup
|
186 |
+
final_words = OrderedDict()
|
187 |
+
for i, w in enumerate(words.keys()[:word_limit]):
|
188 |
+
final_words.update({w:i})
|
189 |
+
with open(path_vocab, 'w') as f:
|
190 |
+
f.write(json.dumps(final_words, indent=4, separators=(',', ': ')))
|
191 |
+
|
192 |
+
|
193 |
+
def all_words_in_sentences(sentences):
|
194 |
+
""" Extracts all unique words from a given list of sentences.
|
195 |
+
|
196 |
+
# Arguments:
|
197 |
+
sentences: List or word generator of sentences to be processed.
|
198 |
+
|
199 |
+
# Returns:
|
200 |
+
List of all unique words contained in the given sentences.
|
201 |
+
"""
|
202 |
+
vocab = []
|
203 |
+
if isinstance(sentences, WordGenerator):
|
204 |
+
sentences = [s for s, _ in sentences]
|
205 |
+
|
206 |
+
for sentence in sentences:
|
207 |
+
for word in sentence:
|
208 |
+
if word not in vocab:
|
209 |
+
vocab.append(word)
|
210 |
+
|
211 |
+
return vocab
|
212 |
+
|
213 |
+
|
214 |
+
def extend_vocab_in_file(vocab, max_tokens=10000, vocab_path=VOCAB_PATH):
|
215 |
+
""" Extends JSON-formatted vocabulary with words from vocab that are not
|
216 |
+
present in the current vocabulary. Adds up to max_tokens words.
|
217 |
+
Overwrites file in vocab_path.
|
218 |
+
|
219 |
+
# Arguments:
|
220 |
+
new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e.
|
221 |
+
must have run count_all_words() previously.
|
222 |
+
max_tokens: Maximum number of words to be added.
|
223 |
+
vocab_path: Path to the vocabulary json which is to be extended.
|
224 |
+
"""
|
225 |
+
try:
|
226 |
+
with open(vocab_path, 'r') as f:
|
227 |
+
current_vocab = json.load(f)
|
228 |
+
except IOError:
|
229 |
+
print('Vocabulary file not found, expected at ' + vocab_path)
|
230 |
+
return
|
231 |
+
|
232 |
+
extend_vocab(current_vocab, vocab, max_tokens)
|
233 |
+
|
234 |
+
# Save back to file
|
235 |
+
with open(vocab_path, 'w') as f:
|
236 |
+
json.dump(current_vocab, f, sort_keys=True, indent=4, separators=(',',': '))
|
237 |
+
|
238 |
+
|
239 |
+
def extend_vocab(current_vocab, new_vocab, max_tokens=10000):
|
240 |
+
""" Extends current vocabulary with words from vocab that are not
|
241 |
+
present in the current vocabulary. Adds up to max_tokens words.
|
242 |
+
|
243 |
+
# Arguments:
|
244 |
+
current_vocab: Current dictionary of tokens.
|
245 |
+
new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e.
|
246 |
+
must have run count_all_words() previously.
|
247 |
+
max_tokens: Maximum number of words to be added.
|
248 |
+
|
249 |
+
# Returns:
|
250 |
+
How many new tokens have been added.
|
251 |
+
"""
|
252 |
+
if max_tokens < 0:
|
253 |
+
max_tokens = 10000
|
254 |
+
|
255 |
+
words = OrderedDict()
|
256 |
+
|
257 |
+
# sort words by frequency
|
258 |
+
desc_order = OrderedDict(sorted(new_vocab.word_counts.items(),
|
259 |
+
key=lambda kv: kv[1], reverse=True))
|
260 |
+
words.update(desc_order)
|
261 |
+
|
262 |
+
base_index = len(current_vocab.keys())
|
263 |
+
added = 0
|
264 |
+
for word in words:
|
265 |
+
if added >= max_tokens:
|
266 |
+
break
|
267 |
+
if word not in current_vocab.keys():
|
268 |
+
current_vocab[word] = base_index + added
|
269 |
+
added += 1
|
270 |
+
|
271 |
+
return added
|
torchmoji/filter_input.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from __future__ import print_function, division
|
3 |
+
import codecs
|
4 |
+
import csv
|
5 |
+
import numpy as np
|
6 |
+
from emoji import UNICODE_EMOJI
|
7 |
+
|
8 |
+
def read_english(path="english_words.txt", add_emojis=True):
|
9 |
+
# read english words for filtering (includes emojis as part of set)
|
10 |
+
english = set()
|
11 |
+
with codecs.open(path, "r", "utf-8") as f:
|
12 |
+
for line in f:
|
13 |
+
line = line.strip().lower().replace('\n', '')
|
14 |
+
if len(line):
|
15 |
+
english.add(line)
|
16 |
+
if add_emojis:
|
17 |
+
for e in UNICODE_EMOJI:
|
18 |
+
english.add(e)
|
19 |
+
return english
|
20 |
+
|
21 |
+
def read_wanted_emojis(path="wanted_emojis.csv"):
|
22 |
+
emojis = []
|
23 |
+
with open(path, 'rb') as f:
|
24 |
+
reader = csv.reader(f)
|
25 |
+
for line in reader:
|
26 |
+
line = line[0].strip().replace('\n', '')
|
27 |
+
line = line.decode('unicode-escape')
|
28 |
+
emojis.append(line)
|
29 |
+
return emojis
|
30 |
+
|
31 |
+
def read_non_english_users(path="unwanted_users.npz"):
|
32 |
+
try:
|
33 |
+
neu_set = set(np.load(path)['userids'])
|
34 |
+
except IOError:
|
35 |
+
neu_set = set()
|
36 |
+
return neu_set
|
torchmoji/filter_utils.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from __future__ import print_function, division, unicode_literals
|
4 |
+
import sys
|
5 |
+
import re
|
6 |
+
import string
|
7 |
+
import emoji
|
8 |
+
from itertools import groupby
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from torchmoji.tokenizer import RE_MENTION, RE_URL
|
12 |
+
from torchmoji.global_variables import SPECIAL_TOKENS
|
13 |
+
|
14 |
+
try:
|
15 |
+
unichr # Python 2
|
16 |
+
except NameError:
|
17 |
+
unichr = chr # Python 3
|
18 |
+
|
19 |
+
|
20 |
+
AtMentionRegex = re.compile(RE_MENTION)
|
21 |
+
urlRegex = re.compile(RE_URL)
|
22 |
+
|
23 |
+
# from http://bit.ly/2rdjgjE (UTF-8 encodings and Unicode chars)
|
24 |
+
VARIATION_SELECTORS = [ '\ufe00',
|
25 |
+
'\ufe01',
|
26 |
+
'\ufe02',
|
27 |
+
'\ufe03',
|
28 |
+
'\ufe04',
|
29 |
+
'\ufe05',
|
30 |
+
'\ufe06',
|
31 |
+
'\ufe07',
|
32 |
+
'\ufe08',
|
33 |
+
'\ufe09',
|
34 |
+
'\ufe0a',
|
35 |
+
'\ufe0b',
|
36 |
+
'\ufe0c',
|
37 |
+
'\ufe0d',
|
38 |
+
'\ufe0e',
|
39 |
+
'\ufe0f']
|
40 |
+
|
41 |
+
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
|
42 |
+
ALL_CHARS = (unichr(i) for i in range(sys.maxunicode))
|
43 |
+
CONTROL_CHARS = ''.join(map(unichr, list(range(0,32)) + list(range(127,160))))
|
44 |
+
CONTROL_CHAR_REGEX = re.compile('[%s]' % re.escape(CONTROL_CHARS))
|
45 |
+
|
46 |
+
def is_special_token(word):
|
47 |
+
equal = False
|
48 |
+
for spec in SPECIAL_TOKENS:
|
49 |
+
if word == spec:
|
50 |
+
equal = True
|
51 |
+
break
|
52 |
+
return equal
|
53 |
+
|
54 |
+
def mostly_english(words, english, pct_eng_short=0.5, pct_eng_long=0.6, ignore_special_tokens=True, min_length=2):
|
55 |
+
""" Ensure text meets threshold for containing English words """
|
56 |
+
|
57 |
+
n_words = 0
|
58 |
+
n_english = 0
|
59 |
+
|
60 |
+
if english is None:
|
61 |
+
return True, 0, 0
|
62 |
+
|
63 |
+
for w in words:
|
64 |
+
if len(w) < min_length:
|
65 |
+
continue
|
66 |
+
if punct_word(w):
|
67 |
+
continue
|
68 |
+
if ignore_special_tokens and is_special_token(w):
|
69 |
+
continue
|
70 |
+
n_words += 1
|
71 |
+
if w in english:
|
72 |
+
n_english += 1
|
73 |
+
|
74 |
+
if n_words < 2:
|
75 |
+
return True, n_words, n_english
|
76 |
+
if n_words < 5:
|
77 |
+
valid_english = n_english >= n_words * pct_eng_short
|
78 |
+
else:
|
79 |
+
valid_english = n_english >= n_words * pct_eng_long
|
80 |
+
return valid_english, n_words, n_english
|
81 |
+
|
82 |
+
def correct_length(words, min_words, max_words, ignore_special_tokens=True):
|
83 |
+
""" Ensure text meets threshold for containing English words
|
84 |
+
and that it's within the min and max words limits. """
|
85 |
+
|
86 |
+
if min_words is None:
|
87 |
+
min_words = 0
|
88 |
+
|
89 |
+
if max_words is None:
|
90 |
+
max_words = 99999
|
91 |
+
|
92 |
+
n_words = 0
|
93 |
+
for w in words:
|
94 |
+
if punct_word(w):
|
95 |
+
continue
|
96 |
+
if ignore_special_tokens and is_special_token(w):
|
97 |
+
continue
|
98 |
+
n_words += 1
|
99 |
+
valid = min_words <= n_words and n_words <= max_words
|
100 |
+
return valid
|
101 |
+
|
102 |
+
def punct_word(word, punctuation=string.punctuation):
|
103 |
+
return all([True if c in punctuation else False for c in word])
|
104 |
+
|
105 |
+
def load_non_english_user_set():
|
106 |
+
non_english_user_set = set(np.load('uids.npz')['data'])
|
107 |
+
return non_english_user_set
|
108 |
+
|
109 |
+
def non_english_user(userid, non_english_user_set):
|
110 |
+
neu_found = int(userid) in non_english_user_set
|
111 |
+
return neu_found
|
112 |
+
|
113 |
+
def separate_emojis_and_text(text):
|
114 |
+
emoji_chars = []
|
115 |
+
non_emoji_chars = []
|
116 |
+
for c in text:
|
117 |
+
if c in emoji.UNICODE_EMOJI:
|
118 |
+
emoji_chars.append(c)
|
119 |
+
else:
|
120 |
+
non_emoji_chars.append(c)
|
121 |
+
return ''.join(emoji_chars), ''.join(non_emoji_chars)
|
122 |
+
|
123 |
+
def extract_emojis(text, wanted_emojis):
|
124 |
+
text = remove_variation_selectors(text)
|
125 |
+
return [c for c in text if c in wanted_emojis]
|
126 |
+
|
127 |
+
def remove_variation_selectors(text):
|
128 |
+
""" Remove styling glyph variants for Unicode characters.
|
129 |
+
For instance, remove skin color from emojis.
|
130 |
+
"""
|
131 |
+
for var in VARIATION_SELECTORS:
|
132 |
+
text = text.replace(var, '')
|
133 |
+
return text
|
134 |
+
|
135 |
+
def shorten_word(word):
|
136 |
+
""" Shorten groupings of 3+ identical consecutive chars to 2, e.g. '!!!!' --> '!!'
|
137 |
+
"""
|
138 |
+
|
139 |
+
# only shorten ASCII words
|
140 |
+
try:
|
141 |
+
word.decode('ascii')
|
142 |
+
except (UnicodeDecodeError, UnicodeEncodeError, AttributeError) as e:
|
143 |
+
return word
|
144 |
+
|
145 |
+
# must have at least 3 char to be shortened
|
146 |
+
if len(word) < 3:
|
147 |
+
return word
|
148 |
+
|
149 |
+
# find groups of 3+ consecutive letters
|
150 |
+
letter_groups = [list(g) for k, g in groupby(word)]
|
151 |
+
triple_or_more = [''.join(g) for g in letter_groups if len(g) >= 3]
|
152 |
+
if len(triple_or_more) == 0:
|
153 |
+
return word
|
154 |
+
|
155 |
+
# replace letters to find the short word
|
156 |
+
short_word = word
|
157 |
+
for trip in triple_or_more:
|
158 |
+
short_word = short_word.replace(trip, trip[0]*2)
|
159 |
+
|
160 |
+
return short_word
|
161 |
+
|
162 |
+
def detect_special_tokens(word):
|
163 |
+
try:
|
164 |
+
int(word)
|
165 |
+
word = SPECIAL_TOKENS[4]
|
166 |
+
except ValueError:
|
167 |
+
if AtMentionRegex.findall(word):
|
168 |
+
word = SPECIAL_TOKENS[2]
|
169 |
+
elif urlRegex.findall(word):
|
170 |
+
word = SPECIAL_TOKENS[3]
|
171 |
+
return word
|
172 |
+
|
173 |
+
def process_word(word):
|
174 |
+
""" Shortening and converting the word to a special token if relevant.
|
175 |
+
"""
|
176 |
+
word = shorten_word(word)
|
177 |
+
word = detect_special_tokens(word)
|
178 |
+
return word
|
179 |
+
|
180 |
+
def remove_control_chars(text):
|
181 |
+
return CONTROL_CHAR_REGEX.sub('', text)
|
182 |
+
|
183 |
+
def convert_nonbreaking_space(text):
|
184 |
+
# ugly hack handling non-breaking space no matter how badly it's been encoded in the input
|
185 |
+
for r in ['\\\\xc2', '\\xc2', '\xc2', '\\\\xa0', '\\xa0', '\xa0']:
|
186 |
+
text = text.replace(r, ' ')
|
187 |
+
return text
|
188 |
+
|
189 |
+
def convert_linebreaks(text):
|
190 |
+
# ugly hack handling non-breaking space no matter how badly it's been encoded in the input
|
191 |
+
# space around to ensure proper tokenization
|
192 |
+
for r in ['\\\\n', '\\n', '\n', '\\\\r', '\\r', '\r', '<br>']:
|
193 |
+
text = text.replace(r, ' ' + SPECIAL_TOKENS[5] + ' ')
|
194 |
+
return text
|
torchmoji/finetuning.py
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Finetuning functions for doing transfer learning to new datasets.
|
3 |
+
"""
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import uuid
|
7 |
+
from time import sleep
|
8 |
+
from io import open
|
9 |
+
|
10 |
+
import math
|
11 |
+
import pickle
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.optim as optim
|
17 |
+
from sklearn.metrics import accuracy_score
|
18 |
+
from torch.autograd import Variable
|
19 |
+
from torch.utils.data import Dataset, DataLoader
|
20 |
+
from torch.utils.data.sampler import BatchSampler, SequentialSampler
|
21 |
+
from torch.nn.utils import clip_grad_norm
|
22 |
+
|
23 |
+
from sklearn.metrics import f1_score
|
24 |
+
|
25 |
+
from torchmoji.global_variables import (FINETUNING_METHODS,
|
26 |
+
FINETUNING_METRICS,
|
27 |
+
WEIGHTS_DIR)
|
28 |
+
from torchmoji.tokenizer import tokenize
|
29 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
30 |
+
|
31 |
+
try:
|
32 |
+
unicode
|
33 |
+
IS_PYTHON2 = True
|
34 |
+
except NameError:
|
35 |
+
unicode = str
|
36 |
+
IS_PYTHON2 = False
|
37 |
+
|
38 |
+
|
39 |
+
def load_benchmark(path, vocab, extend_with=0):
|
40 |
+
""" Loads the given benchmark dataset.
|
41 |
+
|
42 |
+
Tokenizes the texts using the provided vocabulary, extending it with
|
43 |
+
words from the training dataset if extend_with > 0. Splits them into
|
44 |
+
three lists: training, validation and testing (in that order).
|
45 |
+
|
46 |
+
Also calculates the maximum length of the texts and the
|
47 |
+
suggested batch_size.
|
48 |
+
|
49 |
+
# Arguments:
|
50 |
+
path: Path to the dataset to be loaded.
|
51 |
+
vocab: Vocabulary to be used for tokenizing texts.
|
52 |
+
extend_with: If > 0, the vocabulary will be extended with up to
|
53 |
+
extend_with tokens from the training set before tokenizing.
|
54 |
+
|
55 |
+
# Returns:
|
56 |
+
A dictionary with the following fields:
|
57 |
+
texts: List of three lists, containing tokenized inputs for
|
58 |
+
training, validation and testing (in that order).
|
59 |
+
labels: List of three lists, containing labels for training,
|
60 |
+
validation and testing (in that order).
|
61 |
+
added: Number of tokens added to the vocabulary.
|
62 |
+
batch_size: Batch size.
|
63 |
+
maxlen: Maximum length of an input.
|
64 |
+
"""
|
65 |
+
# Pre-processing dataset
|
66 |
+
with open(path, 'rb') as dataset:
|
67 |
+
if IS_PYTHON2:
|
68 |
+
data = pickle.load(dataset)
|
69 |
+
else:
|
70 |
+
data = pickle.load(dataset, fix_imports=True)
|
71 |
+
|
72 |
+
# Decode data
|
73 |
+
try:
|
74 |
+
texts = [unicode(x) for x in data['texts']]
|
75 |
+
except UnicodeDecodeError:
|
76 |
+
texts = [x.decode('utf-8') for x in data['texts']]
|
77 |
+
|
78 |
+
# Extract labels
|
79 |
+
labels = [x['label'] for x in data['info']]
|
80 |
+
|
81 |
+
batch_size, maxlen = calculate_batchsize_maxlen(texts)
|
82 |
+
|
83 |
+
st = SentenceTokenizer(vocab, maxlen)
|
84 |
+
|
85 |
+
# Split up dataset. Extend the existing vocabulary with up to extend_with
|
86 |
+
# tokens from the training dataset.
|
87 |
+
texts, labels, added = st.split_train_val_test(texts,
|
88 |
+
labels,
|
89 |
+
[data['train_ind'],
|
90 |
+
data['val_ind'],
|
91 |
+
data['test_ind']],
|
92 |
+
extend_with=extend_with)
|
93 |
+
return {'texts': texts,
|
94 |
+
'labels': labels,
|
95 |
+
'added': added,
|
96 |
+
'batch_size': batch_size,
|
97 |
+
'maxlen': maxlen}
|
98 |
+
|
99 |
+
|
100 |
+
def calculate_batchsize_maxlen(texts):
|
101 |
+
""" Calculates the maximum length in the provided texts and a suitable
|
102 |
+
batch size. Rounds up maxlen to the nearest multiple of ten.
|
103 |
+
|
104 |
+
# Arguments:
|
105 |
+
texts: List of inputs.
|
106 |
+
|
107 |
+
# Returns:
|
108 |
+
Batch size,
|
109 |
+
max length
|
110 |
+
"""
|
111 |
+
def roundup(x):
|
112 |
+
return int(math.ceil(x / 10.0)) * 10
|
113 |
+
|
114 |
+
# Calculate max length of sequences considered
|
115 |
+
# Adjust batch_size accordingly to prevent GPU overflow
|
116 |
+
lengths = [len(tokenize(t)) for t in texts]
|
117 |
+
maxlen = roundup(np.percentile(lengths, 80.0))
|
118 |
+
batch_size = 250 if maxlen <= 100 else 50
|
119 |
+
return batch_size, maxlen
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
def freeze_layers(model, unfrozen_types=[], unfrozen_keyword=None):
|
124 |
+
""" Freezes all layers in the given model, except for ones that are
|
125 |
+
explicitly specified to not be frozen.
|
126 |
+
|
127 |
+
# Arguments:
|
128 |
+
model: Model whose layers should be modified.
|
129 |
+
unfrozen_types: List of layer types which shouldn't be frozen.
|
130 |
+
unfrozen_keyword: Name keywords of layers that shouldn't be frozen.
|
131 |
+
|
132 |
+
# Returns:
|
133 |
+
Model with the selected layers frozen.
|
134 |
+
"""
|
135 |
+
# Get trainable modules
|
136 |
+
trainable_modules = [(n, m) for n, m in model.named_children() if len([id(p) for p in m.parameters()]) != 0]
|
137 |
+
for name, module in trainable_modules:
|
138 |
+
trainable = (any(typ in str(module) for typ in unfrozen_types) or
|
139 |
+
(unfrozen_keyword is not None and unfrozen_keyword.lower() in name.lower()))
|
140 |
+
change_trainable(module, trainable, verbose=False)
|
141 |
+
return model
|
142 |
+
|
143 |
+
|
144 |
+
def change_trainable(module, trainable, verbose=False):
|
145 |
+
""" Helper method that freezes or unfreezes a given layer.
|
146 |
+
|
147 |
+
# Arguments:
|
148 |
+
module: Module to be modified.
|
149 |
+
trainable: Whether the layer should be frozen or unfrozen.
|
150 |
+
verbose: Verbosity flag.
|
151 |
+
"""
|
152 |
+
|
153 |
+
if verbose: print('Changing MODULE', module, 'to trainable =', trainable)
|
154 |
+
for name, param in module.named_parameters():
|
155 |
+
if verbose: print('Setting weight', name, 'to trainable =', trainable)
|
156 |
+
param.requires_grad = trainable
|
157 |
+
|
158 |
+
if verbose:
|
159 |
+
action = 'Unfroze' if trainable else 'Froze'
|
160 |
+
if verbose: print("{} {}".format(action, module))
|
161 |
+
|
162 |
+
|
163 |
+
def find_f1_threshold(model, val_gen, test_gen, average='binary'):
|
164 |
+
""" Choose a threshold for F1 based on the validation dataset
|
165 |
+
(see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/
|
166 |
+
for details on why to find another threshold than simply 0.5)
|
167 |
+
|
168 |
+
# Arguments:
|
169 |
+
model: pyTorch model
|
170 |
+
val_gen: Validation set dataloader.
|
171 |
+
test_gen: Testing set dataloader.
|
172 |
+
|
173 |
+
# Returns:
|
174 |
+
F1 score for the given data and
|
175 |
+
the corresponding F1 threshold
|
176 |
+
"""
|
177 |
+
thresholds = np.arange(0.01, 0.5, step=0.01)
|
178 |
+
f1_scores = []
|
179 |
+
|
180 |
+
model.eval()
|
181 |
+
val_out = [(y, model(X)) for X, y in val_gen]
|
182 |
+
y_val, y_pred_val = (list(t) for t in zip(*val_out))
|
183 |
+
|
184 |
+
test_out = [(y, model(X)) for X, y in test_gen]
|
185 |
+
y_test, y_pred_test = (list(t) for t in zip(*val_out))
|
186 |
+
|
187 |
+
for t in thresholds:
|
188 |
+
y_pred_val_ind = (y_pred_val > t)
|
189 |
+
f1_val = f1_score(y_val, y_pred_val_ind, average=average)
|
190 |
+
f1_scores.append(f1_val)
|
191 |
+
|
192 |
+
best_t = thresholds[np.argmax(f1_scores)]
|
193 |
+
y_pred_ind = (y_pred_test > best_t)
|
194 |
+
f1_test = f1_score(y_test, y_pred_ind, average=average)
|
195 |
+
return f1_test, best_t
|
196 |
+
|
197 |
+
|
198 |
+
def finetune(model, texts, labels, nb_classes, batch_size, method,
|
199 |
+
metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
|
200 |
+
verbose=1):
|
201 |
+
""" Compiles and finetunes the given pytorch model.
|
202 |
+
|
203 |
+
# Arguments:
|
204 |
+
model: Model to be finetuned
|
205 |
+
texts: List of three lists, containing tokenized inputs for training,
|
206 |
+
validation and testing (in that order).
|
207 |
+
labels: List of three lists, containing labels for training,
|
208 |
+
validation and testing (in that order).
|
209 |
+
nb_classes: Number of classes in the dataset.
|
210 |
+
batch_size: Batch size.
|
211 |
+
method: Finetuning method to be used. For available methods, see
|
212 |
+
FINETUNING_METHODS in global_variables.py.
|
213 |
+
metric: Evaluation metric to be used. For available metrics, see
|
214 |
+
FINETUNING_METRICS in global_variables.py.
|
215 |
+
epoch_size: Number of samples in an epoch.
|
216 |
+
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
|
217 |
+
embed_l2: L2 regularization for the embedding layer.
|
218 |
+
verbose: Verbosity flag.
|
219 |
+
|
220 |
+
# Returns:
|
221 |
+
Model after finetuning,
|
222 |
+
score after finetuning using the provided metric.
|
223 |
+
"""
|
224 |
+
|
225 |
+
if method not in FINETUNING_METHODS:
|
226 |
+
raise ValueError('ERROR (finetune): Invalid method parameter. '
|
227 |
+
'Available options: {}'.format(FINETUNING_METHODS))
|
228 |
+
if metric not in FINETUNING_METRICS:
|
229 |
+
raise ValueError('ERROR (finetune): Invalid metric parameter. '
|
230 |
+
'Available options: {}'.format(FINETUNING_METRICS))
|
231 |
+
|
232 |
+
train_gen = get_data_loader(texts[0], labels[0], batch_size,
|
233 |
+
extended_batch_sampler=True, epoch_size=epoch_size)
|
234 |
+
val_gen = get_data_loader(texts[1], labels[1], batch_size,
|
235 |
+
extended_batch_sampler=False)
|
236 |
+
test_gen = get_data_loader(texts[2], labels[2], batch_size,
|
237 |
+
extended_batch_sampler=False)
|
238 |
+
|
239 |
+
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
|
240 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
241 |
+
|
242 |
+
if method in ['last', 'new']:
|
243 |
+
lr = 0.001
|
244 |
+
elif method in ['full', 'chain-thaw']:
|
245 |
+
lr = 0.0001
|
246 |
+
|
247 |
+
loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \
|
248 |
+
else nn.CrossEntropyLoss()
|
249 |
+
|
250 |
+
# Freeze layers if using last
|
251 |
+
if method == 'last':
|
252 |
+
model = freeze_layers(model, unfrozen_keyword='output_layer')
|
253 |
+
|
254 |
+
# Define optimizer, for chain-thaw we define it later (after freezing)
|
255 |
+
if method == 'last':
|
256 |
+
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
|
257 |
+
elif method in ['full', 'new']:
|
258 |
+
# Add L2 regulation on embeddings only
|
259 |
+
embed_params_id = [id(p) for p in model.embed.parameters()]
|
260 |
+
output_layer_params_id = [id(p) for p in model.output_layer.parameters()]
|
261 |
+
base_params = [p for p in model.parameters()
|
262 |
+
if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad]
|
263 |
+
embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad]
|
264 |
+
output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad]
|
265 |
+
adam = optim.Adam([
|
266 |
+
{'params': base_params},
|
267 |
+
{'params': embed_params, 'weight_decay': embed_l2},
|
268 |
+
{'params': output_layer_params, 'lr': 0.001},
|
269 |
+
], lr=lr)
|
270 |
+
|
271 |
+
# Training
|
272 |
+
if verbose:
|
273 |
+
print('Method: {}'.format(method))
|
274 |
+
print('Metric: {}'.format(metric))
|
275 |
+
print('Classes: {}'.format(nb_classes))
|
276 |
+
|
277 |
+
if method == 'chain-thaw':
|
278 |
+
result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2,
|
279 |
+
evaluate=metric, verbose=verbose)
|
280 |
+
else:
|
281 |
+
result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path,
|
282 |
+
evaluate=metric, verbose=verbose)
|
283 |
+
return model, result
|
284 |
+
|
285 |
+
|
286 |
+
def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen,
|
287 |
+
nb_epochs, checkpoint_path, patience=5, evaluate='acc',
|
288 |
+
verbose=2):
|
289 |
+
""" Finetunes the given model using the accuracy measure.
|
290 |
+
|
291 |
+
# Arguments:
|
292 |
+
model: Model to be finetuned.
|
293 |
+
nb_classes: Number of classes in the given dataset.
|
294 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
295 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
296 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
297 |
+
epoch_size: Number of samples in an epoch.
|
298 |
+
nb_epochs: Number of epochs.
|
299 |
+
batch_size: Batch size.
|
300 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
301 |
+
during training. This file will be rewritten by the function.
|
302 |
+
patience: Patience for callback methods.
|
303 |
+
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
|
304 |
+
verbose: Verbosity flag.
|
305 |
+
|
306 |
+
# Returns:
|
307 |
+
Accuracy of the trained model, ONLY if 'evaluate' is set.
|
308 |
+
"""
|
309 |
+
if verbose:
|
310 |
+
print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad]))
|
311 |
+
print("Training...")
|
312 |
+
if evaluate == 'acc':
|
313 |
+
print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen))
|
314 |
+
elif evaluate == 'weighted_f1':
|
315 |
+
print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen))
|
316 |
+
|
317 |
+
fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience)
|
318 |
+
|
319 |
+
# Reload the best weights found to avoid overfitting
|
320 |
+
# Wait a bit to allow proper closing of weights file
|
321 |
+
sleep(1)
|
322 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
323 |
+
if verbose >= 2:
|
324 |
+
print("Loaded weights from {}".format(checkpoint_path))
|
325 |
+
|
326 |
+
if evaluate == 'acc':
|
327 |
+
return evaluate_using_acc(model, test_gen)
|
328 |
+
elif evaluate == 'weighted_f1':
|
329 |
+
return evaluate_using_weighted_f1(model, test_gen, val_gen)
|
330 |
+
|
331 |
+
|
332 |
+
def evaluate_using_weighted_f1(model, test_gen, val_gen):
|
333 |
+
""" Evaluation function using macro weighted F1 score.
|
334 |
+
|
335 |
+
# Arguments:
|
336 |
+
model: Model to be evaluated.
|
337 |
+
X_test: Inputs of the testing set.
|
338 |
+
y_test: Outputs of the testing set.
|
339 |
+
X_val: Inputs of the validation set.
|
340 |
+
y_val: Outputs of the validation set.
|
341 |
+
batch_size: Batch size.
|
342 |
+
|
343 |
+
# Returns:
|
344 |
+
Weighted F1 score of the given model.
|
345 |
+
"""
|
346 |
+
# Evaluate on test and val data
|
347 |
+
f1_test, _ = find_f1_threshold(model, test_gen, val_gen, average='weighted_f1')
|
348 |
+
return f1_test
|
349 |
+
|
350 |
+
|
351 |
+
def evaluate_using_acc(model, test_gen):
|
352 |
+
""" Evaluation function using accuracy.
|
353 |
+
|
354 |
+
# Arguments:
|
355 |
+
model: Model to be evaluated.
|
356 |
+
test_gen: Testing data iterator (DataLoader)
|
357 |
+
|
358 |
+
# Returns:
|
359 |
+
Accuracy of the given model.
|
360 |
+
"""
|
361 |
+
|
362 |
+
# Validate on test_data
|
363 |
+
model.eval()
|
364 |
+
accs = []
|
365 |
+
for i, data in enumerate(test_gen):
|
366 |
+
x, y = data
|
367 |
+
outs = model(x)
|
368 |
+
if model.nb_classes > 2:
|
369 |
+
pred = torch.max(outs, 1)[1]
|
370 |
+
acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
|
371 |
+
else:
|
372 |
+
pred = (outs >= 0).long()
|
373 |
+
acc = (pred == y).double().sum() / len(pred)
|
374 |
+
accs.append(acc)
|
375 |
+
return np.mean(accs)
|
376 |
+
|
377 |
+
|
378 |
+
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
|
379 |
+
patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1):
|
380 |
+
""" Finetunes given model using chain-thaw and evaluates using accuracy.
|
381 |
+
|
382 |
+
# Arguments:
|
383 |
+
model: Model to be finetuned.
|
384 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
385 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
386 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
387 |
+
batch_size: Batch size.
|
388 |
+
loss: Loss function to be used during training.
|
389 |
+
epoch_size: Number of samples in an epoch.
|
390 |
+
nb_epochs: Number of epochs.
|
391 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
392 |
+
during training. This file will be rewritten by the function.
|
393 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
394 |
+
training step (i.e. the output_layer layer)
|
395 |
+
next_lr: Learning rate for every subsequent step.
|
396 |
+
seed: Random number generator seed.
|
397 |
+
verbose: Verbosity flag.
|
398 |
+
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
|
399 |
+
|
400 |
+
# Returns:
|
401 |
+
Accuracy of the finetuned model.
|
402 |
+
"""
|
403 |
+
if verbose:
|
404 |
+
print('Training..')
|
405 |
+
|
406 |
+
# Train using chain-thaw
|
407 |
+
train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
|
408 |
+
initial_lr, next_lr, embed_l2, verbose)
|
409 |
+
|
410 |
+
if evaluate == 'acc':
|
411 |
+
return evaluate_using_acc(model, test_gen)
|
412 |
+
elif evaluate == 'weighted_f1':
|
413 |
+
return evaluate_using_weighted_f1(model, test_gen, val_gen)
|
414 |
+
|
415 |
+
|
416 |
+
def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
|
417 |
+
initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1):
|
418 |
+
""" Finetunes model using the chain-thaw method.
|
419 |
+
|
420 |
+
This is done as follows:
|
421 |
+
1) Freeze every layer except the last (output_layer) layer and train it.
|
422 |
+
2) Freeze every layer except the first layer and train it.
|
423 |
+
3) Freeze every layer except the second etc., until the second last layer.
|
424 |
+
4) Unfreeze all layers and train entire model.
|
425 |
+
|
426 |
+
# Arguments:
|
427 |
+
model: Model to be trained.
|
428 |
+
train_gen: Training sample generator.
|
429 |
+
val_data: Validation data.
|
430 |
+
loss: Loss function to be used.
|
431 |
+
finetuning_args: Training early stopping and checkpoint saving parameters
|
432 |
+
epoch_size: Number of samples in an epoch.
|
433 |
+
nb_epochs: Number of epochs.
|
434 |
+
checkpoint_weight_path: Where weight checkpoints should be saved.
|
435 |
+
batch_size: Batch size.
|
436 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
437 |
+
training step (i.e. the output_layer layer)
|
438 |
+
next_lr: Learning rate for every subsequent step.
|
439 |
+
verbose: Verbosity flag.
|
440 |
+
"""
|
441 |
+
# Get trainable layers
|
442 |
+
layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0]
|
443 |
+
|
444 |
+
# Bring last layer to front
|
445 |
+
layers.insert(0, layers.pop(len(layers) - 1))
|
446 |
+
|
447 |
+
# Add None to the end to signify finetuning all layers
|
448 |
+
layers.append(None)
|
449 |
+
|
450 |
+
lr = None
|
451 |
+
# Finetune each layer one by one and finetune all of them at once
|
452 |
+
# at the end
|
453 |
+
for layer in layers:
|
454 |
+
if lr is None:
|
455 |
+
lr = initial_lr
|
456 |
+
elif lr == initial_lr:
|
457 |
+
lr = next_lr
|
458 |
+
|
459 |
+
# Freeze all except current layer
|
460 |
+
for _layer in layers:
|
461 |
+
if _layer is not None:
|
462 |
+
trainable = _layer == layer or layer is None
|
463 |
+
change_trainable(_layer, trainable=trainable, verbose=False)
|
464 |
+
|
465 |
+
# Verify we froze the right layers
|
466 |
+
for _layer in model.children():
|
467 |
+
assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None
|
468 |
+
|
469 |
+
if verbose:
|
470 |
+
if layer is None:
|
471 |
+
print('Finetuning all layers')
|
472 |
+
else:
|
473 |
+
print('Finetuning {}'.format(layer))
|
474 |
+
|
475 |
+
special_params = [id(p) for p in model.embed.parameters()]
|
476 |
+
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
|
477 |
+
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
|
478 |
+
adam = optim.Adam([
|
479 |
+
{'params': base_params},
|
480 |
+
{'params': embed_parameters, 'weight_decay': embed_l2},
|
481 |
+
], lr=lr)
|
482 |
+
|
483 |
+
fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs,
|
484 |
+
checkpoint_path, patience)
|
485 |
+
|
486 |
+
# Reload the best weights found to avoid overfitting
|
487 |
+
# Wait a bit to allow proper closing of weights file
|
488 |
+
sleep(1)
|
489 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
490 |
+
if verbose >= 2:
|
491 |
+
print("Loaded weights from {}".format(checkpoint_path))
|
492 |
+
|
493 |
+
|
494 |
+
def calc_loss(loss_op, pred, yv):
|
495 |
+
if type(loss_op) is nn.CrossEntropyLoss:
|
496 |
+
return loss_op(pred.squeeze(), yv.squeeze())
|
497 |
+
else:
|
498 |
+
return loss_op(pred.squeeze(), yv.squeeze().float())
|
499 |
+
|
500 |
+
|
501 |
+
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
502 |
+
checkpoint_path, patience):
|
503 |
+
""" Analog to Keras fit_generator function.
|
504 |
+
|
505 |
+
# Arguments:
|
506 |
+
model: Model to be finetuned.
|
507 |
+
loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.)
|
508 |
+
optim_op: optimization operation (Adam e.g.)
|
509 |
+
train_gen: Training data iterator (DataLoader)
|
510 |
+
val_gen: Validation data iterator (DataLoader)
|
511 |
+
epochs: Number of epochs.
|
512 |
+
checkpoint_path: Filepath where weights will be checkpointed to
|
513 |
+
during training. This file will be rewritten by the function.
|
514 |
+
patience: Patience for callback methods.
|
515 |
+
verbose: Verbosity flag.
|
516 |
+
|
517 |
+
# Returns:
|
518 |
+
Accuracy of the trained model, ONLY if 'evaluate' is set.
|
519 |
+
"""
|
520 |
+
# Save original checkpoint
|
521 |
+
torch.save(model.state_dict(), checkpoint_path)
|
522 |
+
|
523 |
+
model.eval()
|
524 |
+
best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
|
525 |
+
print("original val loss", best_loss)
|
526 |
+
|
527 |
+
epoch_without_impr = 0
|
528 |
+
for epoch in range(epochs):
|
529 |
+
for i, data in enumerate(train_gen):
|
530 |
+
X_train, y_train = data
|
531 |
+
X_train = Variable(X_train, requires_grad=False)
|
532 |
+
y_train = Variable(y_train, requires_grad=False)
|
533 |
+
model.train()
|
534 |
+
optim_op.zero_grad()
|
535 |
+
output = model(X_train)
|
536 |
+
loss = calc_loss(loss_op, output, y_train)
|
537 |
+
loss.backward()
|
538 |
+
clip_grad_norm(model.parameters(), 1)
|
539 |
+
optim_op.step()
|
540 |
+
|
541 |
+
acc = evaluate_using_acc(model, [(X_train.data, y_train.data)])
|
542 |
+
print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc)
|
543 |
+
|
544 |
+
model.eval()
|
545 |
+
acc = evaluate_using_acc(model, val_gen)
|
546 |
+
print("val acc", acc)
|
547 |
+
|
548 |
+
val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
|
549 |
+
print("val loss", val_loss)
|
550 |
+
if best_loss is not None and val_loss >= best_loss:
|
551 |
+
epoch_without_impr += 1
|
552 |
+
print('No improvement over previous best loss: ', best_loss)
|
553 |
+
|
554 |
+
# Save checkpoint
|
555 |
+
if best_loss is None or val_loss < best_loss:
|
556 |
+
best_loss = val_loss
|
557 |
+
torch.save(model.state_dict(), checkpoint_path)
|
558 |
+
print('Saving model at', checkpoint_path)
|
559 |
+
|
560 |
+
# Early stopping
|
561 |
+
if epoch_without_impr >= patience:
|
562 |
+
break
|
563 |
+
|
564 |
+
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42):
|
565 |
+
""" Returns a dataloader that enables larger epochs on small datasets and
|
566 |
+
has upsampling functionality.
|
567 |
+
|
568 |
+
# Arguments:
|
569 |
+
X_in: Inputs of the given dataset.
|
570 |
+
y_in: Outputs of the given dataset.
|
571 |
+
batch_size: Batch size.
|
572 |
+
epoch_size: Number of samples in an epoch.
|
573 |
+
upsample: Whether upsampling should be done. This flag should only be
|
574 |
+
set on binary class problems.
|
575 |
+
|
576 |
+
# Returns:
|
577 |
+
DataLoader.
|
578 |
+
"""
|
579 |
+
dataset = DeepMojiDataset(X_in, y_in)
|
580 |
+
|
581 |
+
if extended_batch_sampler:
|
582 |
+
batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed)
|
583 |
+
else:
|
584 |
+
batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False)
|
585 |
+
|
586 |
+
return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0)
|
587 |
+
|
588 |
+
class DeepMojiDataset(Dataset):
|
589 |
+
""" A simple Dataset class.
|
590 |
+
|
591 |
+
# Arguments:
|
592 |
+
X_in: Inputs of the given dataset.
|
593 |
+
y_in: Outputs of the given dataset.
|
594 |
+
|
595 |
+
# __getitem__ output:
|
596 |
+
(torch.LongTensor, torch.LongTensor)
|
597 |
+
"""
|
598 |
+
def __init__(self, X_in, y_in):
|
599 |
+
# Check if we have Torch.LongTensor inputs (assume Numpy array otherwise)
|
600 |
+
if not isinstance(X_in, torch.LongTensor):
|
601 |
+
X_in = torch.from_numpy(X_in.astype('int64')).long()
|
602 |
+
if not isinstance(y_in, torch.LongTensor):
|
603 |
+
y_in = torch.from_numpy(y_in.astype('int64')).long()
|
604 |
+
|
605 |
+
self.X_in = torch.split(X_in, 1, dim=0)
|
606 |
+
self.y_in = torch.split(y_in, 1, dim=0)
|
607 |
+
|
608 |
+
def __len__(self):
|
609 |
+
return len(self.X_in)
|
610 |
+
|
611 |
+
def __getitem__(self, idx):
|
612 |
+
return self.X_in[idx].squeeze(), self.y_in[idx].squeeze()
|
613 |
+
|
614 |
+
class DeepMojiBatchSampler(object):
|
615 |
+
"""A Batch sampler that enables larger epochs on small datasets and
|
616 |
+
has upsampling functionality.
|
617 |
+
|
618 |
+
# Arguments:
|
619 |
+
y_in: Labels of the dataset.
|
620 |
+
batch_size: Batch size.
|
621 |
+
epoch_size: Number of samples in an epoch.
|
622 |
+
upsample: Whether upsampling should be done. This flag should only be
|
623 |
+
set on binary class problems.
|
624 |
+
seed: Random number generator seed.
|
625 |
+
|
626 |
+
# __iter__ output:
|
627 |
+
iterator of lists (batches) of indices in the dataset
|
628 |
+
"""
|
629 |
+
|
630 |
+
def __init__(self, y_in, batch_size, epoch_size, upsample, seed):
|
631 |
+
self.batch_size = batch_size
|
632 |
+
self.epoch_size = epoch_size
|
633 |
+
self.upsample = upsample
|
634 |
+
|
635 |
+
np.random.seed(seed)
|
636 |
+
|
637 |
+
if upsample:
|
638 |
+
# Should only be used on binary class problems
|
639 |
+
assert len(y_in.shape) == 1
|
640 |
+
neg = np.where(y_in.numpy() == 0)[0]
|
641 |
+
pos = np.where(y_in.numpy() == 1)[0]
|
642 |
+
assert epoch_size % 2 == 0
|
643 |
+
samples_pr_class = int(epoch_size / 2)
|
644 |
+
else:
|
645 |
+
ind = range(len(y_in))
|
646 |
+
|
647 |
+
if not upsample:
|
648 |
+
# Randomly sample observations in a balanced way
|
649 |
+
self.sample_ind = np.random.choice(ind, epoch_size, replace=True)
|
650 |
+
else:
|
651 |
+
# Randomly sample observations in a balanced way
|
652 |
+
sample_neg = np.random.choice(neg, samples_pr_class, replace=True)
|
653 |
+
sample_pos = np.random.choice(pos, samples_pr_class, replace=True)
|
654 |
+
concat_ind = np.concatenate((sample_neg, sample_pos), axis=0)
|
655 |
+
|
656 |
+
# Shuffle to avoid labels being in specific order
|
657 |
+
# (all negative then positive)
|
658 |
+
p = np.random.permutation(len(concat_ind))
|
659 |
+
self.sample_ind = concat_ind[p]
|
660 |
+
|
661 |
+
label_dist = np.mean(y_in.numpy()[self.sample_ind])
|
662 |
+
assert(label_dist > 0.45)
|
663 |
+
assert(label_dist < 0.55)
|
664 |
+
|
665 |
+
def __iter__(self):
|
666 |
+
# Hand-off data using batch_size
|
667 |
+
for i in range(int(self.epoch_size/self.batch_size)):
|
668 |
+
start = i * self.batch_size
|
669 |
+
end = min(start + self.batch_size, self.epoch_size)
|
670 |
+
yield self.sample_ind[start:end]
|
671 |
+
|
672 |
+
def __len__(self):
|
673 |
+
# Take care of the last (maybe incomplete) batch
|
674 |
+
return (self.epoch_size + self.batch_size - 1) // self.batch_size
|
torchmoji/global_variables.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Global variables.
|
3 |
+
"""
|
4 |
+
import tempfile
|
5 |
+
from os.path import abspath, dirname
|
6 |
+
|
7 |
+
# The ordering of these special tokens matter
|
8 |
+
# blank tokens can be used for new purposes
|
9 |
+
# Tokenizer should be updated if special token prefix is changed
|
10 |
+
SPECIAL_PREFIX = 'CUSTOM_'
|
11 |
+
SPECIAL_TOKENS = ['CUSTOM_MASK',
|
12 |
+
'CUSTOM_UNKNOWN',
|
13 |
+
'CUSTOM_AT',
|
14 |
+
'CUSTOM_URL',
|
15 |
+
'CUSTOM_NUMBER',
|
16 |
+
'CUSTOM_BREAK']
|
17 |
+
SPECIAL_TOKENS.extend(['{}BLANK_{}'.format(SPECIAL_PREFIX, i) for i in range(6, 10)])
|
18 |
+
|
19 |
+
ROOT_PATH = dirname(dirname(abspath(__file__)))
|
20 |
+
VOCAB_PATH = '{}/model/vocabulary.json'.format(ROOT_PATH)
|
21 |
+
PRETRAINED_PATH = '{}/model/pytorch_model.bin'.format(ROOT_PATH)
|
22 |
+
|
23 |
+
WEIGHTS_DIR = tempfile.mkdtemp()
|
24 |
+
|
25 |
+
NB_TOKENS = 50000
|
26 |
+
NB_EMOJI_CLASSES = 64
|
27 |
+
FINETUNING_METHODS = ['last', 'full', 'new', 'chain-thaw']
|
28 |
+
FINETUNING_METRICS = ['acc', 'weighted']
|
torchmoji/lstm.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Implement a pyTorch LSTM with hard sigmoid reccurent activation functions.
|
3 |
+
Adapted from the non-cuda variant of pyTorch LSTM at
|
4 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
from __future__ import print_function, division
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from torch.nn import Module
|
12 |
+
from torch.nn.parameter import Parameter
|
13 |
+
from torch.nn.utils.rnn import PackedSequence
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
class LSTMHardSigmoid(Module):
|
17 |
+
|
18 |
+
def __init__(self, input_size, hidden_size,
|
19 |
+
num_layers=1, bias=True, batch_first=False,
|
20 |
+
dropout=0, bidirectional=False):
|
21 |
+
super(LSTMHardSigmoid, self).__init__()
|
22 |
+
self.input_size = input_size
|
23 |
+
self.hidden_size = hidden_size
|
24 |
+
self.num_layers = num_layers
|
25 |
+
self.bias = bias
|
26 |
+
self.batch_first = batch_first
|
27 |
+
self.dropout = dropout
|
28 |
+
self.dropout_state = {}
|
29 |
+
self.bidirectional = bidirectional
|
30 |
+
num_directions = 2 if bidirectional else 1
|
31 |
+
|
32 |
+
gate_size = 4 * hidden_size
|
33 |
+
|
34 |
+
self._all_weights = []
|
35 |
+
for layer in range(num_layers):
|
36 |
+
for direction in range(num_directions):
|
37 |
+
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
|
38 |
+
|
39 |
+
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
|
40 |
+
w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
|
41 |
+
b_ih = Parameter(torch.Tensor(gate_size))
|
42 |
+
b_hh = Parameter(torch.Tensor(gate_size))
|
43 |
+
layer_params = (w_ih, w_hh, b_ih, b_hh)
|
44 |
+
|
45 |
+
suffix = '_reverse' if direction == 1 else ''
|
46 |
+
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
|
47 |
+
if bias:
|
48 |
+
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
|
49 |
+
param_names = [x.format(layer, suffix) for x in param_names]
|
50 |
+
|
51 |
+
for name, param in zip(param_names, layer_params):
|
52 |
+
setattr(self, name, param)
|
53 |
+
self._all_weights.append(param_names)
|
54 |
+
|
55 |
+
self.flatten_parameters()
|
56 |
+
self.reset_parameters()
|
57 |
+
|
58 |
+
def flatten_parameters(self):
|
59 |
+
"""Resets parameter data pointer so that they can use faster code paths.
|
60 |
+
|
61 |
+
Right now, this is a no-op wince we don't use CUDA acceleration.
|
62 |
+
"""
|
63 |
+
self._data_ptrs = []
|
64 |
+
|
65 |
+
def _apply(self, fn):
|
66 |
+
ret = super(LSTMHardSigmoid, self)._apply(fn)
|
67 |
+
self.flatten_parameters()
|
68 |
+
return ret
|
69 |
+
|
70 |
+
def reset_parameters(self):
|
71 |
+
stdv = 1.0 / math.sqrt(self.hidden_size)
|
72 |
+
for weight in self.parameters():
|
73 |
+
weight.data.uniform_(-stdv, stdv)
|
74 |
+
|
75 |
+
def forward(self, input, hx=None):
|
76 |
+
is_packed = isinstance(input, PackedSequence)
|
77 |
+
if is_packed:
|
78 |
+
input, batch_sizes ,_ ,_ = input
|
79 |
+
max_batch_size = batch_sizes[0]
|
80 |
+
else:
|
81 |
+
batch_sizes = None
|
82 |
+
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
83 |
+
|
84 |
+
if hx is None:
|
85 |
+
num_directions = 2 if self.bidirectional else 1
|
86 |
+
hx = torch.autograd.Variable(input.data.new(self.num_layers *
|
87 |
+
num_directions,
|
88 |
+
max_batch_size,
|
89 |
+
self.hidden_size).zero_(), requires_grad=False)
|
90 |
+
hx = (hx, hx)
|
91 |
+
|
92 |
+
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
|
93 |
+
if has_flat_weights:
|
94 |
+
first_data = next(self.parameters()).data
|
95 |
+
assert first_data.storage().size() == self._param_buf_size
|
96 |
+
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
|
97 |
+
else:
|
98 |
+
flat_weight = None
|
99 |
+
func = AutogradRNN(
|
100 |
+
self.input_size,
|
101 |
+
self.hidden_size,
|
102 |
+
num_layers=self.num_layers,
|
103 |
+
batch_first=self.batch_first,
|
104 |
+
dropout=self.dropout,
|
105 |
+
train=self.training,
|
106 |
+
bidirectional=self.bidirectional,
|
107 |
+
batch_sizes=batch_sizes,
|
108 |
+
dropout_state=self.dropout_state,
|
109 |
+
flat_weight=flat_weight
|
110 |
+
)
|
111 |
+
output, hidden = func(input, self.all_weights, hx)
|
112 |
+
if is_packed:
|
113 |
+
output = PackedSequence(output, batch_sizes)
|
114 |
+
return output, hidden
|
115 |
+
|
116 |
+
def __repr__(self):
|
117 |
+
s = '{name}({input_size}, {hidden_size}'
|
118 |
+
if self.num_layers != 1:
|
119 |
+
s += ', num_layers={num_layers}'
|
120 |
+
if self.bias is not True:
|
121 |
+
s += ', bias={bias}'
|
122 |
+
if self.batch_first is not False:
|
123 |
+
s += ', batch_first={batch_first}'
|
124 |
+
if self.dropout != 0:
|
125 |
+
s += ', dropout={dropout}'
|
126 |
+
if self.bidirectional is not False:
|
127 |
+
s += ', bidirectional={bidirectional}'
|
128 |
+
s += ')'
|
129 |
+
return s.format(name=self.__class__.__name__, **self.__dict__)
|
130 |
+
|
131 |
+
def __setstate__(self, d):
|
132 |
+
super(LSTMHardSigmoid, self).__setstate__(d)
|
133 |
+
self.__dict__.setdefault('_data_ptrs', [])
|
134 |
+
if 'all_weights' in d:
|
135 |
+
self._all_weights = d['all_weights']
|
136 |
+
if isinstance(self._all_weights[0][0], str):
|
137 |
+
return
|
138 |
+
num_layers = self.num_layers
|
139 |
+
num_directions = 2 if self.bidirectional else 1
|
140 |
+
self._all_weights = []
|
141 |
+
for layer in range(num_layers):
|
142 |
+
for direction in range(num_directions):
|
143 |
+
suffix = '_reverse' if direction == 1 else ''
|
144 |
+
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
|
145 |
+
weights = [x.format(layer, suffix) for x in weights]
|
146 |
+
if self.bias:
|
147 |
+
self._all_weights += [weights]
|
148 |
+
else:
|
149 |
+
self._all_weights += [weights[:2]]
|
150 |
+
|
151 |
+
@property
|
152 |
+
def all_weights(self):
|
153 |
+
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
|
154 |
+
|
155 |
+
def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False,
|
156 |
+
dropout=0, train=True, bidirectional=False, batch_sizes=None,
|
157 |
+
dropout_state=None, flat_weight=None):
|
158 |
+
|
159 |
+
cell = LSTMCell
|
160 |
+
|
161 |
+
if batch_sizes is None:
|
162 |
+
rec_factory = Recurrent
|
163 |
+
else:
|
164 |
+
rec_factory = variable_recurrent_factory(batch_sizes)
|
165 |
+
|
166 |
+
if bidirectional:
|
167 |
+
layer = (rec_factory(cell), rec_factory(cell, reverse=True))
|
168 |
+
else:
|
169 |
+
layer = (rec_factory(cell),)
|
170 |
+
|
171 |
+
func = StackedRNN(layer,
|
172 |
+
num_layers,
|
173 |
+
True,
|
174 |
+
dropout=dropout,
|
175 |
+
train=train)
|
176 |
+
|
177 |
+
def forward(input, weight, hidden):
|
178 |
+
if batch_first and batch_sizes is None:
|
179 |
+
input = input.transpose(0, 1)
|
180 |
+
|
181 |
+
nexth, output = func(input, hidden, weight)
|
182 |
+
|
183 |
+
if batch_first and batch_sizes is None:
|
184 |
+
output = output.transpose(0, 1)
|
185 |
+
|
186 |
+
return output, nexth
|
187 |
+
|
188 |
+
return forward
|
189 |
+
|
190 |
+
def Recurrent(inner, reverse=False):
|
191 |
+
def forward(input, hidden, weight):
|
192 |
+
output = []
|
193 |
+
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
|
194 |
+
for i in steps:
|
195 |
+
hidden = inner(input[i], hidden, *weight)
|
196 |
+
# hack to handle LSTM
|
197 |
+
output.append(hidden[0] if isinstance(hidden, tuple) else hidden)
|
198 |
+
|
199 |
+
if reverse:
|
200 |
+
output.reverse()
|
201 |
+
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
|
202 |
+
|
203 |
+
return hidden, output
|
204 |
+
|
205 |
+
return forward
|
206 |
+
|
207 |
+
|
208 |
+
def variable_recurrent_factory(batch_sizes):
|
209 |
+
def fac(inner, reverse=False):
|
210 |
+
if reverse:
|
211 |
+
return VariableRecurrentReverse(batch_sizes, inner)
|
212 |
+
else:
|
213 |
+
return VariableRecurrent(batch_sizes, inner)
|
214 |
+
return fac
|
215 |
+
|
216 |
+
def VariableRecurrent(batch_sizes, inner):
|
217 |
+
def forward(input, hidden, weight):
|
218 |
+
output = []
|
219 |
+
input_offset = 0
|
220 |
+
last_batch_size = batch_sizes[0]
|
221 |
+
hiddens = []
|
222 |
+
flat_hidden = not isinstance(hidden, tuple)
|
223 |
+
if flat_hidden:
|
224 |
+
hidden = (hidden,)
|
225 |
+
for batch_size in batch_sizes:
|
226 |
+
step_input = input[input_offset:input_offset + batch_size]
|
227 |
+
input_offset += batch_size
|
228 |
+
|
229 |
+
dec = last_batch_size - batch_size
|
230 |
+
if dec > 0:
|
231 |
+
hiddens.append(tuple(h[-dec:] for h in hidden))
|
232 |
+
hidden = tuple(h[:-dec] for h in hidden)
|
233 |
+
last_batch_size = batch_size
|
234 |
+
|
235 |
+
if flat_hidden:
|
236 |
+
hidden = (inner(step_input, hidden[0], *weight),)
|
237 |
+
else:
|
238 |
+
hidden = inner(step_input, hidden, *weight)
|
239 |
+
|
240 |
+
output.append(hidden[0])
|
241 |
+
hiddens.append(hidden)
|
242 |
+
hiddens.reverse()
|
243 |
+
|
244 |
+
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
|
245 |
+
assert hidden[0].size(0) == batch_sizes[0]
|
246 |
+
if flat_hidden:
|
247 |
+
hidden = hidden[0]
|
248 |
+
output = torch.cat(output, 0)
|
249 |
+
|
250 |
+
return hidden, output
|
251 |
+
|
252 |
+
return forward
|
253 |
+
|
254 |
+
|
255 |
+
def VariableRecurrentReverse(batch_sizes, inner):
|
256 |
+
def forward(input, hidden, weight):
|
257 |
+
output = []
|
258 |
+
input_offset = input.size(0)
|
259 |
+
last_batch_size = batch_sizes[-1]
|
260 |
+
initial_hidden = hidden
|
261 |
+
flat_hidden = not isinstance(hidden, tuple)
|
262 |
+
if flat_hidden:
|
263 |
+
hidden = (hidden,)
|
264 |
+
initial_hidden = (initial_hidden,)
|
265 |
+
hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
|
266 |
+
for batch_size in reversed(batch_sizes):
|
267 |
+
inc = batch_size - last_batch_size
|
268 |
+
if inc > 0:
|
269 |
+
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
|
270 |
+
for h, ih in zip(hidden, initial_hidden))
|
271 |
+
last_batch_size = batch_size
|
272 |
+
step_input = input[input_offset - batch_size:input_offset]
|
273 |
+
input_offset -= batch_size
|
274 |
+
|
275 |
+
if flat_hidden:
|
276 |
+
hidden = (inner(step_input, hidden[0], *weight),)
|
277 |
+
else:
|
278 |
+
hidden = inner(step_input, hidden, *weight)
|
279 |
+
output.append(hidden[0])
|
280 |
+
|
281 |
+
output.reverse()
|
282 |
+
output = torch.cat(output, 0)
|
283 |
+
if flat_hidden:
|
284 |
+
hidden = hidden[0]
|
285 |
+
return hidden, output
|
286 |
+
|
287 |
+
return forward
|
288 |
+
|
289 |
+
def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
|
290 |
+
|
291 |
+
num_directions = len(inners)
|
292 |
+
total_layers = num_layers * num_directions
|
293 |
+
|
294 |
+
def forward(input, hidden, weight):
|
295 |
+
assert(len(weight) == total_layers)
|
296 |
+
next_hidden = []
|
297 |
+
|
298 |
+
if lstm:
|
299 |
+
hidden = list(zip(*hidden))
|
300 |
+
|
301 |
+
for i in range(num_layers):
|
302 |
+
all_output = []
|
303 |
+
for j, inner in enumerate(inners):
|
304 |
+
l = i * num_directions + j
|
305 |
+
|
306 |
+
hy, output = inner(input, hidden[l], weight[l])
|
307 |
+
next_hidden.append(hy)
|
308 |
+
all_output.append(output)
|
309 |
+
|
310 |
+
input = torch.cat(all_output, input.dim() - 1)
|
311 |
+
|
312 |
+
if dropout != 0 and i < num_layers - 1:
|
313 |
+
input = F.dropout(input, p=dropout, training=train, inplace=False)
|
314 |
+
|
315 |
+
if lstm:
|
316 |
+
next_h, next_c = zip(*next_hidden)
|
317 |
+
next_hidden = (
|
318 |
+
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
|
319 |
+
torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
next_hidden = torch.cat(next_hidden, 0).view(
|
323 |
+
total_layers, *next_hidden[0].size())
|
324 |
+
|
325 |
+
return next_hidden, input
|
326 |
+
|
327 |
+
return forward
|
328 |
+
|
329 |
+
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
|
330 |
+
"""
|
331 |
+
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates.
|
332 |
+
"""
|
333 |
+
hx, cx = hidden
|
334 |
+
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
|
335 |
+
|
336 |
+
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
337 |
+
|
338 |
+
ingate = hard_sigmoid(ingate)
|
339 |
+
forgetgate = hard_sigmoid(forgetgate)
|
340 |
+
cellgate = F.tanh(cellgate)
|
341 |
+
outgate = hard_sigmoid(outgate)
|
342 |
+
|
343 |
+
cy = (forgetgate * cx) + (ingate * cellgate)
|
344 |
+
hy = outgate * F.tanh(cy)
|
345 |
+
|
346 |
+
return hy, cy
|
347 |
+
|
348 |
+
def hard_sigmoid(x):
|
349 |
+
"""
|
350 |
+
Computes element-wise hard sigmoid of x.
|
351 |
+
See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
|
352 |
+
"""
|
353 |
+
x = (0.2 * x) + 0.5
|
354 |
+
x = F.threshold(-x, -1, -1)
|
355 |
+
x = F.threshold(-x, 0, 0)
|
356 |
+
return x
|
torchmoji/model_def.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Model definition functions and weight loading.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from __future__ import print_function, division, unicode_literals
|
6 |
+
|
7 |
+
from os.path import exists
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.autograd import Variable
|
12 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
|
13 |
+
|
14 |
+
from torchmoji.lstm import LSTMHardSigmoid
|
15 |
+
from torchmoji.attlayer import Attention
|
16 |
+
from torchmoji.global_variables import NB_TOKENS, NB_EMOJI_CLASSES
|
17 |
+
|
18 |
+
|
19 |
+
def torchmoji_feature_encoding(weight_path, return_attention=False):
|
20 |
+
""" Loads the pretrained torchMoji model for extracting features
|
21 |
+
from the penultimate feature layer. In this way, it transforms
|
22 |
+
the text into its emotional encoding.
|
23 |
+
|
24 |
+
# Arguments:
|
25 |
+
weight_path: Path to model weights to be loaded.
|
26 |
+
return_attention: If true, output will include weight of each input token
|
27 |
+
used for the prediction
|
28 |
+
|
29 |
+
# Returns:
|
30 |
+
Pretrained model for encoding text into feature vectors.
|
31 |
+
"""
|
32 |
+
|
33 |
+
model = TorchMoji(nb_classes=None,
|
34 |
+
nb_tokens=NB_TOKENS,
|
35 |
+
feature_output=True,
|
36 |
+
return_attention=return_attention)
|
37 |
+
load_specific_weights(model, weight_path, exclude_names=['output_layer'])
|
38 |
+
return model
|
39 |
+
|
40 |
+
|
41 |
+
def torchmoji_emojis(weight_path, return_attention=False):
|
42 |
+
""" Loads the pretrained torchMoji model for extracting features
|
43 |
+
from the penultimate feature layer. In this way, it transforms
|
44 |
+
the text into its emotional encoding.
|
45 |
+
|
46 |
+
# Arguments:
|
47 |
+
weight_path: Path to model weights to be loaded.
|
48 |
+
return_attention: If true, output will include weight of each input token
|
49 |
+
used for the prediction
|
50 |
+
|
51 |
+
# Returns:
|
52 |
+
Pretrained model for encoding text into feature vectors.
|
53 |
+
"""
|
54 |
+
|
55 |
+
model = TorchMoji(nb_classes=NB_EMOJI_CLASSES,
|
56 |
+
nb_tokens=NB_TOKENS,
|
57 |
+
return_attention=return_attention)
|
58 |
+
model.load_state_dict(torch.load(weight_path))
|
59 |
+
return model
|
60 |
+
|
61 |
+
|
62 |
+
def torchmoji_transfer(nb_classes, weight_path=None, extend_embedding=0,
|
63 |
+
embed_dropout_rate=0.1, final_dropout_rate=0.5):
|
64 |
+
""" Loads the pretrained torchMoji model for finetuning/transfer learning.
|
65 |
+
Does not load weights for the softmax layer.
|
66 |
+
|
67 |
+
Note that if you are planning to use class average F1 for evaluation,
|
68 |
+
nb_classes should be set to 2 instead of the actual number of classes
|
69 |
+
in the dataset, since binary classification will be performed on each
|
70 |
+
class individually.
|
71 |
+
|
72 |
+
Note that for the 'new' method, weight_path should be left as None.
|
73 |
+
|
74 |
+
# Arguments:
|
75 |
+
nb_classes: Number of classes in the dataset.
|
76 |
+
weight_path: Path to model weights to be loaded.
|
77 |
+
extend_embedding: Number of tokens that have been added to the
|
78 |
+
vocabulary on top of NB_TOKENS. If this number is larger than 0,
|
79 |
+
the embedding layer's dimensions are adjusted accordingly, with the
|
80 |
+
additional weights being set to random values.
|
81 |
+
embed_dropout_rate: Dropout rate for the embedding layer.
|
82 |
+
final_dropout_rate: Dropout rate for the final Softmax layer.
|
83 |
+
|
84 |
+
# Returns:
|
85 |
+
Model with the given parameters.
|
86 |
+
"""
|
87 |
+
|
88 |
+
model = TorchMoji(nb_classes=nb_classes,
|
89 |
+
nb_tokens=NB_TOKENS + extend_embedding,
|
90 |
+
embed_dropout_rate=embed_dropout_rate,
|
91 |
+
final_dropout_rate=final_dropout_rate,
|
92 |
+
output_logits=True)
|
93 |
+
if weight_path is not None:
|
94 |
+
load_specific_weights(model, weight_path,
|
95 |
+
exclude_names=['output_layer'],
|
96 |
+
extend_embedding=extend_embedding)
|
97 |
+
return model
|
98 |
+
|
99 |
+
|
100 |
+
class TorchMoji(nn.Module):
|
101 |
+
def __init__(self, nb_classes, nb_tokens, feature_output=False, output_logits=False,
|
102 |
+
embed_dropout_rate=0, final_dropout_rate=0, return_attention=False):
|
103 |
+
"""
|
104 |
+
torchMoji model.
|
105 |
+
IMPORTANT: The model is loaded in evaluation mode by default (self.eval())
|
106 |
+
|
107 |
+
# Arguments:
|
108 |
+
nb_classes: Number of classes in the dataset.
|
109 |
+
nb_tokens: Number of tokens in the dataset (i.e. vocabulary size).
|
110 |
+
feature_output: If True the model returns the penultimate
|
111 |
+
feature vector rather than Softmax probabilities
|
112 |
+
(defaults to False).
|
113 |
+
output_logits: If True the model returns logits rather than probabilities
|
114 |
+
(defaults to False).
|
115 |
+
embed_dropout_rate: Dropout rate for the embedding layer.
|
116 |
+
final_dropout_rate: Dropout rate for the final Softmax layer.
|
117 |
+
return_attention: If True the model also returns attention weights over the sentence
|
118 |
+
(defaults to False).
|
119 |
+
"""
|
120 |
+
super(TorchMoji, self).__init__()
|
121 |
+
|
122 |
+
embedding_dim = 256
|
123 |
+
hidden_size = 512
|
124 |
+
attention_size = 4 * hidden_size + embedding_dim
|
125 |
+
|
126 |
+
self.feature_output = feature_output
|
127 |
+
self.embed_dropout_rate = embed_dropout_rate
|
128 |
+
self.final_dropout_rate = final_dropout_rate
|
129 |
+
self.return_attention = return_attention
|
130 |
+
self.hidden_size = hidden_size
|
131 |
+
self.output_logits = output_logits
|
132 |
+
self.nb_classes = nb_classes
|
133 |
+
|
134 |
+
self.add_module('embed', nn.Embedding(nb_tokens, embedding_dim))
|
135 |
+
# dropout2D: embedding channels are dropped out instead of words
|
136 |
+
# many exampels in the datasets contain few words that losing one or more words can alter the emotions completely
|
137 |
+
self.add_module('embed_dropout', nn.Dropout2d(embed_dropout_rate))
|
138 |
+
self.add_module('lstm_0', LSTMHardSigmoid(embedding_dim, hidden_size, batch_first=True, bidirectional=True))
|
139 |
+
self.add_module('lstm_1', LSTMHardSigmoid(hidden_size*2, hidden_size, batch_first=True, bidirectional=True))
|
140 |
+
self.add_module('attention_layer', Attention(attention_size=attention_size, return_attention=return_attention))
|
141 |
+
if not feature_output:
|
142 |
+
self.add_module('final_dropout', nn.Dropout(final_dropout_rate))
|
143 |
+
if output_logits:
|
144 |
+
self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1)))
|
145 |
+
else:
|
146 |
+
self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1),
|
147 |
+
nn.Softmax() if self.nb_classes > 2 else nn.Sigmoid()))
|
148 |
+
self.init_weights()
|
149 |
+
# Put model in evaluation mode by default
|
150 |
+
self.eval()
|
151 |
+
|
152 |
+
def init_weights(self):
|
153 |
+
"""
|
154 |
+
Here we reproduce Keras default initialization weights for consistency with Keras version
|
155 |
+
"""
|
156 |
+
ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
|
157 |
+
hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
|
158 |
+
b = (param.data for name, param in self.named_parameters() if 'bias' in name)
|
159 |
+
nn.init.uniform(self.embed.weight.data, a=-0.5, b=0.5)
|
160 |
+
for t in ih:
|
161 |
+
nn.init.xavier_uniform(t)
|
162 |
+
for t in hh:
|
163 |
+
nn.init.orthogonal(t)
|
164 |
+
for t in b:
|
165 |
+
nn.init.constant(t, 0)
|
166 |
+
if not self.feature_output:
|
167 |
+
nn.init.xavier_uniform(self.output_layer[0].weight.data)
|
168 |
+
|
169 |
+
def forward(self, input_seqs):
|
170 |
+
""" Forward pass.
|
171 |
+
|
172 |
+
# Arguments:
|
173 |
+
input_seqs: Can be one of Numpy array, Torch.LongTensor, Torch.Variable, Torch.PackedSequence.
|
174 |
+
|
175 |
+
# Return:
|
176 |
+
Same format as input format (except for PackedSequence returned as Variable).
|
177 |
+
"""
|
178 |
+
# Check if we have Torch.LongTensor inputs or not Torch.Variable (assume Numpy array in this case), take note to return same format
|
179 |
+
return_numpy = False
|
180 |
+
return_tensor = False
|
181 |
+
if isinstance(input_seqs, (torch.LongTensor, torch.cuda.LongTensor)):
|
182 |
+
input_seqs = Variable(input_seqs)
|
183 |
+
return_tensor = True
|
184 |
+
elif not isinstance(input_seqs, Variable):
|
185 |
+
input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long())
|
186 |
+
return_numpy = True
|
187 |
+
|
188 |
+
# If we don't have a packed inputs, let's pack it
|
189 |
+
reorder_output = False
|
190 |
+
if not isinstance(input_seqs, PackedSequence):
|
191 |
+
ho = self.lstm_0.weight_hh_l0.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
|
192 |
+
co = self.lstm_0.weight_hh_l0.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
|
193 |
+
|
194 |
+
# Reorder batch by sequence length
|
195 |
+
input_lengths = torch.LongTensor([torch.max(input_seqs[i, :].data.nonzero()) + 1 for i in range(input_seqs.size()[0])])
|
196 |
+
input_lengths, perm_idx = input_lengths.sort(0, descending=True)
|
197 |
+
input_seqs = input_seqs[perm_idx][:, :input_lengths.max()]
|
198 |
+
|
199 |
+
# Pack sequence and work on data tensor to reduce embeddings/dropout computations
|
200 |
+
packed_input = pack_padded_sequence(input_seqs, input_lengths.cpu().numpy(), batch_first=True)
|
201 |
+
reorder_output = True
|
202 |
+
else:
|
203 |
+
ho = self.lstm_0.weight_hh_l0.data.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
|
204 |
+
co = self.lstm_0.weight_hh_l0.data.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
|
205 |
+
input_lengths = input_seqs.batch_sizes
|
206 |
+
packed_input = input_seqs
|
207 |
+
|
208 |
+
hidden = (Variable(ho, requires_grad=False), Variable(co, requires_grad=False))
|
209 |
+
|
210 |
+
# Embed with an activation function to bound the values of the embeddings
|
211 |
+
x = self.embed(packed_input.data)
|
212 |
+
x = nn.Tanh()(x)
|
213 |
+
|
214 |
+
# pyTorch 2D dropout2d operate on axis 1 which is fine for us
|
215 |
+
x = self.embed_dropout(x)
|
216 |
+
|
217 |
+
# Update packed sequence data for RNN
|
218 |
+
packed_input = PackedSequence(x, packed_input.batch_sizes)
|
219 |
+
|
220 |
+
# skip-connection from embedding to output eases gradient-flow and allows access to lower-level features
|
221 |
+
# ordering of the way the merge is done is important for consistency with the pretrained model
|
222 |
+
lstm_0_output, _ = self.lstm_0(packed_input, hidden)
|
223 |
+
lstm_1_output, _ = self.lstm_1(lstm_0_output, hidden)
|
224 |
+
|
225 |
+
# Update packed sequence data for attention layer
|
226 |
+
packed_input = PackedSequence(torch.cat((lstm_1_output.data,
|
227 |
+
lstm_0_output.data,
|
228 |
+
packed_input.data), dim=1),
|
229 |
+
packed_input.batch_sizes)
|
230 |
+
|
231 |
+
input_seqs, _ = pad_packed_sequence(packed_input, batch_first=True)
|
232 |
+
|
233 |
+
x, att_weights = self.attention_layer(input_seqs, input_lengths)
|
234 |
+
|
235 |
+
# output class probabilities or penultimate feature vector
|
236 |
+
if not self.feature_output:
|
237 |
+
x = self.final_dropout(x)
|
238 |
+
outputs = self.output_layer(x)
|
239 |
+
else:
|
240 |
+
outputs = x
|
241 |
+
|
242 |
+
# Reorder output if needed
|
243 |
+
if reorder_output:
|
244 |
+
reorered = Variable(outputs.data.new(outputs.size()))
|
245 |
+
reorered[perm_idx] = outputs
|
246 |
+
outputs = reorered
|
247 |
+
|
248 |
+
# Adapt return format if needed
|
249 |
+
if return_tensor:
|
250 |
+
outputs = outputs.data
|
251 |
+
if return_numpy:
|
252 |
+
outputs = outputs.data.numpy()
|
253 |
+
|
254 |
+
if self.return_attention:
|
255 |
+
return outputs, att_weights
|
256 |
+
else:
|
257 |
+
return outputs
|
258 |
+
|
259 |
+
|
260 |
+
def load_specific_weights(model, weight_path, exclude_names=[], extend_embedding=0, verbose=True):
|
261 |
+
""" Loads model weights from the given file path, excluding any
|
262 |
+
given layers.
|
263 |
+
|
264 |
+
# Arguments:
|
265 |
+
model: Model whose weights should be loaded.
|
266 |
+
weight_path: Path to file containing model weights.
|
267 |
+
exclude_names: List of layer names whose weights should not be loaded.
|
268 |
+
extend_embedding: Number of new words being added to vocabulary.
|
269 |
+
verbose: Verbosity flag.
|
270 |
+
|
271 |
+
# Raises:
|
272 |
+
ValueError if the file at weight_path does not exist.
|
273 |
+
"""
|
274 |
+
if not exists(weight_path):
|
275 |
+
raise ValueError('ERROR (load_weights): The weights file at {} does '
|
276 |
+
'not exist. Refer to the README for instructions.'
|
277 |
+
.format(weight_path))
|
278 |
+
|
279 |
+
if extend_embedding and 'embed' in exclude_names:
|
280 |
+
raise ValueError('ERROR (load_weights): Cannot extend a vocabulary '
|
281 |
+
'without loading the embedding weights.')
|
282 |
+
|
283 |
+
# Copy only weights from the temporary model that are wanted
|
284 |
+
# for the specific task (e.g. the Softmax is often ignored)
|
285 |
+
weights = torch.load(weight_path)
|
286 |
+
for key, weight in weights.items():
|
287 |
+
if any(excluded in key for excluded in exclude_names):
|
288 |
+
if verbose:
|
289 |
+
print('Ignoring weights for {}'.format(key))
|
290 |
+
continue
|
291 |
+
|
292 |
+
try:
|
293 |
+
model_w = model.state_dict()[key]
|
294 |
+
except KeyError:
|
295 |
+
raise KeyError("Weights had parameters {},".format(key)
|
296 |
+
+ " but could not find this parameters in model.")
|
297 |
+
|
298 |
+
if verbose:
|
299 |
+
print('Loading weights for {}'.format(key))
|
300 |
+
|
301 |
+
# extend embedding layer to allow new randomly initialized words
|
302 |
+
# if requested. Otherwise, just load the weights for the layer.
|
303 |
+
if 'embed' in key and extend_embedding > 0:
|
304 |
+
weight = torch.cat((weight, model_w[NB_TOKENS:, :]), dim=0)
|
305 |
+
if verbose:
|
306 |
+
print('Extended vocabulary for embedding layer ' +
|
307 |
+
'from {} to {} tokens.'.format(
|
308 |
+
NB_TOKENS, NB_TOKENS + extend_embedding))
|
309 |
+
try:
|
310 |
+
model_w.copy_(weight)
|
311 |
+
except:
|
312 |
+
print('While copying the weigths named {}, whose dimensions in the model are'
|
313 |
+
' {} and whose dimensions in the saved file are {}, ...'.format(
|
314 |
+
key, model_w.size(), weight.size()))
|
315 |
+
raise
|
torchmoji/sentence_tokenizer.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
Provides functionality for converting a given list of tokens (words) into
|
4 |
+
numbers, according to the given vocabulary.
|
5 |
+
'''
|
6 |
+
from __future__ import print_function, division, unicode_literals
|
7 |
+
|
8 |
+
import numbers
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchmoji.create_vocab import extend_vocab, VocabBuilder
|
12 |
+
from torchmoji.word_generator import WordGenerator
|
13 |
+
from torchmoji.global_variables import SPECIAL_TOKENS
|
14 |
+
|
15 |
+
# import torch
|
16 |
+
|
17 |
+
from sklearn.model_selection import train_test_split
|
18 |
+
|
19 |
+
from copy import deepcopy
|
20 |
+
|
21 |
+
class SentenceTokenizer():
|
22 |
+
""" Create numpy array of tokens corresponding to input sentences.
|
23 |
+
The vocabulary can include Unicode tokens.
|
24 |
+
"""
|
25 |
+
def __init__(self, vocabulary, fixed_length, custom_wordgen=None,
|
26 |
+
ignore_sentences_with_only_custom=False, masking_value=0,
|
27 |
+
unknown_value=1):
|
28 |
+
""" Needs a dictionary as input for the vocabulary.
|
29 |
+
"""
|
30 |
+
|
31 |
+
if len(vocabulary) > np.iinfo('uint16').max:
|
32 |
+
raise ValueError('Dictionary is too big ({} tokens) for the numpy '
|
33 |
+
'datatypes used (max limit={}). Reduce vocabulary'
|
34 |
+
' or adjust code accordingly!'
|
35 |
+
.format(len(vocabulary), np.iinfo('uint16').max))
|
36 |
+
|
37 |
+
# Shouldn't be able to modify the given vocabulary
|
38 |
+
self.vocabulary = deepcopy(vocabulary)
|
39 |
+
self.fixed_length = fixed_length
|
40 |
+
self.ignore_sentences_with_only_custom = ignore_sentences_with_only_custom
|
41 |
+
self.masking_value = masking_value
|
42 |
+
self.unknown_value = unknown_value
|
43 |
+
|
44 |
+
# Initialized with an empty stream of sentences that must then be fed
|
45 |
+
# to the generator at a later point for reusability.
|
46 |
+
# A custom word generator can be used for domain-specific filtering etc
|
47 |
+
if custom_wordgen is not None:
|
48 |
+
assert custom_wordgen.stream is None
|
49 |
+
self.wordgen = custom_wordgen
|
50 |
+
self.uses_custom_wordgen = True
|
51 |
+
else:
|
52 |
+
self.wordgen = WordGenerator(None, allow_unicode_text=True,
|
53 |
+
ignore_emojis=False,
|
54 |
+
remove_variation_selectors=True,
|
55 |
+
break_replacement=True)
|
56 |
+
self.uses_custom_wordgen = False
|
57 |
+
|
58 |
+
def tokenize_sentences(self, sentences, reset_stats=True, max_sentences=None):
|
59 |
+
""" Converts a given list of sentences into a numpy array according to
|
60 |
+
its vocabulary.
|
61 |
+
|
62 |
+
# Arguments:
|
63 |
+
sentences: List of sentences to be tokenized.
|
64 |
+
reset_stats: Whether the word generator's stats should be reset.
|
65 |
+
max_sentences: Maximum length of sentences. Must be set if the
|
66 |
+
length cannot be inferred from the input.
|
67 |
+
|
68 |
+
# Returns:
|
69 |
+
Numpy array of the tokenization sentences with masking,
|
70 |
+
infos,
|
71 |
+
stats
|
72 |
+
|
73 |
+
# Raises:
|
74 |
+
ValueError: When maximum length is not set and cannot be inferred.
|
75 |
+
"""
|
76 |
+
|
77 |
+
if max_sentences is None and not hasattr(sentences, '__len__'):
|
78 |
+
raise ValueError('Either you must provide an array with a length'
|
79 |
+
'attribute (e.g. a list) or specify the maximum '
|
80 |
+
'length yourself using `max_sentences`!')
|
81 |
+
n_sentences = (max_sentences if max_sentences is not None
|
82 |
+
else len(sentences))
|
83 |
+
|
84 |
+
if self.masking_value == 0:
|
85 |
+
tokens = np.zeros((n_sentences, self.fixed_length), dtype='uint16')
|
86 |
+
else:
|
87 |
+
tokens = (np.ones((n_sentences, self.fixed_length), dtype='uint16')
|
88 |
+
* self.masking_value)
|
89 |
+
|
90 |
+
if reset_stats:
|
91 |
+
self.wordgen.reset_stats()
|
92 |
+
|
93 |
+
# With a custom word generator info can be extracted from each
|
94 |
+
# sentence (e.g. labels)
|
95 |
+
infos = []
|
96 |
+
|
97 |
+
# Returns words as strings and then map them to vocabulary
|
98 |
+
self.wordgen.stream = sentences
|
99 |
+
next_insert = 0
|
100 |
+
n_ignored_unknowns = 0
|
101 |
+
for s_words, s_info in self.wordgen:
|
102 |
+
s_tokens = self.find_tokens(s_words)
|
103 |
+
|
104 |
+
if (self.ignore_sentences_with_only_custom and
|
105 |
+
np.all([True if t < len(SPECIAL_TOKENS)
|
106 |
+
else False for t in s_tokens])):
|
107 |
+
n_ignored_unknowns += 1
|
108 |
+
continue
|
109 |
+
if len(s_tokens) > self.fixed_length:
|
110 |
+
s_tokens = s_tokens[:self.fixed_length]
|
111 |
+
tokens[next_insert,:len(s_tokens)] = s_tokens
|
112 |
+
infos.append(s_info)
|
113 |
+
next_insert += 1
|
114 |
+
|
115 |
+
# For standard word generators all sentences should be tokenized
|
116 |
+
# this is not necessarily the case for custom wordgenerators as they
|
117 |
+
# may filter the sentences etc.
|
118 |
+
if not self.uses_custom_wordgen and not self.ignore_sentences_with_only_custom:
|
119 |
+
assert len(sentences) == next_insert
|
120 |
+
else:
|
121 |
+
# adjust based on actual tokens received
|
122 |
+
tokens = tokens[:next_insert]
|
123 |
+
infos = infos[:next_insert]
|
124 |
+
|
125 |
+
return tokens, infos, self.wordgen.stats
|
126 |
+
|
127 |
+
def find_tokens(self, words):
|
128 |
+
assert len(words) > 0
|
129 |
+
tokens = []
|
130 |
+
for w in words:
|
131 |
+
try:
|
132 |
+
tokens.append(self.vocabulary[w])
|
133 |
+
except KeyError:
|
134 |
+
tokens.append(self.unknown_value)
|
135 |
+
return tokens
|
136 |
+
|
137 |
+
def split_train_val_test(self, sentences, info_dicts,
|
138 |
+
split_parameter=[0.7, 0.1, 0.2], extend_with=0):
|
139 |
+
""" Splits given sentences into three different datasets: training,
|
140 |
+
validation and testing.
|
141 |
+
|
142 |
+
# Arguments:
|
143 |
+
sentences: The sentences to be tokenized.
|
144 |
+
info_dicts: A list of dicts that contain information about each
|
145 |
+
sentence (e.g. a label).
|
146 |
+
split_parameter: A parameter for deciding the splits between the
|
147 |
+
three different datasets. If instead of being passed three
|
148 |
+
values, three lists are passed, then these will be used to
|
149 |
+
specify which observation belong to which dataset.
|
150 |
+
extend_with: An optional parameter. If > 0 then this is the number
|
151 |
+
of tokens added to the vocabulary from this dataset. The
|
152 |
+
expanded vocab will be generated using only the training set,
|
153 |
+
but is applied to all three sets.
|
154 |
+
|
155 |
+
# Returns:
|
156 |
+
List of three lists of tokenized sentences,
|
157 |
+
|
158 |
+
List of three corresponding dictionaries with information,
|
159 |
+
|
160 |
+
How many tokens have been added to the vocab. Make sure to extend
|
161 |
+
the embedding layer of the model accordingly.
|
162 |
+
"""
|
163 |
+
|
164 |
+
# If passed three lists, use those directly
|
165 |
+
if isinstance(split_parameter, list) and \
|
166 |
+
all(isinstance(x, list) for x in split_parameter) and \
|
167 |
+
len(split_parameter) == 3:
|
168 |
+
|
169 |
+
# Helper function to verify provided indices are numbers in range
|
170 |
+
def verify_indices(inds):
|
171 |
+
return list(filter(lambda i: isinstance(i, numbers.Number)
|
172 |
+
and i < len(sentences), inds))
|
173 |
+
|
174 |
+
ind_train = verify_indices(split_parameter[0])
|
175 |
+
ind_val = verify_indices(split_parameter[1])
|
176 |
+
ind_test = verify_indices(split_parameter[2])
|
177 |
+
else:
|
178 |
+
# Split sentences and dicts
|
179 |
+
ind = list(range(len(sentences)))
|
180 |
+
ind_train, ind_test = train_test_split(ind, test_size=split_parameter[2])
|
181 |
+
ind_train, ind_val = train_test_split(ind_train, test_size=split_parameter[1])
|
182 |
+
|
183 |
+
# Map indices to data
|
184 |
+
train = np.array([sentences[x] for x in ind_train])
|
185 |
+
test = np.array([sentences[x] for x in ind_test])
|
186 |
+
val = np.array([sentences[x] for x in ind_val])
|
187 |
+
|
188 |
+
info_train = np.array([info_dicts[x] for x in ind_train])
|
189 |
+
info_test = np.array([info_dicts[x] for x in ind_test])
|
190 |
+
info_val = np.array([info_dicts[x] for x in ind_val])
|
191 |
+
|
192 |
+
added = 0
|
193 |
+
# Extend vocabulary with training set tokens
|
194 |
+
if extend_with > 0:
|
195 |
+
wg = WordGenerator(train)
|
196 |
+
vb = VocabBuilder(wg)
|
197 |
+
vb.count_all_words()
|
198 |
+
added = extend_vocab(self.vocabulary, vb, max_tokens=extend_with)
|
199 |
+
|
200 |
+
# Wrap results
|
201 |
+
result = [self.tokenize_sentences(s)[0] for s in [train, val, test]]
|
202 |
+
result_infos = [info_train, info_val, info_test]
|
203 |
+
# if type(result_infos[0][0]) in [np.double, np.float, np.int64, np.int32, np.uint8]:
|
204 |
+
# result_infos = [torch.from_numpy(label).long() for label in result_infos]
|
205 |
+
|
206 |
+
return result, result_infos, added
|
207 |
+
|
208 |
+
def to_sentence(self, sentence_idx):
|
209 |
+
""" Converts a tokenized sentence back to a list of words.
|
210 |
+
|
211 |
+
# Arguments:
|
212 |
+
sentence_idx: List of numbers, representing a tokenized sentence
|
213 |
+
given the current vocabulary.
|
214 |
+
|
215 |
+
# Returns:
|
216 |
+
String created by converting all numbers back to words and joined
|
217 |
+
together with spaces.
|
218 |
+
"""
|
219 |
+
# Have to recalculate the mappings in case the vocab was extended.
|
220 |
+
ind_to_word = {ind: word for word, ind in self.vocabulary.items()}
|
221 |
+
|
222 |
+
sentence_as_list = [ind_to_word[x] for x in sentence_idx]
|
223 |
+
cleaned_list = [x for x in sentence_as_list if x != 'CUSTOM_MASK']
|
224 |
+
return " ".join(cleaned_list)
|
225 |
+
|
226 |
+
|
227 |
+
def coverage(dataset, verbose=False):
|
228 |
+
""" Computes the percentage of words in a given dataset that are unknown.
|
229 |
+
|
230 |
+
# Arguments:
|
231 |
+
dataset: Tokenized dataset to be checked.
|
232 |
+
verbose: Verbosity flag.
|
233 |
+
|
234 |
+
# Returns:
|
235 |
+
Percentage of unknown tokens.
|
236 |
+
"""
|
237 |
+
n_total = np.count_nonzero(dataset)
|
238 |
+
n_unknown = np.sum(dataset == 1)
|
239 |
+
coverage = 1.0 - float(n_unknown) / n_total
|
240 |
+
|
241 |
+
if verbose:
|
242 |
+
print("Unknown words: {}".format(n_unknown))
|
243 |
+
print("Total words: {}".format(n_total))
|
244 |
+
print("Coverage: {}".format(coverage))
|
245 |
+
return coverage
|
torchmoji/tokenizer.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
Splits up a Unicode string into a list of tokens.
|
4 |
+
Recognises:
|
5 |
+
- Abbreviations
|
6 |
+
- URLs
|
7 |
+
- Emails
|
8 |
+
- #hashtags
|
9 |
+
- @mentions
|
10 |
+
- emojis
|
11 |
+
- emoticons (limited support)
|
12 |
+
|
13 |
+
Multiple consecutive symbols are also treated as a single token.
|
14 |
+
'''
|
15 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
16 |
+
|
17 |
+
import re
|
18 |
+
|
19 |
+
# Basic patterns.
|
20 |
+
RE_NUM = r'[0-9]+'
|
21 |
+
RE_WORD = r'[a-zA-Z]+'
|
22 |
+
RE_WHITESPACE = r'\s+'
|
23 |
+
RE_ANY = r'.'
|
24 |
+
|
25 |
+
# Combined words such as 'red-haired' or 'CUSTOM_TOKEN'
|
26 |
+
RE_COMB = r'[a-zA-Z]+[-_][a-zA-Z]+'
|
27 |
+
|
28 |
+
# English-specific patterns
|
29 |
+
RE_CONTRACTIONS = RE_WORD + r'\'' + RE_WORD
|
30 |
+
|
31 |
+
TITLES = [
|
32 |
+
r'Mr\.',
|
33 |
+
r'Ms\.',
|
34 |
+
r'Mrs\.',
|
35 |
+
r'Dr\.',
|
36 |
+
r'Prof\.',
|
37 |
+
]
|
38 |
+
# Ensure case insensitivity
|
39 |
+
RE_TITLES = r'|'.join([r'(?i)' + t for t in TITLES])
|
40 |
+
|
41 |
+
# Symbols have to be created as separate patterns in order to match consecutive
|
42 |
+
# identical symbols.
|
43 |
+
SYMBOLS = r'()<!?.,/\'\"-_=\\§|´ˇ°[]<>{}~$^&*;:%+\xa3€`'
|
44 |
+
RE_SYMBOL = r'|'.join([re.escape(s) + r'+' for s in SYMBOLS])
|
45 |
+
|
46 |
+
# Hash symbols and at symbols have to be defined separately in order to not
|
47 |
+
# clash with hashtags and mentions if there are multiple - i.e.
|
48 |
+
# ##hello -> ['#', '#hello'] instead of ['##', 'hello']
|
49 |
+
SPECIAL_SYMBOLS = r'|#+(?=#[a-zA-Z0-9_]+)|@+(?=@[a-zA-Z0-9_]+)|#+|@+'
|
50 |
+
RE_SYMBOL += SPECIAL_SYMBOLS
|
51 |
+
|
52 |
+
RE_ABBREVIATIONS = r'\b(?<!\.)(?:[A-Za-z]\.){2,}'
|
53 |
+
|
54 |
+
# Twitter-specific patterns
|
55 |
+
RE_HASHTAG = r'#[a-zA-Z0-9_]+'
|
56 |
+
RE_MENTION = r'@[a-zA-Z0-9_]+'
|
57 |
+
|
58 |
+
RE_URL = r'(?:https?://|www\.)(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
|
59 |
+
RE_EMAIL = r'\b[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+\b'
|
60 |
+
|
61 |
+
# Emoticons and emojis
|
62 |
+
RE_HEART = r'(?:<+/?3+)+'
|
63 |
+
EMOTICONS_START = [
|
64 |
+
r'>:',
|
65 |
+
r':',
|
66 |
+
r'=',
|
67 |
+
r';',
|
68 |
+
]
|
69 |
+
EMOTICONS_MID = [
|
70 |
+
r'-',
|
71 |
+
r',',
|
72 |
+
r'^',
|
73 |
+
'\'',
|
74 |
+
'\"',
|
75 |
+
]
|
76 |
+
EMOTICONS_END = [
|
77 |
+
r'D',
|
78 |
+
r'd',
|
79 |
+
r'p',
|
80 |
+
r'P',
|
81 |
+
r'v',
|
82 |
+
r')',
|
83 |
+
r'o',
|
84 |
+
r'O',
|
85 |
+
r'(',
|
86 |
+
r'3',
|
87 |
+
r'/',
|
88 |
+
r'|',
|
89 |
+
'\\',
|
90 |
+
]
|
91 |
+
EMOTICONS_EXTRA = [
|
92 |
+
r'-_-',
|
93 |
+
r'x_x',
|
94 |
+
r'^_^',
|
95 |
+
r'o.o',
|
96 |
+
r'o_o',
|
97 |
+
r'(:',
|
98 |
+
r'):',
|
99 |
+
r');',
|
100 |
+
r'(;',
|
101 |
+
]
|
102 |
+
|
103 |
+
RE_EMOTICON = r'|'.join([re.escape(s) for s in EMOTICONS_EXTRA])
|
104 |
+
for s in EMOTICONS_START:
|
105 |
+
for m in EMOTICONS_MID:
|
106 |
+
for e in EMOTICONS_END:
|
107 |
+
RE_EMOTICON += '|{0}{1}?{2}+'.format(re.escape(s), re.escape(m), re.escape(e))
|
108 |
+
|
109 |
+
# requires ucs4 in python2.7 or python3+
|
110 |
+
# RE_EMOJI = r"""[\U0001F300-\U0001F64F\U0001F680-\U0001F6FF\u2600-\u26FF\u2700-\u27BF]"""
|
111 |
+
# safe for all python
|
112 |
+
RE_EMOJI = r"""\ud83c[\udf00-\udfff]|\ud83d[\udc00-\ude4f\ude80-\udeff]|[\u2600-\u26FF\u2700-\u27BF]"""
|
113 |
+
|
114 |
+
# List of matched token patterns, ordered from most specific to least specific.
|
115 |
+
TOKENS = [
|
116 |
+
RE_URL,
|
117 |
+
RE_EMAIL,
|
118 |
+
RE_COMB,
|
119 |
+
RE_HASHTAG,
|
120 |
+
RE_MENTION,
|
121 |
+
RE_HEART,
|
122 |
+
RE_EMOTICON,
|
123 |
+
RE_CONTRACTIONS,
|
124 |
+
RE_TITLES,
|
125 |
+
RE_ABBREVIATIONS,
|
126 |
+
RE_NUM,
|
127 |
+
RE_WORD,
|
128 |
+
RE_SYMBOL,
|
129 |
+
RE_EMOJI,
|
130 |
+
RE_ANY
|
131 |
+
]
|
132 |
+
|
133 |
+
# List of ignored token patterns
|
134 |
+
IGNORED = [
|
135 |
+
RE_WHITESPACE
|
136 |
+
]
|
137 |
+
|
138 |
+
# Final pattern
|
139 |
+
RE_PATTERN = re.compile(r'|'.join(IGNORED) + r'|(' + r'|'.join(TOKENS) + r')',
|
140 |
+
re.UNICODE)
|
141 |
+
|
142 |
+
|
143 |
+
def tokenize(text):
|
144 |
+
'''Splits given input string into a list of tokens.
|
145 |
+
|
146 |
+
# Arguments:
|
147 |
+
text: Input string to be tokenized.
|
148 |
+
|
149 |
+
# Returns:
|
150 |
+
List of strings (tokens).
|
151 |
+
'''
|
152 |
+
result = RE_PATTERN.findall(text)
|
153 |
+
|
154 |
+
# Remove empty strings
|
155 |
+
result = [t for t in result if t.strip()]
|
156 |
+
return result
|
torchmoji/word_generator.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
''' Extracts lists of words from a given input to be used for later vocabulary
|
3 |
+
generation or for creating tokenized datasets.
|
4 |
+
Supports functionality for handling different file types and
|
5 |
+
filtering/processing of this input.
|
6 |
+
'''
|
7 |
+
|
8 |
+
from __future__ import division, print_function, unicode_literals
|
9 |
+
|
10 |
+
import re
|
11 |
+
import unicodedata
|
12 |
+
import numpy as np
|
13 |
+
from text_unidecode import unidecode
|
14 |
+
|
15 |
+
from torchmoji.tokenizer import RE_MENTION, tokenize
|
16 |
+
from torchmoji.filter_utils import (convert_linebreaks,
|
17 |
+
convert_nonbreaking_space,
|
18 |
+
correct_length,
|
19 |
+
extract_emojis,
|
20 |
+
mostly_english,
|
21 |
+
non_english_user,
|
22 |
+
process_word,
|
23 |
+
punct_word,
|
24 |
+
remove_control_chars,
|
25 |
+
remove_variation_selectors,
|
26 |
+
separate_emojis_and_text)
|
27 |
+
|
28 |
+
try:
|
29 |
+
unicode # Python 2
|
30 |
+
except NameError:
|
31 |
+
unicode = str # Python 3
|
32 |
+
|
33 |
+
# Only catch retweets in the beginning of the tweet as those are the
|
34 |
+
# automatically added ones.
|
35 |
+
# We do not want to remove tweets like "Omg.. please RT this!!"
|
36 |
+
RETWEETS_RE = re.compile(r'^[rR][tT]')
|
37 |
+
|
38 |
+
# Use fast and less precise regex for removing tweets with URLs
|
39 |
+
# It doesn't matter too much if a few tweets with URL's make it through
|
40 |
+
URLS_RE = re.compile(r'https?://|www\.')
|
41 |
+
|
42 |
+
MENTION_RE = re.compile(RE_MENTION)
|
43 |
+
ALLOWED_CONVERTED_UNICODE_PUNCTUATION = """!"#$'()+,-.:;<=>?@`~"""
|
44 |
+
|
45 |
+
|
46 |
+
class WordGenerator():
|
47 |
+
''' Cleanses input and converts into words. Needs all sentences to be in
|
48 |
+
Unicode format. Has subclasses that read sentences differently based on
|
49 |
+
file type.
|
50 |
+
|
51 |
+
Takes a generator as input. This can be from e.g. a file.
|
52 |
+
unicode_handling in ['ignore_sentence', 'convert_punctuation', 'allow']
|
53 |
+
unicode_handling in ['ignore_emoji', 'ignore_sentence', 'allow']
|
54 |
+
'''
|
55 |
+
def __init__(self, stream, allow_unicode_text=False, ignore_emojis=True,
|
56 |
+
remove_variation_selectors=True, break_replacement=True):
|
57 |
+
self.stream = stream
|
58 |
+
self.allow_unicode_text = allow_unicode_text
|
59 |
+
self.remove_variation_selectors = remove_variation_selectors
|
60 |
+
self.ignore_emojis = ignore_emojis
|
61 |
+
self.break_replacement = break_replacement
|
62 |
+
self.reset_stats()
|
63 |
+
|
64 |
+
def get_words(self, sentence):
|
65 |
+
""" Tokenizes a sentence into individual words.
|
66 |
+
Converts Unicode punctuation into ASCII if that option is set.
|
67 |
+
Ignores sentences with Unicode if that option is set.
|
68 |
+
Returns an empty list of words if the sentence has Unicode and
|
69 |
+
that is not allowed.
|
70 |
+
"""
|
71 |
+
|
72 |
+
if not isinstance(sentence, unicode):
|
73 |
+
raise ValueError("All sentences should be Unicode-encoded!")
|
74 |
+
sentence = sentence.strip().lower()
|
75 |
+
|
76 |
+
if self.break_replacement:
|
77 |
+
sentence = convert_linebreaks(sentence)
|
78 |
+
|
79 |
+
if self.remove_variation_selectors:
|
80 |
+
sentence = remove_variation_selectors(sentence)
|
81 |
+
|
82 |
+
# Split into words using simple whitespace splitting and convert
|
83 |
+
# Unicode. This is done to prevent word splitting issues with
|
84 |
+
# twokenize and Unicode
|
85 |
+
words = sentence.split()
|
86 |
+
converted_words = []
|
87 |
+
for w in words:
|
88 |
+
accept_sentence, c_w = self.convert_unicode_word(w)
|
89 |
+
# Unicode word detected and not allowed
|
90 |
+
if not accept_sentence:
|
91 |
+
return []
|
92 |
+
else:
|
93 |
+
converted_words.append(c_w)
|
94 |
+
sentence = ' '.join(converted_words)
|
95 |
+
|
96 |
+
words = tokenize(sentence)
|
97 |
+
words = [process_word(w) for w in words]
|
98 |
+
return words
|
99 |
+
|
100 |
+
def check_ascii(self, word):
|
101 |
+
""" Returns whether a word is ASCII """
|
102 |
+
|
103 |
+
try:
|
104 |
+
word.decode('ascii')
|
105 |
+
return True
|
106 |
+
except (UnicodeDecodeError, UnicodeEncodeError, AttributeError):
|
107 |
+
return False
|
108 |
+
|
109 |
+
def convert_unicode_punctuation(self, word):
|
110 |
+
word_converted_punct = []
|
111 |
+
for c in word:
|
112 |
+
decoded_c = unidecode(c).lower()
|
113 |
+
if len(decoded_c) == 0:
|
114 |
+
# Cannot decode to anything reasonable
|
115 |
+
word_converted_punct.append(c)
|
116 |
+
else:
|
117 |
+
# Check if all punctuation and therefore fine
|
118 |
+
# to include unidecoded version
|
119 |
+
allowed_punct = punct_word(
|
120 |
+
decoded_c,
|
121 |
+
punctuation=ALLOWED_CONVERTED_UNICODE_PUNCTUATION)
|
122 |
+
|
123 |
+
if allowed_punct:
|
124 |
+
word_converted_punct.append(decoded_c)
|
125 |
+
else:
|
126 |
+
word_converted_punct.append(c)
|
127 |
+
return ''.join(word_converted_punct)
|
128 |
+
|
129 |
+
def convert_unicode_word(self, word):
|
130 |
+
""" Converts Unicode words to ASCII using unidecode. If Unicode is not
|
131 |
+
allowed (set as a variable during initialization), then only
|
132 |
+
punctuation that can be converted to ASCII will be allowed.
|
133 |
+
"""
|
134 |
+
if self.check_ascii(word):
|
135 |
+
return True, word
|
136 |
+
|
137 |
+
# First we ensure that the Unicode is normalized so it's
|
138 |
+
# always a single character.
|
139 |
+
word = unicodedata.normalize("NFKC", word)
|
140 |
+
|
141 |
+
# Convert Unicode punctuation to ASCII equivalent. We want
|
142 |
+
# e.g. "\u203c" (double exclamation mark) to be treated the same
|
143 |
+
# as "!!" no matter if we allow other Unicode characters or not.
|
144 |
+
word = self.convert_unicode_punctuation(word)
|
145 |
+
|
146 |
+
if self.ignore_emojis:
|
147 |
+
_, word = separate_emojis_and_text(word)
|
148 |
+
|
149 |
+
# If conversion of punctuation and removal of emojis took care
|
150 |
+
# of all the Unicode or if we allow Unicode then everything is fine
|
151 |
+
if self.check_ascii(word) or self.allow_unicode_text:
|
152 |
+
return True, word
|
153 |
+
else:
|
154 |
+
# Sometimes we might want to simply ignore Unicode sentences
|
155 |
+
# (e.g. for vocabulary creation). This is another way to prevent
|
156 |
+
# "polution" of strange Unicode tokens from low quality datasets
|
157 |
+
return False, ''
|
158 |
+
|
159 |
+
def data_preprocess_filtering(self, line, iter_i):
|
160 |
+
""" To be overridden with specific preprocessing/filtering behavior
|
161 |
+
if desired.
|
162 |
+
|
163 |
+
Returns a boolean of whether the line should be accepted and the
|
164 |
+
preprocessed text.
|
165 |
+
|
166 |
+
Runs prior to tokenization.
|
167 |
+
"""
|
168 |
+
return True, line, {}
|
169 |
+
|
170 |
+
def data_postprocess_filtering(self, words, iter_i):
|
171 |
+
""" To be overridden with specific postprocessing/filtering behavior
|
172 |
+
if desired.
|
173 |
+
|
174 |
+
Returns a boolean of whether the line should be accepted and the
|
175 |
+
postprocessed text.
|
176 |
+
|
177 |
+
Runs after tokenization.
|
178 |
+
"""
|
179 |
+
return True, words, {}
|
180 |
+
|
181 |
+
def extract_valid_sentence_words(self, line):
|
182 |
+
""" Line may either a string of a list of strings depending on how
|
183 |
+
the stream is being parsed.
|
184 |
+
Domain-specific processing and filtering can be done both prior to
|
185 |
+
and after tokenization.
|
186 |
+
Custom information about the line can be extracted during the
|
187 |
+
processing phases and returned as a dict.
|
188 |
+
"""
|
189 |
+
|
190 |
+
info = {}
|
191 |
+
|
192 |
+
pre_valid, pre_line, pre_info = \
|
193 |
+
self.data_preprocess_filtering(line, self.stats['total'])
|
194 |
+
info.update(pre_info)
|
195 |
+
if not pre_valid:
|
196 |
+
self.stats['pretokenization_filtered'] += 1
|
197 |
+
return False, [], info
|
198 |
+
|
199 |
+
words = self.get_words(pre_line)
|
200 |
+
if len(words) == 0:
|
201 |
+
self.stats['unicode_filtered'] += 1
|
202 |
+
return False, [], info
|
203 |
+
|
204 |
+
post_valid, post_words, post_info = \
|
205 |
+
self.data_postprocess_filtering(words, self.stats['total'])
|
206 |
+
info.update(post_info)
|
207 |
+
if not post_valid:
|
208 |
+
self.stats['posttokenization_filtered'] += 1
|
209 |
+
return post_valid, post_words, info
|
210 |
+
|
211 |
+
def generate_array_from_input(self):
|
212 |
+
sentences = []
|
213 |
+
for words in self:
|
214 |
+
sentences.append(words)
|
215 |
+
return sentences
|
216 |
+
|
217 |
+
def reset_stats(self):
|
218 |
+
self.stats = {'pretokenization_filtered': 0,
|
219 |
+
'unicode_filtered': 0,
|
220 |
+
'posttokenization_filtered': 0,
|
221 |
+
'total': 0,
|
222 |
+
'valid': 0}
|
223 |
+
|
224 |
+
def __iter__(self):
|
225 |
+
if self.stream is None:
|
226 |
+
raise ValueError("Stream should be set before iterating over it!")
|
227 |
+
|
228 |
+
for line in self.stream:
|
229 |
+
valid, words, info = self.extract_valid_sentence_words(line)
|
230 |
+
|
231 |
+
# Words may be filtered away due to unidecode etc.
|
232 |
+
# In that case the words should not be passed on.
|
233 |
+
if valid and len(words):
|
234 |
+
self.stats['valid'] += 1
|
235 |
+
yield words, info
|
236 |
+
|
237 |
+
self.stats['total'] += 1
|
238 |
+
|
239 |
+
|
240 |
+
class TweetWordGenerator(WordGenerator):
|
241 |
+
''' Returns np array or generator of ASCII sentences for given tweet input.
|
242 |
+
Any file opening/closing should be handled outside of this class.
|
243 |
+
'''
|
244 |
+
def __init__(self, stream, wanted_emojis=None, english_words=None,
|
245 |
+
non_english_user_set=None, allow_unicode_text=False,
|
246 |
+
ignore_retweets=True, ignore_url_tweets=True,
|
247 |
+
ignore_mention_tweets=False):
|
248 |
+
|
249 |
+
self.wanted_emojis = wanted_emojis
|
250 |
+
self.english_words = english_words
|
251 |
+
self.non_english_user_set = non_english_user_set
|
252 |
+
self.ignore_retweets = ignore_retweets
|
253 |
+
self.ignore_url_tweets = ignore_url_tweets
|
254 |
+
self.ignore_mention_tweets = ignore_mention_tweets
|
255 |
+
WordGenerator.__init__(self, stream,
|
256 |
+
allow_unicode_text=allow_unicode_text)
|
257 |
+
|
258 |
+
def validated_tweet(self, data):
|
259 |
+
''' A bunch of checks to determine whether the tweet is valid.
|
260 |
+
Also returns emojis contained by the tweet.
|
261 |
+
'''
|
262 |
+
|
263 |
+
# Ordering of validations is important for speed
|
264 |
+
# If it passes all checks, then the tweet is validated for usage
|
265 |
+
|
266 |
+
# Skips incomplete tweets
|
267 |
+
if len(data) <= 9:
|
268 |
+
return False, []
|
269 |
+
|
270 |
+
text = data[9]
|
271 |
+
|
272 |
+
if self.ignore_retweets and RETWEETS_RE.search(text):
|
273 |
+
return False, []
|
274 |
+
|
275 |
+
if self.ignore_url_tweets and URLS_RE.search(text):
|
276 |
+
return False, []
|
277 |
+
|
278 |
+
if self.ignore_mention_tweets and MENTION_RE.search(text):
|
279 |
+
return False, []
|
280 |
+
|
281 |
+
if self.wanted_emojis is not None:
|
282 |
+
uniq_emojis = np.unique(extract_emojis(text, self.wanted_emojis))
|
283 |
+
if len(uniq_emojis) == 0:
|
284 |
+
return False, []
|
285 |
+
else:
|
286 |
+
uniq_emojis = []
|
287 |
+
|
288 |
+
if self.non_english_user_set is not None and \
|
289 |
+
non_english_user(data[1], self.non_english_user_set):
|
290 |
+
return False, []
|
291 |
+
return True, uniq_emojis
|
292 |
+
|
293 |
+
def data_preprocess_filtering(self, line, iter_i):
|
294 |
+
fields = line.strip().split("\t")
|
295 |
+
valid, emojis = self.validated_tweet(fields)
|
296 |
+
text = fields[9].replace('\\n', '') \
|
297 |
+
.replace('\\r', '') \
|
298 |
+
.replace('&', '&') if valid else ''
|
299 |
+
return valid, text, {'emojis': emojis}
|
300 |
+
|
301 |
+
def data_postprocess_filtering(self, words, iter_i):
|
302 |
+
valid_length = correct_length(words, 1, None)
|
303 |
+
valid_english, n_words, n_english = mostly_english(words,
|
304 |
+
self.english_words)
|
305 |
+
if valid_length and valid_english:
|
306 |
+
return True, words, {'length': len(words),
|
307 |
+
'n_normal_words': n_words,
|
308 |
+
'n_english': n_english}
|
309 |
+
else:
|
310 |
+
return False, [], {'length': len(words),
|
311 |
+
'n_normal_words': n_words,
|
312 |
+
'n_english': n_english}
|
vocabulary.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|