File size: 947 Bytes
51ef5ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from detectfaces import fer
from models.PosterV2_7cls import pyramid_trans_expr2
import os 
import torch 
from main import RecorderMeter1, RecorderMeter  # noqa: F401 
script_dir = os.path.dirname(os.path.abspath(__file__))

# Construct the full path to the model file
model_path = os.path.join(script_dir,"models","checkpoints","raf-db-model_best.pth")

# Determine the available device for model execution
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# Initialize the model with specified image size and number of classes
model = pyramid_trans_expr2(img_size=224, num_classes=7)

# Wrap the model with DataParallel for potential multi-GPU usage
model = torch.nn.DataParallel(model)

# Move the model to the chosen device
model = model.to(device)
def main():
    fer(model_path=model_path, device=device, model=model)

if __name__ == "__main__":
    main()