VictorSanh commited on
Commit
8084b2d
·
1 Parent(s): a9d91fb

formatting

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +24 -22
modeling_siglip.py CHANGED
@@ -284,7 +284,7 @@ class SiglipVisionEmbeddings(nn.Module):
284
  )
285
 
286
  self.num_patches_per_side = self.image_size // self.patch_size
287
- self.num_patches = self.num_patches_per_side ** 2
288
  self.num_positions = self.num_patches
289
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
290
 
@@ -295,16 +295,22 @@ class SiglipVisionEmbeddings(nn.Module):
295
  embeddings = patch_embeds.flatten(2).transpose(1, 2)
296
 
297
  max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
298
- max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
299
- boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
300
- position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w,), fill_value=0)
 
 
 
 
 
 
301
 
302
  for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
303
  nb_patches_h = p_attn_mask[:, 0].sum()
304
  nb_patches_w = p_attn_mask[0].sum()
305
 
306
- fractional_coords_h = torch.arange(0, 1-1e-6, 1/nb_patches_h)
307
- fractional_coords_w = torch.arange(0, 1-1e-6, 1/nb_patches_w)
308
 
309
  bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
310
  bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
@@ -1095,27 +1101,26 @@ class SiglipVisionTransformer(nn.Module):
1095
  batch_size = pixel_values.size(0)
1096
  if patch_attention_mask is None:
1097
  patch_attention_mask = torch.ones(
1098
- size=(batch_size, pixel_values.size(2)//self.config.patch_size, pixel_values.size(3)//self.config.patch_size),
 
 
 
 
1099
  dtype=torch.bool,
1100
  device=pixel_values.device,
1101
  )
1102
- # if pixel_attention_mask is None:
1103
- # # assuming `pixel_attention_mask` is of size bs x h x w
1104
- # pixel_attention_mask = torch.ones(size=(batch_size, pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device)
1105
-
1106
- # subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
1107
- # patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
1108
 
1109
- hidden_states = self.embeddings(
1110
- pixel_values=pixel_values,
1111
- patch_attention_mask=patch_attention_mask
1112
- )
1113
 
1114
  patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1115
 
1116
  encoder_outputs = self.encoder(
1117
  inputs_embeds=hidden_states,
1118
- attention_mask=_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if not self.config._flash_attn_2_enabled else patch_attention_mask,
 
 
 
 
1119
  output_attentions=output_attentions,
1120
  output_hidden_states=output_hidden_states,
1121
  return_dict=return_dict,
@@ -1156,10 +1161,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1156
  probe = self.probe.repeat(batch_size, 1, 1)
1157
 
1158
  hidden_state = self.attention(
1159
- query=probe,
1160
- key=hidden_state,
1161
- value=hidden_state,
1162
- key_padding_mask=~attention_mask
1163
  )[0]
1164
 
1165
  residual = hidden_state
 
284
  )
285
 
286
  self.num_patches_per_side = self.image_size // self.patch_size
287
+ self.num_patches = self.num_patches_per_side**2
288
  self.num_positions = self.num_patches
289
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
290
 
 
295
  embeddings = patch_embeds.flatten(2).transpose(1, 2)
296
 
297
  max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
298
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
299
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
300
+ position_ids = torch.full(
301
+ size=(
302
+ batch_size,
303
+ max_nb_patches_h * max_nb_patches_w,
304
+ ),
305
+ fill_value=0,
306
+ )
307
 
308
  for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
309
  nb_patches_h = p_attn_mask[:, 0].sum()
310
  nb_patches_w = p_attn_mask[0].sum()
311
 
312
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
313
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
314
 
315
  bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
316
  bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
 
1101
  batch_size = pixel_values.size(0)
1102
  if patch_attention_mask is None:
1103
  patch_attention_mask = torch.ones(
1104
+ size=(
1105
+ batch_size,
1106
+ pixel_values.size(2) // self.config.patch_size,
1107
+ pixel_values.size(3) // self.config.patch_size,
1108
+ ),
1109
  dtype=torch.bool,
1110
  device=pixel_values.device,
1111
  )
 
 
 
 
 
 
1112
 
1113
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
 
 
 
1114
 
1115
  patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1116
 
1117
  encoder_outputs = self.encoder(
1118
  inputs_embeds=hidden_states,
1119
+ attention_mask=(
1120
+ _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1121
+ if not self.config._flash_attn_2_enabled
1122
+ else patch_attention_mask
1123
+ ),
1124
  output_attentions=output_attentions,
1125
  output_hidden_states=output_hidden_states,
1126
  return_dict=return_dict,
 
1161
  probe = self.probe.repeat(batch_size, 1, 1)
1162
 
1163
  hidden_state = self.attention(
1164
+ query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
 
 
 
1165
  )[0]
1166
 
1167
  residual = hidden_state