alexnasa commited on
Commit
de5ce53
·
verified ·
1 Parent(s): 638d939

Update src/pixel3dmm/tracking/flame/lbs.py

Browse files
src/pixel3dmm/tracking/flame/lbs.py CHANGED
@@ -210,12 +210,7 @@ def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
210
  batch_size = max(betas.shape[0], pose.shape[0])
211
 
212
  # after: get the device index (an int), which Dynamo can capture
213
- device_idx = betas.get_device() # returns -1 on CPU, ≥0 for cuda:<idx>
214
- if device_idx >= 0:
215
- device = torch.device("cuda", device_idx)
216
- else:
217
- device = torch.device("cpu")
218
-
219
 
220
  # Add shape contribution
221
  v_shaped = v_template + blend_shapes(betas, shapedirs)
 
210
  batch_size = max(betas.shape[0], pose.shape[0])
211
 
212
  # after: get the device index (an int), which Dynamo can capture
213
+ device = betas.device
 
 
 
 
 
214
 
215
  # Add shape contribution
216
  v_shaped = v_template + blend_shapes(betas, shapedirs)