Spaces:
Runtime error
Runtime error
File size: 7,982 Bytes
1fb65ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import os
import torch
import transformers
import matplotlib.pyplot as plt
from datetime import datetime
from functools import partial
from peft import LoraConfig, get_peft_model
from peft import prepare_model_for_kbit_training
from datasets import load_dataset
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
def formatting_func_QA(example):
text = f"### Question: Given an image prompt {example['input']}\n give me random Edit Action and the output prompt \n ### Answer: Here is the edit action {example['edit']}, and here is the output {example['output']}"
return text
def formatting_func_Edit(example, is_train=True):
text = f"### Categorizes image editing actions, outputting classifications in the format 'Edit Class: A,B,C'. In this format, 'A' represents whether the edit is 'Global' or 'Local', and 'B' denotes the specific type of manipulation, such as 'Filter', 'Stylization', 'SceneChange', etc. 'C' denotes a specified 'B' such as 'FujiFilter', 'Part' etc. This structured approach provides clear and concise information, facilitating easy understanding of the edit class. The GPT remains committed to a formal, user-friendly communication style, ensuring the classifications are accessible and precise, without delving into technical complexities.\
Question: Given the Edit Action {example['edit']}, what is its edit type?\n"
if is_train:
text = text + f"### Answer: Edit Class: {example['class']}"
return text
def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset):
lengths = [len(x['input_ids']) for x in tokenized_train_dataset]
lengths += [len(x['input_ids']) for x in tokenized_val_dataset]
print(len(lengths))
# Plotting the histogram
plt.figure(figsize=(10, 6))
plt.hist(lengths, bins=10, alpha=0.7, color='blue')
plt.xlabel('Length of input_ids')
plt.ylabel('Frequency')
plt.title('Distribution of Lengths of input_ids')
# Saving the figure to a file
plt.savefig('./experiments/figure.png') # Spe
def generate_and_tokenize_prompt(prompt, formatting=None):
return tokenizer(formatting(prompt))
def generate_and_tokenize_prompt2(prompt, max_length=512, formatting=None):
result = tokenizer(
formatting(prompt),
truncation=True,
max_length=max_length,
padding="max_length",
)
result["labels"] = result["input_ids"].copy()
return result
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def train():
generate_and_tokenize = partial(generate_and_tokenize_prompt2,
max_length=128,
formatting=formatting_func_Edit)
# configs here latter change
model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/"
output_root = "/mlx/users/peng.wang/playground/data/chat_edit/models/llm"
output_root = "/opt/tiger/llm"
os.makedirs(output_root, exist_ok=True)
######### Tune model with Mixtral MoE #########
base_model_id = f"{model_root}/Mixtral-8x7B-v0.1"
base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1"
base_model_name = "mixtral-8x7b"
# ######### Tune model with Mixtral Instruct 7B #########
# base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2"
# base_model_name = "mixtral-7b"
######### Instructions #########
train_json = "./data/chat_edit/assets/test200/edit_instructions_v0.jsonl"
val_json = train_json
project = "edit-finetune"
run_name = base_model_name + "-" + project
output_dir = f"{output_root}/{run_name}"
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
train_dataset = load_dataset('json', data_files=train_json, split='train')
eval_dataset = load_dataset('json', data_files=val_json, split='train')
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
padding_side="left",
add_eos_token=True,
add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token
tokenized_train_dataset = train_dataset.map(generate_and_tokenize)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize)
print(tokenized_train_dataset[1]['input_ids'])
plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)
# load model and do finetune
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
base_model_id, quantization_config=bnb_config, device_map="auto")
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
print(model)
config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"w1",
"w2",
"w3",
"lm_head",
],
bias="none",
lora_dropout=0.01, # Conventional
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
print_trainable_parameters(model)
print(model)
## RUN training ##
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
padding_side="left",
add_eos_token=True,
add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token
if torch.cuda.device_count() > 1: # If more than 1 GPU
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_val_dataset,
args=transformers.TrainingArguments(
output_dir=output_dir,
warmup_steps=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
max_steps=100,
learning_rate=2.5e-5, # Want a small lr for finetuning
fp16=True,
optim="paged_adamw_8bit",
logging_steps=25, # When to start reporting loss
logging_dir="./experiments/logs", # Directory for storing logs
save_strategy="steps", # Save the model checkpoint every logging step
save_steps=100, # Save checkpoints every 50 steps
evaluation_strategy="steps", # Evaluate the model every logging step
eval_steps=25, # Evaluate and save checkpoints every 50 steps
do_eval=True, # Perform evaluation at the end of training
report_to="wandb", # Comment this out if you don't want to use weights & baises
run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" # Name of the W&B run (optional)
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train()
if __name__ == '__main__':
train() |