eubinecto commited on
Commit
cffca27
·
unverified ·
2 Parent(s): 210581d c1728bd

Merge pull request #10 from eubinecto/issue-9

Browse files
config.yaml CHANGED
@@ -1,12 +1,13 @@
1
  # for training an idiomifier
2
  idiomifier:
3
- ver: m-1-2
4
- desc: just overfitting the model, but on the entire PIE dataset.
5
  bart: facebook/bart-base
6
- lr: 0.0001
7
- literal2idiomatic_ver: d-1-2
8
- idioms_ver: d-1-2
9
- max_epochs: 2
 
10
  batch_size: 40
11
  shuffle: true
12
  seed: 104
 
1
  # for training an idiomifier
2
  idiomifier:
3
+ ver: m-1-3
4
+ desc: Just overfitting on PIE dataset, but now with <idiom> & </idiom> special tokens.
5
  bart: facebook/bart-base
6
+ lr: 0.00005
7
+ literal2idiomatic_ver: d-1-3
8
+ idioms_ver: d-1-3
9
+ tokenizer_ver: t-1-1
10
+ max_epochs: 8
11
  batch_size: 40
12
  shuffle: true
13
  seed: 104
explore/explore_bart_tokenizer_decode_idiom_special_tokens.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from idiomify.fetchers import fetch_tokenizer
2
+
3
+
4
+ def main():
5
+ tokenizer = fetch_tokenizer("t-1-1")
6
+ sent = "There will always be a <idiom> silver lining </idiom> even when things look pitch black"
7
+ ids = tokenizer(sent)['input_ids']
8
+ print(ids)
9
+ decoded = tokenizer.decode(ids)
10
+ print(decoded)
11
+
12
+
13
+ if __name__ == '__main__':
14
+ main()
explore/explore_fetch_tokenizer.py CHANGED
@@ -12,6 +12,9 @@ def main():
12
  print(tokenizer.unk_token)
13
  print(tokenizer.additional_special_tokens) # this should have been added
14
 
 
 
 
15
 
16
  """
17
  <s>
@@ -22,6 +25,7 @@ def main():
22
  <pad>
23
  <unk>
24
  ['<idiom>', '</idiom>']
 
25
  """
26
 
27
  if __name__ == '__main__':
 
12
  print(tokenizer.unk_token)
13
  print(tokenizer.additional_special_tokens) # this should have been added
14
 
15
+ # the size of the vocab
16
+ print(len(tokenizer))
17
+
18
 
19
  """
20
  <s>
 
25
  <pad>
26
  <unk>
27
  ['<idiom>', '</idiom>']
28
+ 50267
29
  """
30
 
31
  if __name__ == '__main__':
idiomify/fetchers.py CHANGED
@@ -27,7 +27,7 @@ def fetch_idioms(ver: str, run: Run = None) -> pd.DataFrame:
27
  artifact = run.use_artifact(f"idioms:{ver}", type="dataset")
28
  else:
29
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
30
- artifact_dir = artifact.download(root=idioms_dir(ver))
31
  tsv_path = path.join(artifact_dir, "all.tsv")
32
  return pd.read_csv(tsv_path, sep="\t")
33
 
