M3Site / model /egnn /utils.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
raw
history blame contribute delete
821 Bytes
import torch
from torch import sin, cos, atan2, acos
def rot_z(gamma):
return torch.tensor([
[cos(gamma), -sin(gamma), 0],
[sin(gamma), cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
def rot_y(beta):
return torch.tensor([
[cos(beta), 0, sin(beta)],
[0, 1, 0],
[-sin(beta), 0, cos(beta)]
], dtype=beta.dtype)
def rot(alpha, beta, gamma):
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
def get_node_feature_dims():
'''
each node has 25 dim feature corrsponding to residual type, sasa, dihedral, mu_r_norm
'''
return [20, 1, 1, 4, 5,640]
def get_edge_feature_dims():
'''
each node has 93 dim feature corrsponding to one hot sequence distance, interatomic distance, local frame orientation
'''
return [65, 1, 15, 12]