LLM / app.py
dushuai112233's picture
Update app.py
9d1ec84 verified
raw
history blame
2.63 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import torch
import os
def main():
# 基础模型位置
model_name = "dushuai112233/Qwen2-1.5B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
# Setup PEFT
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
# 加载数据集
ds = load_dataset("dushuai112233/medical")
train_dataset = ds["train"]
val_dataset = ds["validation"]
# 数据集预处理
def tokenize_function(examples):
encodings = tokenizer(examples['question'], padding='max_length', truncation=True, max_length=128)
encodings['labels'] = encodings['input_ids'].copy()
return encodings
train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
# 设置训练参数
training_args = TrainingArguments(
output_dir="./output",
evaluation_strategy="epoch",
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
logging_dir="./logs",
logging_steps=10,
save_steps=100, # 每 100 步保存一次检查点
save_total_limit=2, # 限制最多保存 2 个检查点
num_train_epochs=10,
load_best_model_at_end=False, # 是否在训练结束时加载最优模型
)
# 定义 Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
# 检查是否有中断点
checkpoint = None
if os.path.exists("./output") and len(os.listdir("./output")) > 0:
checkpoint = max([os.path.join("./output", ckpt) for ckpt in os.listdir("./output")], key=os.path.getmtime)
print(f"Resuming training from checkpoint: {checkpoint}")
# 开始训练
trainer.train(resume_from_checkpoint=checkpoint)
# 保存最终模型
model.save_pretrained('./output')
if __name__ == '__main__':
main()