ljchang commited on
Commit
842709a
1 Parent(s): a0f970c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -2
README.md CHANGED
@@ -49,9 +49,14 @@ import torch.nn as nn
49
  from feat.emo_detectors.ResMaskNet.resmasknet_test import ResMasking
50
  from huggingface_hub import hf_hub_download
51
 
 
 
 
 
 
52
  device = 'cpu'
53
- emotion_detector = ResMasking("", in_channels=3)
54
- emotion_detector.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(512, 7))
55
  emotion_model_file = hf_hub_download(repo_id='py-feat/resmasknet', filename="ResMaskNet_Z_resmasking_dropout1_rot30.pth")
56
  emotion_checkpoint = torch.load(emotion_model_file, map_location=device)["net"]
57
  emotion_detector.load_state_dict(emotion_checkpoint)
 
49
  from feat.emo_detectors.ResMaskNet.resmasknet_test import ResMasking
50
  from huggingface_hub import hf_hub_download
51
 
52
+ # Load Configs
53
+ emotion_config_file = hf_hub_download(repo_id= "py-feat/resmasknet", filename="config.json", cache_dir=get_resource_path())
54
+ with open(emotion_config_file, "r") as f:
55
+ emotion_config = json.load(f)
56
+
57
  device = 'cpu'
58
+ emotion_detector = ResMasking("", in_channels=emotion_config['in_channels'])
59
+ emotion_detector.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(512, emotion_config['num_classes']))
60
  emotion_model_file = hf_hub_download(repo_id='py-feat/resmasknet', filename="ResMaskNet_Z_resmasking_dropout1_rot30.pth")
61
  emotion_checkpoint = torch.load(emotion_model_file, map_location=device)["net"]
62
  emotion_detector.load_state_dict(emotion_checkpoint)