File size: 2,741 Bytes
3ad8be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pandas as pd
from torch.utils.data import DataLoader
import torch
from UltraFlow import dataset, commons, losses, models
import numpy as np

def trans_device(batch, device):
    return [x if isinstance(x, list) else x.to(device) for x in batch]

@torch.no_grad()
def reproduce_result(config, test_set, model, device):
    """
    Evaluate the model.
    Parameters:
        split (str): split to evaluate. Can be ``train``, ``val`` or ``test``.
    """

    dataloader = DataLoader(test_set, batch_size=config.train.batch_size,
                            shuffle=False, collate_fn=dataset.collate_pdbbind_affinity_multi_task_v2,
                            num_workers=config.train.num_workers)
    y_preds, y_preds_IC50, y_preds_K = torch.tensor([]).to(device), torch.tensor([]).to(device), torch.tensor([]).to(device)
    y, y_IC50, y_K = torch.tensor([]).to(device), torch.tensor([]).to(device), torch.tensor([]).to(device)
    model.eval()
    for batch in dataloader:
        if device.type == "cuda":
            batch = trans_device(batch, device)

        (regression_loss_IC50, regression_loss_K), \
        (affinity_pred_IC50, affinity_pred_K), \
        (affinity_IC50, affinity_K) = model(batch, ASRP=False)

        affinity_pred = torch.cat([affinity_pred_IC50, affinity_pred_K], dim=0)
        affinity = torch.cat([affinity_IC50, affinity_K], dim=0)

        y_preds_IC50 = torch.cat([y_preds_IC50, affinity_pred_IC50])
        y_preds_K = torch.cat([y_preds_K, affinity_pred_K])
        y_preds = torch.cat([y_preds, affinity_pred])

        y_IC50 = torch.cat([y_IC50, affinity_IC50])
        y_K = torch.cat([y_K, affinity_K])
        y = torch.cat([y, affinity])

    metics_dict = commons.get_sbap_regression_metric_dict(np.array(y.cpu()), np.array(y_preds.cpu()))
    result_str = commons.get_matric_output_str(metics_dict)

    pd.DataFrame({'pred_values': y_preds.cpu().squeeze().tolist(),
                  'gd_values': y.cpu().squeeze().tolist()}
                 ).to_csv('pred_result.csv')

    if len(y_IC50) > 0:
        metics_dict_IC50 = commons.get_sbap_regression_metric_dict(np.array(y_IC50.cpu()), np.array(y_preds_IC50.cpu()))
        result_str_IC50 = commons.get_matric_output_str(metics_dict_IC50)
        result_str_IC50 = f'| IC50 ' + result_str_IC50
        result_str += result_str_IC50

    if len(y_K) > 0:
        metics_dict_K = commons.get_sbap_regression_metric_dict(np.array(y_K.cpu()), np.array(y_preds_K.cpu()))
        result_str_K = commons.get_matric_output_str(metics_dict_K)
        result_str_K = f'| K ' + result_str_K
        result_str += result_str_K

    # print(result_str)
    return metics_dict['RMSE'], metics_dict['MAE'], metics_dict['SD'], metics_dict['Pearson']