mup base shapes path
Browse files- modeling_nt_bert.py +4 -2
modeling_nt_bert.py
CHANGED
@@ -78,11 +78,13 @@ class BertPreTrainedModel(PreTrainedModel):
|
|
78 |
|
79 |
# since we used MuP, need to reset values since they're not saved with the model
|
80 |
if os.path.exists("base_shapes.bsh") is False:
|
81 |
-
hf_hub_download(
|
82 |
"zpn/human_bp_bert", "base_shapes.bsh"
|
83 |
)
|
|
|
|
|
84 |
|
85 |
-
set_base_shapes(model,
|
86 |
|
87 |
return model
|
88 |
|
|
|
78 |
|
79 |
# since we used MuP, need to reset values since they're not saved with the model
|
80 |
if os.path.exists("base_shapes.bsh") is False:
|
81 |
+
path = hf_hub_download(
|
82 |
"zpn/human_bp_bert", "base_shapes.bsh"
|
83 |
)
|
84 |
+
else:
|
85 |
+
path = "base_shapes.bsh"
|
86 |
|
87 |
+
set_base_shapes(model, path, rescale_params=False)
|
88 |
|
89 |
return model
|
90 |
|