Luffuly commited on
Commit
f3093e1
·
1 Parent(s): ea543e0

update config

Browse files
Files changed (4) hide show
  1. model_index.json +3 -4
  2. unet/config.json +7 -2
  3. unet/mv_unet.py +163 -0
  4. unet_state_dict.pth +0 -3
model_index.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
  "_class_name": "StableDiffusionImage2MVCustomPipeline",
3
- "_diffusers_version": "0.27.2",
4
- "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
5
  "condition_offset": true,
6
  "feature_extractor": [
7
  "transformers",
@@ -21,8 +20,8 @@
21
  "DDIMScheduler"
22
  ],
23
  "unet": [
24
- "diffusers",
25
- "UNet2DConditionModel"
26
  ],
27
  "vae": [
28
  "diffusers",
 
1
  {
2
  "_class_name": "StableDiffusionImage2MVCustomPipeline",
3
+ "_diffusers_version": "0.27.2",
 
4
  "condition_offset": true,
5
  "feature_extractor": [
6
  "transformers",
 
20
  "DDIMScheduler"
21
  ],
22
  "unet": [
23
+ "mv_unet",
24
+ "UnifieldWrappedUNet"
25
  ],
26
  "vae": [
27
  "diffusers",
unet/config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
  "_class_name": "UnifieldWrappedUNet",
3
  "_diffusers_version": "0.27.2",
4
- "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
5
  "act_fn": "silu",
6
  "addition_embed_type": null,
7
  "addition_embed_type_num_heads": 64,
@@ -64,5 +63,11 @@
64
  "CrossAttnUpBlock2D"
65
  ],
66
  "upcast_attention": false,
67
- "use_linear_projection": false
 
 
 
 
 
 
68
  }
 
1
  {
2
  "_class_name": "UnifieldWrappedUNet",
3
  "_diffusers_version": "0.27.2",
 
4
  "act_fn": "silu",
5
  "addition_embed_type": null,
6
  "addition_embed_type_num_heads": 64,
 
63
  "CrossAttnUpBlock2D"
64
  ],
65
  "upcast_attention": false,
66
+ "use_linear_projection": false,
67
+
68
+ "multiview_attn_position": "attn1",
69
+ "num_modalities": 1,
70
+ "latent_size": 64,
71
+ "multiview_chain_pose": "parralle"
72
+
73
  }
