import torch from comfy.model_detection import count_blocks from comfy.utils import flux_to_diffusers def convert_diffusers_flux_lora(state_dict, output_prefix=""): out_sd = {} state_dict = {key.replace('transformer.', ''): value for key, value in state_dict.items()} new_sd = {} for lora_a in state_dict.keys(): if "lora_A" not in lora_a: continue lora_b = lora_a.replace("lora_A", "lora_B") key = lora_a.replace("lora_A.", "") assert "lora_A" in lora_a and "lora_B" in lora_b, f"Invalid LoRA checkpoint. {lora_a} {lora_b}" new_sd[key] = state_dict[lora_b] @ state_dict[lora_a] state_dict = new_sd depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') hidden_size = state_dict["x_embedder.weight"].shape[0] sd_map = flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) for k in sd_map: weight = state_dict.get(k, None) if weight is not None: t = sd_map[k] if not isinstance(t, str): if len(t) > 2: fun = t[2] else: fun = lambda a: a offset = t[1] if offset is not None: old_weight = out_sd.get(t[0], None) if old_weight is None: old_weight = torch.empty_like(weight) if old_weight.shape[offset[0]] < offset[1] + offset[2]: exp = list(weight.shape) exp[offset[0]] = offset[1] + offset[2] new = torch.empty(exp, device=weight.device, dtype=weight.dtype) new[:old_weight.shape[0]] = old_weight old_weight = new w = old_weight.narrow(offset[0], offset[1], offset[2]) else: old_weight = weight w = weight w[:] = fun(weight) t = t[0] out_sd[t] = old_weight else: out_sd[t] = weight state_dict.pop(k) return out_sd