Spaces:
Running
Running
Update model.py
Browse filesupdated model module
model.py
CHANGED
@@ -2,11 +2,21 @@
|
|
2 |
StableResNet Model for Biomass Prediction
|
3 |
A numerically stable ResNet architecture for regression tasks
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
Author: najahpokkiri
|
6 |
Date: 2025-05-17
|
7 |
"""
|
8 |
import torch
|
9 |
import torch.nn as nn
|
|
|
10 |
|
11 |
class StableResNet(nn.Module):
|
12 |
"""Numerically stable ResNet for biomass regression"""
|
@@ -33,20 +43,27 @@ class StableResNet(nn.Module):
|
|
33 |
self._init_weights()
|
34 |
|
35 |
def _make_simple_resblock(self, in_dim, out_dim):
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
nn.
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def _init_weights(self):
|
|
|
50 |
for m in self.modules():
|
51 |
if isinstance(m, nn.Linear):
|
52 |
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
@@ -54,14 +71,35 @@ class StableResNet(nn.Module):
|
|
54 |
nn.init.zeros_(m.bias)
|
55 |
|
56 |
def forward(self, x):
|
|
|
57 |
x = self.input_proj(x)
|
58 |
|
|
|
59 |
identity = x
|
60 |
out = self.layer1(x)
|
61 |
x = out + identity
|
62 |
|
|
|
63 |
x = self.layer2(x)
|
64 |
x = self.layer3(x)
|
65 |
|
|
|
66 |
x = self.regressor(x)
|
67 |
-
return x.squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
StableResNet Model for Biomass Prediction
|
3 |
A numerically stable ResNet architecture for regression tasks
|
4 |
|
5 |
+
Author: najahpokkiri
|
6 |
+
Date: 2025-05-17
|
7 |
+
"""
|
8 |
+
"""
|
9 |
+
StableResNet Model Architecture
|
10 |
+
|
11 |
+
This module defines the StableResNet architecture used for biomass prediction.
|
12 |
+
The model is designed for numerical stability with batch normalization and residual connections.
|
13 |
+
|
14 |
Author: najahpokkiri
|
15 |
Date: 2025-05-17
|
16 |
"""
|
17 |
import torch
|
18 |
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
|
21 |
class StableResNet(nn.Module):
|
22 |
"""Numerically stable ResNet for biomass regression"""
|
|
|
43 |
self._init_weights()
|
44 |
|
45 |
def _make_simple_resblock(self, in_dim, out_dim):
|
46 |
+
"""Create a simple residual block or downsampling block"""
|
47 |
+
if in_dim == out_dim:
|
48 |
+
# Residual block
|
49 |
+
return nn.Sequential(
|
50 |
+
nn.Linear(in_dim, out_dim),
|
51 |
+
nn.BatchNorm1d(out_dim),
|
52 |
+
nn.ReLU(),
|
53 |
+
nn.Linear(out_dim, out_dim),
|
54 |
+
nn.BatchNorm1d(out_dim),
|
55 |
+
nn.ReLU()
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
# Downsampling block
|
59 |
+
return nn.Sequential(
|
60 |
+
nn.Linear(in_dim, out_dim),
|
61 |
+
nn.BatchNorm1d(out_dim),
|
62 |
+
nn.ReLU(),
|
63 |
+
)
|
64 |
|
65 |
def _init_weights(self):
|
66 |
+
"""Initialize weights for better convergence"""
|
67 |
for m in self.modules():
|
68 |
if isinstance(m, nn.Linear):
|
69 |
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
|
|
71 |
nn.init.zeros_(m.bias)
|
72 |
|
73 |
def forward(self, x):
|
74 |
+
"""Forward pass through the network"""
|
75 |
x = self.input_proj(x)
|
76 |
|
77 |
+
# First residual block
|
78 |
identity = x
|
79 |
out = self.layer1(x)
|
80 |
x = out + identity
|
81 |
|
82 |
+
# Remaining blocks
|
83 |
x = self.layer2(x)
|
84 |
x = self.layer3(x)
|
85 |
|
86 |
+
# Regression output
|
87 |
x = self.regressor(x)
|
88 |
+
return x.squeeze()
|
89 |
+
|
90 |
+
def get_model_info():
|
91 |
+
"""Return information about the model architecture"""
|
92 |
+
return {
|
93 |
+
'name': 'StableResNet',
|
94 |
+
'description': 'Numerically stable ResNet for biomass regression',
|
95 |
+
'parameters': {
|
96 |
+
'n_features': 'Number of input features',
|
97 |
+
'dropout': 'Dropout rate (default: 0.2)'
|
98 |
+
},
|
99 |
+
'architecture': [
|
100 |
+
'Input projection with layer normalization',
|
101 |
+
'Residual blocks with batch normalization',
|
102 |
+
'Downsampling blocks',
|
103 |
+
'Regression head'
|
104 |
+
]
|
105 |
+
}
|