Weiyu Liu commited on
Commit
3827c6d
·
1 Parent(s): 824a79e

compute rot 6d does not depend on cuda

Browse files
src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc and b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc differ
 
src/StructDiffusion/diffusion/pose_conversion.py CHANGED
@@ -45,11 +45,6 @@ def get_diffusion_variables_from_H(poses):
45
 
46
  def get_struct_objs_poses(x):
47
 
48
- on_gpu = x.is_cuda
49
- if not on_gpu:
50
- x = x.cuda()
51
-
52
- # assert x.is_cuda, "compute_rotation_matrix_from_ortho6d requires input to be on gpu"
53
  device = x.device
54
 
55
  # important: the noisy x can go out of bounds
@@ -72,10 +67,6 @@ def get_struct_objs_poses(x):
72
  struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4
73
  pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4
74
 
75
- if not on_gpu:
76
- struct_pose = struct_pose.cpu()
77
- pc_poses_in_struct = pc_poses_in_struct.cpu()
78
-
79
  return struct_pose, pc_poses_in_struct
80
 
81
 
 
45
 
46
  def get_struct_objs_poses(x):
47
 
 
 
 
 
 
48
  device = x.device
49
 
50
  # important: the noisy x can go out of bounds
 
67
  struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4
68
  pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4
69
 
 
 
 
 
70
  return struct_pose, pc_poses_in_struct
71
 
72
 
src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc differ
 
src/StructDiffusion/utils/rotation_continuity.py CHANGED
@@ -21,7 +21,7 @@ def compute_pose_from_rotation_matrix(T_pose, r_matrix):
21
  def normalize_vector( v, return_mag =False):
22
  batch=v.shape[0]
23
  v_mag = torch.sqrt(v.pow(2).sum(1))# batch
24
- v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).cuda()))
25
  v_mag = v_mag.view(batch,1).expand(batch,v.shape[1])
26
  v = v/v_mag
27
  if(return_mag==True):
 
21
  def normalize_vector( v, return_mag =False):
22
  batch=v.shape[0]
23
  v_mag = torch.sqrt(v.pow(2).sum(1))# batch
24
+ v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).to(v.device)))
25
  v_mag = v_mag.view(batch,1).expand(batch,v.shape[1])
26
  v = v/v_mag
27
  if(return_mag==True):