alexnasa commited on
Commit
638d939
·
verified ·
1 Parent(s): be22660

Update src/pixel3dmm/tracking/flame/lbs.py

Browse files
src/pixel3dmm/tracking/flame/lbs.py CHANGED
@@ -208,8 +208,15 @@ def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
208
  '''
209
 
210
  batch_size = max(betas.shape[0], pose.shape[0])
211
- device = betas.device
212
 
 
 
 
 
 
 
 
 
213
  # Add shape contribution
214
  v_shaped = v_template + blend_shapes(betas, shapedirs)
215
 
 
208
  '''
209
 
210
  batch_size = max(betas.shape[0], pose.shape[0])
 
211
 
212
+ # after: get the device index (an int), which Dynamo can capture
213
+ device_idx = betas.get_device() # returns -1 on CPU, ≥0 for cuda:<idx>
214
+ if device_idx >= 0:
215
+ device = torch.device("cuda", device_idx)
216
+ else:
217
+ device = torch.device("cpu")
218
+
219
+
220
  # Add shape contribution
221
  v_shaped = v_template + blend_shapes(betas, shapedirs)
222