Demo750 commited on
Commit
a05a7bd
·
verified ·
1 Parent(s): b4b7ecd

Update DL_models.py

Browse files
Files changed (1) hide show
  1. DL_models.py +17 -2
DL_models.py CHANGED
@@ -33,6 +33,10 @@ class CustomResNet(nn.Module):
33
  self.fc1 = nn.Linear(2 * (512 * 4), 256)
34
  self.fc2 = nn.Linear(256, 1)
35
 
 
 
 
 
36
 
37
  def forward(self, x1, x2, Location):
38
  N = x1.shape[0]
@@ -42,7 +46,9 @@ class CustomResNet(nn.Module):
42
  Location = Location.to(torch.float).to(self.device)
43
 
44
  # Process both images through the same ResNet
45
- f1 = self.resnet(x1)
 
 
46
  f2 = self.resnet(x2)
47
 
48
  # Flatten the features
@@ -59,4 +65,13 @@ class CustomResNet(nn.Module):
59
  x = torch.relu(self.fc1(combined))
60
  x = self.fc2(x)
61
 
62
- return x
 
 
 
 
 
 
 
 
 
 
33
  self.fc1 = nn.Linear(2 * (512 * 4), 256)
34
  self.fc2 = nn.Linear(256, 1)
35
 
36
+ self.gradients = None
37
+
38
+ def activations_hook(self, grad):
39
+ self.gradients = grad
40
 
41
  def forward(self, x1, x2, Location):
42
  N = x1.shape[0]
 
46
  Location = Location.to(torch.float).to(self.device)
47
 
48
  # Process both images through the same ResNet
49
+ f1 = self.resnet[:8](x1)
50
+ h = f1.register_hook(self.activations_hook)
51
+ f1 = self.resnet[8:](f1)
52
  f2 = self.resnet(x2)
53
 
54
  # Flatten the features
 
65
  x = torch.relu(self.fc1(combined))
66
  x = self.fc2(x)
67
 
68
+ return x
69
+
70
+ # method for the gradient extraction
71
+ def get_activations_gradient(self):
72
+ return self.gradients
73
+
74
+ # method for the activation exctraction
75
+ def get_activations(self, x):
76
+ x = self.transform(x.to(torch.float).to(self.device))
77
+ return self.resnet[:8](x)