from safetensors.torch import load_file import torch from tqdm import tqdm __all__ = [ 'flux_load_lora' ] def is_int(d): try: d = int(d) return True except Exception as e: return False def flux_load_lora(self, lora_file, lora_weight=1.0): device = self.transformer.device # DiT 部分 state_dict, network_alphas = self.lora_state_dict(lora_file, return_alphas=True) state_dict = {k:v.to(device) for k,v in state_dict.items()} model = self.transformer keys = list(state_dict.keys()) keys = [k for k in keys if k.startswith('transformer.')] for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in transformer ..."): v_lora = state_dict[k_lora] # 非 up 的都跳过 if '.lora_A.weight' in k_lora: continue if '.alpha' in k_lora: continue k_lora_name = k_lora.replace("transformer.", "") k_lora_name = k_lora_name.replace(".lora_B.weight", "") attr_name_list = k_lora_name.split('.') cur_attr = model latest_attr_name = '' for idx in range(0, len(attr_name_list)): attr_name = attr_name_list[idx] if is_int(attr_name): cur_attr = cur_attr[int(attr_name)] latest_attr_name = '' else: try: if latest_attr_name != '': cur_attr = cur_attr.__getattr__(f"{latest_attr_name}.{attr_name}") else: cur_attr = cur_attr.__getattr__(attr_name) latest_attr_name = '' except Exception as e: if latest_attr_name != '': latest_attr_name = f"{latest_attr_name}.{attr_name}" else: latest_attr_name = attr_name up_w = v_lora down_w = state_dict[k_lora.replace('.lora_B.weight', '.lora_A.weight')] # 赋值 einsum_a = f"ijabcdefg" einsum_b = f"jkabcdefg" einsum_res = f"ikabcdefg" length_shape = len(up_w.shape) einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}" dtype = cur_attr.weight.data.dtype d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype) cur_attr.weight.data = cur_attr.weight.data + d_w * lora_weight # text encoder 部分 raw_state_dict = load_file(lora_file) raw_state_dict = {k:v.to(device) for k,v in raw_state_dict.items()} # text encoder state_dict = {k:v for k,v in raw_state_dict.items() if 'lora_te1_' in k} model = self.text_encoder keys = list(state_dict.keys()) keys = [k for k in keys if k.startswith('lora_te1_')] for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in text_encoder ..."): v_lora = state_dict[k_lora] # 非 up 的都跳过 if '.lora_down.weight' in k_lora: continue if '.alpha' in k_lora: continue k_lora_name = k_lora.replace("lora_te1_", "") k_lora_name = k_lora_name.replace(".lora_up.weight", "") attr_name_list = k_lora_name.split('_') cur_attr = model latest_attr_name = '' for idx in range(0, len(attr_name_list)): attr_name = attr_name_list[idx] if is_int(attr_name): cur_attr = cur_attr[int(attr_name)] latest_attr_name = '' else: try: if latest_attr_name != '': cur_attr = cur_attr.__getattr__(f"{latest_attr_name}_{attr_name}") else: cur_attr = cur_attr.__getattr__(attr_name) latest_attr_name = '' except Exception as e: if latest_attr_name != '': latest_attr_name = f"{latest_attr_name}_{attr_name}" else: latest_attr_name = attr_name up_w = v_lora down_w = state_dict[k_lora.replace('.lora_up.weight', '.lora_down.weight')] alpha = state_dict.get(k_lora.replace('.lora_up.weight', '.alpha'), None) if alpha is None: lora_scale = 1 else: rank = up_w.shape[1] lora_scale = alpha / rank # 赋值 einsum_a = f"ijabcdefg" einsum_b = f"jkabcdefg" einsum_res = f"ikabcdefg" length_shape = len(up_w.shape) einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}" dtype = cur_attr.weight.data.dtype d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype) cur_attr.weight.data = cur_attr.weight.data + d_w * lora_scale * lora_weight