Spaces:
Running
on
Zero
Running
on
Zero
from typing import Callable, Optional | |
from .triposg_transformer import TripoSGDiTModel | |
def default_set_attn_proc_func( | |
name: str, | |
hidden_size: int, | |
cross_attention_dim: Optional[int], | |
ori_attn_proc: object, | |
) -> object: | |
return ori_attn_proc | |
def set_transformer_attn_processor( | |
transformer: TripoSGDiTModel, | |
set_self_attn_proc_func: Callable = default_set_attn_proc_func, | |
set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func, | |
set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func, | |
set_self_attn_module_names: Optional[list[str]] = None, | |
set_cross_attn_1_module_names: Optional[list[str]] = None, | |
set_cross_attn_2_module_names: Optional[list[str]] = None, | |
) -> None: | |
do_set_processor = lambda name, module_names: ( | |
any([name.startswith(module_name) for module_name in module_names]) | |
if module_names is not None | |
else True | |
) # prefix match | |
attn_procs = {} | |
for name, attn_processor in transformer.attn_processors.items(): | |
hidden_size = transformer.config.width | |
if name.endswith("attn1.processor"): | |
# self attention | |
attn_procs[name] = ( | |
set_self_attn_proc_func(name, hidden_size, None, attn_processor) | |
if do_set_processor(name, set_self_attn_module_names) | |
else attn_processor | |
) | |
elif name.endswith("attn2.processor"): | |
# cross attention | |
cross_attention_dim = transformer.config.cross_attention_dim | |
attn_procs[name] = ( | |
set_cross_attn_1_proc_func( | |
name, hidden_size, cross_attention_dim, attn_processor | |
) | |
if do_set_processor(name, set_cross_attn_1_module_names) | |
else attn_processor | |
) | |
elif name.endswith("attn2_2.processor"): | |
# cross attention 2 | |
cross_attention_dim = transformer.config.cross_attention_2_dim | |
attn_procs[name] = ( | |
set_cross_attn_2_proc_func( | |
name, hidden_size, cross_attention_dim, attn_processor | |
) | |
if do_set_processor(name, set_cross_attn_2_module_names) | |
else attn_processor | |
) | |
transformer.set_attn_processor(attn_procs) | |