File size: 1,291 Bytes
0fdcb79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn

from dockformerpp.model.primitives import Linear
from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array
from dockformerpp.utils.geometry.rotation_matrix import Rot3Array
from dockformerpp.utils.geometry.vector import Vec3Array


class QuatRigid(nn.Module):
    def __init__(self, c_hidden, full_quat):
        super().__init__()
        self.full_quat = full_quat
        if self.full_quat:
            rigid_dim = 7
        else:
            rigid_dim = 6

        self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32)

    def forward(self, activations: torch.Tensor) -> Rigid3Array:
        # NOTE: During training, this needs to be run in higher precision
        rigid_flat = self.linear(activations)
        
        rigid_flat = torch.unbind(rigid_flat, dim=-1)
        if(self.full_quat):
            qw, qx, qy, qz = rigid_flat[:4]
            translation = rigid_flat[4:]
        else:
            qx, qy, qz = rigid_flat[:3]
            qw = torch.ones_like(qx)
            translation = rigid_flat[3:]

        rotation = Rot3Array.from_quaternion(
            qw, qx, qy, qz, normalize=True,
        )
        translation = Vec3Array(*translation)
        return Rigid3Array(rotation, translation)