NimaKL commited on
Commit
222fb9e
·
verified ·
1 Parent(s): f156242

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -6,7 +6,6 @@ from torch_geometric.data import Data
6
  from torch_geometric.nn import GATConv
7
  from sentence_transformers import SentenceTransformer
8
  from sklearn.metrics.pairwise import cosine_similarity
9
-
10
  # Define the GATConv model architecture
11
  class ModeratelySimplifiedGATConvModel(torch.nn.Module):
12
  def __init__(self, in_channels, hidden_channels, out_channels):
@@ -18,13 +17,29 @@ class ModeratelySimplifiedGATConvModel(torch.nn.Module):
18
  def forward(self, x, edge_index, edge_attr=None):
19
  x = self.conv1(x, edge_index, edge_attr)
20
  x = torch.relu(x)
21
- x = dropout1(x)
22
  x = self.conv2(x, edge_index, edge_attr)
23
  return x
24
 
25
  # Load the dataset and the GATConv model
26
  data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Load the BERT-based sentence transformer model
29
  model_bert = SentenceTransformer("all-mpnet-base-v2")
30
 
 
6
  from torch_geometric.nn import GATConv
7
  from sentence_transformers import SentenceTransformer
8
  from sklearn.metrics.pairwise import cosine_similarity
 
9
  # Define the GATConv model architecture
10
  class ModeratelySimplifiedGATConvModel(torch.nn.Module):
11
  def __init__(self, in_channels, hidden_channels, out_channels):
 
17
  def forward(self, x, edge_index, edge_attr=None):
18
  x = self.conv1(x, edge_index, edge_attr)
19
  x = torch.relu(x)
20
+ x = self.dropout1(x)
21
  x = self.conv2(x, edge_index, edge_attr)
22
  return x
23
 
24
  # Load the dataset and the GATConv model
25
  data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
26
 
27
+ # Correct the state dictionary's key names
28
+ original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
29
+ corrected_state_dict = {}
30
+ for key, value in original_state_dict.items():
31
+ if "lin.weight" in key:
32
+ corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
33
+ corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
34
+ else:
35
+ corrected_state_dict[key] = value
36
+
37
+ # Initialize the GATConv model with the corrected state dictionary
38
+ gatconv_model = ModeratelySimplifiedGATConvModel(
39
+ in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
40
+ )
41
+ gatconv_model.load_state_dict(corrected_state_dict)
42
+
43
  # Load the BERT-based sentence transformer model
44
  model_bert = SentenceTransformer("all-mpnet-base-v2")
45