PiKaHa commited on
Commit
2a2a7a4
·
1 Parent(s): d61adc7

Add saved models and requirements

Browse files
KC.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9985473ee19c6be7ba5777b12bc1babe60746298345e64e0426ab54f8e8e92d0
3
+ size 188721
KC/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c35ad708d8b793c7d4b9867ee4225445cd9c2648a3589067a7814da7a660f28
3
+ size 1209629
KC/variables/variables.data-00000-of-00001 ADDED
Binary file (676 Bytes). View file
 
KC/variables/variables.index ADDED
Binary file (387 Bytes). View file
 
Specificity.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c7508f5fec50afa81a43d85917b947b367f20da122c6fc9635d6cd202bc0b51
3
+ size 271953
Specificity/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bef26d451689448d1469d1865bf91b9f342c9258d059dbc8deff006690a3189
3
+ size 2207910
Specificity/variables/variables.data-00000-of-00001 ADDED
Binary file (676 Bytes). View file
 
Specificity/variables/variables.index ADDED
Binary file (387 Bytes). View file
 
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import joblib
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from transformers import AutoTokenizer, AutoModel, EsmModel
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import tensorflow as tf
9
+ import os
10
+ from keras.layers import TFSMLayer
11
+ import pandas as pd
12
+
13
+ base_dir = "."
14
+
15
+ # Set random seed
16
+ SEED = 42
17
+ np.random.seed(SEED)
18
+ random.seed(SEED)
19
+ torch.manual_seed(SEED)
20
+ if torch.cuda.is_available():
21
+ torch.cuda.manual_seed(SEED)
22
+ torch.cuda.manual_seed_all(SEED)
23
+
24
+ # Ensure deterministic behavior
25
+ torch.backends.cudnn.deterministic = True
26
+ torch.backends.cudnn.benchmark = False
27
+
28
+
29
+ def load_model(model_path):
30
+ print(f"Loading model from {model_path}...")
31
+ return tf.saved_model.load(model_path)
32
+
33
+
34
+ print("Loading models...")
35
+ plant_models = {
36
+ "Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6},
37
+ "kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11},
38
+ "KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4},
39
+ }
40
+
41
+ general_models = {
42
+ "Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33},
43
+ "kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7},
44
+ "KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26},
45
+ }
46
+
47
+
48
+ # Function to generate embeddings
49
+ def get_embedding(sequence, esm_model_name, layer):
50
+ print(f"Generating embeddings using {esm_model_name}, Layer {layer}...")
51
+ tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
52
+ model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True)
53
+
54
+ # Tokenize the sequence
55
+ inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
56
+
57
+ # Generate embeddings
58
+ with torch.no_grad():
59
+ outputs = model(**inputs)
60
+ hidden_states = outputs.hidden_states # Retrieve all hidden states
61
+ embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling
62
+
63
+ # Convert to DataFrame with named columns
64
+ feature_columns = {f"D{i+1}": embedding[0, i] for i in range(embedding.shape[1])}
65
+ embedding_df = pd.DataFrame([feature_columns])
66
+ print (embedding_df)
67
+ return embedding_df.values, embedding_df
68
+
69
+
70
+ def predict_with_gpflow(model, X):
71
+ print(model.signatures)
72
+ # Convert input to TensorFlow tensor
73
+ X_tensor = tf.convert_to_tensor(X, dtype=tf.float64)
74
+ print (X_tensor.shape)
75
+ # Get predictions
76
+ #predict_fn = model.predict_f_compiled
77
+ predict_fn = model.signatures["serving_default"]
78
+ result = predict_fn(Xnew=X_tensor) # Pass Xnew explicitly
79
+ #mean, variance = predict_fn(Xnew=X_tensor)
80
+ mean = result["output_0"].numpy() # Adjust output key names if needed
81
+ variance = result["output_1"].numpy()
82
+
83
+ # Return mean and variance as numpy arrays
84
+ #return mean.numpy().flatten(), variance.numpy().flatten()
85
+ return mean.flatten(), variance.flatten()
86
+
87
+
88
+
89
+ def process_target(target, selected_models, sequence, prediction_type):
90
+ """
91
+ Process a single target for prediction using transformer embeddings and the specified model.
92
+ """
93
+ # Get model and embedding details
94
+ esm_model_name = selected_models[target]["esm_model"]
95
+ layer = selected_models[target]["layer"]
96
+ model = selected_models[target]["model"]
97
+
98
+ # Generate embeddings in the required format
99
+ embedding, _ = get_embedding(sequence, esm_model_name, layer)
100
+ embedding = embedding.astype(np.float64)
101
+ np.save(f"hf_embedding_{target}.npy", embedding)
102
+ if prediction_type == "Plant-Specific":
103
+ # Random Forest prediction
104
+ y_pred = model.predict(embedding)[0]
105
+ return target, round(y_pred, 2)
106
+ else:
107
+ # GPflow prediction
108
+ print (esm_model_name)
109
+ print (layer)
110
+ print (model)
111
+ y_pred, y_uncertainty = predict_with_gpflow(model, embedding)
112
+ return target, round(y_pred[0], 2), round(y_uncertainty[0], 2)
113
+
114
+
115
+ def predict(sequence, prediction_type):
116
+ """
117
+ Predicts Specificity, kcatC, and KC for the given sequence and prediction type.
118
+ """
119
+ # Select the appropriate model set
120
+ selected_models = plant_models if prediction_type == "Plant-Specific" else general_models
121
+
122
+ # Predict for all targets in parallel
123
+ with ThreadPoolExecutor() as executor:
124
+ results = list(
125
+ executor.map(
126
+ lambda target: process_target(target, selected_models, sequence, prediction_type),
127
+ selected_models.keys()
128
+ )
129
+ )
130
+
131
+ # Format results
132
+ if prediction_type == "Plant-Specific":
133
+ formatted_results = [
134
+ ["Specificity", results[0][1]],
135
+ ["kcat\u1d9c", results[1][1]],
136
+ ["K\u1d9c", results[2][1]],
137
+ ]
138
+ else:
139
+ formatted_results = [
140
+ ["Specificity", results[0][1], results[0][2]],
141
+ ["kcat\u1d9c", results[1][1], results[1][2]],
142
+ ["K\u1d9c", results[2][1], results[2][2]],
143
+ ]
144
+
145
+ return formatted_results
146
+
147
+
148
+ # Define Gradio interface
149
+ print("Creating Gradio interface...")
150
+ interface = gr.Interface(
151
+ fn=predict,
152
+ inputs=[
153
+ gr.Textbox(label="Input Protein Sequence",
154
+ value="MSPQTETKASVGFKAGVKEYKLTYYTPEYETKDTDILAAFRVTPQPGVPPEEAGAAVAAESSTGTWTTVWTDGLTSLDRYKGRCYHIEPVPGEETQFIAYVAYPLDLFEEGSVTNMFTSIVGNVFGFKALAALRLEDLRIPPAYTKTFQGPPHGIQVERDKLNKYGRPLLGCTIKPKLGLSAKNYGRAVYECLRGGLDFTKDDENVNSQPFMRWRDRFLFCAEAIYKSQAETGEIKGHYLNATAGTCEEMIKRAVFARELGVPIVMHDYLTGGFTANTSLSHYCRDNGLLLHIHRAMHAVIDRQKNHGMHFRVLAKALRLSGGDHIHAGTVVGKLEGDRESTLGFVDLLRDDYVEKDRSRGIFFTQDWVSLPGVLPVASGGIHVWHMPALTEIFGDDSVLQFGGGTLGHPWGNAPGAVANRVALEACVQARNEGRDLAVEGNEIIREACKWSPELAAACEVWKEITFNFPTIDKLDGQE",
155
+ lines=10,
156
+ ), # Input: Text box for sequence
157
+ gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"), # Dropdown for selection
158
+ ],
159
+ outputs=gr.Dataframe(
160
+ headers=["Target", "Prediction", "Uncertainty (for General)"],
161
+ type="array"
162
+ ), # Output: Table
163
+ title="Rubisco Kinetics Prediction",
164
+ description=(
165
+ "Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). "
166
+ "Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions."
167
+ ),
168
+ )
169
+
170
+ if __name__ == "__main__":
171
+ interface.launch()
kcatC.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33e1131e59808c8c23f910502730c40569f71946322fbc6f6c9f0236c11a8c6a
3
+ size 181809
kcatC/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a3fedc1e7fc41a15fa41e5fa04efe015fcadd2395e1ac07da60fb5328a8a401
3
+ size 1003709
kcatC/variables/variables.data-00000-of-00001 ADDED
Binary file (676 Bytes). View file
 
kcatC/variables/variables.index ADDED
Binary file (387 Bytes). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ gradio
4
+ joblib
5
+ numpy
6
+ scikit-learn
7
+ gpflow