|
""" |
|
Train a new model. |
|
""" |
|
|
|
import sys |
|
import argparse |
|
import h5py |
|
import datetime |
|
import subprocess as sp |
|
import numpy as np |
|
import pandas as pd |
|
import gzip as gz |
|
from tqdm import tqdm |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from torch.autograd import Variable |
|
from torch.utils.data import IterableDataset, DataLoader |
|
from sklearn.metrics import average_precision_score as average_precision |
|
|
|
import dscript |
|
from dscript.utils import PairedDataset, collate_paired_sequences |
|
from dscript.models.embedding import ( |
|
IdentityEmbed, |
|
FullyConnectedEmbed, |
|
) |
|
from dscript.models.contact import ContactCNN |
|
from dscript.models.interaction import ModelInteraction |
|
|
|
|
|
def add_args(parser): |
|
""" |
|
Create parser for command line utility. |
|
|
|
:meta private: |
|
""" |
|
|
|
data_grp = parser.add_argument_group("Data") |
|
proj_grp = parser.add_argument_group("Projection Module") |
|
contact_grp = parser.add_argument_group("Contact Module") |
|
inter_grp = parser.add_argument_group("Interaction Module") |
|
train_grp = parser.add_argument_group("Training") |
|
misc_grp = parser.add_argument_group("Output and Device") |
|
|
|
|
|
data_grp.add_argument("--train", help="Training data", required=True) |
|
data_grp.add_argument("--val", help="Validation data", required=True) |
|
data_grp.add_argument("--embedding", help="h5 file with embedded sequences", required=True) |
|
data_grp.add_argument( |
|
"--augment", |
|
action="store_true", |
|
help="Set flag to augment data by adding (B A) for all pairs (A B)", |
|
) |
|
|
|
|
|
proj_grp.add_argument( |
|
"--projection-dim", |
|
type=int, |
|
default=100, |
|
help="Dimension of embedding projection layer (default: 100)", |
|
) |
|
proj_grp.add_argument( |
|
"--dropout-p", |
|
type=float, |
|
default=0.5, |
|
help="Parameter p for embedding dropout layer (default: 0.5)", |
|
) |
|
|
|
|
|
contact_grp.add_argument( |
|
"--hidden-dim", |
|
type=int, |
|
default=50, |
|
help="Number of hidden units for comparison layer in contact prediction (default: 50)", |
|
) |
|
contact_grp.add_argument( |
|
"--kernel-width", |
|
type=int, |
|
default=7, |
|
help="Width of convolutional filter for contact prediction (default: 7)", |
|
) |
|
|
|
|
|
inter_grp.add_argument( |
|
"--use-w", |
|
action="store_true", |
|
help="Use weight matrix in interaction prediction model", |
|
) |
|
inter_grp.add_argument( |
|
"--pool-width", |
|
type=int, |
|
default=9, |
|
help="Size of max-pool in interaction model (default: 9)", |
|
) |
|
|
|
|
|
train_grp.add_argument( |
|
"--negative-ratio", |
|
type=int, |
|
default=10, |
|
help="Number of negative training samples for each positive training sample (default: 10)", |
|
) |
|
train_grp.add_argument( |
|
"--epoch-scale", |
|
type=int, |
|
default=1, |
|
help="Report heldout performance every this many epochs (default: 1)", |
|
) |
|
train_grp.add_argument("--num-epochs", type=int, default=10, help="Number of epochs (default: 10)") |
|
train_grp.add_argument("--batch-size", type=int, default=25, help="Minibatch size (default: 25)") |
|
train_grp.add_argument("--weight-decay", type=float, default=0, help="L2 regularization (default: 0)") |
|
train_grp.add_argument("--lr", type=float, default=0.001, help="Learning rate (default: 0.001)") |
|
train_grp.add_argument( |
|
"--lambda", |
|
dest="lambda_", |
|
type=float, |
|
default=0.35, |
|
help="Weight on the similarity objective (default: 0.35)", |
|
) |
|
|
|
|
|
misc_grp.add_argument("-o", "--outfile", help="Output file path (default: stdout)") |
|
misc_grp.add_argument("--save-prefix", help="Path prefix for saving models") |
|
misc_grp.add_argument("-d", "--device", type=int, default=-1, help="Compute device to use") |
|
misc_grp.add_argument("--checkpoint", help="Checkpoint model to start training from") |
|
|
|
return parser |
|
|
|
|
|
def predict_interaction(model, n0, n1, tensors, use_cuda): |
|
""" |
|
Predict whether a list of protein pairs will interact. |
|
|
|
:param model: Model to be trained |
|
:type model: dscript.models.interaction.ModelInteraction |
|
:param n0: First protein names |
|
:type n0: list[str] |
|
:param n1: Second protein names |
|
:type n1: list[str] |
|
:param tensors: Dictionary of protein names to embeddings |
|
:type tensors: dict[str, torch.Tensor] |
|
:param use_cuda: Whether to use GPU |
|
:type use_cuda: bool |
|
""" |
|
|
|
b = len(n0) |
|
|
|
p_hat = [] |
|
for i in range(b): |
|
z_a = tensors[n0[i]] |
|
z_b = tensors[n1[i]] |
|
if use_cuda: |
|
z_a = z_a.cuda() |
|
z_b = z_b.cuda() |
|
|
|
p_hat.append(model.predict(z_a, z_b)) |
|
p_hat = torch.stack(p_hat, 0) |
|
return p_hat |
|
|
|
|
|
def predict_cmap_interaction(model, n0, n1, tensors, use_cuda): |
|
""" |
|
Predict whether a list of protein pairs will interact, as well as their contact map. |
|
|
|
:param model: Model to be trained |
|
:type model: dscript.models.interaction.ModelInteraction |
|
:param n0: First protein names |
|
:type n0: list[str] |
|
:param n1: Second protein names |
|
:type n1: list[str] |
|
:param tensors: Dictionary of protein names to embeddings |
|
:type tensors: dict[str, torch.Tensor] |
|
:param use_cuda: Whether to use GPU |
|
:type use_cuda: bool |
|
""" |
|
|
|
b = len(n0) |
|
|
|
p_hat = [] |
|
c_map_mag = [] |
|
for i in range(b): |
|
z_a = tensors[n0[i]] |
|
z_b = tensors[n1[i]] |
|
if use_cuda: |
|
z_a = z_a.cuda() |
|
z_b = z_b.cuda() |
|
|
|
cm, ph = model.map_predict(z_a, z_b) |
|
p_hat.append(ph) |
|
c_map_mag.append(torch.mean(cm)) |
|
p_hat = torch.stack(p_hat, 0) |
|
c_map_mag = torch.stack(c_map_mag, 0) |
|
return c_map_mag, p_hat |
|
|
|
|
|
def interaction_grad(model, n0, n1, y, tensors, use_cuda, weight=0.35): |
|
""" |
|
Compute gradient and backpropagate loss for a batch. |
|
|
|
:param model: Model to be trained |
|
:type model: dscript.models.interaction.ModelInteraction |
|
:param n0: First protein names |
|
:type n0: list[str] |
|
:param n1: Second protein names |
|
:type n1: list[str] |
|
:param y: Interaction labels |
|
:type y: torch.Tensor |
|
:param tensors: Dictionary of protein names to embeddings |
|
:type tensors: dict[str, torch.Tensor] |
|
:param use_cuda: Whether to use GPU |
|
:type use_cuda: bool |
|
:param weight: Weight on the contact map magnitude objective. BCE loss is :math:`1 - \\text{weight}`. |
|
:type weight: float |
|
|
|
:return: (Loss, number correct, mean square error, batch size) |
|
:rtype: (torch.Tensor, int, torch.Tensor, int) |
|
""" |
|
|
|
c_map_mag, p_hat = predict_cmap_interaction(model, n0, n1, tensors, use_cuda) |
|
if use_cuda: |
|
y = y.cuda() |
|
y = Variable(y) |
|
|
|
bce_loss = F.binary_cross_entropy(p_hat.float(), y.float()) |
|
cmap_loss = torch.mean(c_map_mag) |
|
loss = (weight * bce_loss) + ((1 - weight) * cmap_loss) |
|
b = len(p_hat) |
|
|
|
|
|
loss.backward() |
|
|
|
if use_cuda: |
|
y = y.cpu() |
|
p_hat = p_hat.cpu() |
|
|
|
with torch.no_grad(): |
|
guess_cutoff = 0.5 |
|
p_hat = p_hat.float() |
|
p_guess = (guess_cutoff * torch.ones(b) < p_hat).float() |
|
y = y.float() |
|
correct = torch.sum(p_guess == y).item() |
|
mse = torch.mean((y.float() - p_hat) ** 2).item() |
|
|
|
return loss, correct, mse, b |
|
|
|
|
|
def interaction_eval(model, test_iterator, tensors, use_cuda): |
|
""" |
|
Evaluate test data set performance. |
|
|
|
:param model: Model to be trained |
|
:type model: dscript.models.interaction.ModelInteraction |
|
:param test_iterator: Test data iterator |
|
:type test_iterator: torch.utils.data.DataLoader |
|
:param tensors: Dictionary of protein names to embeddings |
|
:type tensors: dict[str, torch.Tensor] |
|
:param use_cuda: Whether to use GPU |
|
:type use_cuda: bool |
|
|
|
:return: (Loss, number correct, mean square error, precision, recall, F1 Score, AUPR) |
|
:rtype: (torch.Tensor, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) |
|
""" |
|
p_hat = [] |
|
true_y = [] |
|
|
|
for n0, n1, y in test_iterator: |
|
p_hat.append(predict_interaction(model, n0, n1, tensors, use_cuda)) |
|
true_y.append(y) |
|
|
|
y = torch.cat(true_y, 0) |
|
p_hat = torch.cat(p_hat, 0) |
|
|
|
if use_cuda: |
|
y.cuda() |
|
p_hat = torch.Tensor([x.cuda() for x in p_hat]) |
|
p_hat.cuda() |
|
|
|
loss = F.binary_cross_entropy(p_hat.float(), y.float()).item() |
|
b = len(y) |
|
|
|
with torch.no_grad(): |
|
guess_cutoff = torch.Tensor([0.5]).float() |
|
p_hat = p_hat.float() |
|
y = y.float() |
|
p_guess = (guess_cutoff * torch.ones(b) < p_hat).float() |
|
correct = torch.sum(p_guess == y).item() |
|
mse = torch.mean((y.float() - p_hat) ** 2).item() |
|
|
|
tp = torch.sum(y * p_hat).item() |
|
pr = tp / torch.sum(p_hat).item() |
|
re = tp / torch.sum(y).item() |
|
f1 = 2 * pr * re / (pr + re) |
|
|
|
y = y.cpu().numpy() |
|
p_hat = p_hat.data.cpu().numpy() |
|
|
|
aupr = average_precision(y, p_hat) |
|
|
|
return loss, correct, mse, pr, re, f1, aupr |
|
|
|
|
|
def main(args): |
|
""" |
|
Run training from arguments. |
|
|
|
:meta private: |
|
""" |
|
|
|
output = args.outfile |
|
if output is None: |
|
output = sys.stdout |
|
else: |
|
output = open(output, "w") |
|
|
|
print(f'# Called as: {" ".join(sys.argv)}', file=output) |
|
if output is not sys.stdout: |
|
print(f'Called as: {" ".join(sys.argv)}') |
|
|
|
|
|
device = args.device |
|
use_cuda = (device >= 0) and torch.cuda.is_available() |
|
if use_cuda: |
|
torch.cuda.set_device(device) |
|
print( |
|
f"# Using CUDA device {device} - {torch.cuda.get_device_name(device)}", |
|
file=output, |
|
) |
|
else: |
|
print("# Using CPU", file=output) |
|
device = "cpu" |
|
|
|
batch_size = args.batch_size |
|
|
|
train_fi = args.train |
|
test_fi = args.val |
|
augment = args.augment |
|
embedding_h5 = args.embedding |
|
h5fi = h5py.File(embedding_h5, "r") |
|
|
|
print(f"# Loading training pairs from {train_fi}...", file=output) |
|
output.flush() |
|
|
|
train_df = pd.read_csv(train_fi, sep="\t", header=None) |
|
if augment: |
|
train_n0 = pd.concat((train_df[0], train_df[1]), axis=0) |
|
train_n1 = pd.concat((train_df[1], train_df[0]), axis=0) |
|
train_y = torch.from_numpy(pd.concat((train_df[2], train_df[2])).values) |
|
else: |
|
train_n0, train_n1 = train_df[0], train_df[1] |
|
train_y = torch.from_numpy(train_df[2].values) |
|
|
|
print(f"# Loading testing pairs from {test_fi}...", file=output) |
|
output.flush() |
|
|
|
test_df = pd.read_csv(test_fi, sep="\t", header=None) |
|
test_n0, test_n1 = test_df[0], test_df[1] |
|
test_y = torch.from_numpy(test_df[2].values) |
|
output.flush() |
|
|
|
train_pairs = PairedDataset(train_n0, train_n1, train_y) |
|
pairs_train_iterator = torch.utils.data.DataLoader( |
|
train_pairs, |
|
batch_size=batch_size, |
|
collate_fn=collate_paired_sequences, |
|
shuffle=True, |
|
) |
|
|
|
test_pairs = PairedDataset(test_n0, test_n1, test_y) |
|
pairs_test_iterator = torch.utils.data.DataLoader( |
|
test_pairs, |
|
batch_size=batch_size, |
|
collate_fn=collate_paired_sequences, |
|
shuffle=True, |
|
) |
|
|
|
output.flush() |
|
|
|
print(f"# Loading embeddings", file=output) |
|
tensors = {} |
|
all_proteins = set(train_n0).union(set(train_n1)).union(set(test_n0)).union(set(test_n1)) |
|
for prot_name in tqdm(all_proteins): |
|
tensors[prot_name] = torch.from_numpy(h5fi[prot_name][:, :]) |
|
|
|
use_cuda = (args.device > -1) and torch.cuda.is_available() |
|
|
|
if args.checkpoint is None: |
|
|
|
projection_dim = args.projection_dim |
|
dropout_p = args.dropout_p |
|
embedding = FullyConnectedEmbed(6165, projection_dim, dropout=dropout_p) |
|
print("# Initializing embedding model with:", file=output) |
|
print(f"\tprojection_dim: {projection_dim}", file=output) |
|
print(f"\tdropout_p: {dropout_p}", file=output) |
|
|
|
|
|
hidden_dim = args.hidden_dim |
|
kernel_width = args.kernel_width |
|
print("# Initializing contact model with:", file=output) |
|
print(f"\thidden_dim: {hidden_dim}", file=output) |
|
print(f"\tkernel_width: {kernel_width}", file=output) |
|
|
|
contact = ContactCNN(projection_dim, hidden_dim, kernel_width) |
|
|
|
|
|
use_W = args.use_w |
|
pool_width = args.pool_width |
|
print("# Initializing interaction model with:", file=output) |
|
print(f"\tpool_width: {pool_width}", file=output) |
|
print(f"\tuse_w: {use_W}", file=output) |
|
model = ModelInteraction(embedding, contact, use_W=use_W, pool_size=pool_width) |
|
|
|
print(model, file=output) |
|
|
|
else: |
|
print("# Loading model from checkpoint {}".format(args.checkpoint), file=output) |
|
model = torch.load(args.checkpoint) |
|
model.use_cuda = use_cuda |
|
|
|
if use_cuda: |
|
model = model.cuda() |
|
|
|
|
|
lr = args.lr |
|
wd = args.weight_decay |
|
num_epochs = args.num_epochs |
|
batch_size = args.batch_size |
|
report_steps = args.epoch_scale |
|
inter_weight = args.lambda_ |
|
cmap_weight = 1 - inter_weight |
|
digits = int(np.floor(np.log10(num_epochs))) + 1 |
|
save_prefix = args.save_prefix |
|
if save_prefix is None: |
|
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M") |
|
|
|
params = [p for p in model.parameters() if p.requires_grad] |
|
optim = torch.optim.Adam(params, lr=lr, weight_decay=wd) |
|
|
|
print(f'# Using save prefix "{save_prefix}"', file=output) |
|
print(f"# Training with Adam: lr={lr}, weight_decay={wd}", file=output) |
|
print(f"\tnum_epochs: {num_epochs}", file=output) |
|
print(f"\tepoch_scale: {report_steps}", file=output) |
|
print(f"\tbatch_size: {batch_size}", file=output) |
|
print(f"\tinteraction weight: {inter_weight}", file=output) |
|
print(f"\tcontact map weight: {cmap_weight}", file=output) |
|
output.flush() |
|
|
|
batch_report_fmt = "# [{}/{}] training {:.1%}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}" |
|
epoch_report_fmt = "# Finished Epoch {}/{}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}, Precision={:.6}, Recall={:.6}, F1={:.6}, AUPR={:.6}" |
|
|
|
N = len(pairs_train_iterator) * batch_size |
|
for epoch in range(num_epochs): |
|
|
|
model.train() |
|
|
|
n = 0 |
|
loss_accum = 0 |
|
acc_accum = 0 |
|
mse_accum = 0 |
|
|
|
|
|
for (z0, z1, y) in tqdm(pairs_train_iterator, desc=f"Epoch {epoch+1}/{num_epochs}",total=len(pairs_train_iterator)): |
|
|
|
loss, correct, mse, b = interaction_grad(model, z0, z1, y, tensors, use_cuda, weight=inter_weight) |
|
|
|
n += b |
|
delta = b * (loss - loss_accum) |
|
loss_accum += delta / n |
|
|
|
delta = correct - b * acc_accum |
|
acc_accum += delta / n |
|
|
|
delta = b * (mse - mse_accum) |
|
mse_accum += delta / n |
|
|
|
report = (n - b) // 100 < n // 100 |
|
|
|
optim.step() |
|
optim.zero_grad() |
|
model.clip() |
|
|
|
if report: |
|
tokens = [ |
|
epoch + 1, |
|
num_epochs, |
|
n / N, |
|
loss_accum, |
|
acc_accum, |
|
mse_accum, |
|
] |
|
if output is not sys.stdout: |
|
print(batch_report_fmt.format(*tokens), file=output) |
|
output.flush() |
|
|
|
if (epoch + 1) % report_steps == 0: |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
( |
|
inter_loss, |
|
inter_correct, |
|
inter_mse, |
|
inter_pr, |
|
inter_re, |
|
inter_f1, |
|
inter_aupr, |
|
) = interaction_eval(model, pairs_test_iterator, tensors, use_cuda) |
|
tokens = [ |
|
epoch + 1, |
|
num_epochs, |
|
inter_loss, |
|
inter_correct / (len(pairs_test_iterator) * batch_size), |
|
inter_mse, |
|
inter_pr, |
|
inter_re, |
|
inter_f1, |
|
inter_aupr, |
|
] |
|
print(epoch_report_fmt.format(*tokens), file=output) |
|
output.flush() |
|
|
|
|
|
if save_prefix is not None: |
|
save_path = save_prefix + "_epoch" + str(epoch + 1).zfill(digits) + ".sav" |
|
print(f"# Saving model to {save_path}", file=output) |
|
model.cpu() |
|
torch.save(model, save_path) |
|
if use_cuda: |
|
model.cuda() |
|
|
|
output.flush() |
|
|
|
if save_prefix is not None: |
|
save_path = save_prefix + "_final.sav" |
|
print(f"# Saving final model to {save_path}", file=output) |
|
model.cpu() |
|
torch.save(model, save_path) |
|
if use_cuda: |
|
model.cuda() |
|
|
|
output.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description=__doc__) |
|
add_args(parser) |
|
main(parser.parse_args()) |
|
|