Chen YiJia commited on
Commit
091d980
1 Parent(s): 5cedadf

reorganise

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
first_epoch/config.json → config.json RENAMED
File without changes
first_epoch/eval_results.txt → eval_results.txt RENAMED
File without changes
first_epoch/generation_config.json → generation_config.json RENAMED
File without changes
first_epoch/model_args.json → model_args.json RENAMED
File without changes
first_epoch/optimizer.pt → optimizer.pt RENAMED
File without changes
first_epoch/pytorch_model.bin → pytorch_model.bin RENAMED
File without changes
first_epoch/scheduler.pt → scheduler.pt RENAMED
File without changes
first_epoch/special_tokens_map.json → special_tokens_map.json RENAMED
File without changes
first_epoch/spiece.model → spiece.model RENAMED
File without changes
test.py DELETED
@@ -1,33 +0,0 @@
1
- from datasets.load import load_dataset
2
- import logging
3
- import sacrebleu
4
- import pandas as pd
5
- from simpletransformers.t5 import T5Model, T5Args
6
-
7
- raw_datasets = load_dataset('iwslt2017', 'iwslt2017-zh-en')
8
-
9
- logging.basicConfig(level=logging.INFO)
10
- transformers_logger = logging.getLogger("transformers")
11
- transformers_logger.setLevel(logging.WARNING)
12
-
13
-
14
- model_args = T5Args()
15
- model_args.max_length = 512
16
- model_args.length_penalty = 1
17
- model_args.num_beams = 10
18
-
19
- model = T5Model("mt5", "outputs", args=model_args)
20
-
21
- en_zh_test = pd.DataFrame(raw_datasets['test']['translation'])
22
- zh_truth = en_zh_test['zh'].tolist()
23
- en_input = en_zh_test['en'].tolist()
24
-
25
- zh_preds = model.predict(en_input)
26
- en_zh_bleu = sacrebleu.corpus_bleu(zh_preds, zh_truth)
27
- print("----------------------------------------------")
28
- print("English to Chinese: ", en_zh_bleu.score)
29
-
30
- en_preds = model.predict(zh_truth)
31
- zh_en_bleu = sacrebleu.corpus_bleu(en_preds, en_input)
32
- print("----------------------------------------------")
33
- print("Chinese to English: ", zh_en_bleu.score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
first_epoch/tokenizer_config.json → tokenizer_config.json RENAMED
File without changes
train.py DELETED
@@ -1,49 +0,0 @@
1
- from datasets.load import load_dataset
2
- import pandas as pd
3
- import logging
4
- from simpletransformers.t5 import T5Args, T5Model
5
-
6
- logging.basicConfig(level=logging.INFO)
7
- transformers_logger = logging.getLogger("transformers")
8
- transformers_logger.setLevel(logging.WARNING)
9
-
10
- raw_datasets = load_dataset('iwslt2017', 'iwslt2017-zh-en')
11
-
12
- train_df = pd.DataFrame(raw_datasets['train']['translation'])
13
- train_df.columns = ['input_text', 'target_text']
14
- reverse_df = train_df.copy()
15
- reverse_df.columns = ['target_text', 'input_text']
16
- train_df['prefix'] = 'translate english to chinese'
17
- reverse_df['prefix'] = 'translate chinese to english'
18
- train_df = pd.concat([train_df, reverse_df])
19
-
20
- eval_df = pd.DataFrame(raw_datasets['validation']['translation'])
21
- eval_df.columns = ['input_text', 'target_text']
22
- reverse_df = eval_df.copy()
23
- reverse_df.columns = ['target_text', 'input_text']
24
- eval_df['prefix'] = 'translate english to chinese'
25
- reverse_df['prefix'] = 'translate chinese to english'
26
- eval_df = pd.concat([eval_df, reverse_df])
27
-
28
- model_args = T5Args()
29
- model_args.max_seq_length = 96
30
- model_args.train_batch_size = 20
31
- model_args.eval_batch_size = 20
32
- model_args.num_train_epochs = 4
33
- model_args.evaluate_during_training = True
34
- model_args.evaluate_during_training_steps = 5000
35
- model_args.use_multiprocessing = False
36
- model_args.fp16 = False
37
- model_args.save_steps = -1
38
- model_args.save_model_every_epoch = True
39
- model_args.save_eval_checkpoints = False
40
- model_args.no_cache = True
41
- model_args.reprocess_input_data = True
42
- model_args.overwrite_output_dir = False
43
- model_args.preprocess_inputs = False
44
- model_args.num_return_sequences = 1
45
- model_args.wandb_project = "MT5 English-Chinese Translation"
46
-
47
- model = T5Model("mt5", "outputs", args=model_args)
48
-
49
- model.train_model(train_df, eval_data=eval_df, output_dir='mt5_more_epochs')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
first_epoch/training_args.bin → training_args.bin RENAMED
File without changes