qwen2vl-flux-mini-demo / flux /flux_network.py
erwold
Initial Commit
49d4954
raw
history blame
7.41 kB
import torch
import torch.nn as nn
from .transformer_flux import FluxTransformer2DModel
class FluxNetwork(nn.Module):
TARGET_REPLACE_MODULE = ["FluxTransformerBlock","FluxSingleTransformerBlock"] # 可训练的模块类型
FLUX_PREFIX = "flux"
def __init__(self, flux_model: FluxTransformer2DModel):
super().__init__()
self.flux_model = flux_model
self.trainable_component_names = [] # 用于记录可训练组件的名称
@staticmethod
def generate_trainable_components(layers, num_transformer_blocks=19):
transformer_components = [
"attn.to_q",
"attn.to_k",
"attn.to_v",
"attn.to_out",
"norm1",
"norm1_context",
]
single_transformer_components = [
"attn.to_q",
"attn.to_k",
"attn.to_v",
"norm",
#"proj_mlp",
]
components = ["context_embedder"] # 添加 context_embedder
for layer in layers:
if layer < num_transformer_blocks:
prefix = f"transformer_blocks.{layer}"
base_components = transformer_components
else:
prefix = f"single_transformer_blocks.{layer - num_transformer_blocks}"
base_components = single_transformer_components
components.extend([f"{prefix}.{comp}" for comp in base_components])
return components
#def apply_to(self, num_layers=1, additional_components=None):
# component_names = self.generate_trainable_components(num_layers)
#
# if additional_components:
# component_names.extend(additional_components)
#
# self.trainable_component_names = [] # 重置
# for name in component_names:
# recursive_getattr(self.flux_model, name).requires_grad_(True)
# self.trainable_component_names.append(name) # 记录名称
#def apply_to(self, num_layers=1, additional_components=None):
# component_names = self.generate_trainable_components(num_layers)
#
# if additional_components:
# component_names.extend(additional_components)
#
# self.trainable_component_names = [] # 重置
# for name in component_names:
# component = recursive_getattr(self.flux_model, name)
# if isinstance(component, nn.Module):
# component.requires_grad_(True)
# self.trainable_component_names.append(name)
# else:
# print(f"Warning: {name} is not a Module, skipping.")
def apply_to(self, layers=None, additional_components=None):
if layers is None:
layers = list(range(57)) # 默认包含所有层
component_names = self.generate_trainable_components(layers)
if additional_components:
component_names.extend(additional_components)
self.trainable_component_names = [] # 重置
for name in component_names:
try:
component = recursive_getattr(self.flux_model, name)
if isinstance(component, nn.Module):
component.requires_grad_(True)
self.trainable_component_names.append(name)
else:
print(f"Warning: {name} is not a Module, skipping.")
except AttributeError:
print(f"Warning: {name} not found in the model, skipping.")
def prepare_grad_etc(self):
# 供flux_model调用,用于冻结/解冻组件
self.flux_model.requires_grad_(False)
for name in self.trainable_component_names:
recursive_getattr(self.flux_model, name).requires_grad_(True)
def get_trainable_params(self):
# 返回需要训练的参数
params = []
for name in self.trainable_component_names:
params.extend(recursive_getattr(self.flux_model, name).parameters())
return params
def print_trainable_params_info(self):
total_params = 0
for name in self.trainable_component_names:
module = recursive_getattr(self.flux_model, name)
module_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
total_params += module_params
#print(f'{name}: {module_params} trainable parameters')
print(f'Total trainable params: {total_params}')
def save_weights(self, file, dtype=None):
# 保存需要训练的组件参数
state_dict = {}
for name in self.trainable_component_names:
state_dict[name] = recursive_getattr(self.flux_model, name).state_dict()
if dtype is not None:
for v in state_dict.values():
v = {k: t.detach().clone().to("cpu").to(dtype) for k, t in v.items()}
torch.save(state_dict, file)
#def load_weights(self, file):
# # 加载需要训练的组件参数
# state_dict = torch.load(file, weights_only=True)
# for name in state_dict:
# module = recursive_getattr(self.flux_model, name)
# module.load_state_dict(state_dict[name])
# print(f"加载参数: {name}")
def load_weights(self, file, device):
print(f"Loading weights from {file}")
try:
state_dict = torch.load(file, map_location=device, weights_only=True)
except Exception as e:
print(f"Failed to load weights from {file}: {str(e)}")
return False
successfully_loaded = []
failed_to_load = []
for name in state_dict:
try:
module = recursive_getattr(self.flux_model, name)
module_state_dict = module.state_dict()
# 检查state_dict的键是否匹配
if set(state_dict[name].keys()) != set(module_state_dict.keys()):
raise ValueError(f"State dict keys for {name} do not match")
# 检查张量的形状是否匹配
for key in state_dict[name]:
if state_dict[name][key].shape != module_state_dict[key].shape:
raise ValueError(f"Shape mismatch for {name}.{key}")
module.load_state_dict(state_dict[name])
successfully_loaded.append(name)
except Exception as e:
print(f"Failed to load weights for {name}: {str(e)}")
failed_to_load.append(name)
if successfully_loaded:
print(f"Successfully loaded weights for: {', '.join(successfully_loaded)}")
if failed_to_load:
print(f"Failed to load weights for: {', '.join(failed_to_load)}")
return len(failed_to_load) == 0 # 如果没有加载失败的组件,则返回True
# 改进的递归获取属性函数
def recursive_getattr(obj, attr):
attrs = attr.split(".")
for i in range(len(attrs)):
obj = getattr(obj, attrs[i])
return obj
# 递归设置属性函数
def recursive_setattr(obj, attr, val):
attrs = attr.split(".")
for i in range(len(attrs)-1):
obj = getattr(obj, attrs[i])
setattr(obj, attrs[-1], val)