@@ -39,7 +39,7 @@ def fetch_literal2idiomatic(ver: str, run: Run = None) -> Tuple[pd.DataFrame, pd
39
  artifact = run.use_artifact(f"literal2idiomatic:{ver}", type="dataset")
40
  else:
41
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
42
- artifact_dir = artifact.download(root=literal2idiomatic(ver))
43
  train_path = path.join(artifact_dir, "train.tsv")
44
  test_path = path.join(artifact_dir, "test.tsv")
45
  train_df = pd.read_csv(train_path, sep="\t")
@@ -57,9 +57,10 @@ def fetch_idiomifier(ver: str, run: Run = None) -> Idiomifier:
57
  else:
58
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/idiomifier:{ver}", type="model")
59
  config = artifact.metadata
60
- artifact_dir = artifact.download(root=idiomifier_dir(ver))
61
  ckpt_path = path.join(artifact_dir, "model.ckpt")
62
  bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
 
63
  model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
64
  return model
65
 
@@ -69,7 +70,7 @@ def fetch_tokenizer(ver: str, run: Run = None) -> BartTokenizer:
69
  artifact = run.use_artifact(f"tokenizer:{ver}", type="other")
70
  else:
71
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/tokenizer:{ver}", type="other")
72
- artifact_dir = artifact.download(root=tokenizer_dir(ver))
73
  tokenizer = BartTokenizer.from_pretrained(artifact_dir)
74
  return tokenizer
75
 
 
27
  artifact = run.use_artifact(f"idioms:{ver}", type="dataset")
28
  else:
29
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
30
+ artifact_dir = artifact.download(root=str(idioms_dir(ver)))
31
  tsv_path = path.join(artifact_dir, "all.tsv")
32
  return pd.read_csv(tsv_path, sep="\t")
33
 
 
39
  artifact = run.use_artifact(f"literal2idiomatic:{ver}", type="dataset")
40
  else:
41
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
42
+ artifact_dir = artifact.download(root=str(literal2idiomatic(ver)))
43
  train_path = path.join(artifact_dir, "train.tsv")
44
  test_path = path.join(artifact_dir, "test.tsv")
45
  train_df = pd.read_csv(train_path, sep="\t")
 
57
  else:
58
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/idiomifier:{ver}", type="model")
59
  config = artifact.metadata
60
+ artifact_dir = artifact.download(root=str(idiomifier_dir(ver)))
61
  ckpt_path = path.join(artifact_dir, "model.ckpt")
62
  bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
63
+ bart.resize_token_embeddings(config['vocab_size'])
64
  model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
65
  return model
66
 
 
70
  artifact = run.use_artifact(f"tokenizer:{ver}", type="other")
71
  else:
72
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/tokenizer:{ver}", type="other")
73
+ artifact_dir = artifact.download(root=str(tokenizer_dir(ver)))
74
  tokenizer = BartTokenizer.from_pretrained(artifact_dir)
75
  return tokenizer
76
 
idiomify/models.py CHANGED
@@ -71,4 +71,3 @@ class Idiomifier(pl.LightningModule): # noqa
71
  """
72
  # The authors used Adam, so we might as well use it as well.
73
  return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
74
-
 
71
  """
72
  # The authors used Adam, so we might as well use it as well.
73
  return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
 
idiomify/paths.py CHANGED
@@ -19,4 +19,3 @@ def idiomifier_dir(ver: str) -> Path:
19
 
20
  def tokenizer_dir(ver: str) -> Path:
21
  return ARTIFACTS_DIR / f"tokenizer_{ver}"
22
-
 
19
 
20
  def tokenizer_dir(ver: str) -> Path:
21
  return ARTIFACTS_DIR / f"tokenizer_{ver}"
 
idiomify/pipeline.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List
2
  from transformers import BartTokenizer
3
  from idiomify.builders import SourcesBuilder
@@ -18,5 +19,9 @@ class Pipeline:
18
  decoder_start_token_id=self.model.hparams['bos_token_id'],
19
  max_length=max_length,
20
  ) # -> (N, L_t)
21
- tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
 
 
 
 
22
  return tgts
 
1
+ import re
2
  from typing import List
3
  from transformers import BartTokenizer
4
  from idiomify.builders import SourcesBuilder
 
19
  decoder_start_token_id=self.model.hparams['bos_token_id'],
20
  max_length=max_length,
21
  ) # -> (N, L_t)
22
+ tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
23
+ tgts = [
24
+ re.sub(r"<s>|</s>", "", tgt)
25
+ for tgt in tgts
26
+ ]
27
  return tgts
idiomify/preprocess.py CHANGED
@@ -59,4 +59,3 @@ def stratified_split(df: pd.DataFrame, ratio: float, seed: int) -> Tuple[pd.Data
59
  test_size=other_size, random_state=seed,
60
  shuffle=True)
61
  return ratio_df, other_df
62
-
 
59
  test_size=other_size, random_state=seed,
60
  shuffle=True)
61
  return ratio_df, other_df
 
main_deploy.py CHANGED
@@ -1,20 +1,18 @@
1
  """
2
  we deploy the pipeline via streamlit.
3
  """
4
- from typing import Tuple, List
5
  import streamlit as st
6
- from transformers import BartTokenizer
7
- from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_idioms
8
  from idiomify.pipeline import Pipeline
9
- from idiomify.models import Idiomifier
10
 
11
 
12
  @st.cache(allow_output_mutation=True)
13
- def fetch_resources() -> Tuple[dict, Idiomifier, BartTokenizer, List[str]]:
14
  config = fetch_config()['idiomifier']
15
  model = fetch_idiomifier(config['ver'])
 
16
  idioms = fetch_idioms(config['idioms_ver'])
17
- tokenizer = BartTokenizer.from_pretrained(config['bart'])
18
  return config, model, tokenizer, idioms
19
 
20
 
@@ -24,20 +22,21 @@ def main():
24
  model.eval()
