henry000 commited on
Commit
8942278
·
1 Parent(s): c653e7e

🚑️ [Fix] when creating onnx, trt force use cuda

Browse files
Files changed (1) hide show
  1. yolo/utils/deploy_utils.py +4 -4
yolo/utils/deploy_utils.py CHANGED
@@ -51,8 +51,8 @@ class FastModelLoader:
51
  from onnxruntime import InferenceSession
52
  from torch.onnx import export
53
 
54
- model = create_model(self.cfg).eval().cuda()
55
- dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
56
  export(
57
  model,
58
  dummy_input,
@@ -80,8 +80,8 @@ class FastModelLoader:
80
  def _create_trt_weight(self):
81
  from torch2trt import torch2trt
82
 
83
- model = create_model(self.cfg).eval().cuda()
84
- dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
85
  logger.info(f"♻️ Creating TensorRT model")
86
  model_trt = torch2trt(model, [dummy_input])
87
  torch.save(model_trt.state_dict(), self.weight)
 
51
  from onnxruntime import InferenceSession
52
  from torch.onnx import export
53
 
54
+ model = create_model(self.cfg).eval()
55
+ dummy_input = torch.ones((1, 3, *self.cfg.image_size))
56
  export(
57
  model,
58
  dummy_input,
 
80
  def _create_trt_weight(self):
81
  from torch2trt import torch2trt
82
 
83
+ model = create_model(self.cfg).eval()
84
+ dummy_input = torch.ones((1, 3, *self.cfg.image_size))
85
  logger.info(f"♻️ Creating TensorRT model")
86
  model_trt = torch2trt(model, [dummy_input])
87
  torch.save(model_trt.state_dict(), self.weight)