import copy import random import numpy as np import torch from torch_geometric.data import Batch from torch_geometric.loader import DataLoader from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch from utils.torsion import modify_conformer_torsion_angles from scipy.spatial.transform import Rotation as R from utils.utils import crop_beyond from utils.logging_utils import get_logger def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7, initial_noise_std_proportion=-1.0, choose_residue=False): # in place modification of the list center_pocket = data_list[0]['receptor'].pos.mean(dim=0) if pocket_knowledge: complex = data_list[0] d = torch.cdist(complex['receptor'].pos, torch.from_numpy(complex['ligand'].orig_pos[0]).float() - complex.original_center) label = torch.any(d < pocket_cutoff, dim=1) if torch.any(label): center_pocket = complex['receptor'].pos[label].mean(dim=0) else: print("No pocket residue below minimum distance ", pocket_cutoff, "taking closest at", torch.min(d)) center_pocket = complex['receptor'].pos[torch.argmin(torch.min(d, dim=1)[0])] if not no_torsion: # randomize torsion angles for complex_graph in data_list: torsion_updates = np.random.uniform(low=-np.pi, high=np.pi, size=complex_graph['ligand'].edge_mask.sum()) complex_graph['ligand'].pos = \ modify_conformer_torsion_angles(complex_graph['ligand'].pos, complex_graph['ligand', 'ligand'].edge_index.T[ complex_graph['ligand'].edge_mask], complex_graph['ligand'].mask_rotate[0], torsion_updates) for complex_graph in data_list: # randomize position molecule_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True) random_rotation = torch.from_numpy(R.random().as_matrix()).float() complex_graph['ligand'].pos = (complex_graph['ligand'].pos - molecule_center) @ random_rotation.T + center_pocket # base_rmsd = np.sqrt(np.sum((complex_graph['ligand'].pos.cpu().numpy() - orig_complex_graph['ligand'].pos.numpy()) ** 2, axis=1).mean()) if not no_random: # note for now the torsion angles are still randomised if choose_residue: idx = random.randint(0, len(complex_graph['receptor'].pos)-1) tr_update = torch.normal(mean=complex_graph['receptor'].pos[idx:idx+1], std=0.01) elif initial_noise_std_proportion >= 0.0: std_rec = torch.sqrt(torch.mean(torch.sum(complex_graph['receptor'].pos ** 2, dim=1))) tr_update = torch.normal(mean=0, std=std_rec * initial_noise_std_proportion / 1.73, size=(1, 3)) else: # if initial_noise_std_proportion < 0.0, we use the tr_sigma_max multiplied by -initial_noise_std_proportion tr_update = torch.normal(mean=0, std=-initial_noise_std_proportion * tr_sigma_max, size=(1, 3)) complex_graph['ligand'].pos += tr_update def is_iterable(arr): try: some_object_iterator = iter(arr) return True except TypeError as te: return False def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_schedule, device, t_to_sigma, model_args, no_random=False, ode=False, visualization_list=None, confidence_model=None, confidence_data_list=None, confidence_model_args=None, t_schedule=None, batch_size=32, no_final_step_noise=False, pivot=None, return_full_trajectory=False, temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False): N = len(data_list) trajectory = [] logger = get_logger() if return_features: lig_features, rec_features = [], [] assert batch_size >= N, "Not implemented yet" loader = DataLoader(data_list, batch_size=batch_size) assert not (return_full_trajectory or return_features or pivot), "Not implemented yet in new inference version" mask_rotate = torch.from_numpy(data_list[0]['ligand'].mask_rotate[0]).to(device) confidence = None if confidence_model is not None: confidence_loader = iter(DataLoader(confidence_data_list, batch_size=batch_size)) confidence = [] with torch.no_grad(): for batch_id, complex_graph_batch in enumerate(loader): b = complex_graph_batch.num_graphs n = len(complex_graph_batch['ligand'].pos) // b complex_graph_batch = complex_graph_batch.to(device) for t_idx in range(inference_steps): t_tr, t_rot, t_tor = tr_schedule[t_idx], rot_schedule[t_idx], tor_schedule[t_idx] dt_tr = tr_schedule[t_idx] - tr_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tr_schedule[t_idx] dt_rot = rot_schedule[t_idx] - rot_schedule[t_idx + 1] if t_idx < inference_steps - 1 else rot_schedule[t_idx] dt_tor = tor_schedule[t_idx] - tor_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tor_schedule[t_idx] tr_sigma, rot_sigma, tor_sigma = t_to_sigma(t_tr, t_rot, t_tor) if hasattr(model_args, 'crop_beyond') and model_args.crop_beyond is not None: #print('Cropping beyond', tr_sigma * 3 + model_args.crop_beyond, 'for score model') mod_complex_graph_batch = copy.deepcopy(complex_graph_batch).to_data_list() for batch in mod_complex_graph_batch: crop_beyond(batch, tr_sigma * 3 + model_args.crop_beyond, model_args.all_atoms) mod_complex_graph_batch = Batch.from_data_list(mod_complex_graph_batch) else: mod_complex_graph_batch = complex_graph_batch set_time(mod_complex_graph_batch, t_schedule[t_idx] if t_schedule is not None else None, t_tr, t_rot, t_tor, b, 'all_atoms' in model_args and model_args.all_atoms, device) tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3] mean_scores = torch.mean(tr_score, dim=-1) num_nans = torch.sum(torch.isnan(mean_scores)) if num_nans > 0: name = complex_graph_batch['name'] if isinstance(name, list): name = name[0] logger.warning(f"Complex {name} Batch {batch_id+1} Inference Iteration {t_idx}: " f"{num_nans} / {mean_scores.numel()} samples failed") # Set the nan values to a small value, just want to disturb slightly # Hopefully won't get nan the next iteration tr_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tr_score.abs())), posinf=eps, neginf=-eps) rot_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(rot_score.abs())), posinf=eps, neginf=-eps) tor_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tor_score.abs())), posinf=eps, neginf=-eps) del eps tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min))) rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min))) if ode: tr_perturb = (0.5 * tr_g ** 2 * dt_tr * tr_score) rot_perturb = (0.5 * rot_score * dt_rot * rot_g ** 2) else: tr_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \ else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device) tr_perturb = (tr_g ** 2 * dt_tr * tr_score + tr_g * np.sqrt(dt_tr) * tr_z) rot_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \ else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device) rot_perturb = (rot_score * dt_rot * rot_g ** 2 + rot_g * np.sqrt(dt_rot) * rot_z) if not model_args.no_torsion: tor_g = tor_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tor_sigma_max / model_args.tor_sigma_min))) if ode: tor_perturb = (0.5 * tor_g ** 2 * dt_tor * tor_score) else: tor_z = torch.zeros(tor_score.shape, device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \ else torch.normal(mean=0, std=1, size=tor_score.shape, device=device) tor_perturb = (tor_g ** 2 * dt_tor * tor_score + tor_g * np.sqrt(dt_tor) * tor_z) torsions_per_molecule = tor_perturb.shape[0] // b else: tor_perturb = None if not is_iterable(temp_sampling): temp_sampling = [temp_sampling] * 3 if not is_iterable(temp_psi): temp_psi = [temp_psi] * 3 if not is_iterable(temp_sampling): temp_sampling = [temp_sampling] * 3 if not is_iterable(temp_psi): temp_psi = [temp_psi] * 3 if not is_iterable(temp_sigma_data): temp_sigma_data = [temp_sigma_data] * 3 assert len(temp_sampling) == 3 assert len(temp_psi) == 3 assert len(temp_sigma_data) == 3 if temp_sampling[0] != 1.0: tr_sigma_data = np.exp(temp_sigma_data[0] * np.log(model_args.tr_sigma_max) + (1 - temp_sigma_data[0]) * np.log(model_args.tr_sigma_min)) lambda_tr = (tr_sigma_data + tr_sigma) / (tr_sigma_data + tr_sigma / temp_sampling[0]) tr_perturb = (tr_g ** 2 * dt_tr * (lambda_tr + temp_sampling[0] * temp_psi[0] / 2) * tr_score + tr_g * np.sqrt(dt_tr * (1 + temp_psi[0])) * tr_z) if temp_sampling[1] != 1.0: rot_sigma_data = np.exp(temp_sigma_data[1] * np.log(model_args.rot_sigma_max) + (1 - temp_sigma_data[1]) * np.log(model_args.rot_sigma_min)) lambda_rot = (rot_sigma_data + rot_sigma) / (rot_sigma_data + rot_sigma / temp_sampling[1]) rot_perturb = (rot_g ** 2 * dt_rot * (lambda_rot + temp_sampling[1] * temp_psi[1] / 2) * rot_score + rot_g * np.sqrt(dt_rot * (1 + temp_psi[1])) * rot_z) if temp_sampling[2] != 1.0: tor_sigma_data = np.exp(temp_sigma_data[2] * np.log(model_args.tor_sigma_max) + (1 - temp_sigma_data[2]) * np.log(model_args.tor_sigma_min)) lambda_tor = (tor_sigma_data + tor_sigma) / (tor_sigma_data + tor_sigma / temp_sampling[2]) tor_perturb = (tor_g ** 2 * dt_tor * (lambda_tor + temp_sampling[2] * temp_psi[2] / 2) * tor_score + tor_g * np.sqrt(dt_tor * (1 + temp_psi[2])) * tor_z) # Apply noise complex_graph_batch['ligand'].pos = \ modify_conformer_batch(complex_graph_batch['ligand'].pos, complex_graph_batch, tr_perturb, rot_perturb, tor_perturb if not model_args.no_torsion else None, mask_rotate) if visualization_list is not None: for idx_b in range(b): visualization_list[batch_id * batch_size + idx_b].add(( complex_graph_batch['ligand'].pos[idx_b*n:n*(idx_b+1)].detach().cpu() + data_list[batch_id * batch_size + idx_b].original_center.detach().cpu()), part=1, order=t_idx + 2) for i in range(b): data_list[batch_id * batch_size + i]['ligand'].pos = complex_graph_batch['ligand'].pos[i*n:n*(i+1)] if visualization_list is not None: for idx, visualization in enumerate(visualization_list): visualization.add((data_list[idx]['ligand'].pos.detach().cpu() + data_list[idx].original_center.detach().cpu()), part=1, order=2) if confidence_model is not None: if confidence_data_list is not None: confidence_complex_graph_batch = next(confidence_loader) confidence_complex_graph_batch['ligand'].pos = complex_graph_batch['ligand'].pos.cpu() if hasattr(confidence_model_args, 'crop_beyond') and confidence_model_args.crop_beyond is not None: confidence_complex_graph_batch = confidence_complex_graph_batch.to_data_list() for batch in confidence_complex_graph_batch: crop_beyond(batch, confidence_model_args.crop_beyond, confidence_model_args.all_atoms) confidence_complex_graph_batch = Batch.from_data_list(confidence_complex_graph_batch) confidence_complex_graph_batch = confidence_complex_graph_batch.to(device) set_time(confidence_complex_graph_batch, 0, 0, 0, 0, b, confidence_model_args.all_atoms, device) out = confidence_model(confidence_complex_graph_batch) else: out = confidence_model(complex_graph_batch) if type(out) is tuple: out = out[0] confidence.append(out) if confidence_model is not None: confidence = torch.cat(confidence, dim=0) confidence = torch.nan_to_num(confidence, nan=-1000) if return_full_trajectory: return data_list, confidence, trajectory elif return_features: lig_features = torch.cat(lig_features, dim=0) rec_features = torch.cat(rec_features, dim=0) return data_list, confidence, lig_features, rec_features return data_list, confidence def compute_affinity(data_list, affinity_model, affinity_data_list, device, parallel, all_atoms, include_miscellaneous_atoms): with torch.no_grad(): if affinity_model is not None: assert parallel <= len(data_list) loader = DataLoader(data_list, batch_size=parallel) complex_graph_batch = next(iter(loader)).to(device) positions = complex_graph_batch['ligand'].pos assert affinity_data_list is not None complex_graph = affinity_data_list[0] N = complex_graph['ligand'].num_nodes complex_graph['ligand'].x = complex_graph['ligand'].x.repeat(parallel, 1) complex_graph['ligand'].edge_mask = complex_graph['ligand'].edge_mask.repeat(parallel) complex_graph['ligand', 'ligand'].edge_index = torch.cat( [N * i + complex_graph['ligand', 'ligand'].edge_index for i in range(parallel)], dim=1) complex_graph['ligand', 'ligand'].edge_attr = complex_graph['ligand', 'ligand'].edge_attr.repeat(parallel, 1) complex_graph['ligand'].pos = positions affinity_loader = DataLoader([complex_graph], batch_size=1) affinity_batch = next(iter(affinity_loader)).to(device) set_time(affinity_batch, 0, 0, 0, 0, 1, all_atoms, device, include_miscellaneous_atoms=include_miscellaneous_atoms) _, affinity = affinity_model(affinity_batch) else: affinity = None return affinity