eubinecto commited on
Commit
539e83f
·
1 Parent(s): d2dce47

infer logic added

Browse files
config.yaml CHANGED
@@ -7,7 +7,7 @@ alpha:
7
  idiom2def_ver: c
8
  k: 11
9
  lr: 0.00001
10
- max_epochs: 200
11
  batch_size: 64
12
  shuffle: true
13
  kor2eng:
@@ -18,7 +18,7 @@ alpha:
18
  idiom2def_ver: d
19
  k: 11
20
  lr: 0.00001
21
- max_epochs: 200
22
  batch_size: 64
23
  num_workers: 4
24
  shuffle: true
@@ -30,7 +30,7 @@ gamma:
30
  idiom2def_ver: c
31
  k: 11
32
  lr: 0.00001
33
- max_epochs: 200
34
  batch_size: 64
35
  shuffle: true
36
  kor2eng:
@@ -40,7 +40,7 @@ gamma:
40
  idiom2def_ver: d
41
  k: 11
42
  lr: 0.00001
43
- max_epochs: 200
44
  batch_size: 64
45
  num_workers: 4
46
  shuffle: true
 
7
  idiom2def_ver: c
8
  k: 11
9
  lr: 0.00001
10
+ max_epochs: 10
11
  batch_size: 64
12
  shuffle: true
13
  kor2eng:
 
18
  idiom2def_ver: d
19
  k: 11
20
  lr: 0.00001
21
+ max_epochs: 20
22
  batch_size: 64
23
  num_workers: 4
24
  shuffle: true
 
30
  idiom2def_ver: c
31
  k: 11
32
  lr: 0.00001
33
+ max_epochs: 10
34
  batch_size: 64
35
  shuffle: true
36
  kor2eng:
 
40
  idiom2def_ver: d
41
  k: 11
42
  lr: 0.00001
43
+ max_epochs: 20
44
  batch_size: 64
45
  num_workers: 4
46
  shuffle: true
idiomify/fetchers.py CHANGED
@@ -2,8 +2,10 @@ import csv
2
  import yaml
3
  import wandb
4
  from typing import Tuple, List
5
- from idiomify.models import Alpha, Gamma
6
- from idiomify.paths import idiom2def_dir, CONFIG_YAML, idioms_dir
 
 
7
 
8
 
9
  # dataset
@@ -35,13 +37,23 @@ def fetch_idioms(ver: str) -> List[str]:
35
  ]
36
 
37
 
38
- # models
39
- def fetch_alpha(ver: str) -> Alpha:
40
- pass
41
-
42
-
43
- def fetch_gamma(ver: str) -> Gamma:
44
- pass
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  def fetch_config() -> dict:
 
2
  import yaml
3
  import wandb
4
  from typing import Tuple, List
5
+ from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
6
+ from idiomify.models import Alpha, Gamma, RD
7
+ from idiomify.paths import idiom2def_dir, CONFIG_YAML, idioms_dir, alpha_dir
8
+ from idiomify import tensors as T
9
 
10
 
11
  # dataset
 
37
  ]
38
 
39
 
40
+ def fetch_rd(model: str, ver: str) -> RD:
41
+ artifact = wandb.Api().artifact(f"eubinecto/idiomify-demo/{model}:{ver}", type="model")
42
+ config = artifact.metadata
43
+ artifact_path = alpha_dir(ver)
44
+ artifact.download(root=str(artifact_path))
45
+ mlm = AutoModelForMaskedLM.from_config(AutoConfig.from_pretrained(config['bert']))
46
+ ckpt_path = artifact_path / "rd.ckpt"
47
+ idioms = fetch_idioms(config['idioms_ver'])
48
+ tokenizer = BertTokenizer.from_pretrained(config['bert'])
49
+ idiom2subwords = T.idiom2subwords(idioms, tokenizer, config['k'])
50
+ if model == Alpha.name():
51
+ rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
52
+ elif model == Gamma.name():
53
+ rd = Gamma.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
54
+ else:
55
+ raise ValueError
56
+ return rd
57
 
