Niksa Praljak commited on
Commit
c11fc4d
1 Parent(s): 0354528

update PenCL test inference script

Browse files
Files changed (2) hide show
  1. README.md +5 -0
  2. run_PenCL_inference.py +21 -0
README.md CHANGED
@@ -118,6 +118,11 @@ tensor([[9.8060e-01, 3.7078e-07],
118
  Text-Normalized Probabilities:
119
  tensor([[9.9596e-01, 4.0412e-03],
120
  [1.8076e-06, 1.0000e+00]])
 
 
 
 
 
121
  ```
122
 
123
  ## Stage 2: Facilitator Sampling
 
118
  Text-Normalized Probabilities:
119
  tensor([[9.9596e-01, 4.0412e-03],
120
  [1.8076e-06, 1.0000e+00]])
121
+
122
+ === Homology Matrix (Dot Product of Normalized z_p) ===
123
+ tensor([[1.0000, 0.1840],
124
+ [0.1840, 1.0000]])
125
+
126
  ```
127
 
128
  ## Stage 2: Facilitator Sampling
run_PenCL_inference.py CHANGED
@@ -58,6 +58,20 @@ def parse_arguments():
58
  help="Path to the pre-trained model weights (pytorch_model.bin)")
59
  return parser.parse_args()
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Main Execution
62
  if __name__ == '__main__':
63
  # Parse arguments
@@ -101,6 +115,9 @@ if __name__ == '__main__':
101
  # Compute magnitudes (L2 norms) for z_t and z_p
102
  z_p_magnitude = torch.norm(z_p_tensor, dim=1) # L2 norm for each protein latent vector
103
  z_t_magnitude = torch.norm(z_t_tensor, dim=1) # L2 norm for each text latent vector
 
 
 
104
 
105
  # Print results
106
  print("\n=== Inference Results ===")
@@ -118,3 +135,7 @@ if __name__ == '__main__':
118
 
119
  print("\nText-Normalized Probabilities (Softmax across Texts for each Protein):")
120
  print(text_given_protein_probs)
 
 
 
 
 
58
  help="Path to the pre-trained model weights (pytorch_model.bin)")
59
  return parser.parse_args()
60
 
61
+ # Step 6: Compute Homology Probabilities
62
+ def compute_homology_matrix(z_p_tensor):
63
+ """
64
+ Compute the homology matrix as cosine similarities between protein latent vectors.
65
+ """
66
+ # Normalize z_p to unit vectors
67
+ z_p_normalized = F.normalize(z_p_tensor, p=2, dim=1) # L2 normalization
68
+
69
+ # Compute cosine similarity matrix
70
+ homology_matrix = torch.matmul(z_p_normalized, z_p_normalized.T) # (num_samples x num_samples)
71
+
72
+ return homology_matrix
73
+
74
+
75
  # Main Execution
76
  if __name__ == '__main__':
77
  # Parse arguments
 
115
  # Compute magnitudes (L2 norms) for z_t and z_p
116
  z_p_magnitude = torch.norm(z_p_tensor, dim=1) # L2 norm for each protein latent vector
117
  z_t_magnitude = torch.norm(z_t_tensor, dim=1) # L2 norm for each text latent vector
118
+
119
+ # Compute homology probabilities
120
+ homology_matrix = compute_homology_matrix(z_p_tensor)
121
 
122
  # Print results
123
  print("\n=== Inference Results ===")
 
135
 
136
  print("\nText-Normalized Probabilities (Softmax across Texts for each Protein):")
137
  print(text_given_protein_probs)
138
+
139
+ print("\n=== Homology Matrix (Dot Product of Normalized z_p) ===")
140
+ print(homology_matrix)
141
+