VictorSanh commited on
Commit
58b4a7a
·
1 Parent(s): 003d72c
Files changed (1) hide show
  1. modeling_siglip.py +5 -4
modeling_siglip.py CHANGED
@@ -95,11 +95,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
95
 
96
  # Use inverse cdf transform for normal distribution to get truncated
97
  # standard normal
98
- if tensor.dtype == torch.float16:
99
- # The `erfinv_` op is not (yet?) defined in float16
 
100
  tensor = tensor.to(torch.float32)
101
  tensor.erfinv_()
102
- tensor = tensor.to(torch.float16)
103
  else:
104
  tensor.erfinv_()
105
 
@@ -109,7 +110,7 @@ def _trunc_normal_(tensor, mean, std, a, b):
109
 
110
  # Clamp to ensure it's in the proper range
111
  if tensor.dtype == torch.float16:
112
- # The `clamp_` op is not (yet?) defined in float16
113
  tensor = tensor.to(torch.float32)
114
  tensor.clamp_(min=a, max=b)
115
  tensor = tensor.to(torch.float16)
 
95
 
96
  # Use inverse cdf transform for normal distribution to get truncated
97
  # standard normal
98
+ if tensor.dtype in [torch.float16, torch.bfloat16]:
99
+ # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
100
+ og_dtype = tensor.dtype
101
  tensor = tensor.to(torch.float32)
102
  tensor.erfinv_()
103
+ tensor = tensor.to(og_dtype)
104
  else:
105
  tensor.erfinv_()
106
 
 
110
 
111
  # Clamp to ensure it's in the proper range
112
  if tensor.dtype == torch.float16:
113
+ # The `clamp_` op is not (yet?) defined in float16+cpu
114
  tensor = tensor.to(torch.float32)
115
  tensor.clamp_(min=a, max=b)
116
  tensor = tensor.to(torch.float16)