25
  pipeline = Pipeline(model, tokenizer)
26
  st.title("Idiomify Demo")
27
- st.markdown(f"Author: `Eu-Bin KIM`")
28
- st.markdown(f"Version: `{config['ver']}`")
29
  text = st.text_area("Type sentences here",
30
- value="Just remember there will always be a hope even when things look black")
31
  with st.sidebar:
32
  st.subheader("Supported idioms")
 
33
  st.write(" / ".join(idioms))
34
 
35
  if st.button(label="Idiomify"):
36
  with st.spinner("Please wait..."):
37
  sents = [sent for sent in text.split(".") if sent]
38
- sents = pipeline(sents, max_length=200)
39
  # highlight the rule & honorifics that were applied
40
- st.write(". ".join(sents))
 
 
41
 
42
 
43
  if __name__ == '__main__':
 
1
  """
2
  we deploy the pipeline via streamlit.
3
  """
4
+ import re
5
  import streamlit as st
6
+ from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_idioms, fetch_tokenizer
 
7
  from idiomify.pipeline import Pipeline
 
8
 
9
 
10
  @st.cache(allow_output_mutation=True)
11
+ def fetch_resources() -> tuple:
12
  config = fetch_config()['idiomifier']
13
  model = fetch_idiomifier(config['ver'])
14
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'])
15
  idioms = fetch_idioms(config['idioms_ver'])
 
16
  return config, model, tokenizer, idioms
17
 
18
 
 
22
  model.eval()
23
  pipeline = Pipeline(model, tokenizer)
24
  st.title("Idiomify Demo")
 
 
25
  text = st.text_area("Type sentences here",
26
+ value="Just remember that there will always be a hope even when things look hopeless")
27
  with st.sidebar:
28
  st.subheader("Supported idioms")
29
+ idioms = [row["Idiom"] for _, row in idioms.iterrows()]
30
  st.write(" / ".join(idioms))
31
 
32
  if st.button(label="Idiomify"):
33
  with st.spinner("Please wait..."):
34
  sents = [sent for sent in text.split(".") if sent]
35
+ preds = pipeline(sents, max_length=200)
36
  # highlight the rule & honorifics that were applied
37
+ preds = [re.sub(r"<idiom>|</idiom>", "`", pred)
38
+ for pred in preds]
39
+ st.markdown(". ".join(preds))
40
 
41
 
42
  if __name__ == '__main__':
main_eval.py CHANGED
@@ -6,7 +6,7 @@ import pytorch_lightning as pl
6
  from pytorch_lightning.loggers import WandbLogger
7
  from transformers import BartTokenizer
8
  from idiomify.datamodules import IdiomifyDataModule
9
- from idiomify.fetchers import fetch_config, fetch_idiomifier
10
  from idiomify.paths import ROOT_DIR
11
 
12
 
@@ -17,10 +17,10 @@ def main():
17
  args = parser.parse_args()
18
  config = fetch_config()['idiomifier']
19
  config.update(vars(args))
20
- tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
  # prepare the datamodule
22
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
23
  model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
 
24
  datamodule = IdiomifyDataModule(config, tokenizer, run)
25
  logger = WandbLogger(log_model=False)
26
  trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
 
6
  from pytorch_lightning.loggers import WandbLogger
7
  from transformers import BartTokenizer
8
  from idiomify.datamodules import IdiomifyDataModule
9
+ from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_tokenizer
10
  from idiomify.paths import ROOT_DIR
11
 
12
 
 
17
  args = parser.parse_args()
18
  config = fetch_config()['idiomifier']
19
  config.update(vars(args))
 
20
  # prepare the datamodule
21
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
22
  model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
23
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'], run)
24
  datamodule = IdiomifyDataModule(config, tokenizer, run)
25
  logger = WandbLogger(log_model=False)
26
  trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
main_infer.py CHANGED
@@ -3,25 +3,24 @@ This is for just a simple sanity check on the inference.
3
  """
4
  import argparse
5
  from idiomify.pipeline import Pipeline
6
- from idiomify.fetchers import fetch_config, fetch_idiomifier
7
  from transformers import BartTokenizer
8
 
9
 
10
  def main():
11
  parser = argparse.ArgumentParser()
12
  parser.add_argument("--sent", type=str,
13
- default="If there's any good to loosing my job,"
14
- " it's that I'll now be able to go to school full-time and finish my degree earlier.")
15
  args = parser.parse_args()
16
  config = fetch_config()['idiomifier']
17
  config.update(vars(args))
18
  model = fetch_idiomifier(config['ver'])
 
19
  model.eval() # this is crucial
20
- tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
  pipeline = Pipeline(model, tokenizer)
22
  src = config['sent']
23
- tgt = pipeline(sents=[config['sent']])
24
- print(src, "\n->", tgt)
25
 
26
 
27
  if __name__ == '__main__':
 
3
  """
