Meaowangxi commited on
Commit
ac12192
·
verified ·
1 Parent(s): 96f3396

Update ip_adapter/attention_processor.py

Browse files
Files changed (1) hide show
  1. ip_adapter/attention_processor.py +0 -40
ip_adapter/attention_processor.py CHANGED
@@ -216,51 +216,11 @@ class IPAttnProcessor(nn.Module):
216
  ip_key = attn.head_to_batch_dim(ip_key)
217
  ip_value = attn.head_to_batch_dim(ip_value)
218
 
219
- # ip_key.shape=[16, 4, 40]
220
- # query = [16,6025,40]
221
- # target = [16,6025,4]
222
- # print("**************************************************")
223
- # print(query)
224
- # print("**************************************************")
225
- # threshold = 5
226
- # tensor_from_data = torch.tensor(query).to("cuda")
227
- # binary_mask = torch.where(tensor_from_data > threshold, torch.tensor(0).to("cuda"), torch.tensor(1).to("cuda"))
228
- # binary_mask = binary_mask.to(torch.float16)
229
- # print("**************************************************")
230
- # print(binary_mask)
231
- # print("**************************************************")
232
-
233
-
234
- # query.shape=[16,6205,40]
235
  ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
236
- ##########################################
237
- # attention_probs
238
- #ip_attention_probs = attn.get_attention_scores(keyforIPADAPTER, ip_key, None)
239
- ##########################################
240
-
241
  ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
242
- ##########################################
243
- # ip_hidden_states = ip_hidden_states*binary_mask +(1-binary_mask)*query
244
- ##########################################
245
  ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
246
-
247
- # hidden_states.shape=【2,6205,320】s
248
- # ip_hidden_states.shape=【2,3835,320】
249
- # hidden_states = hidden_states + self.scale* ip_hidden_states
250
- #print("Control_factor:{}".format(self.Control_factor))
251
- #print("IP_factor:{}".format(self.IP_factor))
252
  hidden_states = self.Control_factor * hidden_states + self.IP_factor * self.scale * ip_hidden_states
253
 
254
-
255
-
256
- #hidden_states = 2*hidden_states +0.6*self.scale*ip_hidden_states
257
- # if self.roundNumber < 12:
258
- # hidden_states = hidden_states
259
- # else:
260
- # hidden_states = 1.2*hidden_states +0.6*self.scale*ip_hidden_states
261
- # self.roundNumber = self.roundNumber + 1
262
-
263
-
264
  # linear proj
265
  hidden_states = attn.to_out[0](hidden_states)
266
  # dropout
 
216
  ip_key = attn.head_to_batch_dim(ip_key)
217
  ip_value = attn.head_to_batch_dim(ip_value)
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
 
 
 
 
 
220
  ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
 
 
 
221
  ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
 
 
 
 
 
 
222
  hidden_states = self.Control_factor * hidden_states + self.IP_factor * self.scale * ip_hidden_states
223
 
 
 
 
 
 
 
 
 
 
 
224
  # linear proj
225
  hidden_states = attn.to_out[0](hidden_states)
226
  # dropout