Spaces:
Runtime error
Runtime error
Update ip_adapter/ip_adapter.py
Browse files- ip_adapter/ip_adapter.py +10 -69
ip_adapter/ip_adapter.py
CHANGED
@@ -30,51 +30,17 @@ class ImageProjModel(torch.nn.Module):
|
|
30 |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
31 |
super().__init__()
|
32 |
|
33 |
-
# cross_attention_dim = 768
|
34 |
-
# clip_extra_context_tokens = 4
|
35 |
-
# clip_embeddings_dim = 1024
|
36 |
self.cross_attention_dim = cross_attention_dim
|
37 |
self.clip_extra_context_tokens = clip_extra_context_tokens
|
38 |
-
# 创建了一个线性层self.proj,将clip_embeddings_dim作为输入维度,将self.clip_extra_context_tokens * cross_attention_dim作为输出维度。
|
39 |
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
40 |
-
# self.proj_1 = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
41 |
-
#
|
42 |
-
# # 访问线性层的权重参数
|
43 |
-
# weights = self.proj.weight
|
44 |
-
# print("proj_weights")
|
45 |
-
# print(weights)
|
46 |
-
# # 访问线性层的权重参数
|
47 |
-
# weights_1 = self.proj_1.weight
|
48 |
-
# print("proj_1_weights")
|
49 |
-
# print(weights_1)
|
50 |
-
#
|
51 |
-
# # 访问线性层的偏置参数
|
52 |
-
# bias = self.proj.bias
|
53 |
-
# print("proj_bias")
|
54 |
-
# print(bias)
|
55 |
-
# # 访问线性层的偏置参数
|
56 |
-
# bias_1 = self.proj_1.bias
|
57 |
-
# print("proj_1_bias")
|
58 |
-
# print(bias_1)
|
59 |
-
|
60 |
-
|
61 |
-
# 接着,它创建了一个LayerNorm层self.norm,将cross_attention_dim作为输入维度
|
62 |
-
# LayerNorm层能对每个通道进行归一化处理,确保每个通道均值方差一致,使得每个通道的特征分布相对一致,帮助模型学习特征
|
63 |
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
64 |
|
65 |
def forward(self, image_embeds):
|
66 |
-
# 在前向传播函数中,它接受image_embeds作为输入,然后将其赋值给embeds。
|
67 |
embeds = image_embeds
|
68 |
-
# embeds.shape = [1,1024]
|
69 |
-
# self.proj(embeds).shape = [1,3072]
|
70 |
-
# 接着,它使用self.proj对embeds进行线性变换,并将结果reshape
|
71 |
clip_extra_context_tokens = self.proj(embeds).reshape(
|
72 |
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
73 |
)
|
74 |
-
# clip_extra_context_tokens.shape = [1,4,768]
|
75 |
-
# 然后,它将结果传入self.norm进行LayerNorm操作,并返回处理后的结果clip_extra_context_tokens。
|
76 |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
77 |
-
# clip_extra_context_tokens.shape = [1,4,768]
|
78 |
return clip_extra_context_tokens
|
79 |
|
80 |
|
@@ -110,7 +76,7 @@ class IPAdapter:
|
|
110 |
|
111 |
# load image encoder
|
112 |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
113 |
-
self.device, dtype=torch.
|
114 |
)
|
115 |
self.clip_image_processor = CLIPImageProcessor()
|
116 |
# image proj model
|
@@ -123,20 +89,14 @@ class IPAdapter:
|
|
123 |
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
124 |
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
125 |
clip_extra_context_tokens=self.num_tokens,
|
126 |
-
).to(self.device, dtype=torch.
|
127 |
return image_proj_model
|
128 |
|
129 |
def set_ip_adapter(self):
|
130 |
-
# 首先,它获取了self.pipe.unet中的unet,
|
131 |
unet = self.pipe.unet
|
132 |
-
# 并初始化了一个空的字典attn_procs
|
133 |
attn_procs = {}
|
134 |
-
# 然后,它遍历unet.attn_processors中的每个键名name
|
135 |
for name in unet.attn_processors.keys():
|
136 |
-
# 在循环中,它根据name的不同情况设置cross_attention_dim和hidden_size
|
137 |
-
# 如果name以"attn1.processor"结尾,那么cross_attention_dim被设置为None;否则,它被设置为unet.config.cross_attention_dim。
|
138 |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
139 |
-
# 接着,根据name的前缀不同,设置了hidden_size的值
|
140 |
if name.startswith("mid_block"):
|
141 |
hidden_size = unet.config.block_out_channels[-1]
|
142 |
elif name.startswith("up_blocks"):
|
@@ -145,12 +105,9 @@ class IPAdapter:
|
|
145 |
elif name.startswith("down_blocks"):
|
146 |
block_id = int(name[len("down_blocks.")])
|
147 |
hidden_size = unet.config.block_out_channels[block_id]
|
148 |
-
# 接下来,根据cross_attention_dim的值,为每个name创建了一个对应的AttnProcessor或IPAttnProcessor,并将其加入attn_procs字典中最后
|
149 |
if cross_attention_dim is None:
|
150 |
-
#print("initialization:attn_procs[name] = AttnProcessor()")
|
151 |
attn_procs[name] = AttnProcessor()
|
152 |
else:
|
153 |
-
#print("initialization:attn_procs[name] = IPAttnProcessor()")
|
154 |
attn_procs[name] = IPAttnProcessor(
|
155 |
hidden_size= hidden_size,
|
156 |
cross_attention_dim=cross_attention_dim,
|
@@ -158,9 +115,7 @@ class IPAdapter:
|
|
158 |
num_tokens=self.num_tokens,
|
159 |
Control_factor=self.Control_factor,
|
160 |
IP_factor=self.IP_factor,
|
161 |
-
).to(self.device, dtype=torch.
|
162 |
-
# 调用unet.set_attn_processor(attn_procs)来设置unet的注意力处理器
|
163 |
-
# 同时调用self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))来设置self.pipe.controlnet的注意力处理器。
|
164 |
unet.set_attn_processor(attn_procs)
|
165 |
#self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
166 |
if hasattr(self.pipe, "controlnet"):
|
@@ -171,12 +126,8 @@ class IPAdapter:
|
|
171 |
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
172 |
|
173 |
def load_ip_adapter(self):
|
174 |
-
# 该方法用于加载IP适配器的状态。然后,它使用safe_open函数打开self.ip_ckpt文件,并遍历文件中的键名。
|
175 |
-
# 首先,它检查self.ip_ckpt的文件扩展名是否为".safetensors"。
|
176 |
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
177 |
-
# 如果是,它创建了一个空的state_dict字典,包含"image_proj"和"ip_adapter"两个键。
|
178 |
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
179 |
-
# 对于以"image_proj."开头的键名,它将对应的张量存入state_dict["image_proj"]中;对于以"ip_adapter."开头的键名,它将对应的张量存入state_dict["ip_adapter"]中。
|
180 |
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
181 |
for key in f.keys():
|
182 |
if key.startswith("image_proj."):
|
@@ -184,12 +135,7 @@ class IPAdapter:
|
|
184 |
elif key.startswith("ip_adapter."):
|
185 |
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
186 |
else:
|
187 |
-
# 如果self.ip_ckpt的文件扩展名不是".safetensors",那么它直接使用torch.load函数加载self.ip_ckpt文件的状态,并将其存入state_dict中。
|
188 |
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
189 |
-
# 这段代码中的两行分别用于加载预训练模型的参数。
|
190 |
-
# 第一行使用load_state_dict方法将state_dict中的"image_proj"部分加载到self.image_proj_model中
|
191 |
-
# 而第二行则尝试将state_dict中的"ip_adapter"部分加载到ip_layers中。
|
192 |
-
# 需要注意的是,ip_layers是一个ModuleList,它包含了多个attn_processors,因此在尝试加载"ip_adapter"部分时,需要确保state_dict中的键能够与ip_layers中的各个子模块对应上。
|
193 |
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
194 |
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
195 |
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
@@ -200,14 +146,9 @@ class IPAdapter:
|
|
200 |
if isinstance(pil_image, Image.Image):
|
201 |
pil_image = [pil_image]
|
202 |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
203 |
-
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.
|
204 |
-
|
205 |
-
# clip_imageBroken = self.clip_image_processor(images=image_broken, return_tensors="pt").pixel_values
|
206 |
-
# clip_imageBroken_embeds = self.image_encoder(clip_imageBroken.to(self.device, dtype=torch.float16)).image_embeds
|
207 |
-
# clip_image_embeds.shape: torch.Size([1, 1024])
|
208 |
-
# style_vector = clip_image_embeds-clip_imageBroken_embeds
|
209 |
else:
|
210 |
-
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.
|
211 |
|
212 |
|
213 |
# image_prompt_embeds = self.image_proj_model(style_vector)
|
@@ -382,7 +323,7 @@ class IPAdapterPlus(IPAdapter):
|
|
382 |
embedding_dim=self.image_encoder.config.hidden_size,
|
383 |
output_dim=self.pipe.unet.config.cross_attention_dim,
|
384 |
ff_mult=4,
|
385 |
-
).to(self.device, dtype=torch.
|
386 |
return image_proj_model
|
387 |
|
388 |
@torch.inference_mode()
|
@@ -390,7 +331,7 @@ class IPAdapterPlus(IPAdapter):
|
|
390 |
if isinstance(pil_image, Image.Image):
|
391 |
pil_image = [pil_image]
|
392 |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
393 |
-
clip_image = clip_image.to(self.device, dtype=torch.
|
394 |
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
395 |
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
396 |
uncond_clip_image_embeds = self.image_encoder(
|
@@ -408,7 +349,7 @@ class IPAdapterFull(IPAdapterPlus):
|
|
408 |
image_proj_model = MLPProjModel(
|
409 |
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
410 |
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
411 |
-
).to(self.device, dtype=torch.
|
412 |
return image_proj_model
|
413 |
|
414 |
# image_proj_model = Resampler(
|
@@ -425,7 +366,7 @@ class IPAdapterPlusXL(IPAdapter):
|
|
425 |
embedding_dim=self.image_encoder.config.hidden_size,
|
426 |
output_dim=self.pipe.unet.config.cross_attention_dim,
|
427 |
ff_mult=4,
|
428 |
-
).to(self.device, dtype=torch.
|
429 |
return image_proj_model
|
430 |
|
431 |
@torch.inference_mode()
|
@@ -433,7 +374,7 @@ class IPAdapterPlusXL(IPAdapter):
|
|
433 |
if isinstance(pil_image, Image.Image):
|
434 |
pil_image = [pil_image]
|
435 |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
436 |
-
clip_image = clip_image.to(self.device, dtype=torch.
|
437 |
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
438 |
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
439 |
uncond_clip_image_embeds = self.image_encoder(
|
|
|
30 |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
31 |
super().__init__()
|
32 |
|
|
|
|
|
|
|
33 |
self.cross_attention_dim = cross_attention_dim
|
34 |
self.clip_extra_context_tokens = clip_extra_context_tokens
|
|
|
35 |
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
37 |
|
38 |
def forward(self, image_embeds):
|
|
|
39 |
embeds = image_embeds
|
|
|
|
|
|
|
40 |
clip_extra_context_tokens = self.proj(embeds).reshape(
|
41 |
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
42 |
)
|
|
|
|
|
43 |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
|
|
44 |
return clip_extra_context_tokens
|
45 |
|
46 |
|
|
|
76 |
|
77 |
# load image encoder
|
78 |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
79 |
+
self.device, dtype=torch.float32
|
80 |
)
|
81 |
self.clip_image_processor = CLIPImageProcessor()
|
82 |
# image proj model
|
|
|
89 |
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
90 |
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
91 |
clip_extra_context_tokens=self.num_tokens,
|
92 |
+
).to(self.device, dtype=torch.float32)
|
93 |
return image_proj_model
|
94 |
|
95 |
def set_ip_adapter(self):
|
|
|
96 |
unet = self.pipe.unet
|
|
|
97 |
attn_procs = {}
|
|
|
98 |
for name in unet.attn_processors.keys():
|
|
|
|
|
99 |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
|
|
100 |
if name.startswith("mid_block"):
|
101 |
hidden_size = unet.config.block_out_channels[-1]
|
102 |
elif name.startswith("up_blocks"):
|
|
|
105 |
elif name.startswith("down_blocks"):
|
106 |
block_id = int(name[len("down_blocks.")])
|
107 |
hidden_size = unet.config.block_out_channels[block_id]
|
|
|
108 |
if cross_attention_dim is None:
|
|
|
109 |
attn_procs[name] = AttnProcessor()
|
110 |
else:
|
|
|
111 |
attn_procs[name] = IPAttnProcessor(
|
112 |
hidden_size= hidden_size,
|
113 |
cross_attention_dim=cross_attention_dim,
|
|
|
115 |
num_tokens=self.num_tokens,
|
116 |
Control_factor=self.Control_factor,
|
117 |
IP_factor=self.IP_factor,
|
118 |
+
).to(self.device, dtype=torch.float32)
|
|
|
|
|
119 |
unet.set_attn_processor(attn_procs)
|
120 |
#self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
121 |
if hasattr(self.pipe, "controlnet"):
|
|
|
126 |
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
127 |
|
128 |
def load_ip_adapter(self):
|
|
|
|
|
129 |
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
|
|
130 |
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
|
|
131 |
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
132 |
for key in f.keys():
|
133 |
if key.startswith("image_proj."):
|
|
|
135 |
elif key.startswith("ip_adapter."):
|
136 |
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
137 |
else:
|
|
|
138 |
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
|
|
|
|
|
|
|
|
139 |
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
140 |
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
141 |
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
|
|
146 |
if isinstance(pil_image, Image.Image):
|
147 |
pil_image = [pil_image]
|
148 |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
149 |
+
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds
|
|
|
|
|
|
|
|
|
|
|
150 |
else:
|
151 |
+
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32)
|
152 |
|
153 |
|
154 |
# image_prompt_embeds = self.image_proj_model(style_vector)
|
|
|
323 |
embedding_dim=self.image_encoder.config.hidden_size,
|
324 |
output_dim=self.pipe.unet.config.cross_attention_dim,
|
325 |
ff_mult=4,
|
326 |
+
).to(self.device, dtype=torch.float32)
|
327 |
return image_proj_model
|
328 |
|
329 |
@torch.inference_mode()
|
|
|
331 |
if isinstance(pil_image, Image.Image):
|
332 |
pil_image = [pil_image]
|
333 |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
334 |
+
clip_image = clip_image.to(self.device, dtype=torch.float32)
|
335 |
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
336 |
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
337 |
uncond_clip_image_embeds = self.image_encoder(
|
|
|
349 |
image_proj_model = MLPProjModel(
|
350 |
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
351 |
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
352 |
+
).to(self.device, dtype=torch.float32)
|
353 |
return image_proj_model
|
354 |
|
355 |
# image_proj_model = Resampler(
|
|
|
366 |
embedding_dim=self.image_encoder.config.hidden_size,
|
367 |
output_dim=self.pipe.unet.config.cross_attention_dim,
|
368 |
ff_mult=4,
|
369 |
+
).to(self.device, dtype=torch.float32)
|
370 |
return image_proj_model
|
371 |
|
372 |
@torch.inference_mode()
|
|
|
374 |
if isinstance(pil_image, Image.Image):
|
375 |
pil_image = [pil_image]
|
376 |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
377 |
+
clip_image = clip_image.to(self.device, dtype=torch.float32)
|
378 |
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
379 |
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
380 |
uncond_clip_image_embeds = self.image_encoder(
|