glenn-jocher commited on
Commit
f62609e
·
unverified ·
1 Parent(s): 4b284a1

Update check_requirements() with `cmds=()` argument (#7543)

Browse files
Files changed (2) hide show
  1. export.py +2 -8
  2. utils/general.py +3 -3
export.py CHANGED
@@ -218,14 +218,8 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
218
  # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
219
  try:
220
  assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
221
- try:
222
- import tensorrt as trt
223
- except Exception:
224
- s = f"\n{prefix} tensorrt not found and is required by YOLOv5"
225
- LOGGER.info(f"{s}, attempting auto-update...")
226
- r = '-U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com'
227
- LOGGER.info(subprocess.check_output(f"pip install {r}", shell=True).decode())
228
- import tensorrt as trt
229
 
230
  if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
231
  grid = model.model[-1].anchor_grid
 
218
  # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
219
  try:
220
  assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
221
+ check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
222
+ import tensorrt as trt
 
 
 
 
 
 
223
 
224
  if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
225
  grid = model.model[-1].anchor_grid
utils/general.py CHANGED
@@ -321,7 +321,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
321
 
322
 
323
  @try_except
324
- def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
325
  # Check installed dependencies meet requirements (pass *.txt file or list of packages)
326
  prefix = colorstr('red', 'bold', 'requirements:')
327
  check_python() # check python version
@@ -334,7 +334,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
334
  requirements = [x for x in requirements if x not in exclude]
335
 
336
  n = 0 # number of packages updates
337
- for r in requirements:
338
  try:
339
  pkg.require(r)
340
  except Exception: # DistributionNotFound or VersionConflict if requirements not met
@@ -343,7 +343,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
343
  LOGGER.info(f"{s}, attempting auto-update...")
344
  try:
345
  assert check_online(), f"'pip install {r}' skipped (offline)"
346
- LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode())
347
  n += 1
348
  except Exception as e:
349
  LOGGER.warning(f'{prefix} {e}')
 
321
 
322
 
323
  @try_except
324
+ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
325
  # Check installed dependencies meet requirements (pass *.txt file or list of packages)
326
  prefix = colorstr('red', 'bold', 'requirements:')
327
  check_python() # check python version
 
334
  requirements = [x for x in requirements if x not in exclude]
335
 
336
  n = 0 # number of packages updates
337
+ for i, r in enumerate(requirements):
338
  try:
339
  pkg.require(r)
340
  except Exception: # DistributionNotFound or VersionConflict if requirements not met
 
343
  LOGGER.info(f"{s}, attempting auto-update...")
344
  try:
345
  assert check_online(), f"'pip install {r}' skipped (offline)"
346
+ LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode())
347
  n += 1
348
  except Exception as e:
349
  LOGGER.warning(f'{prefix} {e}')