NIRVANALAN commited on
Commit
bcb1668
·
1 Parent(s): 1786dd6
Files changed (1) hide show
  1. vit/vision_transformer.py +2 -2
vit/vision_transformer.py CHANGED
@@ -74,8 +74,8 @@ class Attention(nn.Module):
74
  self.proj = nn.Linear(dim, dim)
75
  self.proj_drop = nn.Dropout(proj_drop)
76
  # https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L79C1-L80C78
77
- self.q_norm = RMSNorm(head_dim, elementwise_affine=True) if qk_norm else nn.Identity() # sd-3
78
- self.k_norm = RMSNorm(head_dim, elementwise_affine=True) if qk_norm else nn.Identity()
79
 
80
  # if qk_norm:
81
  # self.q_norm = LayerNorm(dim, eps=1e-5)
 
74
  self.proj = nn.Linear(dim, dim)
75
  self.proj_drop = nn.Dropout(proj_drop)
76
  # https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L79C1-L80C78
77
+ self.q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity() # sd-3
78
+ self.k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity()
79
 
80
  # if qk_norm:
81
  # self.q_norm = LayerNorm(dim, eps=1e-5)