Niksa Praljak
commited on
Commit
•
eca78a8
1
Parent(s):
0655b48
update README.md and include PenCL inference script
Browse files- README.md +134 -0
- run_PenCL_inference.py +120 -0
README.md
CHANGED
@@ -1,3 +1,137 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
+
|
5 |
+
# BioM3: Protein Language Model Pipeline
|
6 |
+
|
7 |
+
## Citation
|
8 |
+
|
9 |
+
If you use this code, please cite:
|
10 |
+
|
11 |
+
```bibtex
|
12 |
+
Natural Language Prompts Guide the Design of Novel Functional Protein Sequences
|
13 |
+
bioRxiv 2024.11.11.622734
|
14 |
+
doi: https://doi.org/10.1101/2024.11.11.622734
|
15 |
+
```
|
16 |
+
|
17 |
+
[Read the paper on bioRxiv](https://www.biorxiv.org/content/10.1101/2024.11.11.622734v1)
|
18 |
+
|
19 |
+
## Software Requirements
|
20 |
+
|
21 |
+
### Required Dependencies
|
22 |
+
- Python 3.8 or later
|
23 |
+
- PyTorch (latest stable version)
|
24 |
+
- PyTorch Lightning
|
25 |
+
- pandas
|
26 |
+
- pyyaml
|
27 |
+
|
28 |
+
### Installation
|
29 |
+
|
30 |
+
Create and activate a conda environment:
|
31 |
+
```bash
|
32 |
+
conda create -n BioM3_env python=3.8
|
33 |
+
conda activate BioM3_env
|
34 |
+
```
|
35 |
+
|
36 |
+
Install the required packages:
|
37 |
+
```bash
|
38 |
+
conda install pytorch pytorch-lightning pandas pyyaml -c pytorch -c conda-forge
|
39 |
+
```
|
40 |
+
|
41 |
+
## Stage 1: PenCL Inference
|
42 |
+
|
43 |
+
### Overview
|
44 |
+
|
45 |
+
This stage demonstrates how to perform inference using the **BioM3 PenCL model** for aligning protein sequences and text descriptions. The model computes latent embeddings for the given inputs and calculates **dot product scores** (similarities) with normalization.
|
46 |
+
|
47 |
+
### Model Weights
|
48 |
+
|
49 |
+
Before running the model, ensure you have:
|
50 |
+
- Configuration file: `stage1_config.json`
|
51 |
+
- Pre-trained weights: `BioM3_PenCL_epoch20.bin`
|
52 |
+
|
53 |
+
### Running the Model
|
54 |
+
|
55 |
+
1. Clone the repository:
|
56 |
+
```bash
|
57 |
+
git clone https://huggingface.co/your_username/BioM3_PenCL
|
58 |
+
cd BioM3_PenCL
|
59 |
+
```
|
60 |
+
|
61 |
+
2. Run inference:
|
62 |
+
```bash
|
63 |
+
python run_PenCL_inference.py \
|
64 |
+
--json_path "stage1_config.json" \
|
65 |
+
--model_path "BioM3_PenCL_epoch20.bin"
|
66 |
+
```
|
67 |
+
|
68 |
+
### Expected Output
|
69 |
+
|
70 |
+
The script provides the following outputs:
|
71 |
+
|
72 |
+
1. **Latent Embedding Shapes**
|
73 |
+
- `z_p`: Protein sequence embeddings
|
74 |
+
- `z_t`: Text description embeddings
|
75 |
+
|
76 |
+
2. **Vector Magnitudes**
|
77 |
+
- L2 norms of both embedding types
|
78 |
+
|
79 |
+
3. **Dot Product Scores**
|
80 |
+
- Similarity matrix between embeddings
|
81 |
+
|
82 |
+
4. **Normalized Probabilities**
|
83 |
+
- Protein-normalized (softmax over rows)
|
84 |
+
- Text-normalized (softmax over columns)
|
85 |
+
|
86 |
+
#### Sample Output
|
87 |
+
```plaintext
|
88 |
+
=== Inference Results ===
|
89 |
+
Shape of z_p (protein latent): torch.Size([2, 512])
|
90 |
+
Shape of z_t (text latent): torch.Size([2, 512])
|
91 |
+
|
92 |
+
Magnitudes of z_p vectors: tensor([5.3376, 4.8237])
|
93 |
+
Magnitudes of z_t vectors: tensor([29.6971, 27.6714])
|
94 |
+
|
95 |
+
=== Dot Product Scores Matrix ===
|
96 |
+
tensor([[ 7.3152, 1.8080],
|
97 |
+
[ 3.3922, 16.6157]])
|
98 |
+
|
99 |
+
=== Normalized Probabilities ===
|
100 |
+
Protein-Normalized Probabilities:
|
101 |
+
tensor([[9.8060e-01, 3.7078e-07],
|
102 |
+
[1.9398e-02, 1.0000e+00]])
|
103 |
+
|
104 |
+
Text-Normalized Probabilities:
|
105 |
+
tensor([[9.9596e-01, 4.0412e-03],
|
106 |
+
[1.8076e-06, 1.0000e+00]])
|
107 |
+
```
|
108 |
+
|
109 |
+
## Stage 2: Facilitator Sampling
|
110 |
+
|
111 |
+
🚧 **Coming Soon** 🚧
|
112 |
+
|
113 |
+
This stage will contain scripts and models for the Facilitator Sampling process. Check back for:
|
114 |
+
- Configuration files
|
115 |
+
- Model weights
|
116 |
+
- Running instructions
|
117 |
+
- Output examples
|
118 |
+
|
119 |
+
## Stage 3: ProteoScribe
|
120 |
+
|
121 |
+
🚧 **Coming Soon** 🚧
|
122 |
+
|
123 |
+
This stage will contain scripts and models for the ProteoScribe process. Check back for:
|
124 |
+
- Configuration files
|
125 |
+
- Model weights
|
126 |
+
- Running instructions
|
127 |
+
- Output examples
|
128 |
+
|
129 |
+
## Support
|
130 |
+
|
131 |
+
For questions or issues:
|
132 |
+
- Open an issue in this repository
|
133 |
+
- Contact: [Your contact information]
|
134 |
+
|
135 |
+
---
|
136 |
+
Repository maintained by the BioM3 Team
|
137 |
+
|
run_PenCL_inference.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import yaml
|
3 |
+
from argparse import Namespace
|
4 |
+
import json
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import Stage1_source.preprocess as prep
|
10 |
+
import Stage1_source.model as mod
|
11 |
+
import Stage1_source.PL_wrapper as PL_wrap
|
12 |
+
|
13 |
+
# Step 1: Load JSON Configuration
|
14 |
+
def load_json_config(json_path):
|
15 |
+
with open(json_path, "r") as f:
|
16 |
+
config = json.load(f)
|
17 |
+
return config
|
18 |
+
|
19 |
+
# Step 2: Convert JSON dictionary to Namespace
|
20 |
+
def convert_to_namespace(config_dict):
|
21 |
+
for key, value in config_dict.items():
|
22 |
+
if isinstance(value, dict):
|
23 |
+
config_dict[key] = convert_to_namespace(value)
|
24 |
+
return Namespace(**config_dict)
|
25 |
+
|
26 |
+
# Step 3: Load Pre-trained Model
|
27 |
+
def prepare_model(config_args, model_path) -> nn.Module:
|
28 |
+
model = mod.pfam_PEN_CL(args=config_args)
|
29 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
30 |
+
model.eval()
|
31 |
+
print("Model loaded successfully with weights!")
|
32 |
+
return model
|
33 |
+
|
34 |
+
# Step 4: Prepare Test Dataset
|
35 |
+
def load_test_dataset(config_args):
|
36 |
+
test_dict = {
|
37 |
+
'primary_Accession': ['A0A009IHW8', 'A0A023I7E1'],
|
38 |
+
'protein_sequence': [
|
39 |
+
"MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENARIQSKL...",
|
40 |
+
"MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIFPEIKHP..."
|
41 |
+
],
|
42 |
+
'[final]text_caption': [
|
43 |
+
"PROTEIN NAME: 2' cyclic ADP-D-ribose synthase AbTIR...",
|
44 |
+
"PROTEIN NAME: Glucan endo-1,3-beta-D-glucosidase 1..."
|
45 |
+
],
|
46 |
+
'pfam_label': ["['PF13676']", "['PF17652','PF03639']"]
|
47 |
+
}
|
48 |
+
test_df = pd.DataFrame(test_dict)
|
49 |
+
test_dataset = prep.TextSeqPairing_Dataset(args=config_args, df=test_df)
|
50 |
+
return test_dataset
|
51 |
+
|
52 |
+
# Step 5: Argument Parser Function
|
53 |
+
def parse_arguments():
|
54 |
+
parser = argparse.ArgumentParser(description="BioM3 Inference Script (Stage 1)")
|
55 |
+
parser.add_argument('--json_path', type=str, required=True,
|
56 |
+
help="Path to the JSON configuration file (stage1_config.json)")
|
57 |
+
parser.add_argument('--model_path', type=str, required=True,
|
58 |
+
help="Path to the pre-trained model weights (pytorch_model.bin)")
|
59 |
+
return parser.parse_args()
|
60 |
+
|
61 |
+
# Main Execution
|
62 |
+
if __name__ == '__main__':
|
63 |
+
# Parse arguments
|
64 |
+
config_args_parser = parse_arguments()
|
65 |
+
|
66 |
+
# Load configuration
|
67 |
+
config_dict = load_json_config(config_args_parser.json_path)
|
68 |
+
config_args = convert_to_namespace(config_dict)
|
69 |
+
|
70 |
+
# Load model
|
71 |
+
model = prepare_model(config_args=config_args, model_path=config_args_parser.model_path)
|
72 |
+
|
73 |
+
# Load test dataset
|
74 |
+
test_dataset = load_test_dataset(config_args)
|
75 |
+
|
76 |
+
# Run inference and store z_t, z_p
|
77 |
+
z_t_list = []
|
78 |
+
z_p_list = []
|
79 |
+
|
80 |
+
with torch.no_grad():
|
81 |
+
for idx in range(len(test_dataset)):
|
82 |
+
batch = test_dataset[idx]
|
83 |
+
x_t, x_p = batch
|
84 |
+
outputs = model(x_t, x_p, compute_masked_logits=False) # Infer Joint-Embeddings
|
85 |
+
z_t = outputs['text_joint_latent'] # Text latent
|
86 |
+
z_p = outputs['seq_joint_latent'] # Protein latent
|
87 |
+
z_t_list.append(z_t)
|
88 |
+
z_p_list.append(z_p)
|
89 |
+
|
90 |
+
# Stack all latent vectors
|
91 |
+
z_t_tensor = torch.vstack(z_t_list) # Shape: (num_samples, latent_dim)
|
92 |
+
z_p_tensor = torch.vstack(z_p_list) # Shape: (num_samples, latent_dim)
|
93 |
+
|
94 |
+
# Compute Dot Product scores
|
95 |
+
dot_product_scores = torch.matmul(z_p_tensor, z_t_tensor.T) # Dot product
|
96 |
+
|
97 |
+
# Normalize scores into probabilities
|
98 |
+
protein_given_text_probs = F.softmax(dot_product_scores, dim=0) # Normalize across rows (proteins), for each text
|
99 |
+
text_given_protein_probs = F.softmax(dot_product_scores, dim=1) # Normalize across columns (texts), for each protein
|
100 |
+
|
101 |
+
# Compute magnitudes (L2 norms) for z_t and z_p
|
102 |
+
z_p_magnitude = torch.norm(z_p_tensor, dim=1) # L2 norm for each protein latent vector
|
103 |
+
z_t_magnitude = torch.norm(z_t_tensor, dim=1) # L2 norm for each text latent vector
|
104 |
+
|
105 |
+
# Print results
|
106 |
+
print("\n=== Inference Results ===")
|
107 |
+
print(f"Shape of z_p (protein latent): {z_p_tensor.shape}")
|
108 |
+
print(f"Shape of z_t (text latent): {z_t_tensor.shape}")
|
109 |
+
print(f"\nMagnitudes of z_p vectors: {z_p_magnitude}")
|
110 |
+
print(f"Magnitudes of z_t vectors: {z_t_magnitude}")
|
111 |
+
|
112 |
+
print("\n=== Dot Product Scores Matrix ===")
|
113 |
+
print(dot_product_scores)
|
114 |
+
|
115 |
+
print("\n=== Normalized Probabilities ===")
|
116 |
+
print("Protein-Normalized Probabilities (Softmax across Proteins for each Text):")
|
117 |
+
print(protein_given_text_probs)
|
118 |
+
|
119 |
+
print("\nText-Normalized Probabilities (Softmax across Texts for each Protein):")
|
120 |
+
print(text_given_protein_probs)
|