VictorSanh
commited on
Commit
·
58b4a7a
1
Parent(s):
003d72c
fix ops
Browse files- modeling_siglip.py +5 -4
modeling_siglip.py
CHANGED
@@ -95,11 +95,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
95 |
|
96 |
# Use inverse cdf transform for normal distribution to get truncated
|
97 |
# standard normal
|
98 |
-
if tensor.dtype
|
99 |
-
# The `erfinv_` op is not (yet?) defined in float16
|
|
|
100 |
tensor = tensor.to(torch.float32)
|
101 |
tensor.erfinv_()
|
102 |
-
tensor = tensor.to(
|
103 |
else:
|
104 |
tensor.erfinv_()
|
105 |
|
@@ -109,7 +110,7 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
109 |
|
110 |
# Clamp to ensure it's in the proper range
|
111 |
if tensor.dtype == torch.float16:
|
112 |
-
# The `clamp_` op is not (yet?) defined in float16
|
113 |
tensor = tensor.to(torch.float32)
|
114 |
tensor.clamp_(min=a, max=b)
|
115 |
tensor = tensor.to(torch.float16)
|
|
|
95 |
|
96 |
# Use inverse cdf transform for normal distribution to get truncated
|
97 |
# standard normal
|
98 |
+
if tensor.dtype in [torch.float16, torch.bfloat16]:
|
99 |
+
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
|
100 |
+
og_dtype = tensor.dtype
|
101 |
tensor = tensor.to(torch.float32)
|
102 |
tensor.erfinv_()
|
103 |
+
tensor = tensor.to(og_dtype)
|
104 |
else:
|
105 |
tensor.erfinv_()
|
106 |
|
|
|
110 |
|
111 |
# Clamp to ensure it's in the proper range
|
112 |
if tensor.dtype == torch.float16:
|
113 |
+
# The `clamp_` op is not (yet?) defined in float16+cpu
|
114 |
tensor = tensor.to(torch.float32)
|
115 |
tensor.clamp_(min=a, max=b)
|
116 |
tensor = tensor.to(torch.float16)
|