File size: 1,749 Bytes
d1b31ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b404794
c0f6432
b404794
c0f6432
d1b31ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
File: model.py
Author: Elena Ryumina and Dmitry Ryumin
Description: This module provides functions for loading and processing a pre-trained deep learning model
             for facial expression recognition.
License: MIT License
"""

import torch
import requests
from PIL import Image
from torchvision import transforms

# Importing necessary components for the Gradio app
from app.config import config_data


def load_model(model_url, model_path):
    try:
        with requests.get(model_url, stream=True) as response:
            with open(model_path, "wb") as file:
                for chunk in response.iter_content(chunk_size=8192):
                    file.write(chunk)
        return torch.jit.load(model_path).eval()
    except Exception as e:
        print(f"Error loading model: {e}")
        return None


pth_model_static = load_model(config_data.model_static_url, config_data.model_static_path).to(config_data.DEVICE)

pth_model_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path).to(config_data.DEVICE)



def pth_processing(fp):
    class PreprocessInput(torch.nn.Module):
        def init(self):
            super(PreprocessInput, self).init()

        def forward(self, x):
            x = x.to(torch.float32)
            x = torch.flip(x, dims=(0,))
            x[0, :, :] -= 91.4953
            x[1, :, :] -= 103.8827
            x[2, :, :] -= 131.0912
            return x

    def get_img_torch(img, target_size=(224, 224)):
        transform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()])
        img = img.resize(target_size, Image.Resampling.NEAREST)
        img = transform(img)
        img = torch.unsqueeze(img, 0)
        return img

    return get_img_torch(fp)