huanngzh's picture
update
c9724af
from typing import Callable, Optional
from .triposg_transformer import TripoSGDiTModel
def default_set_attn_proc_func(
name: str,
hidden_size: int,
cross_attention_dim: Optional[int],
ori_attn_proc: object,
) -> object:
return ori_attn_proc
def set_transformer_attn_processor(
transformer: TripoSGDiTModel,
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func,
set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func,
set_self_attn_module_names: Optional[list[str]] = None,
set_cross_attn_1_module_names: Optional[list[str]] = None,
set_cross_attn_2_module_names: Optional[list[str]] = None,
) -> None:
do_set_processor = lambda name, module_names: (
any([name.startswith(module_name) for module_name in module_names])
if module_names is not None
else True
) # prefix match
attn_procs = {}
for name, attn_processor in transformer.attn_processors.items():
hidden_size = transformer.config.width
if name.endswith("attn1.processor"):
# self attention
attn_procs[name] = (
set_self_attn_proc_func(name, hidden_size, None, attn_processor)
if do_set_processor(name, set_self_attn_module_names)
else attn_processor
)
elif name.endswith("attn2.processor"):
# cross attention
cross_attention_dim = transformer.config.cross_attention_dim
attn_procs[name] = (
set_cross_attn_1_proc_func(
name, hidden_size, cross_attention_dim, attn_processor
)
if do_set_processor(name, set_cross_attn_1_module_names)
else attn_processor
)
elif name.endswith("attn2_2.processor"):
# cross attention 2
cross_attention_dim = transformer.config.cross_attention_2_dim
attn_procs[name] = (
set_cross_attn_2_proc_func(
name, hidden_size, cross_attention_dim, attn_processor
)
if do_set_processor(name, set_cross_attn_2_module_names)
else attn_processor
)
transformer.set_attn_processor(attn_procs)