unet/mv_unet.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple, Union
3
+ from diffusers import UNet2DConditionModel
4
+ from diffusers.models.attention_processor import Attention
5
+
6
+
7
+ def switch_multiview_processor(model, enable_filter=lambda x:True):
8
+ def recursive_add_processors(name: str, module: torch.nn.Module):
9
+ for sub_name, child in module.named_children():
10
+ recursive_add_processors(f"{name}.{sub_name}", child)
11
+
12
+ if isinstance(module, Attention):
13
+ processor = module.get_processor()
14
+ if isinstance(processor, multiviewAttnProc):
15
+ processor.enabled = enable_filter(f"{name}.processor")
16
+
17
+ for name, module in model.named_children():
18
+ recursive_add_processors(name, module)
19
+
20
+
21
+ def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
22
+ return_dict = torch.nn.ModuleDict()
23
+ def recursive_add_processors(name: str, module: torch.nn.Module):
24
+ for sub_name, child in module.named_children():
25
+ if "ref_unet" not in (sub_name + name):
26
+ recursive_add_processors(f"{name}.{sub_name}", child)
27
+
28
+ if isinstance(module, Attention):
29
+ new_processor = multiviewAttnProc(
30
+ chained_proc=module.get_processor(),
31
+ enabled=enable_filter(f"{name}.processor"),
32
+ name=f"{name}.processor",
33
+ hidden_states_dim=module.inner_dim,
34
+ **kwargs
35
+ )
36
+ module.set_processor(new_processor)
37
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
38
+
39
+ for name, module in model.named_children():
40
+ recursive_add_processors(name, module)
41
+
42
+ return return_dict
43
+
44
+
45
+ class multiviewAttnProc(torch.nn.Module):
46
+ def __init__(
47
+ self,
48
+ chained_proc,
49
+ enabled=False,
50
+ name=None,
51
+ hidden_states_dim=None,
52
+ chain_pos="parralle", # before or parralle or after
53
+ num_modalities=1,
54
+ views=4,
55
+ base_img_size=64,
56
+ ) -> None:
57
+ super().__init__()
58
+ self.enabled = enabled
59
+ self.chained_proc = chained_proc
60
+ self.name = name
61
+ self.hidden_states_dim = hidden_states_dim
62
+ self.num_modalities = num_modalities
63
+ self.views = views
64
+ self.base_img_size = base_img_size
65
+ self.chain_pos = chain_pos
66
+ self.diff_joint_attn = True
67
+
68
+ def __call__(
69
+ self,
70
+ attn: Attention,
71
+ hidden_states: torch.FloatTensor,
72
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
73
+ attention_mask: Optional[torch.FloatTensor] = None,
74
+ **kwargs
75
+ ) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
78
+
79
+ B, L, C = hidden_states.shape
80
+ mv = self.views
81
+ hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C)
82
+ hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
83
+ return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C)
84
+
85
+
86
+
87
+ class UnifieldWrappedUNet(UNet2DConditionModel):
88
+ def __init__(
89
+ self,
90
+ sample_size: Optional[int] = None,
91
+ in_channels: int = 4,
92
+ out_channels: int = 4,
93
+ center_input_sample: bool = False,
94
+ flip_sin_to_cos: bool = True,
95
+ freq_shift: int = 0,
96
+ down_block_types: Tuple[str] = (
97
+ "CrossAttnDownBlock2D",
98
+ "CrossAttnDownBlock2D",
99
+ "CrossAttnDownBlock2D",
100
+ "DownBlock2D",
101
+ ),
102
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
103
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
104
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
105
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
106
+ layers_per_block: Union[int, Tuple[int]] = 2,
107
+ downsample_padding: int = 1,
108
+ mid_block_scale_factor: float = 1,
109
+ dropout: float = 0.0,
110
+ act_fn: str = "silu",
111
+ norm_num_groups: Optional[int] = 32,
112
+ norm_eps: float = 1e-5,
113
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
114
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
115
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
116
+ encoder_hid_dim: Optional[int] = None,
117
+ encoder_hid_dim_type: Optional[str] = None,
118
+ attention_head_dim: Union[int, Tuple[int]] = 8,
119
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
120
+ dual_cross_attention: bool = False,
121
+ use_linear_projection: bool = False,
122
+ class_embed_type: Optional[str] = None,
123
+ addition_embed_type: Optional[str] = None,
124
+ addition_time_embed_dim: Optional[int] = None,
125
+ num_class_embeds: Optional[int] = None,
126
+ upcast_attention: bool = False,
127
+ resnet_time_scale_shift: str = "default",
128
+ resnet_skip_time_act: bool = False,
129
+ resnet_out_scale_factor: float = 1.0,
130
+ time_embedding_type: str = "positional",
131
+ time_embedding_dim: Optional[int] = None,
132
+ time_embedding_act_fn: Optional[str] = None,
133
+ timestep_post_act: Optional[str] = None,
134
+ time_cond_proj_dim: Optional[int] = None,
135
+ conv_in_kernel: int = 3,
136
+ conv_out_kernel: int = 3,
137
+ projection_class_embeddings_input_dim: Optional[int] = None,
138
+ attention_type: str = "default",
139
+ class_embeddings_concat: bool = False,
140
+ mid_block_only_cross_attention: Optional[bool] = None,
141
+ cross_attention_norm: Optional[str] = None,
142
+ addition_embed_type_num_heads: int = 64,
143
+ multiview_attn_position: str = "attn1",
144
+ num_modalities: int = 1,
145
+ latent_size: int = 64,
146
+ multiview_chain_pose: str = "parralle",
147
+ **kwargs
148
+ ):
149
+ super().__init__(**{
150
+ k: v for k, v in locals().items() if k not in
151
+ ["self", "kwargs", "__class__", "multiview_attn_position", "num_modalities", "latent_size", "multiview_chain_pose"]
152
+ })
153
+
154
+ add_multiview_processor(
155
+ model = self,
156
+ enable_filter = lambda name: name.endswith(f"{multiview_attn_position}.processor"),
157
+ num_modalities = num_modalities,
158
+ base_img_size = latent_size,
159
+ chain_pos = multiview_chain_pose,
160
+ )
161
+
162
+ switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
163
+
unet_state_dict.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0dff2fdba450af0e10c3a847ba66a530170be2e9b9c9f4c834483515e82738b5
3
- size 3438460972