pokkiri commited on
Commit
b9cff30
·
verified ·
1 Parent(s): c74a613

Update model.py

Browse files

updated model module

Files changed (1) hide show
  1. model.py +51 -13
model.py CHANGED
@@ -2,11 +2,21 @@
2
  StableResNet Model for Biomass Prediction
3
  A numerically stable ResNet architecture for regression tasks
4
 
 
 
 
 
 
 
 
 
 
5
  Author: najahpokkiri
6
  Date: 2025-05-17
7
  """
8
  import torch
9
  import torch.nn as nn
 
10
 
11
  class StableResNet(nn.Module):
12
  """Numerically stable ResNet for biomass regression"""
@@ -33,20 +43,27 @@ class StableResNet(nn.Module):
33
  self._init_weights()
34
 
35
  def _make_simple_resblock(self, in_dim, out_dim):
36
- return nn.Sequential(
37
- nn.Linear(in_dim, out_dim),
38
- nn.BatchNorm1d(out_dim),
39
- nn.ReLU(),
40
- nn.Linear(out_dim, out_dim),
41
- nn.BatchNorm1d(out_dim),
42
- nn.ReLU()
43
- ) if in_dim == out_dim else nn.Sequential(
44
- nn.Linear(in_dim, out_dim),
45
- nn.BatchNorm1d(out_dim),
46
- nn.ReLU(),
47
- )
 
 
 
 
 
 
48
 
49
  def _init_weights(self):
 
50
  for m in self.modules():
51
  if isinstance(m, nn.Linear):
52
  nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
@@ -54,14 +71,35 @@ class StableResNet(nn.Module):
54
  nn.init.zeros_(m.bias)
55
 
56
  def forward(self, x):
 
57
  x = self.input_proj(x)
58
 
 
59
  identity = x
60
  out = self.layer1(x)
61
  x = out + identity
62
 
 
63
  x = self.layer2(x)
64
  x = self.layer3(x)
65
 
 
66
  x = self.regressor(x)
67
- return x.squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  StableResNet Model for Biomass Prediction
3
  A numerically stable ResNet architecture for regression tasks
4
 
5
+ Author: najahpokkiri
6
+ Date: 2025-05-17
7
+ """
8
+ """
9
+ StableResNet Model Architecture
10
+
11
+ This module defines the StableResNet architecture used for biomass prediction.
12
+ The model is designed for numerical stability with batch normalization and residual connections.
13
+
14
  Author: najahpokkiri
15
  Date: 2025-05-17
16
  """
17
  import torch
18
  import torch.nn as nn
19
+ import torch.nn.functional as F
20
 
21
  class StableResNet(nn.Module):
22
  """Numerically stable ResNet for biomass regression"""
 
43
  self._init_weights()
44
 
45
  def _make_simple_resblock(self, in_dim, out_dim):
46
+ """Create a simple residual block or downsampling block"""
47
+ if in_dim == out_dim:
48
+ # Residual block
49
+ return nn.Sequential(
50
+ nn.Linear(in_dim, out_dim),
51
+ nn.BatchNorm1d(out_dim),
52
+ nn.ReLU(),
53
+ nn.Linear(out_dim, out_dim),
54
+ nn.BatchNorm1d(out_dim),
55
+ nn.ReLU()
56
+ )
57
+ else:
58
+ # Downsampling block
59
+ return nn.Sequential(
60
+ nn.Linear(in_dim, out_dim),
61
+ nn.BatchNorm1d(out_dim),
62
+ nn.ReLU(),
63
+ )
64
 
65
  def _init_weights(self):
66
+ """Initialize weights for better convergence"""
67
  for m in self.modules():
68
  if isinstance(m, nn.Linear):
69
  nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
 
71
  nn.init.zeros_(m.bias)
72
 
73
  def forward(self, x):
74
+ """Forward pass through the network"""
75
  x = self.input_proj(x)
76
 
77
+ # First residual block
78
  identity = x
79
  out = self.layer1(x)
80
  x = out + identity
81
 
82
+ # Remaining blocks
83
  x = self.layer2(x)
84
  x = self.layer3(x)
85
 
86
+ # Regression output
87
  x = self.regressor(x)
88
+ return x.squeeze()
89
+
90
+ def get_model_info():
91
+ """Return information about the model architecture"""
92
+ return {
93
+ 'name': 'StableResNet',
94
+ 'description': 'Numerically stable ResNet for biomass regression',
95
+ 'parameters': {
96
+ 'n_features': 'Number of input features',
97
+ 'dropout': 'Dropout rate (default: 0.2)'
98
+ },
99
+ 'architecture': [
100
+ 'Input projection with layer normalization',
101
+ 'Residual blocks with batch normalization',
102
+ 'Downsampling blocks',
103
+ 'Regression head'
104
+ ]
105
+ }