File size: 5,716 Bytes
18ee20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039c51a
 
 
18ee20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
##############################################################################
# Copyright (c) 2024, Oak Ridge National Laboratory                          #
# All rights reserved.                                                       #
#                                                                            #
# This file is part of HydraGNN and is distributed under a BSD 3-clause      #
# license. For the licensing terms see the LICENSE file in the top-level     #
# directory.                                                                 #
#                                                                            #
# SPDX-License-Identifier: BSD-3-Clause                                      #
##############################################################################


##############################################################################
########################  INSTRUCTIONS TO RUN THE CODE #######################
##############################################################################

"""
1. Follow run instructions on HydraGNN Wiki page: https://github.com/ORNL/HydraGNN/wiki/Run
2. Copy the "Ensemble_of_models" folder with the models into "examples/ensemble_learning"
3. Copy the script "inference_example_with_dummy_data_object.py" into "examples/ensemble_learning"
"""

import json, os
from mpi4py import MPI
import argparse

import torch
from torch_geometric.data import Data
from torch_geometric.transforms import Distance

import hydragnn

try:
    from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset
except ImportError:
    pass

from hydragnn.preprocess.utils import (
    RadiusGraph,
)

radius_graph = RadiusGraph(5.0, loop=False, max_num_neighbors=50)
transform_coordinates = Distance(norm=False, cat=False)

from ensemble_utils import model_ensemble

######################################################
###### PYTORCH GEOMETRIC DATA OBJECT DEFINITION ######
######################################################

# Define the dummy inputs
lattice_mat = torch.tensor([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]], dtype=torch.float32)  # 3x3 supercell size matrix
positions = torch.tensor([[0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], dtype=torch.float32)  # Nx3 positions matrix
atomic_numbers = torch.tensor([1, 6], dtype=torch.int64).unsqueeze(1)  # Atomic numbers as a 1D tensor
x = torch.cat([atomic_numbers, positions], dim=1)

# Creating the Data object
data = Data(
    supercell_size=lattice_mat,
    pos=positions,
    atomic_numbers=atomic_numbers.view(-1, 1),  # Reshape atomic_numbers to Nx1 tensor
    x=x,
)

data = radius_graph(data)
data = transform_coordinates(data)

######################################################
###### LOAD MODELS AND USE THEM FOR PREDICTIONS ######
######################################################


if __name__ == "__main__":
    ##################################################################################################################
    parser = argparse.ArgumentParser()
    print("gfm starting")
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--models_dir_folder", help="folder of trained models", type=str, default="./Ensemble_of_models")

    args = parser.parse_args()
    args.parameters = vars(args)
    ##################################################################################################################
    modeldirlists = args.models_dir_folder.split(",")
    assert len(modeldirlists) == 1 or len(modeldirlists) == 2
    if len(modeldirlists) == 1:
        modeldirlist = [os.path.join(args.models_dir_folder, name) for name in os.listdir(args.models_dir_folder) if
                        os.path.isdir(os.path.join(args.models_dir_folder, name))]
    else:
        modeldirlist = []
        for models_dir_folder in modeldirlists:
            modeldirlist.extend([os.path.join(models_dir_folder, name) for name in os.listdir(models_dir_folder) if
                                 os.path.isdir(os.path.join(models_dir_folder, name))])

    var_config = None
    for modeldir in modeldirlist:
        input_filename = os.path.join(modeldir, "config.json")
        with open(input_filename, "r") as f:
            config = json.load(f)
        if var_config is not None:
            assert var_config == config["NeuralNetwork"][
                "Variables_of_interest"], "Inconsistent variable config in %s" % input_filename
        else:
            var_config = config["NeuralNetwork"]["Variables_of_interest"]
    verbosity = config["Verbosity"]["level"]

    ##################################################################################################################
    # Always initialize for multi-rank training.
    comm_size, rank = hydragnn.utils.setup_ddp()
    ##################################################################################################################

    comm = MPI.COMM_WORLD

    ##################################################################################################################
    model_ens = model_ensemble(modeldirlist)
    model_ens = hydragnn.utils.get_distributed_model(model_ens, verbosity)
    model_ens.eval()
    ##################################################################################################################
    nheads = len(config["NeuralNetwork"]["Variables_of_interest"]["output_names"])

    pred_ens = model_ens(data)

    for model_id in range(len(modeldirlist)):
        print(f"Model {modeldirlist[model_id]} - Prediction of energy: \n", pred_ens[model_id][0])
        print(f"Model {modeldirlist[model_id]} - Prediction of forces: \n", pred_ens[model_id][1])