Spaces:
Runtime error
Runtime error
import os | |
import tqdm | |
import joblib | |
import numpy as np | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import peft | |
import loralib as lora | |
from peft import LoraConfig | |
import json | |
from torch.utils.data import DataLoader | |
from torch.utils.data import Dataset | |
from accelerate import Accelerator, DeepSpeedPlugin | |
from transformers import get_linear_schedule_with_warmup | |
""" | |
extra requirements: | |
pip install icetk | |
""" | |
# reload the model: no int8, so 14GB is needed | |
version = 'no.pt' # finetune_0.pt | |
model_dir = '/home/liang/lzy_tmp/models/chatglm-6b' | |
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | |
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True) | |
config = LoraConfig( | |
peft_type="LORA", | |
task_type="SEQ_2_SEQ_LM", | |
r=32, | |
lora_alpha=32, | |
target_modules=["q", "k", "v"], | |
lora_dropout=0.1, | |
) | |
class QKV_layer(torch.nn.Module): | |
def __init__(self, in_features, out_features): | |
super(QKV_layer, self).__init__() | |
self.linear_q = torch.nn.Linear(in_features, out_features//3) | |
self.linear_k = torch.nn.Linear(in_features, out_features//3) | |
self.linear_v = torch.nn.Linear(in_features, out_features//3) | |
def update(self, target_layer): | |
self.linear_q.weight.data = target_layer.weight[:target_layer.out_features//3, :].data | |
self.linear_q.bias.data = target_layer.bias[:target_layer.out_features//3].data | |
self.linear_k.weight.data = target_layer.weight[target_layer.out_features//3:target_layer.out_features//3*2, :].data | |
self.linear_k.bias.data = target_layer.bias[target_layer.out_features//3:target_layer.out_features//3*2].data | |
self.linear_v.weight.data = target_layer.weight[target_layer.out_features//3*2:, :].data | |
self.linear_v.bias.data = target_layer.bias[target_layer.out_features//3*2:].data | |
def forward(self, x): | |
q = self.linear_q(x) | |
k = self.linear_k(x) | |
v = self.linear_v(x) | |
return torch.concat([q,k,v], dim = -1) | |
if version != 'no.pt': | |
# convert it again | |
for key, module in model.named_modules(): | |
if key.endswith('attention'): | |
try: | |
qkv_layer = QKV_layer(module.query_key_value.in_features, module.query_key_value.out_features) | |
qkv_layer.update(module.query_key_value) | |
module.query_key_value = qkv_layer | |
except: | |
pass | |
module.query_key_value = peft.tuners.lora.LoraModel(config, module.query_key_value) | |
# load the LoRA checkpoint | |
model.load_state_dict(torch.load(f'/{model_dir}/{version}'), strict=False) | |
model.half().cuda().eval() | |
# Let's chat! | |
os.makedirs('outs/chatglm-6b/', exist_ok=True) | |
with open(f'outs/chatglm-6b/test_{version}.txt','w') as f: | |
for text in open('sample/test.jsonl'): | |
text = json.loads(text) | |
inputs = text['instruction'] | |
print('Q:', inputs) | |
print('Q:', inputs, file=f) | |
response, history = model.chat(tokenizer, inputs, history=[]) | |
print('A:', response) | |
print('A:', response, '\n',file=f) |