Niksa Praljak commited on
Commit
eca78a8
1 Parent(s): 0655b48

update README.md and include PenCL inference script

Browse files
Files changed (2) hide show
  1. README.md +134 -0
  2. run_PenCL_inference.py +120 -0
README.md CHANGED
@@ -1,3 +1,137 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # BioM3: Protein Language Model Pipeline
6
+
7
+ ## Citation
8
+
9
+ If you use this code, please cite:
10
+
11
+ ```bibtex
12
+ Natural Language Prompts Guide the Design of Novel Functional Protein Sequences
13
+ bioRxiv 2024.11.11.622734
14
+ doi: https://doi.org/10.1101/2024.11.11.622734
15
+ ```
16
+
17
+ [Read the paper on bioRxiv](https://www.biorxiv.org/content/10.1101/2024.11.11.622734v1)
18
+
19
+ ## Software Requirements
20
+
21
+ ### Required Dependencies
22
+ - Python 3.8 or later
23
+ - PyTorch (latest stable version)
24
+ - PyTorch Lightning
25
+ - pandas
26
+ - pyyaml
27
+
28
+ ### Installation
29
+
30
+ Create and activate a conda environment:
31
+ ```bash
32
+ conda create -n BioM3_env python=3.8
33
+ conda activate BioM3_env
34
+ ```
35
+
36
+ Install the required packages:
37
+ ```bash
38
+ conda install pytorch pytorch-lightning pandas pyyaml -c pytorch -c conda-forge
39
+ ```
40
+
41
+ ## Stage 1: PenCL Inference
42
+
43
+ ### Overview
44
+
45
+ This stage demonstrates how to perform inference using the **BioM3 PenCL model** for aligning protein sequences and text descriptions. The model computes latent embeddings for the given inputs and calculates **dot product scores** (similarities) with normalization.
46
+
47
+ ### Model Weights
48
+
49
+ Before running the model, ensure you have:
50
+ - Configuration file: `stage1_config.json`
51
+ - Pre-trained weights: `BioM3_PenCL_epoch20.bin`
52
+
53
+ ### Running the Model
54
+
55
+ 1. Clone the repository:
56
+ ```bash
57
+ git clone https://huggingface.co/your_username/BioM3_PenCL
58
+ cd BioM3_PenCL
59
+ ```
60
+
61
+ 2. Run inference:
62
+ ```bash
63
+ python run_PenCL_inference.py \
64
+ --json_path "stage1_config.json" \
65
+ --model_path "BioM3_PenCL_epoch20.bin"
66
+ ```
67
+
68
+ ### Expected Output
69
+
70
+ The script provides the following outputs:
71
+
72
+ 1. **Latent Embedding Shapes**
73
+ - `z_p`: Protein sequence embeddings
74
+ - `z_t`: Text description embeddings
75
+
76
+ 2. **Vector Magnitudes**
77
+ - L2 norms of both embedding types
78
+
79
+ 3. **Dot Product Scores**
80
+ - Similarity matrix between embeddings
81
+
82
+ 4. **Normalized Probabilities**
83
+ - Protein-normalized (softmax over rows)
84
+ - Text-normalized (softmax over columns)
85
+
86
+ #### Sample Output
87
+ ```plaintext
88
+ === Inference Results ===
89
+ Shape of z_p (protein latent): torch.Size([2, 512])
90
+ Shape of z_t (text latent): torch.Size([2, 512])
91
+
92
+ Magnitudes of z_p vectors: tensor([5.3376, 4.8237])
93
+ Magnitudes of z_t vectors: tensor([29.6971, 27.6714])
94
+
95
+ === Dot Product Scores Matrix ===
96
+ tensor([[ 7.3152, 1.8080],
97
+ [ 3.3922, 16.6157]])
98
+
99
+ === Normalized Probabilities ===
100
+ Protein-Normalized Probabilities:
101
+ tensor([[9.8060e-01, 3.7078e-07],
102
+ [1.9398e-02, 1.0000e+00]])
103
+
104
+ Text-Normalized Probabilities:
105
+ tensor([[9.9596e-01, 4.0412e-03],
106
+ [1.8076e-06, 1.0000e+00]])
107
+ ```
108
+
109
+ ## Stage 2: Facilitator Sampling
110
+
111
+ 🚧 **Coming Soon** 🚧
112
+
113
+ This stage will contain scripts and models for the Facilitator Sampling process. Check back for:
114
+ - Configuration files
115
+ - Model weights
116
+ - Running instructions
117
+ - Output examples
118
+
119
+ ## Stage 3: ProteoScribe
120
+
121
+ 🚧 **Coming Soon** 🚧
122
+
123
+ This stage will contain scripts and models for the ProteoScribe process. Check back for:
124
+ - Configuration files
125
+ - Model weights
126
+ - Running instructions
127
+ - Output examples
128
+
129
+ ## Support
130
+
131
+ For questions or issues:
132
+ - Open an issue in this repository
133
+ - Contact: [Your contact information]
134
+
135
+ ---
136
+ Repository maintained by the BioM3 Team
137
+
run_PenCL_inference.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ from argparse import Namespace
4
+ import json
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import Stage1_source.preprocess as prep
10
+ import Stage1_source.model as mod
11
+ import Stage1_source.PL_wrapper as PL_wrap
12
+
13
+ # Step 1: Load JSON Configuration
14
+ def load_json_config(json_path):
15
+ with open(json_path, "r") as f:
16
+ config = json.load(f)
17
+ return config
18
+
19
+ # Step 2: Convert JSON dictionary to Namespace
20
+ def convert_to_namespace(config_dict):
21
+ for key, value in config_dict.items():
22
+ if isinstance(value, dict):
23
+ config_dict[key] = convert_to_namespace(value)
24
+ return Namespace(**config_dict)
25
+
26
+ # Step 3: Load Pre-trained Model
27
+ def prepare_model(config_args, model_path) -> nn.Module:
28
+ model = mod.pfam_PEN_CL(args=config_args)
29
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
30
+ model.eval()
31
+ print("Model loaded successfully with weights!")
32
+ return model
33
+
34
+ # Step 4: Prepare Test Dataset
35
+ def load_test_dataset(config_args):
36
+ test_dict = {
37
+ 'primary_Accession': ['A0A009IHW8', 'A0A023I7E1'],
38
+ 'protein_sequence': [
39
+ "MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENARIQSKL...",
40
+ "MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIFPEIKHP..."
41
+ ],
42
+ '[final]text_caption': [
43
+ "PROTEIN NAME: 2' cyclic ADP-D-ribose synthase AbTIR...",
44
+ "PROTEIN NAME: Glucan endo-1,3-beta-D-glucosidase 1..."
45
+ ],
46
+ 'pfam_label': ["['PF13676']", "['PF17652','PF03639']"]
47
+ }
48
+ test_df = pd.DataFrame(test_dict)
49
+ test_dataset = prep.TextSeqPairing_Dataset(args=config_args, df=test_df)
50
+ return test_dataset
51
+
52
+ # Step 5: Argument Parser Function
53
+ def parse_arguments():
54
+ parser = argparse.ArgumentParser(description="BioM3 Inference Script (Stage 1)")
55
+ parser.add_argument('--json_path', type=str, required=True,
56
+ help="Path to the JSON configuration file (stage1_config.json)")
57
+ parser.add_argument('--model_path', type=str, required=True,
58
+ help="Path to the pre-trained model weights (pytorch_model.bin)")
59
+ return parser.parse_args()
60
+
61
+ # Main Execution
62
+ if __name__ == '__main__':
63
+ # Parse arguments
64
+ config_args_parser = parse_arguments()
65
+
66
+ # Load configuration
67
+ config_dict = load_json_config(config_args_parser.json_path)
68
+ config_args = convert_to_namespace(config_dict)
69
+
70
+ # Load model
71
+ model = prepare_model(config_args=config_args, model_path=config_args_parser.model_path)
72
+
73
+ # Load test dataset
74
+ test_dataset = load_test_dataset(config_args)
75
+
76
+ # Run inference and store z_t, z_p
77
+ z_t_list = []
78
+ z_p_list = []
79
+
80
+ with torch.no_grad():
81
+ for idx in range(len(test_dataset)):
82
+ batch = test_dataset[idx]
83
+ x_t, x_p = batch
84
+ outputs = model(x_t, x_p, compute_masked_logits=False) # Infer Joint-Embeddings
85
+ z_t = outputs['text_joint_latent'] # Text latent
86
+ z_p = outputs['seq_joint_latent'] # Protein latent
87
+ z_t_list.append(z_t)
88
+ z_p_list.append(z_p)
89
+
90
+ # Stack all latent vectors
91
+ z_t_tensor = torch.vstack(z_t_list) # Shape: (num_samples, latent_dim)
92
+ z_p_tensor = torch.vstack(z_p_list) # Shape: (num_samples, latent_dim)
93
+
94
+ # Compute Dot Product scores
95
+ dot_product_scores = torch.matmul(z_p_tensor, z_t_tensor.T) # Dot product
96
+
97
+ # Normalize scores into probabilities
98
+ protein_given_text_probs = F.softmax(dot_product_scores, dim=0) # Normalize across rows (proteins), for each text
99
+ text_given_protein_probs = F.softmax(dot_product_scores, dim=1) # Normalize across columns (texts), for each protein
100
+
101
+ # Compute magnitudes (L2 norms) for z_t and z_p
102
+ z_p_magnitude = torch.norm(z_p_tensor, dim=1) # L2 norm for each protein latent vector
103
+ z_t_magnitude = torch.norm(z_t_tensor, dim=1) # L2 norm for each text latent vector
104
+
105
+ # Print results
106
+ print("\n=== Inference Results ===")
107
+ print(f"Shape of z_p (protein latent): {z_p_tensor.shape}")
108
+ print(f"Shape of z_t (text latent): {z_t_tensor.shape}")
109
+ print(f"\nMagnitudes of z_p vectors: {z_p_magnitude}")
110
+ print(f"Magnitudes of z_t vectors: {z_t_magnitude}")
111
+
112
+ print("\n=== Dot Product Scores Matrix ===")
113
+ print(dot_product_scores)
114
+
115
+ print("\n=== Normalized Probabilities ===")
116
+ print("Protein-Normalized Probabilities (Softmax across Proteins for each Text):")
117
+ print(protein_given_text_probs)
118
+
119
+ print("\nText-Normalized Probabilities (Softmax across Texts for each Protein):")
120
+ print(text_given_protein_probs)