Update README.md
Browse files
README.md
CHANGED
@@ -49,9 +49,14 @@ import torch.nn as nn
|
|
49 |
from feat.emo_detectors.ResMaskNet.resmasknet_test import ResMasking
|
50 |
from huggingface_hub import hf_hub_download
|
51 |
|
|
|
|
|
|
|
|
|
|
|
52 |
device = 'cpu'
|
53 |
-
emotion_detector = ResMasking("", in_channels=
|
54 |
-
emotion_detector.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(512,
|
55 |
emotion_model_file = hf_hub_download(repo_id='py-feat/resmasknet', filename="ResMaskNet_Z_resmasking_dropout1_rot30.pth")
|
56 |
emotion_checkpoint = torch.load(emotion_model_file, map_location=device)["net"]
|
57 |
emotion_detector.load_state_dict(emotion_checkpoint)
|
|
|
49 |
from feat.emo_detectors.ResMaskNet.resmasknet_test import ResMasking
|
50 |
from huggingface_hub import hf_hub_download
|
51 |
|
52 |
+
# Load Configs
|
53 |
+
emotion_config_file = hf_hub_download(repo_id= "py-feat/resmasknet", filename="config.json", cache_dir=get_resource_path())
|
54 |
+
with open(emotion_config_file, "r") as f:
|
55 |
+
emotion_config = json.load(f)
|
56 |
+
|
57 |
device = 'cpu'
|
58 |
+
emotion_detector = ResMasking("", in_channels=emotion_config['in_channels'])
|
59 |
+
emotion_detector.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(512, emotion_config['num_classes']))
|
60 |
emotion_model_file = hf_hub_download(repo_id='py-feat/resmasknet', filename="ResMaskNet_Z_resmasking_dropout1_rot30.pth")
|
61 |
emotion_checkpoint = torch.load(emotion_model_file, map_location=device)["net"]
|
62 |
emotion_detector.load_state_dict(emotion_checkpoint)
|