58
 
59
  def fetch_config() -> dict:
idiomify/models.py CHANGED
@@ -29,17 +29,17 @@ class RD(pl.LightningModule):
29
  def predict_dataloader(self):
30
  pass
31
 
32
- def __init__(self, mlm: BertForMaskedLM, wisdom2subwords: torch.Tensor, k: int, lr: float): # noqa
33
  """
34
  :param mlm: a bert model for masked language modeling
35
- :param wisdom2subwords: (|W|, K)
36
  :return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
37
  """
38
  super().__init__()
39
  # -- hyper params --- #
40
  # should be saved to self.hparams
41
  # https://github.com/PyTorchLightning/pytorch-lightning/issues/4390#issue-730493746
42
- self.save_hyperparameters(ignore=["mlm", "wisdom2subwords"])
43
  # -- the only neural network we need -- #
44
  self.mlm = mlm
45
  # --- to be used for getting H_k --- #
@@ -47,7 +47,7 @@ class RD(pl.LightningModule):
47
  # --- to be used for getting H_desc --- #
48
  self.desc_mask: Optional[torch.Tensor] = None # (N, L)
49
  # -- constant tensors -- #
50
- self.register_buffer("wisdom2subwords", wisdom2subwords) # (|W|, K)
51
 
52
  def forward(self, X: torch.Tensor) -> torch.Tensor:
53
  """
@@ -94,7 +94,7 @@ class RD(pl.LightningModule):
94
  :return: S_wisdom_literal (N, |W|)
95
  """
96
  S_vocab = self.mlm.cls(H_k) # bmm; (N, K, H) * (H, |V|) -> (N, K, |V|)
97
- indices = self.wisdom2subwords.T.repeat(S_vocab.shape[0], 1, 1) # (|W|, K) -> (N, K, |W|)
98
  S_wisdom_literal = S_vocab.gather(dim=-1, index=indices) # (N, K, |V|) -> (N, K, |W|)
99
  S_wisdom_literal = S_wisdom_literal.sum(dim=1) # (N, K, |W|) -> (N, |W|)
100
  return S_wisdom_literal
@@ -194,9 +194,9 @@ class Gamma(RD):
194
  but the way we get S_wisdom_figurative is much simplified, compared with RDBeta.
195
  """
196
 
197
- def __init__(self, mlm: BertForMaskedLM, wisdom2subwords: torch.Tensor, k: int, lr: float):
198
- super().__init__(mlm, wisdom2subwords, k, lr)
199
- # a pooler is a multilayer perceptron that pools wisdom_embeddings from wisdom2subwords_embeddings
200
  self.pooler = BiLSTMPooler(self.mlm.config.hidden_size)
201
  # --- to be used to compute attentions --- #
202
  self.attention_mask: Optional[torch.Tensor] = None
@@ -232,11 +232,11 @@ class Gamma(RD):
232
  return S_wisdom, S_wisdom_literal, S_wisdom_figurative
233
 
234
  def S_wisdom_figurative(self, H_all: torch.Tensor) -> torch.Tensor:
235
- # --- draw the embeddings for wisdoms from the embeddings of wisdom2subwords -- #
236
  # this is to use as less of newly initialised weights as possible
237
- wisdom2subwords_embeddings = self.mlm.bert \
238
- .embeddings.word_embeddings(self.wisdom2subwords) # (W, K) -> (W, K, H)
239
- wisdom_embeddings = self.pooler(wisdom2subwords_embeddings).squeeze() # (W, H, K) -> (W, H, 1) -> (W, H)
240
  # --- draw H_wisdom from H_desc with attention --- #
241
  H_cls = H_all[:, 0] # (N, L, H) -> (N, H)
242
  H_desc = self.H_desc(H_all) # (N, L, H) -> (N, D, H)
 
29
  def predict_dataloader(self):
30
  pass
31
 
32
+ def __init__(self, mlm: BertForMaskedLM, idiom2subwords: torch.Tensor, k: int, lr: float): # noqa
33
  """
