Spaces:
Sleeping
Sleeping
Souha Ben Hassine
commited on
Commit
·
e017cd8
1
Parent(s):
e09e818
initial commit
Browse files
app.py
CHANGED
@@ -70,26 +70,29 @@ class MultimodalRiskBehaviorModel(nn.Module):
|
|
70 |
@classmethod
|
71 |
def from_pretrained(cls, load_directory, map_location=None):
|
72 |
if os.path.exists(load_directory):
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
else:
|
77 |
-
# Hugging Face Hub: load configuration using AutoConfig
|
78 |
-
config = AutoConfig.from_pretrained(load_directory)
|
79 |
-
config_dict = {
|
80 |
-
"text_model_name": config.name_or_path,
|
81 |
-
"hidden_dim": config.hidden_size if hasattr(config, 'hidden_size') else 512 # Default to 512 if not available
|
82 |
-
}
|
83 |
-
|
84 |
-
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
|
85 |
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
state_dict = torch.load(state_dict_path, map_location=map_location)
|
|
|
|
|
89 |
else:
|
90 |
-
|
|
|
91 |
|
92 |
-
|
|
|
|
|
|
|
93 |
|
94 |
return model
|
95 |
|
|
|
70 |
@classmethod
|
71 |
def from_pretrained(cls, load_directory, map_location=None):
|
72 |
if os.path.exists(load_directory):
|
73 |
+
# Handling local paths
|
74 |
+
config_path = os.path.join(load_directory, 'config.json')
|
75 |
+
state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
if not os.path.exists(config_path):
|
78 |
+
raise FileNotFoundError(f"{config_path} not found.")
|
79 |
+
if not os.path.exists(state_dict_path):
|
80 |
+
raise FileNotFoundError(f"{state_dict_path} not found.")
|
81 |
+
|
82 |
+
with open(config_path, 'r') as f:
|
83 |
+
config_dict = json.load(f)
|
84 |
+
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
|
85 |
state_dict = torch.load(state_dict_path, map_location=map_location)
|
86 |
+
model.load_state_dict(state_dict)
|
87 |
+
|
88 |
else:
|
89 |
+
# Handling Hugging Face Hub paths
|
90 |
+
config = AutoConfig.from_pretrained(load_directory)
|
91 |
|
92 |
+
# Use AutoModel to load a pre-trained model with weights
|
93 |
+
hf_model = AutoModel.from_pretrained(load_directory, config=config)
|
94 |
+
model = cls(text_model_name=config.name_or_path, hidden_dim=hf_model.config.hidden_size)
|
95 |
+
model.text_model = hf_model
|
96 |
|
97 |
return model
|
98 |
|