Meaowangxi commited on
Commit
96f3396
·
verified ·
1 Parent(s): 0060beb

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +10 -69
ip_adapter/ip_adapter.py CHANGED
@@ -30,51 +30,17 @@ class ImageProjModel(torch.nn.Module):
30
  def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
31
  super().__init__()
32
 
33
- # cross_attention_dim = 768
34
- # clip_extra_context_tokens = 4
35
- # clip_embeddings_dim = 1024
36
  self.cross_attention_dim = cross_attention_dim
37
  self.clip_extra_context_tokens = clip_extra_context_tokens
38
- # 创建了一个线性层self.proj,将clip_embeddings_dim作为输入维度,将self.clip_extra_context_tokens * cross_attention_dim作为输出维度。
39
  self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
40
- # self.proj_1 = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
41
- #
42
- # # 访问线性层的权重参数
43
- # weights = self.proj.weight
44
- # print("proj_weights")
45
- # print(weights)
46
- # # 访问线性层的权重参数
47
- # weights_1 = self.proj_1.weight
48
- # print("proj_1_weights")
49
- # print(weights_1)
50
- #
51
- # # 访问线性层的偏置参数
52
- # bias = self.proj.bias
53
- # print("proj_bias")
54
- # print(bias)
55
- # # 访问线性层的偏置参数
56
- # bias_1 = self.proj_1.bias
57
- # print("proj_1_bias")
58
- # print(bias_1)
59
-
60
-
61
- # 接着,它创建了一个LayerNorm层self.norm,将cross_attention_dim作为输入维度
62
- # LayerNorm层能对每个通道进行归一化处理,确保每个通道均值方差一致,使得每个通道的特征分布相对一致,帮助模型学习特征
63
  self.norm = torch.nn.LayerNorm(cross_attention_dim)
64
 
65
  def forward(self, image_embeds):
66
- # 在前向传播函数中,它接受image_embeds作为输入,然后将其赋值给embeds。
67
  embeds = image_embeds
68
- # embeds.shape = [1,1024]
69
- # self.proj(embeds).shape = [1,3072]
70
- # 接着,它使用self.proj对embeds进行线性变换,并将结果reshape
71
  clip_extra_context_tokens = self.proj(embeds).reshape(
72
  -1, self.clip_extra_context_tokens, self.cross_attention_dim
73
  )
74
- # clip_extra_context_tokens.shape = [1,4,768]
75
- # 然后,它将结果传入self.norm进行LayerNorm操作,并返回处理后的结果clip_extra_context_tokens。
76
  clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
77
- # clip_extra_context_tokens.shape = [1,4,768]
78
  return clip_extra_context_tokens
79
 
80
 
@@ -110,7 +76,7 @@ class IPAdapter:
110
 
111
  # load image encoder
112
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
113
- self.device, dtype=torch.float16
114
  )
115
  self.clip_image_processor = CLIPImageProcessor()
116
  # image proj model
@@ -123,20 +89,14 @@ class IPAdapter:
123
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
124
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
125
  clip_extra_context_tokens=self.num_tokens,
126
- ).to(self.device, dtype=torch.float16)
127
  return image_proj_model
128
 
129
  def set_ip_adapter(self):
130
- # 首先,它获取了self.pipe.unet中的unet,
131
  unet = self.pipe.unet
132
- # 并初始化了一个空的字典attn_procs
133
  attn_procs = {}
134
- # 然后,它遍历unet.attn_processors中的每个键名name
135
  for name in unet.attn_processors.keys():
136
- # 在循环中,它根据name的不同情况设置cross_attention_dim和hidden_size
137
- # 如果name以"attn1.processor"结尾,那么cross_attention_dim被设置为None;否则,它被设置为unet.config.cross_attention_dim。
138
  cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
139
- # 接着,根据name的前缀不同,设置了hidden_size的值
140
  if name.startswith("mid_block"):
141
  hidden_size = unet.config.block_out_channels[-1]
142
  elif name.startswith("up_blocks"):
@@ -145,12 +105,9 @@ class IPAdapter:
145
  elif name.startswith("down_blocks"):
146
  block_id = int(name[len("down_blocks.")])
147
  hidden_size = unet.config.block_out_channels[block_id]
148
- # 接下来,根据cross_attention_dim的值,为每个name创建了一个对应的AttnProcessor或IPAttnProcessor,并将其加入attn_procs字典中最后
149
  if cross_attention_dim is None:
150
- #print("initialization:attn_procs[name] = AttnProcessor()")
151
  attn_procs[name] = AttnProcessor()
152
  else:
