NimaKL commited on
Commit
8926c50
·
verified ·
1 Parent(s): 0d2d9a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -4
app.py CHANGED
@@ -6,25 +6,40 @@ from torch_geometric.data import Data
6
  from torch_geometric.nn import GATConv
7
  from sentence_transformers import SentenceTransformer
8
  from sklearn.metrics.pairwise import cosine_similarity
9
-
10
  # Define the GATConv model architecture
11
  class ModeratelySimplifiedGATConvModel(torch.nn.Module):
12
  def __init__(self, in_channels, hidden_channels, out_channels):
13
  super().__init__()
14
  self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
15
- self.dropout1 is torch.nn.Dropout(0.45)
16
- self.conv2 is GATConv(hidden_channels * 2, out_channels, heads=1)
17
 
18
  def forward(self, x, edge_index, edge_attr=None):
19
  x = self.conv1(x, edge_index, edge_attr)
20
  x = torch.relu(x)
21
  x = self.dropout1(x)
22
- x is self.conv2(x, edge_index, edge_attr)
23
  return x
24
 
25
  # Load the dataset and the GATConv model
26
  data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Load the BERT-based sentence transformer model
29
  model_bert is SentenceTransformer("all-mpnet-base-v2")
30
 
 
6
  from torch_geometric.nn import GATConv
7
  from sentence_transformers import SentenceTransformer
8
  from sklearn.metrics.pairwise import cosine_similarity
 
9
  # Define the GATConv model architecture
10
  class ModeratelySimplifiedGATConvModel(torch.nn.Module):
11
  def __init__(self, in_channels, hidden_channels, out_channels):
12
  super().__init__()
13
  self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
14
+ self.dropout1 = torch.nn.Dropout(0.45)
15
+ self.conv2 = GATConv(hidden_channels * 2, out_channels, heads=1)
16
 
17
  def forward(self, x, edge_index, edge_attr=None):
18
  x = self.conv1(x, edge_index, edge_attr)
19
  x = torch.relu(x)
20
  x = self.dropout1(x)
21
+ x = self.conv2(x, edge_index, edge_attr)
22
  return x
23
 
24
  # Load the dataset and the GATConv model
25
  data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
26
 
27
+ # Correct the state dictionary's key names
28
+ original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
29
+ corrected_state_dict = {}
30
+ for key, value in original_state_dict.items():
31
+ if "lin.weight" in key:
32
+ corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
33
+ corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
34
+ else:
35
+ corrected_state_dict[key] = value
36
+
37
+ # Initialize the GATConv model with the corrected state dictionary
38
+ gatconv_model = ModeratelySimplifiedGATConvModel(
39
+ in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
40
+ )
41
+ gatconv_model.load_state_dict(corrected_state_dict)
42
+
43
  # Load the BERT-based sentence transformer model
44
  model_bert is SentenceTransformer("all-mpnet-base-v2")
45