[#1] merging issue_1 to main
Browse files- .gitignore +5 -0
- README.md +8 -75
- config.yaml +8 -0
- explore/explore_bart.py +16 -0
- explore/explore_bart_for_conditional_generation.py +10 -0
- explore/explore_bart_logits_shape.py +39 -0
- explore/explore_fetch_idioms.py +9 -0
- explore/explore_fetch_literal2idiomatic.py +10 -0
- explore/explore_fetch_pie.py +14 -0
- explore/explore_fetch_seq2seq.py +10 -0
- explore/explore_fetch_seq2seq_predict.py +19 -0
- explore/explore_idiom2subwords.py +0 -0
- explore/explore_idiomifydatamodule.py +26 -0
- explore/explore_nlpaug.py +21 -0
- explore/explore_src_builder.py +18 -0
- explore/explore_tgt_builder.py +19 -0
- idiomify/__init__.py +0 -0
- idiomify/builders.py +87 -0
- idiomify/data.py +78 -0
- idiomify/fetchers.py +71 -0
- idiomify/metrics.py +4 -0
- idiomify/models.py +78 -0
- idiomify/paths.py +17 -0
- idiomify/urls.py +13 -0
- main_infer.py +25 -0
- main_train.py +56 -0
- main_upload_idioms.py +37 -0
- main_upload_literal2idiomatic.py +40 -0
- requirements.txt +3 -0
.gitignore
CHANGED
@@ -127,3 +127,8 @@ dmypy.json
|
|
127 |
|
128 |
# Pyre type checker
|
129 |
.pyre/
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
# Pyre type checker
|
129 |
.pyre/
|
130 |
+
|
131 |
+
artifacts
|
132 |
+
wandb
|
133 |
+
.idea
|
134 |
+
|
README.md
CHANGED
@@ -1,80 +1,13 @@
|
|
1 |
-
#
|
2 |
-
A human-inspired Idiomifier.
|
3 |
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
keywords: idiomify, idioms, inductive biases, novel predictions
|
8 |
|
9 |
-
## What are your research questions?
|
10 |
|
11 |
-
Given the following two [](https://huggingface.co/bert-base-uncased)connectionsist models (two versions of a language model called BERT):
|
12 |
-
|
13 |
-
| models | what task has it learned already? | what new task will they be taught? |
|
14 |
-
| --- | --- | --- |
|
15 |
-
| L1 Idiomifier | Has been pre-trained with fill-in-the-blank task on English Wikipedia only (i.e. Monolingual BERT) | Eng2Eng Idiomify task. |
|
16 |
-
| L2 Idiomifer | Has been pre-trained with fill-in-the-blank task on Wikipedia in multiple languages, including English. (i.e. Multilingual BERT) | Eng2Eng Idiomify task. (the same) |
|
17 |
-
|
18 |
-
where examples of Eng2Eng Idiomify task are:
|
19 |
-
<img width="813" alt="image" src="https://user-images.githubusercontent.com/56193069/154847480-adacff57-68fc-40c1-af73-dab478f8ab19.png">
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
I have the following two research questions:
|
24 |
-
|
25 |
-
1. (SLA → NLP) If we have both of the models **decreamentally infer** the figurative meaning of idioms from their constituents, will this lead to an increased performance in Eng2Eng Idiomify task?
|
26 |
-
2. (NLP → SLA) What differences can we observe from L1 & L2 Idiomifiers in how they learn Eng2Eng Idiomify task? From this, can we draw any **novel predictions** on how L1 & L2 learners might differ in learning idioms?
|
27 |
-
|
28 |
-
## But why? what is your rationale?
|
29 |
-
|
30 |
-
<img width="581" alt="image" src="https://user-images.githubusercontent.com/56193069/154847506-88c4283d-8a35-4c53-81c1-83c193ecf739.png">
|
31 |
-
|
32 |
-
|
33 |
-
In short, the reason I have the two questions is to **kill two birds with one stone,** where the two birds are ***suggesting better biases*** and ***suggesting novel predictions***, and the stone is ***designing a human-inspired Idiomifier**.*
|
34 |
-
|
35 |
-
### What do you mean by the first bird, *suggest better biases*? (SLA → NLP)
|
36 |
-
|
37 |
-
I think we could improve machines in processing idioms if we draw inspirations from how humans go about learning idioms. That is, if we could introduce human-inspired biases to machines, we may be able to improve their performance on figurative processing.
|
38 |
-
|
39 |
-
<img width="800" alt="image" src="https://user-images.githubusercontent.com/56193069/154848885-0e40af8d-7554-429e-aff3-965e6121afec.png">
|
40 |
-
|
41 |
-
|
42 |
-
But first, why do we even need to have machines better understand idioms? It is because, although a huge progress has been made within Natural Language Processing (NLP) in recent years, **figurative processing has always been a “pain in the neck” in NLP, so to speak.** Take [BERT](https://arxiv.org/abs/1810.04805) as an example. It is a connectionists language model that can be finetuned to fill-in-the-blanks (top left), answer a question (top right), summarize a pargraph (bottom left), analyse sentiments (bottom right), etc. These are by no means easy tasks to machines, but as you can see from the examples above, the performance of BERT on these colloquial tasks are quite impressive.
|
43 |
-
|
44 |
-
<img width="893" alt="image" src="https://user-images.githubusercontent.com/56193069/154848914-67a3aa0f-2171-433e-8a56-2187fff60f7c.png">
|
45 |
-
|
46 |
-
However, when it comes to processing idioms, BERT is far from impressive. Without even getting into the literature, you can already see how replacing *get ready* (left) to *wet my gills* (right) substantially changes the predictions on fill-in-the-blanks task, although the two phrases essentially mean the same thing. Ideally, the probability distribution should stay more or less the same, but it doesn’t. This is because, as with many other language models, BERT falls short at processing the figures of speech.
|
47 |
-
|
48 |
-
<img width="872" alt="image" src="https://user-images.githubusercontent.com/56193069/154848931-2b81a5fe-85b0-4868-bd20-d7326f83b9f3.png">
|
49 |
-
|
50 |
-
|
51 |
-
Given that the goal of NLP is to “process all forms of natural language well” (Haagsma, 2020), NLP researchers unanimously started to point out this problem in recent years. Just like how humans process natural language, a well-designed NLP unit should be able to process any forms of natural languge, whether it be formal (e.g. writing an email), colloquial (e.g. chatting with friends), canonical / structured (e.g. writing essays). While some success has been made in processing canonical language as we saw above, language models are “still far from revealing implicit meaning” of the figures of speech (Shawartz & Dagan, 2019). Likewise, “Idiomatic meaning gets overpowered by compositional meaning of the expressions” ( Saxemna & Paul, 2020), partly because their constituents are more often found separately in many corpora than together as idioms. All in all, “figurative language is an important research area for computational & congnitive linguistics”, as ACL remarks in their report on 2020 workshop, which was aptly named, *Figurative Language Processing.*
|
52 |
-
|
53 |
-
<img width="529" alt="image" src="https://user-images.githubusercontent.com/56193069/154848936-206d4d8a-3232-412c-91c6-62719207e1f0.png">
|
54 |
-
|
55 |
-
|
56 |
-
So, there is a huge room for improvement in figurative language processing, but where do we get the ideas for the improvement? We could take various approaches to this, but Shawatz & Dagan suggest (2019) what I think is arguably the most sensible approach: “get some inspiration from the way that **humans learn idioms”**. We at least have a working answer in the human brain, however elusive it may be, so it is sensible to at least try to replicate this in machines rather than to invent a completely new solution from scratch. It works in the human brain, so it may as well work in connectionsists language models ( layers of artifical neural networks). And this, this is what I mean by SLA could *suggest better biases* to NLP. That is, we could improve the performance of such language models on processing idioms, specifically BERT for my dissertation, by drawing inspirations (i.e. biases) from how humans learn idioms.
|
57 |
-
|
58 |
-
|
59 |
-
<img width="676" alt="image" src="https://user-images.githubusercontent.com/56193069/154848943-c800b0ca-5ad1-437a-9590-46b6b5d5cfb2.png">
|
60 |
-
|
61 |
-
|
62 |
-
What better biases have I found, then? the Global Elaboration Hypothesis posits (Levorato & Cacciari, 1995; karlson, 2019) that both L1 and L2 learners may start learning idioms by first deducing the figurative meaning from the literal meaning, for those idioms that are yet to take place in their mental lexicon (vocabulary). It is not like they get the metaphor behing the literal interpretation right off the bat. However, as the learners age and contine learing those idioms, they gradually treat idioms as a single chunk and stop relying on analogies to get the figurative meaning. For example, when L2 learners of English encounter the idiom *throw the baby out with the bathwater* for the first time, their first reaction is to interpret the meaning literally, which they analogize with a given context to guess the figurative meaning, *to ignore potentially important things.* However, as they go along, the gradually stop imagining babies being thrown altogether with dirty water in their minds, and at the end, they don’t even think of babies when using *throw baby out with the bathwater* in its idioamtic sense - they just use it as a single chunk at the end of their learning.
|
63 |
-
|
64 |
-
If that’s how we go about learning idioms, that is, if humans use the literal interpretaion of idioms to “bootstrap” their understaning on the figurative meaning, so to speak, then there is nothing stopping us to expect that the bootstrapping bias as such may be useful for teaching idioms to machines.
|
65 |
-
|
66 |
-
Hence, I believe it is sensible to ask the first question:
|
67 |
-
|
68 |
-
1. (SLA → NLP) If we have both of the models **decreamentally infer** the figurative meaning of idioms from their constituents, will this lead to an increased performance in the Eng2Eng Idiomify task? If so, what would be the mathematical interpretation of such human-inspired success?
|
69 |
-
|
70 |
-
### What do you mean by the second bird, *suggest novel predictions? (NLP → SLA)*
|
71 |
-
(work in progress)
|
72 |
-
|
73 |
-
|
74 |
-
1. (NLP → SLA) What differences can we observe from L1 & L2 Idiomifiers in how they learn idioms? From this, **can we draw any novel predictons on how L1 & L2 learners learn idioms?**
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
## Miscellaneous
|
79 |
-

|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Idiomify
|
|
|
2 |
|
3 |
+
A human-inspired Idiomifier based on BERT
|
4 |
|
5 |
+
<img width="807" alt="image" src="https://user-images.githubusercontent.com/56193069/153775460-5ca04edd-e788-442d-b0f1-e780dc0a5724.png">
|
|
|
6 |
|
|
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
## Requirements
|
10 |
+
- wandb
|
11 |
+
- pytorch-lightning
|
12 |
+
- transformers
|
13 |
+
- pandas
|
config.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tag011:
|
2 |
+
desc: just overfitting
|
3 |
+
bart: facebook/bart-base
|
4 |
+
lr: 0.0001
|
5 |
+
literal2idiomatic_ver: tag01
|
6 |
+
max_epochs: 100
|
7 |
+
batch_size: 100
|
8 |
+
shuffle: true
|
explore/explore_bart.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer, BartModel
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
|
6 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
7 |
+
model = BartModel.from_pretrained('facebook/bart-large')
|
8 |
+
|
9 |
+
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
10 |
+
outputs = model(**inputs)
|
11 |
+
H_all = outputs.last_hidden_state # noqa
|
12 |
+
print(H_all.shape) # (1, 8, 1024)
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
main()
|
explore/explore_bart_for_conditional_generation.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
pass
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
main()
|
explore/explore_bart_logits_shape.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
2 |
+
|
3 |
+
from data import IdiomifyDataModule
|
4 |
+
|
5 |
+
|
6 |
+
CONFIG = {
|
7 |
+
"literal2idiomatic_ver": "pie_v0",
|
8 |
+
"batch_size": 20,
|
9 |
+
"num_workers": 4,
|
10 |
+
"shuffle": True
|
11 |
+
}
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
16 |
+
bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
|
17 |
+
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
|
18 |
+
datamodule.prepare_data()
|
19 |
+
datamodule.setup()
|
20 |
+
for batch in datamodule.train_dataloader():
|
21 |
+
srcs, tgts_r, tgts = batch
|
22 |
+
input_ids, attention_mask = srcs[:, 0], srcs[:, 1] # noqa
|
23 |
+
decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
|
24 |
+
outputs = bart(input_ids=input_ids,
|
25 |
+
attention_mask=attention_mask,
|
26 |
+
decoder_input_ids=decoder_input_ids,
|
27 |
+
decoder_attention_mask=decoder_attention_mask)
|
28 |
+
logits = outputs[0]
|
29 |
+
print(logits.shape)
|
30 |
+
"""
|
31 |
+
torch.Size([20, 47, 50265])
|
32 |
+
(N, L, |V|)
|
33 |
+
"""
|
34 |
+
|
35 |
+
break
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
main()
|
explore/explore_fetch_idioms.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from idiomify.fetchers import fetch_idioms
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
print(fetch_idioms("pie_v0"))
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == '__main__':
|
9 |
+
main()
|
explore/explore_fetch_literal2idiomatic.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from idiomify.fetchers import fetch_literal2idiomatic
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
for src, tgt in fetch_literal2idiomatic("pie_v0"):
|
6 |
+
print(src, "->", tgt)
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
main()
|
explore/explore_fetch_pie.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from idiomify.fetchers import fetch_pie
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
for idx, row in enumerate(fetch_pie()):
|
7 |
+
print(idx, row)
|
8 |
+
# the first 105 = V0.
|
9 |
+
if idx == 105:
|
10 |
+
break
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
main()
|
explore/explore_fetch_seq2seq.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from idiomify.fetchers import fetch_seq2seq
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
model = fetch_seq2seq("overfit")
|
6 |
+
print(model.bart.config)
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
main()
|
explore/explore_fetch_seq2seq_predict.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer
|
2 |
+
from builders import SourcesBuilder
|
3 |
+
from fetchers import fetch_seq2seq
|
4 |
+
|
5 |
+
|
6 |
+
def main():
|
7 |
+
model = fetch_seq2seq("overfit")
|
8 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
9 |
+
lit2idi = [
|
10 |
+
("my man", ""),
|
11 |
+
("hello", "")
|
12 |
+
] # just some dummy stuff
|
13 |
+
srcs = SourcesBuilder(tokenizer)(lit2idi)
|
14 |
+
out = model.predict(srcs=srcs)
|
15 |
+
print(out)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
main()
|
explore/explore_idiom2subwords.py
ADDED
File without changes
|
explore/explore_idiomifydatamodule.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer
|
2 |
+
from idiomify.data import IdiomifyDataModule
|
3 |
+
|
4 |
+
|
5 |
+
CONFIG = {
|
6 |
+
"literal2idiomatic_ver": "pie_v0",
|
7 |
+
"batch_size": 20,
|
8 |
+
"num_workers": 4,
|
9 |
+
"shuffle": True
|
10 |
+
}
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
15 |
+
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
|
16 |
+
datamodule.prepare_data()
|
17 |
+
datamodule.setup()
|
18 |
+
for batch in datamodule.train_dataloader():
|
19 |
+
srcs, tgts_r, tgts = batch
|
20 |
+
print(srcs.shape)
|
21 |
+
print(tgts_r.shape)
|
22 |
+
print(tgts.shape)
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
main()
|
explore/explore_nlpaug.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import nlpaug.augmenter.word as naw
|
3 |
+
import nlpaug.augmenter.sentence as nas
|
4 |
+
|
5 |
+
import nltk
|
6 |
+
|
7 |
+
|
8 |
+
sent = "I am really happy with the new job and I mean that with sincere feeling"
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
nltk.download("omw-1.4")
|
13 |
+
# this seems legit! I could definitely use this to increase the accuracy of the model
|
14 |
+
# for a few idioms (possibly ten, ten very different but frequent idioms)
|
15 |
+
aug = naw.ContextualWordEmbsAug()
|
16 |
+
augmented = aug.augment(sent, n=10)
|
17 |
+
print(augmented)
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == '__main__':
|
21 |
+
main()
|
explore/explore_src_builder.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer
|
2 |
+
from idiomify.builders import SourcesBuilder
|
3 |
+
|
4 |
+
BATCH = [
|
5 |
+
("I could die at any moment", "I could kick the bucket at any moment"),
|
6 |
+
("Speak plainly", "Don't beat around the bush")
|
7 |
+
]
|
8 |
+
|
9 |
+
|
10 |
+
def main():
|
11 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
12 |
+
builder = SourcesBuilder(tokenizer)
|
13 |
+
src = builder(BATCH)
|
14 |
+
print(src)
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
main()
|
explore/explore_tgt_builder.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer
|
2 |
+
from idiomify.builders import TargetsBuilder
|
3 |
+
|
4 |
+
BATCH = [
|
5 |
+
("I could die at any moment", "I could kick the bucket at any moment"),
|
6 |
+
("Speak plainly", "Don't beat around the bush")
|
7 |
+
]
|
8 |
+
|
9 |
+
|
10 |
+
def main():
|
11 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
12 |
+
builder = TargetsBuilder(tokenizer)
|
13 |
+
tgt_r, tgt = builder(BATCH)
|
14 |
+
print(tgt_r)
|
15 |
+
print(tgt)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
main()
|
idiomify/__init__.py
ADDED
File without changes
|
idiomify/builders.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
all the functions for building tensors are defined here.
|
3 |
+
builders must accept device as one of the parameters.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
from typing import List, Tuple
|
7 |
+
from transformers import BartTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
class TensorBuilder:
|
11 |
+
|
12 |
+
def __init__(self, tokenizer: BartTokenizer):
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
|
15 |
+
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
16 |
+
raise NotImplementedError
|
17 |
+
|
18 |
+
|
19 |
+
class Idiom2SubwordsBuilder(TensorBuilder):
|
20 |
+
|
21 |
+
def __call__(self, idioms: List[str], k: int) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
1. The function takes in a list of idioms, and a maximum length of the input sequence.
|
24 |
+
2. It then splits the idioms into words, and pads the sequence to the maximum length.
|
25 |
+
3. It masks the padding tokens, and returns the input ids
|
26 |
+
:param idioms: a list of idioms, each of which is a list of tokens
|
27 |
+
:type idioms: List[str]
|
28 |
+
:param k: the maximum length of the idioms
|
29 |
+
:type k: int
|
30 |
+
:return: The input_ids of the idioms, with the pad tokens replaced by the mask token.
|
31 |
+
"""
|
32 |
+
mask_id = self.tokenizer.mask_token_id
|
33 |
+
pad_id = self.tokenizer.pad_token_id
|
34 |
+
# temporarily disable single-token status of the idioms
|
35 |
+
idioms = [idiom.split(" ") for idiom in idioms]
|
36 |
+
encodings = self.tokenizer(text=idioms,
|
37 |
+
add_special_tokens=False,
|
38 |
+
# should set this to True, as we already have the idioms split.
|
39 |
+
is_split_into_words=True,
|
40 |
+
padding='max_length',
|
41 |
+
max_length=k, # set to k
|
42 |
+
return_tensors="pt")
|
43 |
+
input_ids = encodings['input_ids']
|
44 |
+
input_ids[input_ids == pad_id] = mask_id
|
45 |
+
return input_ids
|
46 |
+
|
47 |
+
|
48 |
+
class SourcesBuilder(TensorBuilder):
|
49 |
+
"""
|
50 |
+
to be used for both training and inference
|
51 |
+
"""
|
52 |
+
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
53 |
+
encodings = self.tokenizer(text=[literal for literal, _ in literal2idiomatic],
|
54 |
+
return_tensors="pt",
|
55 |
+
padding=True,
|
56 |
+
truncation=True,
|
57 |
+
add_special_tokens=True)
|
58 |
+
src = torch.stack([encodings['input_ids'],
|
59 |
+
encodings['attention_mask']], dim=1) # (N, 2, L)
|
60 |
+
return src # (N, 2, L)
|
61 |
+
|
62 |
+
|
63 |
+
class TargetsRightShiftedBuilder(TensorBuilder):
|
64 |
+
"""
|
65 |
+
This is to be used only for training. As for inference, we don't need this.
|
66 |
+
"""
|
67 |
+
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
68 |
+
encodings = self.tokenizer([
|
69 |
+
self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
|
70 |
+
for _, idiomatic in literal2idiomatic
|
71 |
+
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
72 |
+
tgts_r = torch.stack([encodings['input_ids'],
|
73 |
+
encodings['attention_mask']], dim=1) # (N, 2, L)
|
74 |
+
return tgts_r
|
75 |
+
|
76 |
+
|
77 |
+
class TargetsBuilder(TensorBuilder):
|
78 |
+
|
79 |
+
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
80 |
+
encodings = self.tokenizer([
|
81 |
+
idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
|
82 |
+
for _, idiomatic in literal2idiomatic
|
83 |
+
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
84 |
+
tgts = encodings['input_ids']
|
85 |
+
return tgts # (N, L)
|
86 |
+
|
87 |
+
|
idiomify/data.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Tuple, Optional, List
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from pytorch_lightning import LightningDataModule
|
5 |
+
from wandb.sdk.wandb_run import Run
|
6 |
+
|
7 |
+
from idiomify.fetchers import fetch_literal2idiomatic
|
8 |
+
from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
|
9 |
+
from transformers import BartTokenizer
|
10 |
+
|
11 |
+
|
12 |
+
class IdiomifyDataset(Dataset):
|
13 |
+
def __init__(self,
|
14 |
+
srcs: torch.Tensor,
|
15 |
+
tgts_r: torch.Tensor,
|
16 |
+
tgts: torch.Tensor):
|
17 |
+
self.srcs = srcs # (N, 2, L)
|
18 |
+
self.tgts_r = tgts_r # (N, 2, L)
|
19 |
+
self.tgts = tgts # (N, L)
|
20 |
+
|
21 |
+
def __len__(self) -> int:
|
22 |
+
"""
|
23 |
+
Returning the size of the dataset
|
24 |
+
:return:
|
25 |
+
"""
|
26 |
+
assert self.srcs.shape[0] == self.tgts_r.shape[0] == self.tgts.shape[0]
|
27 |
+
return self.srcs.shape[0]
|
28 |
+
|
29 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
30 |
+
"""
|
31 |
+
Returns features & the label
|
32 |
+
:param idx:
|
33 |
+
:return:
|
34 |
+
"""
|
35 |
+
return self.srcs[idx], self.tgts_r[idx], self.tgts[idx]
|
36 |
+
|
37 |
+
|
38 |
+
class IdiomifyDataModule(LightningDataModule):
|
39 |
+
|
40 |
+
# boilerplate - just ignore these
|
41 |
+
def test_dataloader(self):
|
42 |
+
pass
|
43 |
+
|
44 |
+
def val_dataloader(self):
|
45 |
+
pass
|
46 |
+
|
47 |
+
def predict_dataloader(self):
|
48 |
+
pass
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
config: dict,
|
52 |
+
tokenizer: BartTokenizer,
|
53 |
+
run: Run = None):
|
54 |
+
super().__init__()
|
55 |
+
self.config = config
|
56 |
+
self.tokenizer = tokenizer
|
57 |
+
self.run = run
|
58 |
+
# --- to be downloaded & built --- #
|
59 |
+
self.literal2idiomatic: Optional[List[Tuple[str, str]]] = None
|
60 |
+
self.dataset: Optional[IdiomifyDataset] = None
|
61 |
+
|
62 |
+
def prepare_data(self):
|
63 |
+
"""
|
64 |
+
prepare: download all data needed for this from wandb to local.
|
65 |
+
"""
|
66 |
+
self.literal2idiomatic = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'], self.run)
|
67 |
+
|
68 |
+
def setup(self, stage: Optional[str] = None):
|
69 |
+
# --- set up the builders --- #
|
70 |
+
# build the datasets
|
71 |
+
srcs = SourcesBuilder(self.tokenizer)(self.literal2idiomatic)
|
72 |
+
tgts_r = TargetsRightShiftedBuilder(self.tokenizer)(self.literal2idiomatic)
|
73 |
+
tgts = TargetsBuilder(self.tokenizer)(self.literal2idiomatic)
|
74 |
+
self.dataset = IdiomifyDataset(srcs, tgts_r, tgts)
|
75 |
+
|
76 |
+
def train_dataloader(self) -> DataLoader:
|
77 |
+
return DataLoader(self.dataset, batch_size=self.config['batch_size'],
|
78 |
+
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
idiomify/fetchers.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
from os import path
|
3 |
+
import yaml
|
4 |
+
import wandb
|
5 |
+
import requests
|
6 |
+
from typing import Tuple, List
|
7 |
+
from wandb.sdk.wandb_run import Run
|
8 |
+
from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic, seq2seq_dir
|
9 |
+
from idiomify.urls import PIE_URL
|
10 |
+
from transformers import AutoModelForSeq2SeqLM, AutoConfig
|
11 |
+
from idiomify.models import Seq2Seq
|
12 |
+
|
13 |
+
|
14 |
+
def fetch_pie() -> list:
|
15 |
+
text = requests.get(PIE_URL).text
|
16 |
+
lines = (line for line in text.split("\n") if line)
|
17 |
+
reader = csv.reader(lines)
|
18 |
+
next(reader) # skip the header
|
19 |
+
return [
|
20 |
+
row
|
21 |
+
for row in reader
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
# --- from wandb --- #
|
26 |
+
def fetch_idioms(ver: str, run: Run = None) -> List[str]:
|
27 |
+
"""
|
28 |
+
why do you need this? -> you need this to have access to the idiom embeddings.
|
29 |
+
"""
|
30 |
+
# if run object is given, we track the lineage of the data.
|
31 |
+
# if not, we get the dataset via wandb Api.
|
32 |
+
if run:
|
33 |
+
artifact = run.use_artifact(f"idioms:{ver}", type="dataset")
|
34 |
+
else:
|
35 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
|
36 |
+
artifact_dir = artifact.download(root=idioms_dir(ver))
|
37 |
+
txt_path = path.join(artifact_dir, "all.txt")
|
38 |
+
with open(txt_path, 'r') as fh:
|
39 |
+
return [line.strip() for line in fh]
|
40 |
+
|
41 |
+
|
42 |
+
def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
|
43 |
+
# if run object is given, we track the lineage of the data.
|
44 |
+
# if not, we get the dataset via wandb Api.
|
45 |
+
if run:
|
46 |
+
artifact = run.use_artifact(f"literal2idiomatic:{ver}", type="dataset")
|
47 |
+
else:
|
48 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
|
49 |
+
artifact_dir = artifact.download(root=literal2idiomatic(ver))
|
50 |
+
tsv_path = path.join(artifact_dir, "all.tsv")
|
51 |
+
with open(tsv_path, 'r') as fh:
|
52 |
+
reader = csv.reader(fh, delimiter="\t")
|
53 |
+
return [(row[0], row[1]) for row in reader]
|
54 |
+
|
55 |
+
|
56 |
+
def fetch_seq2seq(ver: str, run: Run = None) -> Seq2Seq:
|
57 |
+
if run:
|
58 |
+
artifact = run.use_artifact(f"seq2seq:{ver}", type="model")
|
59 |
+
else:
|
60 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/seq2seq:{ver}", type="model")
|
61 |
+
config = artifact.metadata
|
62 |
+
artifact_dir = artifact.download(root=seq2seq_dir(ver))
|
63 |
+
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
64 |
+
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
65 |
+
alpha = Seq2Seq.load_from_checkpoint(ckpt_path, bart=bart)
|
66 |
+
return alpha
|
67 |
+
|
68 |
+
|
69 |
+
def fetch_config() -> dict:
|
70 |
+
with open(str(CONFIG_YAML), 'r', encoding="utf-8") as fh:
|
71 |
+
return yaml.safe_load(fh)
|
idiomify/metrics.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
you may want to include bleu score.
|
3 |
+
and more metrics for paraphrasing.
|
4 |
+
"""
|
idiomify/models.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The reverse dictionary models below are based off of: https://github.com/yhcc/BertForRD/blob/master/mono/model/bert.py
|
3 |
+
"""
|
4 |
+
from typing import Tuple
|
5 |
+
import torch
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
+
from idiomify.builders import SourcesBuilder
|
10 |
+
|
11 |
+
|
12 |
+
# for training
|
13 |
+
class Seq2Seq(pl.LightningModule): # noqa
|
14 |
+
"""
|
15 |
+
the baseline is in here.
|
16 |
+
"""
|
17 |
+
def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
|
18 |
+
super().__init__()
|
19 |
+
self.bart = bart
|
20 |
+
self.save_hyperparameters(ignore=["bart"])
|
21 |
+
|
22 |
+
def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
|
23 |
+
"""
|
24 |
+
as for using bart for CG, refer to:
|
25 |
+
https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartForQuestionAnswering.forward
|
26 |
+
param srcs: (N, 2, L_s)
|
27 |
+
param tgts_r: (N, 2, L_t)
|
28 |
+
return: (N, L, |V|)
|
29 |
+
"""
|
30 |
+
input_ids, attention_mask = srcs[:, 0], srcs[:, 1]
|
31 |
+
decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
|
32 |
+
outputs = self.bart(input_ids=input_ids,
|
33 |
+
attention_mask=attention_mask,
|
34 |
+
decoder_input_ids=decoder_input_ids,
|
35 |
+
decoder_attention_mask=decoder_attention_mask)
|
36 |
+
logits = outputs[0] # (N, L, |V|)
|
37 |
+
return logits
|
38 |
+
|
39 |
+
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
|
40 |
+
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
41 |
+
logits = self.forward(srcs, tgts_r) # -> (N, L, |V|)
|
42 |
+
logits = logits.transpose(1, 2) # (N, L, |V|) -> (N, |V|, L)
|
43 |
+
loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
|
44 |
+
.sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
|
45 |
+
return {
|
46 |
+
"loss": loss
|
47 |
+
}
|
48 |
+
|
49 |
+
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
50 |
+
self.log("Train/Loss", outputs['loss'])
|
51 |
+
|
52 |
+
def configure_optimizers(self) -> torch.optim.Optimizer:
|
53 |
+
"""
|
54 |
+
Instantiates and returns the optimizer to be used for this model
|
55 |
+
e.g. torch.optim.Adam
|
56 |
+
"""
|
57 |
+
# The authors used Adam, so we might as well use it as well.
|
58 |
+
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
59 |
+
|
60 |
+
|
61 |
+
# for inference
|
62 |
+
class Idiomifier:
|
63 |
+
|
64 |
+
def __init__(self, model: Seq2Seq, tokenizer: BartTokenizer):
|
65 |
+
self.model = model
|
66 |
+
self.builder = SourcesBuilder(tokenizer)
|
67 |
+
self.model.eval()
|
68 |
+
|
69 |
+
def __call__(self, src: str, max_length=100) -> str:
|
70 |
+
srcs = self.builder(literal2idiomatic=[(src, "")])
|
71 |
+
pred_ids = self.model.bart.generate(
|
72 |
+
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
73 |
+
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
74 |
+
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
75 |
+
max_length=max_length,
|
76 |
+
).squeeze() # -> (N, L_t) -> (L_t)
|
77 |
+
tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
78 |
+
return tgt
|
idiomify/paths.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
ROOT_DIR = Path(__file__).resolve().parent.parent
|
4 |
+
ARTIFACTS_DIR = ROOT_DIR / "artifacts"
|
5 |
+
CONFIG_YAML = ROOT_DIR / "config.yaml"
|
6 |
+
|
7 |
+
|
8 |
+
def idioms_dir(ver: str) -> Path:
|
9 |
+
return ARTIFACTS_DIR / f"idioms-{ver}"
|
10 |
+
|
11 |
+
|
12 |
+
def literal2idiomatic(ver: str) -> Path:
|
13 |
+
return ARTIFACTS_DIR / f"literal2idiomatic-{ver}"
|
14 |
+
|
15 |
+
|
16 |
+
def seq2seq_dir(ver: str) -> Path:
|
17 |
+
return ARTIFACTS_DIR / f"seq2seq-{ver}"
|
idiomify/urls.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# EPIE dataset
|
3 |
+
EPIE_IMMUTABLE_IDIOMS_TAGS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Static_Idioms_Corpus/Static_Idioms_Tags.txt" # noqa
|
4 |
+
EPIE_IMMUTABLE_IDIOMS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Static_Idioms_Corpus/Static_Idioms_Candidates.txt" # noqa
|
5 |
+
EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Static_Idioms_Corpus/Static_Idioms_Words.txt" # noqa
|
6 |
+
EPIE_MUTABLE_IDIOMS_TAGS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Formal_Idioms_Corpus/Formal_Idioms_Tags.txt" # noqa
|
7 |
+
EPIE_MUTABLE_IDIOMS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Formal_Idioms_Corpus/Formal_Idioms_Candidates.txt" # noqa
|
8 |
+
EPIE_MUTABLE_IDIOMS_CONTEXTS_URL = "https://github.com/prateeksaxena2809/EPIE_Corpus/blob/master/Formal_Idioms_Corpus/Formal_Idioms_Words.txt" # noqa
|
9 |
+
|
10 |
+
# PIE dataset (Zhou, 2021)
|
11 |
+
# https://aclanthology.org/2021.mwe-1.5/
|
12 |
+
# right, let's just work on it.
|
13 |
+
PIE_URL = "https://raw.githubusercontent.com/zhjjn/MWE_PIE/main/data_cleaned.csv"
|
main_infer.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from idiomify.models import Idiomifier
|
3 |
+
from idiomify.fetchers import fetch_config, fetch_seq2seq
|
4 |
+
from transformers import BartTokenizer
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--ver", type=str, default="tag011")
|
10 |
+
parser.add_argument("--src", type=str,
|
11 |
+
default="If there's any good to loosing my job,"
|
12 |
+
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
13 |
+
args = parser.parse_args()
|
14 |
+
config = fetch_config()[args.ver]
|
15 |
+
config.update(vars(args))
|
16 |
+
model = fetch_seq2seq(config['ver'])
|
17 |
+
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
+
idiomifier = Idiomifier(model, tokenizer)
|
19 |
+
src = config['src']
|
20 |
+
tgt = idiomifier(src=config['src'])
|
21 |
+
print(src, "\n->", tgt)
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == '__main__':
|
25 |
+
main()
|
main_train.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch.cuda
|
3 |
+
import wandb
|
4 |
+
import argparse
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from termcolor import colored
|
7 |
+
from pytorch_lightning.loggers import WandbLogger
|
8 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
+
from idiomify.data import IdiomifyDataModule
|
10 |
+
from idiomify.fetchers import fetch_config
|
11 |
+
from idiomify.models import Seq2Seq
|
12 |
+
from idiomify.paths import ROOT_DIR
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument("--ver", type=str, default="tag011")
|
18 |
+
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
19 |
+
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
20 |
+
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
21 |
+
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
22 |
+
args = parser.parse_args()
|
23 |
+
config = fetch_config()[args.ver]
|
24 |
+
config.update(vars(args))
|
25 |
+
if not config['upload']:
|
26 |
+
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
27 |
+
|
28 |
+
# prepare the model
|
29 |
+
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
30 |
+
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
31 |
+
model = Seq2Seq(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
32 |
+
# prepare the datamodule
|
33 |
+
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
34 |
+
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
35 |
+
logger = WandbLogger(log_model=False)
|
36 |
+
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
37 |
+
fast_dev_run=config['fast_dev_run'],
|
38 |
+
log_every_n_steps=config['log_every_n_steps'],
|
39 |
+
gpus=torch.cuda.device_count(),
|
40 |
+
default_root_dir=str(ROOT_DIR),
|
41 |
+
enable_checkpointing=False,
|
42 |
+
logger=logger)
|
43 |
+
# start training
|
44 |
+
trainer.fit(model=model, datamodule=datamodule)
|
45 |
+
# upload the model to wandb only if the training is properly done #
|
46 |
+
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
47 |
+
ckpt_path = ROOT_DIR / "model.ckpt"
|
48 |
+
trainer.save_checkpoint(str(ckpt_path))
|
49 |
+
artifact = wandb.Artifact(name="seq2seq", type="model", metadata=config)
|
50 |
+
artifact.add_file(str(ckpt_path))
|
51 |
+
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
52 |
+
os.remove(str(ckpt_path)) # make sure you remove it after you are done with uploading it
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
main()
|
main_upload_idioms.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Here, what should you do here?
|
3 |
+
just upload all idioms here - name it as epie.
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
from idiomify.paths import ROOT_DIR
|
7 |
+
from idiomify.fetchers import fetch_pie
|
8 |
+
import argparse
|
9 |
+
import wandb
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--ver", type=str, default="tag01")
|
15 |
+
config = vars(parser.parse_args())
|
16 |
+
|
17 |
+
# get the idioms here
|
18 |
+
if config['ver'] == "tag01":
|
19 |
+
# only the first 106, and this is for piloting
|
20 |
+
idioms = set([row[0] for row in fetch_pie()[:106]])
|
21 |
+
else:
|
22 |
+
raise NotImplementedError
|
23 |
+
idioms = list(idioms)
|
24 |
+
|
25 |
+
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
26 |
+
artifact = wandb.Artifact(name="idioms", type="dataset")
|
27 |
+
txt_path = ROOT_DIR / "all.txt"
|
28 |
+
with open(txt_path, 'w') as fh:
|
29 |
+
for idiom in idioms:
|
30 |
+
fh.write(idiom + "\n")
|
31 |
+
artifact.add_file(txt_path)
|
32 |
+
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
33 |
+
os.remove(txt_path)
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
main()
|
main_upload_literal2idiomatic.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Here, what should you do here?
|
3 |
+
just upload all idioms here - name it as epie.
|
4 |
+
"""
|
5 |
+
import csv
|
6 |
+
import os
|
7 |
+
from idiomify.paths import ROOT_DIR
|
8 |
+
from idiomify.fetchers import fetch_pie
|
9 |
+
import argparse
|
10 |
+
import wandb
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument("--ver", type=str, default="tag01")
|
16 |
+
config = vars(parser.parse_args())
|
17 |
+
|
18 |
+
# get the idioms here
|
19 |
+
if config['ver'] == "tag01":
|
20 |
+
# only the first 106, and we use this just for piloting
|
21 |
+
literal2idiom = [
|
22 |
+
(row[3], row[2]) for row in fetch_pie()[:106]
|
23 |
+
]
|
24 |
+
else:
|
25 |
+
raise NotImplementedError
|
26 |
+
|
27 |
+
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
28 |
+
artifact = wandb.Artifact(name="literal2idiomatic", type="dataset")
|
29 |
+
tsv_path = ROOT_DIR / "all.tsv"
|
30 |
+
with open(tsv_path, 'w') as fh:
|
31 |
+
writer = csv.writer(fh, delimiter="\t")
|
32 |
+
for row in literal2idiom:
|
33 |
+
writer.writerow(row)
|
34 |
+
artifact.add_file(tsv_path)
|
35 |
+
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
36 |
+
os.remove(tsv_path)
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
pytorch-lightning==1.5.10
|
2 |
+
transformers==4.16.2
|
3 |
+
wandb==0.12.10
|