File size: 3,822 Bytes
c72cc7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e63dd8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c72cc7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750668a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Custom ResNet-18 for 7-Class Classification

This is a fine-tuned **`ResNet-18`** model designed for a 7-class classification task. The model replaces all **ReLU** activation functions with **PReLU**, introduces **Dropout2D** layers for better generalization, and was trained on a custom dataset with various augmentations.

---

## 📜 Model Details

- **Base Model:** ResNet-18 (pre-trained on ImageNet).
- **Activations:** ReLU layers replaced with PReLU.
- **Dropout:** Dropout2D applied to enhance generalization.
- **Classes:** 7 output classes.
- **Input Size:** Images with customizable dimensions (default: `[100, 100]`).
- **Normalization:** Input images are normalized using the following statistics:
  - Mean: `[0.485, 0.456, 0.406]`
  - Std: `[0.229, 0.224, 0.225]`

---

## 📈 Evaluation Metrics on Test Data 
![confusion matrix](images/confusion_matrix.png)

Accuracy: 79.92%
Precision: 79.80%
Recall: 79.92%
F1-Score: 79.80%

Classification Report:
              precision    recall  f1-score   support

           1       0.79      0.81      0.80       329
           2       0.58      0.47      0.52        74
           3       0.51      0.42      0.46       160
           4       0.92      0.90      0.91      1185
           5       0.74      0.78      0.76       478
           6       0.68      0.72      0.70       162
           7       0.75      0.78      0.77       680

    accuracy                           0.80      3068
   macro avg       0.71      0.70      0.70      3068
weighted avg       0.80      0.80      0.80      3068

## 🧑‍💻 How to Use

You can load the model weights and architecture for inference or fine-tuning with the provided files:

### **Using PyTorch**
```

def get_out_channels(module):
    """تابعی برای یافتن تعداد کانال‌های خروجی از لایه‌های کانولوشن و BatchNorm"""
    if isinstance(module, nn.Conv2d):
        return module.out_channels
    elif isinstance(module, nn.BatchNorm2d):
        return module.num_features
    elif isinstance(module, nn.Linear):
        return module.out_features
    return None

def replace_relu_with_prelu_and_dropout(module, inplace=True):
    for name, child in module.named_children():
        replace_relu_with_prelu_and_dropout(child, inplace)
        
        if isinstance(child, nn.ReLU): 
            out_channels = None
            for prev_name, prev_child in module.named_children():
                if prev_name == name:
                    break
                out_channels = get_out_channels(prev_child) or out_channels
            
            if out_channels is None:
                raise ValueError(f"Cannot determine `out_channels` for {child}. Please check the model structure.")
            
            prelu = PReLU(device=device, num_parameters=out_channels) 
            dropout = nn.Dropout2d(p=0.2) 
            setattr(module, name, nn.Sequential(prelu, dropout).to(device))
model = models.resnet18(weights = models.ResNet18_Weights.IMAGENET1K_V1).train(True).to(device)
replace_relu_with_prelu_and_dropout(model)
# print(model.fc.in_features)


number = model.fc.in_features
module  = []

module.append(LazyLinear(7))
model.fc = Sequential(*module).to(device)

state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
```

## ⚠️ Limitations and Considerations
Input Dimensions: Make sure your input images are resized to the expected dimensions (100x100) before inference.
Number of Classes: The trained model supports exactly 7 classes as defined in the training dataset.
Output: The model output should be a probability of each of the 7 face type labels. Don't forget to use the softmax function to make predictions. Note that softmax is not used in the last layer of this model's architecture.