update config
Browse files- model_index.json +3 -4
- unet/config.json +7 -2
- unet/mv_unet.py +163 -0
- unet_state_dict.pth +0 -3
model_index.json
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
{
|
2 |
"_class_name": "StableDiffusionImage2MVCustomPipeline",
|
3 |
-
"_diffusers_version": "0.27.2",
|
4 |
-
"_name_or_path": "lambdalabs/sd-image-variations-diffusers",
|
5 |
"condition_offset": true,
|
6 |
"feature_extractor": [
|
7 |
"transformers",
|
@@ -21,8 +20,8 @@
|
|
21 |
"DDIMScheduler"
|
22 |
],
|
23 |
"unet": [
|
24 |
-
"
|
25 |
-
"
|
26 |
],
|
27 |
"vae": [
|
28 |
"diffusers",
|
|
|
1 |
{
|
2 |
"_class_name": "StableDiffusionImage2MVCustomPipeline",
|
3 |
+
"_diffusers_version": "0.27.2",
|
|
|
4 |
"condition_offset": true,
|
5 |
"feature_extractor": [
|
6 |
"transformers",
|
|
|
20 |
"DDIMScheduler"
|
21 |
],
|
22 |
"unet": [
|
23 |
+
"mv_unet",
|
24 |
+
"UnifieldWrappedUNet"
|
25 |
],
|
26 |
"vae": [
|
27 |
"diffusers",
|
unet/config.json
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
{
|
2 |
"_class_name": "UnifieldWrappedUNet",
|
3 |
"_diffusers_version": "0.27.2",
|
4 |
-
"_name_or_path": "lambdalabs/sd-image-variations-diffusers",
|
5 |
"act_fn": "silu",
|
6 |
"addition_embed_type": null,
|
7 |
"addition_embed_type_num_heads": 64,
|
@@ -64,5 +63,11 @@
|
|
64 |
"CrossAttnUpBlock2D"
|
65 |
],
|
66 |
"upcast_attention": false,
|
67 |
-
"use_linear_projection": false
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
}
|
|
|
1 |
{
|
2 |
"_class_name": "UnifieldWrappedUNet",
|
3 |
"_diffusers_version": "0.27.2",
|
|
|
4 |
"act_fn": "silu",
|
5 |
"addition_embed_type": null,
|
6 |
"addition_embed_type_num_heads": 64,
|
|
|
63 |
"CrossAttnUpBlock2D"
|
64 |
],
|
65 |
"upcast_attention": false,
|
66 |
+
"use_linear_projection": false,
|
67 |
+
|
68 |
+
"multiview_attn_position": "attn1",
|
69 |
+
"num_modalities": 1,
|
70 |
+
"latent_size": 64,
|
71 |
+
"multiview_chain_pose": "parralle"
|
72 |
+
|
73 |
}
|
unet/mv_unet.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional, Tuple, Union
|
3 |
+
from diffusers import UNet2DConditionModel
|
4 |
+
from diffusers.models.attention_processor import Attention
|
5 |
+
|
6 |
+
|
7 |
+
def switch_multiview_processor(model, enable_filter=lambda x:True):
|
8 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
9 |
+
for sub_name, child in module.named_children():
|
10 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
11 |
+
|
12 |
+
if isinstance(module, Attention):
|
13 |
+
processor = module.get_processor()
|
14 |
+
if isinstance(processor, multiviewAttnProc):
|
15 |
+
processor.enabled = enable_filter(f"{name}.processor")
|
16 |
+
|
17 |
+
for name, module in model.named_children():
|
18 |
+
recursive_add_processors(name, module)
|
19 |
+
|
20 |
+
|
21 |
+
def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
|
22 |
+
return_dict = torch.nn.ModuleDict()
|
23 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
24 |
+
for sub_name, child in module.named_children():
|
25 |
+
if "ref_unet" not in (sub_name + name):
|
26 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
27 |
+
|
28 |
+
if isinstance(module, Attention):
|
29 |
+
new_processor = multiviewAttnProc(
|
30 |
+
chained_proc=module.get_processor(),
|
31 |
+
enabled=enable_filter(f"{name}.processor"),
|
32 |
+
name=f"{name}.processor",
|
33 |
+
hidden_states_dim=module.inner_dim,
|
34 |
+
**kwargs
|
35 |
+
)
|
36 |
+
module.set_processor(new_processor)
|
37 |
+
return_dict[f"{name}.processor".replace(".", "__")] = new_processor
|
38 |
+
|
39 |
+
for name, module in model.named_children():
|
40 |
+
recursive_add_processors(name, module)
|
41 |
+
|
42 |
+
return return_dict
|
43 |
+
|
44 |
+
|
45 |
+
class multiviewAttnProc(torch.nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
chained_proc,
|
49 |
+
enabled=False,
|
50 |
+
name=None,
|
51 |
+
hidden_states_dim=None,
|
52 |
+
chain_pos="parralle", # before or parralle or after
|
53 |
+
num_modalities=1,
|
54 |
+
views=4,
|
55 |
+
base_img_size=64,
|
56 |
+
) -> None:
|
57 |
+
super().__init__()
|
58 |
+
self.enabled = enabled
|
59 |
+
self.chained_proc = chained_proc
|
60 |
+
self.name = name
|
61 |
+
self.hidden_states_dim = hidden_states_dim
|
62 |
+
self.num_modalities = num_modalities
|
63 |
+
self.views = views
|
64 |
+
self.base_img_size = base_img_size
|
65 |
+
self.chain_pos = chain_pos
|
66 |
+
self.diff_joint_attn = True
|
67 |
+
|
68 |
+
def __call__(
|
69 |
+
self,
|
70 |
+
attn: Attention,
|
71 |
+
hidden_states: torch.FloatTensor,
|
72 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
73 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
74 |
+
**kwargs
|
75 |
+
) -> torch.Tensor:
|
76 |
+
if not self.enabled:
|
77 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
78 |
+
|
79 |
+
B, L, C = hidden_states.shape
|
80 |
+
mv = self.views
|
81 |
+
hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C)
|
82 |
+
hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
83 |
+
return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C)
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
class UnifieldWrappedUNet(UNet2DConditionModel):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
sample_size: Optional[int] = None,
|
91 |
+
in_channels: int = 4,
|
92 |
+
out_channels: int = 4,
|
93 |
+
center_input_sample: bool = False,
|
94 |
+
flip_sin_to_cos: bool = True,
|
95 |
+
freq_shift: int = 0,
|
96 |
+
down_block_types: Tuple[str] = (
|
97 |
+
"CrossAttnDownBlock2D",
|
98 |
+
"CrossAttnDownBlock2D",
|
99 |
+
"CrossAttnDownBlock2D",
|
100 |
+
"DownBlock2D",
|
101 |
+
),
|
102 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
103 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
104 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
105 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
106 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
107 |
+
downsample_padding: int = 1,
|
108 |
+
mid_block_scale_factor: float = 1,
|
109 |
+
dropout: float = 0.0,
|
110 |
+
act_fn: str = "silu",
|
111 |
+
norm_num_groups: Optional[int] = 32,
|
112 |
+
norm_eps: float = 1e-5,
|
113 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
114 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
115 |
+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
116 |
+
encoder_hid_dim: Optional[int] = None,
|
117 |
+
encoder_hid_dim_type: Optional[str] = None,
|
118 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
119 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
120 |
+
dual_cross_attention: bool = False,
|
121 |
+
use_linear_projection: bool = False,
|
122 |
+
class_embed_type: Optional[str] = None,
|
123 |
+
addition_embed_type: Optional[str] = None,
|
124 |
+
addition_time_embed_dim: Optional[int] = None,
|
125 |
+
num_class_embeds: Optional[int] = None,
|
126 |
+
upcast_attention: bool = False,
|
127 |
+
resnet_time_scale_shift: str = "default",
|
128 |
+
resnet_skip_time_act: bool = False,
|
129 |
+
resnet_out_scale_factor: float = 1.0,
|
130 |
+
time_embedding_type: str = "positional",
|
131 |
+
time_embedding_dim: Optional[int] = None,
|
132 |
+
time_embedding_act_fn: Optional[str] = None,
|
133 |
+
timestep_post_act: Optional[str] = None,
|
134 |
+
time_cond_proj_dim: Optional[int] = None,
|
135 |
+
conv_in_kernel: int = 3,
|
136 |
+
conv_out_kernel: int = 3,
|
137 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
138 |
+
attention_type: str = "default",
|
139 |
+
class_embeddings_concat: bool = False,
|
140 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
141 |
+
cross_attention_norm: Optional[str] = None,
|
142 |
+
addition_embed_type_num_heads: int = 64,
|
143 |
+
multiview_attn_position: str = "attn1",
|
144 |
+
num_modalities: int = 1,
|
145 |
+
latent_size: int = 64,
|
146 |
+
multiview_chain_pose: str = "parralle",
|
147 |
+
**kwargs
|
148 |
+
):
|
149 |
+
super().__init__(**{
|
150 |
+
k: v for k, v in locals().items() if k not in
|
151 |
+
["self", "kwargs", "__class__", "multiview_attn_position", "num_modalities", "latent_size", "multiview_chain_pose"]
|
152 |
+
})
|
153 |
+
|
154 |
+
add_multiview_processor(
|
155 |
+
model = self,
|
156 |
+
enable_filter = lambda name: name.endswith(f"{multiview_attn_position}.processor"),
|
157 |
+
num_modalities = num_modalities,
|
158 |
+
base_img_size = latent_size,
|
159 |
+
chain_pos = multiview_chain_pose,
|
160 |
+
)
|
161 |
+
|
162 |
+
switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
|
163 |
+
|
unet_state_dict.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0dff2fdba450af0e10c3a847ba66a530170be2e9b9c9f4c834483515e82738b5
|
3 |
-
size 3438460972
|
|
|
|
|
|
|
|