34
  :param mlm: a bert model for masked language modeling
35
+ :param idiom2subwords: (|W|, K)
36
  :return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
37
  """
38
  super().__init__()
39
  # -- hyper params --- #
40
  # should be saved to self.hparams
41
  # https://github.com/PyTorchLightning/pytorch-lightning/issues/4390#issue-730493746
42
+ self.save_hyperparameters(ignore=["mlm", "idiom2subwords"])
43
  # -- the only neural network we need -- #
44
  self.mlm = mlm
45
  # --- to be used for getting H_k --- #
 
47
  # --- to be used for getting H_desc --- #
48
  self.desc_mask: Optional[torch.Tensor] = None # (N, L)
49
  # -- constant tensors -- #
50
+ self.register_buffer("idiom2subwords", idiom2subwords) # (|W|, K)
51
 
52
  def forward(self, X: torch.Tensor) -> torch.Tensor:
53
  """
 
94
  :return: S_wisdom_literal (N, |W|)
95
  """
96
  S_vocab = self.mlm.cls(H_k) # bmm; (N, K, H) * (H, |V|) -> (N, K, |V|)
97
+ indices = self.idiom2subwords.T.repeat(S_vocab.shape[0], 1, 1) # (|W|, K) -> (N, K, |W|)
98
  S_wisdom_literal = S_vocab.gather(dim=-1, index=indices) # (N, K, |V|) -> (N, K, |W|)
99
  S_wisdom_literal = S_wisdom_literal.sum(dim=1) # (N, K, |W|) -> (N, |W|)
100
  return S_wisdom_literal
 
194
  but the way we get S_wisdom_figurative is much simplified, compared with RDBeta.
195
  """
196
 
197
+ def __init__(self, mlm: BertForMaskedLM, idiom2subwords: torch.Tensor, k: int, lr: float):
198
+ super().__init__(mlm, idiom2subwords, k, lr)
199
+ # a pooler is a multilayer perceptron that pools wisdom_embeddings from idiom2subwords_embeddings
200
  self.pooler = BiLSTMPooler(self.mlm.config.hidden_size)
201
  # --- to be used to compute attentions --- #
202
  self.attention_mask: Optional[torch.Tensor] = None
 
232
  return S_wisdom, S_wisdom_literal, S_wisdom_figurative
233
 
234
  def S_wisdom_figurative(self, H_all: torch.Tensor) -> torch.Tensor:
235
+ # --- draw the embeddings for wisdoms from the embeddings of idiom2subwords -- #
236
  # this is to use as less of newly initialised weights as possible
237
+ idiom2subwords_embeddings = self.mlm.bert \
238
+ .embeddings.word_embeddings(self.idiom2subwords) # (W, K) -> (W, K, H)
239
+ wisdom_embeddings = self.pooler(idiom2subwords_embeddings).squeeze() # (W, H, K) -> (W, H, 1) -> (W, H)
240
  # --- draw H_wisdom from H_desc with attention --- #
241
  H_cls = H_all[:, 0] # (N, L, H) -> (N, H)
242
  H_desc = self.H_desc(H_all) # (N, L, H) -> (N, D, H)
idiomify/tensors.py CHANGED
@@ -7,7 +7,7 @@ from typing import List
7
  from transformers import BertTokenizer
8
 
9
 
10
- def wisdom2subwords(idioms: List[str], tokenizer: BertTokenizer, k: int) -> torch.Tensor:
11
  mask_id = tokenizer.mask_token_id
12
  pad_id = tokenizer.pad_token_id
13
  # temporarily disable single-token status of the wisdoms
 
7
  from transformers import BertTokenizer
8
 
9
 
10
+ def idiom2subwords(idioms: List[str], tokenizer: BertTokenizer, k: int) -> torch.Tensor:
11
  mask_id = tokenizer.mask_token_id
12
  pad_id = tokenizer.pad_token_id
13
  # temporarily disable single-token status of the wisdoms
