|
from src.data.CodeGeneration.APPS_dataloader import APPS |
|
from src.data.CodeGeneration.MBPP_dataloader import MBPP |
|
from src.data.Arithmetic.python_scripts.Arithmetic_Dataset import Arithmetic_Dataset |
|
|
|
DEVICE = "cuda:0" |
|
|
|
|
|
DEBUG = False |
|
|
|
config = { |
|
"model": { |
|
"codellama": { |
|
"base_model_id": "codellama/CodeLlama-7b-hf", |
|
"quantitize": "int8", |
|
"dataset": "Arithmetic_Simple", |
|
"data_collator": "DataCollatorForSeq2Seq", |
|
"lora_config": { |
|
"r": 16, |
|
"lora_alpha": 16, |
|
"target_modules": [ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
"gate_proj", |
|
"up_proj", |
|
"down_proj", |
|
], |
|
"lora_dropout": 0.05, |
|
"bias": "none", |
|
"task_type": "CAUSAL_LM", |
|
}, |
|
"training_args": { |
|
"output_dir": "codellama-output", |
|
"warmup_steps": 100, |
|
"per_device_train_batch_size": 1, |
|
"per_device_eval_batch_size": 1, |
|
"gradient_accumulation_steps": 4, |
|
"max_steps": 10000, |
|
"learning_rate": 3e-4, |
|
"optim": "adamw_torch", |
|
"logging_dir": "codellama-output-logs", |
|
"logging_steps": 10, |
|
"save_strategy": "steps", |
|
"save_steps": 500, |
|
"load_best_model_at_end": False, |
|
"group_by_length": True, |
|
"fp16": True, |
|
"evaluation_strategy": "steps", |
|
"eval_steps": 1000, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}, |
|
"tokenizer": { |
|
"tokenize_config": { |
|
"truncation": True, |
|
"max_length": 192, |
|
"padding": "max_length", |
|
}, |
|
"prompt_template": "config/qa_template.txt", |
|
}, |
|
}, |
|
"phi-2": { |
|
"base_model_id": "microsoft/phi-2", |
|
"quantitize": "fp16", |
|
"dataset": "Arithmetic_Simple", |
|
"data_collator": "DataCollatorForLanguageModeling", |
|
"lora_config": { |
|
"r": 32, |
|
"lora_alpha": 64, |
|
"target_modules": [ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"dense", |
|
"fc1", |
|
"fc2", |
|
], |
|
"bias": "none", |
|
"lora_dropout": 0.05, |
|
"task_type": "CAUSAL_LM", |
|
}, |
|
"training_args": { |
|
"output_dir": "phi2-output", |
|
"warmup_steps": 500, |
|
|
|
"per_device_train_batch_size": 1, |
|
"per_device_eval_batch_size": 1, |
|
"gradient_accumulation_steps": 4, |
|
"max_steps": 100000, |
|
"learning_rate": 3e-4, |
|
"optim": "paged_adamw_8bit", |
|
"logging_dir": "phi2-output-logs", |
|
"logging_steps": 100, |
|
"save_strategy": "steps", |
|
"save_steps": 500, |
|
"evaluation_strategy": "steps", |
|
"eval_steps": 500, |
|
"fp16": True, |
|
}, |
|
"tokenizer": { |
|
"tokenize_config": { |
|
"truncation": True, |
|
"max_length": 512, |
|
"padding": "max_length", |
|
}, |
|
"prompt_template": "config/qa_template.txt", |
|
}, |
|
}, |
|
"phi-1.5":{ |
|
"base_model_id": "microsoft/phi-1.5", |
|
"quantitize": "fp16", |
|
"dataset": "Arithmetic_Hard", |
|
"data_collator": "DataCollatorForLanguageModeling", |
|
"lora_config":{ |
|
"r": 32, |
|
"lora_alpha":64, |
|
"target_modules":["q_proj", "k_proj", "v_proj"], |
|
"bias":"none", |
|
"lora_dropout":0.05, |
|
"task_type":"CAUSAL_LM", |
|
}, |
|
"training_args": { |
|
"output_dir": "phi-output", |
|
"warmup_steps": 1, |
|
"per_device_train_batch_size": 1, |
|
"per_device_eval_batch_size": 1, |
|
"gradient_accumulation_steps": 4, |
|
"max_steps": 10000, |
|
"learning_rate": 3e-4, |
|
"optim": "paged_adamw_8bit", |
|
"logging_dir": "phi-output-logs", |
|
"logging_steps": 10, |
|
"save_strategy": "steps", |
|
"save_steps": 500, |
|
"evaluation_strategy": "steps", |
|
"eval_steps": 500, |
|
"fp16": True, |
|
"report_to": "none", |
|
}, |
|
"tokenizer": { |
|
"tokenize_config": { |
|
"truncation": True, |
|
"max_length": 512, |
|
"padding": "max_length", |
|
}, |
|
"prompt_template": "config/qa_template.txt", |
|
}, |
|
}, |
|
"roberta":{ |
|
"base_model_id": "FacebookAI/roberta-large", |
|
"quantitize": "fp16", |
|
"dataset": "Arithmetic_Hard", |
|
"data_collator": "DataCollatorForLanguageModeling", |
|
"lora_config":{ |
|
"r": 32, |
|
"lora_alpha":64, |
|
"target_modules":["query", "key", "value"], |
|
"bias":"none", |
|
"lora_dropout":0.05, |
|
"task_type":"CAUSAL_LM", |
|
}, |
|
"training_args": { |
|
"output_dir": "roberta-output", |
|
"warmup_steps": 1, |
|
"per_device_train_batch_size": 1, |
|
"per_device_eval_batch_size": 1, |
|
"gradient_accumulation_steps": 4, |
|
"max_steps": 10000, |
|
"learning_rate": 3e-4, |
|
"optim": "paged_adamw_8bit", |
|
"logging_dir": "roberta-output-logs", |
|
"logging_steps": 10, |
|
"save_strategy": "steps", |
|
"save_steps": 500, |
|
"report_to": "none", |
|
}, |
|
"tokenizer": { |
|
"tokenize_config": { |
|
"truncation": True, |
|
"max_length": 512, |
|
"padding": "max_length", |
|
}, |
|
"prompt_template": "config/qa_template.txt", |
|
}, |
|
}, |
|
"deepseek": { |
|
"base_model_id": "deepseek-ai/deepseek-coder-1.3b-instruct", |
|
"quantitize": "bf16", |
|
"dataset": "Arithmetic_Simple", |
|
"data_collator": "DataCollatorForLanguageModeling", |
|
"lora_config": { |
|
"r": 32, |
|
"lora_alpha": 64, |
|
"target_modules": [ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
"gate_proj", |
|
"up_proj", |
|
"down_proj", |
|
], |
|
"bias": "none", |
|
"lora_dropout": 0.05, |
|
"task_type": "CAUSAL_LM", |
|
}, |
|
"lora_large_config": { |
|
"r": 128, |
|
"lora_alpha": 256, |
|
"target_modules": [ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
"gate_proj", |
|
"up_proj", |
|
"down_proj", |
|
], |
|
"bias": "none", |
|
"lora_dropout": 0.05, |
|
"task_type": "CAUSAL_LM", |
|
}, |
|
"p_tuning_config": { |
|
"num_virtual_tokens": 16, |
|
"num_transformer_submodules": 1, |
|
"token_dim": 2048, |
|
"encoder_hidden_size": 2048, |
|
"task_type": "CAUSAL_LM", |
|
}, |
|
"training_args": { |
|
"output_dir": "runs/deepseek-continue", |
|
"warmup_steps": 500, |
|
|
|
"per_device_train_batch_size": 4, |
|
"per_device_eval_batch_size": 4, |
|
"gradient_accumulation_steps": 1, |
|
"max_steps": 100000, |
|
"learning_rate": 5e-5, |
|
"optim": "paged_adamw_8bit", |
|
"logging_dir": "runs/deepseek-continue/logs", |
|
"logging_steps": 100, |
|
"save_strategy": "steps", |
|
"save_steps": 1000, |
|
"evaluation_strategy": "steps", |
|
"eval_steps": 1000, |
|
"fp16": True, |
|
}, |
|
"tokenizer": { |
|
"tokenize_config": { |
|
"truncation": True, |
|
"max_length": 512, |
|
"padding": "max_length", |
|
}, |
|
"prompt_template": "config/qa_template.txt", |
|
}, |
|
}, |
|
}, |
|
"dataset": { |
|
"simple_dataset": { |
|
"type": "huggingface", |
|
"dataset_purpose": "downstream", |
|
"name": "b-mc2/sql-create-context", |
|
"train_split": 0.9, |
|
"max_train_size": 100, |
|
"filling_field": ["question", "context", "answer"], |
|
}, |
|
"testdset": { |
|
"type": "local", |
|
"dataset_purpose": "downstream", |
|
"train_file": "data/Test/TestDataset.json", |
|
"val_file": "data/Test/TestDataset.json", |
|
"test_file": "data/Test/TestDataset.json", |
|
"filling_field": ["prompted_question", "answer"], |
|
}, |
|
"APPS_loader": { |
|
"type": "list-like", |
|
"dataset_purpose": "downstream", |
|
"train": "data/APPS/apps_train.json", |
|
"val": "data/APPS/test/apps_test_1.json", |
|
"test": "data/APPS/test/apps_test_75.json", |
|
"filling_field": ["Question", "Answer"], |
|
}, |
|
"MBPP_loader": { |
|
"type": "list-like", |
|
"dataset_purpose": "downstream", |
|
"train": "data/MBPP/mbpp_train.json", |
|
"val": "data/MBPP/mbpp_test.json", |
|
"test": "data/MBPP/mbpp_dev.json", |
|
"filling_field": ["Question", "Answer"], |
|
}, |
|
"Arithmetic_Simple": { |
|
"type": "list-like", |
|
"dataset_purpose": "downstream", |
|
"attributes": { |
|
"subjects": [1, 2, 3, 4, 5, 6, 7, 8, 9], |
|
"lessons": [ |
|
"Max_Ops1_Bounds0_100", |
|
"Max_Ops1_Bounds0_1000", |
|
"Max_Ops2_Bounds0_100", |
|
"Max_Ops2_Bounds0_1000", |
|
"Max_Ops3_Bounds0_100", |
|
"Max_Ops3_Bounds0_1000", |
|
"Max_Ops4_Bounds0_100", |
|
"Max_Ops4_Bounds0_1000", |
|
"Max_Ops5_Bounds0_100", |
|
"Max_Ops5_Bounds0_1000", |
|
] |
|
}, |
|
"train": "data/Arithmetic/Curriculum_Simple", |
|
"val": "data/Arithmetic/Curriculum_Simple", |
|
"test": "data/Arithmetic/Curriculum_Simple", |
|
"filling_field": ["Question", "Answer"], |
|
}, |
|
"Arithmetic_Hard": { |
|
"type": "list-like", |
|
"dataset_purpose": "downstream", |
|
"attributes": { |
|
"subjects": [1, 2, 3, 4, 5, 6, 7, 8, 9], |
|
"lessons": [ |
|
"Max_Ops1_Bounds-1000_1000", |
|
"Max_Ops1_Bounds-100_100", |
|
"Max_Ops1_Bounds0_100", |
|
"Max_Ops1_Bounds0_1000", |
|
"Max_Ops2_Bounds-1000_1000", |
|
"Max_Ops2_Bounds-100_100", |
|
"Max_Ops2_Bounds0_100", |
|
"Max_Ops2_Bounds0_1000", |
|
"Max_Ops3_Bounds-1000_1000", |
|
"Max_Ops3_Bounds-100_100", |
|
"Max_Ops3_Bounds0_100", |
|
"Max_Ops3_Bounds0_1000", |
|
"Max_Ops4_Bounds-1000_1000", |
|
"Max_Ops4_Bounds-100_100", |
|
"Max_Ops4_Bounds0_100", |
|
"Max_Ops4_Bounds0_1000", |
|
"Max_Ops5_Bounds-1000_1000", |
|
"Max_Ops5_Bounds-100_100", |
|
"Max_Ops5_Bounds0_100", |
|
"Max_Ops5_Bounds0_1000", |
|
"Max_Ops6_Bounds-1000_1000", |
|
"Max_Ops6_Bounds-100_100", |
|
"Max_Ops6_Bounds0_100", |
|
"Max_Ops6_Bounds0_1000", |
|
"Max_Ops7_Bounds-1000_1000", |
|
"Max_Ops7_Bounds-100_100", |
|
"Max_Ops7_Bounds0_100", |
|
"Max_Ops7_Bounds0_1000", |
|
"Max_Ops8_Bounds-1000_1000", |
|
"Max_Ops8_Bounds-100_100", |
|
"Max_Ops8_Bounds0_100", |
|
"Max_Ops8_Bounds0_1000", |
|
"Max_Ops9_Bounds-1000_1000", |
|
"Max_Ops9_Bounds-100_100", |
|
"Max_Ops9_Bounds0_100", |
|
"Max_Ops9_Bounds0_1000", |
|
"Max_Ops10_Bounds-1000_1000", |
|
"Max_Ops10_Bounds-100_100", |
|
"Max_Ops10_Bounds0_100", |
|
"Max_Ops10_Bounds0_1000", |
|
] |
|
}, |
|
"train": "data/Arithmetic/Curriculum_Hard", |
|
"val": "data/Arithmetic/Curriculum_Hard", |
|
"test": "data/Arithmetic/Curriculum_Hard", |
|
"filling_field": ["Question", "Answer"], |
|
}, |
|
"Arithmetic_XHard": { |
|
"type": "list-like", |
|
"dataset_purpose": "downstream", |
|
"attributes": { |
|
"subjects": [1, 2, 3, 4, 5, 6, 7, 8, 9], |
|
"lessons": [ |
|
"Max_Ops10_Bounds0_10000.json", |
|
"Max_Ops10_Bounds0_1000.json", |
|
"Max_Ops10_Bounds-10000_10000.json", |
|
"Max_Ops10_Bounds-1000_1000.json", |
|
"Max_Ops11_Bounds0_10000.json", |
|
"Max_Ops11_Bounds0_1000.json", |
|
"Max_Ops11_Bounds-10000_10000.json", |
|
"Max_Ops11_Bounds-1000_1000.json", |
|
"Max_Ops12_Bounds0_10000.json", |
|
"Max_Ops12_Bounds0_1000.json", |
|
"Max_Ops12_Bounds-10000_10000.json", |
|
"Max_Ops12_Bounds-1000_1000.json", |
|
"Max_Ops13_Bounds0_10000.json", |
|
"Max_Ops13_Bounds0_1000.json", |
|
"Max_Ops13_Bounds-10000_10000.json", |
|
"Max_Ops13_Bounds-1000_1000.json", |
|
"Max_Ops14_Bounds0_10000.json", |
|
"Max_Ops14_Bounds0_1000.json", |
|
"Max_Ops14_Bounds-10000_10000.json", |
|
"Max_Ops14_Bounds-1000_1000.json", |
|
"Max_Ops15_Bounds0_10000.json", |
|
"Max_Ops15_Bounds0_1000.json", |
|
"Max_Ops15_Bounds-10000_10000.json", |
|
"Max_Ops15_Bounds-1000_1000.json", |
|
"Max_Ops16_Bounds0_10000.json", |
|
"Max_Ops16_Bounds0_1000.json", |
|
"Max_Ops16_Bounds-10000_10000.json", |
|
"Max_Ops16_Bounds-1000_1000.json", |
|
"Max_Ops17_Bounds0_10000.json", |
|
"Max_Ops17_Bounds0_1000.json", |
|
"Max_Ops17_Bounds-10000_10000.json", |
|
"Max_Ops17_Bounds-1000_1000.json", |
|
"Max_Ops18_Bounds0_10000.json", |
|
"Max_Ops18_Bounds0_1000.json", |
|
"Max_Ops18_Bounds-10000_10000.json", |
|
"Max_Ops18_Bounds-1000_1000.json", |
|
"Max_Ops19_Bounds0_10000.json", |
|
"Max_Ops19_Bounds0_1000.json", |
|
"Max_Ops19_Bounds-10000_10000.json", |
|
"Max_Ops19_Bounds-1000_1000.json", |
|
"Max_Ops1_Bounds0_10000.json", |
|
"Max_Ops1_Bounds0_1000.json", |
|
"Max_Ops1_Bounds-10000_10000.json", |
|
"Max_Ops1_Bounds-1000_1000.json", |
|
"Max_Ops20_Bounds0_10000.json", |
|
"Max_Ops20_Bounds0_1000.json", |
|
"Max_Ops20_Bounds-10000_10000.json", |
|
"Max_Ops20_Bounds-1000_1000.json", |
|
"Max_Ops2_Bounds0_10000.json", |
|
"Max_Ops2_Bounds0_1000.json", |
|
"Max_Ops2_Bounds-10000_10000.json", |
|
"Max_Ops2_Bounds-1000_1000.json", |
|
"Max_Ops3_Bounds0_10000.json", |
|
"Max_Ops3_Bounds0_1000.json", |
|
"Max_Ops3_Bounds-10000_10000.json", |
|
"Max_Ops3_Bounds-1000_1000.json", |
|
"Max_Ops4_Bounds0_10000.json", |
|
"Max_Ops4_Bounds0_1000.json", |
|
"Max_Ops4_Bounds-10000_10000.json", |
|
"Max_Ops4_Bounds-1000_1000.json", |
|
"Max_Ops5_Bounds0_10000.json", |
|
"Max_Ops5_Bounds0_1000.json", |
|
"Max_Ops5_Bounds-10000_10000.json", |
|
"Max_Ops5_Bounds-1000_1000.json", |
|
"Max_Ops6_Bounds0_10000.json", |
|
"Max_Ops6_Bounds0_1000.json", |
|
"Max_Ops6_Bounds-10000_10000.json", |
|
"Max_Ops6_Bounds-1000_1000.json", |
|
"Max_Ops7_Bounds0_10000.json", |
|
"Max_Ops7_Bounds0_1000.json", |
|
"Max_Ops7_Bounds-10000_10000.json", |
|
"Max_Ops7_Bounds-1000_1000.json", |
|
"Max_Ops8_Bounds0_10000.json", |
|
"Max_Ops8_Bounds0_1000.json", |
|
"Max_Ops8_Bounds-10000_10000.json", |
|
"Max_Ops8_Bounds-1000_1000.json", |
|
"Max_Ops9_Bounds0_10000.json", |
|
"Max_Ops9_Bounds0_1000.json", |
|
"Max_Ops9_Bounds-10000_10000.json", |
|
"Max_Ops9_Bounds-1000_1000.json", |
|
] |
|
}, |
|
"train": "data/Arithmetic/Curriculum_XHard", |
|
"val": "data/Arithmetic/Curriculum_XHard", |
|
"test": "data/Arithmetic/Curriculum_XHard", |
|
"filling_field": ["Question", "Answer"], |
|
}, |
|
"GSM8K": { |
|
"type": "local", |
|
"dataset_purpose": "downstream", |
|
"train_file": "data/GSM8K/GSM8K_train.json", |
|
"val_file": "data/GSM8K/GSM8K_test.json", |
|
"test_file": "data/GSM8K/GSM8K_dev.json", |
|
"filling_field": ["Body", "Question", "Answer"], |
|
}, |
|
"APPS": { |
|
"type": "local", |
|
"dataset_purpose": "downstream", |
|
"train_file": "data/APPS/apps_train.json", |
|
"val_file": "data/APPS/apps_test.json", |
|
"test_file": "data/APPS/apps_dev.json", |
|
"filling_field": ["Body", "Question", "Answer"], |
|
}, |
|
"ghcode_python": { |
|
"type": "huggingface", |
|
"dataset_purpose": "pretrain", |
|
"name": "slseanwu/ghcode_python_split_700k", |
|
"max_eval_size": 1000, |
|
"max_train_size": 160000, |
|
"filling_field": ["code"], |
|
}, |
|
}, |
|
} |
|
|
|
|
|
if DEBUG: |
|
config.epochs = 100 |
|
config.save_steps = 10 |
|
config.train_dataset = "local-test-train" |
|
config.val_dataset = "local-test-dev" |
|
config.test_dataset = "test-clean" |
|
|