File size: 5,586 Bytes
8896a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Make new predictions with a pre-trained model. One of --seqs or --embeddings is required.
"""
import sys, os
import torch
import h5py
import argparse
import datetime
import numpy as np
import pandas as pd
from scipy.special import comb
from tqdm import tqdm

from dscript.alphabets import Uniprot21
from dscript.fasta import parse
from dscript.language_model import lm_embed
from dscript.utils import log

def add_args(parser):
    """
    Create parser for command line utility

    :meta private:
    """

    parser.add_argument("--pairs", help="Candidate protein pairs to predict", required=True)
    parser.add_argument("--model", help="Pretrained Model", required=True)
    parser.add_argument("--seqs", help="Protein sequences in .fasta format")
    parser.add_argument("--embeddings", help="h5 file with embedded sequences")
    parser.add_argument("-o", "--outfile", help="File for predictions")
    parser.add_argument("-d", "--device", type=int, default=-1, help="Compute device to use")
    parser.add_argument(
        "--thresh",
        type=float,
        default=0.5,
        help="Positive prediction threshold - used to store contact maps and predictions in a separate file. [default: 0.5]",
    )
    return parser


def main(args):
    """
    Run new prediction from arguments.

    :meta private:
    """
    if args.seqs is None and args.embeddings is None:
        print("One of --seqs or --embeddings is required.")
        sys.exit(0)

    csvPath = args.pairs
    modelPath = args.model
    outPath = args.outfile
    seqPath = args.seqs
    embPath = args.embeddings
    device = args.device
    threshold = args.thresh

    # Set Outpath
    if outPath is None:
        outPath = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M.predictions")

    logFilePath = outPath + ".log"
    logFile = open(logFilePath,"w+")

    # Set Device
    use_cuda = (device >= 0) and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(device)
        print(f"# Using CUDA device {device} - {torch.cuda.get_device_name(device)}")
        log(f"Using CUDA device {device} - {torch.cuda.get_device_name(device)}", file=logFile)
    else:
        print("# Using CPU")
        log("# Using CPU", file=logFile)

    # Load Model
    try:
        if use_cuda:
            model = torch.load(modelPath).cuda()
        else:
            model = torch.load(modelPath).cpu()
            model.use_cuda = False
    except FileNotFoundError:
        print(f"# Model {modelPath} not found")
        log(f"Model {modelPath} not found", file=logFile)
        logFile.close()
        sys.exit(1)

    # Load Pairs
    try:
        pairs = pd.read_csv(csvPath, sep="\t", header=None)
        all_prots = set(pairs.iloc[:, 0]).union(set(pairs.iloc[:, 1]))
    except FileNotFoundError:
        print(f"# Pairs File {csvPath} not found")
        log(f"Pairs File {csvPath} not found", file=logFile)
        logFile.close()
        sys.exit(1)

    # Load Sequences or Embeddings
    if embPath is None:
        try:
            names, seqs = parse(open(seqPath, "r"))
            seqDict = {n: s for n, s in zip(names, seqs)}
        except FileNotFoundError:
            print(f"# Sequence File {seqPath} not found")
            log(f"Sequence File {seqPath} not found", file=logFile)
            logFile.close()
            sys.exit(1)
        print("# Generating Embeddings...")
        log("Generating Embeddings...", file=logFile)
        embeddings = {}
        for n in tqdm(all_prots):
            embeddings[n] = lm_embed(seqDict[n], use_cuda)
    else:
        print("# Loading Embeddings...")
        log("Loading Embeddings...", file=logFile)
        embedH5 = h5py.File(embPath, "r")
        embeddings = {}
        for n in tqdm(all_prots):
            embeddings[n] = torch.from_numpy(embedH5[n][:])
        embedH5.close()

    # Make Predictions
    print("# Making Predictions...")
    log("Making Predictions...", file=logFile)
    n = 0
    outPathAll = f"{outPath}.tsv"
    outPathPos = f"{outPath}.positive.tsv"
    cmap_file = h5py.File(f"{outPath}.cmaps.h5", "w")
    model.eval()
    with open(outPathAll, "w+") as f:
        with open(outPathPos, "w+") as pos_f:
            with torch.no_grad():
                for _, (n0, n1) in tqdm(pairs.iloc[:, :2].iterrows(), total=len(pairs)):
                    n0 = str(n0)
                    n1 = str(n1)
                    if n % 50 == 0:
                        f.flush()
                    n += 1
                    p0 = embeddings[n0]
                    p1 = embeddings[n1]
                    if use_cuda:
                        p0 = p0.cuda()
                        p1 = p1.cuda()
                    try:
                        cm, p = model.map_predict(p0, p1)
                        p = p.item()
                        f.write(f"{n0}\t{n1}\t{p}\n")
                        if p >= threshold:
                            pos_f.write(f"{n0}\t{n1}\t{p}\n")
                            cm_np = cm.squeeze().cpu().numpy()
                            dset = cmap_file.require_dataset(f"{n0}x{n1}", cm_np.shape, np.float32)
                            dset[:] = cm_np
                            #cmap_file.create_dataset(f"{n0}x{n1}", data=cm.squeeze().cpu().numpy())
                    except RuntimeError as e:
                        log(f"{n0} x {n1} skipped - CUDA out of memory", file=logFile)

    logFile.close()
    cmap_file.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__)
    add_args(parser)
    main(parser.parse_args())