File size: 2,633 Bytes
3098aa9
 
 
 
9aa9618
 
3098aa9
 
a093e2c
e03b1b6
3098aa9
 
 
 
 
a093e2c
3098aa9
 
 
 
a093e2c
 
3098aa9
 
 
 
a093e2c
 
36941e8
1b576c8
3098aa9
a093e2c
3098aa9
7043406
a093e2c
7043406
3098aa9
 
 
 
a093e2c
3098aa9
a093e2c
 
9d1ec84
 
a093e2c
 
 
 
 
 
3098aa9
 
a093e2c
3098aa9
a093e2c
 
 
 
 
3098aa9
 
a093e2c
 
 
 
 
 
 
 
3098aa9
a093e2c
3098aa9
 
 
9aa9618
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import torch
import os

def main():
    # 基础模型位置
    model_name = "dushuai112233/Qwen2-1.5B-Instruct"
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 加载分词器和模型
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

    # Setup PEFT
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1
    )
    model = get_peft_model(model, peft_config)

    # 加载数据集
    ds = load_dataset("dushuai112233/medical")
    train_dataset = ds["train"]
    val_dataset = ds["validation"]

    # 数据集预处理
    def tokenize_function(examples):
        encodings = tokenizer(examples['question'], padding='max_length', truncation=True, max_length=128)
        encodings['labels'] = encodings['input_ids'].copy()
        return encodings

    train_dataset = train_dataset.map(tokenize_function, batched=True)
    val_dataset = val_dataset.map(tokenize_function, batched=True)

    # 设置训练参数
    training_args = TrainingArguments(
        output_dir="./output",
        evaluation_strategy="epoch",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        logging_dir="./logs",
        logging_steps=10,
        save_steps=100,                 # 每 100 步保存一次检查点
        save_total_limit=2,             # 限制最多保存 2 个检查点
        num_train_epochs=10,
        load_best_model_at_end=False,   # 是否在训练结束时加载最优模型
    )

    # 定义 Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
    )

    # 检查是否有中断点
    checkpoint = None
    if os.path.exists("./output") and len(os.listdir("./output")) > 0:
        checkpoint = max([os.path.join("./output", ckpt) for ckpt in os.listdir("./output")], key=os.path.getmtime)
        print(f"Resuming training from checkpoint: {checkpoint}")

    # 开始训练
    trainer.train(resume_from_checkpoint=checkpoint)

    # 保存最终模型
    model.save_pretrained('./output')

if __name__ == '__main__':
    main()