main_infer.py CHANGED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from idiomify.fetchers import fetch_config, fetch_idioms, fetch_rd
3
+ from idiomify import tensors as T
4
+ from transformers import BertTokenizer
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--model", type=str,
10
+ default="alpha")
11
+ parser.add_argument("--ver", type=str,
12
+ default="eng2eng")
13
+ parser.add_argument("--sent", type=str,
14
+ default="avoid getting to the point")
15
+ args = parser.parse_args()
16
+ config = fetch_config()[args.model][args.ver]
17
+ config.update(vars(args))
18
+ tokenizer = BertTokenizer.from_pretrained(config['bert'])
19
+ idioms = fetch_idioms(config['idioms_ver'])
20
+ X = T.inputs([config['sent']], tokenizer, config['k'])
21
+ rd = fetch_rd(config['model'], config['ver'])
22
+ probs = rd.P_wisdom(X).squeeze().tolist()
23
+ wisdom2prob = [
24
+ (wisdom, prob)
25
+ for wisdom, prob in zip(idioms, probs)
26
+ ]
27
+ # sort and append
28
+ res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
29
+ for idx, (idiom, prob) in enumerate(res):
30
+ print(idx, idiom, prob)
31
+
32
+
33
+ if __name__ == '__main__':
34
+ main()
main_train.py CHANGED
@@ -15,7 +15,6 @@ from idiomify import tensors as T
15
 
16
  def main():
17
  parser = argparse.ArgumentParser()
18
- parser.add_argument("entity", type=str)
19
  parser.add_argument("--model", type=str, default="alpha")
20
  parser.add_argument("--ver", type=str, default="eng2eng")
21
  parser.add_argument("--num_workers", type=int, default=os.cpu_count())
@@ -32,18 +31,18 @@ def main():
32
  mlm = BertForMaskedLM.from_pretrained(config['bert'])
33
  tokenizer = BertTokenizer.from_pretrained(config['bert'])
34
  idioms = fetch_idioms(config['idioms_ver'])
35
- wisdom2subwords = T.wisdom2subwords(idioms, tokenizer, config['k'])
36
  # choose the model to train
37
  if config['model'] == Alpha.name():
38
- rd = Alpha(mlm, wisdom2subwords, config['k'], config['lr'])
39
  elif config['model'] == Gamma.name():
40
- rd = Gamma(mlm, wisdom2subwords, config['k'], config['lr'])
41
  else:
42
  raise ValueError
43
  # prepare datamodule
44
  datamodule = IdiomifyDataModule(config, tokenizer, idioms)
45
 
46
- with wandb.init(entity=config['entity'], project="idiomify_demo", config=config) as run:
47
  logger = WandbLogger(log_model=False)
48
  trainer = pl.Trainer(max_epochs=config['max_epochs'],
49
  fast_dev_run=config['fast_dev_run'],
 
15
 
16
  def main():
17
  parser = argparse.ArgumentParser()
 
18
  parser.add_argument("--model", type=str, default="alpha")
19
  parser.add_argument("--ver", type=str, default="eng2eng")
20
  parser.add_argument("--num_workers", type=int, default=os.cpu_count())
 
31
  mlm = BertForMaskedLM.from_pretrained(config['bert'])
32
  tokenizer = BertTokenizer.from_pretrained(config['bert'])
33
  idioms = fetch_idioms(config['idioms_ver'])
34
+ idiom2subwords = T.idiom2subwords(idioms, tokenizer, config['k'])
35
  # choose the model to train
36
  if config['model'] == Alpha.name():
37
+ rd = Alpha(mlm, idiom2subwords, config['k'], config['lr'])
38
  elif config['model'] == Gamma.name():
39
+ rd = Gamma(mlm, idiom2subwords, config['k'], config['lr'])
40
  else:
41
  raise ValueError
42
  # prepare datamodule
43
  datamodule = IdiomifyDataModule(config, tokenizer, idioms)
44
 
45
+ with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
46
  logger = WandbLogger(log_model=False)
47
  trainer = pl.Trainer(max_epochs=config['max_epochs'],
48
  fast_dev_run=config['fast_dev_run'],