LLM / app.py
dushuai112233's picture
Update app.py
9aa9618 verified
raw
history blame
2.78 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from torch.utils.tensorboard import SummaryWriter
import torch
import os
def main():
# 基础模型位置
model_name = "dushuai112233/Qwen2-1.5B-Instruct" # 使用你提供的模型
# 设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
# Setup PEFT (Low-Rank Adaption)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
inference_mode=False,
r=8, # 低秩矩阵的秩
lora_alpha=32, # LoRA的alpha超参数
lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
# 加载 Hugging Face 数据集
ds = load_dataset("dushuai112233/medical") # 自动加载 train 和 val 分区
# 提取训练集和验证集
train_dataset = ds["train"]
val_dataset = ds["validation"]
# Tokenize the datasets
def tokenize_function(examples):
return tokenizer(examples['question'], padding='max_length', truncation=True, max_length=128)
train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
# Define Training Arguments
training_args = TrainingArguments(
output_dir="./output", # 保存模型和日志的路径
evaluation_strategy="epoch", # 每个epoch后进行验证
per_device_train_batch_size=1, # 每个设备的batch size
per_device_eval_batch_size=1, # 验证时的batch size
logging_dir="./logs", # 日志目录
logging_steps=10, # 每10步记录一次日志
save_steps=100, # 每100步保存一次模型
num_train_epochs=10, # 训练的epoch数
save_total_limit=2, # 最大保存模型数
)
# Define the Trainer
trainer = Trainer(
model=model, # 训练的模型
args=training_args, # 训练的参数
train_dataset=train_dataset, # 训练数据集
eval_dataset=val_dataset, # 验证数据集
tokenizer=tokenizer, # 用于预处理的分词器
)
# Start Training
trainer.train()
# Save the model
model.save_pretrained('./output')
if __name__ == '__main__':
main()