chatlawv1 / tools /application /chatglm_lora_test.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
import os
import tqdm
import joblib
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
import peft
import loralib as lora
from peft import LoraConfig
import json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from accelerate import Accelerator, DeepSpeedPlugin
from transformers import get_linear_schedule_with_warmup
"""
extra requirements:
pip install icetk
"""
# reload the model: no int8, so 14GB is needed
version = 'no.pt' # finetune_0.pt
model_dir = '/home/liang/lzy_tmp/models/chatglm-6b'
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True)
config = LoraConfig(
peft_type="LORA",
task_type="SEQ_2_SEQ_LM",
r=32,
lora_alpha=32,
target_modules=["q", "k", "v"],
lora_dropout=0.1,
)
class QKV_layer(torch.nn.Module):
def __init__(self, in_features, out_features):
super(QKV_layer, self).__init__()
self.linear_q = torch.nn.Linear(in_features, out_features//3)
self.linear_k = torch.nn.Linear(in_features, out_features//3)
self.linear_v = torch.nn.Linear(in_features, out_features//3)
def update(self, target_layer):
self.linear_q.weight.data = target_layer.weight[:target_layer.out_features//3, :].data
self.linear_q.bias.data = target_layer.bias[:target_layer.out_features//3].data
self.linear_k.weight.data = target_layer.weight[target_layer.out_features//3:target_layer.out_features//3*2, :].data
self.linear_k.bias.data = target_layer.bias[target_layer.out_features//3:target_layer.out_features//3*2].data
self.linear_v.weight.data = target_layer.weight[target_layer.out_features//3*2:, :].data
self.linear_v.bias.data = target_layer.bias[target_layer.out_features//3*2:].data
def forward(self, x):
q = self.linear_q(x)
k = self.linear_k(x)
v = self.linear_v(x)
return torch.concat([q,k,v], dim = -1)
if version != 'no.pt':
# convert it again
for key, module in model.named_modules():
if key.endswith('attention'):
try:
qkv_layer = QKV_layer(module.query_key_value.in_features, module.query_key_value.out_features)
qkv_layer.update(module.query_key_value)
module.query_key_value = qkv_layer
except:
pass
module.query_key_value = peft.tuners.lora.LoraModel(config, module.query_key_value)
# load the LoRA checkpoint
model.load_state_dict(torch.load(f'/{model_dir}/{version}'), strict=False)
model.half().cuda().eval()
# Let's chat!
os.makedirs('outs/chatglm-6b/', exist_ok=True)
with open(f'outs/chatglm-6b/test_{version}.txt','w') as f:
for text in open('sample/test.jsonl'):
text = json.loads(text)
inputs = text['instruction']
print('Q:', inputs)
print('Q:', inputs, file=f)
response, history = model.chat(tokenizer, inputs, history=[])
print('A:', response)
print('A:', response, '\n',file=f)