import logging import shutil import sys import tempfile from argparse import ArgumentParser, Namespace, FileType import copy import itertools import os import subprocess from datetime import datetime from pathlib import Path from functools import partial, cache import warnings import yaml from Bio.PDB import PDBParser from prody import parsePDB, parsePQR from sklearn.cluster import DBSCAN from openbabel import openbabel as ob from src import const from src.datasets import ( collate_with_fragment_without_pocket_edges, collate_with_fragment_edges, get_dataloader, get_one_hot, parse_molecule ) from src.lightning import DDPM from src.linker_size_lightning import SizeClassifier from src.utils import set_deterministic, FoundNaNException # Ignore pandas deprecation warning around pyarrow warnings.filterwarnings("ignore", category=DeprecationWarning, message="(?s).*Pyarrow will become a required dependency of pandas.*") import numpy as np import pandas as pd from pandarallel import pandarallel import torch from torch_geometric.loader import DataLoader from Bio import SeqIO from rdkit import RDLogger, Chem from rdkit.Chem import RemoveAllHs, PandasTools # TODO imports are a little odd, utils seems to shadow things from utils.logging_utils import configure_logger, get_logger from datasets.process_mols import create_mol_with_coords, read_molecule from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule from utils.inference_utils import InferenceDataset from utils.sampling import randomize_position, sampling from utils.utils import get_model from tqdm import tqdm configure_logger() log = get_logger() RDLogger.DisableLog('rdApp.*') ob.obErrorLog.SetOutputLevel(0) warnings.filterwarnings("ignore", category=UserWarning, message="The TorchScript type system doesn't support instance-level annotations on" " empty non-base types in `__init__`") # Prody logging is very verbose by default prody_logger = logging.getLogger(".prody") prody_logger.setLevel(logging.ERROR) # Pandarallel initialization nb_workers = os.cpu_count() progress_bar = False if hasattr(sys, 'gettrace') and sys.gettrace() is not None: # Debug mode nb_workers = 1 progress_bar = True pandarallel.initialize(nb_workers=nb_workers, progress_bar=progress_bar) def read_fragment_library(file_path): if file_path is None: return pd.DataFrame(columns=['X1', 'ID1', 'mol']) file_path = Path(file_path) if file_path.suffix == '.csv': df = pd.read_csv(file_path) # Validate columns for col in ['X1', 'ID1']: if col not in df.columns: raise ValueError(f"Column '{col}' not found in CSV file.") PandasTools.AddMoleculeColumnToFrame(df, smilesCol='X1', molCol='mol') elif file_path.suffix == '.sdf': df = PandasTools.LoadSDF(file_path, smilesName='X1', molColName='mol') id_cols = [col for col in df.columns if 'ID' in col] if id_cols: df['ID1'] = df[id_cols[0]] else: raise ValueError(f"Unsupported file format: {file_path.suffix}") if 'ID1' not in df.columns: df['ID1'] = None # Use InChiKey as ID1 if None df.loc[df['ID1'].isna(), 'ID1'] = df.loc[ df['ID1'].isna(), 'mol' ].apply(Chem.MolToInchiKey) return df[['X1', 'ID1', 'mol']] def read_protein_library(file_path): df = None if file_path.suffix == '.csv': df = pd.read_csv(file_path) elif file_path.suffix == '.fasta': records = list(SeqIO.parse(file_path, 'fasta')) df = pd.DataFrame([{'X2': str(record.seq), 'ID2': record.id} for record in records]) return df def remove_halogens(mol): if mol is None: return None halogens = ['F', 'Cl', 'Br', 'I', 'At'] # Enable editing rw_mol = Chem.RWMol(mol) for atom in rw_mol.GetAtoms(): if atom.GetSymbol() in halogens: # Replace with hydrogen atom.SetAtomicNum(1) mol_no_halogens = Chem.Mol(rw_mol) # Make hydrogen implicit mol_no_halogens = Chem.RemoveHs(mol_no_halogens) return mol_no_halogens def process_fragment_library(df, dehalogenate=True, discard_inorganic=True): """ SMILES strings with separators (e.g., .) represent distinct molecular entities, such as ligands, ions, or co-crystallized molecules. Splitting them ensures that each entity is treated individually, allowing focused analysis of their roles in binding. Single atom fragments (e.g., counterions like [I-] or [Cl-] are irrelevant in docking and are to be removed. This filtering focuses on structurally relevant fragments. """ # Remove fragments with invalid SMILES df['mol'] = df['X1'].apply(read_molecule, remove_confs=True) df = df.dropna(subset=['mol']) df['X1'] = df['mol'].apply(Chem.MolToSmiles) # Get subset of rows with SMILES containing separators fragmented_rows = df['X1'].str.contains('.', regex=False) df_fragmented = df[fragmented_rows].copy() # Split SMILES into lists and expand df_fragmented['X1'] = df_fragmented['X1'].str.split('.') df_fragmented = df_fragmented.explode('X1').reset_index(drop=True) # Append fragment index as alphabet (A, B, C... AA, AB...) to ID1 for rows with fragmented SMILES df_fragmented['ID1'] = df_fragmented.groupby('ID1').cumcount().apply(num_to_letter_code).radd( df_fragmented['ID1'] + '_') df = pd.concat([df[~fragmented_rows], df_fragmented]) # Remove single atom fragments df = df[df['mol'].apply(lambda mol: mol.GetNumAtoms() > 1)] if discard_inorganic: df = df[df['mol'].apply(lambda mol: any(atom.GetSymbol() == 'C' for atom in mol.GetAtoms()))] if dehalogenate: df['mol'] = df['mol'].apply(remove_halogens) # Deduplicate fragments and canonicalize SMILES df = df.groupby(['X1']).first().reset_index() df['X1'] = df['mol'].apply(lambda x: Chem.MolToSmiles(x)) return df def check_one_to_one(df, ID_column, X_column): # Check for multiple X values for the same ID id_to_x_conflicts = df.groupby(ID_column)[X_column].nunique() conflicting_ids = id_to_x_conflicts[id_to_x_conflicts > 1] # Check for multiple ID values for the same X x_to_id_conflicts = df.groupby(X_column)[ID_column].nunique() conflicting_xs = x_to_id_conflicts[x_to_id_conflicts > 1] # Print conflicting mappings if not conflicting_ids.empty: print(f"Conflicting {ID_column} -> multiple {X_column}:") for idx in conflicting_ids.index: print(f"{ID_column}: {idx}, {X_column} values: {df[df[ID_column] == idx][X_column].unique()}") if not conflicting_xs.empty: print(f"Conflicting {X_column} -> multiple {ID_column}:") for x in conflicting_xs.index: print(f"{X_column}: {x}, {ID_column} values: {df[df[X_column] == x][ID_column].unique()}") # Return whether the mappings are one-to-one return conflicting_ids.empty and conflicting_xs.empty def save_sdf(path, one_hot, positions, node_mask, is_geom): # Select atom mapping based on whether geometry or generic atoms are used idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM # Identify valid atoms based on the mask mask = node_mask.squeeze() atom_indices = torch.where(mask)[0] obMol = ob.OBMol() # Add atoms to OpenBabel molecule atoms = torch.argmax(one_hot, dim=1) for atom_i in atom_indices: atom = atoms[atom_i].item() atom_symbol = idx2atom[atom] obAtom = obMol.NewAtom() obAtom.SetAtomicNum(Chem.GetPeriodicTable().GetAtomicNumber(atom_symbol)) # Set atomic number # Set atomic positions pos = positions[atom_i] obAtom.SetVector(pos[0].item(), pos[1].item(), pos[2].item()) # Infer bonds with OpenBabel obMol.ConnectTheDots() obMol.PerceiveBondOrders() # Convert OpenBabel molecule to SDF obConversion = ob.OBConversion() obConversion.SetOutFormat("sdf") sdf_string = obConversion.WriteString(obMol) # Save SDF file with open(path, "w") as f: f.write(sdf_string) # Generate SMILES rdkit_mol = Chem.MolFromMolBlock(sdf_string) if rdkit_mol is not None: smiles = Chem.MolToSmiles(rdkit_mol) else: # Use OpenBabel to generate SMILES if RDKit fails obConversion.SetOutFormat("can") smiles = obConversion.WriteString(obMol).strip() return smiles def num_to_letter_code(n): result = '' while n >= 0: result = chr(65 + (n % 26)) + result n = n // 26 - 1 return result def dock_fragments( out_dir, score_ckpt, confidence_ckpt, device, inference_steps, n_poses, initial_noise_std_proportion, docking_batch_size, no_final_step_noise, temp_sampling_tr, temp_sampling_rot, temp_sampling_tor, temp_psi_tr, temp_psi_rot, temp_psi_tor, temp_sigma_data_tr, temp_sigma_data_rot,temp_sigma_data_tor, save_docking, df=None, protein_ligand_csv=None, fragment_library=None, protein_library=None, ): with open(Path(score_ckpt).parent / 'model_parameters.yml') as f: score_model_args = Namespace(**yaml.full_load(f)) with open(Path(confidence_ckpt).parent / 'model_parameters.yml') as f: confidence_args = Namespace(**yaml.full_load(f)) docking_out_dir = Path(out_dir, 'docking') docking_out_dir.mkdir(parents=True, exist_ok=True) if df is None: if protein_ligand_csv is not None: csv_path = Path(protein_ligand_csv) assert csv_path.is_file(), f"File {protein_ligand_csv} does not exist" df = pd.read_csv(csv_path) df = process_fragment_library(df) else: assert fragment_library is not None and protein_library is not None, "Either a .csv file or `X1` and `X2` must be provided." compound_df = pd.DataFrame(columns=['X1', 'ID1']) if Path(fragment_library).is_file(): compound_path = Path(fragment_library) if compound_path.suffix in ['.csv', '.sdf']: compound_df[['X1', 'ID1']] = read_fragment_library(compound_path)[['X1', 'ID1']] else: compound_df['X1'] = [compound_path] compound_df['ID1'] = [compound_path.stem] else: compound_df['X1'] = [fragment_library] compound_df['ID1'] = 'compound_0' compound_df.dropna(subset=['X1'], inplace=True) compound_df.loc[compound_df['ID1'].isna(), 'ID1'] = compound_df.loc[compound_df['ID1'].isna(), 'X1'].apply( lambda x: Chem.MolToInchiKey(Chem.MolFromSmiles(x)) ) protein_df = pd.DataFrame(columns=['X2', 'ID2']) if Path(protein_library).is_file(): protein_path = Path(protein_library) if protein_path.suffix in ['.csv', '.fasta']: protein_df[['X2', 'ID2']] = read_protein_library(protein_path)[['X2', 'ID2']] else: protein_df['X2'] = [protein_path] protein_df['ID2'] = [protein_path.stem] else: protein_df['X2'] = [protein_library] protein_df['ID2'] = 'protein_0' protein_df.dropna(subset=['X2'], inplace=True) protein_df.loc[protein_df['ID2'].isna(), 'ID2'] = [ f"protein_{i}" for i in range(protein_df['ID2'].isna().sum()) ] compound_df = process_fragment_library(compound_df) df = compound_df.merge(protein_df, how='cross') # Identify duplicates based on 'X1' and 'X2' duplicates = df[df.duplicated(subset=['X1', 'X2'], keep=False)] if not duplicates.empty: print("Duplicate rows based on columns 'X1' and 'X2':\n", duplicates[['ID1', 'X1', 'ID2', 'X2']]) print("Keeping the first occurrence of each duplicate.") df.drop_duplicates(subset=['X1', 'X2'], inplace=True) df['name'] = df['ID2'] + '-' + df['ID1'] df = df.replace({pd.NA: None}) # Check unique mappings between IDn and Xn assert check_one_to_one(df, 'ID1', 'X1'), "ID1-X1 mapping is not one-to-one." assert check_one_to_one(df, 'ID2', 'X2'), "ID2-X2 mapping is not one-to-one." """ Docking phase """ # preprocessing of complexes into geometric graphs test_dataset = InferenceDataset( df=df, out_dir=out_dir, lm_embeddings=True, receptor_radius=score_model_args.receptor_radius, remove_hs=True, # score_model_args.remove_hs, c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors, all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius, atom_max_neighbors=score_model_args.atom_max_neighbors, knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph ) test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) confidence_test_dataset = InferenceDataset( df=df, out_dir=out_dir, lm_embeddings=True, receptor_radius=confidence_args.receptor_radius, remove_hs=True, # confidence_args.remove_hs, c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors, all_atoms=confidence_args.all_atoms, atom_radius=confidence_args.atom_radius, atom_max_neighbors=confidence_args.atom_max_neighbors, precomputed_lm_embeddings=test_dataset.lm_embeddings, knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph ) t_to_sigma = partial(t_to_sigma_compl, args=score_model_args) model = get_model( score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True ) state_dict = torch.load(Path(score_ckpt), map_location='cpu', weights_only=True) model.load_state_dict(state_dict, strict=True) model = model.to(device) model.eval() confidence_model = get_model( confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True, confidence_mode=True, old=True ) state_dict = torch.load(Path(confidence_ckpt), map_location='cpu', weights_only=True) confidence_model.load_state_dict(state_dict, strict=True) confidence_model = confidence_model.to(device) confidence_model.eval() tr_schedule = get_t_schedule(inference_steps=inference_steps, sigma_schedule='expbeta') failures, skipped = 0, 0 samples_per_complex = n_poses test_ds_size = len(test_dataset) df = test_loader.dataset.df docking_dfs = [] log.info(f'Size of fragment dataset: {test_ds_size}') for idx, orig_complex_graph in tqdm(enumerate(test_loader), total=test_ds_size): if not orig_complex_graph.success[0]: skipped += 1 log.warning( f"The test dataset did not contain {df['name'].iloc[idx]}" f" for {df['X1'].iloc[idx]} and {df['X2'].iloc[idx]}. We are skipping this complex.") continue try: if confidence_test_dataset is not None: confidence_complex_graph = confidence_test_dataset[idx] if not confidence_complex_graph.success: skipped += 1 log.warning( f"The confidence dataset did not contain {orig_complex_graph.name}. We are skipping this complex.") continue confidence_data_list = [copy.deepcopy(confidence_complex_graph) for _ in range(samples_per_complex)] else: confidence_data_list = None data_list = [copy.deepcopy(orig_complex_graph) for _ in range(samples_per_complex)] randomize_position( data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max, initial_noise_std_proportion=initial_noise_std_proportion, choose_residue=False ) # run reverse diffusion # TODO How to make full use of VRAM? seems the best way to create another loop for each fragment ''' File "DiffFragDock/utils/sampling.py", line 142, in sampling tr_perturb = (tr_g ** 2 * dt_tr * tr_score + tr_g * np.sqrt(dt_tr) * tr_z) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RuntimeError: The size of tensor a (4) must match the size of tensor b (16) at non-singleton dimension 0 ''' # TODO It seems molecules of different sizes cannot be in the same batch in inference if n_poses <= docking_batch_size: batch_size = n_poses elif n_poses % docking_batch_size == 0: batch_size = docking_batch_size else: raise ValueError data_list, confidence = sampling( data_list=data_list, model=model, inference_steps=inference_steps, tr_schedule=tr_schedule, rot_schedule=tr_schedule, tor_schedule=tr_schedule, device=device, t_to_sigma=t_to_sigma, model_args=score_model_args, visualization_list=None, confidence_model=confidence_model, confidence_data_list=confidence_data_list, confidence_model_args=confidence_args, batch_size=batch_size, no_final_step_noise=no_final_step_noise, temp_sampling=[temp_sampling_tr, temp_sampling_rot, temp_sampling_tor], temp_psi=[temp_psi_tr, temp_psi_rot, temp_psi_tor], temp_sigma_data=[temp_sigma_data_tr, temp_sigma_data_rot,temp_sigma_data_tor] ) ligand_pos = np.asarray( [complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for complex_graph in data_list] ) # save predictions n_samples = len(confidence) sample_df = pd.DataFrame([df.iloc[idx]] * n_samples) sample_df['ID1'] = [f"{df['ID1'].iloc[idx]}_{i}" for i in range(n_samples)] confidence = confidence[:, 0].cpu().numpy() sample_df['confidence'] = confidence lig = orig_complex_graph.mol[0] # TODO Use index instead of confidence in filename if save_docking: sample_df['ligand_conf_path'] = [ f"{df['name'].iloc[idx]}_{i}-confidence{confidence[i]:.2f}.sdf" for i in range(n_samples) ] sample_df['ligand_mol']= [ create_mol_with_coords( mol=RemoveAllHs(copy.deepcopy(lig)), new_coords=pos, path=Path(docking_out_dir, sample_df['ligand_conf_path'].iloc[i]) if save_docking else None ) for i, pos in enumerate(ligand_pos) ] # sample_df['ligand_pos'] = list(ligand_pos) docking_dfs.append(sample_df) # write_dir = f"{args.out_dir}/{df['name'].iloc[idx]}" # for rank, pos in enumerate(ligand_pos): # mol_pred = copy.deepcopy(lig) # if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred) # if rank == 0: write_mol_with_coords(mol_pred, pos, Path(write_dir, f'rank{rank + 1}.sdf')) # write_mol_with_coords(mol_pred, pos, # Path(write_dir, f'rank{rank + 1}_confidence{confidence[rank]:.2f}.sdf')) except Exception as e: log.warning("Failed on", orig_complex_graph["name"], e) failures += 1 # Teardown del model if confidence_model is not None: del confidence_model del test_dataset if confidence_test_dataset is not None: del confidence_test_dataset del test_loader torch.cuda.empty_cache() docking_df = pd.concat(docking_dfs, ignore_index=True) # Save intermediate docking results if save_docking: docking_df[ ['name', 'ID2', 'protein_path', 'ID1', 'X1', 'confidence', 'ligand_conf_path'] ].to_csv(Path(out_dir, 'docking_summary.csv'), index=False) result_msg = f""" Failed for {failures} / {test_ds_size} complexes. Skipped {skipped} / {test_ds_size} complexes. """ if failures or skipped: log.warning(result_msg) else: log.info(result_msg) log.info(f"Docking results saved to {docking_out_dir}") return docking_df def calculate_mol_atomic_distances(mol1, mol2, distance_type='min'): mol1_coords = [ mol1.GetConformer().GetAtomPosition(i) for i in range(mol1.GetNumAtoms()) ] mol2_coords = [ mol2.GetConformer().GetAtomPosition(i) for i in range(mol2.GetNumAtoms()) ] # Ensure numpy arrays mol1_coords = np.array(mol1_coords) mol2_coords = np.array(mol2_coords) # Compute pairwise distances between carbon atoms atom_pairwise_distances = np.linalg.norm(mol1_coords[:, None, :] - mol2_coords[None, :, :], axis=-1) # if np.any(np.isnan(atom_pairwise_distances)): # import pdb # pdb.set_trace() # Trigger a breakpoint if NaN is found if distance_type == 'min': return atom_pairwise_distances.min() elif distance_type == 'mean': return atom_pairwise_distances.mean() elif distance_type is None: return atom_pairwise_distances else: raise ValueError(f"Unsupported distance_type: {distance_type}") def process_docking_results( df, eps=5, # Distance threshold for DBSCAN clustering min_samples=5, # Minimum number of samples for a cluster (enrichment) frag_dist_range=(2, 5), # Distance range for fragment linking distance_type='min', # Type of distance to compute between fragments ): assert len(frag_dist_range) == 2, 'Distance range must be a tuple of two values in Angstroms (Å).' frag_dist_range = sorted(frag_dist_range) # The mols in df should have been processed to have no explicit hydrogens, except heavy hydrogen isotopes. docking_summaries = [] # For saving intermediate docking results fragment_combos = [] # Fragment pairs for the linking step # 1. Cluster docking poses # Compute pairwise distances of molecules defined by the closest non-heavy atoms for protein, protein_df in df.groupby('X2'): protein_id = protein_df['ID2'].iloc[0] protein_path = protein_df['protein_path'].iloc[0] protein_df['index'] = protein_df.index log.info(f'Processing docking results for {protein_id}...') dist_matrix = np.stack( protein_df['ligand_mol'].parallel_apply( lambda mol1: [ calculate_mol_atomic_distances(mol1, mol2, distance_type=distance_type) for mol2 in protein_df['ligand_mol'] ] ) ) # Perform DBSCAN clustering dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='precomputed') protein_df['cluster'] = dbscan.fit_predict(dist_matrix) protein_df = protein_df.sort_values( by=['X1', 'cluster', 'confidence'], ascending=[True, True, False] ) # Add conformer number to ID1 protein_df['ID1'] = protein_df.groupby('ID1').cumcount().astype(str).radd(protein_df['ID1'] + '_') if args.save_docking: docking_summaries.append( protein_df[['name', 'ID2', 'X2', 'ID1', 'X1', 'cluster', 'confidence', 'ligand_conf_path']] ) # Filter out outlier poses protein_df = protein_df[protein_df['cluster'] != -1] # Keep only the highest confidence pose per protein per ligand per cluster protein_df = protein_df.groupby(['X1', 'cluster']).first().reset_index() # 2. Create fragment-linking pairs fragment_path = None protein_fragment_combos = [] for cluster, cluster_df in protein_df.groupby('cluster'): if len(cluster_df) > 1: # Skip clusters with only one pose pairs = list(itertools.combinations(cluster_df['index'], 2)) for i, j in pairs: row1 = cluster_df[cluster_df['index'] == i].iloc[0] row2 = cluster_df[cluster_df['index'] == j].iloc[0] dist = dist_matrix[i, j] # Check if intermolecular distance is within the range if frag_dist_range[0] < dist < frag_dist_range[1]: combined_smiles = f"{row1['X1']}.{row2['X1']}" combined_mol = Chem.CombineMols(row1['ligand_mol'], row2['ligand_mol']) complex_name = f"{protein_id}-{row1['ID1']}-{row2['ID1']}" if 'ligand_conf_path' in row1 and 'ligand_conf_path' in row2: fragment_path = [str(row1['ligand_conf_path']), str(row2['ligand_conf_path'])] protein_fragment_combos.append( (complex_name, protein, protein_path, combined_smiles, fragment_path, combined_mol, dist) ) log.info(f'Number of fragment pairs for {protein_id}: {len(protein_fragment_combos)}.') fragment_combos.extend(protein_fragment_combos) # Save intermediate docking results if args.save_docking: docking_summary_df = pd.concat(docking_summaries, ignore_index=True) docking_summary_df.to_csv(Path(args.out_dir, 'docking_summary.csv'), index=False) log.info(f'Saved intermediate docking results to {args.out_dir}') # Convert fragment pair results to DataFrame if fragment_combos: linking_df = pd.DataFrame( fragment_combos, columns=['name', 'X2', 'protein_path', 'X1', 'fragment_path', 'fragment_mol', 'distance'] ) if linking_df['fragment_path'].isnull().all(): linking_df.drop(columns=['fragment_path'], inplace=True) linking_df.drop(columns=['fragment_mol']).to_csv(Path(args.out_dir, 'linking_summary.csv'), index=False) return linking_df else: raise ValueError('No eligible fragment pose pairs found for linking.') def extract_pockets(protein_path, ligand_residue=None, top_pockets=None): protein_path = Path(protein_path) if ligand_residue: top_pockets = 1 # Copy the protein file to a temporary directory to avoid overwriting pocket files in different runs tmp_dir = tempfile.mkdtemp() tmp_protein_path = Path(tmp_dir) / protein_path.name shutil.copy(protein_path, tmp_protein_path) # Run fpocket distance = 2.5 min_size = 30 args = ['./fpocket', '-d', '-f', tmp_protein_path, '-D', str(distance), '-i', str(min_size)] if ligand_residue is not None: args += ['-r', ligand_residue] print(args) subprocess.run(args, stdout=subprocess.DEVNULL) fpocket_out_path = Path(str(tmp_protein_path.with_suffix('')) + '_out') if not fpocket_out_path.is_dir(): raise ValueError(f"fpocket output directory not found: {fpocket_out_path}") pocket_alpha_sphere_path_dict = {} if top_pockets is not None: pocket_names = [f'pocket{i}' for i in range(1, top_pockets + 1)] for name in pocket_names: pocket_path = Path(fpocket_out_path, f'{name}_vert.pqr').resolve() if pocket_path.is_file(): pocket_alpha_sphere_path_dict[name] = str(pocket_path) else: # use fpocket_out_path.glob('*_vert.pqr') pocket_alpha_sphere_path_dict = { pocket_path.stem.split('_')[0]: str(pocket_path) for pocket_path in fpocket_out_path.glob('*_vert.pqr') } return pocket_alpha_sphere_path_dict def check_pocket_overlap(mol, pocket_as): mol_coords = [ mol.GetConformer().GetAtomPosition(i) for i in range(mol.GetNumAtoms()) ] for as_coords, as_radii in zip(pocket_as['coord'], pocket_as['radii']): for atom_coord in mol_coords: if np.linalg.norm(as_coords - atom_coord) < as_radii: return True return False def deduplicate_conformers(fragment_df, rmsd_threshold=1.5): if len(fragment_df) > 1: mol_list = fragment_df['ligand_mol'].tolist() indices_to_drop = set() for i, mol1 in enumerate(mol_list): if i in indices_to_drop: # Skip already marked duplicates continue for j, mol2 in enumerate(mol_list): if i < j and j not in indices_to_drop: # Not comparing already removed molecules rmsd = Chem.rdMolAlign.CalcRMS(mol1, mol2) if rmsd < rmsd_threshold: indices_to_drop.add(fragment_df.index[j]) # Mark duplicate for removal fragment_df.drop(indices_to_drop, inplace=True) return fragment_df def select_fragment_pairs( df, pocket_path_dict=None, top_pockets=3, frag_dist_range=(2, 5), # Distance range for fragment linking confidence_threshold=-1.5, rmsd_threshold=1.5, method='fpocket', out_dir=Path('.'), ligand_residue=None, ): df = df[df['confidence'] > confidence_threshold].copy() if 'ligand_conf_path' in df.columns: df['ligand_conf_path'] = df['ligand_conf_path'].apply(Path) if 'ligand_mol' not in df.columns: df['ligand_mol'] = df['ligand_conf_path'].apply(read_molecule) # Given pocket_path_dict for single protein case if pocket_path_dict is not None: pocket_names = list(pocket_path_dict.keys()) top_pockets = len(pocket_names) else: pocket_names = [f'pocket{i}' for i in range(1, top_pockets + 1)] # Add pocket columns to DataFrame for name in pocket_names: df[name] = False fragment_conf_pairs = [] for protein_path, protein_df in df.groupby('protein_path'): protein_path = Path(protein_path) protein_fragment_conf_pairs = [] fragment_path = None protein_id = protein_df['ID2'].iloc[0] match method: case 'fpocket': # TODO: avoid reruning fpocket when proper job management is implemented if pocket_path_dict is None: pocket_path_dict = extract_pockets(protein_path, ligand_residue, top_pockets) # Read pocket PQRs for name in pocket_names: pocket_as = read_pocket_alpha_spheres(pocket_path_dict[name]) # Check if any atom in a fragment conformer falls within pocket volume of alpha spheres protein_df[name] = protein_df['ligand_mol'].parallel_apply( check_pocket_overlap, pocket_as=pocket_as ) case 'clustering': # Clustering-based pocket finding pass # Filter out fragment conformers that do not overlap with any pocket protein_df = protein_df[protein_df[pocket_names].any(axis=1)] # Select fragment conformer pairs for linking per pocket based on distance range for name in pocket_names: pocket_df = protein_df[protein_df[name] == True].copy() if len(pocket_df) > 1: # pocket_path = pocket_alpha_sphere_path_dict[name] # Deduplicate similar conformers with RDKit Chem.rdMolAlign.CalcRMS pocket_df = pocket_df.groupby('X1', group_keys=False).parallel_apply( deduplicate_conformers, rmsd_threshold=rmsd_threshold ).reset_index(drop=True) pairs = list(itertools.combinations(pocket_df.index, 2)) dist_matrix = np.stack( pocket_df['ligand_mol'].parallel_apply( lambda mol1: [ calculate_mol_atomic_distances(mol1, mol2, distance_type='min') for mol2 in pocket_df['ligand_mol'] ] ) ) for i, j in pairs: dist = dist_matrix[i, j] if frag_dist_range[0] < dist < frag_dist_range[1]: row1 = pocket_df.loc[i] row2 = pocket_df.loc[j] combined_smiles = f"{row1['X1']}.{row2['X1']}" combined_mol = Chem.CombineMols(row1['ligand_mol'], row2['ligand_mol']) complex_name = f"{protein_id}-{row1['ID1']}-{row2['ID1']}" if 'ligand_conf_path' in row1 and 'ligand_conf_path' in row2: fragment_path = [row1['ligand_conf_path'].name, row2['ligand_conf_path'].name] protein_fragment_conf_pairs.append( (complex_name, protein_path, # pocket_path, combined_smiles, fragment_path, combined_mol, dist) ) log.info(f'Number of fragment pairs for {protein_id}: {len(protein_fragment_conf_pairs)}.') fragment_conf_pairs.extend(protein_fragment_conf_pairs) # Convert fragment pair results to DataFrame if fragment_conf_pairs: linking_df = pd.DataFrame( fragment_conf_pairs, columns=[ 'name', 'protein_path', # 'pocket_path', 'X1', 'fragment_path', 'fragment_mol', 'distance' ] ) if linking_df['fragment_path'].isnull().all(): linking_df.drop(columns=['fragment_path'], inplace=True) linking_df.drop(columns=['fragment_mol']).to_csv(Path(out_dir, 'linking_summary.csv'), index=False) return linking_df else: return None def process_linking_results(): pass def get_pocket(mol, pdb_path, backbone_atoms_only=False): struct = PDBParser().get_structure('', pdb_path) residue_ids = [] atom_coords = [] for residue in struct.get_residues(): resid = residue.get_id()[1] for atom in residue.get_atoms(): atom_coords.append(atom.get_coord()) residue_ids.append(resid) residue_ids = np.array(residue_ids) atom_coords = np.array(atom_coords) mol_atom_coords = mol.GetConformer().GetPositions() distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1) contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]]) pocket_coords = [] pocket_types = [] for residue in struct.get_residues(): resid = residue.get_id()[1] if resid not in contact_residues: continue for atom in residue.get_atoms(): atom_name = atom.get_name() atom_type = atom.element.upper() atom_coord = atom.get_coord() if backbone_atoms_only and atom_name not in {'N', 'CA', 'C', 'O'}: continue pocket_coords.append(atom_coord.tolist()) pocket_types.append(atom_type) pocket_pos = [] pocket_one_hot = [] pocket_charges = [] for coord, atom_type in zip(pocket_coords, pocket_types): if atom_type not in const.GEOM_ATOM2IDX.keys(): continue pocket_pos.append(coord) pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) pocket_charges.append(const.GEOM_CHARGES[atom_type]) pocket_pos = np.array(pocket_pos) pocket_one_hot = np.array(pocket_one_hot) pocket_charges = np.array(pocket_charges) return pocket_pos, pocket_one_hot, pocket_charges def read_pocket(path, backbone_atoms_only): pocket_coords = [] pocket_types = [] struct = PDBParser().get_structure('', path) for residue in struct.get_residues(): for atom in residue.get_atoms(): atom_name = atom.get_name() atom_type = atom.element.upper() atom_coord = atom.get_coord() if backbone_atoms_only and atom_name not in {'N', 'CA', 'C', 'O'}: continue pocket_coords.append(atom_coord.tolist()) pocket_types.append(atom_type) return { 'coord': np.array(pocket_coords), 'types': np.array(pocket_types), } def read_pocket_alpha_spheres(path): ag = parsePQR(path) as_coords = [] as_radii = [] for atom in ag: as_coords.append(atom.getCoords()) as_radii.append(atom.getRadius()) return { 'coord': np.array(as_coords), 'radii': np.array(as_radii), } def generate_linkers( df, backbone_atoms_only, output_dir, n_samples, n_steps, linker_size, anchors, max_batch_size, random_seed, robust, linker_ckpt, size_ckpt, linker_condition, device, ): # Model setup pocket_conditioned = linker_condition in ['protein', 'pocket'] if 'X2' in df.columns and pocket_conditioned: if backbone_atoms_only: linker_ckpt = linker_ckpt['pocket_bb'] else: linker_ckpt = linker_ckpt['pocket_full'] else: linker_ckpt = linker_ckpt['geom'] ddpm = DDPM.load_from_checkpoint( linker_ckpt, robust=robust, torch_device=device, map_location=device ).eval().to(device) is_geom = ddpm.is_geom if random_seed is not None: set_deterministic(random_seed) output_dir = Path(output_dir, 'linking') output_dir.mkdir(exist_ok=True, parents=True) linker_size = str(linker_size) if linker_size == '0': log.info(f'Will generate linkers with sampled numbers of atoms') size_classifier = SizeClassifier.load_from_checkpoint(size_ckpt, map_location=device).eval().to(device) def sample_fn(_data): # TODO Improve efficiency: do not repeat sampling for the same fragment(-pocket) samples out, _ = size_classifier.forward( _data, return_loss=False, with_pocket=pocket_conditioned, adjust_shape=True ) probabilities = torch.softmax(out, dim=1) distribution = torch.distributions.Categorical(probs=probabilities) samples = distribution.sample() sizes = [] for label in samples.detach().cpu().numpy(): sizes.append(size_classifier.linker_id2size[label]) sizes = torch.tensor(sizes, device=samples.device, dtype=const.TORCH_INT) return sizes elif linker_size.isdigit(): log.info(f'Will generate linkers with {linker_size} atoms') linker_size = int(linker_size) def sample_fn(_data): return torch.ones(_data['positions'].shape[0], device=device, dtype=const.TORCH_INT) * linker_size else: boundaries = [x.strip() for x in linker_size.split(',')] if len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit(): left = int(boundaries[0]) right = int(boundaries[1]) log.info(f'Will generate linkers with numbers of atoms sampled from U({left}, {right})') def sample_fn(_data): shape = len(_data['positions']), return torch.randint(left, right + 1, shape, device=device, dtype=const.TORCH_INT) if n_steps is not None: ddpm.edm.T = n_steps if ddpm.center_of_mass == 'anchors' and anchors is None: log.warning( "Using a anchor-conditioned DiffLinker checkpoint without providing anchors. " "Forcing model's `center_of_mass` to 'fragments'." ) ddpm.center_of_mass = 'fragments' # # Apply the mapping to fill NaN values in ID1 and ID2 # if 'ID1' not in df.columns: # df['ID1'] = None # if 'ID2' not in df.columns: # df['ID2'] = None # df.loc[df['ID1'].isna(), 'ID1'] = df.loc[df['ID1'].isna(), 'X1'].apply( # lambda x: Chem.MolToInchiKey(Chem.MolFromSmiles(x)) # ) # df.loc[df['ID2'].isna(), 'ID2'] = df.loc[df['ID2'].isna(), 'X2'].map({ # x2_value: f"protein_{i}" # for i, x2_value in enumerate(df.loc[df['ID2'].isna(), 'X2'].unique()) # }) # # Identify duplicates based on 'X1' and 'X2' # duplicates = df[df.duplicated(subset=['X1', 'X2'], keep=False)] # if not duplicates.empty: # print("Duplicate rows based on columns 'X1' and 'X2':\n", duplicates[['X1', 'X2']]) # print("Keeping the first occurrence of each duplicate.") # df = df.drop_duplicates(subset=['X1', 'X2']) # Dataset setup if 'fragment_path' not in df.columns: df['fragment_path'] = df['X1'] if 'fragment_mol' not in df.columns: df['fragment_mol'] = df['fragment_path'].parallel_apply(read_molecule, remove_hs=True, remove_confs=False) if 'protein_path' not in df.columns: df['protein_path'] = df['X2'] if 'name' not in df.columns and 'ID1' in df.columns and 'ID2' in df.columns: df['name'] = df['ID1'] + '-' + df['ID2'] df.dropna(subset=['fragment_mol', 'protein_path'], inplace=True) cached_parse_molecule = cache(parse_molecule) dataset = [] optional_keys = ['X2', 'protein_path'] for row in df.itertuples(): mol = row.fragment_mol # Hs already removed # Parsing fragments data frag_pos, frag_one_hot, frag_charges = cached_parse_molecule(mol, is_geom=is_geom) # Parsing pocket data if pocket_conditioned: if linker_condition == 'protein': pocket_pos, pocket_one_hot, pocket_charges = get_pocket(mol, row.protein_path, backbone_atoms_only) elif linker_condition == 'pocket': pocket_data = read_pocket(row.protein_path, backbone_atoms_only) pocket_pos = pocket_data['coord'] pocket_one_hot = [] pocket_charges = [] for atom_type in pocket_data['types']: pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) pocket_charges.append(const.GEOM_CHARGES[atom_type]) pocket_one_hot = np.array(pocket_one_hot) pocket_charges = np.array(pocket_charges) positions = np.concatenate([frag_pos, pocket_pos], axis=0) one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0) charges = np.concatenate([frag_charges, pocket_charges], axis=0) fragment_only_mask = np.concatenate([ np.ones_like(frag_charges), np.zeros_like(pocket_charges), ]) pocket_mask = np.concatenate([ np.zeros_like(frag_charges), np.ones_like(pocket_charges), ]) linker_mask = np.concatenate([ np.zeros_like(frag_charges), np.zeros_like(pocket_charges), ]) fragment_mask = np.concatenate([ np.ones_like(frag_charges), np.ones_like(pocket_charges), ]) else: positions = frag_pos one_hot = frag_one_hot charges = frag_charges fragment_only_mask = np.ones_like(charges) pocket_mask = np.zeros_like(charges) linker_mask = np.zeros_like(charges) fragment_mask = np.ones_like(charges) anchor_flags = np.zeros_like(charges) if anchors is not None: for anchor in anchors.split(','): anchor_flags[int(anchor.strip()) - 1] = 1 data = { 'name': row.name, 'X1': row.X1, 'fragment_path': row.fragment_path, 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), 'anchors': torch.tensor(anchor_flags, dtype=const.TORCH_FLOAT, device=device), 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), 'num_atoms': len(positions) } for k in optional_keys: if hasattr(row, k): data[k] = getattr(row, k) if pocket_conditioned: data |= { 'X2': row.X2, 'protein_path': row.protein_path, 'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device), 'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device), } dataset.extend([data] * n_samples) ddpm.val_dataset = dataset global_batch_size = min(n_samples, max_batch_size) log.info(f'DiffLinker global batch size: {global_batch_size}') dataloader = get_dataloader( dataset, batch_size=global_batch_size, collate_fn=collate_with_fragment_without_pocket_edges if pocket_conditioned else collate_with_fragment_edges ) # df.drop(columns=['ligand_mol', 'protein_path'], inplace=True) linking_dfs = [] # Sampling print('Sampling...') # TODO: update linking_summary.csv per batch for batch_i, data in tqdm(enumerate(dataloader), total=len(dataloader)): effective_batch_size = len(data['positions']) batch_data = { 'name': data['name'], 'X1': data['X1'], 'fragment_path': data['fragment_path'], } for k in optional_keys: if k in data: batch_data[k] = data[k] if pocket_conditioned: batch_data |= { 'X2': data['X2'], 'protein_path': data['protein_path'], } batch_df = pd.DataFrame(batch_data) chain = None node_mask = None for i in range(5): try: chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) break except FoundNaNException: continue if chain is None: log.warning(f'Could not generate linker for batch {batch_i} in 5 attempts') continue x = chain[0][:, :, :ddpm.n_dims] h = chain[0][:, :, ddpm.n_dims:] # Put the molecule back to the initial orientation if ddpm.center_of_mass == 'fragments': if pocket_conditioned: com_mask = data['fragment_only_mask'] else: com_mask = data['fragment_mask'] else: com_mask = data['anchors'] pos_masked = data['positions'] * com_mask N = com_mask.sum(1, keepdims=True) mean = torch.sum(pos_masked, dim=1, keepdim=True) / N x = x + mean * node_mask if pocket_conditioned: node_mask[torch.where(data['pocket_mask'])] = 0 batch_df['one_hot'] = list(h.cpu()) batch_df['positions'] = list(x.cpu()) batch_df['node_mask'] = list(node_mask.cpu()) linking_dfs.append(batch_df) # for i in range(effective_batch_size): # # # Save XYZ file and generate SMILES # # out_xyz = Path(output_dir, f'{name}_{offset_idx+i}.xyz') # # smiles = save_xyz_files(out_xyz, h[i], x[i], node_mask[i], is_geom=is_geom) # # # Convert XYZ to SDF # # out_sdf = Path(output_dir, name, f'output_{offset_idx+i}.sdf') # # with open(os.devnull, 'w') as devnull: # # subprocess.run(f'obabel {out_xyz} -O {out_sdf} -q', shell=True, stdout=devnull) # # Save SDF file and generate SMILES # out_sdf = Path(output_dir, f'{data["name"][i]}.sdf') # smiles = save_sdf(out_sdf, h[i], x[i], node_mask[i], is_geom=is_geom) # # # Add experiment summary info # batch_df['X1^'] = smiles # batch_df['out_path'] = str(out_sdf) # linking_dfs.append(batch_df) # Teardown del ddpm torch.cuda.empty_cache() if linking_dfs: linking_summary_df = pd.concat(linking_dfs, ignore_index=True) linking_summary_df['out_path'] = linking_summary_df.groupby('name').cumcount().apply( lambda x: f"{x:0{len(str(linking_summary_df.groupby('name').cumcount().max()))}d}" ).radd(linking_summary_df['name'] + '_') + '.sdf' linking_summary_df['X1^'] = linking_summary_df.parallel_apply( # parallel_apply bug lambda x: save_sdf( output_dir / x['out_path'], x['one_hot'], x['positions'], x['node_mask'], is_geom=is_geom ), axis=1 ) # TODO add 'pocket_path' and 'distance' linking_summary_df[ ['name', 'protein_path', 'fragment_path', 'X1', 'X1^', 'out_path'] ].to_csv(Path(output_dir.parent, 'linking_summary.csv'), index=False) print(f'Saved experiment summary and generated molecules to {output_dir}') else: raise ValueError('No linkers generated.') if __name__ == "__main__": parser = ArgumentParser() # Fragment docking settings parser.add_argument('--config', type=FileType(mode='r'), default='default_inference_args.yaml') parser.add_argument('--protein_ligand_csv', type=str, default=None, help='Path to a .csv file specifying the input as described in the README. ' 'If this is not None, it will be used instead of the `X1` and `X2` parameters') parser.add_argument('-n', '--name', type=str, default=None, help='Name that the experiment will be saved with') parser.add_argument('--X1', type=str, help='Either a SMILES string or the path of a molecule file that rdkit can read') parser.add_argument('--X2', type=str, help='Either a FASTA sequence or the path of a protein for ESMFold') parser.add_argument('-l', '--log', '--loglevel', type=str, default='INFO', dest="loglevel", help='Log level. Default %(default)s') parser.add_argument('--out_dir', type=str, default='results/', help='Directory where the outputs will be written to') parser.add_argument('--save_docking', action='store_true', default=True, help='Save the intermediate docking results including SDF files and a summary CSV.') parser.add_argument('--save_visualisation', action='store_true', default=False, help='Save a pdb file with all of the steps of the reverse diffusion') parser.add_argument('--samples_per_complex', type=int, default=10, help='Number of samples to generate') # parser.add_argument('--model_dir', type=str, default=None, # help='Path to folder with trained score model and hyperparameters') parser.add_argument('--score_ckpt', type=str, default='best_ema_inference_epoch_model.pt', help='Checkpoint to use for the score model') # parser.add_argument('--confidence_model_dir', type=str, default=None, # help='Path to folder with trained confidence model and hyperparameters') parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt', help='Checkpoint to use for the confidence model') parser.add_argument('--n_poses', type=int, default=10, help='') parser.add_argument('--no_final_step_noise', action='store_true', default=True, help='Use no noise in the final step of the reverse diffusion') parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps') parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0, help='Initial noise std proportion') parser.add_argument('--choose_residue', action='store_true', default=False, help='') parser.add_argument('--temp_sampling_tr', type=float, default=1.0) parser.add_argument('--temp_psi_tr', type=float, default=0.0) parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5) parser.add_argument('--temp_sampling_rot', type=float, default=1.0) parser.add_argument('--temp_psi_rot', type=float, default=0.0) parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5) parser.add_argument('--temp_sampling_tor', type=float, default=1.0) parser.add_argument('--temp_psi_tor', type=float, default=0.0) parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5) parser.add_argument('--gnina_minimize', action='store_true', default=False, help='') parser.add_argument('--gnina_path', type=str, default='gnina', help='') parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt', help='') # To redirect gnina subprocesses stdouts from the terminal window parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='') parser.add_argument('--gnina_autobox_add', type=float, default=4.0) parser.add_argument('--gnina_poses_to_optimize', type=int, default=1) # Linker generation settings # parser.add_argument('--fragments', action='store', type=str, required=True, # help='Path to the file with input fragments' # ) # parser.add_argument( # '--protein', action='store', type=str, required=True, # help='Path to the file with the target protein' # ) parser.add_argument( '--backbone_atoms_only', action='store_true', required=False, default=False, help='Flag if to use only protein backbone atoms' ) parser.add_argument( '--linker_ckpt', action='store', type=str, help='Path to the DiffLinker model' ) parser.add_argument( '--linker_size', action='store', type=str, default='0', help='Linker size (int) or allowed size boundaries (comma-separated) or path to the size prediction model' ) parser.add_argument( '--n_linkers', action='store', type=int, required=False, default=5, help='Number of linkers to generate' ) parser.add_argument( '--linker_steps', action='store', type=int, required=False, default=1000, help='Number of denoising steps' ) parser.add_argument( '--anchors', action='store', type=str, required=False, default=None, help='Comma-separated indices of anchor atoms ' '(according to the order of atoms in the input fragments file, enumeration starts with 1)' ) parser.add_argument( '--linker_batch_size', action='store', type=int, required=False, help='Max batch size for linker generation model' ) parser.add_argument( '--docking_batch_size', action='store', type=int, required=False, help='Max batch size for fragment docking model' ) parser.add_argument( '--random_seed', action='store', type=int, required=False, default=None, help='Random seed' ) parser.add_argument( '--robust', action='store_true', required=False, default=False, help='Robust sampling modification' ) parser.add_argument( '--dock', action='store_true', default=False, help='Fragment docking with DiffDock' ) parser.add_argument( '--link', action='store_true', default=False, help='Linker generation with DiffLinker' ) args = parser.parse_args() if args.config: config_dict = yaml.load(args.config, Loader=yaml.FullLoader) arg_dict = args.__dict__ for key, value in config_dict.items(): # if isinstance(value, list): # for v in value: # arg_dict[key].append(v) # else: arg_dict[key] = value device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") experiment_name = f"{date_time}_{args.name}" args.out_dir = Path(args.out_dir, experiment_name) args.out_dir.mkdir(exist_ok=True, parents=True) configure_logger(args.loglevel, logfile=args.out_dir / 'inference.log') log = get_logger() log.info(f"DiffFBDD will run on {device}") docking_df = None linking_df = None if args.dock: docking_df = dock_fragments( protein_ligand_csv=args.protein_ligand_csv, fragment_library=args.X1, protein_library=args.X2, out_dir=args.out_dir, score_ckpt=args.score_ckpt, confidence_ckpt=args.confidence_ckpt, inference_steps=args.inference_steps, n_poses=args.n_poses, docking_batch_size=args.docking_batch_size, initial_noise_std_proportion=args.initial_noise_std_proportion, no_final_step_noise=args.no_final_step_noise, temp_sampling_tr=args.temp_sampling_tr, temp_sampling_rot=args.temp_sampling_rot, temp_sampling_tor=args.temp_sampling_tor, temp_psi_tr=args.temp_psi_tr, temp_psi_rot=args.temp_psi_rot, temp_psi_tor=args.temp_psi_tor, temp_sigma_data_tr=args.temp_sigma_data_tr, temp_sigma_data_rot=args.temp_sigma_data_rot, temp_sigma_data_tor=args.temp_sigma_data_tor, save_docking=args.save_docking, device=device, ) # linking_df = process_docking_results( # docking_df, # eps=args.eps, min_samples=args.min_samples, # frag_dist_range=args.frag_dist_range, distance_type=args.distance_type # ) else: df = pd.read_csv(args.protein_ligand_csv) if 'ligand_conf_path' in df.columns: docking_df = df else: linking_df = df if args.link: if docking_df is not None and linking_df is None: linking_df = select_fragment_pairs( docking_df, top_pockets=args.top_pockets, frag_dist_range=args.frag_dist_range, confidence_threshold=args.confidence_threshold, rmsd_threshold=args.rmsd_threshold, out_dir=args.out_dir, ) if linking_df is None or len(linking_df) == 0: log.error('No eligible fragment pose pairs found for linking.') sys.exit() generate_linkers( linking_df, backbone_atoms_only=args.backbone_atoms_only, output_dir=args.out_dir, n_samples=args.n_linkers, n_steps=args.linker_steps, linker_size=args.linker_size, anchors=args.anchors, max_batch_size=args.linker_batch_size, random_seed=args.random_seed, robust=args.robust, linker_ckpt=args.linker_ckpt, size_ckpt=args.size_ckpt, linker_condition=args.linker_condition, device=device, )