Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from .transformer_flux import FluxTransformer2DModel | |
class FluxNetwork(nn.Module): | |
TARGET_REPLACE_MODULE = ["FluxTransformerBlock","FluxSingleTransformerBlock"] # 可训练的模块类型 | |
FLUX_PREFIX = "flux" | |
def __init__(self, flux_model: FluxTransformer2DModel): | |
super().__init__() | |
self.flux_model = flux_model | |
self.trainable_component_names = [] # 用于记录可训练组件的名称 | |
def generate_trainable_components(layers, num_transformer_blocks=19): | |
transformer_components = [ | |
"attn.to_q", | |
"attn.to_k", | |
"attn.to_v", | |
"attn.to_out", | |
"norm1", | |
"norm1_context", | |
] | |
single_transformer_components = [ | |
"attn.to_q", | |
"attn.to_k", | |
"attn.to_v", | |
"norm", | |
#"proj_mlp", | |
] | |
components = ["context_embedder"] # 添加 context_embedder | |
for layer in layers: | |
if layer < num_transformer_blocks: | |
prefix = f"transformer_blocks.{layer}" | |
base_components = transformer_components | |
else: | |
prefix = f"single_transformer_blocks.{layer - num_transformer_blocks}" | |
base_components = single_transformer_components | |
components.extend([f"{prefix}.{comp}" for comp in base_components]) | |
return components | |
#def apply_to(self, num_layers=1, additional_components=None): | |
# component_names = self.generate_trainable_components(num_layers) | |
# | |
# if additional_components: | |
# component_names.extend(additional_components) | |
# | |
# self.trainable_component_names = [] # 重置 | |
# for name in component_names: | |
# recursive_getattr(self.flux_model, name).requires_grad_(True) | |
# self.trainable_component_names.append(name) # 记录名称 | |
#def apply_to(self, num_layers=1, additional_components=None): | |
# component_names = self.generate_trainable_components(num_layers) | |
# | |
# if additional_components: | |
# component_names.extend(additional_components) | |
# | |
# self.trainable_component_names = [] # 重置 | |
# for name in component_names: | |
# component = recursive_getattr(self.flux_model, name) | |
# if isinstance(component, nn.Module): | |
# component.requires_grad_(True) | |
# self.trainable_component_names.append(name) | |
# else: | |
# print(f"Warning: {name} is not a Module, skipping.") | |
def apply_to(self, layers=None, additional_components=None): | |
if layers is None: | |
layers = list(range(57)) # 默认包含所有层 | |
component_names = self.generate_trainable_components(layers) | |
if additional_components: | |
component_names.extend(additional_components) | |
self.trainable_component_names = [] # 重置 | |
for name in component_names: | |
try: | |
component = recursive_getattr(self.flux_model, name) | |
if isinstance(component, nn.Module): | |
component.requires_grad_(True) | |
self.trainable_component_names.append(name) | |
else: | |
print(f"Warning: {name} is not a Module, skipping.") | |
except AttributeError: | |
print(f"Warning: {name} not found in the model, skipping.") | |
def prepare_grad_etc(self): | |
# 供flux_model调用,用于冻结/解冻组件 | |
self.flux_model.requires_grad_(False) | |
for name in self.trainable_component_names: | |
recursive_getattr(self.flux_model, name).requires_grad_(True) | |
def get_trainable_params(self): | |
# 返回需要训练的参数 | |
params = [] | |
for name in self.trainable_component_names: | |
params.extend(recursive_getattr(self.flux_model, name).parameters()) | |
return params | |
def print_trainable_params_info(self): | |
total_params = 0 | |
for name in self.trainable_component_names: | |
module = recursive_getattr(self.flux_model, name) | |
module_params = sum(p.numel() for p in module.parameters() if p.requires_grad) | |
total_params += module_params | |
#print(f'{name}: {module_params} trainable parameters') | |
print(f'Total trainable params: {total_params}') | |
def save_weights(self, file, dtype=None): | |
# 保存需要训练的组件参数 | |
state_dict = {} | |
for name in self.trainable_component_names: | |
state_dict[name] = recursive_getattr(self.flux_model, name).state_dict() | |
if dtype is not None: | |
for v in state_dict.values(): | |
v = {k: t.detach().clone().to("cpu").to(dtype) for k, t in v.items()} | |
torch.save(state_dict, file) | |
#def load_weights(self, file): | |
# # 加载需要训练的组件参数 | |
# state_dict = torch.load(file, weights_only=True) | |
# for name in state_dict: | |
# module = recursive_getattr(self.flux_model, name) | |
# module.load_state_dict(state_dict[name]) | |
# print(f"加载参数: {name}") | |
def load_weights(self, file, device): | |
print(f"Loading weights from {file}") | |
try: | |
state_dict = torch.load(file, map_location=device, weights_only=True) | |
except Exception as e: | |
print(f"Failed to load weights from {file}: {str(e)}") | |
return False | |
successfully_loaded = [] | |
failed_to_load = [] | |
for name in state_dict: | |
try: | |
module = recursive_getattr(self.flux_model, name) | |
module_state_dict = module.state_dict() | |
# 检查state_dict的键是否匹配 | |
if set(state_dict[name].keys()) != set(module_state_dict.keys()): | |
raise ValueError(f"State dict keys for {name} do not match") | |
# 检查张量的形状是否匹配 | |
for key in state_dict[name]: | |
if state_dict[name][key].shape != module_state_dict[key].shape: | |
raise ValueError(f"Shape mismatch for {name}.{key}") | |
module.load_state_dict(state_dict[name]) | |
successfully_loaded.append(name) | |
except Exception as e: | |
print(f"Failed to load weights for {name}: {str(e)}") | |
failed_to_load.append(name) | |
if successfully_loaded: | |
print(f"Successfully loaded weights for: {', '.join(successfully_loaded)}") | |
if failed_to_load: | |
print(f"Failed to load weights for: {', '.join(failed_to_load)}") | |
return len(failed_to_load) == 0 # 如果没有加载失败的组件,则返回True | |
# 改进的递归获取属性函数 | |
def recursive_getattr(obj, attr): | |
attrs = attr.split(".") | |
for i in range(len(attrs)): | |
obj = getattr(obj, attrs[i]) | |
return obj | |
# 递归设置属性函数 | |
def recursive_setattr(obj, attr, val): | |
attrs = attr.split(".") | |
for i in range(len(attrs)-1): | |
obj = getattr(obj, attrs[i]) | |
setattr(obj, attrs[-1], val) |