153
- #print("initialization:attn_procs[name] = IPAttnProcessor()")
154
  attn_procs[name] = IPAttnProcessor(
155
  hidden_size= hidden_size,
156
  cross_attention_dim=cross_attention_dim,
@@ -158,9 +115,7 @@ class IPAdapter:
158
  num_tokens=self.num_tokens,
159
  Control_factor=self.Control_factor,
160
  IP_factor=self.IP_factor,
161
- ).to(self.device, dtype=torch.float16)
162
- # 调用unet.set_attn_processor(attn_procs)来设置unet的注意力处理器
163
- # 同时调用self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))来设置self.pipe.controlnet的注意力处理器。
164
  unet.set_attn_processor(attn_procs)
165
  #self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
166
  if hasattr(self.pipe, "controlnet"):
@@ -171,12 +126,8 @@ class IPAdapter:
171
  self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
172
 
173
  def load_ip_adapter(self):
174
- # 该方法用于加载IP适配器的状态。然后,它使用safe_open函数打开self.ip_ckpt文件,并遍历文件中的键名。
175
- # 首先,它检查self.ip_ckpt的文件扩展名是否为".safetensors"。
176
  if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
177
- # 如果是,它创建了一个空的state_dict字典,包含"image_proj"和"ip_adapter"两个键。
178
  state_dict = {"image_proj": {}, "ip_adapter": {}}
179
- # 对于以"image_proj."开头的键名,它将对应的张量存入state_dict["image_proj"]中;对于以"ip_adapter."开头的键名,它将对应的张量存入state_dict["ip_adapter"]中。
180
  with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
181
  for key in f.keys():
182
  if key.startswith("image_proj."):
@@ -184,12 +135,7 @@ class IPAdapter:
184
  elif key.startswith("ip_adapter."):
185
  state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
186
  else:
187
- # 如果self.ip_ckpt的文件扩展名不是".safetensors",那么它直接使用torch.load函数加载self.ip_ckpt文件的状态,并将其存入state_dict中。
188
  state_dict = torch.load(self.ip_ckpt, map_location="cpu")
189
- # 这段代码中的两行分别用于加载预训练模型的参数。
190
- # 第一行使用load_state_dict方法将state_dict中的"image_proj"部分加载到self.image_proj_model中
191
- # 而第二行则尝试将state_dict中的"ip_adapter"部分加载到ip_layers中。
192
- # 需要注意的是,ip_layers是一个ModuleList,它包含了多个attn_processors,因此在尝试加载"ip_adapter"部分时,需要确保state_dict中的键能够与ip_layers中的各个子模块对应上。
193
  self.image_proj_model.load_state_dict(state_dict["image_proj"])
194
  ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
195
  ip_layers.load_state_dict(state_dict["ip_adapter"])
@@ -200,14 +146,9 @@ class IPAdapter:
200
  if isinstance(pil_image, Image.Image):
201
  pil_image = [pil_image]
202
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
203
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
204
-
205
- # clip_imageBroken = self.clip_image_processor(images=image_broken, return_tensors="pt").pixel_values
206
- # clip_imageBroken_embeds = self.image_encoder(clip_imageBroken.to(self.device, dtype=torch.float16)).image_embeds
207
- # clip_image_embeds.shape: torch.Size([1, 1024])
208
- # style_vector = clip_image_embeds-clip_imageBroken_embeds
209
  else:
210
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
211
 
212
 
213
  # image_prompt_embeds = self.image_proj_model(style_vector)
@@ -382,7 +323,7 @@ class IPAdapterPlus(IPAdapter):
382
  embedding_dim=self.image_encoder.config.hidden_size,
383
  output_dim=self.pipe.unet.config.cross_attention_dim,
384
  ff_mult=4,
385
- ).to(self.device, dtype=torch.float16)
386
  return image_proj_model
387
 
388
  @torch.inference_mode()
@@ -390,7 +331,7 @@ class IPAdapterPlus(IPAdapter):
390
  if isinstance(pil_image, Image.Image):
391
  pil_image = [pil_image]
392
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
393
- clip_image = clip_image.to(self.device, dtype=torch.float16)
394
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
395
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
396
  uncond_clip_image_embeds = self.image_encoder(
@@ -408,7 +349,7 @@ class IPAdapterFull(IPAdapterPlus):
408
  image_proj_model = MLPProjModel(
409
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
410
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
411
- ).to(self.device, dtype=torch.float16)
412
  return image_proj_model
413
 
414
  # image_proj_model = Resampler(
@@ -425,7 +366,7 @@ class IPAdapterPlusXL(IPAdapter):
425
  embedding_dim=self.image_encoder.config.hidden_size,
426
  output_dim=self.pipe.unet.config.cross_attention_dim,
427
  ff_mult=4,
428
- ).to(self.device, dtype=torch.float16)
429
  return image_proj_model
430
 
431
  @torch.inference_mode()
