Update check_requirements() with `cmds=()` argument (#7543)
Browse files- export.py +2 -8
- utils/general.py +3 -3
export.py
CHANGED
@@ -218,14 +218,8 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
|
218 |
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
219 |
try:
|
220 |
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
221 |
-
|
222 |
-
|
223 |
-
except Exception:
|
224 |
-
s = f"\n{prefix} tensorrt not found and is required by YOLOv5"
|
225 |
-
LOGGER.info(f"{s}, attempting auto-update...")
|
226 |
-
r = '-U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com'
|
227 |
-
LOGGER.info(subprocess.check_output(f"pip install {r}", shell=True).decode())
|
228 |
-
import tensorrt as trt
|
229 |
|
230 |
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
231 |
grid = model.model[-1].anchor_grid
|
|
|
218 |
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
219 |
try:
|
220 |
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
221 |
+
check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
|
222 |
+
import tensorrt as trt
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
225 |
grid = model.model[-1].anchor_grid
|
utils/general.py
CHANGED
@@ -321,7 +321,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
|
|
321 |
|
322 |
|
323 |
@try_except
|
324 |
-
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
|
325 |
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
|
326 |
prefix = colorstr('red', 'bold', 'requirements:')
|
327 |
check_python() # check python version
|
@@ -334,7 +334,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
|
|
334 |
requirements = [x for x in requirements if x not in exclude]
|
335 |
|
336 |
n = 0 # number of packages updates
|
337 |
-
for r in requirements:
|
338 |
try:
|
339 |
pkg.require(r)
|
340 |
except Exception: # DistributionNotFound or VersionConflict if requirements not met
|
@@ -343,7 +343,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
|
|
343 |
LOGGER.info(f"{s}, attempting auto-update...")
|
344 |
try:
|
345 |
assert check_online(), f"'pip install {r}' skipped (offline)"
|
346 |
-
LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode())
|
347 |
n += 1
|
348 |
except Exception as e:
|
349 |
LOGGER.warning(f'{prefix} {e}')
|
|
|
321 |
|
322 |
|
323 |
@try_except
|
324 |
+
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
|
325 |
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
|
326 |
prefix = colorstr('red', 'bold', 'requirements:')
|
327 |
check_python() # check python version
|
|
|
334 |
requirements = [x for x in requirements if x not in exclude]
|
335 |
|
336 |
n = 0 # number of packages updates
|
337 |
+
for i, r in enumerate(requirements):
|
338 |
try:
|
339 |
pkg.require(r)
|
340 |
except Exception: # DistributionNotFound or VersionConflict if requirements not met
|
|
|
343 |
LOGGER.info(f"{s}, attempting auto-update...")
|
344 |
try:
|
345 |
assert check_online(), f"'pip install {r}' skipped (offline)"
|
346 |
+
LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode())
|
347 |
n += 1
|
348 |
except Exception as e:
|
349 |
LOGGER.warning(f'{prefix} {e}')
|