File size: 3,794 Bytes
ce7bf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Layers for integrating Stochastic Differential Equations (SDEs).
进行随机微分方程

"""


from typing import Callable, Tuple

import torch
from tqdm.autonotebook import tqdm


def sde_integrate(
    sde_func: Callable,
    y0: torch.Tensor,
    tspan: Tuple,
    N: int,
    project_func: Callable = None,
    T_grid: torch.Tensor = None,
) -> list:
    """Integrate an Ito SDE with the Euler-Maruyama method.

    args:
        sde_func (function): a function that takes in time and y and returns SDE drift and diffusion terms for the evolution of y
        y0 (torch.tensor): the initial value of y, e.g. a noised protein structure tensor
        tspan (tuple): a tuple (t_i, t_f) with t_i being the initial time and t_f being the final time for integration
        N (int): number of integration steps

    returns:
        y_trajectory (list): a list of snapshots of the evolution of y as the SDE is integrated

    """

    with torch.no_grad():
        # Integrate SDE
        y_trajectory = [y0]

        if T_grid is None:
            T_grid = torch.linspace(tspan[0], tspan[1], N + 1).to(y0.device)
        else:
            assert T_grid.shape[0] == N + 1

        y = y0
        for t0, t1 in tqdm(
            zip(T_grid[:-1], T_grid[1:]), total=N, desc="Integrating SDE"
        ):
            t = t0
            dT = t1 - t0

            f, gZ = sde_func(t, y)
            y = y + dT * f + dT.abs().sqrt() * gZ
            y = y if project_func is None else project_func(t, y)

            y_trajectory.append(y)
    return y_trajectory


def sde_integrate_heun(
    sde_func: Callable,
    y0: torch.Tensor,
    tspan: Tuple,
    N: int,
    project_func: Callable = None,
    T_grid: torch.Tensor = None,
) -> list:
    """Integrate an Ito SDE with Heun's method.

    args:
        sde_func (function): a function that takes in time and y and returns SDE drift and diffusion terms for the evolution of y
        y0 (torch.tensor): the initial value of y, e.g. a noised protein structure tensor
        tspan (tuple): a tuple (t_i, t_f) with t_i being the initial time and t_f being the final time for integration
        N (int): number of integration steps

    returns:
        y_trajectory (list): a list of snapshots of the evolution of y as the SDE is integrated

    """

    with torch.no_grad():
        # Integrate SDE
        y_trajectory = [y0]
        dT = (tspan[1] - tspan[0]) / N

        if T_grid is None:
            T_grid = torch.linspace(tspan[0], tspan[1], N + 1).to(y0.device)
        else:
            assert T_grid.shape[0] == N + 1

        y = y0

        for t0, t1 in tqdm(
            zip(T_grid[:-1], T_grid[1:]), total=N, desc="Integrating SDE"
        ):
            # for i in tqdm(range(N)):
            # t = tspan[0] + i * dT
            t = t0
            dT = t1 - t0
            f, gZ = sde_func(t, y)
            y_pred = y + dT * f + dT.abs().sqrt() * gZ
            f_pred, gZ_pred = sde_func(t, y_pred)
            y_correct = y + dT * f_pred + dT.abs().sqrt() * gZ
            y = (y_pred + y_correct) / 2.0
            y = y if project_func is None else project_func(t, y)
            y_trajectory.append(y)

    return y_trajectory