4
  import argparse
5
  from idiomify.pipeline import Pipeline
6
+ from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_tokenizer
7
  from transformers import BartTokenizer
8
 
9
 
10
  def main():
11
  parser = argparse.ArgumentParser()
12
  parser.add_argument("--sent", type=str,
13
+ default="Just remember that there will always be a hope even when things look hopeless")
 
14
  args = parser.parse_args()
15
  config = fetch_config()['idiomifier']
16
  config.update(vars(args))
17
  model = fetch_idiomifier(config['ver'])
18
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'])
19
  model.eval() # this is crucial
 
20
  pipeline = Pipeline(model, tokenizer)
21
  src = config['sent']
22
+ tgts = pipeline(sents=[src])
23
+ print(src, "\n->", tgts[0])
24
 
25
 
26
  if __name__ == '__main__':
main_train.py CHANGED
@@ -5,9 +5,9 @@ import argparse
5
  import pytorch_lightning as pl
6
  from termcolor import colored
7
  from pytorch_lightning.loggers import WandbLogger
8
- from transformers import BartTokenizer, BartForConditionalGeneration
9
  from idiomify.datamodules import IdiomifyDataModule
10
- from idiomify.fetchers import fetch_config
11
  from idiomify.models import Idiomifier
12
  from idiomify.paths import ROOT_DIR
13
 
@@ -23,12 +23,13 @@ def main():
23
  config.update(vars(args))
24
  if not config['upload']:
25
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
26
- # prepare the model
27
  bart = BartForConditionalGeneration.from_pretrained(config['bart'])
28
- tokenizer = BartTokenizer.from_pretrained(config['bart'])
29
- model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
30
  # prepare the datamodule
31
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
 
 
 
32
  datamodule = IdiomifyDataModule(config, tokenizer, run)
33
  logger = WandbLogger(log_model=False)
34
  trainer = pl.Trainer(max_epochs=config['max_epochs'],
@@ -44,6 +45,7 @@ def main():
44
  if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
45
  ckpt_path = ROOT_DIR / "model.ckpt"
46
  trainer.save_checkpoint(str(ckpt_path))
 
47
  artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
48
  artifact.add_file(str(ckpt_path))
49
  run.log_artifact(artifact, aliases=["latest", config['ver']])
 
5
  import pytorch_lightning as pl
6
  from termcolor import colored
7
  from pytorch_lightning.loggers import WandbLogger
8
+ from transformers import BartForConditionalGeneration
9
  from idiomify.datamodules import IdiomifyDataModule
10
+ from idiomify.fetchers import fetch_config, fetch_tokenizer
11
  from idiomify.models import Idiomifier
12
  from idiomify.paths import ROOT_DIR
13
 
 
23
  config.update(vars(args))
24
  if not config['upload']:
25
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
26
+ # prepare a pre-trained BART
27
  bart = BartForConditionalGeneration.from_pretrained(config['bart'])
 
 
28
  # prepare the datamodule
29
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
30
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'], run)
31
+ bart.resize_token_embeddings(len(tokenizer)) # because new tokens are added, this process is necessary
32
+ model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
33
  datamodule = IdiomifyDataModule(config, tokenizer, run)
34
  logger = WandbLogger(log_model=False)
35
  trainer = pl.Trainer(max_epochs=config['max_epochs'],
 
45
  if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
46
  ckpt_path = ROOT_DIR / "model.ckpt"
47
  trainer.save_checkpoint(str(ckpt_path))
48
+ config['vocab_size'] = len(tokenizer) # this will be needed to fetch a pretrained idiomifier later
49
  artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
50
  artifact.add_file(str(ckpt_path))
51
  run.log_artifact(artifact, aliases=["latest", config['ver']])
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  pytorch-lightning==1.5.10
2
- transformers==4.16.2
3
- wandb==0.12.10
4
  scikit-learn==1.0.2
5
- pandas==1.3.5
6
  streamlit==1.7.0
7
  watchdog==2.1.6
 
1
  pytorch-lightning==1.5.10
2
+ transformers==4.17.0
3
+ wandb==0.12.11
4
  scikit-learn==1.0.2
5
+ pandas==1.4.1
6
  streamlit==1.7.0
7
  watchdog==2.1.6