Spaces:
Paused
Paused
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments | |
from peft import LoraConfig, get_peft_model, TaskType | |
from datasets import load_dataset | |
import torch | |
import os | |
def main(): | |
# 基础模型位置 | |
model_name = "dushuai112233/Qwen2-1.5B-Instruct" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# 加载分词器和模型 | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) | |
# Setup PEFT | |
peft_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
inference_mode=False, | |
r=8, | |
lora_alpha=32, | |
lora_dropout=0.1 | |
) | |
model = get_peft_model(model, peft_config) | |
# 加载数据集 | |
ds = load_dataset("dushuai112233/medical") | |
train_dataset = ds["train"] | |
val_dataset = ds["validation"] | |
# 数据集预处理 | |
def tokenize_function(examples): | |
encodings = tokenizer(examples['question'], padding='max_length', truncation=True, max_length=128) | |
encodings['labels'] = encodings['input_ids'].copy() | |
return encodings | |
train_dataset = train_dataset.map(tokenize_function, batched=True) | |
val_dataset = val_dataset.map(tokenize_function, batched=True) | |
# 设置训练参数 | |
training_args = TrainingArguments( | |
output_dir="./output", | |
evaluation_strategy="epoch", | |
per_device_train_batch_size=1, | |
per_device_eval_batch_size=1, | |
logging_dir="./logs", | |
logging_steps=10, | |
save_steps=100, # 每 100 步保存一次检查点 | |
save_total_limit=2, # 限制最多保存 2 个检查点 | |
num_train_epochs=10, | |
load_best_model_at_end=False, # 是否在训练结束时加载最优模型 | |
) | |
# 定义 Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
tokenizer=tokenizer, | |
) | |
# 检查是否有中断点 | |
checkpoint = None | |
if os.path.exists("./output") and len(os.listdir("./output")) > 0: | |
checkpoint = max([os.path.join("./output", ckpt) for ckpt in os.listdir("./output")], key=os.path.getmtime) | |
print(f"Resuming training from checkpoint: {checkpoint}") | |
# 开始训练 | |
trainer.train(resume_from_checkpoint=checkpoint) | |
# 保存最终模型 | |
model.save_pretrained('./output') | |
if __name__ == '__main__': | |
main() | |