from collections import OrderedDict from typing import Dict import typing from rwkv.model import RWKV as RWKV_UPSTREAM import types, gc, os, time, re import torch from torch.nn import functional as F # valid_filter_pattern = r"(((\d+\.\d+\*)?(\d+)(-\d+)?(/\S+)?|(/\S+))(\s+|$))+" def get_filter_keys_and_merge_coef(layer_filter): if layer_filter: layers = [] layer_coef = {} layer_remove_patterns = {} for layer in layer_filter.split(' '): if '/' in layer: #过滤pattern,需要写成正则表达式 layer,_,remove_pattern = layer.partition('/') remove_pattern = re.compile(remove_pattern) else: remove_pattern = None if layer=='': layer_remove_patterns['global']=remove_pattern continue if '*' in layer: coef,_,layer = layer.partition('*') coef = float(coef) else: coef = 1 if layer.isdecimal(): layers.append(int(layer)) layer_coef[int(layer)]=coef layer_remove_patterns[int(layer)]=remove_pattern elif '-' in layer: start,_,end = layer.partition('-') start,end = int(start),int(end) layers.extend(range(start,end+1)) for l in range(start,end+1): layer_coef[l] = coef layer_remove_patterns[l]=remove_pattern else: raise NotImplementedError("layer_filter Not implemented:",layer_filter) layers = sorted(set(layers)) # layer_prefixes = tuple(f"blocks.{l}." for l in layers) def filter_keys(keys): new_keys = [] for key in keys: if layer_remove_patterns.get("global") and layer_remove_patterns['global'].search(key): continue #符合全局去除规则 if key.startswith("blocks."): #过滤掉blocks开头,且不在允许范围内的权重 l = int(key.split('.')[1]) if l not in layers: #不在允许层,过滤掉 continue if layer_remove_patterns[l] and layer_remove_patterns[l].search(key): #符合对应层的去除规则,过滤掉 continue # if not key.startswith(layer_prefixes): # continue new_keys.append(key) return new_keys def merge_coef(key): if key.startswith('blocks.') and int(key.split('.')[1]) in layer_coef: return layer_coef[int(key.split('.')[1])] else: return 1 else: def filter_keys(keys): return keys def merge_coef(key): return 1 return filter_keys,merge_coef def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,): print(f"Loading LoRA: {lora}") print(f"LoRA alpha={lora_alpha}, layer_filter={layer_filter}") filter_keys,merge_coef = get_filter_keys_and_merge_coef(layer_filter) w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') # merge LoRA-only slim checkpoint into the main weights w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') # pdb.set_trace() #DEBUG for k in filter_keys(w_lora.keys()): #处理time_mixing之类的融合 if k in w: print(f"replacing {k}") w[k] = w_lora[k] output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() # merge LoRA weights keys = list(w.keys()) for k in keys: if k.endswith('.weight'): prefix = k[:-len('.weight')] lora_A = prefix + '.lora_A' lora_B = prefix + '.lora_B' if lora_A in keys: assert lora_B in keys print(f'merging {lora_A} and {lora_B} into {k}') assert w[lora_B].shape[1] == w[lora_A].shape[0] lora_r = w[lora_B].shape[1] w[k] = w[k].to(device=device) w[lora_A] = w[lora_A].to(device=device) w[lora_B] = w[lora_B].to(device=device) w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) * merge_coef(k) output_w[k] = w[k].to(device='cpu', copy=True) del w[k] del w[lora_A] del w[lora_B] continue if 'lora' not in k: print(f'retaining {k}') output_w[k] = w[k].clone() del w[k] return output_w class RWKV(RWKV_UPSTREAM): def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None,lora=None,lora_alpha=0,lora_layer_filter=None): super(RWKV_UPSTREAM,self).__init__() if verbose: prxxx = lambda *args, **kwargs: print(*args, **kwargs) else: prxxx = lambda *args, **kwargs: None STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" if not re.match(STRATEGY_REGEX, strategy): raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/") strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ') self.args = types.SimpleNamespace() args = self.args args.MODEL_NAME = model args.strategy_string = strategy # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0 prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n') args.MODEL_NAME = args.MODEL_NAME.strip() if not args.MODEL_NAME.endswith('.pth'): args.MODEL_NAME += '.pth' prxxx(f'Loading {args.MODEL_NAME} ...') with torch.no_grad(): if lora: self.w = lora_merge(base_model=args.MODEL_NAME,lora=lora, lora_alpha=lora_alpha,layer_filter=lora_layer_filter, device=('cuda' if 'cuda' in strategy else 'cpu')) else: self.w = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first gc.collect() w = self.w ALREADY_CONVERTED = False if '_strategy' in w: ALREADY_CONVERTED = True assert convert_and_save_and_exit == None # you should only convert a raw model prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n") assert w['_strategy'] == args.strategy_string # if you are using a new strategy, re-convert the model assert float(w['_version']) >= 0.7 # sometimes you should re-convert using latest convert_model.py assert w['_rescale_layer'] == self.RESCALE_LAYER del w['_strategy'] del w['_version'] del w['_rescale_layer'] args.n_embd = w['emb.weight'].shape[1] args.n_layer = 0 keys = list(w.keys()) for x in keys: layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 args.n_layer = max(args.n_layer, layer_id+1) ####################### Compute strategy s = [x.strip().split(' ') for x in strategy.split('->')] plan = [0] * len(s) stream_i = -1 stream_count = 0 to_allocate = args.n_layer + 1 allocated = 0 free_slots = 0 for i in range(len(s)): si = s[i] si1 = si[1] if si1.startswith('fp32'): si[1] = [torch.float] elif si1.startswith('fp16'): si[1] = [torch.float16] elif si1.startswith('bf16'): si[1] = [torch.bfloat16] if si1.endswith('i8'): si[1] += [torch.uint8] else: si[1] += [si[1][0]] if len(si) > 2: ss = si[2] assert ss.startswith('*') if ss.endswith('+'): plan[i] = int(ss[1:-1]) stream_i = i else: plan[i] = int(ss[1:]) allocated += plan[i] if allocated >= to_allocate: plan[i] += to_allocate - allocated break else: free_slots += 1 if stream_i < 0: if free_slots > 0 and to_allocate > allocated: for i in range(len(s)): if plan[i] == 0: plan[i] = (to_allocate - allocated) // free_slots allocated += plan[i] free_slots -= 1 if to_allocate > allocated: plan[len(s)-1] += to_allocate - allocated else: if to_allocate > allocated: stream_count = to_allocate - allocated plan[stream_i] += stream_count prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)') for i in range(len(s)): ss = s[i] if i != stream_i: prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers') else: prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers') plan[i] += (0 if i == 0 else plan[i-1]) self.strategy = [None] * (args.n_layer + 1) strategy = self.strategy for n in range(args.n_layer + 1): for i in range(len(s)): if n < plan[i]: strategy[n] = types.SimpleNamespace() strategy[n].device = s[i][0] strategy[n].atype = s[i][1][0] strategy[n].wtype = s[i][1][1] strategy[n].stream = False if i == stream_i and n >= (plan[i] - stream_count): strategy[n].stream = True break prxxx(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ') prxxx() ####################### Load weights to self.w if not ALREADY_CONVERTED: try: # precompute embedding w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) except: w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float()) del w['blocks.0.ln0.weight'] del w['blocks.0.ln0.bias'] print_need_newline = False keys = list(w.keys()) for x in keys: w[x].requires_grad = False layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 if ('ln_out.' in x) or ('head.' in x): layer_id = args.n_layer dd = strategy[layer_id] DEVICE = dd.device ATYPE = dd.atype WTYPE = dd.wtype if not ALREADY_CONVERTED: if self.RESCALE_LAYER > 0: if 'att.output.weight' in x: w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) if 'ffn.value.weight' in x: w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) if '.time_' in x: w[x] = w[x].squeeze() if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x: w[x] = w[x].t() if '.time_decay' in x: # need fp32 for this w[x] = -torch.exp(w[x].float()) elif '.time_first' in x: # need fp32 for this w[x] = w[x].float() else: if (len(w[x].shape) == 2) and ('emb' not in x): if WTYPE != torch.uint8: w[x] = w[x].to(dtype=WTYPE) else: w[x] = w[x].float() if w[x].shape[0] > w[x].shape[1]: w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) w[x] = w[x] - w[x+'_my'] w[x+'_mx'] = torch.amin(w[x], dim=0) w[x] = w[x] - w[x+'_mx'] w[x+'_rx'] = torch.amax(w[x], dim=0) w[x] = w[x] / w[x+'_rx'] w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) w[x] = w[x] / w[x+'_ry'] else: w[x+'_mx'] = torch.amin(w[x], dim=0) w[x] = w[x] - w[x+'_mx'] w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) w[x] = w[x] - w[x+'_my'] w[x+'_rx'] = torch.amax(w[x], dim=0) w[x] = w[x] / w[x+'_rx'] w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) w[x] = w[x] / w[x+'_ry'] w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8) w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() else: w[x] = w[x].to(dtype=ATYPE) if convert_and_save_and_exit == None: if 'emb.' in x: w[x] = w[x].contiguous() elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')): try: w[x] = w[x].contiguous().pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) except: print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.') elif DEVICE != 'cpu': w[x] = w[x].to(device=DEVICE).contiguous() if (dd.stream) or (DEVICE != 'cpu'): try: w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous() w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous() w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous() w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous() except: pass if 'ffn.value.weight' in x: gc.collect() if 'cuda' in args.strategy_string: torch.cuda.empty_cache() shape = [i for i in w[x].shape if i != 1] if len(shape) > 1: shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" else: shape = f" {str(shape[0]).rjust(5)} " if layer_id == 0 or layer_id >= args.n_layer-1: if print_need_newline: prxxx('\n', end = '') print_need_newline = False dt = str(w[x].dtype).replace('torch.', '') dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8') prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '') else: print_need_newline = True prxxx('.', end = '', flush = True) if convert_and_save_and_exit: w['_strategy'] = args.strategy_string w['_rescale_layer'] = self.RESCALE_LAYER w['_version'] = '0.7' if not convert_and_save_and_exit.endswith('.pth'): convert_and_save_and_exit += '.pth' prxxx(f'Saving to {convert_and_save_and_exit}...') torch.save(w, convert_and_save_and_exit) prxxx(f'Converted and saved. Now this will exit.') exit(0) gc.collect() if 'cuda' in args.strategy_string: torch.cuda.empty_cache()