VictorSanh
commited on
Commit
·
003d72c
1
Parent(s):
b54299a
ops in fp16
Browse files- modeling_siglip.py +11 -4
modeling_siglip.py
CHANGED
@@ -95,10 +95,11 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
95 |
|
96 |
# Use inverse cdf transform for normal distribution to get truncated
|
97 |
# standard normal
|
98 |
-
if tensor.dtype == torch.
|
|
|
99 |
tensor = tensor.to(torch.float32)
|
100 |
tensor.erfinv_()
|
101 |
-
tensor = tensor.to(torch.
|
102 |
else:
|
103 |
tensor.erfinv_()
|
104 |
|
@@ -107,7 +108,13 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
107 |
tensor.add_(mean)
|
108 |
|
109 |
# Clamp to ensure it's in the proper range
|
110 |
-
tensor.
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
|
113 |
def trunc_normal_tf_(
|
@@ -732,7 +739,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
|
732 |
nn.init.normal_(module.attention.in_proj_weight.data)
|
733 |
nn.init.zeros_(module.attention.in_proj_bias.data)
|
734 |
elif isinstance(module, SiglipModel):
|
735 |
-
logit_scale_init = torch.
|
736 |
module.logit_scale.data.fill_(logit_scale_init)
|
737 |
module.logit_bias.data.zero_()
|
738 |
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
|
95 |
|
96 |
# Use inverse cdf transform for normal distribution to get truncated
|
97 |
# standard normal
|
98 |
+
if tensor.dtype == torch.float16:
|
99 |
+
# The `erfinv_` op is not (yet?) defined in float16
|
100 |
tensor = tensor.to(torch.float32)
|
101 |
tensor.erfinv_()
|
102 |
+
tensor = tensor.to(torch.float16)
|
103 |
else:
|
104 |
tensor.erfinv_()
|
105 |
|
|
|
108 |
tensor.add_(mean)
|
109 |
|
110 |
# Clamp to ensure it's in the proper range
|
111 |
+
if tensor.dtype == torch.float16:
|
112 |
+
# The `clamp_` op is not (yet?) defined in float16
|
113 |
+
tensor = tensor.to(torch.float32)
|
114 |
+
tensor.clamp_(min=a, max=b)
|
115 |
+
tensor = tensor.to(torch.float16)
|
116 |
+
else:
|
117 |
+
tensor.clamp_(min=a, max=b)
|
118 |
|
119 |
|
120 |
def trunc_normal_tf_(
|
|
|
739 |
nn.init.normal_(module.attention.in_proj_weight.data)
|
740 |
nn.init.zeros_(module.attention.in_proj_bias.data)
|
741 |
elif isinstance(module, SiglipModel):
|
742 |
+
logit_scale_init = torch.tensor(0.0)
|
743 |
module.logit_scale.data.fill_(logit_scale_init)
|
744 |
module.logit_bias.data.zero_()
|
745 |
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|