|
import sys |
|
from typing import List |
|
from tqdm import tqdm |
|
import pandas as pd |
|
import numpy as np |
|
import threading |
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError |
|
import time |
|
import requests |
|
import joblib |
|
|
|
from Bio import SeqIO |
|
import rdkit |
|
from rdkit import Chem, DataStructs |
|
from rdkit.Chem import AllChem |
|
import torch |
|
from typing import * |
|
from rdkit import RDLogger |
|
RDLogger.DisableLog("rdApp.*") |
|
|
|
from xgboost import XGBClassifier, DMatrix |
|
|
|
from model.barlow_twins import BarlowTwins |
|
|
|
|
|
from utils.sequence import uniprot2sequence, encode_sequences |
|
|
|
|
|
class DTIModel: |
|
def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"): |
|
self.bt_model = BarlowTwins() |
|
self.bt_model.load_model(bt_model_path) |
|
|
|
self.gbm_model = XGBClassifier() |
|
self.gbm_model.load_model(gbm_model_path) |
|
|
|
self.encoder = encoder |
|
|
|
self.smiles_cache = {} |
|
self.sequence_cache = {} |
|
|
|
def _encode_smiles(self, smiles: str, radius: int = 2, bits: int = 1024, features: bool = False): |
|
if smiles is None: |
|
return None |
|
|
|
if smiles in self.smiles_cache: |
|
return self.smiles_cache[smiles] |
|
else: |
|
|
|
try: |
|
mol = Chem.MolFromSmiles(smiles) |
|
morgan = AllChem.GetMorganFingerprintAsBitVect( |
|
mol, |
|
radius=radius, |
|
nBits=bits, |
|
useFeatures=features, |
|
) |
|
morgan = np.array(morgan) |
|
self.smiles_cache[smiles] = morgan |
|
return morgan |
|
except Exception as e: |
|
print(f"Failed to encode SMILES: {smiles}") |
|
print(e) |
|
return None |
|
|
|
def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False): |
|
morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles] |
|
return np.array(morgan) |
|
|
|
def _encode_sequence(self, sequence: str): |
|
|
|
torch.cuda.empty_cache() |
|
if sequence is None: |
|
return None |
|
|
|
if sequence in self.sequence_cache: |
|
return self.sequence_cache[sequence] |
|
else: |
|
|
|
try: |
|
encoded_sequence = encode_sequences([sequence], encoder=self.encoder) |
|
self.sequence_cache[sequence] = encoded_sequence |
|
return encoded_sequence |
|
except Exception as e: |
|
print(f"Failed to encode sequence: {sequence}") |
|
print(e) |
|
return None |
|
|
|
def _encode_sequence_mult(self, sequences: List[str]): |
|
seq = [self._encode_sequence(sequence) for sequence in sequences] |
|
return np.array(seq) |
|
|
|
def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool): |
|
if drug_emb.shape[0] < target_emb.shape[0]: |
|
drug_emb = np.tile(drug_emb, (len(target_emb), 1)) |
|
elif len(drug_emb) > len(target_emb): |
|
target_emb = np.tile(target_emb, (len(drug_emb), 1)) |
|
emb = self.bt_model.zero_shot(drug_emb, target_emb) |
|
|
|
if pred_leaf: |
|
d_emb = DMatrix(emb) |
|
return self.gbm_model.get_booster().predict(d_emb, pred_leaf=True) |
|
else: |
|
return self.gbm_model.predict_proba(emb)[:, 1] |
|
|
|
def predict(self, drug: List[str] or str, target: str, pred_leaf: bool = False): |
|
if isinstance(drug, str): |
|
drug_emb = self._encode_smiles(drug) |
|
else: |
|
drug_emb = self._encode_smiles_mult(drug) |
|
target_emb = self._encode_sequence(target) |
|
return self.__predict_pair(drug_emb, target_emb, pred_leaf) |
|
|
|
def get_leaf_weights(self): |
|
return self.gbm_model.get_booster().get_score(importance_type="weight") |
|
|
|
def _predict_fasta(self, drug: str, fasta_path: str): |
|
drug_emb = self._encode_smiles(drug) |
|
|
|
results = [] |
|
|
|
for target in tqdm(SeqIO.parse(fasta_path, "fasta"), desc="Predicting targets"): |
|
target_emb = self._encode_sequence(str(target.seq)) |
|
pred = self.__predict_pair(drug_emb, target_emb) |
|
results.append( |
|
{ |
|
"drug": drug, |
|
"target": target.id, |
|
"name": target.name, |
|
"description": target.description, |
|
"prediction": pred[0] |
|
} |
|
) |
|
return pd.DataFrame(results) |
|
|
|
def predict_fasta(self, drug: str, fasta_path: str, timeout_seconds: int = 120): |
|
def process_target(target, results): |
|
target_emb = self._encode_sequence(str(target.seq)) |
|
pred = self.__predict_pair(drug_emb, target_emb) |
|
results.append({ |
|
"drug": drug, |
|
"target": target.id, |
|
"name": target.name, |
|
"description": target.description, |
|
"prediction": pred[0] |
|
}) |
|
|
|
drug_emb = self._encode_smiles(drug) |
|
results = [] |
|
|
|
|
|
total_records = sum(1 for _ in SeqIO.parse(fasta_path, "fasta")) |
|
|
|
|
|
for target in tqdm(SeqIO.parse(fasta_path, "fasta"), total=total_records, desc="Predicting targets"): |
|
thread_results = [] |
|
thread = threading.Thread(target=process_target, args=(target, thread_results)) |
|
thread.start() |
|
thread.join(timeout_seconds) |
|
if thread.is_alive(): |
|
print(f"Skipping target {target.id} due to timeout") |
|
continue |
|
results.extend(thread_results) |
|
|
|
return pd.DataFrame(results) |
|
|
|
def predict_uniprot(self, drug: List[str] or str, uniprot_id: str): |
|
return self.predict(drug, uniprot2sequence(uniprot_id)) |
|
|