from detectfaces import fer | |
from models.PosterV2_7cls import pyramid_trans_expr2 | |
import os | |
import torch | |
from main import RecorderMeter1, RecorderMeter # noqa: F401 | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
# Construct the full path to the model file | |
model_path = os.path.join(script_dir,"models","checkpoints","raf-db-model_best.pth") | |
# Determine the available device for model execution | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
elif torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
# Initialize the model with specified image size and number of classes | |
model = pyramid_trans_expr2(img_size=224, num_classes=7) | |
# Wrap the model with DataParallel for potential multi-GPU usage | |
model = torch.nn.DataParallel(model) | |
# Move the model to the chosen device | |
model = model.to(device) | |
def main(): | |
fer(model_path=model_path, device=device, model=model) | |
if __name__ == "__main__": | |
main() |