@@ -433,7 +374,7 @@ class IPAdapterPlusXL(IPAdapter):
433
  if isinstance(pil_image, Image.Image):
434
  pil_image = [pil_image]
435
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
436
- clip_image = clip_image.to(self.device, dtype=torch.float16)
437
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
438
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
439
  uncond_clip_image_embeds = self.image_encoder(
 
30
  def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
31
  super().__init__()
32
 
 
 
 
33
  self.cross_attention_dim = cross_attention_dim
34
  self.clip_extra_context_tokens = clip_extra_context_tokens
 
35
  self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self.norm = torch.nn.LayerNorm(cross_attention_dim)
37
 
38
  def forward(self, image_embeds):
 
39
  embeds = image_embeds
 
 
 
40
  clip_extra_context_tokens = self.proj(embeds).reshape(
41
  -1, self.clip_extra_context_tokens, self.cross_attention_dim
42
  )
 
 
43
  clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
 
44
  return clip_extra_context_tokens
45
 
46
 
 
76
 
77
  # load image encoder
78
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
79
+ self.device, dtype=torch.float32
80
  )
81
  self.clip_image_processor = CLIPImageProcessor()
82
  # image proj model
 
89
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
90
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
91
  clip_extra_context_tokens=self.num_tokens,
92
+ ).to(self.device, dtype=torch.float32)
93
  return image_proj_model
94
 
95
  def set_ip_adapter(self):
 
96
  unet = self.pipe.unet
 
97
  attn_procs = {}
 
98
  for name in unet.attn_processors.keys():
 
 
99
  cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
 
100
  if name.startswith("mid_block"):
101
  hidden_size = unet.config.block_out_channels[-1]
102
  elif name.startswith("up_blocks"):
 
105
  elif name.startswith("down_blocks"):
106
  block_id = int(name[len("down_blocks.")])
107
  hidden_size = unet.config.block_out_channels[block_id]
 
108
  if cross_attention_dim is None:
 
109
  attn_procs[name] = AttnProcessor()
110
  else:
 
111
  attn_procs[name] = IPAttnProcessor(
112
  hidden_size= hidden_size,
113
  cross_attention_dim=cross_attention_dim,
 
115
  num_tokens=self.num_tokens,
116
  Control_factor=self.Control_factor,
117
  IP_factor=self.IP_factor,
118
+ ).to(self.device, dtype=torch.float32)
 
 
119
  unet.set_attn_processor(attn_procs)
120
  #self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
121
  if hasattr(self.pipe, "controlnet"):
 
126
  self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
127
 
128
  def load_ip_adapter(self):
 
 
129
  if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
 
130
  state_dict = {"image_proj": {}, "ip_adapter": {}}
 
131
  with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
132
  for key in f.keys():
133
  if key.startswith("image_proj."):
 
135
  elif key.startswith("ip_adapter."):
136
  state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
137
  else:
 
138
  state_dict = torch.load(self.ip_ckpt, map_location="cpu")
 
 
 
 
139
  self.image_proj_model.load_state_dict(state_dict["image_proj"])
140
  ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
141
  ip_layers.load_state_dict(state_dict["ip_adapter"])
 
146
  if isinstance(pil_image, Image.Image):
147
  pil_image = [pil_image]
148
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
149
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds
 
 
 
 
 
150
  else:
151
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32)
152
 
153
 
154
  # image_prompt_embeds = self.image_proj_model(style_vector)
 
323
  embedding_dim=self.image_encoder.config.hidden_size,
324
  output_dim=self.pipe.unet.config.cross_attention_dim,
325
  ff_mult=4,
326
+ ).to(self.device, dtype=torch.float32)
327
  return image_proj_model
328
 
329
  @torch.inference_mode()
 
331
  if isinstance(pil_image, Image.Image):
332
  pil_image = [pil_image]
333
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
334
+ clip_image = clip_image.to(self.device, dtype=torch.float32)
335
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
336
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
337
  uncond_clip_image_embeds = self.image_encoder(
 
349
  image_proj_model = MLPProjModel(
350
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
351
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
352
+ ).to(self.device, dtype=torch.float32)
353
  return image_proj_model
354
 
355
  # image_proj_model = Resampler(
 
366
  embedding_dim=self.image_encoder.config.hidden_size,
367
  output_dim=self.pipe.unet.config.cross_attention_dim,
368
  ff_mult=4,
369
+ ).to(self.device, dtype=torch.float32)
370
  return image_proj_model
371
 
372
  @torch.inference_mode()
 
374
  if isinstance(pil_image, Image.Image):
375
  pil_image = [pil_image]
376
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
377
+ clip_image = clip_image.to(self.device, dtype=torch.float32)
378
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
379
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
380
  uncond_clip_image_embeds = self.image_encoder(