Souha Ben Hassine commited on
Commit
e017cd8
·
1 Parent(s): e09e818

initial commit

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -70,26 +70,29 @@ class MultimodalRiskBehaviorModel(nn.Module):
70
  @classmethod
71
  def from_pretrained(cls, load_directory, map_location=None):
72
  if os.path.exists(load_directory):
73
- # Local directory: load configuration from file
74
- with open(os.path.join(load_directory, 'config.json'), 'r') as f:
75
- config_dict = json.load(f)
76
- else:
77
- # Hugging Face Hub: load configuration using AutoConfig
78
- config = AutoConfig.from_pretrained(load_directory)
79
- config_dict = {
80
- "text_model_name": config.name_or_path,
81
- "hidden_dim": config.hidden_size if hasattr(config, 'hidden_size') else 512 # Default to 512 if not available
82
- }
83
-
84
- model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
85
 
86
- state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
87
- if map_location:
 
 
 
 
 
 
88
  state_dict = torch.load(state_dict_path, map_location=map_location)
 
 
89
  else:
90
- state_dict = torch.load(state_dict_path)
 
91
 
92
- model.load_state_dict(state_dict)
 
 
 
93
 
94
  return model
95
 
 
70
  @classmethod
71
  def from_pretrained(cls, load_directory, map_location=None):
72
  if os.path.exists(load_directory):
73
+ # Handling local paths
74
+ config_path = os.path.join(load_directory, 'config.json')
75
+ state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
 
 
 
 
 
 
 
 
 
76
 
77
+ if not os.path.exists(config_path):
78
+ raise FileNotFoundError(f"{config_path} not found.")
79
+ if not os.path.exists(state_dict_path):
80
+ raise FileNotFoundError(f"{state_dict_path} not found.")
81
+
82
+ with open(config_path, 'r') as f:
83
+ config_dict = json.load(f)
84
+ model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
85
  state_dict = torch.load(state_dict_path, map_location=map_location)
86
+ model.load_state_dict(state_dict)
87
+
88
  else:
89
+ # Handling Hugging Face Hub paths
90
+ config = AutoConfig.from_pretrained(load_directory)
91
 
92
+ # Use AutoModel to load a pre-trained model with weights
93
+ hf_model = AutoModel.from_pretrained(load_directory, config=config)
94
+ model = cls(text_model_name=config.name_or_path, hidden_dim=hf_model.config.hidden_size)
95
+ model.text_model = hf_model
96
 
97
  return model
98