Niksa Praljak commited on
Commit
b497299
1 Parent(s): cb5bad1

Facilitator weights README.md instructions

Browse files
Files changed (1) hide show
  1. weights/Facilitator/README.md +49 -8
weights/Facilitator/README.md CHANGED
@@ -14,12 +14,11 @@ This folder will contain the pre-trained weights for the **Facilitator** model.
14
 
15
  The Google Drive link for downloading the Facilitator pre-trained weights will be added here soon.
16
 
17
- ---
18
-
19
- ## **File Details**
20
 
21
- - **File Name**: Facilitator pre-trained weights (TBD).
22
- - **Description**: Pre-trained weights for the Facilitator model.
 
 
23
 
24
  ---
25
 
@@ -28,8 +27,50 @@ The Google Drive link for downloading the Facilitator pre-trained weights will b
28
  Once available, the pre-trained weights can be loaded as follows:
29
 
30
  ```python
 
31
  import torch
32
- model = YourFacilitatorModel() # Replace with your model class
33
- model.load_state_dict(torch.load("weights/Facilitator/Facilitator_weights.bin", map_location="cpu"))
34
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
 
14
 
15
  The Google Drive link for downloading the Facilitator pre-trained weights will be added here soon.
16
 
 
 
 
17
 
18
+ ```bash
19
+ pip install gdown # assuming gdown package is not already installed
20
+ gdown --id 1_YWwILXDkx9MSoSA1kfS-y0jk3Vy4HJE -O BioM3_Facilitator_epoch20.bin
21
+ ```
22
 
23
  ---
24
 
 
27
  Once available, the pre-trained weights can be loaded as follows:
28
 
29
  ```python
30
+ import json
31
  import torch
32
+ from argparse import Namespace
33
+ import Stage1_source.model as mod
34
+
35
+ # Step 1: Load JSON Configuration
36
+ def load_json_config(json_path):
37
+ """
38
+ Load a JSON configuration file and return it as a dictionary.
39
+ """
40
+ with open(json_path, "r") as f:
41
+ config = json.load(f)
42
+ return config
43
+
44
+ # Step 2: Convert JSON Dictionary to Namespace
45
+ def convert_to_namespace(config_dict):
46
+ """
47
+ Recursively convert a dictionary to an argparse Namespace.
48
+ """
49
+ for key, value in config_dict.items():
50
+ if isinstance(value, dict):
51
+ config_dict[key] = convert_to_namespace(value)
52
+ return Namespace(**config_dict)
53
+
54
+ if __name__ == '__main__':
55
+ # Path to configuration and weights
56
+ config_path = "stage2_config.json"
57
+ model_weights_path = "weights/Facilitator/BioM3_Facilitator_epoch20.bin"
58
+
59
+ # Load Configuration
60
+ print("Loading configuration...")
61
+ config_dict = load_json_config(config_path)
62
+ config_args = convert_to_namespace(config_dict)
63
+
64
+ # Load Model
65
+ print("Loading pre-trained model weights...")
66
+ model = mod.Facilitator(
67
+ in_dim=config_args.emb_dim,
68
+ hid_dim=config_args.hid_dim,
69
+ out_dim=config_args.emb_dim,
70
+ dropout=config_args.dropout
71
+ ) # Initialize the model with arguments
72
+ model.load_state_dict(torch.load(model_weights_path, map_location="cpu"))
73
+ model.eval()
74
+ print("Model loaded successfully with weights!")
75
 
76
+ ```