eubinecto
commited on
Commit
·
539e83f
1
Parent(s):
d2dce47
infer logic added
Browse files- config.yaml +4 -4
- idiomify/fetchers.py +21 -9
- idiomify/models.py +12 -12
- idiomify/tensors.py +1 -1
- main_infer.py +34 -0
- main_train.py +4 -5
config.yaml
CHANGED
@@ -7,7 +7,7 @@ alpha:
|
|
7 |
idiom2def_ver: c
|
8 |
k: 11
|
9 |
lr: 0.00001
|
10 |
-
max_epochs:
|
11 |
batch_size: 64
|
12 |
shuffle: true
|
13 |
kor2eng:
|
@@ -18,7 +18,7 @@ alpha:
|
|
18 |
idiom2def_ver: d
|
19 |
k: 11
|
20 |
lr: 0.00001
|
21 |
-
max_epochs:
|
22 |
batch_size: 64
|
23 |
num_workers: 4
|
24 |
shuffle: true
|
@@ -30,7 +30,7 @@ gamma:
|
|
30 |
idiom2def_ver: c
|
31 |
k: 11
|
32 |
lr: 0.00001
|
33 |
-
max_epochs:
|
34 |
batch_size: 64
|
35 |
shuffle: true
|
36 |
kor2eng:
|
@@ -40,7 +40,7 @@ gamma:
|
|
40 |
idiom2def_ver: d
|
41 |
k: 11
|
42 |
lr: 0.00001
|
43 |
-
max_epochs:
|
44 |
batch_size: 64
|
45 |
num_workers: 4
|
46 |
shuffle: true
|
|
|
7 |
idiom2def_ver: c
|
8 |
k: 11
|
9 |
lr: 0.00001
|
10 |
+
max_epochs: 10
|
11 |
batch_size: 64
|
12 |
shuffle: true
|
13 |
kor2eng:
|
|
|
18 |
idiom2def_ver: d
|
19 |
k: 11
|
20 |
lr: 0.00001
|
21 |
+
max_epochs: 20
|
22 |
batch_size: 64
|
23 |
num_workers: 4
|
24 |
shuffle: true
|
|
|
30 |
idiom2def_ver: c
|
31 |
k: 11
|
32 |
lr: 0.00001
|
33 |
+
max_epochs: 10
|
34 |
batch_size: 64
|
35 |
shuffle: true
|
36 |
kor2eng:
|
|
|
40 |
idiom2def_ver: d
|
41 |
k: 11
|
42 |
lr: 0.00001
|
43 |
+
max_epochs: 20
|
44 |
batch_size: 64
|
45 |
num_workers: 4
|
46 |
shuffle: true
|
idiomify/fetchers.py
CHANGED
@@ -2,8 +2,10 @@ import csv
|
|
2 |
import yaml
|
3 |
import wandb
|
4 |
from typing import Tuple, List
|
5 |
-
from
|
6 |
-
from idiomify.
|
|
|
|
|
7 |
|
8 |
|
9 |
# dataset
|
@@ -35,13 +37,23 @@ def fetch_idioms(ver: str) -> List[str]:
|
|
35 |
]
|
36 |
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
def fetch_config() -> dict:
|
|
|
2 |
import yaml
|
3 |
import wandb
|
4 |
from typing import Tuple, List
|
5 |
+
from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
|
6 |
+
from idiomify.models import Alpha, Gamma, RD
|
7 |
+
from idiomify.paths import idiom2def_dir, CONFIG_YAML, idioms_dir, alpha_dir
|
8 |
+
from idiomify import tensors as T
|
9 |
|
10 |
|
11 |
# dataset
|
|
|
37 |
]
|
38 |
|
39 |
|
40 |
+
def fetch_rd(model: str, ver: str) -> RD:
|
41 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify-demo/{model}:{ver}", type="model")
|
42 |
+
config = artifact.metadata
|
43 |
+
artifact_path = alpha_dir(ver)
|
44 |
+
artifact.download(root=str(artifact_path))
|
45 |
+
mlm = AutoModelForMaskedLM.from_config(AutoConfig.from_pretrained(config['bert']))
|
46 |
+
ckpt_path = artifact_path / "rd.ckpt"
|
47 |
+
idioms = fetch_idioms(config['idioms_ver'])
|
48 |
+
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
49 |
+
idiom2subwords = T.idiom2subwords(idioms, tokenizer, config['k'])
|
50 |
+
if model == Alpha.name():
|
51 |
+
rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
|
52 |
+
elif model == Gamma.name():
|
53 |
+
rd = Gamma.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
|
54 |
+
else:
|
55 |
+
raise ValueError
|
56 |
+
return rd
|
57 |
|
58 |
|
59 |
def fetch_config() -> dict:
|
idiomify/models.py
CHANGED
@@ -29,17 +29,17 @@ class RD(pl.LightningModule):
|
|
29 |
def predict_dataloader(self):
|
30 |
pass
|
31 |
|
32 |
-
def __init__(self, mlm: BertForMaskedLM,
|
33 |
"""
|
34 |
:param mlm: a bert model for masked language modeling
|
35 |
-
:param
|
36 |
:return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
|
37 |
"""
|
38 |
super().__init__()
|
39 |
# -- hyper params --- #
|
40 |
# should be saved to self.hparams
|
41 |
# https://github.com/PyTorchLightning/pytorch-lightning/issues/4390#issue-730493746
|
42 |
-
self.save_hyperparameters(ignore=["mlm", "
|
43 |
# -- the only neural network we need -- #
|
44 |
self.mlm = mlm
|
45 |
# --- to be used for getting H_k --- #
|
@@ -47,7 +47,7 @@ class RD(pl.LightningModule):
|
|
47 |
# --- to be used for getting H_desc --- #
|
48 |
self.desc_mask: Optional[torch.Tensor] = None # (N, L)
|
49 |
# -- constant tensors -- #
|
50 |
-
self.register_buffer("
|
51 |
|
52 |
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
53 |
"""
|
@@ -94,7 +94,7 @@ class RD(pl.LightningModule):
|
|
94 |
:return: S_wisdom_literal (N, |W|)
|
95 |
"""
|
96 |
S_vocab = self.mlm.cls(H_k) # bmm; (N, K, H) * (H, |V|) -> (N, K, |V|)
|
97 |
-
indices = self.
|
98 |
S_wisdom_literal = S_vocab.gather(dim=-1, index=indices) # (N, K, |V|) -> (N, K, |W|)
|
99 |
S_wisdom_literal = S_wisdom_literal.sum(dim=1) # (N, K, |W|) -> (N, |W|)
|
100 |
return S_wisdom_literal
|
@@ -194,9 +194,9 @@ class Gamma(RD):
|
|
194 |
but the way we get S_wisdom_figurative is much simplified, compared with RDBeta.
|
195 |
"""
|
196 |
|
197 |
-
def __init__(self, mlm: BertForMaskedLM,
|
198 |
-
super().__init__(mlm,
|
199 |
-
# a pooler is a multilayer perceptron that pools wisdom_embeddings from
|
200 |
self.pooler = BiLSTMPooler(self.mlm.config.hidden_size)
|
201 |
# --- to be used to compute attentions --- #
|
202 |
self.attention_mask: Optional[torch.Tensor] = None
|
@@ -232,11 +232,11 @@ class Gamma(RD):
|
|
232 |
return S_wisdom, S_wisdom_literal, S_wisdom_figurative
|
233 |
|
234 |
def S_wisdom_figurative(self, H_all: torch.Tensor) -> torch.Tensor:
|
235 |
-
# --- draw the embeddings for wisdoms from the embeddings of
|
236 |
# this is to use as less of newly initialised weights as possible
|
237 |
-
|
238 |
-
.embeddings.word_embeddings(self.
|
239 |
-
wisdom_embeddings = self.pooler(
|
240 |
# --- draw H_wisdom from H_desc with attention --- #
|
241 |
H_cls = H_all[:, 0] # (N, L, H) -> (N, H)
|
242 |
H_desc = self.H_desc(H_all) # (N, L, H) -> (N, D, H)
|
|
|
29 |
def predict_dataloader(self):
|
30 |
pass
|
31 |
|
32 |
+
def __init__(self, mlm: BertForMaskedLM, idiom2subwords: torch.Tensor, k: int, lr: float): # noqa
|
33 |
"""
|
34 |
:param mlm: a bert model for masked language modeling
|
35 |
+
:param idiom2subwords: (|W|, K)
|
36 |
:return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
|
37 |
"""
|
38 |
super().__init__()
|
39 |
# -- hyper params --- #
|
40 |
# should be saved to self.hparams
|
41 |
# https://github.com/PyTorchLightning/pytorch-lightning/issues/4390#issue-730493746
|
42 |
+
self.save_hyperparameters(ignore=["mlm", "idiom2subwords"])
|
43 |
# -- the only neural network we need -- #
|
44 |
self.mlm = mlm
|
45 |
# --- to be used for getting H_k --- #
|
|
|
47 |
# --- to be used for getting H_desc --- #
|
48 |
self.desc_mask: Optional[torch.Tensor] = None # (N, L)
|
49 |
# -- constant tensors -- #
|
50 |
+
self.register_buffer("idiom2subwords", idiom2subwords) # (|W|, K)
|
51 |
|
52 |
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
53 |
"""
|
|
|
94 |
:return: S_wisdom_literal (N, |W|)
|
95 |
"""
|
96 |
S_vocab = self.mlm.cls(H_k) # bmm; (N, K, H) * (H, |V|) -> (N, K, |V|)
|
97 |
+
indices = self.idiom2subwords.T.repeat(S_vocab.shape[0], 1, 1) # (|W|, K) -> (N, K, |W|)
|
98 |
S_wisdom_literal = S_vocab.gather(dim=-1, index=indices) # (N, K, |V|) -> (N, K, |W|)
|
99 |
S_wisdom_literal = S_wisdom_literal.sum(dim=1) # (N, K, |W|) -> (N, |W|)
|
100 |
return S_wisdom_literal
|
|
|
194 |
but the way we get S_wisdom_figurative is much simplified, compared with RDBeta.
|
195 |
"""
|
196 |
|
197 |
+
def __init__(self, mlm: BertForMaskedLM, idiom2subwords: torch.Tensor, k: int, lr: float):
|
198 |
+
super().__init__(mlm, idiom2subwords, k, lr)
|
199 |
+
# a pooler is a multilayer perceptron that pools wisdom_embeddings from idiom2subwords_embeddings
|
200 |
self.pooler = BiLSTMPooler(self.mlm.config.hidden_size)
|
201 |
# --- to be used to compute attentions --- #
|
202 |
self.attention_mask: Optional[torch.Tensor] = None
|
|
|
232 |
return S_wisdom, S_wisdom_literal, S_wisdom_figurative
|
233 |
|
234 |
def S_wisdom_figurative(self, H_all: torch.Tensor) -> torch.Tensor:
|
235 |
+
# --- draw the embeddings for wisdoms from the embeddings of idiom2subwords -- #
|
236 |
# this is to use as less of newly initialised weights as possible
|
237 |
+
idiom2subwords_embeddings = self.mlm.bert \
|
238 |
+
.embeddings.word_embeddings(self.idiom2subwords) # (W, K) -> (W, K, H)
|
239 |
+
wisdom_embeddings = self.pooler(idiom2subwords_embeddings).squeeze() # (W, H, K) -> (W, H, 1) -> (W, H)
|
240 |
# --- draw H_wisdom from H_desc with attention --- #
|
241 |
H_cls = H_all[:, 0] # (N, L, H) -> (N, H)
|
242 |
H_desc = self.H_desc(H_all) # (N, L, H) -> (N, D, H)
|
idiomify/tensors.py
CHANGED
@@ -7,7 +7,7 @@ from typing import List
|
|
7 |
from transformers import BertTokenizer
|
8 |
|
9 |
|
10 |
-
def
|
11 |
mask_id = tokenizer.mask_token_id
|
12 |
pad_id = tokenizer.pad_token_id
|
13 |
# temporarily disable single-token status of the wisdoms
|
|
|
7 |
from transformers import BertTokenizer
|
8 |
|
9 |
|
10 |
+
def idiom2subwords(idioms: List[str], tokenizer: BertTokenizer, k: int) -> torch.Tensor:
|
11 |
mask_id = tokenizer.mask_token_id
|
12 |
pad_id = tokenizer.pad_token_id
|
13 |
# temporarily disable single-token status of the wisdoms
|
main_infer.py
CHANGED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from idiomify.fetchers import fetch_config, fetch_idioms, fetch_rd
|
3 |
+
from idiomify import tensors as T
|
4 |
+
from transformers import BertTokenizer
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--model", type=str,
|
10 |
+
default="alpha")
|
11 |
+
parser.add_argument("--ver", type=str,
|
12 |
+
default="eng2eng")
|
13 |
+
parser.add_argument("--sent", type=str,
|
14 |
+
default="avoid getting to the point")
|
15 |
+
args = parser.parse_args()
|
16 |
+
config = fetch_config()[args.model][args.ver]
|
17 |
+
config.update(vars(args))
|
18 |
+
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
19 |
+
idioms = fetch_idioms(config['idioms_ver'])
|
20 |
+
X = T.inputs([config['sent']], tokenizer, config['k'])
|
21 |
+
rd = fetch_rd(config['model'], config['ver'])
|
22 |
+
probs = rd.P_wisdom(X).squeeze().tolist()
|
23 |
+
wisdom2prob = [
|
24 |
+
(wisdom, prob)
|
25 |
+
for wisdom, prob in zip(idioms, probs)
|
26 |
+
]
|
27 |
+
# sort and append
|
28 |
+
res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
|
29 |
+
for idx, (idiom, prob) in enumerate(res):
|
30 |
+
print(idx, idiom, prob)
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
main()
|
main_train.py
CHANGED
@@ -15,7 +15,6 @@ from idiomify import tensors as T
|
|
15 |
|
16 |
def main():
|
17 |
parser = argparse.ArgumentParser()
|
18 |
-
parser.add_argument("entity", type=str)
|
19 |
parser.add_argument("--model", type=str, default="alpha")
|
20 |
parser.add_argument("--ver", type=str, default="eng2eng")
|
21 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
@@ -32,18 +31,18 @@ def main():
|
|
32 |
mlm = BertForMaskedLM.from_pretrained(config['bert'])
|
33 |
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
34 |
idioms = fetch_idioms(config['idioms_ver'])
|
35 |
-
|
36 |
# choose the model to train
|
37 |
if config['model'] == Alpha.name():
|
38 |
-
rd = Alpha(mlm,
|
39 |
elif config['model'] == Gamma.name():
|
40 |
-
rd = Gamma(mlm,
|
41 |
else:
|
42 |
raise ValueError
|
43 |
# prepare datamodule
|
44 |
datamodule = IdiomifyDataModule(config, tokenizer, idioms)
|
45 |
|
46 |
-
with wandb.init(entity=
|
47 |
logger = WandbLogger(log_model=False)
|
48 |
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
49 |
fast_dev_run=config['fast_dev_run'],
|
|
|
15 |
|
16 |
def main():
|
17 |
parser = argparse.ArgumentParser()
|
|
|
18 |
parser.add_argument("--model", type=str, default="alpha")
|
19 |
parser.add_argument("--ver", type=str, default="eng2eng")
|
20 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
|
|
31 |
mlm = BertForMaskedLM.from_pretrained(config['bert'])
|
32 |
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
33 |
idioms = fetch_idioms(config['idioms_ver'])
|
34 |
+
idiom2subwords = T.idiom2subwords(idioms, tokenizer, config['k'])
|
35 |
# choose the model to train
|
36 |
if config['model'] == Alpha.name():
|
37 |
+
rd = Alpha(mlm, idiom2subwords, config['k'], config['lr'])
|
38 |
elif config['model'] == Gamma.name():
|
39 |
+
rd = Gamma(mlm, idiom2subwords, config['k'], config['lr'])
|
40 |
else:
|
41 |
raise ValueError
|
42 |
# prepare datamodule
|
43 |
datamodule = IdiomifyDataModule(config, tokenizer, idioms)
|
44 |
|
45 |
+
with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
|
46 |
logger = WandbLogger(log_model=False)
|
47 |
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
48 |
fast_dev_run=config['fast_dev_run'],
|