Spaces:
Runtime error
Runtime error
Update ip_adapter/attention_processor.py
Browse files
ip_adapter/attention_processor.py
CHANGED
@@ -216,51 +216,11 @@ class IPAttnProcessor(nn.Module):
|
|
216 |
ip_key = attn.head_to_batch_dim(ip_key)
|
217 |
ip_value = attn.head_to_batch_dim(ip_value)
|
218 |
|
219 |
-
# ip_key.shape=[16, 4, 40]
|
220 |
-
# query = [16,6025,40]
|
221 |
-
# target = [16,6025,4]
|
222 |
-
# print("**************************************************")
|
223 |
-
# print(query)
|
224 |
-
# print("**************************************************")
|
225 |
-
# threshold = 5
|
226 |
-
# tensor_from_data = torch.tensor(query).to("cuda")
|
227 |
-
# binary_mask = torch.where(tensor_from_data > threshold, torch.tensor(0).to("cuda"), torch.tensor(1).to("cuda"))
|
228 |
-
# binary_mask = binary_mask.to(torch.float16)
|
229 |
-
# print("**************************************************")
|
230 |
-
# print(binary_mask)
|
231 |
-
# print("**************************************************")
|
232 |
-
|
233 |
-
|
234 |
-
# query.shape=[16,6205,40]
|
235 |
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
236 |
-
##########################################
|
237 |
-
# attention_probs
|
238 |
-
#ip_attention_probs = attn.get_attention_scores(keyforIPADAPTER, ip_key, None)
|
239 |
-
##########################################
|
240 |
-
|
241 |
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
242 |
-
##########################################
|
243 |
-
# ip_hidden_states = ip_hidden_states*binary_mask +(1-binary_mask)*query
|
244 |
-
##########################################
|
245 |
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
246 |
-
|
247 |
-
# hidden_states.shape=【2,6205,320】s
|
248 |
-
# ip_hidden_states.shape=【2,3835,320】
|
249 |
-
# hidden_states = hidden_states + self.scale* ip_hidden_states
|
250 |
-
#print("Control_factor:{}".format(self.Control_factor))
|
251 |
-
#print("IP_factor:{}".format(self.IP_factor))
|
252 |
hidden_states = self.Control_factor * hidden_states + self.IP_factor * self.scale * ip_hidden_states
|
253 |
|
254 |
-
|
255 |
-
|
256 |
-
#hidden_states = 2*hidden_states +0.6*self.scale*ip_hidden_states
|
257 |
-
# if self.roundNumber < 12:
|
258 |
-
# hidden_states = hidden_states
|
259 |
-
# else:
|
260 |
-
# hidden_states = 1.2*hidden_states +0.6*self.scale*ip_hidden_states
|
261 |
-
# self.roundNumber = self.roundNumber + 1
|
262 |
-
|
263 |
-
|
264 |
# linear proj
|
265 |
hidden_states = attn.to_out[0](hidden_states)
|
266 |
# dropout
|
|
|
216 |
ip_key = attn.head_to_batch_dim(ip_key)
|
217 |
ip_value = attn.head_to_batch_dim(ip_value)
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
|
|
|
|
|
|
|
|
|
|
220 |
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
|
|
|
|
|
|
221 |
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
hidden_states = self.Control_factor * hidden_states + self.IP_factor * self.scale * ip_hidden_states
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
# linear proj
|
225 |
hidden_states = attn.to_out[0](hidden_states)
|
226 |
# dropout
|