diff --git a/.gitattributes b/.gitattributes
index c7d9f3332a950355d5a77d85000f05e6f45435ea..b442ae273b4bfd119296a7602d2cec21cd47ee17 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -32,3 +32,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+figure/figure1.png filter=lfs diff=lfs merge=lfs -text
+figure/figure2.png filter=lfs diff=lfs merge=lfs -text
+figure/figure3.png filter=lfs diff=lfs merge=lfs -text
+ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/checkpoint/ckpt.t7 filter=lfs diff=lfs merge=lfs -text
+ultralytics/yolo/v8/detect/night_motorbikes.mp4 filter=lfs diff=lfs merge=lfs -text
+YOLOv8_DeepSORT_TRACKING_SCRIPT.ipynb filter=lfs diff=lfs merge=lfs -text
+YOLOv8_Detection_Tracking_CustomData_Complete.ipynb filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ef69c309cee09a99ca81a02b4fa4b0fad608f7a4
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,150 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# datasets and projects
+datasets/
+runs/
+wandb/
+
+.DS_Store
+
+# Neural Network weights -----------------------------------------------------------------------------------------------
+*.weights
+*.pt
+*.pb
+*.onnx
+*.engine
+*.mlmodel
+*.torchscript
+*.tflite
+*.h5
+*_saved_model/
+*_web_model/
+*_openvino_model/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..480127f24b4f28e2dfd794f72c7f26f9168082f9
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,65 @@
+# Define hooks for code formations
+# Will be applied on any updated commit files if a user has installed and linked commit hook
+
+default_language_version:
+ python: python3.8
+
+exclude: 'docs/'
+# Define bot property if installed via https://github.com/marketplace/pre-commit-ci
+ci:
+ autofix_prs: true
+ autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
+ autoupdate_schedule: monthly
+ # submodules: true
+
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.3.0
+ hooks:
+ # - id: end-of-file-fixer
+ - id: trailing-whitespace
+ - id: check-case-conflict
+ - id: check-yaml
+ - id: check-toml
+ - id: pretty-format-json
+ - id: check-docstring-first
+
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v2.37.3
+ hooks:
+ - id: pyupgrade
+ name: Upgrade code
+ args: [ --py37-plus ]
+
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.10.1
+ hooks:
+ - id: isort
+ name: Sort imports
+
+ - repo: https://github.com/pre-commit/mirrors-yapf
+ rev: v0.32.0
+ hooks:
+ - id: yapf
+ name: YAPF formatting
+
+ - repo: https://github.com/executablebooks/mdformat
+ rev: 0.7.16
+ hooks:
+ - id: mdformat
+ name: MD formatting
+ additional_dependencies:
+ - mdformat-gfm
+ - mdformat-black
+ # exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
+
+ - repo: https://github.com/PyCQA/flake8
+ rev: 5.0.4
+ hooks:
+ - id: flake8
+ name: PEP8
+
+ #- repo: https://github.com/asottile/yesqa
+ # rev: v1.4.0
+ # hooks:
+ # - id: yesqa
\ No newline at end of file
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..5e9c66afa07f59291cfbec1ca28772501fdf2d76
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,113 @@
+## Contributing to YOLOv8 π
+
+We love your input! We want to make contributing to YOLOv8 as easy and transparent as possible, whether it's:
+
+- Reporting a bug
+- Discussing the current state of the code
+- Submitting a fix
+- Proposing a new feature
+- Becoming a maintainer
+
+YOLOv8 works so well due to our combined community effort, and for every small improvement you contribute you will be
+helping push the frontiers of what's possible in AI π!
+
+## Submitting a Pull Request (PR) π οΈ
+
+Submitting a PR is easy! This example shows how to submit a PR for updating `requirements.txt` in 4 steps:
+
+### 1. Select File to Update
+
+Select `requirements.txt` to update by clicking on it in GitHub.
+
+

+
+### 2. Click 'Edit this file'
+
+Button is in top-right corner.
+
+
+
+### 3. Make Changes
+
+Change `matplotlib` version from `3.2.2` to `3.3`.
+
+
+
+### 4. Preview Changes and Submit PR
+
+Click on the **Preview changes** tab to verify your updates. At the bottom of the screen select 'Create a **new branch**
+for this commit', assign your branch a descriptive name such as `fix/matplotlib_version` and click the green **Propose
+changes** button. All done, your PR is now submitted to YOLOv8 for review and approval π!
+
+
+
+### PR recommendations
+
+To allow your work to be integrated as seamlessly as possible, we advise you to:
+
+- β
Verify your PR is **up-to-date** with `ultralytics/ultralytics` `master` branch. If your PR is behind you can update
+ your code by clicking the 'Update branch' button or by running `git pull` and `git merge master` locally.
+
+
+
+- β
Verify all YOLOv8 Continuous Integration (CI) **checks are passing**.
+
+
+
+- β
Reduce changes to the absolute **minimum** required for your bug fix or feature addition. _"It is not daily increase
+ but daily decrease, hack away the unessential. The closer to the source, the less wastage there is."_ β Bruce Lee
+
+### Docstrings
+
+Not all functions or classes require docstrings but when they do, we follow [google-stlye docstrings format](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings). Here is an example:
+
+```python
+"""
+ What the function does - performs nms on given detection predictions
+
+ Args:
+ arg1: The description of the 1st argument
+ arg2: The description of the 2nd argument
+
+ Returns:
+ What the function returns. Empty if nothing is returned
+
+ Raises:
+ Exception Class: When and why this exception can be raised by the function.
+"""
+```
+
+## Submitting a Bug Report π
+
+If you spot a problem with YOLOv8 please submit a Bug Report!
+
+For us to start investigating a possible problem we need to be able to reproduce it ourselves first. We've created a few
+short guidelines below to help users provide what we need in order to get started.
+
+When asking a question, people will be better able to provide help if you provide **code** that they can easily
+understand and use to **reproduce** the problem. This is referred to by community members as creating
+a [minimum reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). Your code that reproduces
+the problem should be:
+
+- β
**Minimal** β Use as little code as possible that still produces the same problem
+- β
**Complete** β Provide **all** parts someone else needs to reproduce your problem in the question itself
+- β
**Reproducible** β Test the code you're about to provide to make sure it reproduces the problem
+
+In addition to the above requirements, for [Ultralytics](https://ultralytics.com/) to provide assistance your code
+should be:
+
+- β
**Current** β Verify that your code is up-to-date with current
+ GitHub [master](https://github.com/ultralytics/ultralytics/tree/main), and if necessary `git pull` or `git clone` a new
+ copy to ensure your problem has not already been resolved by previous commits.
+- β
**Unmodified** β Your problem must be reproducible without any modifications to the codebase in this
+ repository. [Ultralytics](https://ultralytics.com/) does not provide support for custom code β οΈ.
+
+If you believe your problem meets all of the above criteria, please close this issue and raise a new one using the π
+**Bug Report** [template](https://github.com/ultralytics/ultralytics/issues/new/choose) and providing
+a [minimum reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) to help us better
+understand and diagnose your problem.
+
+## License
+
+By contributing, you agree that your contributions will be licensed under
+the [GPL-3.0 license](https://choosealicense.com/licenses/gpl-3.0/)
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..1635ec154384b8857c15841c27f1845f90d6b988
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,5 @@
+include *.md
+include requirements.txt
+include LICENSE
+include setup.py
+recursive-include ultralytics *.yaml
diff --git a/README.md b/README.md
index fdd03d7d6c6c1decda050b04e83df02f473f5805..acec568fa30cba210950bab8b08e7590455a4bf6 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,82 @@
----
-title: YOLOv8 Real Time
-emoji: π’
-colorFrom: pink
-colorTo: green
-sdk: gradio
-sdk_version: 3.29.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+YOLOv8 Object Detection with DeepSORT Tracking(ID + Trails)
+
+## Google Colab File Link (A Single Click Solution)
+The google colab file link for yolov8 object detection and tracking is provided below, you can check the implementation in Google Colab, and its a single click implementation, you just need to select the Run Time as GPU, and click on Run All.
+
+[`Google Colab File`](https://colab.research.google.com/drive/1U6cnTQ0JwCg4kdHxYSl2NAhU4wK18oAu?usp=sharing)
+
+## Object Detection and Tracking (ID + Trails) using YOLOv8 on Custom Data
+## Google Colab File Link (A Single Click Solution)
+[`Google Colab File`](https://colab.research.google.com/drive/1dEpI2k3m1i0vbvB4bNqPRQUO0gSBTz25?usp=sharing)
+
+## YOLOv8 Segmentation with DeepSORT Object Tracking
+
+[`Github Repo Link`](https://github.com/MuhammadMoinFaisal/YOLOv8_Segmentation_DeepSORT_Object_Tracking.git)
+
+## Steps to run Code
+
+- Clone the repository
+```
+git clone https://github.com/MuhammadMoinFaisal/YOLOv8-DeepSORT-Object-Tracking.git
+```
+- Goto the cloned folder.
+```
+cd YOLOv8-DeepSORT-Object-Tracking
+```
+- Install the dependecies
+```
+pip install -e '.[dev]'
+
+```
+
+- Setting the Directory.
+```
+cd ultralytics/yolo/v8/detect
+
+```
+- Downloading the DeepSORT Files From The Google Drive
+```
+
+https://drive.google.com/drive/folders/1kna8eWGrSfzaR6DtNJ8_GchGgPMv3VC8?usp=sharing
+```
+- After downloading the DeepSORT Zip file from the drive, unzip it go into the subfolders and place the deep_sort_pytorch folder into the yolo/v8/detect folder
+
+- Downloading a Sample Video from the Google Drive
+```
+gdown "https://drive.google.com/uc?id=1rjBn8Fl1E_9d0EMVtL24S9aNQOJAveR5&confirm=t"
+```
+
+- Run the code with mentioned command below.
+
+- For yolov8 object detection + Tracking
+```
+python predict.py model=yolov8l.pt source="test3.mp4" show=True
+```
+- For yolov8 object detection + Tracking + Vehicle Counting
+- Download the updated predict.py file from the Google Drive and place it into ultralytics/yolo/v8/detect folder
+- Google Drive Link
+```
+https://drive.google.com/drive/folders/1awlzTGHBBAn_2pKCkLFADMd1EN_rJETW?usp=sharing
+```
+- For yolov8 object detection + Tracking + Vehicle Counting
+```
+python predict.py model=yolov8l.pt source="test3.mp4" show=True
+```
+
+### RESULTS
+
+#### Vehicles Detection, Tracking and Counting
+
+
+#### Vehicles Detection, Tracking and Counting
+
+
+
+### Watch the Complete Step by Step Explanation
+
+- Video Tutorial Link [`YouTube Link`](https://www.youtube.com/watch?v=9jRRZ-WL698)
+
+
+[]([https://www.youtube.com/watch?v=StTqXEQ2l-Y](https://www.youtube.com/watch?v=9jRRZ-WL698))
+
diff --git a/YOLOv8_DeepSORT_TRACKING_SCRIPT.ipynb b/YOLOv8_DeepSORT_TRACKING_SCRIPT.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..ba893bf990a58c3ab754e748230a83500f8217fa
--- /dev/null
+++ b/YOLOv8_DeepSORT_TRACKING_SCRIPT.ipynb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0918af0bfa0ef2e0e9d26d9a8b06e2d706f5a5685d4e19eb58877a8036092ac
+size 16618677
diff --git a/YOLOv8_Detection_Tracking_CustomData_Complete.ipynb b/YOLOv8_Detection_Tracking_CustomData_Complete.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..7a8786db0d3690d707223a700bc5585bea2a395e
--- /dev/null
+++ b/YOLOv8_Detection_Tracking_CustomData_Complete.ipynb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d514434f9b1d8f5f3a7fb72e782f82d3c136523a7fc7bb41c2a2a390f4aa783
+size 22625415
diff --git a/figure/figure1.png b/figure/figure1.png
new file mode 100644
index 0000000000000000000000000000000000000000..e5aca066a69033fc59f2a947f762f3253eb04c5c
--- /dev/null
+++ b/figure/figure1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:664385ae049b026377706815fe377fe731aef846900f0343d50a214a269b4707
+size 2814838
diff --git a/figure/figure2.png b/figure/figure2.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f3b640b32eb202e7528ccef1cc1045634ed04ae
--- /dev/null
+++ b/figure/figure2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2c383e5fcfc524b385e41d33b0bbf56320a3a81d3a96f9500f7f254009c8f03
+size 2632436
diff --git a/figure/figure3.png b/figure/figure3.png
new file mode 100644
index 0000000000000000000000000000000000000000..7ef17e500eed453b2c5c2c70c04dec90f36ac0b7
--- /dev/null
+++ b/figure/figure3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cb719fd4505ae476bebe57131165764d11e6193d165cdecf70a760100cf6551f
+size 2941484
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f71d4bd28dcff828114c35247a95440685c7c5ad
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,95 @@
+site_name: Ultralytics Docs
+repo_url: https://github.com/ultralytics/ultralytics
+repo_name: Ultralytics
+
+theme:
+ name: "material"
+ logo: https://github.com/ultralytics/assets/raw/main/logo/Ultralytics-logomark-white.png
+ icon:
+ repo: fontawesome/brands/github
+ admonition:
+ note: octicons/tag-16
+ abstract: octicons/checklist-16
+ info: octicons/info-16
+ tip: octicons/squirrel-16
+ success: octicons/check-16
+ question: octicons/question-16
+ warning: octicons/alert-16
+ failure: octicons/x-circle-16
+ danger: octicons/zap-16
+ bug: octicons/bug-16
+ example: octicons/beaker-16
+ quote: octicons/quote-16
+
+ palette:
+ # Palette toggle for light mode
+ - scheme: default
+ toggle:
+ icon: material/brightness-7
+ name: Switch to dark mode
+
+ # Palette toggle for dark mode
+ - scheme: slate
+ toggle:
+ icon: material/brightness-4
+ name: Switch to light mode
+ features:
+ - content.code.annotate
+ - content.tooltips
+ - search.highlight
+ - search.share
+ - search.suggest
+ - toc.follow
+
+extra_css:
+ - stylesheets/style.css
+
+markdown_extensions:
+ # Div text decorators
+ - admonition
+ - pymdownx.details
+ - pymdownx.superfences
+ - tables
+ - attr_list
+ - def_list
+ # Syntax highlight
+ - pymdownx.highlight:
+ anchor_linenums: true
+ - pymdownx.inlinehilite
+ - pymdownx.snippets
+
+ # Button
+ - attr_list
+
+ # Content tabs
+ - pymdownx.superfences
+ - pymdownx.tabbed:
+ alternate_style: true
+
+ # Highlight
+ - pymdownx.critic
+ - pymdownx.caret
+ - pymdownx.keys
+ - pymdownx.mark
+ - pymdownx.tilde
+plugins:
+ - mkdocstrings
+
+# Primary navigation
+nav:
+ - Quickstart: quickstart.md
+ - CLI: cli.md
+ - Python Interface: sdk.md
+ - Configuration: config.md
+ - Customization Guide: engine.md
+ - Ultralytics HUB: hub.md
+ - iOS and Android App: app.md
+ - Reference:
+ - Python Model interface: reference/model.md
+ - Engine:
+ - Trainer: reference/base_trainer.md
+ - Validator: reference/base_val.md
+ - Predictor: reference/base_pred.md
+ - Exporter: reference/exporter.md
+ - nn Module: reference/nn.md
+ - operations: reference/ops.md
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fbf5aaacfa1c0b665ef67a853d5c18f941b5cd8d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,46 @@
+# Ultralytics requirements
+# Usage: pip install -r requirements.txt
+
+# Base ----------------------------------------
+hydra-core>=1.2.0
+matplotlib>=3.2.2
+numpy>=1.18.5
+opencv-python>=4.1.1
+Pillow>=7.1.2
+PyYAML>=5.3.1
+requests>=2.23.0
+scipy>=1.4.1
+torch>=1.7.0
+torchvision>=0.8.1
+tqdm>=4.64.0
+
+# Logging -------------------------------------
+tensorboard>=2.4.1
+# clearml
+# comet
+
+# Plotting ------------------------------------
+pandas>=1.1.4
+seaborn>=0.11.0
+
+# Export --------------------------------------
+# coremltools>=6.0 # CoreML export
+# onnx>=1.12.0 # ONNX export
+# onnx-simplifier>=0.4.1 # ONNX simplifier
+# nvidia-pyindex # TensorRT export
+# nvidia-tensorrt # TensorRT export
+# scikit-learn==0.19.2 # CoreML quantization
+# tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
+# tensorflowjs>=3.9.0 # TF.js export
+# openvino-dev # OpenVINO export
+
+# Extras --------------------------------------
+ipython # interactive notebook
+psutil # system utilization
+thop>=0.1.1 # FLOPs computation
+# albumentations>=1.0.3
+# pycocotools>=2.0.6 # COCO mAP
+# roboflow
+
+# HUB -----------------------------------------
+GitPython>=3.1.24
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..d7c4cb3e1a4d34291816835b071ba0d75243d79b
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,54 @@
+# Project-wide configuration file, can be used for package metadata and other toll configurations
+# Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments
+# Local usage: pip install pre-commit, pre-commit run --all-files
+
+[metadata]
+license_file = LICENSE
+description_file = README.md
+
+[tool:pytest]
+norecursedirs =
+ .git
+ dist
+ build
+addopts =
+ --doctest-modules
+ --durations=25
+ --color=yes
+
+[flake8]
+max-line-length = 120
+exclude = .tox,*.egg,build,temp
+select = E,W,F
+doctests = True
+verbose = 2
+# https://pep8.readthedocs.io/en/latest/intro.html#error-codes
+format = pylint
+# see: https://www.flake8rules.com/
+ignore = E731,F405,E402,F401,W504,E127,E231,E501,F403
+ # E731: Do not assign a lambda expression, use a def
+ # F405: name may be undefined, or defined from star imports: module
+ # E402: module level import not at top of file
+ # F401: module imported but unused
+ # W504: line break after binary operator
+ # E127: continuation line over-indented for visual indent
+ # E231: missing whitespace after β,β, β;β, or β:β
+ # E501: line too long
+ # F403: βfrom module import *β used; unable to detect undefined names
+
+[isort]
+# https://pycqa.github.io/isort/docs/configuration/options.html
+line_length = 120
+# see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html
+multi_line_output = 0
+
+[yapf]
+based_on_style = pep8
+spaces_before_comment = 2
+COLUMN_LIMIT = 120
+COALESCE_BRACKETS = True
+SPACES_AROUND_POWER_OPERATOR = True
+SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = False
+SPLIT_BEFORE_CLOSING_BRACKET = False
+SPLIT_BEFORE_FIRST_ARGUMENT = False
+# EACH_DICT_ENTRY_ON_SEPARATE_LINE = False
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5d13d953b4ac461dc8798c16e93dd19796d84bb
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,53 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import re
+from pathlib import Path
+
+import pkg_resources as pkg
+from setuptools import find_packages, setup
+
+# Settings
+FILE = Path(__file__).resolve()
+ROOT = FILE.parent # root directory
+README = (ROOT / "README.md").read_text(encoding="utf-8")
+REQUIREMENTS = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements((ROOT / 'requirements.txt').read_text())]
+
+
+def get_version():
+ file = ROOT / 'ultralytics/__init__.py'
+ return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(), re.M)[1]
+
+
+setup(
+ name="ultralytics", # name of pypi package
+ version=get_version(), # version of pypi package
+ python_requires=">=3.7.0",
+ license='GPL-3.0',
+ description='Ultralytics YOLOv8 and HUB',
+ long_description=README,
+ long_description_content_type="text/markdown",
+ url="https://github.com/ultralytics/ultralytics",
+ project_urls={
+ 'Bug Reports': 'https://github.com/ultralytics/ultralytics/issues',
+ 'Funding': 'https://ultralytics.com',
+ 'Source': 'https://github.com/ultralytics/ultralytics',},
+ author="Ultralytics",
+ author_email='hello@ultralytics.com',
+ packages=find_packages(), # required
+ include_package_data=True,
+ install_requires=REQUIREMENTS,
+ extras_require={
+ 'dev':
+ ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs', 'mkdocstrings[python]', 'mkdocs-material'],},
+ classifiers=[
+ "Intended Audience :: Developers", "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10",
+ "Topic :: Software Development", "Topic :: Scientific/Engineering",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Image Recognition", "Operating System :: POSIX :: Linux",
+ "Operating System :: MacOS", "Operating System :: Microsoft :: Windows"],
+ keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics",
+ entry_points={
+ 'console_scripts': ['yolo = ultralytics.yolo.cli:cli', 'ultralytics = ultralytics.yolo.cli:cli'],})
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dcb7110db6eed6e668354a20d00578feb03e811
--- /dev/null
+++ b/ultralytics/__init__.py
@@ -0,0 +1,9 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+__version__ = "8.0.3"
+
+from ultralytics.hub import checks
+from ultralytics.yolo.engine.model import YOLO
+from ultralytics.yolo.utils import ops
+
+__all__ = ["__version__", "YOLO", "hub", "checks"] # allow simpler import
diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c945d5025071f0e920c8b24eaaf606579f03e1a
--- /dev/null
+++ b/ultralytics/hub/__init__.py
@@ -0,0 +1,133 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import os
+import shutil
+
+import psutil
+import requests
+from IPython import display # to display images and clear console output
+
+from ultralytics.hub.auth import Auth
+from ultralytics.hub.session import HubTrainingSession
+from ultralytics.hub.utils import PREFIX, split_key
+from ultralytics.yolo.utils import LOGGER, emojis, is_colab
+from ultralytics.yolo.utils.torch_utils import select_device
+from ultralytics.yolo.v8.detect import DetectionTrainer
+
+
+def checks(verbose=True):
+ if is_colab():
+ shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
+
+ if verbose:
+ # System info
+ gib = 1 << 30 # bytes per GiB
+ ram = psutil.virtual_memory().total
+ total, used, free = shutil.disk_usage("/")
+ display.clear_output()
+ s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
+ else:
+ s = ''
+
+ select_device(newline=False)
+ LOGGER.info(f'Setup complete β
{s}')
+
+
+def start(key=''):
+ # Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
+ def request_api_key(attempts=0):
+ """Prompt the user to input their API key"""
+ import getpass
+
+ max_attempts = 3
+ tries = f"Attempt {str(attempts + 1)} of {max_attempts}" if attempts > 0 else ""
+ LOGGER.info(f"{PREFIX}Login. {tries}")
+ input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
+ auth.api_key, model_id = split_key(input_key)
+ if not auth.authenticate():
+ attempts += 1
+ LOGGER.warning(f"{PREFIX}Invalid API key β οΈ\n")
+ if attempts < max_attempts:
+ return request_api_key(attempts)
+ raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate β"))
+ else:
+ return model_id
+
+ try:
+ api_key, model_id = split_key(key)
+ auth = Auth(api_key) # attempts cookie login if no api key is present
+ attempts = 1 if len(key) else 0
+ if not auth.get_state():
+ if len(key):
+ LOGGER.warning(f"{PREFIX}Invalid API key β οΈ\n")
+ model_id = request_api_key(attempts)
+ LOGGER.info(f"{PREFIX}Authenticated β
")
+ if not model_id:
+ raise ConnectionError(emojis('Connecting with global API key is not currently supported. β'))
+ session = HubTrainingSession(model_id=model_id, auth=auth)
+ session.check_disk_space()
+
+ # TODO: refactor, hardcoded for v8
+ args = session.model.copy()
+ args.pop("id")
+ args.pop("status")
+ args.pop("weights")
+ args["data"] = "coco128.yaml"
+ args["model"] = "yolov8n.yaml"
+ args["batch_size"] = 16
+ args["imgsz"] = 64
+
+ trainer = DetectionTrainer(overrides=args)
+ session.register_callbacks(trainer)
+ setattr(trainer, 'hub_session', session)
+ trainer.train()
+ except Exception as e:
+ LOGGER.warning(f"{PREFIX}{e}")
+
+
+def reset_model(key=''):
+ # Reset a trained model to an untrained state
+ api_key, model_id = split_key(key)
+ r = requests.post('https://api.ultralytics.com/model-reset', json={"apiKey": api_key, "modelId": model_id})
+
+ if r.status_code == 200:
+ LOGGER.info(f"{PREFIX}model reset successfully")
+ return
+ LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}")
+
+
+def export_model(key='', format='torchscript'):
+ # Export a model to all formats
+ api_key, model_id = split_key(key)
+ formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs',
+ 'ultralytics_tflite', 'ultralytics_coreml')
+ assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"
+
+ r = requests.post('https://api.ultralytics.com/export',
+ json={
+ "apiKey": api_key,
+ "modelId": model_id,
+ "format": format})
+ assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
+ LOGGER.info(f"{PREFIX}{format} export started β
")
+
+
+def get_export(key='', format='torchscript'):
+ # Get an exported model dictionary with download URL
+ api_key, model_id = split_key(key)
+ formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs',
+ 'ultralytics_tflite', 'ultralytics_coreml')
+ assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"
+
+ r = requests.post('https://api.ultralytics.com/get-export',
+ json={
+ "apiKey": api_key,
+ "modelId": model_id,
+ "format": format})
+ assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
+ return r.json()
+
+
+# temp. For checking
+if __name__ == "__main__":
+ start(key="b3fba421be84a20dbe68644e14436d1cce1b0a0aaa_HeMfHgvHsseMPhdq7Ylz")
diff --git a/ultralytics/hub/auth.py b/ultralytics/hub/auth.py
new file mode 100644
index 0000000000000000000000000000000000000000..e38f228bca451467b2a622a378a469ddef1d0c13
--- /dev/null
+++ b/ultralytics/hub/auth.py
@@ -0,0 +1,70 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import requests
+
+from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials
+from ultralytics.yolo.utils import is_colab
+
+API_KEY_PATH = "https://hub.ultralytics.com/settings?tab=api+keys"
+
+
+class Auth:
+ id_token = api_key = model_key = False
+
+ def __init__(self, api_key=None):
+ self.api_key = self._clean_api_key(api_key)
+ self.authenticate() if self.api_key else self.auth_with_cookies()
+
+ @staticmethod
+ def _clean_api_key(key: str) -> str:
+ """Strip model from key if present"""
+ separator = "_"
+ return key.split(separator)[0] if separator in key else key
+
+ def authenticate(self) -> bool:
+ """Attempt to authenticate with server"""
+ try:
+ header = self.get_auth_header()
+ if header:
+ r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
+ if not r.json().get('success', False):
+ raise ConnectionError("Unable to authenticate.")
+ return True
+ raise ConnectionError("User has not authenticated locally.")
+ except ConnectionError:
+ self.id_token = self.api_key = False # reset invalid
+ return False
+
+ def auth_with_cookies(self) -> bool:
+ """
+ Attempt to fetch authentication via cookies and set id_token.
+ User must be logged in to HUB and running in a supported browser.
+ """
+ if not is_colab():
+ return False # Currently only works with Colab
+ try:
+ authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
+ if authn.get("success", False):
+ self.id_token = authn.get("data", {}).get("idToken", None)
+ self.authenticate()
+ return True
+ raise ConnectionError("Unable to fetch browser authentication details.")
+ except ConnectionError:
+ self.id_token = False # reset invalid
+ return False
+
+ def get_auth_header(self):
+ if self.id_token:
+ return {"authorization": f"Bearer {self.id_token}"}
+ elif self.api_key:
+ return {"x-api-key": self.api_key}
+ else:
+ return None
+
+ def get_state(self) -> bool:
+ """Get the authentication state"""
+ return self.id_token or self.api_key
+
+ def set_api_key(self, key: str):
+ """Get the authentication state"""
+ self.api_key = key
diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d268fe85c15360584116daa9e29d17c93b8e9b
--- /dev/null
+++ b/ultralytics/hub/session.py
@@ -0,0 +1,122 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import signal
+import sys
+from pathlib import Path
+from time import sleep
+
+import requests
+
+from ultralytics import __version__
+from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request
+from ultralytics.yolo.utils import LOGGER, is_colab, threaded
+
+AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
+
+session = None
+
+
+def signal_handler(signum, frame):
+ """ Confirm exit """
+ global hub_logger
+ LOGGER.info(f'Signal received. {signum} {frame}')
+ if isinstance(session, HubTrainingSession):
+ hub_logger.alive = False
+ del hub_logger
+ sys.exit(signum)
+
+
+signal.signal(signal.SIGTERM, signal_handler)
+signal.signal(signal.SIGINT, signal_handler)
+
+
+class HubTrainingSession:
+
+ def __init__(self, model_id, auth):
+ self.agent_id = None # identifies which instance is communicating with server
+ self.model_id = model_id
+ self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
+ self.auth_header = auth.get_auth_header()
+ self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
+ self.t = {} # rate limit timers (seconds)
+ self.metrics_queue = {} # metrics queue
+ self.alive = True # for heartbeats
+ self.model = self._get_model()
+ self._heartbeats() # start heartbeats
+
+ def __del__(self):
+ # Class destructor
+ self.alive = False
+
+ def upload_metrics(self):
+ payload = {"metrics": self.metrics_queue.copy(), "type": "metrics"}
+ smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
+
+ def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
+ # Upload a model to HUB
+ file = None
+ if Path(weights).is_file():
+ with open(weights, "rb") as f:
+ file = f.read()
+ if final:
+ smart_request(f'{self.api_url}/upload',
+ data={
+ "epoch": epoch,
+ "type": "final",
+ "map": map},
+ files={"best.pt": file},
+ headers=self.auth_header,
+ retry=10,
+ timeout=3600,
+ code=4)
+ else:
+ smart_request(f'{self.api_url}/upload',
+ data={
+ "epoch": epoch,
+ "type": "epoch",
+ "isBest": bool(is_best)},
+ headers=self.auth_header,
+ files={"last.pt": file},
+ code=3)
+
+ def _get_model(self):
+ # Returns model from database by id
+ api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}"
+ headers = self.auth_header
+
+ try:
+ r = smart_request(api_url, method="get", headers=headers, thread=False, code=0)
+ data = r.json().get("data", None)
+ if not data:
+ return
+ assert data['data'], 'ERROR: Dataset may still be processing. Please wait a minute and try again.' # RF fix
+ self.model_id = data["id"]
+
+ return data
+ except requests.exceptions.ConnectionError as e:
+ raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
+
+ def check_disk_space(self):
+ if not check_dataset_disk_space(self.model['data']):
+ raise MemoryError("Not enough disk space")
+
+ # COMMENT: Should not be needed as HUB is now considered an integration and is in integrations_callbacks
+ # import ultralytics.yolo.utils.callbacks.hub as hub_callbacks
+ # @staticmethod
+ # def register_callbacks(trainer):
+ # for k, v in hub_callbacks.callbacks.items():
+ # trainer.add_callback(k, v)
+
+ @threaded
+ def _heartbeats(self):
+ while self.alive:
+ r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
+ json={
+ "agent": AGENT_NAME,
+ "agentId": self.agent_id},
+ headers=self.auth_header,
+ retry=0,
+ code=5,
+ thread=False)
+ self.agent_id = r.json().get('data', {}).get('agentId', None)
+ sleep(self.rate_limits['heartbeat'])
diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b925399b951a666305684f390d82b3d900515a1
--- /dev/null
+++ b/ultralytics/hub/utils.py
@@ -0,0 +1,150 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import os
+import shutil
+import threading
+import time
+
+import requests
+
+from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis
+
+PREFIX = colorstr('Ultralytics: ')
+HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
+HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
+
+
+def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
+ # Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
+ gib = 1 << 30 # bytes per GiB
+ data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
+ total, used, free = (x / gib for x in shutil.disk_usage("/")) # bytes
+ LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space')
+ if data * sf < free:
+ return True # sufficient space
+ LOGGER.warning(f'{PREFIX}WARNING: Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
+ f'training cancelled β. Please free {data * sf - free:.1f} GB additional disk space and try again.')
+ return False # insufficient space
+
+
+def request_with_credentials(url: str) -> any:
+ """ Make an ajax request with cookies attached """
+ from google.colab import output # noqa
+ from IPython import display # noqa
+ display.display(
+ display.Javascript("""
+ window._hub_tmp = new Promise((resolve, reject) => {
+ const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
+ fetch("%s", {
+ method: 'POST',
+ credentials: 'include'
+ })
+ .then((response) => resolve(response.json()))
+ .then((json) => {
+ clearTimeout(timeout);
+ }).catch((err) => {
+ clearTimeout(timeout);
+ reject(err);
+ });
+ });
+ """ % url))
+ return output.eval_js("_hub_tmp")
+
+
+# Deprecated TODO: eliminate this function?
+def split_key(key=''):
+ """
+ Verify and split a 'api_key[sep]model_id' string, sep is one of '.' or '_'
+
+ Args:
+ key (str): The model key to split. If not provided, the user will be prompted to enter it.
+
+ Returns:
+ Tuple[str, str]: A tuple containing the API key and model ID.
+ """
+
+ import getpass
+
+ error_string = emojis(f'{PREFIX}Invalid API key β οΈ\n') # error string
+ if not key:
+ key = getpass.getpass('Enter model key: ')
+ sep = '_' if '_' in key else '.' if '.' in key else None # separator
+ assert sep, error_string
+ api_key, model_id = key.split(sep)
+ assert len(api_key) and len(model_id), error_string
+ return api_key, model_id
+
+
+def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post", verbose=True, **kwargs):
+ """
+ Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
+
+ Args:
+ *args: Positional arguments to be passed to the requests function specified in method.
+ retry (int, optional): Number of retries to attempt before giving up. Default is 3.
+ timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
+ thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
+ code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
+ method (str, optional): The HTTP method to use for the request. Choices are 'post' and 'get'. Default is 'post'.
+ verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
+ **kwargs: Keyword arguments to be passed to the requests function specified in method.
+
+ Returns:
+ requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None.
+ """
+ retry_codes = (408, 500) # retry only these codes
+
+ def func(*func_args, **func_kwargs):
+ r = None # response
+ t0 = time.time() # initial time for timer
+ for i in range(retry + 1):
+ if (time.time() - t0) > timeout:
+ break
+ if method == 'post':
+ r = requests.post(*func_args, **func_kwargs) # i.e. post(url, data, json, files)
+ elif method == 'get':
+ r = requests.get(*func_args, **func_kwargs) # i.e. get(url, data, json, files)
+ if r.status_code == 200:
+ break
+ try:
+ m = r.json().get('message', 'No JSON message.')
+ except AttributeError:
+ m = 'Unable to read JSON.'
+ if i == 0:
+ if r.status_code in retry_codes:
+ m += f' Retrying {retry}x for {timeout}s.' if retry else ''
+ elif r.status_code == 429: # rate limit
+ h = r.headers # response headers
+ m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
+ f"Please retry after {h['Retry-After']}s."
+ if verbose:
+ LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
+ if r.status_code not in retry_codes:
+ return r
+ time.sleep(2 ** i) # exponential standoff
+ return r
+
+ if thread:
+ threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
+ else:
+ return func(*args, **kwargs)
+
+
+@TryExcept()
+def sync_analytics(cfg, all_keys=False, enabled=False):
+ """
+ Sync analytics data if enabled in the global settings
+
+ Args:
+ cfg (DictConfig): Configuration for the task and mode.
+ all_keys (bool): Sync all items, not just non-default values.
+ enabled (bool): For debugging.
+ """
+ if SETTINGS['sync'] and RANK in {-1, 0} and enabled:
+ cfg = dict(cfg) # convert type from DictConfig to dict
+ if not all_keys:
+ cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CONFIG_DICT.get(k, None)} # retain non-default values
+ cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data
+
+ # Send a request to the HUB API to sync analytics
+ smart_request(f'{HUB_API_ROOT}/v1/usage/anonymous', json=cfg, headers=None, code=3, retry=0, verbose=False)
diff --git a/ultralytics/models/README.md b/ultralytics/models/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7d726576fa52c402217d98845d41dec0a7ea969d
--- /dev/null
+++ b/ultralytics/models/README.md
@@ -0,0 +1,36 @@
+## Models
+
+Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration
+files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted
+and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image
+segmentation tasks.
+
+These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like
+instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms,
+from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this
+directory provides a great starting point for your custom model development needs.
+
+To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've
+selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full
+details at the Ultralytics [Docs](https://docs.ultralytics.com), and if you need help or have any questions, feel free
+to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
+
+### Usage
+
+Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
+
+```bash
+yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
+```
+
+They may also be used directly in a Python environment, and accepts the same
+[arguments](https://docs.ultralytics.com/config/) as in the CLI example above:
+
+```python
+from ultralytics import YOLO
+
+model = YOLO("yolov8n.yaml") # build a YOLOv8n model from scratch
+
+model.info() # display model information
+model.train(data="coco128.yaml", epochs=100) # train the model
+```
diff --git a/ultralytics/models/v3/yolov3-spp.yaml b/ultralytics/models/v3/yolov3-spp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5d6794f1d02e16563355c26aab7c96967941c92e
--- /dev/null
+++ b/ultralytics/models/v3/yolov3-spp.yaml
@@ -0,0 +1,47 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# darknet53 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, Conv, [32, 3, 1]], # 0
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
+ [-1, 1, Bottleneck, [64]],
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
+ [-1, 2, Bottleneck, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
+ [-1, 8, Bottleneck, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
+ [-1, 8, Bottleneck, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
+ [-1, 4, Bottleneck, [1024]], # 10
+ ]
+
+# YOLOv3-SPP head
+head:
+ [[-1, 1, Bottleneck, [1024, False]],
+ [-1, 1, SPP, [512, [5, 9, 13]]],
+ [-1, 1, Conv, [1024, 3, 1]],
+ [-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
+
+ [-2, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
+ [-1, 1, Bottleneck, [512, False]],
+ [-1, 1, Bottleneck, [512, False]],
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
+
+ [-2, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 6], 1, Concat, [1]], # cat backbone P3
+ [-1, 1, Bottleneck, [256, False]],
+ [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
+
+ [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
+ ]
diff --git a/ultralytics/models/v3/yolov3-tiny.yaml b/ultralytics/models/v3/yolov3-tiny.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d7921d37800c92832c52ba506a902ac82ea96a24
--- /dev/null
+++ b/ultralytics/models/v3/yolov3-tiny.yaml
@@ -0,0 +1,38 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# YOLOv3-tiny backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, Conv, [16, 3, 1]], # 0
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
+ [-1, 1, Conv, [32, 3, 1]],
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
+ [-1, 1, Conv, [64, 3, 1]],
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
+ [-1, 1, Conv, [128, 3, 1]],
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
+ [-1, 1, Conv, [256, 3, 1]],
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
+ [-1, 1, Conv, [512, 3, 1]],
+ [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
+ [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
+ ]
+
+# YOLOv3-tiny head
+head:
+ [[-1, 1, Conv, [1024, 3, 1]],
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
+
+ [-2, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
+ [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
+
+ [[19, 15], 1, Detect, [nc]], # Detect(P4, P5)
+ ]
diff --git a/ultralytics/models/v3/yolov3.yaml b/ultralytics/models/v3/yolov3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3ecb642457ebfa2dea684ea714837da4a879ada9
--- /dev/null
+++ b/ultralytics/models/v3/yolov3.yaml
@@ -0,0 +1,47 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# darknet53 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, Conv, [32, 3, 1]], # 0
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
+ [-1, 1, Bottleneck, [64]],
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
+ [-1, 2, Bottleneck, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
+ [-1, 8, Bottleneck, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
+ [-1, 8, Bottleneck, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
+ [-1, 4, Bottleneck, [1024]], # 10
+ ]
+
+# YOLOv3 head
+head:
+ [[-1, 1, Bottleneck, [1024, False]],
+ [-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, Conv, [1024, 3, 1]],
+ [-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
+
+ [-2, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
+ [-1, 1, Bottleneck, [512, False]],
+ [-1, 1, Bottleneck, [512, False]],
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
+
+ [-2, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 6], 1, Concat, [1]], # cat backbone P3
+ [-1, 1, Bottleneck, [256, False]],
+ [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
+
+ [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
+ ]
diff --git a/ultralytics/models/v5/yolov5l.yaml b/ultralytics/models/v5/yolov5l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ca3a85d364e84c90cf6df6457ecd51791baa7040
--- /dev/null
+++ b/ultralytics/models/v5/yolov5l.yaml
@@ -0,0 +1,44 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# YOLOv5 v6.0 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
+ [-1, 3, C3, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
+ [-1, 6, C3, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
+ [-1, 9, C3, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
+ [-1, 3, C3, [1024]],
+ [-1, 1, SPPF, [1024, 5]], # 9
+ ]
+
+# YOLOv5 v6.0 head
+head:
+ [[-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
+ [-1, 3, C3, [512, False]], # 13
+
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
+ [-1, 3, C3, [256, False]], # 17 (P3/8-small)
+
+ [-1, 1, Conv, [256, 3, 2]],
+ [[-1, 14], 1, Concat, [1]], # cat head P4
+ [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
+
+ [-1, 1, Conv, [512, 3, 2]],
+ [[-1, 10], 1, Concat, [1]], # cat head P5
+ [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
+
+ [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
+ ]
\ No newline at end of file
diff --git a/ultralytics/models/v5/yolov5m.yaml b/ultralytics/models/v5/yolov5m.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fddcc63b744a717f3e12cebd4508fbd3778bd188
--- /dev/null
+++ b/ultralytics/models/v5/yolov5m.yaml
@@ -0,0 +1,44 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.67 # model depth multiple
+width_multiple: 0.75 # layer channel multiple
+
+# YOLOv5 v6.0 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
+ [-1, 3, C3, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
+ [-1, 6, C3, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
+ [-1, 9, C3, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
+ [-1, 3, C3, [1024]],
+ [-1, 1, SPPF, [1024, 5]], # 9
+ ]
+
+# YOLOv5 v6.0 head
+head:
+ [[-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
+ [-1, 3, C3, [512, False]], # 13
+
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
+ [-1, 3, C3, [256, False]], # 17 (P3/8-small)
+
+ [-1, 1, Conv, [256, 3, 2]],
+ [[-1, 14], 1, Concat, [1]], # cat head P4
+ [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
+
+ [-1, 1, Conv, [512, 3, 2]],
+ [[-1, 10], 1, Concat, [1]], # cat head P5
+ [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
+
+ [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
+ ]
\ No newline at end of file
diff --git a/ultralytics/models/v5/yolov5n.yaml b/ultralytics/models/v5/yolov5n.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..259e1cf4bd4ad9f8391f2876273b7c82849c42db
--- /dev/null
+++ b/ultralytics/models/v5/yolov5n.yaml
@@ -0,0 +1,44 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.33 # model depth multiple
+width_multiple: 0.25 # layer channel multiple
+
+# YOLOv5 v6.0 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
+ [-1, 3, C3, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
+ [-1, 6, C3, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
+ [-1, 9, C3, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
+ [-1, 3, C3, [1024]],
+ [-1, 1, SPPF, [1024, 5]], # 9
+ ]
+
+# YOLOv5 v6.0 head
+head:
+ [[-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
+ [-1, 3, C3, [512, False]], # 13
+
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
+ [-1, 3, C3, [256, False]], # 17 (P3/8-small)
+
+ [-1, 1, Conv, [256, 3, 2]],
+ [[-1, 14], 1, Concat, [1]], # cat head P4
+ [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
+
+ [-1, 1, Conv, [512, 3, 2]],
+ [[-1, 10], 1, Concat, [1]], # cat head P5
+ [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
+
+ [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
+ ]
\ No newline at end of file
diff --git a/ultralytics/models/v5/yolov5s.yaml b/ultralytics/models/v5/yolov5s.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9e63349f9aa68d654bc9e55a19102e704ca32de9
--- /dev/null
+++ b/ultralytics/models/v5/yolov5s.yaml
@@ -0,0 +1,45 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.33 # model depth multiple
+width_multiple: 0.50 # layer channel multiple
+
+
+# YOLOv5 v6.0 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
+ [-1, 3, C3, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
+ [-1, 6, C3, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
+ [-1, 9, C3, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
+ [-1, 3, C3, [1024]],
+ [-1, 1, SPPF, [1024, 5]], # 9
+ ]
+
+# YOLOv5 v6.0 head
+head:
+ [[-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
+ [-1, 3, C3, [512, False]], # 13
+
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
+ [-1, 3, C3, [256, False]], # 17 (P3/8-small)
+
+ [-1, 1, Conv, [256, 3, 2]],
+ [[-1, 14], 1, Concat, [1]], # cat head P4
+ [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
+
+ [-1, 1, Conv, [512, 3, 2]],
+ [[-1, 10], 1, Concat, [1]], # cat head P5
+ [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
+
+ [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
+ ]
\ No newline at end of file
diff --git a/ultralytics/models/v5/yolov5x.yaml b/ultralytics/models/v5/yolov5x.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8217affcb61800e70f29bd30b08fb980c9254b19
--- /dev/null
+++ b/ultralytics/models/v5/yolov5x.yaml
@@ -0,0 +1,44 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.33 # model depth multiple
+width_multiple: 1.25 # layer channel multiple
+
+# YOLOv5 v6.0 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
+ [-1, 3, C3, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
+ [-1, 6, C3, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
+ [-1, 9, C3, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
+ [-1, 3, C3, [1024]],
+ [-1, 1, SPPF, [1024, 5]], # 9
+ ]
+
+# YOLOv5 v6.0 head
+head:
+ [[-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
+ [-1, 3, C3, [512, False]], # 13
+
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
+ [-1, 3, C3, [256, False]], # 17 (P3/8-small)
+
+ [-1, 1, Conv, [256, 3, 2]],
+ [[-1, 14], 1, Concat, [1]], # cat head P4
+ [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
+
+ [-1, 1, Conv, [512, 3, 2]],
+ [[-1, 10], 1, Concat, [1]], # cat head P5
+ [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
+
+ [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
+ ]
\ No newline at end of file
diff --git a/ultralytics/models/v8/cls/yolov8l-cls.yaml b/ultralytics/models/v8/cls/yolov8l-cls.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bf981a8434cb9ae0d66f31c539fbfa2cbdda2dc7
--- /dev/null
+++ b/ultralytics/models/v8/cls/yolov8l-cls.yaml
@@ -0,0 +1,23 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 1000 # number of classes
+depth_multiple: 1.00 # scales module repeats
+width_multiple: 1.00 # scales convolution channels
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, Classify, [nc]]
diff --git a/ultralytics/models/v8/cls/yolov8m-cls.yaml b/ultralytics/models/v8/cls/yolov8m-cls.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7e91894ce7613efc0cf9fbac8b3800825e42a927
--- /dev/null
+++ b/ultralytics/models/v8/cls/yolov8m-cls.yaml
@@ -0,0 +1,23 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 1000 # number of classes
+depth_multiple: 0.67 # scales module repeats
+width_multiple: 0.75 # scales convolution channels
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, Classify, [nc]]
diff --git a/ultralytics/models/v8/cls/yolov8n-cls.yaml b/ultralytics/models/v8/cls/yolov8n-cls.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..29be2264b9f8e801d8b1dc044dd8c4cdf4cdc700
--- /dev/null
+++ b/ultralytics/models/v8/cls/yolov8n-cls.yaml
@@ -0,0 +1,23 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 1000 # number of classes
+depth_multiple: 0.33 # scales module repeats
+width_multiple: 0.25 # scales convolution channels
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, Classify, [nc]]
diff --git a/ultralytics/models/v8/cls/yolov8s-cls.yaml b/ultralytics/models/v8/cls/yolov8s-cls.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..00ddc55afe8ae8b4693399af29869ee50e6c50d3
--- /dev/null
+++ b/ultralytics/models/v8/cls/yolov8s-cls.yaml
@@ -0,0 +1,23 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 1000 # number of classes
+depth_multiple: 0.33 # scales module repeats
+width_multiple: 0.50 # scales convolution channels
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, Classify, [nc]]
diff --git a/ultralytics/models/v8/cls/yolov8x-cls.yaml b/ultralytics/models/v8/cls/yolov8x-cls.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..46c75d5a2a04b59632aa25c52602a9ada162c0e9
--- /dev/null
+++ b/ultralytics/models/v8/cls/yolov8x-cls.yaml
@@ -0,0 +1,23 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 1000 # number of classes
+depth_multiple: 1.00 # scales module repeats
+width_multiple: 1.25 # scales convolution channels
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, Classify, [nc]]
diff --git a/ultralytics/models/v8/seg/yolov8l-seg.yaml b/ultralytics/models/v8/seg/yolov8l-seg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..235dc761eb3680955c33887000dc6ba8e32c8cd4
--- /dev/null
+++ b/ultralytics/models/v8/seg/yolov8l-seg.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.00 # scales module repeats
+width_multiple: 1.00 # scales convolution channels
+
+# YOLOv8.0l backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [512, True]]
+ - [-1, 1, SPPF, [512, 5]] # 9
+
+# YOLOv8.0l head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [512]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/seg/yolov8m-seg.yaml b/ultralytics/models/v8/seg/yolov8m-seg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..17c07f6bb0089de42190d467676c2434eed64bdc
--- /dev/null
+++ b/ultralytics/models/v8/seg/yolov8m-seg.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.67 # scales module repeats
+width_multiple: 0.75 # scales convolution channels
+
+# YOLOv8.0m backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [768, True]]
+ - [-1, 1, SPPF, [768, 5]] # 9
+
+# YOLOv8.0m head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [768]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/seg/yolov8n-seg.yaml b/ultralytics/models/v8/seg/yolov8n-seg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ffecc9d9ba366d075a8655dba84ebd117a14afaf
--- /dev/null
+++ b/ultralytics/models/v8/seg/yolov8n-seg.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.33 # scales module repeats
+width_multiple: 0.25 # scales convolution channels
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 1, SPPF, [1024, 5]] # 9
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/seg/yolov8s-seg.yaml b/ultralytics/models/v8/seg/yolov8s-seg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dc828a1cd290871a6304e3cce47289d21c397c93
--- /dev/null
+++ b/ultralytics/models/v8/seg/yolov8s-seg.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.33 # scales module repeats
+width_multiple: 0.50 # scales convolution channels
+
+# YOLOv8.0s backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 1, SPPF, [1024, 5]] # 9
+
+# YOLOv8.0s head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/seg/yolov8x-seg.yaml b/ultralytics/models/v8/seg/yolov8x-seg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0572283c987401b0f1a270a9a129f40fa0f51860
--- /dev/null
+++ b/ultralytics/models/v8/seg/yolov8x-seg.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.00 # scales module repeats
+width_multiple: 1.25 # scales convolution channels
+
+# YOLOv8.0x backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [512, True]]
+ - [-1, 1, SPPF, [512, 5]] # 9
+
+# YOLOv8.0x head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [512]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/yolov8l.yaml b/ultralytics/models/v8/yolov8l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9ec170c3099bcf2d240fa61e0bf6b682c030d356
--- /dev/null
+++ b/ultralytics/models/v8/yolov8l.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.00 # scales module repeats
+width_multiple: 1.00 # scales convolution channels
+
+# YOLOv8.0l backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [512, True]]
+ - [-1, 1, SPPF, [512, 5]] # 9
+
+# YOLOv8.0l head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [512]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/yolov8m.yaml b/ultralytics/models/v8/yolov8m.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f97cf052afaa838952eab418fbbad7c15ce8610c
--- /dev/null
+++ b/ultralytics/models/v8/yolov8m.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.67 # scales module repeats
+width_multiple: 0.75 # scales convolution channels
+
+# YOLOv8.0m backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [768, True]]
+ - [-1, 1, SPPF, [768, 5]] # 9
+
+# YOLOv8.0m head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [768]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/yolov8n.yaml b/ultralytics/models/v8/yolov8n.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..83cf0801e257c69cec7fb46c0307b55dcc3acaeb
--- /dev/null
+++ b/ultralytics/models/v8/yolov8n.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.33 # scales module repeats
+width_multiple: 0.25 # scales convolution channels
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 1, SPPF, [1024, 5]] # 9
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/yolov8s.yaml b/ultralytics/models/v8/yolov8s.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0c96d945cd33bb9228d09295cafccd45f5be13f7
--- /dev/null
+++ b/ultralytics/models/v8/yolov8s.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 0.33 # scales module repeats
+width_multiple: 0.50 # scales convolution channels
+
+# YOLOv8.0s backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 1, SPPF, [1024, 5]] # 9
+
+# YOLOv8.0s head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/yolov8x.yaml b/ultralytics/models/v8/yolov8x.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..20e4070dd3813f808bfb85e307b78e0cc93b36dd
--- /dev/null
+++ b/ultralytics/models/v8/yolov8x.yaml
@@ -0,0 +1,40 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.00 # scales module repeats
+width_multiple: 1.25 # scales convolution channels
+
+# YOLOv8.0x backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [512, True]]
+ - [-1, 1, SPPF, [512, 5]] # 9
+
+# YOLOv8.0x head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 13
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 17 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [512]] # 23 (P5/32-large)
+
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
diff --git a/ultralytics/models/v8/yolov8x6.yaml b/ultralytics/models/v8/yolov8x6.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8ffcdeae3a7831bb0fabb498e4ae580b3f223ee6
--- /dev/null
+++ b/ultralytics/models/v8/yolov8x6.yaml
@@ -0,0 +1,50 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+# Parameters
+nc: 80 # number of classes
+depth_multiple: 1.00 # scales module repeats
+width_multiple: 1.25 # scales convolution channels
+
+# YOLOv8.0x6 backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [512, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 9-P6/64
+ - [-1, 3, C2f, [512, True]]
+ - [-1, 1, SPPF, [512, 5]] # 11
+
+# YOLOv8.0x6 head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 8], 1, Concat, [1]] # cat backbone P5
+ - [-1, 3, C2, [512, False]] # 14
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2, [512, False]] # 17
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2, [256, False]] # 20 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 17], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 14], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2, [512, False]] # 26 (P5/32-large)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 11], 1, Concat, [1]] # cat head P6
+ - [-1, 3, C2, [512, False]] # 29 (P6/64-xlarge)
+
+ - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6)
diff --git a/ultralytics/nn/__init__.py b/ultralytics/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bbecca057912dba09674433d84bf79502029175
--- /dev/null
+++ b/ultralytics/nn/autobackend.py
@@ -0,0 +1,381 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import json
+import platform
+from collections import OrderedDict, namedtuple
+from pathlib import Path
+from urllib.parse import urlparse
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+
+from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
+from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version
+from ultralytics.yolo.utils.downloads import attempt_download, is_url
+from ultralytics.yolo.utils.ops import xywh2xyxy
+
+
+class AutoBackend(nn.Module):
+
+ def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
+ """
+ Ultralytics YOLO MultiBackend class for python inference on various backends
+
+ Args:
+ weights: the path to the weights file. Defaults to yolov8n.pt
+ device: The device to run the model on.
+ dnn: If you want to use OpenCV's DNN module to run the inference, set this to True. Defaults to
+ False
+ data: a dictionary containing the following keys:
+ fp16: If true, will use half precision. Defaults to False
+ fuse: whether to fuse the model or not. Defaults to True
+
+ Supported format and their usage:
+ | Platform | weights |
+ |-----------------------|------------------|
+ | PyTorch | *.pt |
+ | TorchScript | *.torchscript |
+ | ONNX Runtime | *.onnx |
+ | ONNX OpenCV DNN | *.onnx --dnn |
+ | OpenVINO | *.xml |
+ | CoreML | *.mlmodel |
+ | TensorRT | *.engine |
+ | TensorFlow SavedModel | *_saved_model |
+ | TensorFlow GraphDef | *.pb |
+ | TensorFlow Lite | *.tflite |
+ | TensorFlow Edge TPU | *_edgetpu.tflite |
+ | PaddlePaddle | *_paddle_model |
+ """
+ super().__init__()
+ w = str(weights[0] if isinstance(weights, list) else weights)
+ nn_module = isinstance(weights, torch.nn.Module)
+ pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
+ fp16 &= pt or jit or onnx or engine or nn_module # FP16
+ nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
+ stride = 32 # default stride
+ cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
+ if not (pt or triton or nn_module):
+ w = attempt_download(w) # download if not local
+
+ # NOTE: special case: in-memory pytorch model
+ if nn_module:
+ model = weights.to(device)
+ model = model.fuse() if fuse else model
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
+ model.half() if fp16 else model.float()
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
+ pt = True
+ elif pt: # PyTorch
+ from ultralytics.nn.tasks import attempt_load_weights
+ model = attempt_load_weights(weights if isinstance(weights, list) else w,
+ device=device,
+ inplace=True,
+ fuse=fuse)
+ stride = max(int(model.stride.max()), 32) # model stride
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
+ model.half() if fp16 else model.float()
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
+ elif jit: # TorchScript
+ LOGGER.info(f'Loading {w} for TorchScript inference...')
+ extra_files = {'config.txt': ''} # model metadata
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
+ model.half() if fp16 else model.float()
+ if extra_files['config.txt']: # load metadata dict
+ d = json.loads(extra_files['config.txt'],
+ object_hook=lambda d: {int(k) if k.isdigit() else k: v
+ for k, v in d.items()})
+ stride, names = int(d['stride']), d['names']
+ elif dnn: # ONNX OpenCV DNN
+ LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
+ check_requirements('opencv-python>=4.5.4')
+ net = cv2.dnn.readNetFromONNX(w)
+ elif onnx: # ONNX Runtime
+ LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
+ check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
+ import onnxruntime
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
+ session = onnxruntime.InferenceSession(w, providers=providers)
+ output_names = [x.name for x in session.get_outputs()]
+ meta = session.get_modelmeta().custom_metadata_map # metadata
+ if 'stride' in meta:
+ stride, names = int(meta['stride']), eval(meta['names'])
+ elif xml: # OpenVINO
+ LOGGER.info(f'Loading {w} for OpenVINO inference...')
+ check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
+ from openvino.runtime import Core, Layout, get_batch # noqa
+ ie = Core()
+ if not Path(w).is_file(): # if not *.xml
+ w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
+ network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
+ if network.get_parameters()[0].get_layout().empty:
+ network.get_parameters()[0].set_layout(Layout("NCHW"))
+ batch_dim = get_batch(network)
+ if batch_dim.is_static:
+ batch_size = batch_dim.get_length()
+ executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
+ stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
+ elif engine: # TensorRT
+ LOGGER.info(f'Loading {w} for TensorRT inference...')
+ import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
+ if device.type == 'cpu':
+ device = torch.device('cuda:0')
+ Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
+ logger = trt.Logger(trt.Logger.INFO)
+ with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
+ model = runtime.deserialize_cuda_engine(f.read())
+ context = model.create_execution_context()
+ bindings = OrderedDict()
+ output_names = []
+ fp16 = False # default updated below
+ dynamic = False
+ for i in range(model.num_bindings):
+ name = model.get_binding_name(i)
+ dtype = trt.nptype(model.get_binding_dtype(i))
+ if model.binding_is_input(i):
+ if -1 in tuple(model.get_binding_shape(i)): # dynamic
+ dynamic = True
+ context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
+ if dtype == np.float16:
+ fp16 = True
+ else: # output
+ output_names.append(name)
+ shape = tuple(context.get_binding_shape(i))
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
+ batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
+ elif coreml: # CoreML
+ LOGGER.info(f'Loading {w} for CoreML inference...')
+ import coremltools as ct
+ model = ct.models.MLModel(w)
+ elif saved_model: # TF SavedModel
+ LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
+ import tensorflow as tf
+ keras = False # assume TF1 saved_model
+ model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
+ elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
+ LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
+ import tensorflow as tf
+
+ def wrap_frozen_graph(gd, inputs, outputs):
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
+ ge = x.graph.as_graph_element
+ return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
+
+ def gd_outputs(gd):
+ name_list, input_list = [], []
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
+ name_list.append(node.name)
+ input_list.extend(node.input)
+ return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
+
+ gd = tf.Graph().as_graph_def() # TF GraphDef
+ with open(w, 'rb') as f:
+ gd.ParseFromString(f.read())
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
+ elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
+ try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
+ from tflite_runtime.interpreter import Interpreter, load_delegate
+ except ImportError:
+ import tensorflow as tf
+ Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
+ LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
+ delegate = {
+ 'Linux': 'libedgetpu.so.1',
+ 'Darwin': 'libedgetpu.1.dylib',
+ 'Windows': 'edgetpu.dll'}[platform.system()]
+ interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
+ else: # TFLite
+ LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
+ interpreter = Interpreter(model_path=w) # load TFLite model
+ interpreter.allocate_tensors() # allocate
+ input_details = interpreter.get_input_details() # inputs
+ output_details = interpreter.get_output_details() # outputs
+ elif tfjs: # TF.js
+ raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
+ elif paddle: # PaddlePaddle
+ LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
+ check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
+ import paddle.inference as pdi
+ if not Path(w).is_file(): # if not *.pdmodel
+ w = next(Path(w).rglob('*.pdmodel')) # get *.xml file from *_openvino_model dir
+ weights = Path(w).with_suffix('.pdiparams')
+ config = pdi.Config(str(w), str(weights))
+ if cuda:
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
+ predictor = pdi.create_predictor(config)
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
+ output_names = predictor.get_output_names()
+ elif triton: # NVIDIA Triton Inference Server
+ LOGGER.info('Triton Inference Server not supported...')
+ '''
+ TODO:
+ check_requirements('tritonclient[all]')
+ from utils.triton import TritonRemoteModel
+ model = TritonRemoteModel(url=w)
+ nhwc = model.runtime.startswith("tensorflow")
+ '''
+ else:
+ raise NotImplementedError(f'ERROR: {w} is not a supported format')
+
+ # class names
+ if 'names' not in locals():
+ names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
+ if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
+ names = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['names'] # human-readable names
+
+ self.__dict__.update(locals()) # assign all variables to self
+
+ def forward(self, im, augment=False, visualize=False):
+ """
+ Runs inference on the given model
+
+ Args:
+ im: the image tensor
+ augment: whether to augment the image. Defaults to False
+ visualize: if True, then the network will output the feature maps of the last convolutional layer.
+ Defaults to False
+ """
+ # YOLOv5 MultiBackend inference
+ b, ch, h, w = im.shape # batch, channel, height, width
+ if self.fp16 and im.dtype != torch.float16:
+ im = im.half() # to FP16
+ if self.nhwc:
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
+
+ if self.pt or self.nn_module: # PyTorch
+ y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
+ elif self.jit: # TorchScript
+ y = self.model(im)
+ elif self.dnn: # ONNX OpenCV DNN
+ im = im.cpu().numpy() # torch to numpy
+ self.net.setInput(im)
+ y = self.net.forward()
+ elif self.onnx: # ONNX Runtime
+ im = im.cpu().numpy() # torch to numpy
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
+ elif self.xml: # OpenVINO
+ im = im.cpu().numpy() # FP32
+ y = list(self.executable_network([im]).values())
+ elif self.engine: # TensorRT
+ if self.dynamic and im.shape != self.bindings['images'].shape:
+ i = self.model.get_binding_index('images')
+ self.context.set_binding_shape(i, im.shape) # reshape if dynamic
+ self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
+ for name in self.output_names:
+ i = self.model.get_binding_index(name)
+ self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
+ s = self.bindings['images'].shape
+ assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
+ self.binding_addrs['images'] = int(im.data_ptr())
+ self.context.execute_v2(list(self.binding_addrs.values()))
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
+ elif self.coreml: # CoreML
+ im = im.cpu().numpy()
+ im = Image.fromarray((im[0] * 255).astype('uint8'))
+ # im = im.resize((192, 320), Image.ANTIALIAS)
+ y = self.model.predict({'image': im}) # coordinates are xywh normalized
+ if 'confidence' in y:
+ box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
+ conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
+ y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
+ else:
+ y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
+ elif self.paddle: # PaddlePaddle
+ im = im.cpu().numpy().astype(np.float32)
+ self.input_handle.copy_from_cpu(im)
+ self.predictor.run()
+ y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
+ elif self.triton: # NVIDIA Triton Inference Server
+ y = self.model(im)
+ else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
+ im = im.cpu().numpy()
+ if self.saved_model: # SavedModel
+ y = self.model(im, training=False) if self.keras else self.model(im)
+ elif self.pb: # GraphDef
+ y = self.frozen_func(x=self.tf.constant(im))
+ else: # Lite or Edge TPU
+ input = self.input_details[0]
+ int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
+ if int8:
+ scale, zero_point = input['quantization']
+ im = (im / scale + zero_point).astype(np.uint8) # de-scale
+ self.interpreter.set_tensor(input['index'], im)
+ self.interpreter.invoke()
+ y = []
+ for output in self.output_details:
+ x = self.interpreter.get_tensor(output['index'])
+ if int8:
+ scale, zero_point = output['quantization']
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
+ y.append(x)
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
+ y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
+
+ if isinstance(y, (list, tuple)):
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
+ else:
+ return self.from_numpy(y)
+
+ def from_numpy(self, x):
+ """
+ `from_numpy` converts a numpy array to a tensor
+
+ Args:
+ x: the numpy array to convert
+ """
+ return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
+
+ def warmup(self, imgsz=(1, 3, 640, 640)):
+ """
+ Warmup model by running inference once
+
+ Args:
+ imgsz: the size of the image you want to run inference on.
+ """
+ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
+ if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
+ for _ in range(2 if self.jit else 1): #
+ self.forward(im) # warmup
+
+ @staticmethod
+ def _model_type(p='path/to/model.pt'):
+ """
+ This function takes a path to a model file and returns the model type
+
+ Args:
+ p: path to the model file. Defaults to path/to/model.pt
+ """
+ # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
+ # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
+ from ultralytics.yolo.engine.exporter import export_formats
+ sf = list(export_formats().Suffix) # export suffixes
+ if not is_url(p, check=False) and not isinstance(p, str):
+ check_suffix(p, sf) # checks
+ url = urlparse(p) # if url may be Triton inference server
+ types = [s in Path(p).name for s in sf]
+ types[8] &= not types[9] # tflite &= not edgetpu
+ triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
+ return types + [triton]
+
+ @staticmethod
+ def _load_metadata(f=Path('path/to/meta.yaml')):
+ """
+ > Loads the metadata from a yaml file
+
+ Args:
+ f: The path to the metadata file.
+ """
+ from ultralytics.yolo.utils.files import yaml_load
+
+ # Load metadata from meta.yaml if it exists
+ if f.exists():
+ d = yaml_load(f)
+ return d['stride'], d['names'] # assign stride, names
+ return None, None
diff --git a/ultralytics/nn/modules.py b/ultralytics/nn/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..51ca310d40e364abd46fe0aeaf420b887ca8007f
--- /dev/null
+++ b/ultralytics/nn/modules.py
@@ -0,0 +1,688 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Common modules
+"""
+
+import math
+import warnings
+from copy import copy
+from pathlib import Path
+
+import cv2
+import numpy as np
+import pandas as pd
+import requests
+import torch
+import torch.nn as nn
+from PIL import Image, ImageOps
+from torch.cuda import amp
+
+from ultralytics.nn.autobackend import AutoBackend
+from ultralytics.yolo.data.augment import LetterBox
+from ultralytics.yolo.utils import LOGGER, colorstr
+from ultralytics.yolo.utils.files import increment_path
+from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh
+from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
+from ultralytics.yolo.utils.tal import dist2bbox, make_anchors
+from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode
+
+# from utils.plots import feature_visualization TODO
+
+
+def autopad(k, p=None, d=1): # kernel, padding, dilation
+ # Pad to 'same' shape outputs
+ if d > 1:
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
+ if p is None:
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
+ return p
+
+
+class Conv(nn.Module):
+ # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
+ default_act = nn.SiLU() # default activation
+
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
+ super().__init__()
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
+ self.bn = nn.BatchNorm2d(c2)
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
+
+ def forward(self, x):
+ return self.act(self.bn(self.conv(x)))
+
+ def forward_fuse(self, x):
+ return self.act(self.conv(x))
+
+
+class DWConv(Conv):
+ # Depth-wise convolution
+ def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
+
+
+class DWConvTranspose2d(nn.ConvTranspose2d):
+ # Depth-wise transpose convolution
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
+ super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
+
+
+class ConvTranspose(nn.Module):
+ # Convolution transpose 2d layer
+ default_act = nn.SiLU() # default activation
+
+ def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
+ super().__init__()
+ self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
+ self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
+
+ def forward(self, x):
+ return self.act(self.bn(self.conv_transpose(x)))
+
+
+class DFL(nn.Module):
+ # DFL module
+ def __init__(self, c1=16):
+ super().__init__()
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
+ x = torch.arange(c1, dtype=torch.float)
+ self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
+ self.c1 = c1
+
+ def forward(self, x):
+ b, c, a = x.shape # batch, channels, anchors
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
+ # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
+
+
+class TransformerLayer(nn.Module):
+ # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
+ def __init__(self, c, num_heads):
+ super().__init__()
+ self.q = nn.Linear(c, c, bias=False)
+ self.k = nn.Linear(c, c, bias=False)
+ self.v = nn.Linear(c, c, bias=False)
+ self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
+ self.fc1 = nn.Linear(c, c, bias=False)
+ self.fc2 = nn.Linear(c, c, bias=False)
+
+ def forward(self, x):
+ x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
+ x = self.fc2(self.fc1(x)) + x
+ return x
+
+
+class TransformerBlock(nn.Module):
+ # Vision Transformer https://arxiv.org/abs/2010.11929
+ def __init__(self, c1, c2, num_heads, num_layers):
+ super().__init__()
+ self.conv = None
+ if c1 != c2:
+ self.conv = Conv(c1, c2)
+ self.linear = nn.Linear(c2, c2) # learnable position embedding
+ self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
+ self.c2 = c2
+
+ def forward(self, x):
+ if self.conv is not None:
+ x = self.conv(x)
+ b, _, w, h = x.shape
+ p = x.flatten(2).permute(2, 0, 1)
+ return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
+
+
+class Bottleneck(nn.Module):
+ # Standard bottleneck
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, k[0], 1)
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class BottleneckCSP(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
+ self.act = nn.SiLU()
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ y1 = self.cv3(self.m(self.cv1(x)))
+ y2 = self.cv2(x)
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
+
+
+class C3(nn.Module):
+ # CSP Bottleneck with 3 convolutions
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c1, c_, 1, 1)
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
+
+
+class C2(nn.Module):
+ # CSP Bottleneck with 2 convolutions
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ self.c = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
+ self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
+ # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
+ self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ a, b = self.cv1(x).split((self.c, self.c), 1)
+ return self.cv2(torch.cat((self.m(a), b), 1))
+
+
+class C2f(nn.Module):
+ # CSP Bottleneck with 2 convolutions
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ self.c = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
+
+ def forward(self, x):
+ y = list(self.cv1(x).split((self.c, self.c), 1))
+ y.extend(m(y[-1]) for m in self.m)
+ return self.cv2(torch.cat(y, 1))
+
+
+class ChannelAttention(nn.Module):
+ # Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
+ def __init__(self, channels: int) -> None:
+ super().__init__()
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
+ self.act = nn.Sigmoid()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x * self.act(self.fc(self.pool(x)))
+
+
+class SpatialAttention(nn.Module):
+ # Spatial-attention module
+ def __init__(self, kernel_size=7):
+ super().__init__()
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
+ padding = 3 if kernel_size == 7 else 1
+ self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
+ self.act = nn.Sigmoid()
+
+ def forward(self, x):
+ return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
+
+
+class CBAM(nn.Module):
+ # CSP Bottleneck with 3 convolutions
+ def __init__(self, c1, ratio=16, kernel_size=7): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ self.channel_attention = ChannelAttention(c1)
+ self.spatial_attention = SpatialAttention(kernel_size)
+
+ def forward(self, x):
+ return self.spatial_attention(self.channel_attention(x))
+
+
+class C1(nn.Module):
+ # CSP Bottleneck with 3 convolutions
+ def __init__(self, c1, c2, n=1): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ self.cv1 = Conv(c1, c2, 1, 1)
+ self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
+
+ def forward(self, x):
+ y = self.cv1(x)
+ return self.m(y) + y
+
+
+class C3x(C3):
+ # C3 module with cross-convolutions
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
+ super().__init__(c1, c2, n, shortcut, g, e)
+ self.c_ = int(c2 * e)
+ self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
+
+
+class C3TR(C3):
+ # C3 module with TransformerBlock()
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
+ super().__init__(c1, c2, n, shortcut, g, e)
+ c_ = int(c2 * e)
+ self.m = TransformerBlock(c_, c_, 4, n)
+
+
+class C3Ghost(C3):
+ # C3 module with GhostBottleneck()
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
+ super().__init__(c1, c2, n, shortcut, g, e)
+ c_ = int(c2 * e) # hidden channels
+ self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
+
+
+class SPP(nn.Module):
+ # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
+ def __init__(self, c1, c2, k=(5, 9, 13)):
+ super().__init__()
+ c_ = c1 // 2 # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
+
+ def forward(self, x):
+ x = self.cv1(x)
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
+
+
+class SPPF(nn.Module):
+ # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
+ super().__init__()
+ c_ = c1 // 2 # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
+
+ def forward(self, x):
+ x = self.cv1(x)
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
+ y1 = self.m(x)
+ y2 = self.m(y1)
+ return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+class Focus(nn.Module):
+ # Focus wh information into c-space
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
+ super().__init__()
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
+ # self.contract = Contract(gain=2)
+
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
+ return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
+ # return self.conv(self.contract(x))
+
+
+class GhostConv(nn.Module):
+ # Ghost Convolution https://github.com/huawei-noah/ghostnet
+ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
+ super().__init__()
+ c_ = c2 // 2 # hidden channels
+ self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
+ self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
+
+ def forward(self, x):
+ y = self.cv1(x)
+ return torch.cat((y, self.cv2(y)), 1)
+
+
+class GhostBottleneck(nn.Module):
+ # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
+ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
+ super().__init__()
+ c_ = c2 // 2
+ self.conv = nn.Sequential(
+ GhostConv(c1, c_, 1, 1), # pw
+ DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
+ GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
+ self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
+ act=False)) if s == 2 else nn.Identity()
+
+ def forward(self, x):
+ return self.conv(x) + self.shortcut(x)
+
+
+class Concat(nn.Module):
+ # Concatenate a list of tensors along dimension
+ def __init__(self, dimension=1):
+ super().__init__()
+ self.d = dimension
+
+ def forward(self, x):
+ return torch.cat(x, self.d)
+
+
+class AutoShape(nn.Module):
+ # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
+ conf = 0.25 # NMS confidence threshold
+ iou = 0.45 # NMS IoU threshold
+ agnostic = False # NMS class-agnostic
+ multi_label = False # NMS multiple labels per box
+ classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
+ max_det = 1000 # maximum number of detections per image
+ amp = False # Automatic Mixed Precision (AMP) inference
+
+ def __init__(self, model, verbose=True):
+ super().__init__()
+ if verbose:
+ LOGGER.info('Adding AutoShape... ')
+ copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
+ self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance
+ self.pt = not self.dmb or model.pt # PyTorch model
+ self.model = model.eval()
+ if self.pt:
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
+ m.inplace = False # Detect.inplace=False for safe multithread inference
+ m.export = True # do not output loss values
+
+ def _apply(self, fn):
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
+ self = super()._apply(fn)
+ if self.pt:
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
+ m.stride = fn(m.stride)
+ m.grid = list(map(fn, m.grid))
+ if isinstance(m.anchor_grid, list):
+ m.anchor_grid = list(map(fn, m.anchor_grid))
+ return self
+
+ @smart_inference_mode()
+ def forward(self, ims, size=640, augment=False, profile=False):
+ # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
+ # file: ims = 'data/images/zidane.jpg' # str or PosixPath
+ # URI: = 'https://ultralytics.com/images/zidane.jpg'
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
+ # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
+ # numpy: = np.zeros((640,1280,3)) # HWC
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
+
+ dt = (Profile(), Profile(), Profile())
+ with dt[0]:
+ if isinstance(size, int): # expand
+ size = (size, size)
+ p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
+ autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
+ if isinstance(ims, torch.Tensor): # torch
+ with amp.autocast(autocast):
+ return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
+
+ # Pre-process
+ n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
+ for i, im in enumerate(ims):
+ f = f'image{i}' # filename
+ if isinstance(im, (str, Path)): # filename or uri
+ im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
+ im = np.asarray(ImageOps.exif_transpose(im))
+ elif isinstance(im, Image.Image): # PIL Image
+ im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f
+ files.append(Path(f).with_suffix('.jpg').name)
+ if im.shape[0] < 5: # image in CHW
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
+ im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
+ s = im.shape[:2] # HWC
+ shape0.append(s) # image shape
+ g = max(size) / max(s) # gain
+ shape1.append([y * g for y in s])
+ ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
+ shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
+ x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad
+ x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
+
+ with amp.autocast(autocast):
+ # Inference
+ with dt[1]:
+ y = self.model(x, augment=augment) # forward
+
+ # Post-process
+ with dt[2]:
+ y = non_max_suppression(y if self.dmb else y[0],
+ self.conf,
+ self.iou,
+ self.classes,
+ self.agnostic,
+ self.multi_label,
+ max_det=self.max_det) # NMS
+ for i in range(n):
+ scale_boxes(shape1, y[i][:, :4], shape0[i])
+
+ return Detections(ims, y, files, dt, self.names, x.shape)
+
+
+class Detections:
+ # YOLOv5 detections class for inference results
+ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
+ super().__init__()
+ d = pred[0].device # device
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
+ self.ims = ims # list of images as numpy arrays
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
+ self.names = names # class names
+ self.files = files # image filenames
+ self.times = times # profiling times
+ self.xyxy = pred # xyxy pixels
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
+ self.n = len(self.pred) # number of images (batch size)
+ self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
+ self.s = tuple(shape) # inference BCHW shape
+
+ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
+ s, crops = '', []
+ for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
+ s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
+ if pred.shape[0]:
+ for c in pred[:, -1].unique():
+ n = (pred[:, -1] == c).sum() # detections per class
+ s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
+ s = s.rstrip(', ')
+ if show or save or render or crop:
+ annotator = Annotator(im, example=str(self.names))
+ for *box, conf, cls in reversed(pred): # xyxy, confidence, class
+ label = f'{self.names[int(cls)]} {conf:.2f}'
+ if crop:
+ file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
+ crops.append({
+ 'box': box,
+ 'conf': conf,
+ 'cls': cls,
+ 'label': label,
+ 'im': save_one_box(box, im, file=file, save=save)})
+ else: # all others
+ annotator.box_label(box, label if labels else '', color=colors(cls))
+ im = annotator.im
+ else:
+ s += '(no detections)'
+
+ im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
+ if show:
+ im.show(self.files[i]) # show
+ if save:
+ f = self.files[i]
+ im.save(save_dir / f) # save
+ if i == self.n - 1:
+ LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
+ if render:
+ self.ims[i] = np.asarray(im)
+ if pprint:
+ s = s.lstrip('\n')
+ return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
+ if crop:
+ if save:
+ LOGGER.info(f'Saved results to {save_dir}\n')
+ return crops
+
+ def show(self, labels=True):
+ self._run(show=True, labels=labels) # show results
+
+ def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
+ self._run(save=True, labels=labels, save_dir=save_dir) # save results
+
+ def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
+ return self._run(crop=True, save=save, save_dir=save_dir) # crop results
+
+ def render(self, labels=True):
+ self._run(render=True, labels=labels) # render results
+ return self.ims
+
+ def pandas(self):
+ # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
+ new = copy(self) # return copy
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
+ setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
+ return new
+
+ def tolist(self):
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
+ r = range(self.n) # iterable
+ x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
+ # for d in x:
+ # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
+ # setattr(d, k, getattr(d, k)[0]) # pop out of list
+ return x
+
+ def print(self):
+ LOGGER.info(self.__str__())
+
+ def __len__(self): # override len(results)
+ return self.n
+
+ def __str__(self): # override print(results)
+ return self._run(pprint=True) # print results
+
+ def __repr__(self):
+ return f'YOLOv5 {self.__class__} instance\n' + self.__str__()
+
+
+class Proto(nn.Module):
+ # YOLOv8 mask Proto module for segmentation models
+ def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
+ super().__init__()
+ self.cv1 = Conv(c1, c_, k=3)
+ self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
+ self.cv2 = Conv(c_, c_, k=3)
+ self.cv3 = Conv(c_, c2)
+
+ def forward(self, x):
+ return self.cv3(self.cv2(self.upsample(self.cv1(x))))
+
+
+class Ensemble(nn.ModuleList):
+ # Ensemble of models
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, augment=False, profile=False, visualize=False):
+ y = [module(x, augment, profile, visualize)[0] for module in self]
+ # y = torch.stack(y).max(0)[0] # max ensemble
+ # y = torch.stack(y).mean(0) # mean ensemble
+ y = torch.cat(y, 1) # nms ensemble
+ return y, None # inference, train output
+
+
+# heads
+class Detect(nn.Module):
+ # YOLOv5 Detect head for detection models
+ dynamic = False # force grid reconstruction
+ export = False # export mode
+ shape = None
+ anchors = torch.empty(0) # init
+ strides = torch.empty(0) # init
+
+ def __init__(self, nc=80, ch=()): # detection layer
+ super().__init__()
+ self.nc = nc # number of classes
+ self.nl = len(ch) # number of detection layers
+ self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
+ self.stride = torch.zeros(self.nl) # strides computed during build
+
+ c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
+ self.cv2 = nn.ModuleList(
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
+ self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
+
+ def forward(self, x):
+ shape = x[0].shape # BCHW
+ for i in range(self.nl):
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
+ if self.training:
+ return x
+ elif self.dynamic or self.shape != shape:
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
+ self.shape = shape
+
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
+ y = torch.cat((dbox, cls.sigmoid()), 1)
+ return y if self.export else (y, x)
+
+ def bias_init(self):
+ # Initialize Detect() biases, WARNING: requires stride availability
+ m = self # self.model[-1] # Detect() module
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
+ a[-1].bias.data[:] = 1.0 # box
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
+
+
+class Segment(Detect):
+ # YOLOv5 Segment head for segmentation models
+ def __init__(self, nc=80, nm=32, npr=256, ch=()):
+ super().__init__(nc, ch)
+ self.nm = nm # number of masks
+ self.npr = npr # number of protos
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
+ self.detect = Detect.forward
+
+ c4 = max(ch[0] // 4, self.nm)
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
+
+ def forward(self, x):
+ p = self.proto(x[0]) # mask protos
+ bs = p.shape[0] # batch size
+
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
+ x = self.detect(self, x)
+ if self.training:
+ return x, mc, p
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
+
+
+class Classify(nn.Module):
+ # YOLOv5 classification head, i.e. x(b,c1,20,20) to x(b,c2)
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
+ super().__init__()
+ c_ = 1280 # efficientnet_b0 size
+ self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
+ self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
+ self.drop = nn.Dropout(p=0.0, inplace=True)
+ self.linear = nn.Linear(c_, c2) # to x(b,c2)
+
+ def forward(self, x):
+ if isinstance(x, list):
+ x = torch.cat(x, 1)
+ return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f143c14a57c9f3a9502d2687621579e253dca493
--- /dev/null
+++ b/ultralytics/nn/tasks.py
@@ -0,0 +1,416 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import contextlib
+from copy import deepcopy
+
+import thop
+import torch
+import torch.nn as nn
+
+from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
+ Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
+ GhostBottleneck, GhostConv, Segment)
+from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
+from ultralytics.yolo.utils.checks import check_yaml
+from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
+ model_info, scale_img, time_sync)
+
+
+class BaseModel(nn.Module):
+ '''
+ The BaseModel class is a base class for all the models in the Ultralytics YOLO family.
+ '''
+
+ def forward(self, x, profile=False, visualize=False):
+ """
+ > `forward` is a wrapper for `_forward_once` that runs the model on a single scale
+
+ Args:
+ x: the input image
+ profile: whether to profile the model. Defaults to False
+ visualize: if True, will return the intermediate feature maps. Defaults to False
+
+ Returns:
+ The output of the network.
+ """
+ return self._forward_once(x, profile, visualize)
+
+ def _forward_once(self, x, profile=False, visualize=False):
+ """
+ > Forward pass of the network
+
+ Args:
+ x: input to the model
+ profile: if True, the time taken for each layer will be printed. Defaults to False
+ visualize: If True, it will save the feature maps of the model. Defaults to False
+
+ Returns:
+ The last layer of the model.
+ """
+ y, dt = [], [] # outputs
+ for m in self.model:
+ if m.f != -1: # if not from previous layer
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
+ if profile:
+ self._profile_one_layer(m, x, dt)
+ x = m(x) # run
+ y.append(x if m.i in self.save else None) # save output
+ if visualize:
+ pass
+ # TODO: feature_visualization(x, m.type, m.i, save_dir=visualize)
+ return x
+
+ def _profile_one_layer(self, m, x, dt):
+ """
+ It takes a model, an input, and a list of times, and it profiles the model on the input, appending
+ the time to the list
+
+ Args:
+ m: the model
+ x: the input image
+ dt: list of time taken for each layer
+ """
+ c = m == self.model[-1] # is final layer, copy input as inplace fix
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
+ t = time_sync()
+ for _ in range(10):
+ m(x.copy() if c else x)
+ dt.append((time_sync() - t) * 100)
+ if m == self.model[0]:
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
+ LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
+ if c:
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
+
+ def fuse(self):
+ """
+ > It takes a model and fuses the Conv2d() and BatchNorm2d() layers into a single layer
+
+ Returns:
+ The model is being returned.
+ """
+ LOGGER.info('Fusing layers... ')
+ for m in self.model.modules():
+ if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
+ delattr(m, 'bn') # remove batchnorm
+ m.forward = m.forward_fuse # update forward
+ self.info()
+ return self
+
+ def info(self, verbose=False, imgsz=640):
+ """
+ Prints model information
+
+ Args:
+ verbose: if True, prints out the model information. Defaults to False
+ imgsz: the size of the image that the model will be trained on. Defaults to 640
+ """
+ model_info(self, verbose, imgsz)
+
+ def _apply(self, fn):
+ """
+ `_apply()` is a function that applies a function to all the tensors in the model that are not
+ parameters or registered buffers
+
+ Args:
+ fn: the function to apply to the model
+
+ Returns:
+ A model that is a Detect() object.
+ """
+ self = super()._apply(fn)
+ m = self.model[-1] # Detect()
+ if isinstance(m, (Detect, Segment)):
+ m.stride = fn(m.stride)
+ m.anchors = fn(m.anchors)
+ m.strides = fn(m.strides)
+ return self
+
+ def load(self, weights):
+ """
+ > This function loads the weights of the model from a file
+
+ Args:
+ weights: The weights to load into the model.
+ """
+ # Force all tasks to implement this function
+ raise NotImplementedError("This function needs to be implemented by derived classes!")
+
+
+class DetectionModel(BaseModel):
+ # YOLOv5 detection model
+ def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
+ super().__init__()
+ self.yaml = cfg if isinstance(cfg, dict) else yaml_load(check_yaml(cfg), append_filename=True) # cfg dict
+
+ # Define model
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
+ if nc and nc != self.yaml['nc']:
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
+ self.yaml['nc'] = nc # override yaml value
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
+ self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
+ self.inplace = self.yaml.get('inplace', True)
+
+ # Build strides
+ m = self.model[-1] # Detect()
+ if isinstance(m, (Detect, Segment)):
+ s = 256 # 2x min stride
+ m.inplace = self.inplace
+ forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
+ self.stride = m.stride
+ m.bias_init() # only run once
+
+ # Init weights, biases
+ initialize_weights(self)
+ if verbose:
+ self.info()
+ LOGGER.info('')
+
+ def forward(self, x, augment=False, profile=False, visualize=False):
+ if augment:
+ return self._forward_augment(x) # augmented inference, None
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
+
+ def _forward_augment(self, x):
+ img_size = x.shape[-2:] # height, width
+ s = [1, 0.83, 0.67] # scales
+ f = [None, 3, None] # flips (2-ud, 3-lr)
+ y = [] # outputs
+ for si, fi in zip(s, f):
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
+ yi = self._forward_once(xi)[0] # forward
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
+ yi = self._descale_pred(yi, fi, si, img_size)
+ y.append(yi)
+ y = self._clip_augmented(y) # clip augmented tails
+ return torch.cat(y, -1), None # augmented inference, train
+
+ @staticmethod
+ def _descale_pred(p, flips, scale, img_size, dim=1):
+ # de-scale predictions following augmented inference (inverse operation)
+ p[:, :4] /= scale # de-scale
+ x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
+ if flips == 2:
+ y = img_size[0] - y # de-flip ud
+ elif flips == 3:
+ x = img_size[1] - x # de-flip lr
+ return torch.cat((x, y, wh, cls), dim)
+
+ def _clip_augmented(self, y):
+ # Clip YOLOv5 augmented inference tails
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
+ g = sum(4 ** x for x in range(nl)) # grid points
+ e = 1 # exclude layer count
+ i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
+ y[0] = y[0][..., :-i] # large
+ i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
+ y[-1] = y[-1][..., i:] # small
+ return y
+
+ def load(self, weights, verbose=True):
+ csd = weights.float().state_dict() # checkpoint state_dict as FP32
+ csd = intersect_dicts(csd, self.state_dict()) # intersect
+ self.load_state_dict(csd, strict=False) # load
+ if verbose:
+ LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
+
+
+class SegmentationModel(DetectionModel):
+ # YOLOv5 segmentation model
+ def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
+ super().__init__(cfg, ch, nc, verbose)
+
+
+class ClassificationModel(BaseModel):
+ # YOLOv5 classification model
+ def __init__(self,
+ cfg=None,
+ model=None,
+ ch=3,
+ nc=1000,
+ cutoff=10,
+ verbose=True): # yaml, model, number of classes, cutoff index
+ super().__init__()
+ self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
+
+ def _from_detection_model(self, model, nc=1000, cutoff=10):
+ # Create a YOLOv5 classification model from a YOLOv5 detection model
+ from ultralytics.nn.autobackend import AutoBackend
+ if isinstance(model, AutoBackend):
+ model = model.model # unwrap DetectMultiBackend
+ model.model = model.model[:cutoff] # backbone
+ m = model.model[-1] # last layer
+ ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
+ c = Classify(ch, nc) # Classify()
+ c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
+ model.model[-1] = c # replace
+ self.model = model.model
+ self.stride = model.stride
+ self.save = []
+ self.nc = nc
+
+ def _from_yaml(self, cfg, ch, nc, verbose):
+ self.yaml = cfg if isinstance(cfg, dict) else yaml_load(check_yaml(cfg), append_filename=True) # cfg dict
+ # Define model
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
+ if nc and nc != self.yaml['nc']:
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
+ self.yaml['nc'] = nc # override yaml value
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
+ self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
+ self.info()
+
+ def load(self, weights):
+ model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
+ csd = model.float().state_dict()
+ csd = intersect_dicts(csd, self.state_dict()) # intersect
+ self.load_state_dict(csd, strict=False) # load
+
+ @staticmethod
+ def reshape_outputs(model, nc):
+ # Update a TorchVision classification model to class count 'n' if required
+ name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
+ if isinstance(m, Classify): # YOLO Classify() head
+ if m.linear.out_features != nc:
+ m.linear = nn.Linear(m.linear.in_features, nc)
+ elif isinstance(m, nn.Linear): # ResNet, EfficientNet
+ if m.out_features != nc:
+ setattr(model, name, nn.Linear(m.in_features, nc))
+ elif isinstance(m, nn.Sequential):
+ types = [type(x) for x in m]
+ if nn.Linear in types:
+ i = types.index(nn.Linear) # nn.Linear index
+ if m[i].out_features != nc:
+ m[i] = nn.Linear(m[i].in_features, nc)
+ elif nn.Conv2d in types:
+ i = types.index(nn.Conv2d) # nn.Conv2d index
+ if m[i].out_channels != nc:
+ m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
+
+
+# Functions ------------------------------------------------------------------------------------------------------------
+
+
+def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
+ # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
+ from ultralytics.yolo.utils.downloads import attempt_download
+
+ model = Ensemble()
+ for w in weights if isinstance(weights, list) else [weights]:
+ ckpt = torch.load(attempt_download(w), map_location='cpu') # load
+ args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
+ ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
+
+ # Model compatibility updates
+ ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
+ ckpt.pt_path = weights # attach *.pt file path to model
+ if not hasattr(ckpt, 'stride'):
+ ckpt.stride = torch.tensor([32.])
+
+ # Append
+ model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
+
+ # Module compatibility updates
+ for m in model.modules():
+ t = type(m)
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
+ m.inplace = inplace # torch 1.7.0 compatibility
+ elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
+
+ # Return model
+ if len(model) == 1:
+ return model[-1]
+
+ # Return ensemble
+ print(f'Ensemble created with {weights}\n')
+ for k in 'names', 'nc', 'yaml':
+ setattr(model, k, getattr(model[0], k))
+ model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
+ assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
+ return model
+
+
+def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
+ # Loads a single model weights
+ from ultralytics.yolo.utils.downloads import attempt_download
+
+ ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
+ args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
+ model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
+
+ # Model compatibility updates
+ model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
+ model.pt_path = weight # attach *.pt file path to model
+ if not hasattr(model, 'stride'):
+ model.stride = torch.tensor([32.])
+
+ model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
+
+ # Module compatibility updates
+ for m in model.modules():
+ t = type(m)
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
+ m.inplace = inplace # torch 1.7.0 compatibility
+ elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
+
+ # Return model and ckpt
+ return model, ckpt
+
+
+def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
+ # Parse a YOLO model.yaml dictionary
+ if verbose:
+ LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
+ nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
+ if act:
+ Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
+ if verbose:
+ LOGGER.info(f"{colorstr('activation:')} {act}") # print
+
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
+ m = eval(m) if isinstance(m, str) else m # eval strings
+ for j, a in enumerate(args):
+ with contextlib.suppress(NameError):
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
+
+ n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
+ if m in {
+ Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
+ BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
+ c1, c2 = ch[f], args[0]
+ if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
+ c2 = make_divisible(c2 * gw, 8)
+
+ args = [c1, c2, *args[1:]]
+ if m in {BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x}:
+ args.insert(2, n) # number of repeats
+ n = 1
+ elif m is nn.BatchNorm2d:
+ args = [ch[f]]
+ elif m is Concat:
+ c2 = sum(ch[x] for x in f)
+ elif m in {Detect, Segment}:
+ args.append([ch[x] for x in f])
+ if m is Segment:
+ args[2] = make_divisible(args[2] * gw, 8)
+ else:
+ c2 = ch[f]
+
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
+ t = str(m)[8:-2].replace('__main__.', '') # module type
+ m.np = sum(x.numel() for x in m_.parameters()) # number params
+ m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
+ if verbose:
+ LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
+ layers.append(m_)
+ if i == 0:
+ ch = []
+ ch.append(c2)
+ return nn.Sequential(*layers), sorted(save)
diff --git a/ultralytics/yolo/cli.py b/ultralytics/yolo/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddcf7c9aa77b109e836ac4233bf0799de4cba568
--- /dev/null
+++ b/ultralytics/yolo/cli.py
@@ -0,0 +1,52 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import shutil
+from pathlib import Path
+
+import hydra
+
+from ultralytics import hub, yolo
+from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr
+
+DIR = Path(__file__).parent
+
+
+@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), config_name=DEFAULT_CONFIG.name)
+def cli(cfg):
+ """
+ Run a specified task and mode with the given configuration.
+
+ Args:
+ cfg (DictConfig): Configuration for the task and mode.
+ """
+ # LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
+ task, mode = cfg.task.lower(), cfg.mode.lower()
+
+ # Special case for initializing the configuration
+ if task == "init":
+ shutil.copy2(DEFAULT_CONFIG, Path.cwd())
+ LOGGER.info(f"""
+ {colorstr("YOLO:")} configuration saved to {Path.cwd() / DEFAULT_CONFIG.name}.
+ To run experiments using custom configuration:
+ yolo task='task' mode='mode' --config-name config_file.yaml
+ """)
+ return
+
+ # Mapping from task to module
+ task_module_map = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}
+ module = task_module_map.get(task)
+ if not module:
+ raise SyntaxError(f"task not recognized. Choices are {', '.join(task_module_map.keys())}")
+
+ # Mapping from mode to function
+ mode_func_map = {
+ "train": module.train,
+ "val": module.val,
+ "predict": module.predict,
+ "export": yolo.engine.exporter.export,
+ "checks": hub.checks}
+ func = mode_func_map.get(mode)
+ if not func:
+ raise SyntaxError(f"mode not recognized. Choices are {', '.join(mode_func_map.keys())}")
+
+ func(cfg)
diff --git a/ultralytics/yolo/data/__init__.py b/ultralytics/yolo/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebf4293ae14a912b1f4aa93c602c29da880b17f4
--- /dev/null
+++ b/ultralytics/yolo/data/__init__.py
@@ -0,0 +1,6 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from .base import BaseDataset
+from .build import build_classification_dataloader, build_dataloader
+from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
+from .dataset_wrappers import MixAndRectDataset
diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5414da64f8bf85a6bf537dd8034f6d81b46c1f1
--- /dev/null
+++ b/ultralytics/yolo/data/augment.py
@@ -0,0 +1,777 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import math
+import random
+from copy import deepcopy
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as T
+
+from ..utils import LOGGER, colorstr
+from ..utils.checks import check_version
+from ..utils.instance import Instances
+from ..utils.metrics import bbox_ioa
+from ..utils.ops import segment2box
+from .utils import IMAGENET_MEAN, IMAGENET_STD, polygons2masks, polygons2masks_overlap
+
+
+# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
+class BaseTransform:
+
+ def __init__(self) -> None:
+ pass
+
+ def apply_image(self, labels):
+ pass
+
+ def apply_instances(self, labels):
+ pass
+
+ def apply_semantic(self, labels):
+ pass
+
+ def __call__(self, labels):
+ self.apply_image(labels)
+ self.apply_instances(labels)
+ self.apply_semantic(labels)
+
+
+class Compose:
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, data):
+ for t in self.transforms:
+ data = t(data)
+ return data
+
+ def append(self, transform):
+ self.transforms.append(transform)
+
+ def tolist(self):
+ return self.transforms
+
+ def __repr__(self):
+ format_string = f"{self.__class__.__name__}("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += f" {t}"
+ format_string += "\n)"
+ return format_string
+
+
+class BaseMixTransform:
+ """This implementation is from mmyolo"""
+
+ def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
+ self.dataset = dataset
+ self.pre_transform = pre_transform
+ self.p = p
+
+ def __call__(self, labels):
+ if random.uniform(0, 1) > self.p:
+ return labels
+
+ # get index of one or three other images
+ indexes = self.get_indexes()
+ if isinstance(indexes, int):
+ indexes = [indexes]
+
+ # get images information will be used for Mosaic or MixUp
+ mix_labels = [self.dataset.get_label_info(i) for i in indexes]
+
+ if self.pre_transform is not None:
+ for i, data in enumerate(mix_labels):
+ mix_labels[i] = self.pre_transform(data)
+ labels["mix_labels"] = mix_labels
+
+ # Mosaic or MixUp
+ labels = self._mix_transform(labels)
+ labels.pop("mix_labels", None)
+ return labels
+
+ def _mix_transform(self, labels):
+ raise NotImplementedError
+
+ def get_indexes(self):
+ raise NotImplementedError
+
+
+class Mosaic(BaseMixTransform):
+ """Mosaic augmentation.
+ Args:
+ imgsz (Sequence[int]): Image size after mosaic pipeline of single
+ image. The shape order should be (height, width).
+ Default to (640, 640).
+ """
+
+ def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
+ assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
+ super().__init__(dataset=dataset, p=p)
+ self.dataset = dataset
+ self.imgsz = imgsz
+ self.border = border
+
+ def get_indexes(self):
+ return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
+
+ def _mix_transform(self, labels):
+ mosaic_labels = []
+ assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
+ assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
+ s = self.imgsz
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
+ for i in range(4):
+ labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy()
+ # Load image
+ img = labels_patch["img"]
+ h, w = labels_patch["resized_shape"]
+
+ # place img in img4
+ if i == 0: # top left
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
+ elif i == 1: # top right
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
+ elif i == 2: # bottom left
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
+ elif i == 3: # bottom right
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
+
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
+ padw = x1a - x1b
+ padh = y1a - y1b
+
+ labels_patch = self._update_labels(labels_patch, padw, padh)
+ mosaic_labels.append(labels_patch)
+ final_labels = self._cat_labels(mosaic_labels)
+ final_labels["img"] = img4
+ return final_labels
+
+ def _update_labels(self, labels, padw, padh):
+ """Update labels"""
+ nh, nw = labels["img"].shape[:2]
+ labels["instances"].convert_bbox(format="xyxy")
+ labels["instances"].denormalize(nw, nh)
+ labels["instances"].add_padding(padw, padh)
+ return labels
+
+ def _cat_labels(self, mosaic_labels):
+ if len(mosaic_labels) == 0:
+ return {}
+ cls = []
+ instances = []
+ for labels in mosaic_labels:
+ cls.append(labels["cls"])
+ instances.append(labels["instances"])
+ final_labels = {
+ "ori_shape": mosaic_labels[0]["ori_shape"],
+ "resized_shape": (self.imgsz * 2, self.imgsz * 2),
+ "im_file": mosaic_labels[0]["im_file"],
+ "cls": np.concatenate(cls, 0),
+ "instances": Instances.concatenate(instances, axis=0)}
+ final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
+ return final_labels
+
+
+class MixUp(BaseMixTransform):
+
+ def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
+ super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
+
+ def get_indexes(self):
+ return random.randint(0, len(self.dataset) - 1)
+
+ def _mix_transform(self, labels):
+ # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
+ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
+ labels2 = labels["mix_labels"][0]
+ labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
+ labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
+ labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
+ return labels
+
+
+class RandomPerspective:
+
+ def __init__(self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0)):
+ self.degrees = degrees
+ self.translate = translate
+ self.scale = scale
+ self.shear = shear
+ self.perspective = perspective
+ # mosaic border
+ self.border = border
+
+ def affine_transform(self, img):
+ # Center
+ C = np.eye(3)
+
+ C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
+ C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
+
+ # Perspective
+ P = np.eye(3)
+ P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
+ P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
+
+ # Rotation and Scale
+ R = np.eye(3)
+ a = random.uniform(-self.degrees, self.degrees)
+ # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
+ s = random.uniform(1 - self.scale, 1 + self.scale)
+ # s = 2 ** random.uniform(-scale, scale)
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
+
+ # Shear
+ S = np.eye(3)
+ S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
+ S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
+
+ # Translation
+ T = np.eye(3)
+ T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
+ T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
+
+ # Combined rotation matrix
+ M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
+ # affine image
+ if (self.border[0] != 0) or (self.border[1] != 0) or (M != np.eye(3)).any(): # image changed
+ if self.perspective:
+ img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
+ else: # affine
+ img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
+ return img, M, s
+
+ def apply_bboxes(self, bboxes, M):
+ """apply affine to bboxes only.
+
+ Args:
+ bboxes(ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
+ M(ndarray): affine matrix.
+ Returns:
+ new_bboxes(ndarray): bboxes after affine, [num_bboxes, 4].
+ """
+ n = len(bboxes)
+ if n == 0:
+ return bboxes
+
+ xy = np.ones((n * 4, 3))
+ xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
+ xy = xy @ M.T # transform
+ xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
+
+ # create new boxes
+ x = xy[:, [0, 2, 4, 6]]
+ y = xy[:, [1, 3, 5, 7]]
+ return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+
+ def apply_segments(self, segments, M):
+ """apply affine to segments and generate new bboxes from segments.
+
+ Args:
+ segments(ndarray): list of segments, [num_samples, 500, 2].
+ M(ndarray): affine matrix.
+ Returns:
+ new_segments(ndarray): list of segments after affine, [num_samples, 500, 2].
+ new_bboxes(ndarray): bboxes after affine, [N, 4].
+ """
+ n, num = segments.shape[:2]
+ if n == 0:
+ return [], segments
+
+ xy = np.ones((n * num, 3))
+ segments = segments.reshape(-1, 2)
+ xy[:, :2] = segments
+ xy = xy @ M.T # transform
+ xy = xy[:, :2] / xy[:, 2:3]
+ segments = xy.reshape(n, -1, 2)
+ bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
+ return bboxes, segments
+
+ def apply_keypoints(self, keypoints, M):
+ """apply affine to keypoints.
+
+ Args:
+ keypoints(ndarray): keypoints, [N, 17, 2].
+ M(ndarray): affine matrix.
+ Return:
+ new_keypoints(ndarray): keypoints after affine, [N, 17, 2].
+ """
+ n = len(keypoints)
+ if n == 0:
+ return keypoints
+ new_keypoints = np.ones((n * 17, 3))
+ new_keypoints[:, :2] = keypoints.reshape(n * 17, 2) # num_kpt is hardcoded to 17
+ new_keypoints = new_keypoints @ M.T # transform
+ new_keypoints = (new_keypoints[:, :2] / new_keypoints[:, 2:3]).reshape(n, 34) # perspective rescale or affine
+ new_keypoints[keypoints.reshape(-1, 34) == 0] = 0
+ x_kpts = new_keypoints[:, list(range(0, 34, 2))]
+ y_kpts = new_keypoints[:, list(range(1, 34, 2))]
+
+ x_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
+ y_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
+ new_keypoints[:, list(range(0, 34, 2))] = x_kpts
+ new_keypoints[:, list(range(1, 34, 2))] = y_kpts
+ return new_keypoints.reshape(n, 17, 2)
+
+ def __call__(self, labels):
+ """
+ Affine images and targets.
+
+ Args:
+ labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
+ """
+ img = labels["img"]
+ cls = labels["cls"]
+ instances = labels.pop("instances")
+ # make sure the coord formats are right
+ instances.convert_bbox(format="xyxy")
+ instances.denormalize(*img.shape[:2][::-1])
+
+ self.size = img.shape[1] + self.border[1] * 2, img.shape[0] + self.border[0] * 2 # w, h
+ # M is affine matrix
+ # scale for func:`box_candidates`
+ img, M, scale = self.affine_transform(img)
+
+ bboxes = self.apply_bboxes(instances.bboxes, M)
+
+ segments = instances.segments
+ keypoints = instances.keypoints
+ # update bboxes if there are segments.
+ if len(segments):
+ bboxes, segments = self.apply_segments(segments, M)
+
+ if keypoints is not None:
+ keypoints = self.apply_keypoints(keypoints, M)
+ new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
+ # clip
+ new_instances.clip(*self.size)
+
+ # filter instances
+ instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
+ # make the bboxes have the same scale with new_bboxes
+ i = self.box_candidates(box1=instances.bboxes.T,
+ box2=new_instances.bboxes.T,
+ area_thr=0.01 if len(segments) else 0.10)
+ labels["instances"] = new_instances[i]
+ labels["cls"] = cls[i]
+ labels["img"] = img
+ labels["resized_shape"] = img.shape[:2]
+ return labels
+
+ def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
+ # Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
+ w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
+ w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
+ ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
+ return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
+
+
+class RandomHSV:
+
+ def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
+ self.hgain = hgain
+ self.sgain = sgain
+ self.vgain = vgain
+
+ def __call__(self, labels):
+ img = labels["img"]
+ if self.hgain or self.sgain or self.vgain:
+ r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
+ hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
+ dtype = img.dtype # uint8
+
+ x = np.arange(0, 256, dtype=r.dtype)
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
+
+ im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
+ cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
+ return labels
+
+
+class RandomFlip:
+
+ def __init__(self, p=0.5, direction="horizontal") -> None:
+ assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
+ assert 0 <= p <= 1.0
+
+ self.p = p
+ self.direction = direction
+
+ def __call__(self, labels):
+ img = labels["img"]
+ instances = labels.pop("instances")
+ instances.convert_bbox(format="xywh")
+ h, w = img.shape[:2]
+ h = 1 if instances.normalized else h
+ w = 1 if instances.normalized else w
+
+ # Flip up-down
+ if self.direction == "vertical" and random.random() < self.p:
+ img = np.flipud(img)
+ instances.flipud(h)
+ if self.direction == "horizontal" and random.random() < self.p:
+ img = np.fliplr(img)
+ instances.fliplr(w)
+ labels["img"] = np.ascontiguousarray(img)
+ labels["instances"] = instances
+ return labels
+
+
+class LetterBox:
+ """Resize image and padding for detection, instance segmentation, pose"""
+
+ def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
+ self.new_shape = new_shape
+ self.auto = auto
+ self.scaleFill = scaleFill
+ self.scaleup = scaleup
+ self.stride = stride
+
+ def __call__(self, labels=None, image=None):
+ if labels is None:
+ labels = {}
+ img = labels.get("img") if image is None else image
+ shape = img.shape[:2] # current shape [height, width]
+ new_shape = labels.pop("rect_shape", self.new_shape)
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not self.scaleup: # only scale down, do not scale up (for better val mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if self.auto: # minimum rectangle
+ dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
+ elif self.scaleFill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+ if labels.get("ratio_pad"):
+ labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) # for evaluation
+
+ if shape[::-1] != new_unpad: # resize
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
+ value=(114, 114, 114)) # add border
+
+ if len(labels):
+ labels = self._update_labels(labels, ratio, dw, dh)
+ labels["img"] = img
+ labels["resized_shape"] = new_shape
+ return labels
+ else:
+ return img
+
+ def _update_labels(self, labels, ratio, padw, padh):
+ """Update labels"""
+ labels["instances"].convert_bbox(format="xyxy")
+ labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
+ labels["instances"].scale(*ratio)
+ labels["instances"].add_padding(padw, padh)
+ return labels
+
+
+class CopyPaste:
+
+ def __init__(self, p=0.5) -> None:
+ self.p = p
+
+ def __call__(self, labels):
+ # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
+ im = labels["img"]
+ cls = labels["cls"]
+ instances = labels.pop("instances")
+ instances.convert_bbox(format="xyxy")
+ if self.p and len(instances.segments):
+ n = len(instances)
+ _, w, _ = im.shape # height, width, channels
+ im_new = np.zeros(im.shape, np.uint8)
+
+ # calculate ioa first then select indexes randomly
+ ins_flip = deepcopy(instances)
+ ins_flip.fliplr(w)
+
+ ioa = bbox_ioa(ins_flip.bboxes, instances.bboxes) # intersection over area, (N, M)
+ indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
+ n = len(indexes)
+ for j in random.sample(list(indexes), k=round(self.p * n)):
+ cls = np.concatenate((cls, cls[[j]]), axis=0)
+ instances = Instances.concatenate((instances, ins_flip[[j]]), axis=0)
+ cv2.drawContours(im_new, instances.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
+
+ result = cv2.flip(im, 1) # augment segments (flip left-right)
+ i = cv2.flip(im_new, 1).astype(bool)
+ im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
+
+ labels["img"] = im
+ labels["cls"] = cls
+ labels["instances"] = instances
+ return labels
+
+
+class Albumentations:
+ # YOLOv5 Albumentations class (optional, only used if package is installed)
+ def __init__(self, p=1.0):
+ self.p = p
+ self.transform = None
+ prefix = colorstr("albumentations: ")
+ try:
+ import albumentations as A
+
+ check_version(A.__version__, "1.0.3", hard=True) # version requirement
+
+ T = [
+ A.Blur(p=0.01),
+ A.MedianBlur(p=0.01),
+ A.ToGray(p=0.01),
+ A.CLAHE(p=0.01),
+ A.RandomBrightnessContrast(p=0.0),
+ A.RandomGamma(p=0.0),
+ A.ImageCompression(quality_lower=75, p=0.0),] # transforms
+ self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
+
+ LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
+ except ImportError: # package not installed, skip
+ pass
+ except Exception as e:
+ LOGGER.info(f"{prefix}{e}")
+
+ def __call__(self, labels):
+ im = labels["img"]
+ cls = labels["cls"]
+ if len(cls):
+ labels["instances"].convert_bbox("xywh")
+ labels["instances"].normalize(*im.shape[:2][::-1])
+ bboxes = labels["instances"].bboxes
+ # TODO: add supports of segments and keypoints
+ if self.transform and random.random() < self.p:
+ new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
+ labels["img"] = new["image"]
+ labels["cls"] = np.array(new["class_labels"])
+ labels["instances"].update(bboxes=bboxes)
+ return labels
+
+
+# TODO: technically this is not an augmentation, maybe we should put this to another files
+class Format:
+
+ def __init__(self,
+ bbox_format="xywh",
+ normalize=True,
+ return_mask=False,
+ return_keypoint=False,
+ mask_ratio=4,
+ mask_overlap=True,
+ batch_idx=True):
+ self.bbox_format = bbox_format
+ self.normalize = normalize
+ self.return_mask = return_mask # set False when training detection only
+ self.return_keypoint = return_keypoint
+ self.mask_ratio = mask_ratio
+ self.mask_overlap = mask_overlap
+ self.batch_idx = batch_idx # keep the batch indexes
+
+ def __call__(self, labels):
+ img = labels["img"]
+ h, w = img.shape[:2]
+ cls = labels.pop("cls")
+ instances = labels.pop("instances")
+ instances.convert_bbox(format=self.bbox_format)
+ instances.denormalize(w, h)
+ nl = len(instances)
+
+ if self.return_mask:
+ if nl:
+ masks, instances, cls = self._format_segments(instances, cls, w, h)
+ masks = torch.from_numpy(masks)
+ else:
+ masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
+ img.shape[1] // self.mask_ratio)
+ labels["masks"] = masks
+ if self.normalize:
+ instances.normalize(w, h)
+ labels["img"] = self._format_img(img)
+ labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
+ labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
+ if self.return_keypoint:
+ labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
+ # then we can use collate_fn
+ if self.batch_idx:
+ labels["batch_idx"] = torch.zeros(nl)
+ return labels
+
+ def _format_img(self, img):
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
+ img = torch.from_numpy(img)
+ return img
+
+ def _format_segments(self, instances, cls, w, h):
+ """convert polygon points to bitmap"""
+ segments = instances.segments
+ if self.mask_overlap:
+ masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
+ masks = masks[None] # (640, 640) -> (1, 640, 640)
+ instances = instances[sorted_idx]
+ cls = cls[sorted_idx]
+ else:
+ masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
+
+ return masks, instances, cls
+
+
+def mosaic_transforms(dataset, imgsz, hyp):
+ pre_transform = Compose([
+ Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
+ CopyPaste(p=hyp.copy_paste),
+ RandomPerspective(
+ degrees=hyp.degrees,
+ translate=hyp.translate,
+ scale=hyp.scale,
+ shear=hyp.shear,
+ perspective=hyp.perspective,
+ border=[-imgsz // 2, -imgsz // 2],
+ ),])
+ return Compose([
+ pre_transform,
+ MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
+ Albumentations(p=1.0),
+ RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
+ RandomFlip(direction="vertical", p=hyp.flipud),
+ RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
+
+
+def affine_transforms(imgsz, hyp):
+ return Compose([
+ LetterBox(new_shape=(imgsz, imgsz)),
+ RandomPerspective(
+ degrees=hyp.degrees,
+ translate=hyp.translate,
+ scale=hyp.scale,
+ shear=hyp.shear,
+ perspective=hyp.perspective,
+ border=[0, 0],
+ ),
+ Albumentations(p=1.0),
+ RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
+ RandomFlip(direction="vertical", p=hyp.flipud),
+ RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
+
+
+# Classification augmentations -----------------------------------------------------------------------------------------
+def classify_transforms(size=224):
+ # Transforms to apply if albumentations not installed
+ assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
+ # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
+ return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
+
+
+def classify_albumentations(
+ augment=True,
+ size=224,
+ scale=(0.08, 1.0),
+ hflip=0.5,
+ vflip=0.0,
+ jitter=0.4,
+ mean=IMAGENET_MEAN,
+ std=IMAGENET_STD,
+ auto_aug=False,
+):
+ # YOLOv5 classification Albumentations (optional, only used if package is installed)
+ prefix = colorstr("albumentations: ")
+ try:
+ import albumentations as A
+ from albumentations.pytorch import ToTensorV2
+
+ check_version(A.__version__, "1.0.3", hard=True) # version requirement
+ if augment: # Resize and crop
+ T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
+ if auto_aug:
+ # TODO: implement AugMix, AutoAug & RandAug in albumentation
+ LOGGER.info(f"{prefix}auto augmentations are currently not supported")
+ else:
+ if hflip > 0:
+ T += [A.HorizontalFlip(p=hflip)]
+ if vflip > 0:
+ T += [A.VerticalFlip(p=vflip)]
+ if jitter > 0:
+ color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, saturation, 0 hue
+ T += [A.ColorJitter(*color_jitter, 0)]
+ else: # Use fixed crop for eval set (reproducibility)
+ T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
+ T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
+ LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
+ return A.Compose(T)
+
+ except ImportError: # package not installed, skip
+ pass
+ except Exception as e:
+ LOGGER.info(f"{prefix}{e}")
+
+
+class ClassifyLetterBox:
+ # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
+ def __init__(self, size=(640, 640), auto=False, stride=32):
+ super().__init__()
+ self.h, self.w = (size, size) if isinstance(size, int) else size
+ self.auto = auto # pass max size integer, automatically solve for short side using stride
+ self.stride = stride # used with auto
+
+ def __call__(self, im): # im = np.array HWC
+ imh, imw = im.shape[:2]
+ r = min(self.h / imh, self.w / imw) # ratio of new/old
+ h, w = round(imh * r), round(imw * r) # resized image
+ hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
+ top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
+ im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
+ im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
+ return im_out
+
+
+class CenterCrop:
+ # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
+ def __init__(self, size=640):
+ super().__init__()
+ self.h, self.w = (size, size) if isinstance(size, int) else size
+
+ def __call__(self, im): # im = np.array HWC
+ imh, imw = im.shape[:2]
+ m = min(imh, imw) # min dimension
+ top, left = (imh - m) // 2, (imw - m) // 2
+ return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
+
+
+class ToTensor:
+ # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
+ def __init__(self, half=False):
+ super().__init__()
+ self.half = half
+
+ def __call__(self, im): # im = np.array HWC in BGR order
+ im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
+ im = torch.from_numpy(im) # to torch
+ im = im.half() if self.half else im.float() # uint8 to fp16/32
+ im /= 255.0 # 0-255 to 0.0-1.0
+ return im
diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..1745f8f6f79bb3dafeebcb88f0e07750aa5ada2f
--- /dev/null
+++ b/ultralytics/yolo/data/base.py
@@ -0,0 +1,226 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import glob
+import math
+import os
+from multiprocessing.pool import ThreadPool
+from pathlib import Path
+from typing import Optional
+
+import cv2
+import numpy as np
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
+from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK
+
+
+class BaseDataset(Dataset):
+ """Base Dataset.
+ Args:
+ img_path (str): image path.
+ pipeline (dict): a dict of image transforms.
+ label_path (str): label path, this can also be an ann_file or other custom label path.
+ """
+
+ def __init__(
+ self,
+ img_path,
+ imgsz=640,
+ label_path=None,
+ cache=False,
+ augment=True,
+ hyp=None,
+ prefix="",
+ rect=False,
+ batch_size=None,
+ stride=32,
+ pad=0.5,
+ single_cls=False,
+ ):
+ super().__init__()
+ self.img_path = img_path
+ self.imgsz = imgsz
+ self.label_path = label_path
+ self.augment = augment
+ self.single_cls = single_cls
+ self.prefix = prefix
+
+ self.im_files = self.get_img_files(self.img_path)
+ self.labels = self.get_labels()
+ if self.single_cls:
+ self.update_labels(include_class=[])
+
+ self.ni = len(self.labels)
+
+ # rect stuff
+ self.rect = rect
+ self.batch_size = batch_size
+ self.stride = stride
+ self.pad = pad
+ if self.rect:
+ assert self.batch_size is not None
+ self.set_rectangle()
+
+ # cache stuff
+ self.ims = [None] * self.ni
+ self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
+ if cache:
+ self.cache_images(cache)
+
+ # transforms
+ self.transforms = self.build_transforms(hyp=hyp)
+
+ def get_img_files(self, img_path):
+ """Read image files."""
+ try:
+ f = [] # image files
+ for p in img_path if isinstance(img_path, list) else [img_path]:
+ p = Path(p) # os-agnostic
+ if p.is_dir(): # dir
+ f += glob.glob(str(p / "**" / "*.*"), recursive=True)
+ # f = list(p.rglob('*.*')) # pathlib
+ elif p.is_file(): # file
+ with open(p) as t:
+ t = t.read().strip().splitlines()
+ parent = str(p.parent) + os.sep
+ f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
+ # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
+ else:
+ raise FileNotFoundError(f"{self.prefix}{p} does not exist")
+ im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
+ # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
+ assert im_files, f"{self.prefix}No images found"
+ except Exception as e:
+ raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}: {e}\n{HELP_URL}") from e
+ return im_files
+
+ def update_labels(self, include_class: Optional[list]):
+ """include_class, filter labels to include only these classes (optional)"""
+ include_class_array = np.array(include_class).reshape(1, -1)
+ for i in range(len(self.labels)):
+ if include_class:
+ cls = self.labels[i]["cls"]
+ bboxes = self.labels[i]["bboxes"]
+ segments = self.labels[i]["segments"]
+ j = (cls == include_class_array).any(1)
+ self.labels[i]["cls"] = cls[j]
+ self.labels[i]["bboxes"] = bboxes[j]
+ if segments:
+ self.labels[i]["segments"] = segments[j]
+ if self.single_cls:
+ self.labels[i]["cls"] = 0
+
+ def load_image(self, i):
+ # Loads 1 image from dataset index 'i', returns (im, resized hw)
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
+ if im is None: # not cached in RAM
+ if fn.exists(): # load npy
+ im = np.load(fn)
+ else: # read image
+ im = cv2.imread(f) # BGR
+ assert im is not None, f"Image Not Found {f}"
+ h0, w0 = im.shape[:2] # orig hw
+ r = self.imgsz / max(h0, w0) # ratio
+ if r != 1: # if sizes are not equal
+ interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
+ im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
+ return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
+ return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
+
+ def cache_images(self, cache):
+ # cache images to memory or disk
+ gb = 0 # Gigabytes of cached images
+ self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
+ fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
+ results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
+ pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
+ for i, x in pbar:
+ if cache == "disk":
+ gb += self.npy_files[i].stat().st_size
+ else: # 'ram'
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
+ gb += self.ims[i].nbytes
+ pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
+ pbar.close()
+
+ def cache_images_to_disk(self, i):
+ # Saves an image as an *.npy file for faster loading
+ f = self.npy_files[i]
+ if not f.exists():
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
+
+ def set_rectangle(self):
+ bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
+ nb = bi[-1] + 1 # number of batches
+
+ s = np.array([x.pop("shape") for x in self.labels]) # hw
+ ar = s[:, 0] / s[:, 1] # aspect ratio
+ irect = ar.argsort()
+ self.im_files = [self.im_files[i] for i in irect]
+ self.labels = [self.labels[i] for i in irect]
+ ar = ar[irect]
+
+ # Set training image shapes
+ shapes = [[1, 1]] * nb
+ for i in range(nb):
+ ari = ar[bi == i]
+ mini, maxi = ari.min(), ari.max()
+ if maxi < 1:
+ shapes[i] = [maxi, 1]
+ elif mini > 1:
+ shapes[i] = [1, 1 / mini]
+
+ self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
+ self.batch = bi # batch index of image
+
+ def __getitem__(self, index):
+ return self.transforms(self.get_label_info(index))
+
+ def get_label_info(self, index):
+ label = self.labels[index].copy()
+ label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
+ label["ratio_pad"] = (
+ label["resized_shape"][0] / label["ori_shape"][0],
+ label["resized_shape"][1] / label["ori_shape"][1],
+ ) # for evaluation
+ if self.rect:
+ label["rect_shape"] = self.batch_shapes[self.batch[index]]
+ label = self.update_labels_info(label)
+ return label
+
+ def __len__(self):
+ return len(self.im_files)
+
+ def update_labels_info(self, label):
+ """custom your label format here"""
+ return label
+
+ def build_transforms(self, hyp=None):
+ """Users can custom augmentations here
+ like:
+ if self.augment:
+ # training transforms
+ return Compose([])
+ else:
+ # val transforms
+ return Compose([])
+ """
+ raise NotImplementedError
+
+ def get_labels(self):
+ """Users can custom their own format here.
+ Make sure your output is a list with each element like below:
+ dict(
+ im_file=im_file,
+ shape=shape, # format: (height, width)
+ cls=cls,
+ bboxes=bboxes, # xywh
+ segments=segments, # xy
+ keypoints=keypoints, # xy
+ normalized=True, # or False
+ bbox_format="xyxy", # or xywh, ltwh
+ )
+ """
+ raise NotImplementedError
diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..144d01e2562f04a248c761059196b81705d2604f
--- /dev/null
+++ b/ultralytics/yolo/data/build.py
@@ -0,0 +1,125 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import os
+import random
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, dataloader, distributed
+
+from ..utils import LOGGER, colorstr
+from ..utils.torch_utils import torch_distributed_zero_first
+from .dataset import ClassificationDataset, YOLODataset
+from .utils import PIN_MEMORY, RANK
+
+
+class InfiniteDataLoader(dataloader.DataLoader):
+ """Dataloader that reuses workers
+
+ Uses same syntax as vanilla DataLoader
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
+ self.iterator = super().__iter__()
+
+ def __len__(self):
+ return len(self.batch_sampler.sampler)
+
+ def __iter__(self):
+ for _ in range(len(self)):
+ yield next(self.iterator)
+
+
+class _RepeatSampler:
+ """Sampler that repeats forever
+
+ Args:
+ sampler (Sampler)
+ """
+
+ def __init__(self, sampler):
+ self.sampler = sampler
+
+ def __iter__(self):
+ while True:
+ yield from iter(self.sampler)
+
+
+def seed_worker(worker_id):
+ # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
+ worker_seed = torch.initial_seed() % 2 ** 32
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+
+
+def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
+ assert mode in ["train", "val"]
+ shuffle = mode == "train"
+ if cfg.rect and shuffle:
+ LOGGER.warning("WARNING β οΈ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
+ shuffle = False
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
+ dataset = YOLODataset(
+ img_path=img_path,
+ label_path=label_path,
+ imgsz=cfg.imgsz,
+ batch_size=batch_size,
+ augment=mode == "train", # augmentation
+ hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
+ rect=cfg.rect if mode == "train" else True, # rectangular batches
+ cache=cfg.get("cache", None),
+ single_cls=cfg.get("single_cls", False),
+ stride=int(stride),
+ pad=0.0 if mode == "train" else 0.5,
+ prefix=colorstr(f"{mode}: "),
+ use_segments=cfg.task == "segment",
+ use_keypoints=cfg.task == "keypoint")
+
+ batch_size = min(batch_size, len(dataset))
+ nd = torch.cuda.device_count() # number of CUDA devices
+ workers = cfg.workers if mode == "train" else cfg.workers * 2
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
+ loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
+ generator = torch.Generator()
+ generator.manual_seed(6148914691236517205 + RANK)
+ return loader(dataset=dataset,
+ batch_size=batch_size,
+ shuffle=shuffle and sampler is None,
+ num_workers=nw,
+ sampler=sampler,
+ pin_memory=PIN_MEMORY,
+ collate_fn=getattr(dataset, "collate_fn", None),
+ worker_init_fn=seed_worker,
+ generator=generator), dataset
+
+
+# build classification
+# TODO: using cfg like `build_dataloader`
+def build_classification_dataloader(path,
+ imgsz=224,
+ batch_size=16,
+ augment=True,
+ cache=False,
+ rank=-1,
+ workers=8,
+ shuffle=True):
+ # Returns Dataloader object to be used with YOLOv5 Classifier
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
+ dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
+ batch_size = min(batch_size, len(dataset))
+ nd = torch.cuda.device_count()
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
+ generator = torch.Generator()
+ generator.manual_seed(6148914691236517205 + RANK)
+ return InfiniteDataLoader(dataset,
+ batch_size=batch_size,
+ shuffle=shuffle and sampler is None,
+ num_workers=nw,
+ sampler=sampler,
+ pin_memory=PIN_MEMORY,
+ worker_init_fn=seed_worker,
+ generator=generator) # or DataLoader(persistent_workers=True)
diff --git a/ultralytics/yolo/data/dataloaders/__init__.py b/ultralytics/yolo/data/dataloaders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ultralytics/yolo/data/dataloaders/stream_loaders.py b/ultralytics/yolo/data/dataloaders/stream_loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..6365cb0fa50ad1ad3781051c412975cf8f6d0d1c
--- /dev/null
+++ b/ultralytics/yolo/data/dataloaders/stream_loaders.py
@@ -0,0 +1,256 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import glob
+import math
+import os
+import time
+from pathlib import Path
+from threading import Thread
+from urllib.parse import urlparse
+
+import cv2
+import numpy as np
+import torch
+
+from ultralytics.yolo.data.augment import LetterBox
+from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
+from ultralytics.yolo.utils import LOGGER, is_colab, is_kaggle, ops
+from ultralytics.yolo.utils.checks import check_requirements
+
+
+class LoadStreams:
+ # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
+ def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
+ self.mode = 'stream'
+ self.imgsz = imgsz
+ self.stride = stride
+ self.vid_stride = vid_stride # video frame-rate stride
+ sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
+ n = len(sources)
+ self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
+ self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
+ for i, s in enumerate(sources): # index, source
+ # Start thread to read frames from video stream
+ st = f'{i + 1}/{n}: {s}... '
+ if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
+ # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
+ check_requirements(('pafy', 'youtube_dl==2020.12.2'))
+ import pafy
+ s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
+ s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
+ if s == 0:
+ assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
+ assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
+ cap = cv2.VideoCapture(s)
+ assert cap.isOpened(), f'{st}Failed to open {s}'
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
+ self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
+ self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
+
+ _, self.imgs[i] = cap.read() # guarantee first frame
+ self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
+ LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
+ self.threads[i].start()
+ LOGGER.info('') # newline
+
+ # check for common shapes
+ s = np.stack([LetterBox(imgsz, auto, stride=stride)(image=x).shape for x in self.imgs])
+ self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
+ self.auto = auto and self.rect
+ self.transforms = transforms # optional
+ if not self.rect:
+ LOGGER.warning('WARNING β οΈ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
+
+ def update(self, i, cap, stream):
+ # Read stream `i` frames in daemon thread
+ n, f = 0, self.frames[i] # frame number, frame array
+ while cap.isOpened() and n < f:
+ n += 1
+ cap.grab() # .read() = .grab() followed by .retrieve()
+ if n % self.vid_stride == 0:
+ success, im = cap.retrieve()
+ if success:
+ self.imgs[i] = im
+ else:
+ LOGGER.warning('WARNING β οΈ Video stream unresponsive, please check your IP camera connection.')
+ self.imgs[i] = np.zeros_like(self.imgs[i])
+ cap.open(stream) # re-open stream if signal was lost
+ time.sleep(0.0) # wait time
+
+ def __iter__(self):
+ self.count = -1
+ return self
+
+ def __next__(self):
+ self.count += 1
+ if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
+ cv2.destroyAllWindows()
+ raise StopIteration
+
+ im0 = self.imgs.copy()
+ if self.transforms:
+ im = np.stack([self.transforms(x) for x in im0]) # transforms
+ else:
+ im = np.stack([LetterBox(self.imgsz, self.auto, stride=self.stride)(image=x) for x in im0])
+ im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
+ im = np.ascontiguousarray(im) # contiguous
+
+ return self.sources, im, im0, None, ''
+
+ def __len__(self):
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
+
+
+class LoadScreenshots:
+ # YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
+ def __init__(self, source, imgsz=640, stride=32, auto=True, transforms=None):
+ # source = [screen_number left top width height] (pixels)
+ check_requirements('mss')
+ import mss
+
+ source, *params = source.split()
+ self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
+ if len(params) == 1:
+ self.screen = int(params[0])
+ elif len(params) == 4:
+ left, top, width, height = (int(x) for x in params)
+ elif len(params) == 5:
+ self.screen, left, top, width, height = (int(x) for x in params)
+ self.imgsz = imgsz
+ self.stride = stride
+ self.transforms = transforms
+ self.auto = auto
+ self.mode = 'stream'
+ self.frame = 0
+ self.sct = mss.mss()
+
+ # Parse monitor shape
+ monitor = self.sct.monitors[self.screen]
+ self.top = monitor["top"] if top is None else (monitor["top"] + top)
+ self.left = monitor["left"] if left is None else (monitor["left"] + left)
+ self.width = width or monitor["width"]
+ self.height = height or monitor["height"]
+ self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ # mss screen capture: get raw pixels from the screen as np array
+ im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
+ s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
+
+ if self.transforms:
+ im = self.transforms(im0) # transforms
+ else:
+ im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+ self.frame += 1
+ return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
+
+
+class LoadImages:
+ # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
+ def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
+ if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
+ path = Path(path).read_text().rsplit()
+ files = []
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
+ p = str(Path(p).resolve())
+ if '*' in p:
+ files.extend(sorted(glob.glob(p, recursive=True))) # glob
+ elif os.path.isdir(p):
+ files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
+ elif os.path.isfile(p):
+ files.append(p) # files
+ else:
+ raise FileNotFoundError(f'{p} does not exist')
+
+ images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
+ videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
+ ni, nv = len(images), len(videos)
+
+ self.imgsz = imgsz
+ self.stride = stride
+ self.files = images + videos
+ self.nf = ni + nv # number of files
+ self.video_flag = [False] * ni + [True] * nv
+ self.mode = 'image'
+ self.auto = auto
+ self.transforms = transforms # optional
+ self.vid_stride = vid_stride # video frame-rate stride
+ if any(videos):
+ self._new_video(videos[0]) # new video
+ else:
+ self.cap = None
+ assert self.nf > 0, f'No images or videos found in {p}. ' \
+ f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
+
+ def __iter__(self):
+ self.count = 0
+ return self
+
+ def __next__(self):
+ if self.count == self.nf:
+ raise StopIteration
+ path = self.files[self.count]
+
+ if self.video_flag[self.count]:
+ # Read video
+ self.mode = 'video'
+ for _ in range(self.vid_stride):
+ self.cap.grab()
+ ret_val, im0 = self.cap.retrieve()
+ while not ret_val:
+ self.count += 1
+ self.cap.release()
+ if self.count == self.nf: # last video
+ raise StopIteration
+ path = self.files[self.count]
+ self._new_video(path)
+ ret_val, im0 = self.cap.read()
+
+ self.frame += 1
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
+ s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
+
+ else:
+ # Read image
+ self.count += 1
+ im0 = cv2.imread(path) # BGR
+ assert im0 is not None, f'Image Not Found {path}'
+ s = f'image {self.count}/{self.nf} {path}: '
+
+ if self.transforms:
+ im = self.transforms(im0) # transforms
+ else:
+ im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+
+ return path, im, im0, self.cap, s
+
+ def _new_video(self, path):
+ # Create a new video capture object
+ self.frame = 0
+ self.cap = cv2.VideoCapture(path)
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
+ self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
+
+ def _cv2_rotate(self, im):
+ # Rotate a cv2 video manually
+ if self.orientation == 0:
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
+ elif self.orientation == 180:
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
+ elif self.orientation == 90:
+ return cv2.rotate(im, cv2.ROTATE_180)
+ return im
+
+ def __len__(self):
+ return self.nf # number of files
diff --git a/ultralytics/yolo/data/dataloaders/v5augmentations.py b/ultralytics/yolo/data/dataloaders/v5augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..0595d7a9298cc961441bebafa5149d5ff5679010
--- /dev/null
+++ b/ultralytics/yolo/data/dataloaders/v5augmentations.py
@@ -0,0 +1,402 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Image augmentation functions
+"""
+
+import math
+import random
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+
+from ultralytics.yolo.utils import LOGGER, colorstr
+from ultralytics.yolo.utils.checks import check_version
+from ultralytics.yolo.utils.metrics import bbox_ioa
+from ultralytics.yolo.utils.ops import resample_segments, segment2box, xywhn2xyxy
+
+IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
+IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
+
+
+class Albumentations:
+ # YOLOv5 Albumentations class (optional, only used if package is installed)
+ def __init__(self, size=640):
+ self.transform = None
+ prefix = colorstr('albumentations: ')
+ try:
+ import albumentations as A
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
+
+ T = [
+ A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
+ A.Blur(p=0.01),
+ A.MedianBlur(p=0.01),
+ A.ToGray(p=0.01),
+ A.CLAHE(p=0.01),
+ A.RandomBrightnessContrast(p=0.0),
+ A.RandomGamma(p=0.0),
+ A.ImageCompression(quality_lower=75, p=0.0)] # transforms
+ self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
+
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
+ except ImportError: # package not installed, skip
+ pass
+ except Exception as e:
+ LOGGER.info(f'{prefix}{e}')
+
+ def __call__(self, im, labels, p=1.0):
+ if self.transform and random.random() < p:
+ new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
+ im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
+ return im, labels
+
+
+def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
+ # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std
+ return TF.normalize(x, mean, std, inplace=inplace)
+
+
+def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
+ # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean
+ for i in range(3):
+ x[:, i] = x[:, i] * std[i] + mean[i]
+ return x
+
+
+def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
+ # HSV color-space augmentation
+ if hgain or sgain or vgain:
+ r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
+ hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
+ dtype = im.dtype # uint8
+
+ x = np.arange(0, 256, dtype=r.dtype)
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
+
+ im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
+ cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
+
+
+def hist_equalize(im, clahe=True, bgr=False):
+ # Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
+ yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
+ if clahe:
+ c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
+ yuv[:, :, 0] = c.apply(yuv[:, :, 0])
+ else:
+ yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
+ return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
+
+
+def replicate(im, labels):
+ # Replicate labels
+ h, w = im.shape[:2]
+ boxes = labels[:, 1:].astype(int)
+ x1, y1, x2, y2 = boxes.T
+ s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
+ for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
+ x1b, y1b, x2b, y2b = boxes[i]
+ bh, bw = y2b - y1b, x2b - x1b
+ yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
+ x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
+ im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
+ labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
+
+ return im, labels
+
+
+def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
+ # Resize and pad image while meeting stride-multiple constraints
+ shape = im.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
+ elif scaleFill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return im, ratio, (dw, dh)
+
+
+def random_perspective(im,
+ targets=(),
+ segments=(),
+ degrees=10,
+ translate=.1,
+ scale=.1,
+ shear=10,
+ perspective=0.0,
+ border=(0, 0)):
+ # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
+ # targets = [cls, xyxy]
+
+ height = im.shape[0] + border[0] * 2 # shape(h,w,c)
+ width = im.shape[1] + border[1] * 2
+
+ # Center
+ C = np.eye(3)
+ C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
+ C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
+
+ # Perspective
+ P = np.eye(3)
+ P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
+ P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
+
+ # Rotation and Scale
+ R = np.eye(3)
+ a = random.uniform(-degrees, degrees)
+ # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
+ s = random.uniform(1 - scale, 1 + scale)
+ # s = 2 ** random.uniform(-scale, scale)
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
+
+ # Shear
+ S = np.eye(3)
+ S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
+ S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
+
+ # Translation
+ T = np.eye(3)
+ T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
+ T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
+
+ # Combined rotation matrix
+ M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
+ if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
+ if perspective:
+ im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
+ else: # affine
+ im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
+
+ # Visualize
+ # import matplotlib.pyplot as plt
+ # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
+ # ax[0].imshow(im[:, :, ::-1]) # base
+ # ax[1].imshow(im2[:, :, ::-1]) # warped
+
+ # Transform label coordinates
+ n = len(targets)
+ if n:
+ use_segments = any(x.any() for x in segments)
+ new = np.zeros((n, 4))
+ if use_segments: # warp segments
+ segments = resample_segments(segments) # upsample
+ for i, segment in enumerate(segments):
+ xy = np.ones((len(segment), 3))
+ xy[:, :2] = segment
+ xy = xy @ M.T # transform
+ xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
+
+ # clip
+ new[i] = segment2box(xy, width, height)
+
+ else: # warp boxes
+ xy = np.ones((n * 4, 3))
+ xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
+ xy = xy @ M.T # transform
+ xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
+
+ # create new boxes
+ x = xy[:, [0, 2, 4, 6]]
+ y = xy[:, [1, 3, 5, 7]]
+ new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+
+ # clip
+ new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
+ new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
+
+ # filter candidates
+ i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
+ targets = targets[i]
+ targets[:, 1:5] = new[i]
+
+ return im, targets
+
+
+def copy_paste(im, labels, segments, p=0.5):
+ # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
+ n = len(segments)
+ if p and n:
+ h, w, c = im.shape # height, width, channels
+ im_new = np.zeros(im.shape, np.uint8)
+
+ # calculate ioa first then select indexes randomly
+ boxes = np.stack([w - labels[:, 3], labels[:, 2], w - labels[:, 1], labels[:, 4]], axis=-1) # (n, 4)
+ ioa = bbox_ioa(boxes, labels[:, 1:5]) # intersection over area
+ indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
+ n = len(indexes)
+ for j in random.sample(list(indexes), k=round(p * n)):
+ l, box, s = labels[j], boxes[j], segments[j]
+ labels = np.concatenate((labels, [[l[0], *box]]), 0)
+ segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
+ cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (1, 1, 1), cv2.FILLED)
+
+ result = cv2.flip(im, 1) # augment segments (flip left-right)
+ i = cv2.flip(im_new, 1).astype(bool)
+ im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
+
+ return im, labels, segments
+
+
+def cutout(im, labels, p=0.5):
+ # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
+ if random.random() < p:
+ h, w = im.shape[:2]
+ scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
+ for s in scales:
+ mask_h = random.randint(1, int(h * s)) # create random masks
+ mask_w = random.randint(1, int(w * s))
+
+ # box
+ xmin = max(0, random.randint(0, w) - mask_w // 2)
+ ymin = max(0, random.randint(0, h) - mask_h // 2)
+ xmax = min(w, xmin + mask_w)
+ ymax = min(h, ymin + mask_h)
+
+ # apply random color mask
+ im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
+
+ # return unobscured labels
+ if len(labels) and s > 0.03:
+ box = np.array([[xmin, ymin, xmax, ymax]], dtype=np.float32)
+ ioa = bbox_ioa(box, xywhn2xyxy(labels[:, 1:5], w, h))[0] # intersection over area
+ labels = labels[ioa < 0.60] # remove >60% obscured labels
+
+ return labels
+
+
+def mixup(im, labels, im2, labels2):
+ # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
+ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
+ im = (im * r + im2 * (1 - r)).astype(np.uint8)
+ labels = np.concatenate((labels, labels2), 0)
+ return im, labels
+
+
+def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
+ # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
+ w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
+ w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
+ ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
+ return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
+
+
+def classify_albumentations(
+ augment=True,
+ size=224,
+ scale=(0.08, 1.0),
+ ratio=(0.75, 1.0 / 0.75), # 0.75, 1.33
+ hflip=0.5,
+ vflip=0.0,
+ jitter=0.4,
+ mean=IMAGENET_MEAN,
+ std=IMAGENET_STD,
+ auto_aug=False):
+ # YOLOv5 classification Albumentations (optional, only used if package is installed)
+ prefix = colorstr('albumentations: ')
+ try:
+ import albumentations as A
+ from albumentations.pytorch import ToTensorV2
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
+ if augment: # Resize and crop
+ T = [A.RandomResizedCrop(height=size, width=size, scale=scale, ratio=ratio)]
+ if auto_aug:
+ # TODO: implement AugMix, AutoAug & RandAug in albumentation
+ LOGGER.info(f'{prefix}auto augmentations are currently not supported')
+ else:
+ if hflip > 0:
+ T += [A.HorizontalFlip(p=hflip)]
+ if vflip > 0:
+ T += [A.VerticalFlip(p=vflip)]
+ if jitter > 0:
+ color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
+ T += [A.ColorJitter(*color_jitter, 0)]
+ else: # Use fixed crop for eval set (reproducibility)
+ T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
+ T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
+ return A.Compose(T)
+
+ except ImportError: # package not installed, skip
+ LOGGER.warning(f'{prefix}β οΈ not found, install with `pip install albumentations` (recommended)')
+ except Exception as e:
+ LOGGER.info(f'{prefix}{e}')
+
+
+def classify_transforms(size=224):
+ # Transforms to apply if albumentations not installed
+ assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
+ # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
+ return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
+
+
+class LetterBox:
+ # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
+ def __init__(self, size=(640, 640), auto=False, stride=32):
+ super().__init__()
+ self.h, self.w = (size, size) if isinstance(size, int) else size
+ self.auto = auto # pass max size integer, automatically solve for short side using stride
+ self.stride = stride # used with auto
+
+ def __call__(self, im): # im = np.array HWC
+ imh, imw = im.shape[:2]
+ r = min(self.h / imh, self.w / imw) # ratio of new/old
+ h, w = round(imh * r), round(imw * r) # resized image
+ hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
+ top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
+ im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
+ im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
+ return im_out
+
+
+class CenterCrop:
+ # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
+ def __init__(self, size=640):
+ super().__init__()
+ self.h, self.w = (size, size) if isinstance(size, int) else size
+
+ def __call__(self, im): # im = np.array HWC
+ imh, imw = im.shape[:2]
+ m = min(imh, imw) # min dimension
+ top, left = (imh - m) // 2, (imw - m) // 2
+ return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
+
+
+class ToTensor:
+ # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
+ def __init__(self, half=False):
+ super().__init__()
+ self.half = half
+
+ def __call__(self, im): # im = np.array HWC in BGR order
+ im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
+ im = torch.from_numpy(im) # to torch
+ im = im.half() if self.half else im.float() # uint8 to fp16/32
+ im /= 255.0 # 0-255 to 0.0-1.0
+ return im
diff --git a/ultralytics/yolo/data/dataloaders/v5loader.py b/ultralytics/yolo/data/dataloaders/v5loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..54aa5e42518cd56cc45f1884bdda1b8e6ebf536e
--- /dev/null
+++ b/ultralytics/yolo/data/dataloaders/v5loader.py
@@ -0,0 +1,1216 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Dataloaders and dataset utils
+"""
+
+import contextlib
+import glob
+import hashlib
+import json
+import math
+import os
+import random
+import shutil
+import time
+from itertools import repeat
+from multiprocessing.pool import Pool, ThreadPool
+from pathlib import Path
+from threading import Thread
+from urllib.parse import urlparse
+
+import cv2
+import numpy as np
+import psutil
+import torch
+import torchvision
+import yaml
+from PIL import ExifTags, Image, ImageOps
+from torch.utils.data import DataLoader, Dataset, dataloader, distributed
+from tqdm import tqdm
+
+from ultralytics.yolo.data.utils import check_dataset, unzip_file
+from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_kaggle
+from ultralytics.yolo.utils.checks import check_requirements, check_yaml
+from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
+from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
+
+from .v5augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
+ letterbox, mixup, random_perspective)
+
+# Parameters
+HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
+IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
+VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
+LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
+RANK = int(os.getenv('RANK', -1))
+PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
+
+# Get orientation exif tag
+for orientation in ExifTags.TAGS.keys():
+ if ExifTags.TAGS[orientation] == 'Orientation':
+ break
+
+
+def get_hash(paths):
+ # Returns a single hash value of a list of paths (files or dirs)
+ size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
+ h = hashlib.md5(str(size).encode()) # hash sizes
+ h.update(''.join(paths).encode()) # hash paths
+ return h.hexdigest() # return hash
+
+
+def exif_size(img):
+ # Returns exif-corrected PIL size
+ s = img.size # (width, height)
+ with contextlib.suppress(Exception):
+ rotation = dict(img._getexif().items())[orientation]
+ if rotation in [6, 8]: # rotation 270 or 90
+ s = (s[1], s[0])
+ return s
+
+
+def exif_transpose(image):
+ """
+ Transpose a PIL image accordingly if it has an EXIF Orientation tag.
+ Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
+
+ :param image: The image to transpose.
+ :return: An image.
+ """
+ exif = image.getexif()
+ orientation = exif.get(0x0112, 1) # default 1
+ if orientation > 1:
+ method = {
+ 2: Image.FLIP_LEFT_RIGHT,
+ 3: Image.ROTATE_180,
+ 4: Image.FLIP_TOP_BOTTOM,
+ 5: Image.TRANSPOSE,
+ 6: Image.ROTATE_270,
+ 7: Image.TRANSVERSE,
+ 8: Image.ROTATE_90}.get(orientation)
+ if method is not None:
+ image = image.transpose(method)
+ del exif[0x0112]
+ image.info["exif"] = exif.tobytes()
+ return image
+
+
+def seed_worker(worker_id):
+ # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
+ worker_seed = torch.initial_seed() % 2 ** 32
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+
+
+def create_dataloader(path,
+ imgsz,
+ batch_size,
+ stride,
+ single_cls=False,
+ hyp=None,
+ augment=False,
+ cache=False,
+ pad=0.0,
+ rect=False,
+ rank=-1,
+ workers=8,
+ image_weights=False,
+ close_mosaic=False,
+ min_items=0,
+ prefix='',
+ shuffle=False,
+ seed=0):
+ if rect and shuffle:
+ LOGGER.warning('WARNING β οΈ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
+ shuffle = False
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
+ dataset = LoadImagesAndLabels(
+ path,
+ imgsz,
+ batch_size,
+ augment=augment, # augmentation
+ hyp=hyp, # hyperparameters
+ rect=rect, # rectangular batches
+ cache_images=cache,
+ single_cls=single_cls,
+ stride=int(stride),
+ pad=pad,
+ image_weights=image_weights,
+ min_items=min_items,
+ prefix=prefix)
+
+ batch_size = min(batch_size, len(dataset))
+ nd = torch.cuda.device_count() # number of CUDA devices
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
+ loader = DataLoader if image_weights or close_mosaic else InfiniteDataLoader # DataLoader allows attribute updates
+ generator = torch.Generator()
+ generator.manual_seed(6148914691236517205 + seed + RANK)
+ return loader(dataset,
+ batch_size=batch_size,
+ shuffle=shuffle and sampler is None,
+ num_workers=nw,
+ sampler=sampler,
+ pin_memory=PIN_MEMORY,
+ collate_fn=LoadImagesAndLabels.collate_fn,
+ worker_init_fn=seed_worker,
+ generator=generator), dataset
+
+
+class InfiniteDataLoader(dataloader.DataLoader):
+ """ Dataloader that reuses workers
+
+ Uses same syntax as vanilla DataLoader
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
+ self.iterator = super().__iter__()
+
+ def __len__(self):
+ return len(self.batch_sampler.sampler)
+
+ def __iter__(self):
+ for _ in range(len(self)):
+ yield next(self.iterator)
+
+
+class _RepeatSampler:
+ """ Sampler that repeats forever
+
+ Args:
+ sampler (Sampler)
+ """
+
+ def __init__(self, sampler):
+ self.sampler = sampler
+
+ def __iter__(self):
+ while True:
+ yield from iter(self.sampler)
+
+
+class LoadScreenshots:
+ # YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
+ def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
+ # source = [screen_number left top width height] (pixels)
+ check_requirements('mss')
+ import mss
+
+ source, *params = source.split()
+ self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
+ if len(params) == 1:
+ self.screen = int(params[0])
+ elif len(params) == 4:
+ left, top, width, height = (int(x) for x in params)
+ elif len(params) == 5:
+ self.screen, left, top, width, height = (int(x) for x in params)
+ self.img_size = img_size
+ self.stride = stride
+ self.transforms = transforms
+ self.auto = auto
+ self.mode = 'stream'
+ self.frame = 0
+ self.sct = mss.mss()
+
+ # Parse monitor shape
+ monitor = self.sct.monitors[self.screen]
+ self.top = monitor["top"] if top is None else (monitor["top"] + top)
+ self.left = monitor["left"] if left is None else (monitor["left"] + left)
+ self.width = width or monitor["width"]
+ self.height = height or monitor["height"]
+ self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ # mss screen capture: get raw pixels from the screen as np array
+ im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
+ s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
+
+ if self.transforms:
+ im = self.transforms(im0) # transforms
+ else:
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+ self.frame += 1
+ return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
+
+
+class LoadImages:
+ # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
+ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
+ if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
+ path = Path(path).read_text().rsplit()
+ files = []
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
+ p = str(Path(p).resolve())
+ if '*' in p:
+ files.extend(sorted(glob.glob(p, recursive=True))) # glob
+ elif os.path.isdir(p):
+ files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
+ elif os.path.isfile(p):
+ files.append(p) # files
+ else:
+ raise FileNotFoundError(f'{p} does not exist')
+
+ images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
+ videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
+ ni, nv = len(images), len(videos)
+
+ self.img_size = img_size
+ self.stride = stride
+ self.files = images + videos
+ self.nf = ni + nv # number of files
+ self.video_flag = [False] * ni + [True] * nv
+ self.mode = 'image'
+ self.auto = auto
+ self.transforms = transforms # optional
+ self.vid_stride = vid_stride # video frame-rate stride
+ if any(videos):
+ self._new_video(videos[0]) # new video
+ else:
+ self.cap = None
+ assert self.nf > 0, f'No images or videos found in {p}. ' \
+ f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
+
+ def __iter__(self):
+ self.count = 0
+ return self
+
+ def __next__(self):
+ if self.count == self.nf:
+ raise StopIteration
+ path = self.files[self.count]
+
+ if self.video_flag[self.count]:
+ # Read video
+ self.mode = 'video'
+ for _ in range(self.vid_stride):
+ self.cap.grab()
+ ret_val, im0 = self.cap.retrieve()
+ while not ret_val:
+ self.count += 1
+ self.cap.release()
+ if self.count == self.nf: # last video
+ raise StopIteration
+ path = self.files[self.count]
+ self._new_video(path)
+ ret_val, im0 = self.cap.read()
+
+ self.frame += 1
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
+ s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
+
+ else:
+ # Read image
+ self.count += 1
+ im0 = cv2.imread(path) # BGR
+ assert im0 is not None, f'Image Not Found {path}'
+ s = f'image {self.count}/{self.nf} {path}: '
+
+ if self.transforms:
+ im = self.transforms(im0) # transforms
+ else:
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+
+ return path, im, im0, self.cap, s
+
+ def _new_video(self, path):
+ # Create a new video capture object
+ self.frame = 0
+ self.cap = cv2.VideoCapture(path)
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
+ self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
+
+ def _cv2_rotate(self, im):
+ # Rotate a cv2 video manually
+ if self.orientation == 0:
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
+ elif self.orientation == 180:
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
+ elif self.orientation == 90:
+ return cv2.rotate(im, cv2.ROTATE_180)
+ return im
+
+ def __len__(self):
+ return self.nf # number of files
+
+
+class LoadStreams:
+ # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
+ def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
+ self.mode = 'stream'
+ self.img_size = img_size
+ self.stride = stride
+ self.vid_stride = vid_stride # video frame-rate stride
+ sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
+ n = len(sources)
+ self.sources = [clean_str(x) for x in sources] # clean source names for later
+ self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
+ for i, s in enumerate(sources): # index, source
+ # Start thread to read frames from video stream
+ st = f'{i + 1}/{n}: {s}... '
+ if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
+ # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
+ check_requirements(('pafy', 'youtube_dl==2020.12.2'))
+ import pafy
+ s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
+ s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
+ if s == 0:
+ assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
+ assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
+ cap = cv2.VideoCapture(s)
+ assert cap.isOpened(), f'{st}Failed to open {s}'
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
+ self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
+ self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
+
+ _, self.imgs[i] = cap.read() # guarantee first frame
+ self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
+ LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
+ self.threads[i].start()
+ LOGGER.info('') # newline
+
+ # check for common shapes
+ s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
+ self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
+ self.auto = auto and self.rect
+ self.transforms = transforms # optional
+ if not self.rect:
+ LOGGER.warning('WARNING β οΈ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
+
+ def update(self, i, cap, stream):
+ # Read stream `i` frames in daemon thread
+ n, f = 0, self.frames[i] # frame number, frame array
+ while cap.isOpened() and n < f:
+ n += 1
+ cap.grab() # .read() = .grab() followed by .retrieve()
+ if n % self.vid_stride == 0:
+ success, im = cap.retrieve()
+ if success:
+ self.imgs[i] = im
+ else:
+ LOGGER.warning('WARNING β οΈ Video stream unresponsive, please check your IP camera connection.')
+ self.imgs[i] = np.zeros_like(self.imgs[i])
+ cap.open(stream) # re-open stream if signal was lost
+ time.sleep(0.0) # wait time
+
+ def __iter__(self):
+ self.count = -1
+ return self
+
+ def __next__(self):
+ self.count += 1
+ if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
+ cv2.destroyAllWindows()
+ raise StopIteration
+
+ im0 = self.imgs.copy()
+ if self.transforms:
+ im = np.stack([self.transforms(x) for x in im0]) # transforms
+ else:
+ im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
+ im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
+ im = np.ascontiguousarray(im) # contiguous
+
+ return self.sources, im, im0, None, ''
+
+ def __len__(self):
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
+
+
+def img2label_paths(img_paths):
+ # Define label paths as a function of image paths
+ sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
+ return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
+
+
+class LoadImagesAndLabels(Dataset):
+ # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
+ cache_version = 0.6 # dataset labels *.cache version
+ rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
+
+ def __init__(self,
+ path,
+ img_size=640,
+ batch_size=16,
+ augment=False,
+ hyp=None,
+ rect=False,
+ image_weights=False,
+ cache_images=False,
+ single_cls=False,
+ stride=32,
+ pad=0.0,
+ min_items=0,
+ prefix=''):
+ self.img_size = img_size
+ self.augment = augment
+ self.hyp = hyp
+ self.image_weights = image_weights
+ self.rect = False if image_weights else rect
+ self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
+ self.mosaic_border = [-img_size // 2, -img_size // 2]
+ self.stride = stride
+ self.path = path
+ self.albumentations = Albumentations(size=img_size) if augment else None
+
+ try:
+ f = [] # image files
+ for p in path if isinstance(path, list) else [path]:
+ p = Path(p) # os-agnostic
+ if p.is_dir(): # dir
+ f += glob.glob(str(p / '**' / '*.*'), recursive=True)
+ # f = list(p.rglob('*.*')) # pathlib
+ elif p.is_file(): # file
+ with open(p) as t:
+ t = t.read().strip().splitlines()
+ parent = str(p.parent) + os.sep
+ f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t] # to global path
+ # f += [p.parent / x.lstrip(os.sep) for x in t] # to global path (pathlib)
+ else:
+ raise FileNotFoundError(f'{prefix}{p} does not exist')
+ self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
+ # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
+ assert self.im_files, f'{prefix}No images found'
+ except Exception as e:
+ raise FileNotFoundError(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
+
+ # Check cache
+ self.label_files = img2label_paths(self.im_files) # labels
+ cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
+ try:
+ cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
+ assert cache['version'] == self.cache_version # matches current version
+ assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
+ except Exception:
+ cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
+
+ # Display cache
+ nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
+ if exists and LOCAL_RANK in {-1, 0}:
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
+ tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
+ if cache['msgs']:
+ LOGGER.info('\n'.join(cache['msgs'])) # display warnings
+ assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
+
+ # Read cache
+ [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
+ labels, shapes, self.segments = zip(*cache.values())
+ nl = len(np.concatenate(labels, 0)) # number of labels
+ assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
+ self.labels = list(labels)
+ self.shapes = np.array(shapes)
+ self.im_files = list(cache.keys()) # update
+ self.label_files = img2label_paths(cache.keys()) # update
+
+ # Filter images
+ if min_items:
+ include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
+ LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
+ self.im_files = [self.im_files[i] for i in include]
+ self.label_files = [self.label_files[i] for i in include]
+ self.labels = [self.labels[i] for i in include]
+ self.segments = [self.segments[i] for i in include]
+ self.shapes = self.shapes[include] # wh
+
+ # Create indices
+ n = len(self.shapes) # number of images
+ bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
+ nb = bi[-1] + 1 # number of batches
+ self.batch = bi # batch index of image
+ self.n = n
+ self.indices = range(n)
+
+ # Update labels
+ include_class = [] # filter labels to include only these classes (optional)
+ include_class_array = np.array(include_class).reshape(1, -1)
+ for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
+ if include_class:
+ j = (label[:, 0:1] == include_class_array).any(1)
+ self.labels[i] = label[j]
+ if segment:
+ self.segments[i] = segment[j]
+ if single_cls: # single-class training, merge all classes into 0
+ self.labels[i][:, 0] = 0
+
+ # Rectangular Training
+ if self.rect:
+ # Sort by aspect ratio
+ s = self.shapes # wh
+ ar = s[:, 1] / s[:, 0] # aspect ratio
+ irect = ar.argsort()
+ self.im_files = [self.im_files[i] for i in irect]
+ self.label_files = [self.label_files[i] for i in irect]
+ self.labels = [self.labels[i] for i in irect]
+ self.segments = [self.segments[i] for i in irect]
+ self.shapes = s[irect] # wh
+ ar = ar[irect]
+
+ # Set training image shapes
+ shapes = [[1, 1]] * nb
+ for i in range(nb):
+ ari = ar[bi == i]
+ mini, maxi = ari.min(), ari.max()
+ if maxi < 1:
+ shapes[i] = [maxi, 1]
+ elif mini > 1:
+ shapes[i] = [1, 1 / mini]
+
+ self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
+
+ # Cache images into RAM/disk for faster training
+ if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
+ cache_images = False
+ self.ims = [None] * n
+ self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
+ if cache_images:
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
+ self.im_hw0, self.im_hw = [None] * n, [None] * n
+ fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
+ results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
+ pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
+ for i, x in pbar:
+ if cache_images == 'disk':
+ b += self.npy_files[i].stat().st_size
+ else: # 'ram'
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
+ b += self.ims[i].nbytes
+ pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
+ pbar.close()
+
+ def check_cache_ram(self, safety_margin=0.1, prefix=''):
+ # Check image caching requirements vs available memory
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
+ n = min(self.n, 30) # extrapolate from 30 random images
+ for _ in range(n):
+ im = cv2.imread(random.choice(self.im_files)) # sample image
+ ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
+ b += im.nbytes * ratio ** 2
+ mem_required = b * self.n / n # GB required to cache dataset into RAM
+ mem = psutil.virtual_memory()
+ cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
+ if not cache:
+ LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
+ f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
+ f"{'caching images β
' if cache else 'not caching images β οΈ'}")
+ return cache
+
+ def cache_labels(self, path=Path('./labels.cache'), prefix=''):
+ # Cache dataset labels, check images and read shapes
+ x = {} # dict
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
+ desc = f"{prefix}Scanning {path.parent / path.stem}..."
+ with Pool(NUM_THREADS) as pool:
+ pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
+ desc=desc,
+ total=len(self.im_files),
+ bar_format=TQDM_BAR_FORMAT)
+ for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
+ nm += nm_f
+ nf += nf_f
+ ne += ne_f
+ nc += nc_f
+ if im_file:
+ x[im_file] = [lb, shape, segments]
+ if msg:
+ msgs.append(msg)
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
+
+ pbar.close()
+ if msgs:
+ LOGGER.info('\n'.join(msgs))
+ if nf == 0:
+ LOGGER.warning(f'{prefix}WARNING β οΈ No labels found in {path}. {HELP_URL}')
+ x['hash'] = get_hash(self.label_files + self.im_files)
+ x['results'] = nf, nm, ne, nc, len(self.im_files)
+ x['msgs'] = msgs # warnings
+ x['version'] = self.cache_version # cache version
+ try:
+ np.save(path, x) # save cache for next time
+ path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
+ LOGGER.info(f'{prefix}New cache created: {path}')
+ except Exception as e:
+ LOGGER.warning(f'{prefix}WARNING β οΈ Cache directory {path.parent} is not writeable: {e}') # not writeable
+ return x
+
+ def __len__(self):
+ return len(self.im_files)
+
+ # def __iter__(self):
+ # self.count = -1
+ # print('ran dataset iter')
+ # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
+ # return self
+
+ def __getitem__(self, index):
+ index = self.indices[index] # linear, shuffled, or image_weights
+
+ hyp = self.hyp
+ mosaic = self.mosaic and random.random() < hyp['mosaic']
+ if mosaic:
+ # Load mosaic
+ img, labels = self.load_mosaic(index)
+ shapes = None
+
+ # MixUp augmentation
+ if random.random() < hyp['mixup']:
+ img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
+
+ else:
+ # Load image
+ img, (h0, w0), (h, w) = self.load_image(index)
+
+ # Letterbox
+ shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
+ img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
+ shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
+
+ labels = self.labels[index].copy()
+ if labels.size: # normalized xywh to pixel xyxy format
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
+
+ if self.augment:
+ img, labels = random_perspective(img,
+ labels,
+ degrees=hyp['degrees'],
+ translate=hyp['translate'],
+ scale=hyp['scale'],
+ shear=hyp['shear'],
+ perspective=hyp['perspective'])
+
+ nl = len(labels) # number of labels
+ if nl:
+ labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
+
+ if self.augment:
+ # Albumentations
+ img, labels = self.albumentations(img, labels)
+ nl = len(labels) # update after albumentations
+
+ # HSV color-space
+ augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
+
+ # Flip up-down
+ if random.random() < hyp['flipud']:
+ img = np.flipud(img)
+ if nl:
+ labels[:, 2] = 1 - labels[:, 2]
+
+ # Flip left-right
+ if random.random() < hyp['fliplr']:
+ img = np.fliplr(img)
+ if nl:
+ labels[:, 1] = 1 - labels[:, 1]
+
+ # Cutouts
+ # labels = cutout(img, labels, p=0.5)
+ # nl = len(labels) # update after cutout
+
+ labels_out = torch.zeros((nl, 6))
+ if nl:
+ labels_out[:, 1:] = torch.from_numpy(labels)
+
+ # Convert
+ img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ img = np.ascontiguousarray(img)
+
+ return torch.from_numpy(img), labels_out, self.im_files[index], shapes
+
+ def load_image(self, i):
+ # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
+ if im is None: # not cached in RAM
+ if fn.exists(): # load npy
+ im = np.load(fn)
+ else: # read image
+ im = cv2.imread(f) # BGR
+ assert im is not None, f'Image Not Found {f}'
+ h0, w0 = im.shape[:2] # orig hw
+ r = self.img_size / max(h0, w0) # ratio
+ if r != 1: # if sizes are not equal
+ interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
+ im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
+ return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
+ return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
+
+ def cache_images_to_disk(self, i):
+ # Saves an image as an *.npy file for faster loading
+ f = self.npy_files[i]
+ if not f.exists():
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
+
+ def load_mosaic(self, index):
+ # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
+ labels4, segments4 = [], []
+ s = self.img_size
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
+ indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
+ random.shuffle(indices)
+ for i, index in enumerate(indices):
+ # Load image
+ img, _, (h, w) = self.load_image(index)
+
+ # place img in img4
+ if i == 0: # top left
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
+ elif i == 1: # top right
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
+ elif i == 2: # bottom left
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
+ elif i == 3: # bottom right
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
+
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
+ padw = x1a - x1b
+ padh = y1a - y1b
+
+ # Labels
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
+ if labels.size:
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
+ segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
+ labels4.append(labels)
+ segments4.extend(segments)
+
+ # Concat/clip labels
+ labels4 = np.concatenate(labels4, 0)
+ for x in (labels4[:, 1:], *segments4):
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
+ # img4, labels4 = replicate(img4, labels4) # replicate
+
+ # Augment
+ img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
+ img4, labels4 = random_perspective(img4,
+ labels4,
+ segments4,
+ degrees=self.hyp['degrees'],
+ translate=self.hyp['translate'],
+ scale=self.hyp['scale'],
+ shear=self.hyp['shear'],
+ perspective=self.hyp['perspective'],
+ border=self.mosaic_border) # border to remove
+
+ return img4, labels4
+
+ def load_mosaic9(self, index):
+ # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
+ labels9, segments9 = [], []
+ s = self.img_size
+ indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
+ random.shuffle(indices)
+ hp, wp = -1, -1 # height, width previous
+ for i, index in enumerate(indices):
+ # Load image
+ img, _, (h, w) = self.load_image(index)
+
+ # place img in img9
+ if i == 0: # center
+ img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
+ h0, w0 = h, w
+ c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
+ elif i == 1: # top
+ c = s, s - h, s + w, s
+ elif i == 2: # top right
+ c = s + wp, s - h, s + wp + w, s
+ elif i == 3: # right
+ c = s + w0, s, s + w0 + w, s + h
+ elif i == 4: # bottom right
+ c = s + w0, s + hp, s + w0 + w, s + hp + h
+ elif i == 5: # bottom
+ c = s + w0 - w, s + h0, s + w0, s + h0 + h
+ elif i == 6: # bottom left
+ c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
+ elif i == 7: # left
+ c = s - w, s + h0 - h, s, s + h0
+ elif i == 8: # top left
+ c = s - w, s + h0 - hp - h, s, s + h0 - hp
+
+ padx, pady = c[:2]
+ x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
+
+ # Labels
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
+ if labels.size:
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
+ segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
+ labels9.append(labels)
+ segments9.extend(segments)
+
+ # Image
+ img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
+ hp, wp = h, w # height, width previous
+
+ # Offset
+ yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
+ img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
+
+ # Concat/clip labels
+ labels9 = np.concatenate(labels9, 0)
+ labels9[:, [1, 3]] -= xc
+ labels9[:, [2, 4]] -= yc
+ c = np.array([xc, yc]) # centers
+ segments9 = [x - c for x in segments9]
+
+ for x in (labels9[:, 1:], *segments9):
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
+ # img9, labels9 = replicate(img9, labels9) # replicate
+
+ # Augment
+ img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
+ img9, labels9 = random_perspective(img9,
+ labels9,
+ segments9,
+ degrees=self.hyp['degrees'],
+ translate=self.hyp['translate'],
+ scale=self.hyp['scale'],
+ shear=self.hyp['shear'],
+ perspective=self.hyp['perspective'],
+ border=self.mosaic_border) # border to remove
+
+ return img9, labels9
+
+ @staticmethod
+ def collate_fn(batch):
+ # YOLOv8 collate function, outputs dict
+ im, label, path, shapes = zip(*batch) # transposed
+ for i, lb in enumerate(label):
+ lb[:, 0] = i # add target image index for build_targets()
+ batch_idx, cls, bboxes = torch.cat(label, 0).split((1, 1, 4), dim=1)
+ return {
+ 'ori_shape': tuple((x[0] if x else None) for x in shapes),
+ 'ratio_pad': tuple((x[1] if x else None) for x in shapes),
+ 'im_file': path,
+ 'img': torch.stack(im, 0),
+ 'cls': cls,
+ 'bboxes': bboxes,
+ 'batch_idx': batch_idx.view(-1)}
+
+ @staticmethod
+ def collate_fn_old(batch):
+ # YOLOv5 original collate function
+ im, label, path, shapes = zip(*batch) # transposed
+ for i, lb in enumerate(label):
+ lb[:, 0] = i # add target image index for build_targets()
+ return torch.stack(im, 0), torch.cat(label, 0), path, shapes
+
+
+# Ancillary functions --------------------------------------------------------------------------------------------------
+def flatten_recursive(path=DATASETS_DIR / 'coco128'):
+ # Flatten a recursive directory by bringing all files to top level
+ new_path = Path(f'{str(path)}_flat')
+ if os.path.exists(new_path):
+ shutil.rmtree(new_path) # delete output folder
+ os.makedirs(new_path) # make new output folder
+ for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
+ shutil.copyfile(file, new_path / Path(file).name)
+
+
+def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
+ # Convert detection dataset into classification dataset, with one directory per class
+ path = Path(path) # images dir
+ shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing
+ files = list(path.rglob('*.*'))
+ n = len(files) # number of files
+ for im_file in tqdm(files, total=n):
+ if im_file.suffix[1:] in IMG_FORMATS:
+ # image
+ im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
+ h, w = im.shape[:2]
+
+ # labels
+ lb_file = Path(img2label_paths([str(im_file)])[0])
+ if Path(lb_file).exists():
+ with open(lb_file) as f:
+ lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
+
+ for j, x in enumerate(lb):
+ c = int(x[0]) # class
+ f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
+ if not f.parent.is_dir():
+ f.parent.mkdir(parents=True)
+
+ b = x[1:] * [w, h, w, h] # box
+ # b[2:] = b[2:].max() # rectangle to square
+ b[2:] = b[2:] * 1.2 + 3 # pad
+ b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)
+
+ b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
+ b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
+ assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
+
+
+def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
+ """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
+ Usage: from utils.dataloaders import *; autosplit()
+ Arguments
+ path: Path to images directory
+ weights: Train, val, test weights (list, tuple)
+ annotated_only: Only use images with an annotated txt file
+ """
+ path = Path(path) # images dir
+ files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
+ n = len(files) # number of files
+ random.seed(0) # for reproducibility
+ indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
+
+ txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
+ for x in txt:
+ if (path.parent / x).exists():
+ (path.parent / x).unlink() # remove existing
+
+ print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
+ for i, img in tqdm(zip(indices, files), total=n):
+ if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
+ with open(path.parent / txt[i], 'a') as f:
+ f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
+
+
+def verify_image_label(args):
+ # Verify one image-label pair
+ im_file, lb_file, prefix = args
+ nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
+ try:
+ # verify images
+ im = Image.open(im_file)
+ im.verify() # PIL verify
+ shape = exif_size(im) # image size
+ assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
+ assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
+ if im.format.lower() in ('jpg', 'jpeg'):
+ with open(im_file, 'rb') as f:
+ f.seek(-2, 2)
+ if f.read() != b'\xff\xd9': # corrupt JPEG
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
+ msg = f'{prefix}WARNING β οΈ {im_file}: corrupt JPEG restored and saved'
+
+ # verify labels
+ if os.path.isfile(lb_file):
+ nf = 1 # label found
+ with open(lb_file) as f:
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
+ if any(len(x) > 6 for x in lb): # is segment
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
+ lb = np.array(lb, dtype=np.float32)
+ nl = len(lb)
+ if nl:
+ assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
+ assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
+ assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
+ _, i = np.unique(lb, axis=0, return_index=True)
+ if len(i) < nl: # duplicate row check
+ lb = lb[i] # remove duplicates
+ if segments:
+ segments = [segments[x] for x in i]
+ msg = f'{prefix}WARNING β οΈ {im_file}: {nl - len(i)} duplicate labels removed'
+ else:
+ ne = 1 # label empty
+ lb = np.zeros((0, 5), dtype=np.float32)
+ else:
+ nm = 1 # label missing
+ lb = np.zeros((0, 5), dtype=np.float32)
+ return im_file, lb, shape, segments, nm, nf, ne, nc, msg
+ except Exception as e:
+ nc = 1
+ msg = f'{prefix}WARNING β οΈ {im_file}: ignoring corrupt image/label: {e}'
+ return [None, None, None, None, nm, nf, ne, nc, msg]
+
+
+class HUBDatasetStats():
+ """ Class for generating HUB dataset JSON and `-hub` dataset directory
+
+ Arguments
+ path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
+ autodownload: Attempt to download dataset if not found locally
+
+ Usage
+ from utils.dataloaders import HUBDatasetStats
+ stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
+ stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
+ stats.get_json(save=False)
+ stats.process_images()
+ """
+
+ def __init__(self, path='coco128.yaml', autodownload=False):
+ # Initialize class
+ zipped, data_dir, yaml_path = self._unzip(Path(path))
+ try:
+ with open(check_yaml(yaml_path), errors='ignore') as f:
+ data = yaml.safe_load(f) # data dict
+ if zipped:
+ data['path'] = data_dir
+ except Exception as e:
+ raise Exception("error/HUB/dataset_stats/yaml_load") from e
+
+ check_dataset(data, autodownload) # download dataset if missing
+ self.hub_dir = Path(data['path'] + '-hub')
+ self.im_dir = self.hub_dir / 'images'
+ self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
+ self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
+ self.data = data
+
+ @staticmethod
+ def _find_yaml(dir):
+ # Return data.yaml file
+ files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
+ assert files, f'No *.yaml file found in {dir}'
+ if len(files) > 1:
+ files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
+ assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
+ assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
+ return files[0]
+
+ def _unzip(self, path):
+ # Unzip data.zip
+ if not str(path).endswith('.zip'): # path is data.yaml
+ return False, None, path
+ assert Path(path).is_file(), f'Error unzipping {path}, file not found'
+ unzip_file(path, path=path.parent)
+ dir = path.with_suffix('') # dataset directory == zip name
+ assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
+ return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
+
+ def _hub_ops(self, f, max_dim=1920):
+ # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
+ f_new = self.im_dir / Path(f).name # dataset-hub image filename
+ try: # use PIL
+ im = Image.open(f)
+ r = max_dim / max(im.height, im.width) # ratio
+ if r < 1.0: # image too large
+ im = im.resize((int(im.width * r), int(im.height * r)))
+ im.save(f_new, 'JPEG', quality=50, optimize=True) # save
+ except Exception as e: # use OpenCV
+ LOGGER.info(f'WARNING β οΈ HUB ops PIL failure {f}: {e}')
+ im = cv2.imread(f)
+ im_height, im_width = im.shape[:2]
+ r = max_dim / max(im_height, im_width) # ratio
+ if r < 1.0: # image too large
+ im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
+ cv2.imwrite(str(f_new), im)
+
+ def get_json(self, save=False, verbose=False):
+ # Return dataset JSON for Ultralytics HUB
+ def _round(labels):
+ # Update labels to integer class and 6 decimal place floats
+ return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
+
+ for split in 'train', 'val', 'test':
+ if self.data.get(split) is None:
+ self.stats[split] = None # i.e. no test set
+ continue
+ dataset = LoadImagesAndLabels(self.data[split]) # load dataset
+ x = np.array([
+ np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
+ for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
+ self.stats[split] = {
+ 'instance_stats': {
+ 'total': int(x.sum()),
+ 'per_class': x.sum(0).tolist()},
+ 'image_stats': {
+ 'total': dataset.n,
+ 'unlabelled': int(np.all(x == 0, 1).sum()),
+ 'per_class': (x > 0).sum(0).tolist()},
+ 'labels': [{
+ str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
+
+ # Save, print and return
+ if save:
+ stats_path = self.hub_dir / 'stats.json'
+ print(f'Saving {stats_path.resolve()}...')
+ with open(stats_path, 'w') as f:
+ json.dump(self.stats, f) # save stats.json
+ if verbose:
+ print(json.dumps(self.stats, indent=2, sort_keys=False))
+ return self.stats
+
+ def process_images(self):
+ # Compress images for Ultralytics HUB
+ for split in 'train', 'val', 'test':
+ if self.data.get(split) is None:
+ continue
+ dataset = LoadImagesAndLabels(self.data[split]) # load dataset
+ desc = f'{split} images'
+ for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
+ pass
+ print(f'Done. All images saved to {self.im_dir}')
+ return self.im_dir
+
+
+# Classification dataloaders -------------------------------------------------------------------------------------------
+class ClassificationDataset(torchvision.datasets.ImageFolder):
+ """
+ YOLOv5 Classification Dataset.
+ Arguments
+ root: Dataset path
+ transform: torchvision transforms, used by default
+ album_transform: Albumentations transforms, used if installed
+ """
+
+ def __init__(self, root, augment, imgsz, cache=False):
+ super().__init__(root=root)
+ self.torch_transforms = classify_transforms(imgsz)
+ self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
+ self.cache_ram = cache is True or cache == 'ram'
+ self.cache_disk = cache == 'disk'
+ self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
+
+ def __getitem__(self, i):
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
+ if self.cache_ram and im is None:
+ im = self.samples[i][3] = cv2.imread(f)
+ elif self.cache_disk:
+ if not fn.exists(): # load npy
+ np.save(fn.as_posix(), cv2.imread(f))
+ im = np.load(fn)
+ else: # read image
+ im = cv2.imread(f) # BGR
+ if self.album_transforms:
+ sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
+ else:
+ sample = self.torch_transforms(im)
+ return sample, j
+
+
+def create_classification_dataloader(path,
+ imgsz=224,
+ batch_size=16,
+ augment=True,
+ cache=False,
+ rank=-1,
+ workers=8,
+ shuffle=True):
+ # Returns Dataloader object to be used with YOLOv5 Classifier
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
+ dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
+ batch_size = min(batch_size, len(dataset))
+ nd = torch.cuda.device_count()
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
+ generator = torch.Generator()
+ generator.manual_seed(6148914691236517205 + RANK)
+ return InfiniteDataLoader(dataset,
+ batch_size=batch_size,
+ shuffle=shuffle and sampler is None,
+ num_workers=nw,
+ sampler=sampler,
+ pin_memory=PIN_MEMORY,
+ worker_init_fn=seed_worker,
+ generator=generator) # or DataLoader(persistent_workers=True)
diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..387271eb30f2b56be16936b778943e81a0cc98b8
--- /dev/null
+++ b/ultralytics/yolo/data/dataset.py
@@ -0,0 +1,225 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from itertools import repeat
+from multiprocessing.pool import Pool
+from pathlib import Path
+
+import torchvision
+from tqdm import tqdm
+
+from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
+from .augment import *
+from .base import BaseDataset
+from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
+
+
+class YOLODataset(BaseDataset):
+ cache_version = 1.0 # dataset labels *.cache version, >= 1.0 for YOLOv8
+ rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
+ """YOLO Dataset.
+ Args:
+ img_path (str): image path.
+ prefix (str): prefix.
+ """
+
+ def __init__(
+ self,
+ img_path,
+ imgsz=640,
+ label_path=None,
+ cache=False,
+ augment=True,
+ hyp=None,
+ prefix="",
+ rect=False,
+ batch_size=None,
+ stride=32,
+ pad=0.0,
+ single_cls=False,
+ use_segments=False,
+ use_keypoints=False,
+ ):
+ self.use_segments = use_segments
+ self.use_keypoints = use_keypoints
+ assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
+ super().__init__(img_path, imgsz, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
+ single_cls)
+
+ def cache_labels(self, path=Path("./labels.cache")):
+ # Cache dataset labels, check images and read shapes
+ x = {"labels": []}
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
+ desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
+ with Pool(NUM_THREADS) as pool:
+ pbar = tqdm(
+ pool.imap(verify_image_label,
+ zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
+ desc=desc,
+ total=len(self.im_files),
+ bar_format=TQDM_BAR_FORMAT,
+ )
+ for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
+ nm += nm_f
+ nf += nf_f
+ ne += ne_f
+ nc += nc_f
+ if im_file:
+ x["labels"].append(
+ dict(
+ im_file=im_file,
+ shape=shape,
+ cls=lb[:, 0:1], # n, 1
+ bboxes=lb[:, 1:], # n, 4
+ segments=segments,
+ keypoints=keypoint,
+ normalized=True,
+ bbox_format="xywh",
+ ))
+ if msg:
+ msgs.append(msg)
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
+
+ pbar.close()
+ if msgs:
+ LOGGER.info("\n".join(msgs))
+ if nf == 0:
+ LOGGER.warning(f"{self.prefix}WARNING β οΈ No labels found in {path}. {HELP_URL}")
+ x["hash"] = get_hash(self.label_files + self.im_files)
+ x["results"] = nf, nm, ne, nc, len(self.im_files)
+ x["msgs"] = msgs # warnings
+ x["version"] = self.cache_version # cache version
+ try:
+ np.save(path, x) # save cache for next time
+ path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
+ LOGGER.info(f"{self.prefix}New cache created: {path}")
+ except Exception as e:
+ LOGGER.warning(
+ f"{self.prefix}WARNING β οΈ Cache directory {path.parent} is not writeable: {e}") # not writeable
+ return x
+
+ def get_labels(self):
+ self.label_files = img2label_paths(self.im_files)
+ cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
+ try:
+ cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
+ assert cache["version"] == self.cache_version # matches current version
+ assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
+ except Exception:
+ cache, exists = self.cache_labels(cache_path), False # run cache ops
+
+ # Display cache
+ nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
+ if exists and LOCAL_RANK in {-1, 0}:
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
+ tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
+ if cache["msgs"]:
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
+ assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
+
+ # Read cache
+ [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
+ labels = cache["labels"]
+ nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
+ assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
+ return labels
+
+ # TODO: use hyp config to set all these augmentations
+ def build_transforms(self, hyp=None):
+ if self.augment:
+ mosaic = self.augment and not self.rect
+ transforms = mosaic_transforms(self, self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
+ else:
+ transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
+ transforms.append(
+ Format(bbox_format="xywh",
+ normalize=True,
+ return_mask=self.use_segments,
+ return_keypoint=self.use_keypoints,
+ batch_idx=True))
+ return transforms
+
+ def close_mosaic(self, hyp):
+ self.transforms = affine_transforms(self.imgsz, hyp)
+ self.transforms.append(
+ Format(bbox_format="xywh",
+ normalize=True,
+ return_mask=self.use_segments,
+ return_keypoint=self.use_keypoints,
+ batch_idx=True))
+
+ def update_labels_info(self, label):
+ """custom your label format here"""
+ # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
+ # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
+ bboxes = label.pop("bboxes")
+ segments = label.pop("segments")
+ keypoints = label.pop("keypoints", None)
+ bbox_format = label.pop("bbox_format")
+ normalized = label.pop("normalized")
+ label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
+ return label
+
+ @staticmethod
+ def collate_fn(batch):
+ # TODO: returning a dict can make thing easier and cleaner when using dataset in training
+ # but I don't know if this will slow down a little bit.
+ new_batch = {}
+ keys = batch[0].keys()
+ values = list(zip(*[list(b.values()) for b in batch]))
+ for i, k in enumerate(keys):
+ value = values[i]
+ if k == "img":
+ value = torch.stack(value, 0)
+ if k in ["masks", "keypoints", "bboxes", "cls"]:
+ value = torch.cat(value, 0)
+ new_batch[k] = value
+ new_batch["batch_idx"] = list(new_batch["batch_idx"])
+ for i in range(len(new_batch["batch_idx"])):
+ new_batch["batch_idx"][i] += i # add target image index for build_targets()
+ new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
+ return new_batch
+
+
+# Classification dataloaders -------------------------------------------------------------------------------------------
+class ClassificationDataset(torchvision.datasets.ImageFolder):
+ """
+ YOLOv5 Classification Dataset.
+ Arguments
+ root: Dataset path
+ transform: torchvision transforms, used by default
+ album_transform: Albumentations transforms, used if installed
+ """
+
+ def __init__(self, root, augment, imgsz, cache=False):
+ super().__init__(root=root)
+ self.torch_transforms = classify_transforms(imgsz)
+ self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
+ self.cache_ram = cache is True or cache == "ram"
+ self.cache_disk = cache == "disk"
+ self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
+
+ def __getitem__(self, i):
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
+ if self.cache_ram and im is None:
+ im = self.samples[i][3] = cv2.imread(f)
+ elif self.cache_disk:
+ if not fn.exists(): # load npy
+ np.save(fn.as_posix(), cv2.imread(f))
+ im = np.load(fn)
+ else: # read image
+ im = cv2.imread(f) # BGR
+ if self.album_transforms:
+ sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
+ else:
+ sample = self.torch_transforms(im)
+ return {'img': sample, 'cls': j}
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+
+# TODO: support semantic segmentation
+class SemanticDataset(BaseDataset):
+
+ def __init__(self):
+ pass
diff --git a/ultralytics/yolo/data/dataset_wrappers.py b/ultralytics/yolo/data/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..46a8eee090ab75c5342ca1375298c55a4ea399d2
--- /dev/null
+++ b/ultralytics/yolo/data/dataset_wrappers.py
@@ -0,0 +1,39 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import collections
+from copy import deepcopy
+
+from .augment import LetterBox
+
+
+class MixAndRectDataset:
+ """A wrapper of multiple images mixed dataset.
+
+ Args:
+ dataset (:obj:`BaseDataset`): The dataset to be mixed.
+ transforms (Sequence[dict]): config dict to be composed.
+ """
+
+ def __init__(self, dataset):
+ self.dataset = dataset
+ self.imgsz = dataset.imgsz
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+ labels = deepcopy(self.dataset[index])
+ for transform in self.dataset.transforms.tolist():
+ # mosaic and mixup
+ if hasattr(transform, "get_indexes"):
+ indexes = transform.get_indexes(self.dataset)
+ if not isinstance(indexes, collections.abc.Sequence):
+ indexes = [indexes]
+ mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
+ labels["mix_labels"] = mix_labels
+ if self.dataset.rect and isinstance(transform, LetterBox):
+ transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
+ labels = transform(labels)
+ if "mix_labels" in labels:
+ labels.pop("mix_labels")
+ return labels
diff --git a/ultralytics/yolo/data/datasets/Argoverse.yaml b/ultralytics/yolo/data/datasets/Argoverse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..35896dae7d20c86a1fa4264b8155607c659f6187
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/Argoverse.yaml
@@ -0,0 +1,74 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI
+# Example usage: python train.py --data Argoverse.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ Argoverse β downloads here (31.3 GB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/Argoverse # dataset root dir
+train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images
+val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images
+test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview
+
+# Classes
+names:
+ 0: person
+ 1: bicycle
+ 2: car
+ 3: motorcycle
+ 4: bus
+ 5: truck
+ 6: traffic_light
+ 7: stop_sign
+
+
+# Download script/URL (optional) ---------------------------------------------------------------------------------------
+download: |
+ import json
+
+ from tqdm import tqdm
+ from utils.general import download, Path
+
+
+ def argoverse2yolo(set):
+ labels = {}
+ a = json.load(open(set, "rb"))
+ for annot in tqdm(a['annotations'], desc=f"Converting {set} to YOLOv5 format..."):
+ img_id = annot['image_id']
+ img_name = a['images'][img_id]['name']
+ img_label_name = f'{img_name[:-3]}txt'
+
+ cls = annot['category_id'] # instance class id
+ x_center, y_center, width, height = annot['bbox']
+ x_center = (x_center + width / 2) / 1920.0 # offset and scale
+ y_center = (y_center + height / 2) / 1200.0 # offset and scale
+ width /= 1920.0 # scale
+ height /= 1200.0 # scale
+
+ img_dir = set.parents[2] / 'Argoverse-1.1' / 'labels' / a['seq_dirs'][a['images'][annot['image_id']]['sid']]
+ if not img_dir.exists():
+ img_dir.mkdir(parents=True, exist_ok=True)
+
+ k = str(img_dir / img_label_name)
+ if k not in labels:
+ labels[k] = []
+ labels[k].append(f"{cls} {x_center} {y_center} {width} {height}\n")
+
+ for k in labels:
+ with open(k, "w") as f:
+ f.writelines(labels[k])
+
+
+ # Download
+ dir = Path(yaml['path']) # dataset root dir
+ urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
+ download(urls, dir=dir, delete=False)
+
+ # Convert
+ annotations_dir = 'Argoverse-HD/annotations/'
+ (dir / 'Argoverse-1.1' / 'tracking').rename(dir / 'Argoverse-1.1' / 'images') # rename 'tracking' to 'images'
+ for d in "train.json", "val.json":
+ argoverse2yolo(dir / annotations_dir / d) # convert VisDrone annotations to YOLO labels
diff --git a/ultralytics/yolo/data/datasets/GlobalWheat2020.yaml b/ultralytics/yolo/data/datasets/GlobalWheat2020.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..286a9bd9529b611553eb512745620476a5902098
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/GlobalWheat2020.yaml
@@ -0,0 +1,54 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# Global Wheat 2020 dataset http://www.global-wheat.com/ by University of Saskatchewan
+# Example usage: python train.py --data GlobalWheat2020.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ GlobalWheat2020 β downloads here (7.0 GB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/GlobalWheat2020 # dataset root dir
+train: # train images (relative to 'path') 3422 images
+ - images/arvalis_1
+ - images/arvalis_2
+ - images/arvalis_3
+ - images/ethz_1
+ - images/rres_1
+ - images/inrae_1
+ - images/usask_1
+val: # val images (relative to 'path') 748 images (WARNING: train set contains ethz_1)
+ - images/ethz_1
+test: # test images (optional) 1276 images
+ - images/utokyo_1
+ - images/utokyo_2
+ - images/nau_1
+ - images/uq_1
+
+# Classes
+names:
+ 0: wheat_head
+
+
+# Download script/URL (optional) ---------------------------------------------------------------------------------------
+download: |
+ from utils.general import download, Path
+
+
+ # Download
+ dir = Path(yaml['path']) # dataset root dir
+ urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip',
+ 'https://github.com/ultralytics/yolov5/releases/download/v1.0/GlobalWheat2020_labels.zip']
+ download(urls, dir=dir)
+
+ # Make Directories
+ for p in 'annotations', 'images', 'labels':
+ (dir / p).mkdir(parents=True, exist_ok=True)
+
+ # Move
+ for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \
+ 'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1':
+ (dir / p).rename(dir / 'images' / p) # move to /images
+ f = (dir / p).with_suffix('.json') # json file
+ if f.exists():
+ f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations
diff --git a/ultralytics/yolo/data/datasets/ImageNet.yaml b/ultralytics/yolo/data/datasets/ImageNet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..264d992cd2cbd7b0f1f975180305f945f3f337f4
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/ImageNet.yaml
@@ -0,0 +1,1022 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# ImageNet-1k dataset https://www.image-net.org/index.php by Stanford University
+# Simplified class names from https://github.com/anishathalye/imagenet-simple-labels
+# Example usage: python classify/train.py --data imagenet
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ imagenet β downloads here (144 GB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/imagenet # dataset root dir
+train: train # train images (relative to 'path') 1281167 images
+val: val # val images (relative to 'path') 50000 images
+test: # test images (optional)
+
+# Classes
+names:
+ 0: tench
+ 1: goldfish
+ 2: great white shark
+ 3: tiger shark
+ 4: hammerhead shark
+ 5: electric ray
+ 6: stingray
+ 7: cock
+ 8: hen
+ 9: ostrich
+ 10: brambling
+ 11: goldfinch
+ 12: house finch
+ 13: junco
+ 14: indigo bunting
+ 15: American robin
+ 16: bulbul
+ 17: jay
+ 18: magpie
+ 19: chickadee
+ 20: American dipper
+ 21: kite
+ 22: bald eagle
+ 23: vulture
+ 24: great grey owl
+ 25: fire salamander
+ 26: smooth newt
+ 27: newt
+ 28: spotted salamander
+ 29: axolotl
+ 30: American bullfrog
+ 31: tree frog
+ 32: tailed frog
+ 33: loggerhead sea turtle
+ 34: leatherback sea turtle
+ 35: mud turtle
+ 36: terrapin
+ 37: box turtle
+ 38: banded gecko
+ 39: green iguana
+ 40: Carolina anole
+ 41: desert grassland whiptail lizard
+ 42: agama
+ 43: frilled-necked lizard
+ 44: alligator lizard
+ 45: Gila monster
+ 46: European green lizard
+ 47: chameleon
+ 48: Komodo dragon
+ 49: Nile crocodile
+ 50: American alligator
+ 51: triceratops
+ 52: worm snake
+ 53: ring-necked snake
+ 54: eastern hog-nosed snake
+ 55: smooth green snake
+ 56: kingsnake
+ 57: garter snake
+ 58: water snake
+ 59: vine snake
+ 60: night snake
+ 61: boa constrictor
+ 62: African rock python
+ 63: Indian cobra
+ 64: green mamba
+ 65: sea snake
+ 66: Saharan horned viper
+ 67: eastern diamondback rattlesnake
+ 68: sidewinder
+ 69: trilobite
+ 70: harvestman
+ 71: scorpion
+ 72: yellow garden spider
+ 73: barn spider
+ 74: European garden spider
+ 75: southern black widow
+ 76: tarantula
+ 77: wolf spider
+ 78: tick
+ 79: centipede
+ 80: black grouse
+ 81: ptarmigan
+ 82: ruffed grouse
+ 83: prairie grouse
+ 84: peacock
+ 85: quail
+ 86: partridge
+ 87: grey parrot
+ 88: macaw
+ 89: sulphur-crested cockatoo
+ 90: lorikeet
+ 91: coucal
+ 92: bee eater
+ 93: hornbill
+ 94: hummingbird
+ 95: jacamar
+ 96: toucan
+ 97: duck
+ 98: red-breasted merganser
+ 99: goose
+ 100: black swan
+ 101: tusker
+ 102: echidna
+ 103: platypus
+ 104: wallaby
+ 105: koala
+ 106: wombat
+ 107: jellyfish
+ 108: sea anemone
+ 109: brain coral
+ 110: flatworm
+ 111: nematode
+ 112: conch
+ 113: snail
+ 114: slug
+ 115: sea slug
+ 116: chiton
+ 117: chambered nautilus
+ 118: Dungeness crab
+ 119: rock crab
+ 120: fiddler crab
+ 121: red king crab
+ 122: American lobster
+ 123: spiny lobster
+ 124: crayfish
+ 125: hermit crab
+ 126: isopod
+ 127: white stork
+ 128: black stork
+ 129: spoonbill
+ 130: flamingo
+ 131: little blue heron
+ 132: great egret
+ 133: bittern
+ 134: crane (bird)
+ 135: limpkin
+ 136: common gallinule
+ 137: American coot
+ 138: bustard
+ 139: ruddy turnstone
+ 140: dunlin
+ 141: common redshank
+ 142: dowitcher
+ 143: oystercatcher
+ 144: pelican
+ 145: king penguin
+ 146: albatross
+ 147: grey whale
+ 148: killer whale
+ 149: dugong
+ 150: sea lion
+ 151: Chihuahua
+ 152: Japanese Chin
+ 153: Maltese
+ 154: Pekingese
+ 155: Shih Tzu
+ 156: King Charles Spaniel
+ 157: Papillon
+ 158: toy terrier
+ 159: Rhodesian Ridgeback
+ 160: Afghan Hound
+ 161: Basset Hound
+ 162: Beagle
+ 163: Bloodhound
+ 164: Bluetick Coonhound
+ 165: Black and Tan Coonhound
+ 166: Treeing Walker Coonhound
+ 167: English foxhound
+ 168: Redbone Coonhound
+ 169: borzoi
+ 170: Irish Wolfhound
+ 171: Italian Greyhound
+ 172: Whippet
+ 173: Ibizan Hound
+ 174: Norwegian Elkhound
+ 175: Otterhound
+ 176: Saluki
+ 177: Scottish Deerhound
+ 178: Weimaraner
+ 179: Staffordshire Bull Terrier
+ 180: American Staffordshire Terrier
+ 181: Bedlington Terrier
+ 182: Border Terrier
+ 183: Kerry Blue Terrier
+ 184: Irish Terrier
+ 185: Norfolk Terrier
+ 186: Norwich Terrier
+ 187: Yorkshire Terrier
+ 188: Wire Fox Terrier
+ 189: Lakeland Terrier
+ 190: Sealyham Terrier
+ 191: Airedale Terrier
+ 192: Cairn Terrier
+ 193: Australian Terrier
+ 194: Dandie Dinmont Terrier
+ 195: Boston Terrier
+ 196: Miniature Schnauzer
+ 197: Giant Schnauzer
+ 198: Standard Schnauzer
+ 199: Scottish Terrier
+ 200: Tibetan Terrier
+ 201: Australian Silky Terrier
+ 202: Soft-coated Wheaten Terrier
+ 203: West Highland White Terrier
+ 204: Lhasa Apso
+ 205: Flat-Coated Retriever
+ 206: Curly-coated Retriever
+ 207: Golden Retriever
+ 208: Labrador Retriever
+ 209: Chesapeake Bay Retriever
+ 210: German Shorthaired Pointer
+ 211: Vizsla
+ 212: English Setter
+ 213: Irish Setter
+ 214: Gordon Setter
+ 215: Brittany
+ 216: Clumber Spaniel
+ 217: English Springer Spaniel
+ 218: Welsh Springer Spaniel
+ 219: Cocker Spaniels
+ 220: Sussex Spaniel
+ 221: Irish Water Spaniel
+ 222: Kuvasz
+ 223: Schipperke
+ 224: Groenendael
+ 225: Malinois
+ 226: Briard
+ 227: Australian Kelpie
+ 228: Komondor
+ 229: Old English Sheepdog
+ 230: Shetland Sheepdog
+ 231: collie
+ 232: Border Collie
+ 233: Bouvier des Flandres
+ 234: Rottweiler
+ 235: German Shepherd Dog
+ 236: Dobermann
+ 237: Miniature Pinscher
+ 238: Greater Swiss Mountain Dog
+ 239: Bernese Mountain Dog
+ 240: Appenzeller Sennenhund
+ 241: Entlebucher Sennenhund
+ 242: Boxer
+ 243: Bullmastiff
+ 244: Tibetan Mastiff
+ 245: French Bulldog
+ 246: Great Dane
+ 247: St. Bernard
+ 248: husky
+ 249: Alaskan Malamute
+ 250: Siberian Husky
+ 251: Dalmatian
+ 252: Affenpinscher
+ 253: Basenji
+ 254: pug
+ 255: Leonberger
+ 256: Newfoundland
+ 257: Pyrenean Mountain Dog
+ 258: Samoyed
+ 259: Pomeranian
+ 260: Chow Chow
+ 261: Keeshond
+ 262: Griffon Bruxellois
+ 263: Pembroke Welsh Corgi
+ 264: Cardigan Welsh Corgi
+ 265: Toy Poodle
+ 266: Miniature Poodle
+ 267: Standard Poodle
+ 268: Mexican hairless dog
+ 269: grey wolf
+ 270: Alaskan tundra wolf
+ 271: red wolf
+ 272: coyote
+ 273: dingo
+ 274: dhole
+ 275: African wild dog
+ 276: hyena
+ 277: red fox
+ 278: kit fox
+ 279: Arctic fox
+ 280: grey fox
+ 281: tabby cat
+ 282: tiger cat
+ 283: Persian cat
+ 284: Siamese cat
+ 285: Egyptian Mau
+ 286: cougar
+ 287: lynx
+ 288: leopard
+ 289: snow leopard
+ 290: jaguar
+ 291: lion
+ 292: tiger
+ 293: cheetah
+ 294: brown bear
+ 295: American black bear
+ 296: polar bear
+ 297: sloth bear
+ 298: mongoose
+ 299: meerkat
+ 300: tiger beetle
+ 301: ladybug
+ 302: ground beetle
+ 303: longhorn beetle
+ 304: leaf beetle
+ 305: dung beetle
+ 306: rhinoceros beetle
+ 307: weevil
+ 308: fly
+ 309: bee
+ 310: ant
+ 311: grasshopper
+ 312: cricket
+ 313: stick insect
+ 314: cockroach
+ 315: mantis
+ 316: cicada
+ 317: leafhopper
+ 318: lacewing
+ 319: dragonfly
+ 320: damselfly
+ 321: red admiral
+ 322: ringlet
+ 323: monarch butterfly
+ 324: small white
+ 325: sulphur butterfly
+ 326: gossamer-winged butterfly
+ 327: starfish
+ 328: sea urchin
+ 329: sea cucumber
+ 330: cottontail rabbit
+ 331: hare
+ 332: Angora rabbit
+ 333: hamster
+ 334: porcupine
+ 335: fox squirrel
+ 336: marmot
+ 337: beaver
+ 338: guinea pig
+ 339: common sorrel
+ 340: zebra
+ 341: pig
+ 342: wild boar
+ 343: warthog
+ 344: hippopotamus
+ 345: ox
+ 346: water buffalo
+ 347: bison
+ 348: ram
+ 349: bighorn sheep
+ 350: Alpine ibex
+ 351: hartebeest
+ 352: impala
+ 353: gazelle
+ 354: dromedary
+ 355: llama
+ 356: weasel
+ 357: mink
+ 358: European polecat
+ 359: black-footed ferret
+ 360: otter
+ 361: skunk
+ 362: badger
+ 363: armadillo
+ 364: three-toed sloth
+ 365: orangutan
+ 366: gorilla
+ 367: chimpanzee
+ 368: gibbon
+ 369: siamang
+ 370: guenon
+ 371: patas monkey
+ 372: baboon
+ 373: macaque
+ 374: langur
+ 375: black-and-white colobus
+ 376: proboscis monkey
+ 377: marmoset
+ 378: white-headed capuchin
+ 379: howler monkey
+ 380: titi
+ 381: Geoffroy's spider monkey
+ 382: common squirrel monkey
+ 383: ring-tailed lemur
+ 384: indri
+ 385: Asian elephant
+ 386: African bush elephant
+ 387: red panda
+ 388: giant panda
+ 389: snoek
+ 390: eel
+ 391: coho salmon
+ 392: rock beauty
+ 393: clownfish
+ 394: sturgeon
+ 395: garfish
+ 396: lionfish
+ 397: pufferfish
+ 398: abacus
+ 399: abaya
+ 400: academic gown
+ 401: accordion
+ 402: acoustic guitar
+ 403: aircraft carrier
+ 404: airliner
+ 405: airship
+ 406: altar
+ 407: ambulance
+ 408: amphibious vehicle
+ 409: analog clock
+ 410: apiary
+ 411: apron
+ 412: waste container
+ 413: assault rifle
+ 414: backpack
+ 415: bakery
+ 416: balance beam
+ 417: balloon
+ 418: ballpoint pen
+ 419: Band-Aid
+ 420: banjo
+ 421: baluster
+ 422: barbell
+ 423: barber chair
+ 424: barbershop
+ 425: barn
+ 426: barometer
+ 427: barrel
+ 428: wheelbarrow
+ 429: baseball
+ 430: basketball
+ 431: bassinet
+ 432: bassoon
+ 433: swimming cap
+ 434: bath towel
+ 435: bathtub
+ 436: station wagon
+ 437: lighthouse
+ 438: beaker
+ 439: military cap
+ 440: beer bottle
+ 441: beer glass
+ 442: bell-cot
+ 443: bib
+ 444: tandem bicycle
+ 445: bikini
+ 446: ring binder
+ 447: binoculars
+ 448: birdhouse
+ 449: boathouse
+ 450: bobsleigh
+ 451: bolo tie
+ 452: poke bonnet
+ 453: bookcase
+ 454: bookstore
+ 455: bottle cap
+ 456: bow
+ 457: bow tie
+ 458: brass
+ 459: bra
+ 460: breakwater
+ 461: breastplate
+ 462: broom
+ 463: bucket
+ 464: buckle
+ 465: bulletproof vest
+ 466: high-speed train
+ 467: butcher shop
+ 468: taxicab
+ 469: cauldron
+ 470: candle
+ 471: cannon
+ 472: canoe
+ 473: can opener
+ 474: cardigan
+ 475: car mirror
+ 476: carousel
+ 477: tool kit
+ 478: carton
+ 479: car wheel
+ 480: automated teller machine
+ 481: cassette
+ 482: cassette player
+ 483: castle
+ 484: catamaran
+ 485: CD player
+ 486: cello
+ 487: mobile phone
+ 488: chain
+ 489: chain-link fence
+ 490: chain mail
+ 491: chainsaw
+ 492: chest
+ 493: chiffonier
+ 494: chime
+ 495: china cabinet
+ 496: Christmas stocking
+ 497: church
+ 498: movie theater
+ 499: cleaver
+ 500: cliff dwelling
+ 501: cloak
+ 502: clogs
+ 503: cocktail shaker
+ 504: coffee mug
+ 505: coffeemaker
+ 506: coil
+ 507: combination lock
+ 508: computer keyboard
+ 509: confectionery store
+ 510: container ship
+ 511: convertible
+ 512: corkscrew
+ 513: cornet
+ 514: cowboy boot
+ 515: cowboy hat
+ 516: cradle
+ 517: crane (machine)
+ 518: crash helmet
+ 519: crate
+ 520: infant bed
+ 521: Crock Pot
+ 522: croquet ball
+ 523: crutch
+ 524: cuirass
+ 525: dam
+ 526: desk
+ 527: desktop computer
+ 528: rotary dial telephone
+ 529: diaper
+ 530: digital clock
+ 531: digital watch
+ 532: dining table
+ 533: dishcloth
+ 534: dishwasher
+ 535: disc brake
+ 536: dock
+ 537: dog sled
+ 538: dome
+ 539: doormat
+ 540: drilling rig
+ 541: drum
+ 542: drumstick
+ 543: dumbbell
+ 544: Dutch oven
+ 545: electric fan
+ 546: electric guitar
+ 547: electric locomotive
+ 548: entertainment center
+ 549: envelope
+ 550: espresso machine
+ 551: face powder
+ 552: feather boa
+ 553: filing cabinet
+ 554: fireboat
+ 555: fire engine
+ 556: fire screen sheet
+ 557: flagpole
+ 558: flute
+ 559: folding chair
+ 560: football helmet
+ 561: forklift
+ 562: fountain
+ 563: fountain pen
+ 564: four-poster bed
+ 565: freight car
+ 566: French horn
+ 567: frying pan
+ 568: fur coat
+ 569: garbage truck
+ 570: gas mask
+ 571: gas pump
+ 572: goblet
+ 573: go-kart
+ 574: golf ball
+ 575: golf cart
+ 576: gondola
+ 577: gong
+ 578: gown
+ 579: grand piano
+ 580: greenhouse
+ 581: grille
+ 582: grocery store
+ 583: guillotine
+ 584: barrette
+ 585: hair spray
+ 586: half-track
+ 587: hammer
+ 588: hamper
+ 589: hair dryer
+ 590: hand-held computer
+ 591: handkerchief
+ 592: hard disk drive
+ 593: harmonica
+ 594: harp
+ 595: harvester
+ 596: hatchet
+ 597: holster
+ 598: home theater
+ 599: honeycomb
+ 600: hook
+ 601: hoop skirt
+ 602: horizontal bar
+ 603: horse-drawn vehicle
+ 604: hourglass
+ 605: iPod
+ 606: clothes iron
+ 607: jack-o'-lantern
+ 608: jeans
+ 609: jeep
+ 610: T-shirt
+ 611: jigsaw puzzle
+ 612: pulled rickshaw
+ 613: joystick
+ 614: kimono
+ 615: knee pad
+ 616: knot
+ 617: lab coat
+ 618: ladle
+ 619: lampshade
+ 620: laptop computer
+ 621: lawn mower
+ 622: lens cap
+ 623: paper knife
+ 624: library
+ 625: lifeboat
+ 626: lighter
+ 627: limousine
+ 628: ocean liner
+ 629: lipstick
+ 630: slip-on shoe
+ 631: lotion
+ 632: speaker
+ 633: loupe
+ 634: sawmill
+ 635: magnetic compass
+ 636: mail bag
+ 637: mailbox
+ 638: tights
+ 639: tank suit
+ 640: manhole cover
+ 641: maraca
+ 642: marimba
+ 643: mask
+ 644: match
+ 645: maypole
+ 646: maze
+ 647: measuring cup
+ 648: medicine chest
+ 649: megalith
+ 650: microphone
+ 651: microwave oven
+ 652: military uniform
+ 653: milk can
+ 654: minibus
+ 655: miniskirt
+ 656: minivan
+ 657: missile
+ 658: mitten
+ 659: mixing bowl
+ 660: mobile home
+ 661: Model T
+ 662: modem
+ 663: monastery
+ 664: monitor
+ 665: moped
+ 666: mortar
+ 667: square academic cap
+ 668: mosque
+ 669: mosquito net
+ 670: scooter
+ 671: mountain bike
+ 672: tent
+ 673: computer mouse
+ 674: mousetrap
+ 675: moving van
+ 676: muzzle
+ 677: nail
+ 678: neck brace
+ 679: necklace
+ 680: nipple
+ 681: notebook computer
+ 682: obelisk
+ 683: oboe
+ 684: ocarina
+ 685: odometer
+ 686: oil filter
+ 687: organ
+ 688: oscilloscope
+ 689: overskirt
+ 690: bullock cart
+ 691: oxygen mask
+ 692: packet
+ 693: paddle
+ 694: paddle wheel
+ 695: padlock
+ 696: paintbrush
+ 697: pajamas
+ 698: palace
+ 699: pan flute
+ 700: paper towel
+ 701: parachute
+ 702: parallel bars
+ 703: park bench
+ 704: parking meter
+ 705: passenger car
+ 706: patio
+ 707: payphone
+ 708: pedestal
+ 709: pencil case
+ 710: pencil sharpener
+ 711: perfume
+ 712: Petri dish
+ 713: photocopier
+ 714: plectrum
+ 715: Pickelhaube
+ 716: picket fence
+ 717: pickup truck
+ 718: pier
+ 719: piggy bank
+ 720: pill bottle
+ 721: pillow
+ 722: ping-pong ball
+ 723: pinwheel
+ 724: pirate ship
+ 725: pitcher
+ 726: hand plane
+ 727: planetarium
+ 728: plastic bag
+ 729: plate rack
+ 730: plow
+ 731: plunger
+ 732: Polaroid camera
+ 733: pole
+ 734: police van
+ 735: poncho
+ 736: billiard table
+ 737: soda bottle
+ 738: pot
+ 739: potter's wheel
+ 740: power drill
+ 741: prayer rug
+ 742: printer
+ 743: prison
+ 744: projectile
+ 745: projector
+ 746: hockey puck
+ 747: punching bag
+ 748: purse
+ 749: quill
+ 750: quilt
+ 751: race car
+ 752: racket
+ 753: radiator
+ 754: radio
+ 755: radio telescope
+ 756: rain barrel
+ 757: recreational vehicle
+ 758: reel
+ 759: reflex camera
+ 760: refrigerator
+ 761: remote control
+ 762: restaurant
+ 763: revolver
+ 764: rifle
+ 765: rocking chair
+ 766: rotisserie
+ 767: eraser
+ 768: rugby ball
+ 769: ruler
+ 770: running shoe
+ 771: safe
+ 772: safety pin
+ 773: salt shaker
+ 774: sandal
+ 775: sarong
+ 776: saxophone
+ 777: scabbard
+ 778: weighing scale
+ 779: school bus
+ 780: schooner
+ 781: scoreboard
+ 782: CRT screen
+ 783: screw
+ 784: screwdriver
+ 785: seat belt
+ 786: sewing machine
+ 787: shield
+ 788: shoe store
+ 789: shoji
+ 790: shopping basket
+ 791: shopping cart
+ 792: shovel
+ 793: shower cap
+ 794: shower curtain
+ 795: ski
+ 796: ski mask
+ 797: sleeping bag
+ 798: slide rule
+ 799: sliding door
+ 800: slot machine
+ 801: snorkel
+ 802: snowmobile
+ 803: snowplow
+ 804: soap dispenser
+ 805: soccer ball
+ 806: sock
+ 807: solar thermal collector
+ 808: sombrero
+ 809: soup bowl
+ 810: space bar
+ 811: space heater
+ 812: space shuttle
+ 813: spatula
+ 814: motorboat
+ 815: spider web
+ 816: spindle
+ 817: sports car
+ 818: spotlight
+ 819: stage
+ 820: steam locomotive
+ 821: through arch bridge
+ 822: steel drum
+ 823: stethoscope
+ 824: scarf
+ 825: stone wall
+ 826: stopwatch
+ 827: stove
+ 828: strainer
+ 829: tram
+ 830: stretcher
+ 831: couch
+ 832: stupa
+ 833: submarine
+ 834: suit
+ 835: sundial
+ 836: sunglass
+ 837: sunglasses
+ 838: sunscreen
+ 839: suspension bridge
+ 840: mop
+ 841: sweatshirt
+ 842: swimsuit
+ 843: swing
+ 844: switch
+ 845: syringe
+ 846: table lamp
+ 847: tank
+ 848: tape player
+ 849: teapot
+ 850: teddy bear
+ 851: television
+ 852: tennis ball
+ 853: thatched roof
+ 854: front curtain
+ 855: thimble
+ 856: threshing machine
+ 857: throne
+ 858: tile roof
+ 859: toaster
+ 860: tobacco shop
+ 861: toilet seat
+ 862: torch
+ 863: totem pole
+ 864: tow truck
+ 865: toy store
+ 866: tractor
+ 867: semi-trailer truck
+ 868: tray
+ 869: trench coat
+ 870: tricycle
+ 871: trimaran
+ 872: tripod
+ 873: triumphal arch
+ 874: trolleybus
+ 875: trombone
+ 876: tub
+ 877: turnstile
+ 878: typewriter keyboard
+ 879: umbrella
+ 880: unicycle
+ 881: upright piano
+ 882: vacuum cleaner
+ 883: vase
+ 884: vault
+ 885: velvet
+ 886: vending machine
+ 887: vestment
+ 888: viaduct
+ 889: violin
+ 890: volleyball
+ 891: waffle iron
+ 892: wall clock
+ 893: wallet
+ 894: wardrobe
+ 895: military aircraft
+ 896: sink
+ 897: washing machine
+ 898: water bottle
+ 899: water jug
+ 900: water tower
+ 901: whiskey jug
+ 902: whistle
+ 903: wig
+ 904: window screen
+ 905: window shade
+ 906: Windsor tie
+ 907: wine bottle
+ 908: wing
+ 909: wok
+ 910: wooden spoon
+ 911: wool
+ 912: split-rail fence
+ 913: shipwreck
+ 914: yawl
+ 915: yurt
+ 916: website
+ 917: comic book
+ 918: crossword
+ 919: traffic sign
+ 920: traffic light
+ 921: dust jacket
+ 922: menu
+ 923: plate
+ 924: guacamole
+ 925: consomme
+ 926: hot pot
+ 927: trifle
+ 928: ice cream
+ 929: ice pop
+ 930: baguette
+ 931: bagel
+ 932: pretzel
+ 933: cheeseburger
+ 934: hot dog
+ 935: mashed potato
+ 936: cabbage
+ 937: broccoli
+ 938: cauliflower
+ 939: zucchini
+ 940: spaghetti squash
+ 941: acorn squash
+ 942: butternut squash
+ 943: cucumber
+ 944: artichoke
+ 945: bell pepper
+ 946: cardoon
+ 947: mushroom
+ 948: Granny Smith
+ 949: strawberry
+ 950: orange
+ 951: lemon
+ 952: fig
+ 953: pineapple
+ 954: banana
+ 955: jackfruit
+ 956: custard apple
+ 957: pomegranate
+ 958: hay
+ 959: carbonara
+ 960: chocolate syrup
+ 961: dough
+ 962: meatloaf
+ 963: pizza
+ 964: pot pie
+ 965: burrito
+ 966: red wine
+ 967: espresso
+ 968: cup
+ 969: eggnog
+ 970: alp
+ 971: bubble
+ 972: cliff
+ 973: coral reef
+ 974: geyser
+ 975: lakeshore
+ 976: promontory
+ 977: shoal
+ 978: seashore
+ 979: valley
+ 980: volcano
+ 981: baseball player
+ 982: bridegroom
+ 983: scuba diver
+ 984: rapeseed
+ 985: daisy
+ 986: yellow lady's slipper
+ 987: corn
+ 988: acorn
+ 989: rose hip
+ 990: horse chestnut seed
+ 991: coral fungus
+ 992: agaric
+ 993: gyromitra
+ 994: stinkhorn mushroom
+ 995: earth star
+ 996: hen-of-the-woods
+ 997: bolete
+ 998: ear
+ 999: toilet paper
+
+
+# Download script/URL (optional)
+download: data/scripts/get_imagenet.sh
diff --git a/ultralytics/yolo/data/datasets/Objects365.yaml b/ultralytics/yolo/data/datasets/Objects365.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a4e071550dd4917658b93524dd3a1cf5a330468
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/Objects365.yaml
@@ -0,0 +1,438 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# Objects365 dataset https://www.objects365.org/ by Megvii
+# Example usage: python train.py --data Objects365.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ Objects365 β downloads here (712 GB = 367G data + 345G zips)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/Objects365 # dataset root dir
+train: images/train # train images (relative to 'path') 1742289 images
+val: images/val # val images (relative to 'path') 80000 images
+test: # test images (optional)
+
+# Classes
+names:
+ 0: Person
+ 1: Sneakers
+ 2: Chair
+ 3: Other Shoes
+ 4: Hat
+ 5: Car
+ 6: Lamp
+ 7: Glasses
+ 8: Bottle
+ 9: Desk
+ 10: Cup
+ 11: Street Lights
+ 12: Cabinet/shelf
+ 13: Handbag/Satchel
+ 14: Bracelet
+ 15: Plate
+ 16: Picture/Frame
+ 17: Helmet
+ 18: Book
+ 19: Gloves
+ 20: Storage box
+ 21: Boat
+ 22: Leather Shoes
+ 23: Flower
+ 24: Bench
+ 25: Potted Plant
+ 26: Bowl/Basin
+ 27: Flag
+ 28: Pillow
+ 29: Boots
+ 30: Vase
+ 31: Microphone
+ 32: Necklace
+ 33: Ring
+ 34: SUV
+ 35: Wine Glass
+ 36: Belt
+ 37: Monitor/TV
+ 38: Backpack
+ 39: Umbrella
+ 40: Traffic Light
+ 41: Speaker
+ 42: Watch
+ 43: Tie
+ 44: Trash bin Can
+ 45: Slippers
+ 46: Bicycle
+ 47: Stool
+ 48: Barrel/bucket
+ 49: Van
+ 50: Couch
+ 51: Sandals
+ 52: Basket
+ 53: Drum
+ 54: Pen/Pencil
+ 55: Bus
+ 56: Wild Bird
+ 57: High Heels
+ 58: Motorcycle
+ 59: Guitar
+ 60: Carpet
+ 61: Cell Phone
+ 62: Bread
+ 63: Camera
+ 64: Canned
+ 65: Truck
+ 66: Traffic cone
+ 67: Cymbal
+ 68: Lifesaver
+ 69: Towel
+ 70: Stuffed Toy
+ 71: Candle
+ 72: Sailboat
+ 73: Laptop
+ 74: Awning
+ 75: Bed
+ 76: Faucet
+ 77: Tent
+ 78: Horse
+ 79: Mirror
+ 80: Power outlet
+ 81: Sink
+ 82: Apple
+ 83: Air Conditioner
+ 84: Knife
+ 85: Hockey Stick
+ 86: Paddle
+ 87: Pickup Truck
+ 88: Fork
+ 89: Traffic Sign
+ 90: Balloon
+ 91: Tripod
+ 92: Dog
+ 93: Spoon
+ 94: Clock
+ 95: Pot
+ 96: Cow
+ 97: Cake
+ 98: Dinning Table
+ 99: Sheep
+ 100: Hanger
+ 101: Blackboard/Whiteboard
+ 102: Napkin
+ 103: Other Fish
+ 104: Orange/Tangerine
+ 105: Toiletry
+ 106: Keyboard
+ 107: Tomato
+ 108: Lantern
+ 109: Machinery Vehicle
+ 110: Fan
+ 111: Green Vegetables
+ 112: Banana
+ 113: Baseball Glove
+ 114: Airplane
+ 115: Mouse
+ 116: Train
+ 117: Pumpkin
+ 118: Soccer
+ 119: Skiboard
+ 120: Luggage
+ 121: Nightstand
+ 122: Tea pot
+ 123: Telephone
+ 124: Trolley
+ 125: Head Phone
+ 126: Sports Car
+ 127: Stop Sign
+ 128: Dessert
+ 129: Scooter
+ 130: Stroller
+ 131: Crane
+ 132: Remote
+ 133: Refrigerator
+ 134: Oven
+ 135: Lemon
+ 136: Duck
+ 137: Baseball Bat
+ 138: Surveillance Camera
+ 139: Cat
+ 140: Jug
+ 141: Broccoli
+ 142: Piano
+ 143: Pizza
+ 144: Elephant
+ 145: Skateboard
+ 146: Surfboard
+ 147: Gun
+ 148: Skating and Skiing shoes
+ 149: Gas stove
+ 150: Donut
+ 151: Bow Tie
+ 152: Carrot
+ 153: Toilet
+ 154: Kite
+ 155: Strawberry
+ 156: Other Balls
+ 157: Shovel
+ 158: Pepper
+ 159: Computer Box
+ 160: Toilet Paper
+ 161: Cleaning Products
+ 162: Chopsticks
+ 163: Microwave
+ 164: Pigeon
+ 165: Baseball
+ 166: Cutting/chopping Board
+ 167: Coffee Table
+ 168: Side Table
+ 169: Scissors
+ 170: Marker
+ 171: Pie
+ 172: Ladder
+ 173: Snowboard
+ 174: Cookies
+ 175: Radiator
+ 176: Fire Hydrant
+ 177: Basketball
+ 178: Zebra
+ 179: Grape
+ 180: Giraffe
+ 181: Potato
+ 182: Sausage
+ 183: Tricycle
+ 184: Violin
+ 185: Egg
+ 186: Fire Extinguisher
+ 187: Candy
+ 188: Fire Truck
+ 189: Billiards
+ 190: Converter
+ 191: Bathtub
+ 192: Wheelchair
+ 193: Golf Club
+ 194: Briefcase
+ 195: Cucumber
+ 196: Cigar/Cigarette
+ 197: Paint Brush
+ 198: Pear
+ 199: Heavy Truck
+ 200: Hamburger
+ 201: Extractor
+ 202: Extension Cord
+ 203: Tong
+ 204: Tennis Racket
+ 205: Folder
+ 206: American Football
+ 207: earphone
+ 208: Mask
+ 209: Kettle
+ 210: Tennis
+ 211: Ship
+ 212: Swing
+ 213: Coffee Machine
+ 214: Slide
+ 215: Carriage
+ 216: Onion
+ 217: Green beans
+ 218: Projector
+ 219: Frisbee
+ 220: Washing Machine/Drying Machine
+ 221: Chicken
+ 222: Printer
+ 223: Watermelon
+ 224: Saxophone
+ 225: Tissue
+ 226: Toothbrush
+ 227: Ice cream
+ 228: Hot-air balloon
+ 229: Cello
+ 230: French Fries
+ 231: Scale
+ 232: Trophy
+ 233: Cabbage
+ 234: Hot dog
+ 235: Blender
+ 236: Peach
+ 237: Rice
+ 238: Wallet/Purse
+ 239: Volleyball
+ 240: Deer
+ 241: Goose
+ 242: Tape
+ 243: Tablet
+ 244: Cosmetics
+ 245: Trumpet
+ 246: Pineapple
+ 247: Golf Ball
+ 248: Ambulance
+ 249: Parking meter
+ 250: Mango
+ 251: Key
+ 252: Hurdle
+ 253: Fishing Rod
+ 254: Medal
+ 255: Flute
+ 256: Brush
+ 257: Penguin
+ 258: Megaphone
+ 259: Corn
+ 260: Lettuce
+ 261: Garlic
+ 262: Swan
+ 263: Helicopter
+ 264: Green Onion
+ 265: Sandwich
+ 266: Nuts
+ 267: Speed Limit Sign
+ 268: Induction Cooker
+ 269: Broom
+ 270: Trombone
+ 271: Plum
+ 272: Rickshaw
+ 273: Goldfish
+ 274: Kiwi fruit
+ 275: Router/modem
+ 276: Poker Card
+ 277: Toaster
+ 278: Shrimp
+ 279: Sushi
+ 280: Cheese
+ 281: Notepaper
+ 282: Cherry
+ 283: Pliers
+ 284: CD
+ 285: Pasta
+ 286: Hammer
+ 287: Cue
+ 288: Avocado
+ 289: Hamimelon
+ 290: Flask
+ 291: Mushroom
+ 292: Screwdriver
+ 293: Soap
+ 294: Recorder
+ 295: Bear
+ 296: Eggplant
+ 297: Board Eraser
+ 298: Coconut
+ 299: Tape Measure/Ruler
+ 300: Pig
+ 301: Showerhead
+ 302: Globe
+ 303: Chips
+ 304: Steak
+ 305: Crosswalk Sign
+ 306: Stapler
+ 307: Camel
+ 308: Formula 1
+ 309: Pomegranate
+ 310: Dishwasher
+ 311: Crab
+ 312: Hoverboard
+ 313: Meat ball
+ 314: Rice Cooker
+ 315: Tuba
+ 316: Calculator
+ 317: Papaya
+ 318: Antelope
+ 319: Parrot
+ 320: Seal
+ 321: Butterfly
+ 322: Dumbbell
+ 323: Donkey
+ 324: Lion
+ 325: Urinal
+ 326: Dolphin
+ 327: Electric Drill
+ 328: Hair Dryer
+ 329: Egg tart
+ 330: Jellyfish
+ 331: Treadmill
+ 332: Lighter
+ 333: Grapefruit
+ 334: Game board
+ 335: Mop
+ 336: Radish
+ 337: Baozi
+ 338: Target
+ 339: French
+ 340: Spring Rolls
+ 341: Monkey
+ 342: Rabbit
+ 343: Pencil Case
+ 344: Yak
+ 345: Red Cabbage
+ 346: Binoculars
+ 347: Asparagus
+ 348: Barbell
+ 349: Scallop
+ 350: Noddles
+ 351: Comb
+ 352: Dumpling
+ 353: Oyster
+ 354: Table Tennis paddle
+ 355: Cosmetics Brush/Eyeliner Pencil
+ 356: Chainsaw
+ 357: Eraser
+ 358: Lobster
+ 359: Durian
+ 360: Okra
+ 361: Lipstick
+ 362: Cosmetics Mirror
+ 363: Curling
+ 364: Table Tennis
+
+
+# Download script/URL (optional) ---------------------------------------------------------------------------------------
+download: |
+ from tqdm import tqdm
+
+ from utils.general import Path, check_requirements, download, np, xyxy2xywhn
+
+ check_requirements(('pycocotools>=2.0',))
+ from pycocotools.coco import COCO
+
+ # Make Directories
+ dir = Path(yaml['path']) # dataset root dir
+ for p in 'images', 'labels':
+ (dir / p).mkdir(parents=True, exist_ok=True)
+ for q in 'train', 'val':
+ (dir / p / q).mkdir(parents=True, exist_ok=True)
+
+ # Train, Val Splits
+ for split, patches in [('train', 50 + 1), ('val', 43 + 1)]:
+ print(f"Processing {split} in {patches} patches ...")
+ images, labels = dir / 'images' / split, dir / 'labels' / split
+
+ # Download
+ url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
+ if split == 'train':
+ download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir, delete=False) # annotations json
+ download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, delete=False, threads=8)
+ elif split == 'val':
+ download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir, delete=False) # annotations json
+ download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, delete=False, threads=8)
+ download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, delete=False, threads=8)
+
+ # Move
+ for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):
+ f.rename(images / f.name) # move to /images/{split}
+
+ # Labels
+ coco = COCO(dir / f'zhiyuan_objv2_{split}.json')
+ names = [x["name"] for x in coco.loadCats(coco.getCatIds())]
+ for cid, cat in enumerate(names):
+ catIds = coco.getCatIds(catNms=[cat])
+ imgIds = coco.getImgIds(catIds=catIds)
+ for im in tqdm(coco.loadImgs(imgIds), desc=f'Class {cid + 1}/{len(names)} {cat}'):
+ width, height = im["width"], im["height"]
+ path = Path(im["file_name"]) # image filename
+ try:
+ with open(labels / path.with_suffix('.txt').name, 'a') as file:
+ annIds = coco.getAnnIds(imgIds=im["id"], catIds=catIds, iscrowd=None)
+ for a in coco.loadAnns(annIds):
+ x, y, w, h = a['bbox'] # bounding box in xywh (xy top-left corner)
+ xyxy = np.array([x, y, x + w, y + h])[None] # pixels(1,4)
+ x, y, w, h = xyxy2xywhn(xyxy, w=width, h=height, clip=True)[0] # normalized and clipped
+ file.write(f"{cid} {x:.5f} {y:.5f} {w:.5f} {h:.5f}\n")
+ except Exception as e:
+ print(e)
diff --git a/ultralytics/yolo/data/datasets/SKU-110K.yaml b/ultralytics/yolo/data/datasets/SKU-110K.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9bfad4f6a549f362b0cbf0384c0d5cd57a7ce6d6
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/SKU-110K.yaml
@@ -0,0 +1,53 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail
+# Example usage: python train.py --data SKU-110K.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ SKU-110K β downloads here (13.6 GB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/SKU-110K # dataset root dir
+train: train.txt # train images (relative to 'path') 8219 images
+val: val.txt # val images (relative to 'path') 588 images
+test: test.txt # test images (optional) 2936 images
+
+# Classes
+names:
+ 0: object
+
+
+# Download script/URL (optional) ---------------------------------------------------------------------------------------
+download: |
+ import shutil
+ from tqdm import tqdm
+ from utils.general import np, pd, Path, download, xyxy2xywh
+
+
+ # Download
+ dir = Path(yaml['path']) # dataset root dir
+ parent = Path(dir.parent) # download dir
+ urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
+ download(urls, dir=parent, delete=False)
+
+ # Rename directories
+ if dir.exists():
+ shutil.rmtree(dir)
+ (parent / 'SKU110K_fixed').rename(dir) # rename dir
+ (dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir
+
+ # Convert labels
+ names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names
+ for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv':
+ x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations
+ images, unique_images = x[:, 0], np.unique(x[:, 0])
+ with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f:
+ f.writelines(f'./images/{s}\n' for s in unique_images)
+ for im in tqdm(unique_images, desc=f'Converting {dir / d}'):
+ cls = 0 # single-class dataset
+ with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f:
+ for r in x[images == im]:
+ w, h = r[6], r[7] # image width, height
+ xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance
+ f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label
diff --git a/ultralytics/yolo/data/datasets/VOC.yaml b/ultralytics/yolo/data/datasets/VOC.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..96d232ebb98864a07950050a91d93958cf080545
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/VOC.yaml
@@ -0,0 +1,100 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford
+# Example usage: python train.py --data VOC.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ VOC β downloads here (2.8 GB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/VOC
+train: # train images (relative to 'path') 16551 images
+ - images/train2012
+ - images/train2007
+ - images/val2012
+ - images/val2007
+val: # val images (relative to 'path') 4952 images
+ - images/test2007
+test: # test images (optional)
+ - images/test2007
+
+# Classes
+names:
+ 0: aeroplane
+ 1: bicycle
+ 2: bird
+ 3: boat
+ 4: bottle
+ 5: bus
+ 6: car
+ 7: cat
+ 8: chair
+ 9: cow
+ 10: diningtable
+ 11: dog
+ 12: horse
+ 13: motorbike
+ 14: person
+ 15: pottedplant
+ 16: sheep
+ 17: sofa
+ 18: train
+ 19: tvmonitor
+
+
+# Download script/URL (optional) ---------------------------------------------------------------------------------------
+download: |
+ import xml.etree.ElementTree as ET
+
+ from tqdm import tqdm
+ from utils.general import download, Path
+
+
+ def convert_label(path, lb_path, year, image_id):
+ def convert_box(size, box):
+ dw, dh = 1. / size[0], 1. / size[1]
+ x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2]
+ return x * dw, y * dh, w * dw, h * dh
+
+ in_file = open(path / f'VOC{year}/Annotations/{image_id}.xml')
+ out_file = open(lb_path, 'w')
+ tree = ET.parse(in_file)
+ root = tree.getroot()
+ size = root.find('size')
+ w = int(size.find('width').text)
+ h = int(size.find('height').text)
+
+ names = list(yaml['names'].values()) # names list
+ for obj in root.iter('object'):
+ cls = obj.find('name').text
+ if cls in names and int(obj.find('difficult').text) != 1:
+ xmlbox = obj.find('bndbox')
+ bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')])
+ cls_id = names.index(cls) # class id
+ out_file.write(" ".join([str(a) for a in (cls_id, *bb)]) + '\n')
+
+
+ # Download
+ dir = Path(yaml['path']) # dataset root dir
+ url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
+ urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
+ f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
+ f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
+ download(urls, dir=dir / 'images', delete=False, curl=True, threads=3)
+
+ # Convert
+ path = dir / 'images/VOCdevkit'
+ for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'):
+ imgs_path = dir / 'images' / f'{image_set}{year}'
+ lbs_path = dir / 'labels' / f'{image_set}{year}'
+ imgs_path.mkdir(exist_ok=True, parents=True)
+ lbs_path.mkdir(exist_ok=True, parents=True)
+
+ with open(path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f:
+ image_ids = f.read().strip().split()
+ for id in tqdm(image_ids, desc=f'{image_set}{year}'):
+ f = path / f'VOC{year}/JPEGImages/{id}.jpg' # old img path
+ lb_path = (lbs_path / f.name).with_suffix('.txt') # new label path
+ f.rename(imgs_path / f.name) # move image
+ convert_label(path, lb_path, year, id) # convert labels to YOLO format
diff --git a/ultralytics/yolo/data/datasets/VisDrone.yaml b/ultralytics/yolo/data/datasets/VisDrone.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..61730613a5b99d680a4925fdca1069e4b1a797ca
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/VisDrone.yaml
@@ -0,0 +1,70 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University
+# Example usage: python train.py --data VisDrone.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ VisDrone β downloads here (2.3 GB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/VisDrone # dataset root dir
+train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images
+val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images
+test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images
+
+# Classes
+names:
+ 0: pedestrian
+ 1: people
+ 2: bicycle
+ 3: car
+ 4: van
+ 5: truck
+ 6: tricycle
+ 7: awning-tricycle
+ 8: bus
+ 9: motor
+
+
+# Download script/URL (optional) ---------------------------------------------------------------------------------------
+download: |
+ from utils.general import download, os, Path
+
+ def visdrone2yolo(dir):
+ from PIL import Image
+ from tqdm import tqdm
+
+ def convert_box(size, box):
+ # Convert VisDrone box to YOLO xywh box
+ dw = 1. / size[0]
+ dh = 1. / size[1]
+ return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh
+
+ (dir / 'labels').mkdir(parents=True, exist_ok=True) # make labels directory
+ pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}')
+ for f in pbar:
+ img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size
+ lines = []
+ with open(f, 'r') as file: # read annotation.txt
+ for row in [x.split(',') for x in file.read().strip().splitlines()]:
+ if row[4] == '0': # VisDrone 'ignored regions' class 0
+ continue
+ cls = int(row[5]) - 1
+ box = convert_box(img_size, tuple(map(int, row[:4])))
+ lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n")
+ with open(str(f).replace(os.sep + 'annotations' + os.sep, os.sep + 'labels' + os.sep), 'w') as fl:
+ fl.writelines(lines) # write label.txt
+
+
+ # Download
+ dir = Path(yaml['path']) # dataset root dir
+ urls = ['https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-train.zip',
+ 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-val.zip',
+ 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-dev.zip',
+ 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-challenge.zip']
+ download(urls, dir=dir, curl=True, threads=4)
+
+ # Convert
+ for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev':
+ visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels
diff --git a/ultralytics/yolo/data/datasets/coco.yaml b/ultralytics/yolo/data/datasets/coco.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1498eac4477a59ec652aa81545034788947f688f
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/coco.yaml
@@ -0,0 +1,113 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# COCO 2017 dataset http://cocodataset.org by Microsoft
+# Example usage: python train.py --data coco.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ coco β downloads here (20.1 GB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/coco # dataset root dir
+train: train2017.txt # train images (relative to 'path') 118287 images
+val: val2017.txt # val images (relative to 'path') 5000 images
+test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
+
+# Classes
+names:
+ 0: person
+ 1: bicycle
+ 2: car
+ 3: motorcycle
+ 4: airplane
+ 5: bus
+ 6: train
+ 7: truck
+ 8: boat
+ 9: traffic light
+ 10: fire hydrant
+ 11: stop sign
+ 12: parking meter
+ 13: bench
+ 14: bird
+ 15: cat
+ 16: dog
+ 17: horse
+ 18: sheep
+ 19: cow
+ 20: elephant
+ 21: bear
+ 22: zebra
+ 23: giraffe
+ 24: backpack
+ 25: umbrella
+ 26: handbag
+ 27: tie
+ 28: suitcase
+ 29: frisbee
+ 30: skis
+ 31: snowboard
+ 32: sports ball
+ 33: kite
+ 34: baseball bat
+ 35: baseball glove
+ 36: skateboard
+ 37: surfboard
+ 38: tennis racket
+ 39: bottle
+ 40: wine glass
+ 41: cup
+ 42: fork
+ 43: knife
+ 44: spoon
+ 45: bowl
+ 46: banana
+ 47: apple
+ 48: sandwich
+ 49: orange
+ 50: broccoli
+ 51: carrot
+ 52: hot dog
+ 53: pizza
+ 54: donut
+ 55: cake
+ 56: chair
+ 57: couch
+ 58: potted plant
+ 59: bed
+ 60: dining table
+ 61: toilet
+ 62: tv
+ 63: laptop
+ 64: mouse
+ 65: remote
+ 66: keyboard
+ 67: cell phone
+ 68: microwave
+ 69: oven
+ 70: toaster
+ 71: sink
+ 72: refrigerator
+ 73: book
+ 74: clock
+ 75: vase
+ 76: scissors
+ 77: teddy bear
+ 78: hair drier
+ 79: toothbrush
+
+
+# Download script/URL (optional)
+download: |
+ from utils.general import download, Path
+ # Download labels
+ segments = True # segment or box labels
+ dir = Path(yaml['path']) # dataset root dir
+ url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
+ urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
+ download(urls, dir=dir.parent)
+ # Download data
+ urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
+ 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
+ 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
+ download(urls, dir=dir / 'images', threads=3)
\ No newline at end of file
diff --git a/ultralytics/yolo/data/datasets/coco128-seg.yaml b/ultralytics/yolo/data/datasets/coco128-seg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6f9ddba4173cfe9c7777791707684d816c10e8ae
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/coco128-seg.yaml
@@ -0,0 +1,101 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
+# Example usage: python train.py --data coco128.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ coco128-seg β downloads here (7 MB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/coco128-seg # dataset root dir
+train: images/train2017 # train images (relative to 'path') 128 images
+val: images/train2017 # val images (relative to 'path') 128 images
+test: # test images (optional)
+
+# Classes
+names:
+ 0: person
+ 1: bicycle
+ 2: car
+ 3: motorcycle
+ 4: airplane
+ 5: bus
+ 6: train
+ 7: truck
+ 8: boat
+ 9: traffic light
+ 10: fire hydrant
+ 11: stop sign
+ 12: parking meter
+ 13: bench
+ 14: bird
+ 15: cat
+ 16: dog
+ 17: horse
+ 18: sheep
+ 19: cow
+ 20: elephant
+ 21: bear
+ 22: zebra
+ 23: giraffe
+ 24: backpack
+ 25: umbrella
+ 26: handbag
+ 27: tie
+ 28: suitcase
+ 29: frisbee
+ 30: skis
+ 31: snowboard
+ 32: sports ball
+ 33: kite
+ 34: baseball bat
+ 35: baseball glove
+ 36: skateboard
+ 37: surfboard
+ 38: tennis racket
+ 39: bottle
+ 40: wine glass
+ 41: cup
+ 42: fork
+ 43: knife
+ 44: spoon
+ 45: bowl
+ 46: banana
+ 47: apple
+ 48: sandwich
+ 49: orange
+ 50: broccoli
+ 51: carrot
+ 52: hot dog
+ 53: pizza
+ 54: donut
+ 55: cake
+ 56: chair
+ 57: couch
+ 58: potted plant
+ 59: bed
+ 60: dining table
+ 61: toilet
+ 62: tv
+ 63: laptop
+ 64: mouse
+ 65: remote
+ 66: keyboard
+ 67: cell phone
+ 68: microwave
+ 69: oven
+ 70: toaster
+ 71: sink
+ 72: refrigerator
+ 73: book
+ 74: clock
+ 75: vase
+ 76: scissors
+ 77: teddy bear
+ 78: hair drier
+ 79: toothbrush
+
+
+# Download script/URL (optional)
+download: https://ultralytics.com/assets/coco128-seg.zip
\ No newline at end of file
diff --git a/ultralytics/yolo/data/datasets/coco128.yaml b/ultralytics/yolo/data/datasets/coco128.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3ef3b8bdfc6fced734cb7e2c45bb8e26160fb78f
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/coco128.yaml
@@ -0,0 +1,101 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
+# Example usage: python train.py --data coco128.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ coco128 β downloads here (7 MB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/coco128 # dataset root dir
+train: images/train2017 # train images (relative to 'path') 128 images
+val: images/train2017 # val images (relative to 'path') 128 images
+test: # test images (optional)
+
+# Classes
+names:
+ 0: person
+ 1: bicycle
+ 2: car
+ 3: motorcycle
+ 4: airplane
+ 5: bus
+ 6: train
+ 7: truck
+ 8: boat
+ 9: traffic light
+ 10: fire hydrant
+ 11: stop sign
+ 12: parking meter
+ 13: bench
+ 14: bird
+ 15: cat
+ 16: dog
+ 17: horse
+ 18: sheep
+ 19: cow
+ 20: elephant
+ 21: bear
+ 22: zebra
+ 23: giraffe
+ 24: backpack
+ 25: umbrella
+ 26: handbag
+ 27: tie
+ 28: suitcase
+ 29: frisbee
+ 30: skis
+ 31: snowboard
+ 32: sports ball
+ 33: kite
+ 34: baseball bat
+ 35: baseball glove
+ 36: skateboard
+ 37: surfboard
+ 38: tennis racket
+ 39: bottle
+ 40: wine glass
+ 41: cup
+ 42: fork
+ 43: knife
+ 44: spoon
+ 45: bowl
+ 46: banana
+ 47: apple
+ 48: sandwich
+ 49: orange
+ 50: broccoli
+ 51: carrot
+ 52: hot dog
+ 53: pizza
+ 54: donut
+ 55: cake
+ 56: chair
+ 57: couch
+ 58: potted plant
+ 59: bed
+ 60: dining table
+ 61: toilet
+ 62: tv
+ 63: laptop
+ 64: mouse
+ 65: remote
+ 66: keyboard
+ 67: cell phone
+ 68: microwave
+ 69: oven
+ 70: toaster
+ 71: sink
+ 72: refrigerator
+ 73: book
+ 74: clock
+ 75: vase
+ 76: scissors
+ 77: teddy bear
+ 78: hair drier
+ 79: toothbrush
+
+
+# Download script/URL (optional)
+download: https://ultralytics.com/assets/coco128.zip
\ No newline at end of file
diff --git a/ultralytics/yolo/data/datasets/xView.yaml b/ultralytics/yolo/data/datasets/xView.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f11f13e13adf28bafaf96e74d80076df3173b995
--- /dev/null
+++ b/ultralytics/yolo/data/datasets/xView.yaml
@@ -0,0 +1,153 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# DIUx xView 2018 Challenge https://challenge.xviewdataset.org by U.S. National Geospatial-Intelligence Agency (NGA)
+# -------- DOWNLOAD DATA MANUALLY and jar xf val_images.zip to 'datasets/xView' before running train command! --------
+# Example usage: python train.py --data xView.yaml
+# parent
+# βββ yolov5
+# βββ datasets
+# βββ xView β downloads here (20.7 GB)
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+path: ../datasets/xView # dataset root dir
+train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images
+val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images
+
+# Classes
+names:
+ 0: Fixed-wing Aircraft
+ 1: Small Aircraft
+ 2: Cargo Plane
+ 3: Helicopter
+ 4: Passenger Vehicle
+ 5: Small Car
+ 6: Bus
+ 7: Pickup Truck
+ 8: Utility Truck
+ 9: Truck
+ 10: Cargo Truck
+ 11: Truck w/Box
+ 12: Truck Tractor
+ 13: Trailer
+ 14: Truck w/Flatbed
+ 15: Truck w/Liquid
+ 16: Crane Truck
+ 17: Railway Vehicle
+ 18: Passenger Car
+ 19: Cargo Car
+ 20: Flat Car
+ 21: Tank car
+ 22: Locomotive
+ 23: Maritime Vessel
+ 24: Motorboat
+ 25: Sailboat
+ 26: Tugboat
+ 27: Barge
+ 28: Fishing Vessel
+ 29: Ferry
+ 30: Yacht
+ 31: Container Ship
+ 32: Oil Tanker
+ 33: Engineering Vehicle
+ 34: Tower crane
+ 35: Container Crane
+ 36: Reach Stacker
+ 37: Straddle Carrier
+ 38: Mobile Crane
+ 39: Dump Truck
+ 40: Haul Truck
+ 41: Scraper/Tractor
+ 42: Front loader/Bulldozer
+ 43: Excavator
+ 44: Cement Mixer
+ 45: Ground Grader
+ 46: Hut/Tent
+ 47: Shed
+ 48: Building
+ 49: Aircraft Hangar
+ 50: Damaged Building
+ 51: Facility
+ 52: Construction Site
+ 53: Vehicle Lot
+ 54: Helipad
+ 55: Storage Tank
+ 56: Shipping container lot
+ 57: Shipping Container
+ 58: Pylon
+ 59: Tower
+
+
+# Download script/URL (optional) ---------------------------------------------------------------------------------------
+download: |
+ import json
+ import os
+ from pathlib import Path
+
+ import numpy as np
+ from PIL import Image
+ from tqdm import tqdm
+
+ from utils.dataloaders import autosplit
+ from utils.general import download, xyxy2xywhn
+
+
+ def convert_labels(fname=Path('xView/xView_train.geojson')):
+ # Convert xView geoJSON labels to YOLO format
+ path = fname.parent
+ with open(fname) as f:
+ print(f'Loading {fname}...')
+ data = json.load(f)
+
+ # Make dirs
+ labels = Path(path / 'labels' / 'train')
+ os.system(f'rm -rf {labels}')
+ labels.mkdir(parents=True, exist_ok=True)
+
+ # xView classes 11-94 to 0-59
+ xview_class2index = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, -1, 3, -1, 4, 5, 6, 7, 8, -1, 9, 10, 11,
+ 12, 13, 14, 15, -1, -1, 16, 17, 18, 19, 20, 21, 22, -1, 23, 24, 25, -1, 26, 27, -1, 28, -1,
+ 29, 30, 31, 32, 33, 34, 35, 36, 37, -1, 38, 39, 40, 41, 42, 43, 44, 45, -1, -1, -1, -1, 46,
+ 47, 48, 49, -1, 50, 51, -1, 52, -1, -1, -1, 53, 54, -1, 55, -1, -1, 56, -1, 57, -1, 58, 59]
+
+ shapes = {}
+ for feature in tqdm(data['features'], desc=f'Converting {fname}'):
+ p = feature['properties']
+ if p['bounds_imcoords']:
+ id = p['image_id']
+ file = path / 'train_images' / id
+ if file.exists(): # 1395.tif missing
+ try:
+ box = np.array([int(num) for num in p['bounds_imcoords'].split(",")])
+ assert box.shape[0] == 4, f'incorrect box shape {box.shape[0]}'
+ cls = p['type_id']
+ cls = xview_class2index[int(cls)] # xView class to 0-60
+ assert 59 >= cls >= 0, f'incorrect class index {cls}'
+
+ # Write YOLO label
+ if id not in shapes:
+ shapes[id] = Image.open(file).size
+ box = xyxy2xywhn(box[None].astype(np.float), w=shapes[id][0], h=shapes[id][1], clip=True)
+ with open((labels / id).with_suffix('.txt'), 'a') as f:
+ f.write(f"{cls} {' '.join(f'{x:.6f}' for x in box[0])}\n") # write label.txt
+ except Exception as e:
+ print(f'WARNING: skipping one label for {file}: {e}')
+
+
+ # Download manually from https://challenge.xviewdataset.org
+ dir = Path(yaml['path']) # dataset root dir
+ # urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
+ # 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
+ # 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
+ # download(urls, dir=dir, delete=False)
+
+ # Convert labels
+ convert_labels(dir / 'xView_train.geojson')
+
+ # Move images
+ images = Path(dir / 'images')
+ images.mkdir(parents=True, exist_ok=True)
+ Path(dir / 'train_images').rename(dir / 'images' / 'train')
+ Path(dir / 'val_images').rename(dir / 'images' / 'val')
+
+ # Split
+ autosplit(dir / 'images' / 'train')
diff --git a/ultralytics/yolo/data/scripts/download_weights.sh b/ultralytics/yolo/data/scripts/download_weights.sh
new file mode 100644
index 0000000000000000000000000000000000000000..59d37faec76674603f4df49e9b7a301f6f7f6c51
--- /dev/null
+++ b/ultralytics/yolo/data/scripts/download_weights.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+# Ultralytics YOLO π, GPL-3.0 license
+# Download latest models from https://github.com/ultralytics/yolov5/releases
+# Example usage: bash data/scripts/download_weights.sh
+# parent
+# βββ yolov5
+# βββ yolov5s.pt β downloads here
+# βββ yolov5m.pt
+# βββ ...
+
+python - < 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
+ if im.format.lower() in ("jpg", "jpeg"):
+ with open(im_file, "rb") as f:
+ f.seek(-2, 2)
+ if f.read() != b"\xff\xd9": # corrupt JPEG
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
+ msg = f"{prefix}WARNING β οΈ {im_file}: corrupt JPEG restored and saved"
+
+ # verify labels
+ if os.path.isfile(lb_file):
+ nf = 1 # label found
+ with open(lb_file) as f:
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
+ if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
+ lb = np.array(lb, dtype=np.float32)
+ nl = len(lb)
+ if nl:
+ if keypoint:
+ assert lb.shape[1] == 56, "labels require 56 columns each"
+ assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
+ assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
+ kpts = np.zeros((lb.shape[0], 39))
+ for i in range(len(lb)):
+ kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
+ 3)) # remove the occlusion parameter from the GT
+ kpts[i] = np.hstack((lb[i, :5], kpt))
+ lb = kpts
+ assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
+ else:
+ assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
+ assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
+ assert (lb[:, 1:] <=
+ 1).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
+ _, i = np.unique(lb, axis=0, return_index=True)
+ if len(i) < nl: # duplicate row check
+ lb = lb[i] # remove duplicates
+ if segments:
+ segments = [segments[x] for x in i]
+ msg = f"{prefix}WARNING β οΈ {im_file}: {nl - len(i)} duplicate labels removed"
+ else:
+ ne = 1 # label empty
+ lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
+ else:
+ nm = 1 # label missing
+ lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
+ if keypoint:
+ keypoints = lb[:, 5:].reshape(-1, 17, 2)
+ lb = lb[:, :5]
+ return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
+ except Exception as e:
+ nc = 1
+ msg = f"{prefix}WARNING β οΈ {im_file}: ignoring corrupt image/label: {e}"
+ return [None, None, None, None, None, nm, nf, ne, nc, msg]
+
+
+def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
+ """
+ Args:
+ imgsz (tuple): The image size.
+ polygons (np.ndarray): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
+ color (int): color
+ downsample_ratio (int): downsample ratio
+ """
+ mask = np.zeros(imgsz, dtype=np.uint8)
+ polygons = np.asarray(polygons)
+ polygons = polygons.astype(np.int32)
+ shape = polygons.shape
+ polygons = polygons.reshape(shape[0], -1, 2)
+ cv2.fillPoly(mask, polygons, color=color)
+ nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
+ # NOTE: fillPoly firstly then resize is trying the keep the same way
+ # of loss calculation when mask-ratio=1.
+ mask = cv2.resize(mask, (nw, nh))
+ return mask
+
+
+def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
+ """
+ Args:
+ imgsz (tuple): The image size.
+ polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
+ color (int): color
+ downsample_ratio (int): downsample ratio
+ """
+ masks = []
+ for si in range(len(polygons)):
+ mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio)
+ masks.append(mask)
+ return np.array(masks)
+
+
+def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
+ """Return a (640, 640) overlap mask."""
+ masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
+ dtype=np.int32 if len(segments) > 255 else np.uint8)
+ areas = []
+ ms = []
+ for si in range(len(segments)):
+ mask = polygon2mask(
+ imgsz,
+ [segments[si].reshape(-1)],
+ downsample_ratio=downsample_ratio,
+ color=1,
+ )
+ ms.append(mask)
+ areas.append(mask.sum())
+ areas = np.asarray(areas)
+ index = np.argsort(-areas)
+ ms = np.array(ms)[index]
+ for i in range(len(segments)):
+ mask = ms[i] * (i + 1)
+ masks = masks + mask
+ masks = np.clip(masks, a_min=0, a_max=i + 1)
+ return masks, index
+
+
+def check_dataset_yaml(data, autodownload=True):
+ # Download, check and/or unzip dataset if not found locally
+ data = check_file(data)
+ DATASETS_DIR = (Path.cwd() / "../datasets").resolve() # TODO: handle global dataset dir
+ # Download (optional)
+ extract_dir = ''
+ if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
+ download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
+ data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
+ extract_dir, autodownload = data.parent, False
+ # Read yaml (optional)
+ if isinstance(data, (str, Path)):
+ data = yaml_load(data, append_filename=True) # dictionary
+
+ # Checks
+ for k in 'train', 'val', 'names':
+ assert k in data, f"data.yaml '{k}:' field missing β"
+ if isinstance(data['names'], (list, tuple)): # old array format
+ data['names'] = dict(enumerate(data['names'])) # convert to dict
+ data['nc'] = len(data['names'])
+
+ # Resolve paths
+ path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
+ if not path.is_absolute():
+ path = (Path.cwd() / path).resolve()
+ data['path'] = path # download scripts
+ for k in 'train', 'val', 'test':
+ if data.get(k): # prepend path
+ if isinstance(data[k], str):
+ x = (path / data[k]).resolve()
+ if not x.exists() and data[k].startswith('../'):
+ x = (path / data[k][3:]).resolve()
+ data[k] = str(x)
+ else:
+ data[k] = [str((path / x).resolve()) for x in data[k]]
+
+ # Parse yaml
+ train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
+ if val:
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
+ if not all(x.exists() for x in val):
+ LOGGER.info('\nDataset not found β οΈ, missing paths %s' % [str(x) for x in val if not x.exists()])
+ if not s or not autodownload:
+ raise FileNotFoundError('Dataset not found β')
+ t = time.time()
+ if s.startswith('http') and s.endswith('.zip'): # URL
+ f = Path(s).name # filename
+ LOGGER.info(f'Downloading {s} to {f}...')
+ torch.hub.download_url_to_file(s, f)
+ Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
+ unzip_file(f, path=DATASETS_DIR) # unzip
+ Path(f).unlink() # remove zip
+ r = None # success
+ elif s.startswith('bash '): # bash script
+ LOGGER.info(f'Running {s} ...')
+ r = os.system(s)
+ else: # python script
+ r = exec(s, {'yaml': data}) # return None
+ dt = f'({round(time.time() - t, 1)}s)'
+ s = f"success β
{dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} β"
+ LOGGER.info(f"Dataset download {s}")
+ check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
+ return data # dictionary
+
+
+def check_dataset(dataset: str):
+ """
+ Check a classification dataset such as Imagenet.
+
+ Copy code
+ This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
+ If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system.
+
+ Args:
+ dataset (str): Name of the dataset.
+
+ Returns:
+ data (dict): A dictionary containing the following keys and values:
+ 'train': Path object for the directory containing the training set of the dataset
+ 'val': Path object for the directory containing the validation set of the dataset
+ 'nc': Number of classes in the dataset
+ 'names': List of class names in the dataset
+ """
+ data_dir = (Path.cwd() / "datasets" / dataset).resolve()
+ if not data_dir.is_dir():
+ LOGGER.info(f'\nDataset not found β οΈ, missing path {data_dir}, attempting download...')
+ t = time.time()
+ if dataset == 'imagenet':
+ subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
+ else:
+ url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
+ download(url, dir=data_dir.parent)
+ s = f"Dataset download success β
({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
+ LOGGER.info(s)
+ train_set = data_dir / "train"
+ test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
+ nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
+ names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
+ names = dict(enumerate(sorted(names)))
+ return {"train": train_set, "val": test_set, "nc": nc, "names": names}
diff --git a/ultralytics/yolo/engine/__init__.py b/ultralytics/yolo/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3188594ed54b8b5dca0829486fc9025143b7b73
--- /dev/null
+++ b/ultralytics/yolo/engine/exporter.py
@@ -0,0 +1,828 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
+
+Format | `format=argument` | Model
+--- | --- | ---
+PyTorch | - | yolov8n.pt
+TorchScript | `torchscript` | yolov8n.torchscript
+ONNX | `onnx` | yolov8n.onnx
+OpenVINO | `openvino` | yolov8n_openvino_model/
+TensorRT | `engine` | yolov8n.engine
+CoreML | `coreml` | yolov8n.mlmodel
+TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
+TensorFlow GraphDef | `pb` | yolov8n.pb
+TensorFlow Lite | `tflite` | yolov8n.tflite
+TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
+TensorFlow.js | `tfjs` | yolov8n_web_model/
+PaddlePaddle | `paddle` | yolov8n_paddle_model/
+
+Requirements:
+ $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
+ $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
+
+Python:
+ from ultralytics import YOLO
+ model = YOLO('yolov8n.yaml')
+ results = model.export(format='onnx')
+
+CLI:
+ $ yolo mode=export model=yolov8n.pt format=onnx
+
+Inference:
+ $ python detect.py --weights yolov8n.pt # PyTorch
+ yolov8n.torchscript # TorchScript
+ yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
+ yolov8n_openvino_model # OpenVINO
+ yolov8n.engine # TensorRT
+ yolov8n.mlmodel # CoreML (macOS-only)
+ yolov8n_saved_model # TensorFlow SavedModel
+ yolov8n.pb # TensorFlow GraphDef
+ yolov8n.tflite # TensorFlow Lite
+ yolov8n_edgetpu.tflite # TensorFlow Edge TPU
+ yolov8n_paddle_model # PaddlePaddle
+
+TensorFlow.js:
+ $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
+ $ npm install
+ $ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
+ $ npm start
+"""
+import contextlib
+import json
+import os
+import platform
+import re
+import subprocess
+import time
+import warnings
+from collections import defaultdict
+from copy import deepcopy
+from pathlib import Path
+
+import hydra
+import numpy as np
+import pandas as pd
+import torch
+
+import ultralytics
+from ultralytics.nn.modules import Detect, Segment
+from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
+from ultralytics.yolo.configs import get_config
+from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
+from ultralytics.yolo.data.utils import check_dataset
+from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
+from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
+from ultralytics.yolo.utils.files import file_size
+from ultralytics.yolo.utils.ops import Profile
+from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode
+
+MACOS = platform.system() == 'Darwin' # macOS environment
+
+
+def export_formats():
+ # YOLOv5 export formats
+ x = [
+ ['PyTorch', '-', '.pt', True, True],
+ ['TorchScript', 'torchscript', '.torchscript', True, True],
+ ['ONNX', 'onnx', '.onnx', True, True],
+ ['OpenVINO', 'openvino', '_openvino_model', True, False],
+ ['TensorRT', 'engine', '.engine', False, True],
+ ['CoreML', 'coreml', '.mlmodel', True, False],
+ ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
+ ['TensorFlow GraphDef', 'pb', '.pb', True, True],
+ ['TensorFlow Lite', 'tflite', '.tflite', True, False],
+ ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
+ ['TensorFlow.js', 'tfjs', '_web_model', False, False],
+ ['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
+ return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
+
+
+def try_export(inner_func):
+ # YOLOv5 export decorator, i..e @try_export
+ inner_args = get_default_args(inner_func)
+
+ def outer_func(*args, **kwargs):
+ prefix = inner_args['prefix']
+ try:
+ with Profile() as dt:
+ f, model = inner_func(*args, **kwargs)
+ LOGGER.info(f'{prefix} export success β
{dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
+ return f, model
+ except Exception as e:
+ LOGGER.info(f'{prefix} export failure β {dt.t:.1f}s: {e}')
+ return None, None
+
+ return outer_func
+
+
+class Exporter:
+ """
+ Exporter
+
+ A class for exporting a model.
+
+ Attributes:
+ args (OmegaConf): Configuration for the exporter.
+ save_dir (Path): Directory to save results.
+ """
+
+ def __init__(self, config=DEFAULT_CONFIG, overrides=None):
+ """
+ Initializes the Exporter class.
+
+ Args:
+ config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
+ overrides (dict, optional): Configuration overrides. Defaults to None.
+ """
+ if overrides is None:
+ overrides = {}
+ self.args = get_config(config, overrides)
+ self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
+ callbacks.add_integration_callbacks(self)
+
+ @smart_inference_mode()
+ def __call__(self, model=None):
+ self.run_callbacks("on_export_start")
+ t = time.time()
+ format = self.args.format.lower() # to lowercase
+ fmts = tuple(export_formats()['Argument'][1:]) # available export formats
+ flags = [x == format for x in fmts]
+ assert sum(flags), f'ERROR: Invalid format={format}, valid formats are {fmts}'
+ jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
+
+ # Load PyTorch model
+ self.device = select_device('cpu' if self.args.device is None else self.args.device)
+ if self.args.half:
+ if self.device.type == 'cpu' and not coreml:
+ LOGGER.info('half=True only compatible with GPU or CoreML export, i.e. use device=0 or format=coreml')
+ self.args.half = False
+ assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
+
+ # Checks
+ # if self.args.batch == model.args['batch_size']: # user has not modified training batch_size
+ self.args.batch = 1
+ self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
+ if self.args.optimize:
+ assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
+
+ # Input
+ im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
+ file = Path(getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml['yaml_file'])
+ if file.suffix == '.yaml':
+ file = Path(file.name)
+
+ # Update model
+ model = deepcopy(model).to(self.device)
+ for p in model.parameters():
+ p.requires_grad = False
+ model.eval()
+ model = model.fuse()
+ for k, m in model.named_modules():
+ if isinstance(m, (Detect, Segment)):
+ m.dynamic = self.args.dynamic
+ m.export = True
+
+ y = None
+ for _ in range(2):
+ y = model(im) # dry runs
+ if self.args.half and not coreml:
+ im, model = im.half(), model.half() # to FP16
+ shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
+ LOGGER.info(
+ f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
+
+ # Warnings
+ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
+ warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning
+ warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
+
+ # Assign
+ self.im = im
+ self.model = model
+ self.file = file
+ self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else (x.shape for x in y)
+ self.metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
+ self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
+
+ # Exports
+ f = [''] * len(fmts) # exported filenames
+ if jit: # TorchScript
+ f[0], _ = self._export_torchscript()
+ if engine: # TensorRT required before ONNX
+ f[1], _ = self._export_engine()
+ if onnx or xml: # OpenVINO requires ONNX
+ f[2], _ = self._export_onnx()
+ if xml: # OpenVINO
+ f[3], _ = self._export_openvino()
+ if coreml: # CoreML
+ f[4], _ = self._export_coreml()
+ if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
+ raise NotImplementedError('YOLOv8 TensorFlow export support is still under development. '
+ 'Please consider contributing to the effort if you have TF expertise. Thank you!')
+ assert not isinstance(model, ClassificationModel), 'ClassificationModel TF exports not yet supported.'
+ nms = False
+ f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
+ agnostic_nms=self.args.agnostic_nms or tfjs)
+ if pb or tfjs: # pb prerequisite to tfjs
+ f[6], _ = self._export_pb(s_model)
+ if tflite or edgetpu:
+ f[7], _ = self._export_tflite(s_model,
+ int8=self.args.int8 or edgetpu,
+ data=self.args.data,
+ nms=nms,
+ agnostic_nms=self.args.agnostic_nms)
+ if edgetpu:
+ f[8], _ = self._export_edgetpu()
+ self._add_tflite_metadata(f[8] or f[7], num_outputs=len(s_model.outputs))
+ if tfjs:
+ f[9], _ = self._export_tfjs()
+ if paddle: # PaddlePaddle
+ f[10], _ = self._export_paddle()
+
+ # Finish
+ f = [str(x) for x in f if x] # filter out '' and None
+ if any(f):
+ task = guess_task_from_head(model.yaml["head"][-1][-2])
+ s = "-WARNING β οΈ not yet supported for YOLOv8 exported models"
+ LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
+ f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}"
+ f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}"
+ f"\nVisualize: https://netron.app")
+
+ self.run_callbacks("on_export_end")
+ return f # return list of exported files/dirs
+
+ @try_export
+ def _export_torchscript(self, prefix=colorstr('TorchScript:')):
+ # YOLOv8 TorchScript model export
+ LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
+ f = self.file.with_suffix('.torchscript')
+
+ ts = torch.jit.trace(self.model, self.im, strict=False)
+ d = {"shape": self.im.shape, "stride": int(max(self.model.stride)), "names": self.model.names}
+ extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
+ if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
+ LOGGER.info(f'{prefix} optimizing for mobile...')
+ from torch.utils.mobile_optimizer import optimize_for_mobile
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
+ else:
+ ts.save(str(f), _extra_files=extra_files)
+ return f, None
+
+ @try_export
+ def _export_onnx(self, prefix=colorstr('ONNX:')):
+ # YOLOv8 ONNX export
+ check_requirements('onnx>=1.12.0')
+ import onnx # noqa
+
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
+ f = str(self.file.with_suffix('.onnx'))
+
+ output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
+ dynamic = self.args.dynamic
+ if dynamic:
+ dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
+ if isinstance(self.model, SegmentationModel):
+ dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
+ dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
+ elif isinstance(self.model, DetectionModel):
+ dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
+
+ torch.onnx.export(
+ self.model.cpu() if dynamic else self.model, # --dynamic only compatible with cpu
+ self.im.cpu() if dynamic else self.im,
+ f,
+ verbose=False,
+ opset_version=self.args.opset,
+ do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
+ input_names=['images'],
+ output_names=output_names,
+ dynamic_axes=dynamic or None)
+
+ # Checks
+ model_onnx = onnx.load(f) # load onnx model
+ onnx.checker.check_model(model_onnx) # check onnx model
+
+ # Metadata
+ d = {'stride': int(max(self.model.stride)), 'names': self.model.names}
+ for k, v in d.items():
+ meta = model_onnx.metadata_props.add()
+ meta.key, meta.value = k, str(v)
+ onnx.save(model_onnx, f)
+
+ # Simplify
+ if self.args.simplify:
+ try:
+ check_requirements('onnxsim')
+ import onnxsim
+
+ LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
+ subprocess.run(f'onnxsim {f} {f}', shell=True)
+ except Exception as e:
+ LOGGER.info(f'{prefix} simplifier failure: {e}')
+ return f, model_onnx
+
+ @try_export
+ def _export_openvino(self, prefix=colorstr('OpenVINO:')):
+ # YOLOv8 OpenVINO export
+ check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
+ import openvino.inference_engine as ie # noqa
+
+ LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
+ f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
+ f_onnx = self.file.with_suffix('.onnx')
+
+ cmd = f"mo --input_model {f_onnx} --output_dir {f} --data_type {'FP16' if self.args.half else 'FP32'}"
+ subprocess.run(cmd.split(), check=True, env=os.environ) # export
+ yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
+ return f, None
+
+ @try_export
+ def _export_paddle(self, prefix=colorstr('PaddlePaddle:')):
+ # YOLOv8 Paddle export
+ check_requirements(('paddlepaddle', 'x2paddle'))
+ import x2paddle # noqa
+ from x2paddle.convert import pytorch2paddle # noqa
+
+ LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
+ f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
+
+ pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
+ yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
+ return f, None
+
+ @try_export
+ def _export_coreml(self, prefix=colorstr('CoreML:')):
+ # YOLOv8 CoreML export
+ check_requirements('coremltools>=6.0')
+ import coremltools as ct # noqa
+
+ class iOSModel(torch.nn.Module):
+ # Wrap an Ultralytics YOLO model for iOS export
+ def __init__(self, model, im):
+ super().__init__()
+ b, c, h, w = im.shape # batch, channel, height, width
+ self.model = model
+ self.nc = len(model.names) # number of classes
+ if w == h:
+ self.normalize = 1.0 / w # scalar
+ else:
+ self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
+
+ def forward(self, x):
+ xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
+ return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
+
+ LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
+ f = self.file.with_suffix('.mlmodel')
+
+ model = iOSModel(self.model, self.im).eval() if self.args.nms else self.model
+ ts = torch.jit.trace(model, self.im, strict=False) # TorchScript model
+ ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])])
+ bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
+ if bits < 32:
+ if MACOS: # quantization only supported on macOS
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
+ else:
+ LOGGER.info(f'{prefix} quantization only supported on macOS, skipping...')
+ if self.args.nms:
+ ct_model = self._pipeline_coreml(ct_model)
+
+ ct_model.save(str(f))
+ return f, ct_model
+
+ @try_export
+ def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
+ # YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt
+ assert self.im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `device==0`'
+ try:
+ import tensorrt as trt # noqa
+ except ImportError:
+ if platform.system() == 'Linux':
+ check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
+ import tensorrt as trt # noqa
+
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0
+ self._export_onnx()
+ onnx = self.file.with_suffix('.onnx')
+
+ LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
+ assert onnx.exists(), f'failed to export ONNX file: {onnx}'
+ f = self.file.with_suffix('.engine') # TensorRT engine file
+ logger = trt.Logger(trt.Logger.INFO)
+ if verbose:
+ logger.min_severity = trt.Logger.Severity.VERBOSE
+
+ builder = trt.Builder(logger)
+ config = builder.create_builder_config()
+ config.max_workspace_size = workspace * 1 << 30
+ # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
+
+ flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
+ network = builder.create_network(flag)
+ parser = trt.OnnxParser(network, logger)
+ if not parser.parse_from_file(str(onnx)):
+ raise RuntimeError(f'failed to load ONNX file: {onnx}')
+
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
+ for inp in inputs:
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
+ for out in outputs:
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
+
+ if self.args.dynamic:
+ shape = self.im.shape
+ if shape[0] <= 1:
+ LOGGER.warning(f"{prefix} WARNING β οΈ --dynamic model requires maximum --batch-size argument")
+ profile = builder.create_optimization_profile()
+ for inp in inputs:
+ profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
+ config.add_optimization_profile(profile)
+
+ LOGGER.info(
+ f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
+ if builder.platform_has_fast_fp16 and self.args.half:
+ config.set_flag(trt.BuilderFlag.FP16)
+ with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
+ t.write(engine.serialize())
+ return f, None
+
+ @try_export
+ def _export_saved_model(self,
+ nms=False,
+ agnostic_nms=False,
+ topk_per_class=100,
+ topk_all=100,
+ iou_thres=0.45,
+ conf_thres=0.25,
+ prefix=colorstr('TensorFlow SavedModel:')):
+
+ # YOLOv8 TensorFlow SavedModel export
+ try:
+ import tensorflow as tf # noqa
+ except ImportError:
+ check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
+ import tensorflow as tf # noqa
+ check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon"),
+ cmds="--extra-index-url https://pypi.ngc.nvidia.com ")
+
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
+ f = str(self.file).replace(self.file.suffix, '_saved_model')
+
+ # Export to ONNX
+ self._export_onnx()
+ onnx = self.file.with_suffix('.onnx')
+
+ # Export to TF SavedModel
+ subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
+
+ # Load saved_model
+ keras_model = tf.saved_model.load(f, tags=None, options=None)
+
+ return f, keras_model
+
+ @try_export
+ def _export_saved_model_OLD(self,
+ nms=False,
+ agnostic_nms=False,
+ topk_per_class=100,
+ topk_all=100,
+ iou_thres=0.45,
+ conf_thres=0.25,
+ prefix=colorstr('TensorFlow SavedModel:')):
+ # YOLOv8 TensorFlow SavedModel export
+ try:
+ import tensorflow as tf # noqa
+ except ImportError:
+ check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
+ import tensorflow as tf # noqa
+ # from models.tf import TFModel
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
+
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
+ f = str(self.file).replace(self.file.suffix, '_saved_model')
+ batch_size, ch, *imgsz = list(self.im.shape) # BCHW
+
+ tf_models = None # TODO: no TF modules available
+ tf_model = tf_models.TFModel(cfg=self.model.yaml, model=self.model.cpu(), nc=self.model.nc, imgsz=imgsz)
+ im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
+ _ = tf_model.predict(im, nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
+ inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if self.args.dynamic else batch_size)
+ outputs = tf_model.predict(inputs, nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
+ keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
+ keras_model.trainable = False
+ keras_model.summary()
+ if self.args.keras:
+ keras_model.save(f, save_format='tf')
+ else:
+ spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
+ m = tf.function(lambda x: keras_model(x)) # full model
+ m = m.get_concrete_function(spec)
+ frozen_func = convert_variables_to_constants_v2(m)
+ tfm = tf.Module()
+ tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if nms else frozen_func(x), [spec])
+ tfm.__call__(im)
+ tf.saved_model.save(tfm,
+ f,
+ options=tf.saved_model.SaveOptions(experimental_custom_gradients=False)
+ if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
+ return f, keras_model
+
+ @try_export
+ def _export_pb(self, keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
+ # YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
+ import tensorflow as tf # noqa
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
+
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
+ f = file.with_suffix('.pb')
+
+ m = tf.function(lambda x: keras_model(x)) # full model
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
+ frozen_func = convert_variables_to_constants_v2(m)
+ frozen_func.graph.as_graph_def()
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
+ return f, None
+
+ @try_export
+ def _export_tflite(self, keras_model, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
+ # YOLOv8 TensorFlow Lite export
+ import tensorflow as tf # noqa
+
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
+ batch_size, ch, *imgsz = list(self.im.shape) # BCHW
+ f = str(self.file).replace(self.file.suffix, '-fp16.tflite')
+
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.target_spec.supported_types = [tf.float16]
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ if int8:
+
+ def representative_dataset_gen(dataset, n_images=100):
+ # Dataset generator for use with converter.representative_dataset, returns a generator of np arrays
+ for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
+ im = np.transpose(img, [1, 2, 0])
+ im = np.expand_dims(im, axis=0).astype(np.float32)
+ im /= 255
+ yield [im]
+ if n >= n_images:
+ break
+
+ dataset = LoadImages(check_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False)
+ converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.target_spec.supported_types = []
+ converter.inference_input_type = tf.uint8 # or tf.int8
+ converter.inference_output_type = tf.uint8 # or tf.int8
+ converter.experimental_new_quantizer = True
+ f = str(self.file).replace(self.file.suffix, '-int8.tflite')
+ if nms or agnostic_nms:
+ converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
+
+ tflite_model = converter.convert()
+ open(f, "wb").write(tflite_model)
+ return f, None
+
+ @try_export
+ def _export_edgetpu(self, prefix=colorstr('Edge TPU:')):
+ # YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
+ cmd = 'edgetpu_compiler --version'
+ help_url = 'https://coral.ai/docs/edgetpu/compiler/'
+ assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
+ if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
+ LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
+ sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
+ for c in (
+ 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' # no comma
+ 'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
+ 'sudo apt-get update',
+ 'sudo apt-get install edgetpu-compiler'):
+ subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
+
+ LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
+ f = str(self.file).replace(self.file.suffix, '-int8_edgetpu.tflite') # Edge TPU model
+ f_tfl = str(self.file).replace(self.file.suffix, '-int8.tflite') # TFLite model
+
+ cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {f_tfl}"
+ subprocess.run(cmd.split(), check=True)
+ return f, None
+
+ @try_export
+ def _export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
+ # YOLOv8 TensorFlow.js export
+ check_requirements('tensorflowjs')
+ import tensorflowjs as tfjs # noqa
+
+ LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
+ f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
+ f_pb = self.file.with_suffix('.pb') # *.pb path
+ f_json = Path(f) / 'model.json' # *.json path
+
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
+ f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
+ subprocess.run(cmd.split())
+
+ with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
+ subst = re.sub(
+ r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
+ r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
+ r'"Identity_1": {"name": "Identity_1"}, '
+ r'"Identity_2": {"name": "Identity_2"}, '
+ r'"Identity_3": {"name": "Identity_3"}}}', f_json.read_text())
+ j.write(subst)
+ return f, None
+
+ def _add_tflite_metadata(self, file, num_outputs):
+ # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
+ with contextlib.suppress(ImportError):
+ # check_requirements('tflite_support')
+ from tflite_support import flatbuffers # noqa
+ from tflite_support import metadata as _metadata # noqa
+ from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
+
+ tmp_file = Path('/tmp/meta.txt')
+ with open(tmp_file, 'w') as meta_f:
+ meta_f.write(str(self.metadata))
+
+ model_meta = _metadata_fb.ModelMetadataT()
+ label_file = _metadata_fb.AssociatedFileT()
+ label_file.name = tmp_file.name
+ model_meta.associatedFiles = [label_file]
+
+ subgraph = _metadata_fb.SubGraphMetadataT()
+ subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
+ subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
+ model_meta.subgraphMetadata = [subgraph]
+
+ b = flatbuffers.Builder(0)
+ b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
+ metadata_buf = b.Output()
+
+ populator = _metadata.MetadataPopulator.with_model_file(file)
+ populator.load_metadata_buffer(metadata_buf)
+ populator.load_associated_files([str(tmp_file)])
+ populator.populate()
+ tmp_file.unlink()
+
+ def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
+ # YOLOv8 CoreML pipeline
+ import coremltools as ct # noqa
+
+ LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
+ batch_size, ch, h, w = list(self.im.shape) # BCHW
+
+ # Output shapes
+ spec = model.get_spec()
+ out0, out1 = iter(spec.description.output)
+ if MACOS:
+ from PIL import Image
+ img = Image.new('RGB', (w, h)) # img(192 width, 320 height)
+ # img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
+ out = model.predict({'image': img})
+ out0_shape = out[out0.name].shape
+ out1_shape = out[out1.name].shape
+ else: # linux and windows can not run model.predict(), get sizes from pytorch output y
+ out0_shape = self.output_shape[1], self.output_shape[2] - 5 # (3780, 80)
+ out1_shape = self.output_shape[1], 4 # (3780, 4)
+
+ # Checks
+ names = self.metadata['names']
+ nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
+ na, nc = out0_shape
+ # na, nc = out0.type.multiArrayType.shape # number anchors, classes
+ assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
+
+ # Define output shapes (missing)
+ out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
+ out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
+ # spec.neuralNetwork.preprocessing[0].featureName = '0'
+
+ # Flexible input shapes
+ # from coremltools.models.neural_network import flexible_shape_utils
+ # s = [] # shapes
+ # s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
+ # s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
+ # flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
+ # r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
+ # r.add_height_range((192, 640))
+ # r.add_width_range((192, 640))
+ # flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
+
+ # Print
+ print(spec.description)
+
+ # Model from spec
+ model = ct.models.MLModel(spec)
+
+ # 3. Create NMS protobuf
+ nms_spec = ct.proto.Model_pb2.Model()
+ nms_spec.specificationVersion = 5
+ for i in range(2):
+ decoder_output = model._spec.description.output[i].SerializeToString()
+ nms_spec.description.input.add()
+ nms_spec.description.input[i].ParseFromString(decoder_output)
+ nms_spec.description.output.add()
+ nms_spec.description.output[i].ParseFromString(decoder_output)
+
+ nms_spec.description.output[0].name = 'confidence'
+ nms_spec.description.output[1].name = 'coordinates'
+
+ output_sizes = [nc, 4]
+ for i in range(2):
+ ma_type = nms_spec.description.output[i].type.multiArrayType
+ ma_type.shapeRange.sizeRanges.add()
+ ma_type.shapeRange.sizeRanges[0].lowerBound = 0
+ ma_type.shapeRange.sizeRanges[0].upperBound = -1
+ ma_type.shapeRange.sizeRanges.add()
+ ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
+ ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
+ del ma_type.shape[:]
+
+ nms = nms_spec.nonMaximumSuppression
+ nms.confidenceInputFeatureName = out0.name # 1x507x80
+ nms.coordinatesInputFeatureName = out1.name # 1x507x4
+ nms.confidenceOutputFeatureName = 'confidence'
+ nms.coordinatesOutputFeatureName = 'coordinates'
+ nms.iouThresholdInputFeatureName = 'iouThreshold'
+ nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
+ nms.iouThreshold = 0.45
+ nms.confidenceThreshold = 0.25
+ nms.pickTop.perClass = True
+ nms.stringClassLabels.vector.extend(names.values())
+ nms_model = ct.models.MLModel(nms_spec)
+
+ # 4. Pipeline models together
+ pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
+ ('iouThreshold', ct.models.datatypes.Double()),
+ ('confidenceThreshold', ct.models.datatypes.Double())],
+ output_features=['confidence', 'coordinates'])
+ pipeline.add_model(model)
+ pipeline.add_model(nms_model)
+
+ # Correct datatypes
+ pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
+ pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
+ pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
+
+ # Update metadata
+ pipeline.spec.specificationVersion = 5
+ pipeline.spec.description.metadata.versionString = f'Ultralytics YOLOv{ultralytics.__version__}'
+ pipeline.spec.description.metadata.shortDescription = f'Ultralytics {self.pretty_name} CoreML model'
+ pipeline.spec.description.metadata.author = 'Ultralytics (https://ultralytics.com)'
+ pipeline.spec.description.metadata.license = 'GPL-3.0 license (https://ultralytics.com/license)'
+ pipeline.spec.description.metadata.userDefined.update({
+ 'IoU threshold': str(nms.iouThreshold),
+ 'Confidence threshold': str(nms.confidenceThreshold)})
+
+ # Save the model
+ model = ct.models.MLModel(pipeline.spec)
+ model.input_description['image'] = 'Input image'
+ model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
+ model.input_description['confidenceThreshold'] = \
+ f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
+ model.output_description['confidence'] = 'Boxes Γ Class confidence (see user-defined metadata "classes")'
+ model.output_description['coordinates'] = 'Boxes Γ [x, y, width, height] (relative to image size)'
+ LOGGER.info(f'{prefix} pipeline success')
+ return model
+
+ def run_callbacks(self, event: str):
+ for callback in self.callbacks.get(event, []):
+ callback(self)
+
+
+@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
+def export(cfg):
+ cfg.model = cfg.model or "yolov8n.yaml"
+ cfg.format = cfg.format or "torchscript"
+
+ # exporter = Exporter(cfg)
+ #
+ # model = None
+ # if isinstance(cfg.model, (str, Path)):
+ # if Path(cfg.model).suffix == '.yaml':
+ # model = DetectionModel(cfg.model)
+ # elif Path(cfg.model).suffix == '.pt':
+ # model = attempt_load_weights(cfg.model, fuse=True)
+ # else:
+ # TypeError(f'Unsupported model type {cfg.model}')
+ # exporter(model=model)
+
+ from ultralytics import YOLO
+ model = YOLO(cfg.model)
+ model.export(**cfg)
+
+
+if __name__ == "__main__":
+ """
+ CLI:
+ yolo mode=export model=yolov8n.yaml format=onnx
+ """
+ export()
diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6e9a5a026a5169a37770c304a0d920ec7be2a0
--- /dev/null
+++ b/ultralytics/yolo/engine/model.py
@@ -0,0 +1,221 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from pathlib import Path
+
+from ultralytics import yolo # noqa
+from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
+from ultralytics.yolo.configs import get_config
+from ultralytics.yolo.engine.exporter import Exporter
+from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
+from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
+from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
+
+# Map head to model, trainer, validator, and predictor classes
+MODEL_MAP = {
+ "classify": [
+ ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
+ 'yolo.TYPE.classify.ClassificationPredictor'],
+ "detect": [
+ DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
+ 'yolo.TYPE.detect.DetectionPredictor'],
+ "segment": [
+ SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
+ 'yolo.TYPE.segment.SegmentationPredictor']}
+
+
+class YOLO:
+ """
+ YOLO
+
+ A python interface which emulates a model-like behaviour by wrapping trainers.
+ """
+
+ def __init__(self, model='yolov8n.yaml', type="v8") -> None:
+ """
+ > Initializes the YOLO object.
+
+ Args:
+ model (str, Path): model to load or create
+ type (str): Type/version of models to use. Defaults to "v8".
+ """
+ self.type = type
+ self.ModelClass = None # model class
+ self.TrainerClass = None # trainer class
+ self.ValidatorClass = None # validator class
+ self.PredictorClass = None # predictor class
+ self.model = None # model object
+ self.trainer = None # trainer object
+ self.task = None # task type
+ self.ckpt = None # if loaded from *.pt
+ self.cfg = None # if loaded from *.yaml
+ self.ckpt_path = None
+ self.overrides = {} # overrides for trainer object
+
+ # Load or create new YOLO model
+ {'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
+
+ def __call__(self, source, **kwargs):
+ return self.predict(source, **kwargs)
+
+ def _new(self, cfg: str, verbose=True):
+ """
+ > Initializes a new model and infers the task type from the model definitions.
+
+ Args:
+ cfg (str): model configuration file
+ verbose (bool): display model info on load
+ """
+ cfg = check_yaml(cfg) # check YAML
+ cfg_dict = yaml_load(cfg, append_filename=True) # model dict
+ self.task = guess_task_from_head(cfg_dict["head"][-1][-2])
+ self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
+ self._guess_ops_from_task(self.task)
+ self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
+ self.cfg = cfg
+
+ def _load(self, weights: str):
+ """
+ > Initializes a new model and infers the task type from the model head.
+
+ Args:
+ weights (str): model checkpoint to be loaded
+ """
+ self.model, self.ckpt = attempt_load_one_weight(weights)
+ self.ckpt_path = weights
+ self.task = self.model.args["task"]
+ self.overrides = self.model.args
+ self._reset_ckpt_args(self.overrides)
+ self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
+ self._guess_ops_from_task(self.task)
+
+ def reset(self):
+ """
+ > Resets the model modules.
+ """
+ for m in self.model.modules():
+ if hasattr(m, 'reset_parameters'):
+ m.reset_parameters()
+ for p in self.model.parameters():
+ p.requires_grad = True
+
+ def info(self, verbose=False):
+ """
+ > Logs model info.
+
+ Args:
+ verbose (bool): Controls verbosity.
+ """
+ self.model.info(verbose=verbose)
+
+ def fuse(self):
+ self.model.fuse()
+
+ @smart_inference_mode()
+ def predict(self, source, **kwargs):
+ """
+ Visualize prediction.
+
+ Args:
+ source (str): Accepts all source types accepted by yolo
+ **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
+ """
+ overrides = self.overrides.copy()
+ overrides["conf"] = 0.25
+ overrides.update(kwargs)
+ overrides["mode"] = "predict"
+ overrides["save"] = kwargs.get("save", False) # not save files by default
+ predictor = self.PredictorClass(overrides=overrides)
+
+ predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
+ predictor.setup(model=self.model, source=source)
+ return predictor()
+
+ @smart_inference_mode()
+ def val(self, data=None, **kwargs):
+ """
+ > Validate a model on a given dataset .
+
+ Args:
+ data (str): The dataset to validate on. Accepts all formats accepted by yolo
+ **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
+ """
+ overrides = self.overrides.copy()
+ overrides.update(kwargs)
+ overrides["mode"] = "val"
+ args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
+ args.data = data or args.data
+ args.task = self.task
+
+ validator = self.ValidatorClass(args=args)
+ validator(model=self.model)
+
+ @smart_inference_mode()
+ def export(self, **kwargs):
+ """
+ > Export model.
+
+ Args:
+ **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
+ """
+
+ overrides = self.overrides.copy()
+ overrides.update(kwargs)
+ args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
+ args.task = self.task
+
+ exporter = Exporter(overrides=args)
+ exporter(model=self.model)
+
+ def train(self, **kwargs):
+ """
+ > Trains the model on a given dataset.
+
+ Args:
+ **kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
+ You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
+ """
+ overrides = self.overrides.copy()
+ overrides.update(kwargs)
+ if kwargs.get("cfg"):
+ LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
+ overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
+ overrides["task"] = self.task
+ overrides["mode"] = "train"
+ if not overrides.get("data"):
+ raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
+ if overrides.get("resume"):
+ overrides["resume"] = self.ckpt_path
+
+ self.trainer = self.TrainerClass(overrides=overrides)
+ if not overrides.get("resume"): # manually set model only if not resuming
+ self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
+ self.model = self.trainer.model
+ self.trainer.train()
+
+ def to(self, device):
+ """
+ > Sends the model to the given device.
+
+ Args:
+ device (str): device
+ """
+ self.model.to(device)
+
+ def _guess_ops_from_task(self, task):
+ model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
+ # warning: eval is unsafe. Use with caution
+ trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
+ validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
+ predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
+
+ return model_class, trainer_class, validator_class, predictor_class
+
+ @staticmethod
+ def _reset_ckpt_args(args):
+ args.pop("device", None)
+ args.pop("project", None)
+ args.pop("name", None)
+ args.pop("batch", None)
+ args.pop("epochs", None)
+ args.pop("cache", None)
+ args.pop("save_json", None)
diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..efbd78fb237a27f973f1adf60455b21daefb84f3
--- /dev/null
+++ b/ultralytics/yolo/engine/predictor.py
@@ -0,0 +1,245 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
+Usage - sources:
+ $ yolo task=... mode=predict model=s.pt --source 0 # webcam
+ img.jpg # image
+ vid.mp4 # video
+ screen # screenshot
+ path/ # directory
+ list.txt # list of images
+ list.streams # list of streams
+ 'path/*.jpg' # glob
+ 'https://youtu.be/Zgi9g1ksQHc' # YouTube
+ 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
+Usage - formats:
+ $ yolo task=... mode=predict --weights yolov8n.pt # PyTorch
+ yolov8n.torchscript # TorchScript
+ yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
+ yolov8n_openvino_model # OpenVINO
+ yolov8n.engine # TensorRT
+ yolov8n.mlmodel # CoreML (macOS-only)
+ yolov8n_saved_model # TensorFlow SavedModel
+ yolov8n.pb # TensorFlow GraphDef
+ yolov8n.tflite # TensorFlow Lite
+ yolov8n_edgetpu.tflite # TensorFlow Edge TPU
+ yolov8n_paddle_model # PaddlePaddle
+ """
+import platform
+from collections import defaultdict
+from pathlib import Path
+
+import cv2
+
+from ultralytics.nn.autobackend import AutoBackend
+from ultralytics.yolo.configs import get_config
+from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
+from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
+from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops
+from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
+from ultralytics.yolo.utils.files import increment_path
+from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
+
+
+class BasePredictor:
+ """
+ BasePredictor
+
+ A base class for creating predictors.
+
+ Attributes:
+ args (OmegaConf): Configuration for the predictor.
+ save_dir (Path): Directory to save results.
+ done_setup (bool): Whether the predictor has finished setup.
+ model (nn.Module): Model used for prediction.
+ data (dict): Data configuration.
+ device (torch.device): Device used for prediction.
+ dataset (Dataset): Dataset used for prediction.
+ vid_path (str): Path to video file.
+ vid_writer (cv2.VideoWriter): Video writer for saving video output.
+ annotator (Annotator): Annotator used for prediction.
+ data_path (str): Path to data.
+ """
+
+ def __init__(self, config=DEFAULT_CONFIG, overrides=None):
+ """
+ Initializes the BasePredictor class.
+
+ Args:
+ config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
+ overrides (dict, optional): Configuration overrides. Defaults to None.
+ """
+ if overrides is None:
+ overrides = {}
+ self.args = get_config(config, overrides)
+ project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
+ name = self.args.name or f"{self.args.mode}"
+ self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
+ if self.args.save:
+ (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
+ if self.args.conf is None:
+ self.args.conf = 0.25 # default conf=0.25
+ self.done_setup = False
+
+ # Usable if setup is done
+ self.model = None
+ self.data = self.args.data # data_dict
+ self.device = None
+ self.dataset = None
+ self.vid_path, self.vid_writer = None, None
+ self.annotator = None
+ self.data_path = None
+ self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
+ callbacks.add_integration_callbacks(self)
+
+ def preprocess(self, img):
+ pass
+
+ def get_annotator(self, img):
+ raise NotImplementedError("get_annotator function needs to be implemented")
+
+ def write_results(self, pred, batch, print_string):
+ raise NotImplementedError("print_results function needs to be implemented")
+
+ def postprocess(self, preds, img, orig_img):
+ return preds
+
+ def setup(self, source=None, model=None):
+ # source
+ source = str(source if source is not None else self.args.source)
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
+ webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
+ screenshot = source.lower().startswith('screen')
+ if is_url and is_file:
+ source = check_file(source) # download
+
+ # model
+ device = select_device(self.args.device)
+ model = model or self.args.model
+ self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
+ model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
+ stride, pt = model.stride, model.pt
+ imgsz = check_imgsz(self.args.imgsz, stride=stride) # check image size
+
+ # Dataloader
+ bs = 1 # batch_size
+ if webcam:
+ self.args.show = check_imshow(warn=True)
+ self.dataset = LoadStreams(source,
+ imgsz=imgsz,
+ stride=stride,
+ auto=pt,
+ transforms=getattr(model.model, 'transforms', None),
+ vid_stride=self.args.vid_stride)
+ bs = len(self.dataset)
+ elif screenshot:
+ self.dataset = LoadScreenshots(source,
+ imgsz=imgsz,
+ stride=stride,
+ auto=pt,
+ transforms=getattr(model.model, 'transforms', None))
+ else:
+ self.dataset = LoadImages(source,
+ imgsz=imgsz,
+ stride=stride,
+ auto=pt,
+ transforms=getattr(model.model, 'transforms', None),
+ vid_stride=self.args.vid_stride)
+ self.vid_path, self.vid_writer = [None] * bs, [None] * bs
+ model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
+
+ self.model = model
+ self.webcam = webcam
+ self.screenshot = screenshot
+ self.imgsz = imgsz
+ self.done_setup = True
+ self.device = device
+
+ return model
+
+ @smart_inference_mode()
+ def __call__(self, source=None, model=None):
+ self.run_callbacks("on_predict_start")
+ model = self.model if self.done_setup else self.setup(source, model)
+ model.eval()
+ self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
+ self.all_outputs = []
+ for batch in self.dataset:
+ self.run_callbacks("on_predict_batch_start")
+ path, im, im0s, vid_cap, s = batch
+ visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
+ with self.dt[0]:
+ im = self.preprocess(im)
+ if len(im.shape) == 3:
+ im = im[None] # expand for batch dim
+
+ # Inference
+ with self.dt[1]:
+ preds = model(im, augment=self.args.augment, visualize=visualize)
+
+ # postprocess
+ with self.dt[2]:
+ preds = self.postprocess(preds, im, im0s)
+
+ for i in range(len(im)):
+ if self.webcam:
+ path, im0s = path[i], im0s[i]
+ p = Path(path)
+ s += self.write_results(i, preds, (p, im, im0s))
+
+ if self.args.show:
+ self.show(p)
+
+ if self.args.save:
+ self.save_preds(vid_cap, i, str(self.save_dir / p.name))
+
+ # Print time (inference-only)
+ LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
+
+ self.run_callbacks("on_predict_batch_end")
+
+ # Print results
+ t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
+ LOGGER.info(
+ f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape {(1, 3, *self.imgsz)}'
+ % t)
+ if self.args.save_txt or self.args.save:
+ s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
+ LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
+
+ self.run_callbacks("on_predict_end")
+ return self.all_outputs
+
+ def show(self, p):
+ im0 = self.annotator.result()
+ if platform.system() == 'Linux' and p not in self.windows:
+ self.windows.append(p)
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
+ cv2.imshow(str(p), im0)
+ cv2.waitKey(1) # 1 millisecond
+
+ def save_preds(self, vid_cap, idx, save_path):
+ im0 = self.annotator.result()
+ # save imgs
+ if self.dataset.mode == 'image':
+ cv2.imwrite(save_path, im0)
+ else: # 'video' or 'stream'
+ if self.vid_path[idx] != save_path: # new video
+ self.vid_path[idx] = save_path
+ if isinstance(self.vid_writer[idx], cv2.VideoWriter):
+ self.vid_writer[idx].release() # release previous video writer
+ if vid_cap: # video
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ else: # stream
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
+ self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
+ self.vid_writer[idx].write(im0)
+
+ def run_callbacks(self, event: str):
+ for callback in self.callbacks.get(event, []):
+ callback(self)
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6be04703332a768e9ab1680c2f7de3aa991fcc4
--- /dev/null
+++ b/ultralytics/yolo/engine/trainer.py
@@ -0,0 +1,573 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Simple training loop; Boilerplate that could apply to any arbitrary neural network,
+"""
+
+import os
+import subprocess
+import time
+from collections import defaultdict
+from copy import deepcopy
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from omegaconf import OmegaConf # noqa
+from omegaconf import open_dict
+from torch.cuda import amp
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import lr_scheduler
+from tqdm import tqdm
+
+import ultralytics.yolo.utils as utils
+from ultralytics import __version__
+from ultralytics.nn.tasks import attempt_load_one_weight
+from ultralytics.yolo.configs import get_config
+from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
+from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
+ yaml_save)
+from ultralytics.yolo.utils.autobatch import check_train_batch_size
+from ultralytics.yolo.utils.checks import check_file, print_args
+from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
+from ultralytics.yolo.utils.files import get_latest_run, increment_path
+from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
+
+
+class BaseTrainer:
+ """
+ BaseTrainer
+
+ > A base class for creating trainers.
+
+ Attributes:
+ args (OmegaConf): Configuration for the trainer.
+ check_resume (method): Method to check if training should be resumed from a saved checkpoint.
+ console (logging.Logger): Logger instance.
+ validator (BaseValidator): Validator instance.
+ model (nn.Module): Model instance.
+ callbacks (defaultdict): Dictionary of callbacks.
+ save_dir (Path): Directory to save results.
+ wdir (Path): Directory to save weights.
+ last (Path): Path to last checkpoint.
+ best (Path): Path to best checkpoint.
+ batch_size (int): Batch size for training.
+ epochs (int): Number of epochs to train for.
+ start_epoch (int): Starting epoch for training.
+ device (torch.device): Device to use for training.
+ amp (bool): Flag to enable AMP (Automatic Mixed Precision).
+ scaler (amp.GradScaler): Gradient scaler for AMP.
+ data (str): Path to data.
+ trainset (torch.utils.data.Dataset): Training dataset.
+ testset (torch.utils.data.Dataset): Testing dataset.
+ ema (nn.Module): EMA (Exponential Moving Average) of the model.
+ lf (nn.Module): Loss function.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
+ best_fitness (float): The best fitness value achieved.
+ fitness (float): Current fitness value.
+ loss (float): Current loss value.
+ tloss (float): Total loss value.
+ loss_names (list): List of loss names.
+ csv (Path): Path to results CSV file.
+ """
+
+ def __init__(self, config=DEFAULT_CONFIG, overrides=None):
+ """
+ > Initializes the BaseTrainer class.
+
+ Args:
+ config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
+ overrides (dict, optional): Configuration overrides. Defaults to None.
+ """
+ if overrides is None:
+ overrides = {}
+ self.args = get_config(config, overrides)
+ self.check_resume()
+ self.console = LOGGER
+ self.validator = None
+ self.model = None
+ self.callbacks = defaultdict(list)
+ init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
+
+ # Dirs
+ project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
+ name = self.args.name or f"{self.args.mode}"
+ self.save_dir = Path(
+ self.args.get(
+ "save_dir",
+ increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)))
+ self.wdir = self.save_dir / 'weights' # weights dir
+ if RANK in {-1, 0}:
+ self.wdir.mkdir(parents=True, exist_ok=True) # make dir
+ with open_dict(self.args):
+ self.args.save_dir = str(self.save_dir)
+ yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
+ self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
+
+ self.batch_size = self.args.batch
+ self.epochs = self.args.epochs
+ self.start_epoch = 0
+ if RANK == -1:
+ print_args(dict(self.args))
+
+ # Device
+ self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
+ self.amp = self.device.type != 'cpu'
+ self.scaler = amp.GradScaler(enabled=self.amp)
+ if self.device.type == 'cpu':
+ self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
+
+ # Model and Dataloaders.
+ self.model = self.args.model
+ self.data = self.args.data
+ if self.data.endswith(".yaml"):
+ self.data = check_dataset_yaml(self.data)
+ else:
+ self.data = check_dataset(self.data)
+ self.trainset, self.testset = self.get_dataset(self.data)
+ self.ema = None
+
+ # Optimization utils init
+ self.lf = None
+ self.scheduler = None
+
+ # Epoch level metrics
+ self.best_fitness = None
+ self.fitness = None
+ self.loss = None
+ self.tloss = None
+ self.loss_names = ['Loss']
+ self.csv = self.save_dir / 'results.csv'
+ self.plot_idx = [0, 1, 2]
+
+ # Callbacks
+ self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
+ if RANK in {0, -1}:
+ callbacks.add_integration_callbacks(self)
+
+ def add_callback(self, event: str, callback):
+ """
+ > Appends the given callback.
+ """
+ self.callbacks[event].append(callback)
+
+ def set_callback(self, event: str, callback):
+ """
+ > Overrides the existing callbacks with the given callback.
+ """
+ self.callbacks[event] = [callback]
+
+ def run_callbacks(self, event: str):
+ for callback in self.callbacks.get(event, []):
+ callback(self)
+
+ def train(self):
+ world_size = torch.cuda.device_count()
+ if world_size > 1 and "LOCAL_RANK" not in os.environ:
+ command = generate_ddp_command(world_size, self)
+ try:
+ subprocess.run(command)
+ except Exception as e:
+ self.console(e)
+ finally:
+ ddp_cleanup(command, self)
+ else:
+ self._do_train(int(os.getenv("RANK", -1)), world_size)
+
+ def _setup_ddp(self, rank, world_size):
+ # os.environ['MASTER_ADDR'] = 'localhost'
+ # os.environ['MASTER_PORT'] = '9020'
+ torch.cuda.set_device(rank)
+ self.device = torch.device('cuda', rank)
+ self.console.info(f"DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}")
+ dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
+
+ def _setup_train(self, rank, world_size):
+ """
+ > Builds dataloaders and optimizer on correct rank process.
+ """
+ # model
+ self.run_callbacks("on_pretrain_routine_start")
+ ckpt = self.setup_model()
+ self.model = self.model.to(self.device)
+ self.set_model_attributes()
+ if world_size > 1:
+ self.model = DDP(self.model, device_ids=[rank])
+
+ # Batch size
+ if self.batch_size == -1:
+ if RANK == -1: # single-GPU only, estimate best batch size
+ self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
+ else:
+ SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
+ 'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
+
+ # Optimizer
+ self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
+ self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
+ self.optimizer = self.build_optimizer(model=self.model,
+ name=self.args.optimizer,
+ lr=self.args.lr0,
+ momentum=self.args.momentum,
+ decay=self.args.weight_decay)
+ # Scheduler
+ if self.args.cos_lr:
+ self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
+ else:
+ self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
+ self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
+ self.scheduler.last_epoch = self.start_epoch - 1 # do not move
+
+ # dataloaders
+ batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
+ self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
+ if rank in {0, -1}:
+ self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
+ self.validator = self.get_validator()
+ metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
+ self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
+ self.ema = ModelEMA(self.model)
+ self.resume_training(ckpt)
+ self.run_callbacks("on_pretrain_routine_end")
+
+ def _do_train(self, rank=-1, world_size=1):
+ if world_size > 1:
+ self._setup_ddp(rank, world_size)
+
+ self._setup_train(rank, world_size)
+
+ self.epoch_time = None
+ self.epoch_time_start = time.time()
+ self.train_time_start = time.time()
+ nb = len(self.train_loader) # number of batches
+ nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
+ last_opt_step = -1
+ self.run_callbacks("on_train_start")
+ self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
+ f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
+ f"Logging results to {colorstr('bold', self.save_dir)}\n"
+ f"Starting training for {self.epochs} epochs...")
+ if self.args.close_mosaic:
+ base_idx = (self.epochs - self.args.close_mosaic) * nb
+ self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
+ for epoch in range(self.start_epoch, self.epochs):
+ self.epoch = epoch
+ self.run_callbacks("on_train_epoch_start")
+ self.model.train()
+ if rank != -1:
+ self.train_loader.sampler.set_epoch(epoch)
+ pbar = enumerate(self.train_loader)
+ # Update dataloader attributes (optional)
+ if epoch == (self.epochs - self.args.close_mosaic):
+ self.console.info("Closing dataloader mosaic")
+ if hasattr(self.train_loader.dataset, 'mosaic'):
+ self.train_loader.dataset.mosaic = False
+ if hasattr(self.train_loader.dataset, 'close_mosaic'):
+ self.train_loader.dataset.close_mosaic(hyp=self.args)
+
+ if rank in {-1, 0}:
+ self.console.info(self.progress_string())
+ pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
+ self.tloss = None
+ self.optimizer.zero_grad()
+ for i, batch in pbar:
+ self.run_callbacks("on_train_batch_start")
+ # Warmup
+ ni = i + nb * epoch
+ if ni <= nw:
+ xi = [0, nw] # x interp
+ self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
+ for j, x in enumerate(self.optimizer.param_groups):
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+ x['lr'] = np.interp(
+ ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
+ if 'momentum' in x:
+ x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
+
+ # Forward
+ with torch.cuda.amp.autocast(self.amp):
+ batch = self.preprocess_batch(batch)
+ preds = self.model(batch["img"])
+ self.loss, self.loss_items = self.criterion(preds, batch)
+ if rank != -1:
+ self.loss *= world_size
+ self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
+ else self.loss_items
+
+ # Backward
+ self.scaler.scale(self.loss).backward()
+
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
+ if ni - last_opt_step >= self.accumulate:
+ self.optimizer_step()
+ last_opt_step = ni
+
+ # Log
+ mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
+ loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
+ losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
+ if rank in {-1, 0}:
+ pbar.set_description(
+ ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
+ (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
+ self.run_callbacks('on_batch_end')
+ if self.args.plots and ni in self.plot_idx:
+ self.plot_training_samples(batch, ni)
+
+ self.run_callbacks("on_train_batch_end")
+
+ self.lr = {f"lr/pg{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
+
+ self.scheduler.step()
+ self.run_callbacks("on_train_epoch_end")
+
+ if rank in {-1, 0}:
+
+ # Validation
+ self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
+ final_epoch = (epoch + 1 == self.epochs)
+ if self.args.val or final_epoch:
+ self.metrics, self.fitness = self.validate()
+ self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
+
+ # Save model
+ if self.args.save or (epoch + 1 == self.epochs):
+ self.save_model()
+ self.run_callbacks('on_model_save')
+
+ tnow = time.time()
+ self.epoch_time = tnow - self.epoch_time_start
+ self.epoch_time_start = tnow
+ self.run_callbacks("on_fit_epoch_end")
+ # TODO: termination condition
+
+ if rank in {-1, 0}:
+ # Do final val with best.pt
+ self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in '
+ f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
+ self.final_eval()
+ if self.args.plots:
+ self.plot_metrics()
+ self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
+ self.run_callbacks('on_train_end')
+ torch.cuda.empty_cache()
+ self.run_callbacks('teardown')
+
+ def save_model(self):
+ ckpt = {
+ 'epoch': self.epoch,
+ 'best_fitness': self.best_fitness,
+ 'model': deepcopy(de_parallel(self.model)).half(),
+ 'ema': deepcopy(self.ema.ema).half(),
+ 'updates': self.ema.updates,
+ 'optimizer': self.optimizer.state_dict(),
+ 'train_args': self.args,
+ 'date': datetime.now().isoformat(),
+ 'version': __version__}
+
+ # Save last, best and delete
+ torch.save(ckpt, self.last)
+ if self.best_fitness == self.fitness:
+ torch.save(ckpt, self.best)
+ del ckpt
+
+ def get_dataset(self, data):
+ """
+ > Get train, val path from data dict if it exists. Returns None if data format is not recognized.
+ """
+ return data["train"], data.get("val") or data.get("test")
+
+ def setup_model(self):
+ """
+ > load/create/download model for any task.
+ """
+ if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
+ return
+
+ model, weights = self.model, None
+ ckpt = None
+ if str(model).endswith(".pt"):
+ weights, ckpt = attempt_load_one_weight(model)
+ cfg = ckpt["model"].yaml
+ else:
+ cfg = model
+ self.model = self.get_model(cfg=cfg, weights=weights) # calls Model(cfg, weights)
+ return ckpt
+
+ def optimizer_step(self):
+ self.scaler.unscale_(self.optimizer) # unscale gradients
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ self.optimizer.zero_grad()
+ if self.ema:
+ self.ema.update(self.model)
+
+ def preprocess_batch(self, batch):
+ """
+ > Allows custom preprocessing model inputs and ground truths depending on task type.
+ """
+ return batch
+
+ def validate(self):
+ """
+ > Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
+ """
+ metrics = self.validator(self)
+ fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
+ if not self.best_fitness or self.best_fitness < fitness:
+ self.best_fitness = fitness
+ return metrics, fitness
+
+ def log(self, text, rank=-1):
+ """
+ > Logs the given text to given ranks process if provided, otherwise logs to all ranks.
+
+ Args"
+ text (str): text to log
+ rank (List[Int]): process rank
+
+ """
+ if rank in {-1, 0}:
+ self.console.info(text)
+
+ def get_model(self, cfg=None, weights=None, verbose=True):
+ raise NotImplementedError("This task trainer doesn't support loading cfg files")
+
+ def get_validator(self):
+ raise NotImplementedError("get_validator function not implemented in trainer")
+
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0):
+ """
+ > Returns dataloader derived from torch.data.Dataloader.
+ """
+ raise NotImplementedError("get_dataloader function not implemented in trainer")
+
+ def criterion(self, preds, batch):
+ """
+ > Returns loss and individual loss items as Tensor.
+ """
+ raise NotImplementedError("criterion function not implemented in trainer")
+
+ def label_loss_items(self, loss_items=None, prefix="train"):
+ """
+ Returns a loss dict with labelled training loss items tensor
+ """
+ # Not needed for classification but necessary for segmentation & detection
+ return {"loss": loss_items} if loss_items is not None else ["loss"]
+
+ def set_model_attributes(self):
+ """
+ To set or update model parameters before training.
+ """
+ self.model.names = self.data["names"]
+
+ def build_targets(self, preds, targets):
+ pass
+
+ def progress_string(self):
+ return ""
+
+ # TODO: may need to put these following functions into callback
+ def plot_training_samples(self, batch, ni):
+ pass
+
+ def save_metrics(self, metrics):
+ keys, vals = list(metrics.keys()), list(metrics.values())
+ n = len(metrics) + 1 # number of cols
+ s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
+ with open(self.csv, 'a') as f:
+ f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
+
+ def plot_metrics(self):
+ pass
+
+ def final_eval(self):
+ for f in self.last, self.best:
+ if f.exists():
+ strip_optimizer(f) # strip optimizers
+ if f is self.best:
+ self.console.info(f'\nValidating {f}...')
+ self.validator.args.save_json = True
+ self.metrics = self.validator(model=f)
+ self.metrics.pop('fitness', None)
+ self.run_callbacks('on_fit_epoch_end')
+
+ def check_resume(self):
+ resume = self.args.resume
+ if resume:
+ last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
+ args_yaml = last.parent.parent / 'args.yaml' # train options yaml
+ if args_yaml.is_file():
+ args = get_config(args_yaml) # replace
+ args.model, resume = str(last), True # reinstate
+ self.args = args
+ self.resume = resume
+
+ def resume_training(self, ckpt):
+ if ckpt is None:
+ return
+ best_fitness = 0.0
+ start_epoch = ckpt['epoch'] + 1
+ if ckpt['optimizer'] is not None:
+ self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
+ best_fitness = ckpt['best_fitness']
+ if self.ema and ckpt.get('ema'):
+ self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
+ self.ema.updates = ckpt['updates']
+ if self.resume:
+ assert start_epoch > 0, \
+ f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
+ f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
+ LOGGER.info(
+ f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
+ if self.epochs < start_epoch:
+ LOGGER.info(
+ f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
+ self.epochs += ckpt['epoch'] # finetune additional epochs
+ self.best_fitness = best_fitness
+ self.start_epoch = start_epoch
+
+ @staticmethod
+ def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
+ """
+ > Builds an optimizer with the specified parameters and parameter groups.
+
+ Args:
+ model (nn.Module): model to optimize
+ name (str): name of the optimizer to use
+ lr (float): learning rate
+ momentum (float): momentum
+ decay (float): weight decay
+
+ Returns:
+ optimizer (torch.optim.Optimizer): the built optimizer
+ """
+ g = [], [], [] # optimizer parameter groups
+ bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
+ for v in model.modules():
+ if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
+ g[2].append(v.bias)
+ if isinstance(v, bn): # weight (no decay)
+ g[1].append(v.weight)
+ elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
+ g[0].append(v.weight)
+
+ if name == 'Adam':
+ optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
+ elif name == 'AdamW':
+ optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
+ elif name == 'RMSProp':
+ optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
+ elif name == 'SGD':
+ optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
+ else:
+ raise NotImplementedError(f'Optimizer {name} not implemented.')
+
+ optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
+ optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
+ LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
+ f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
+ return optimizer
diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py
new file mode 100644
index 0000000000000000000000000000000000000000..91ba0092101a075c7ee8252a78dd216919173991
--- /dev/null
+++ b/ultralytics/yolo/engine/validator.py
@@ -0,0 +1,224 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import json
+from collections import defaultdict
+from pathlib import Path
+
+import torch
+from omegaconf import OmegaConf # noqa
+from tqdm import tqdm
+
+from ultralytics.nn.autobackend import AutoBackend
+from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
+from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
+from ultralytics.yolo.utils.checks import check_imgsz
+from ultralytics.yolo.utils.files import increment_path
+from ultralytics.yolo.utils.ops import Profile
+from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart_inference_mode
+
+
+class BaseValidator:
+ """
+ BaseValidator
+
+ A base class for creating validators.
+
+ Attributes:
+ dataloader (DataLoader): Dataloader to use for validation.
+ pbar (tqdm): Progress bar to update during validation.
+ logger (logging.Logger): Logger to use for validation.
+ args (OmegaConf): Configuration for the validator.
+ model (nn.Module): Model to validate.
+ data (dict): Data dictionary.
+ device (torch.device): Device to use for validation.
+ batch_i (int): Current batch index.
+ training (bool): Whether the model is in training mode.
+ speed (float): Batch processing speed in seconds.
+ jdict (dict): Dictionary to store validation results.
+ save_dir (Path): Directory to save results.
+ """
+
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
+ """
+ Initializes a BaseValidator instance.
+
+ Args:
+ dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
+ save_dir (Path): Directory to save results.
+ pbar (tqdm.tqdm): Progress bar for displaying progress.
+ logger (logging.Logger): Logger to log messages.
+ args (OmegaConf): Configuration for the validator.
+ """
+ self.dataloader = dataloader
+ self.pbar = pbar
+ self.logger = logger or LOGGER
+ self.args = args or OmegaConf.load(DEFAULT_CONFIG)
+ self.model = None
+ self.data = None
+ self.device = None
+ self.batch_i = None
+ self.training = True
+ self.speed = None
+ self.jdict = None
+
+ project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
+ name = self.args.name or f"{self.args.mode}"
+ self.save_dir = save_dir or increment_path(Path(project) / name,
+ exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
+ (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
+
+ if self.args.conf is None:
+ self.args.conf = 0.001 # default conf=0.001
+
+ self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
+
+ @smart_inference_mode()
+ def __call__(self, trainer=None, model=None):
+ """
+ Supports validation of a pre-trained model if passed or a model being trained
+ if trainer is passed (trainer gets priority).
+ """
+ self.training = trainer is not None
+ if self.training:
+ self.device = trainer.device
+ self.data = trainer.data
+ model = trainer.ema.ema or trainer.model
+ self.args.half = self.device.type != 'cpu' # force FP16 val during training
+ model = model.half() if self.args.half else model.float()
+ self.model = model
+ self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
+ self.args.plots = trainer.epoch == trainer.epochs - 1 # always plot final epoch
+ model.eval()
+ else:
+ callbacks.add_integration_callbacks(self)
+ self.run_callbacks('on_val_start')
+ assert model is not None, "Either trainer or model is needed for validation"
+ self.device = select_device(self.args.device, self.args.batch)
+ self.args.half &= self.device.type != 'cpu'
+ model = AutoBackend(model, device=self.device, dnn=self.args.dnn, fp16=self.args.half)
+ self.model = model
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
+ imgsz = check_imgsz(self.args.imgsz, stride=stride)
+ if engine:
+ self.args.batch = model.batch_size
+ else:
+ self.device = model.device
+ if not pt and not jit:
+ self.args.batch = 1 # export.py models default to batch-size 1
+ self.logger.info(
+ f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
+
+ if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
+ self.data = check_dataset_yaml(self.args.data)
+ else:
+ self.data = check_dataset(self.args.data)
+
+ if self.device.type == 'cpu':
+ self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
+ self.dataloader = self.dataloader or \
+ self.get_dataloader(self.data.get("val") or self.data.set("test"), self.args.batch)
+
+ model.eval()
+ model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
+
+ dt = Profile(), Profile(), Profile(), Profile()
+ n_batches = len(self.dataloader)
+ desc = self.get_desc()
+ # NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
+ # which may affect classification task since this arg is in yolov5/classify/val.py.
+ # bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
+ bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
+ self.init_metrics(de_parallel(model))
+ self.jdict = [] # empty before each val
+ for batch_i, batch in enumerate(bar):
+ self.run_callbacks('on_val_batch_start')
+ self.batch_i = batch_i
+ # pre-process
+ with dt[0]:
+ batch = self.preprocess(batch)
+
+ # inference
+ with dt[1]:
+ preds = model(batch["img"])
+
+ # loss
+ with dt[2]:
+ if self.training:
+ self.loss += trainer.criterion(preds, batch)[1]
+
+ # pre-process predictions
+ with dt[3]:
+ preds = self.postprocess(preds)
+
+ self.update_metrics(preds, batch)
+ if self.args.plots and batch_i < 3:
+ self.plot_val_samples(batch, batch_i)
+ self.plot_predictions(batch, preds, batch_i)
+
+ self.run_callbacks('on_val_batch_end')
+ stats = self.get_stats()
+ self.check_stats(stats)
+ self.print_results()
+ self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
+ self.run_callbacks('on_val_end')
+ if self.training:
+ model.float()
+ results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
+ return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
+ else:
+ self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
+ self.speed)
+ if self.args.save_json and self.jdict:
+ with open(str(self.save_dir / "predictions.json"), 'w') as f:
+ self.logger.info(f"Saving {f.name}...")
+ json.dump(self.jdict, f) # flatten and save
+ stats = self.eval_json(stats) # update stats
+ return stats
+
+ def run_callbacks(self, event: str):
+ for callback in self.callbacks.get(event, []):
+ callback(self)
+
+ def get_dataloader(self, dataset_path, batch_size):
+ raise NotImplementedError("get_dataloader function not implemented for this validator")
+
+ def preprocess(self, batch):
+ return batch
+
+ def postprocess(self, preds):
+ return preds
+
+ def init_metrics(self, model):
+ pass
+
+ def update_metrics(self, preds, batch):
+ pass
+
+ def get_stats(self):
+ return {}
+
+ def check_stats(self, stats):
+ pass
+
+ def print_results(self):
+ pass
+
+ def get_desc(self):
+ pass
+
+ @property
+ def metric_keys(self):
+ return []
+
+ # TODO: may need to put these following functions into callback
+ def plot_val_samples(self, batch, ni):
+ pass
+
+ def plot_predictions(self, batch, preds, ni):
+ pass
+
+ def pred_to_json(self, preds, batch):
+ pass
+
+ def eval_json(self, stats):
+ pass
diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6747c5ca147b318fd7243ca7682aa104b4f9815
--- /dev/null
+++ b/ultralytics/yolo/utils/__init__.py
@@ -0,0 +1,411 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import contextlib
+import inspect
+import logging.config
+import os
+import platform
+import subprocess
+import sys
+import tempfile
+import threading
+import uuid
+from pathlib import Path
+
+import cv2
+import numpy as np
+import pandas as pd
+import torch
+import yaml
+
+# Constants
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[2] # YOLO
+DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml"
+RANK = int(os.getenv('RANK', -1))
+NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
+AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
+FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
+VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
+TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
+LOGGING_NAME = 'yolov5'
+HELP_MSG = \
+ """
+ Usage examples for running YOLOv8:
+
+ 1. Install the ultralytics package:
+
+ pip install ultralytics
+
+ 2. Use the Python SDK:
+
+ from ultralytics import YOLO
+
+ model = YOLO('yolov8n.yaml') # build a new model from scratch
+ model = YOLO('yolov8n.pt') # load a pretrained model (recommended for best training results)
+ results = model.train(data='coco128.yaml') # train the model
+ results = model.val() # evaluate model performance on the validation set
+ results = model.predict(source='bus.jpg') # predict on an image
+ success = model.export(format='onnx') # export the model to ONNX format
+
+ 3. Use the command line interface (CLI):
+
+ yolo task=detect mode=train model=yolov8n.yaml args...
+ classify predict yolov8n-cls.yaml args...
+ segment val yolov8n-seg.yaml args...
+ export yolov8n.pt format=onnx args...
+
+ Docs: https://docs.ultralytics.com
+ Community: https://community.ultralytics.com
+ GitHub: https://github.com/ultralytics/ultralytics
+ """
+
+# Settings
+torch.set_printoptions(linewidth=320, precision=5, profile='long')
+np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
+pd.options.display.max_columns = 10
+cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
+os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
+os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
+
+# Default config dictionary
+with open(DEFAULT_CONFIG, errors='ignore') as f:
+ DEFAULT_CONFIG_DICT = yaml.safe_load(f)
+DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
+
+
+def is_colab():
+ """
+ Check if the current script is running inside a Google Colab notebook.
+
+ Returns:
+ bool: True if running inside a Colab notebook, False otherwise.
+ """
+ # Check if the google.colab module is present in sys.modules
+ return 'google.colab' in sys.modules
+
+
+def is_kaggle():
+ """
+ Check if the current script is running inside a Kaggle kernel.
+
+ Returns:
+ bool: True if running inside a Kaggle kernel, False otherwise.
+ """
+ return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
+
+
+def is_jupyter_notebook():
+ """
+ Check if the current script is running inside a Jupyter Notebook.
+ Verified on Colab, Jupyterlab, Kaggle, Paperspace.
+
+ Returns:
+ bool: True if running inside a Jupyter Notebook, False otherwise.
+ """
+ # Check if the get_ipython function exists
+ # (it does not exist when running as a standalone script)
+ try:
+ from IPython import get_ipython
+ return get_ipython() is not None
+ except ImportError:
+ return False
+
+
+def is_docker() -> bool:
+ """
+ Determine if the script is running inside a Docker container.
+
+ Returns:
+ bool: True if the script is running inside a Docker container, False otherwise.
+ """
+ file = Path('/proc/self/cgroup')
+ if file.exists():
+ with open(file) as f:
+ return 'docker' in f.read()
+ else:
+ return False
+
+
+def is_git_directory() -> bool:
+ """
+ Check if the current working directory is inside a git repository.
+
+ Returns:
+ bool: True if the current working directory is inside a git repository, False otherwise.
+ """
+ from git import Repo
+ try:
+ # Check if the current working directory is a git repository
+ Repo(search_parent_directories=True)
+ return True
+ except Exception:
+ return False
+
+
+def is_pip_package(filepath: str = __name__) -> bool:
+ """
+ Determines if the file at the given filepath is part of a pip package.
+
+ Args:
+ filepath (str): The filepath to check.
+
+ Returns:
+ bool: True if the file is part of a pip package, False otherwise.
+ """
+ import importlib.util
+
+ # Get the spec for the module
+ spec = importlib.util.find_spec(filepath)
+
+ # Return whether the spec is not None and the origin is not None (indicating it is a package)
+ return spec is not None and spec.origin is not None
+
+
+def is_dir_writeable(dir_path: str) -> bool:
+ """
+ Check if a directory is writeable.
+
+ Args:
+ dir_path (str): The path to the directory.
+
+ Returns:
+ bool: True if the directory is writeable, False otherwise.
+ """
+ try:
+ with tempfile.TemporaryFile(dir=dir_path):
+ pass
+ return True
+ except OSError:
+ return False
+
+
+def get_git_root_dir():
+ """
+ Determines whether the current file is part of a git repository and if so, returns the repository root directory.
+ If the current file is not part of a git repository, returns None.
+ """
+ try:
+ output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True)
+ return Path(output.stdout.strip().decode('utf-8')).parent # parent/.git
+ except subprocess.CalledProcessError:
+ return None
+
+
+def get_default_args(func):
+ # Get func() default arguments
+ signature = inspect.signature(func)
+ return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
+
+
+def get_user_config_dir(sub_dir='Ultralytics'):
+ """
+ Get the user config directory.
+
+ Args:
+ sub_dir (str): The name of the subdirectory to create.
+
+ Returns:
+ Path: The path to the user config directory.
+ """
+ # Get the operating system name
+ os_name = platform.system()
+
+ # Return the appropriate config directory for each operating system
+ if os_name == 'Windows':
+ path = Path.home() / 'AppData' / 'Roaming' / sub_dir
+ elif os_name == 'Darwin': # macOS
+ path = Path.home() / 'Library' / 'Application Support' / sub_dir
+ elif os_name == 'Linux':
+ path = Path.home() / '.config' / sub_dir
+ else:
+ raise ValueError(f'Unsupported operating system: {os_name}')
+
+ # GCP and AWS lambda fix, only /tmp is writeable
+ if not is_dir_writeable(str(path.parent)):
+ path = Path('/tmp') / sub_dir
+
+ # Create the subdirectory if it does not exist
+ path.mkdir(parents=True, exist_ok=True)
+
+ return path
+
+
+USER_CONFIG_DIR = get_user_config_dir() # Ultralytics settings dir
+
+
+def emojis(string=''):
+ # Return platform-dependent emoji-safe version of string
+ return string.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else string
+
+
+def colorstr(*input):
+ # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
+ *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
+ colors = {
+ "black": "\033[30m", # basic colors
+ "red": "\033[31m",
+ "green": "\033[32m",
+ "yellow": "\033[33m",
+ "blue": "\033[34m",
+ "magenta": "\033[35m",
+ "cyan": "\033[36m",
+ "white": "\033[37m",
+ "bright_black": "\033[90m", # bright colors
+ "bright_red": "\033[91m",
+ "bright_green": "\033[92m",
+ "bright_yellow": "\033[93m",
+ "bright_blue": "\033[94m",
+ "bright_magenta": "\033[95m",
+ "bright_cyan": "\033[96m",
+ "bright_white": "\033[97m",
+ "end": "\033[0m", # misc
+ "bold": "\033[1m",
+ "underline": "\033[4m",}
+ return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
+
+
+def set_logging(name=LOGGING_NAME, verbose=True):
+ # sets up logging for the given name
+ rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
+ level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
+ logging.config.dictConfig({
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ name: {
+ "format": "%(message)s"}},
+ "handlers": {
+ name: {
+ "class": "logging.StreamHandler",
+ "formatter": name,
+ "level": level,}},
+ "loggers": {
+ name: {
+ "level": level,
+ "handlers": [name],
+ "propagate": False,}}})
+
+
+class TryExcept(contextlib.ContextDecorator):
+ # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
+ def __init__(self, msg=''):
+ self.msg = msg
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, value, traceback):
+ if value:
+ print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
+ return True
+
+
+def threaded(func):
+ # Multi-threads a target function and returns thread. Usage: @threaded decorator
+ def wrapper(*args, **kwargs):
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
+ thread.start()
+ return thread
+
+ return wrapper
+
+
+def yaml_save(file='data.yaml', data=None):
+ """
+ Save YAML data to a file.
+
+ Args:
+ file (str, optional): File name. Default is 'data.yaml'.
+ data (dict, optional): Data to save in YAML format. Default is None.
+
+ Returns:
+ None: Data is saved to the specified file.
+ """
+ file = Path(file)
+ if not file.parent.exists():
+ # Create parent directories if they don't exist
+ file.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(file, 'w') as f:
+ # Dump data to file in YAML format, converting Path objects to strings
+ yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
+
+
+def yaml_load(file='data.yaml', append_filename=False):
+ """
+ Load YAML data from a file.
+
+ Args:
+ file (str, optional): File name. Default is 'data.yaml'.
+ append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
+
+ Returns:
+ dict: YAML data and file name.
+ """
+ with open(file, errors='ignore') as f:
+ # Add YAML filename to dict and return
+ return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
+
+
+def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
+ """
+ Loads a global settings YAML file or creates one with default values if it does not exist.
+
+ Args:
+ file (Path): Path to the settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR.
+
+ Returns:
+ dict: Dictionary of settings key-value pairs.
+ """
+ from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
+
+ root = get_git_root_dir() or Path('') # not is_pip_package()
+ defaults = {
+ 'datasets_dir': str(root / 'datasets'), # default datasets directory.
+ 'weights_dir': str(root / 'weights'), # default weights directory.
+ 'runs_dir': str(root / 'runs'), # default runs directory.
+ 'sync': True, # sync analytics to help with YOLO development
+ 'uuid': uuid.getnode()} # device UUID to align analytics
+
+ with torch_distributed_zero_first(RANK):
+ if not file.exists():
+ yaml_save(file, defaults)
+
+ settings = yaml_load(file)
+
+ # Check that settings keys and types match defaults
+ correct = settings.keys() == defaults.keys() and \
+ all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values()))
+ if not correct:
+ LOGGER.warning('WARNING β οΈ Different global settings detected, resetting to defaults. '
+ 'This may be due to an ultralytics package update. '
+ f'View and update your global settings directly in {file}')
+ settings = defaults # merge **defaults with **settings (prefer **settings)
+ yaml_save(file, settings) # save updated defaults
+
+ return settings
+
+
+# Run below code on utils init -----------------------------------------------------------------------------------------
+
+# Set logger
+set_logging(LOGGING_NAME) # run before defining LOGGER
+LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
+if platform.system() == 'Windows':
+ for fn in LOGGER.info, LOGGER.warning:
+ setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
+
+# Check first-install steps
+SETTINGS = get_settings()
+DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
+
+
+def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
+ """
+ Function that runs on a first-time ultralytics package installation to set up global settings and create necessary
+ directories.
+ """
+ SETTINGS.update(kwargs)
+
+ yaml_save(file, SETTINGS)
diff --git a/ultralytics/yolo/utils/autobatch.py b/ultralytics/yolo/utils/autobatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..cac167dccd9b9ea79d43e76a016606ab47e3c120
--- /dev/null
+++ b/ultralytics/yolo/utils/autobatch.py
@@ -0,0 +1,72 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Auto-batch utils
+"""
+
+from copy import deepcopy
+
+import numpy as np
+import torch
+
+from ultralytics.yolo.utils import LOGGER, colorstr
+from ultralytics.yolo.utils.torch_utils import profile
+
+
+def check_train_batch_size(model, imgsz=640, amp=True):
+ # Check YOLOv5 training batch size
+ with torch.cuda.amp.autocast(amp):
+ return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
+
+
+def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
+ # Automatically estimate best YOLOv5 batch size to use `fraction` of available CUDA memory
+ # Usage:
+ # import torch
+ # from utils.autobatch import autobatch
+ # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
+ # print(autobatch(model))
+
+ # Check device
+ prefix = colorstr('AutoBatch: ')
+ LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
+ device = next(model.parameters()).device # get model device
+ if device.type == 'cpu':
+ LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
+ return batch_size
+ if torch.backends.cudnn.benchmark:
+ LOGGER.info(f'{prefix} β οΈ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
+ return batch_size
+
+ # Inspect CUDA memory
+ gb = 1 << 30 # bytes to GiB (1024 ** 3)
+ d = str(device).upper() # 'CUDA:0'
+ properties = torch.cuda.get_device_properties(device) # device properties
+ t = properties.total_memory / gb # GiB total
+ r = torch.cuda.memory_reserved(device) / gb # GiB reserved
+ a = torch.cuda.memory_allocated(device) / gb # GiB allocated
+ f = t - (r + a) # GiB free
+ LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
+
+ # Profile batch sizes
+ batch_sizes = [1, 2, 4, 8, 16]
+ try:
+ img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
+ results = profile(img, model, n=3, device=device)
+ except Exception as e:
+ LOGGER.warning(f'{prefix}{e}')
+
+ # Fit a solution
+ y = [x[2] for x in results if x] # memory [2]
+ p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
+ b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
+ if None in results: # some sizes failed
+ i = results.index(None) # first fail index
+ if b >= batch_sizes[i]: # y intercept above failure point
+ b = batch_sizes[max(i - 1, 0)] # select prior safe point
+ if b < 1 or b > 1024: # b outside of safe range
+ b = batch_size
+ LOGGER.warning(f'{prefix}WARNING β οΈ CUDA anomaly detected, recommend restart environment and retry command.')
+
+ fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
+ LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) β
')
+ return b
diff --git a/ultralytics/yolo/utils/callbacks/__init__.py b/ultralytics/yolo/utils/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c64adb3bbfc20ca1e0da147b2c7f5afd182dd050
--- /dev/null
+++ b/ultralytics/yolo/utils/callbacks/__init__.py
@@ -0,0 +1 @@
+from .base import add_integration_callbacks, default_callbacks
diff --git a/ultralytics/yolo/utils/callbacks/base.py b/ultralytics/yolo/utils/callbacks/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..689bf158cf947e9abcc699a3523a0700756bcbd6
--- /dev/null
+++ b/ultralytics/yolo/utils/callbacks/base.py
@@ -0,0 +1,149 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Base callbacks
+"""
+
+
+# Trainer callbacks ----------------------------------------------------------------------------------------------------
+def on_pretrain_routine_start(trainer):
+ pass
+
+
+def on_pretrain_routine_end(trainer):
+ pass
+
+
+def on_train_start(trainer):
+ pass
+
+
+def on_train_epoch_start(trainer):
+ pass
+
+
+def on_train_batch_start(trainer):
+ pass
+
+
+def optimizer_step(trainer):
+ pass
+
+
+def on_before_zero_grad(trainer):
+ pass
+
+
+def on_train_batch_end(trainer):
+ pass
+
+
+def on_train_epoch_end(trainer):
+ pass
+
+
+def on_fit_epoch_end(trainer):
+ pass
+
+
+def on_model_save(trainer):
+ pass
+
+
+def on_train_end(trainer):
+ pass
+
+
+def on_params_update(trainer):
+ pass
+
+
+def teardown(trainer):
+ pass
+
+
+# Validator callbacks --------------------------------------------------------------------------------------------------
+def on_val_start(validator):
+ pass
+
+
+def on_val_batch_start(validator):
+ pass
+
+
+def on_val_batch_end(validator):
+ pass
+
+
+def on_val_end(validator):
+ pass
+
+
+# Predictor callbacks --------------------------------------------------------------------------------------------------
+def on_predict_start(predictor):
+ pass
+
+
+def on_predict_batch_start(predictor):
+ pass
+
+
+def on_predict_batch_end(predictor):
+ pass
+
+
+def on_predict_end(predictor):
+ pass
+
+
+# Exporter callbacks ---------------------------------------------------------------------------------------------------
+def on_export_start(exporter):
+ pass
+
+
+def on_export_end(exporter):
+ pass
+
+
+default_callbacks = {
+ # Run in trainer
+ 'on_pretrain_routine_start': on_pretrain_routine_start,
+ 'on_pretrain_routine_end': on_pretrain_routine_end,
+ 'on_train_start': on_train_start,
+ 'on_train_epoch_start': on_train_epoch_start,
+ 'on_train_batch_start': on_train_batch_start,
+ 'optimizer_step': optimizer_step,
+ 'on_before_zero_grad': on_before_zero_grad,
+ 'on_train_batch_end': on_train_batch_end,
+ 'on_train_epoch_end': on_train_epoch_end,
+ 'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val
+ 'on_model_save': on_model_save,
+ 'on_train_end': on_train_end,
+ 'on_params_update': on_params_update,
+ 'teardown': teardown,
+
+ # Run in validator
+ 'on_val_start': on_val_start,
+ 'on_val_batch_start': on_val_batch_start,
+ 'on_val_batch_end': on_val_batch_end,
+ 'on_val_end': on_val_end,
+
+ # Run in predictor
+ 'on_predict_start': on_predict_start,
+ 'on_predict_batch_start': on_predict_batch_start,
+ 'on_predict_batch_end': on_predict_batch_end,
+ 'on_predict_end': on_predict_end,
+
+ # Run in exporter
+ 'on_export_start': on_export_start,
+ 'on_export_end': on_export_end}
+
+
+def add_integration_callbacks(instance):
+ from .clearml import callbacks as clearml_callbacks
+ from .comet import callbacks as comet_callbacks
+ from .hub import callbacks as hub_callbacks
+ from .tensorboard import callbacks as tb_callbacks
+
+ for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks:
+ for k, v in x.items():
+ instance.callbacks[k].append(v) # callback[name].append(func)
diff --git a/ultralytics/yolo/utils/callbacks/clearml.py b/ultralytics/yolo/utils/callbacks/clearml.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a0297995eed8cfa7f08d0f6e0436a7f4d0bef54
--- /dev/null
+++ b/ultralytics/yolo/utils/callbacks/clearml.py
@@ -0,0 +1,56 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
+
+try:
+ import clearml
+ from clearml import Task
+
+ assert hasattr(clearml, '__version__')
+except (ImportError, AssertionError):
+ clearml = None
+
+
+def _log_images(imgs_dict, group="", step=0):
+ task = Task.current_task()
+ if task:
+ for k, v in imgs_dict.items():
+ task.get_logger().report_image(group, k, step, v)
+
+
+def on_pretrain_routine_start(trainer):
+ # TODO: reuse existing task
+ task = Task.init(project_name=trainer.args.project or "YOLOv8",
+ task_name=trainer.args.name,
+ tags=['YOLOv8'],
+ output_uri=True,
+ reuse_last_task_id=False,
+ auto_connect_frameworks={'pytorch': False})
+ task.connect(dict(trainer.args), name='General')
+
+
+def on_train_epoch_end(trainer):
+ if trainer.epoch == 1:
+ _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
+
+
+def on_fit_epoch_end(trainer):
+ if trainer.epoch == 0:
+ model_info = {
+ "Parameters": get_num_params(trainer.model),
+ "GFLOPs": round(get_flops(trainer.model), 3),
+ "Inference speed (ms/img)": round(trainer.validator.speed[1], 3)}
+ Task.current_task().connect(model_info, name='Model')
+
+
+def on_train_end(trainer):
+ Task.current_task().update_output_model(model_path=str(trainer.best),
+ model_name=trainer.args.name,
+ auto_delete_file=False)
+
+
+callbacks = {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_train_epoch_end": on_train_epoch_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_train_end": on_train_end} if clearml else {}
diff --git a/ultralytics/yolo/utils/callbacks/comet.py b/ultralytics/yolo/utils/callbacks/comet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7133cbbe3b367da59f8c5d13ca619a4e8740a1b1
--- /dev/null
+++ b/ultralytics/yolo/utils/callbacks/comet.py
@@ -0,0 +1,45 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
+
+try:
+ import comet_ml
+
+except (ModuleNotFoundError, ImportError):
+ comet_ml = None
+
+
+def on_pretrain_routine_start(trainer):
+ experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8",)
+ experiment.log_parameters(dict(trainer.args))
+
+
+def on_train_epoch_end(trainer):
+ experiment = comet_ml.get_global_experiment()
+ experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
+ if trainer.epoch == 1:
+ for f in trainer.save_dir.glob('train_batch*.jpg'):
+ experiment.log_image(f, name=f.stem, step=trainer.epoch + 1)
+
+
+def on_fit_epoch_end(trainer):
+ experiment = comet_ml.get_global_experiment()
+ experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1)
+ if trainer.epoch == 0:
+ model_info = {
+ "model/parameters": get_num_params(trainer.model),
+ "model/GFLOPs": round(get_flops(trainer.model), 3),
+ "model/speed(ms)": round(trainer.validator.speed[1], 3)}
+ experiment.log_metrics(model_info, step=trainer.epoch + 1)
+
+
+def on_train_end(trainer):
+ experiment = comet_ml.get_global_experiment()
+ experiment.log_model("YOLOv8", file_or_folder=trainer.best, file_name="best.pt", overwrite=True)
+
+
+callbacks = {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_train_epoch_end": on_train_epoch_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_train_end": on_train_end} if comet_ml else {}
diff --git a/ultralytics/yolo/utils/callbacks/hub.py b/ultralytics/yolo/utils/callbacks/hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..47a7e545a474d8eabe5773b4c29a5577078fa1f5
--- /dev/null
+++ b/ultralytics/yolo/utils/callbacks/hub.py
@@ -0,0 +1,76 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import json
+from time import time
+
+import torch
+
+from ultralytics.hub.utils import PREFIX, sync_analytics
+from ultralytics.yolo.utils import LOGGER
+
+
+def on_pretrain_routine_end(trainer):
+ session = getattr(trainer, 'hub_session', None)
+ if session:
+ # Start timer for upload rate limit
+ LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} π")
+ session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
+
+
+def on_fit_epoch_end(trainer):
+ session = getattr(trainer, 'hub_session', None)
+ if session:
+ session.metrics_queue[trainer.epoch] = json.dumps(trainer.metrics) # json string
+ if time() - session.t['metrics'] > session.rate_limits['metrics']:
+ session.upload_metrics()
+ session.t['metrics'] = time() # reset timer
+ session.metrics_queue = {} # reset queue
+
+
+def on_model_save(trainer):
+ session = getattr(trainer, 'hub_session', None)
+ if session:
+ # Upload checkpoints with rate limiting
+ is_best = trainer.best_fitness == trainer.fitness
+ if time() - session.t['ckpt'] > session.rate_limits['ckpt']:
+ LOGGER.info(f"{PREFIX}Uploading checkpoint {session.model_id}")
+ session.upload_model(trainer.epoch, trainer.last, is_best)
+ session.t['ckpt'] = time() # reset timer
+
+
+def on_train_end(trainer):
+ session = getattr(trainer, 'hub_session', None)
+ if session:
+ # Upload final model and metrics with exponential standoff
+ LOGGER.info(f"{PREFIX}Training completed successfully β
\n"
+ f"{PREFIX}Uploading final {session.model_id}")
+ session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
+ session.alive = False # stop heartbeats
+ LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} π")
+
+
+def on_train_start(trainer):
+ sync_analytics(trainer.args)
+
+
+def on_val_start(validator):
+ sync_analytics(validator.args)
+
+
+def on_predict_start(predictor):
+ sync_analytics(predictor.args)
+
+
+def on_export_start(exporter):
+ sync_analytics(exporter.args)
+
+
+callbacks = {
+ "on_pretrain_routine_end": on_pretrain_routine_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_model_save": on_model_save,
+ "on_train_end": on_train_end,
+ "on_train_start": on_train_start,
+ "on_val_start": on_val_start,
+ "on_predict_start": on_predict_start,
+ "on_export_start": on_export_start}
diff --git a/ultralytics/yolo/utils/callbacks/tensorboard.py b/ultralytics/yolo/utils/callbacks/tensorboard.py
new file mode 100644
index 0000000000000000000000000000000000000000..86a230e276f77e97738c50c59b0683c392b22c01
--- /dev/null
+++ b/ultralytics/yolo/utils/callbacks/tensorboard.py
@@ -0,0 +1,29 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from torch.utils.tensorboard import SummaryWriter
+
+writer = None # TensorBoard SummaryWriter instance
+
+
+def _log_scalars(scalars, step=0):
+ for k, v in scalars.items():
+ writer.add_scalar(k, v, step)
+
+
+def on_pretrain_routine_start(trainer):
+ global writer
+ writer = SummaryWriter(str(trainer.save_dir))
+
+
+def on_batch_end(trainer):
+ _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
+
+
+def on_fit_epoch_end(trainer):
+ _log_scalars(trainer.metrics, trainer.epoch + 1)
+
+
+callbacks = {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_batch_end": on_batch_end}
diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f953fa7ffe41a1997c240c13348a819b85a9a72
--- /dev/null
+++ b/ultralytics/yolo/utils/checks.py
@@ -0,0 +1,270 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import glob
+import inspect
+import math
+import platform
+import urllib
+from pathlib import Path
+from subprocess import check_output
+from typing import Optional
+
+import cv2
+import numpy as np
+import pkg_resources as pkg
+import torch
+
+from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
+ is_docker, is_jupyter_notebook)
+
+
+def is_ascii(s) -> bool:
+ """
+ Check if a string is composed of only ASCII characters.
+
+ Args:
+ s (str): String to be checked.
+
+ Returns:
+ bool: True if the string is composed only of ASCII characters, False otherwise.
+ """
+ # Convert list, tuple, None, etc. to string
+ s = str(s)
+
+ # Check if the string is composed of only ASCII characters
+ return all(ord(c) < 128 for c in s)
+
+
+def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
+ """
+ Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
+ stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
+
+ Args:
+ imgsz (int or List[int]): Image size.
+ stride (int): Stride value.
+ min_dim (int): Minimum number of dimensions.
+ floor (int): Minimum allowed value for image size.
+
+ Returns:
+ List[int]: Updated image size.
+ """
+ # Convert stride to integer if it is a tensor
+ stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
+
+ # Convert image size to list if it is an integer
+ if isinstance(imgsz, int):
+ imgsz = [imgsz]
+
+
+ # Make image size a multiple of the stride
+ sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
+
+ # Print warning message if image size was updated
+ if sz != imgsz:
+ LOGGER.warning(f'WARNING β οΈ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
+
+ # Add missing dimensions if necessary
+ sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
+
+ return sz
+
+
+def check_version(current: str = "0.0.0",
+ minimum: str = "0.0.0",
+ name: str = "version ",
+ pinned: bool = False,
+ hard: bool = False,
+ verbose: bool = False) -> bool:
+ """
+ Check current version against the required minimum version.
+
+ Args:
+ current (str): Current version.
+ minimum (str): Required minimum version.
+ name (str): Name to be used in warning message.
+ pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied.
+ hard (bool): If True, raise an AssertionError if the minimum version is not met.
+ verbose (bool): If True, print warning message if minimum version is not met.
+
+ Returns:
+ bool: True if minimum version is met, False otherwise.
+ """
+ from pkg_resources import parse_version
+ current, minimum = (parse_version(x) for x in (current, minimum))
+ result = (current == minimum) if pinned else (current >= minimum) # bool
+ warning_message = f"WARNING β οΈ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed"
+ if hard:
+ assert result, emojis(warning_message) # assert min requirements met
+ if verbose and not result:
+ LOGGER.warning(warning_message)
+ return result
+
+
+def check_font(font: str = FONT, progress: bool = False) -> None:
+ """
+ Download font file to the user's configuration directory if it does not already exist.
+
+ Args:
+ font (str): Path to font file.
+ progress (bool): If True, display a progress bar during the download.
+
+ Returns:
+ None
+ """
+ font = Path(font)
+
+ # Destination path for the font file
+ file = USER_CONFIG_DIR / font.name
+
+ # Check if font file exists at the source or destination path
+ if not font.exists() and not file.exists():
+ # Download font file
+ url = f'https://ultralytics.com/assets/{font.name}'
+ LOGGER.info(f'Downloading {url} to {file}...')
+ torch.hub.download_url_to_file(url, str(file), progress=progress)
+
+
+def check_online() -> bool:
+ """
+ Check internet connectivity by attempting to connect to a known online host.
+
+ Returns:
+ bool: True if connection is successful, False otherwise.
+ """
+ import socket
+ try:
+ # Check host accessibility by attempting to establish a connection
+ socket.create_connection(("1.1.1.1", 443), timeout=5)
+ return True
+ except OSError:
+ return False
+
+
+def check_python(minimum: str = '3.7.0') -> bool:
+ """
+ Check current python version against the required minimum version.
+
+ Args:
+ minimum (str): Required minimum version of python.
+
+ Returns:
+ None
+ """
+ check_version(platform.python_version(), minimum, name='Python ', hard=True)
+
+
+@TryExcept()
+def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
+ # Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages or single package str)
+ prefix = colorstr('red', 'bold', 'requirements:')
+ check_python() # check python version
+ if isinstance(requirements, Path): # requirements.txt file
+ file = requirements.resolve()
+ assert file.exists(), f"{prefix} {file} not found, check failed."
+ with file.open() as f:
+ requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
+ elif isinstance(requirements, str):
+ requirements = [requirements]
+
+ s = ''
+ n = 0
+ for r in requirements:
+ try:
+ pkg.require(r)
+ except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
+ s += f'"{r}" '
+ n += 1
+
+ if s and install and AUTOINSTALL: # check environment variable
+ LOGGER.info(f"{prefix} YOLOv5 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
+ try:
+ assert check_online(), "AutoUpdate skipped (offline)"
+ LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
+ source = file if 'file' in locals() else requirements
+ s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
+ f"{prefix} β οΈ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
+ LOGGER.info(s)
+ except Exception as e:
+ LOGGER.warning(f'{prefix} β {e}')
+
+
+def check_suffix(file='yolov8n.pt', suffix=('.pt',), msg=''):
+ # Check file(s) for acceptable suffix
+ if file and suffix:
+ if isinstance(suffix, str):
+ suffix = [suffix]
+ for f in file if isinstance(file, (list, tuple)) else [file]:
+ s = Path(f).suffix.lower() # file suffix
+ if len(s):
+ assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
+
+
+def check_file(file, suffix=''):
+ # Search/download file (if necessary) and return path
+ check_suffix(file, suffix) # optional
+ file = str(file) # convert to str()
+ if Path(file).is_file() or not file: # exists
+ return file
+ elif file.startswith(('http:/', 'https:/')): # download
+ url = file # warning: Pathlib turns :// -> :/
+ file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
+ if Path(file).is_file():
+ LOGGER.info(f'Found {url} locally at {file}') # file already exists
+ else:
+ LOGGER.info(f'Downloading {url} to {file}...')
+ torch.hub.download_url_to_file(url, file)
+ assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
+ return file
+ else: # search
+ files = []
+ for d in 'models', 'yolo/data': # search directories
+ files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
+ assert len(files), f'File not found: {file}' # assert file was found
+ assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
+ return files[0] # return file
+
+
+def check_yaml(file, suffix=('.yaml', '.yml')):
+ # Search/download YAML file (if necessary) and return path, checking suffix
+ return check_file(file, suffix)
+
+
+def check_imshow(warn=False):
+ # Check if environment supports image displays
+ try:
+ assert not is_jupyter_notebook()
+ assert not is_docker()
+ cv2.imshow('test', np.zeros((1, 1, 3)))
+ cv2.waitKey(1)
+ cv2.destroyAllWindows()
+ cv2.waitKey(1)
+ return True
+ except Exception as e:
+ if warn:
+ LOGGER.warning(f'WARNING β οΈ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
+ return False
+
+
+def git_describe(path=ROOT): # path must be a directory
+ # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
+ try:
+ assert (Path(path) / '.git').is_dir()
+ return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
+ except Exception:
+ return ''
+
+
+def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
+ # Print function arguments (optional args dict)
+ x = inspect.currentframe().f_back # previous frame
+ file, _, func, _, _ = inspect.getframeinfo(x)
+ if args is None: # get args automatically
+ args, _, _, frm = inspect.getargvalues(x)
+ args = {k: v for k, v in frm.items() if k in args}
+ try:
+ file = Path(file).resolve().relative_to(ROOT).with_suffix('')
+ except ValueError:
+ file = Path(file).stem
+ s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
+ LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
diff --git a/ultralytics/yolo/utils/dist.py b/ultralytics/yolo/utils/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3b3bbba06ea7acc590ae9165213318db9d410a5
--- /dev/null
+++ b/ultralytics/yolo/utils/dist.py
@@ -0,0 +1,65 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import os
+import shutil
+import socket
+import sys
+import tempfile
+
+from . import USER_CONFIG_DIR
+
+
+def find_free_network_port() -> int:
+ # https://github.com/Lightning-AI/lightning/blob/master/src/lightning_lite/plugins/environments/lightning.py
+ """Finds a free port on localhost.
+
+ It is useful in single-node training when we don't want to connect to a real main node but have to set the
+ `MASTER_PORT` environment variable.
+ """
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.bind(("", 0))
+ port = s.getsockname()[1]
+ s.close()
+ return port
+
+
+def generate_ddp_file(trainer):
+ import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
+
+ if not trainer.resume:
+ shutil.rmtree(trainer.save_dir) # remove the save_dir
+ content = f'''config = {dict(trainer.args)} \nif __name__ == "__main__":
+ from ultralytics.{import_path} import {trainer.__class__.__name__}
+
+ trainer = {trainer.__class__.__name__}(config=config)
+ trainer.train()'''
+ (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
+ with tempfile.NamedTemporaryFile(prefix="_temp_",
+ suffix=f"{id(trainer)}.py",
+ mode="w+",
+ encoding='utf-8',
+ dir=USER_CONFIG_DIR / 'DDP',
+ delete=False) as file:
+ file.write(content)
+ return file.name
+
+
+def generate_ddp_command(world_size, trainer):
+ import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
+ file_name = os.path.abspath(sys.argv[0])
+ using_cli = not file_name.endswith(".py")
+ if using_cli:
+ file_name = generate_ddp_file(trainer)
+ return [
+ sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", f"{world_size}", "--master_port",
+ f"{find_free_network_port()}", file_name] + sys.argv[1:]
+
+
+def ddp_cleanup(command, trainer):
+ # delete temp file if created
+ tempfile_suffix = f"{id(trainer)}.py"
+ if tempfile_suffix in "".join(command):
+ for chunk in command:
+ if tempfile_suffix in chunk:
+ os.remove(chunk)
+ break
diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2bfc53258247414fa363790189bee733fb95f43
--- /dev/null
+++ b/ultralytics/yolo/utils/downloads.py
@@ -0,0 +1,146 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import logging
+import os
+import subprocess
+import urllib
+from itertools import repeat
+from multiprocessing.pool import ThreadPool
+from pathlib import Path
+from zipfile import ZipFile
+
+import requests
+import torch
+
+from ultralytics.yolo.utils import LOGGER
+
+
+def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
+ # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
+ file = Path(file)
+ assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
+ try: # url1
+ LOGGER.info(f'Downloading {url} to {file}...')
+ torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
+ assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
+ except Exception as e: # url2
+ if file.exists():
+ file.unlink() # remove partial downloads
+ LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
+ os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
+ finally:
+ if not file.exists() or file.stat().st_size < min_bytes: # check
+ if file.exists():
+ file.unlink() # remove partial downloads
+ LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
+ LOGGER.info('')
+
+
+def is_url(url, check=True):
+ # Check if string is URL and check if URL exists
+ try:
+ url = str(url)
+ result = urllib.parse.urlparse(url)
+ assert all([result.scheme, result.netloc]) # check if is url
+ return (urllib.request.urlopen(url).getcode() == 200) if check else True # check if exists online
+ except (AssertionError, urllib.request.HTTPError):
+ return False
+
+
+def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
+ # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
+
+ def github_assets(repository, version='latest'):
+ # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...])
+ # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])
+ if version != 'latest':
+ version = f'tags/{version}' # i.e. tags/v6.2
+ response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
+ return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
+
+ file = Path(str(file).strip().replace("'", ''))
+ if not file.exists():
+ # URL specified
+ name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
+ if str(file).startswith(('http:/', 'https:/')): # download
+ url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
+ file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
+ if Path(file).is_file():
+ LOGGER.info(f'Found {url} locally at {file}') # file already exists
+ else:
+ safe_download(file=file, url=url, min_bytes=1E5)
+ return file
+
+ # GitHub assets
+ assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
+ assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
+ try:
+ tag, assets = github_assets(repo, release)
+ except Exception:
+ try:
+ tag, assets = github_assets(repo) # latest release
+ except Exception:
+ try:
+ tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
+ except Exception:
+ tag = release
+
+ file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
+ if name in assets:
+ url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror
+ safe_download(
+ file,
+ url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
+ min_bytes=1E5,
+ error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
+
+ return str(file)
+
+
+def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):
+ # Multithreaded file download and unzip function, used in data.yaml for autodownload
+ def download_one(url, dir):
+ # Download 1 file
+ success = True
+ if Path(url).is_file():
+ f = Path(url) # filename
+ else: # does not exist
+ f = dir / Path(url).name
+ LOGGER.info(f'Downloading {url} to {f}...')
+ for i in range(retry + 1):
+ if curl:
+ s = 'sS' if threads > 1 else '' # silent
+ r = os.system(
+ f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
+ success = r == 0
+ else:
+ torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
+ success = f.is_file()
+ if success:
+ break
+ elif i < retry:
+ LOGGER.warning(f'β οΈ Download failure, retrying {i + 1}/{retry} {url}...')
+ else:
+ LOGGER.warning(f'β Failed to download {url}...')
+
+ if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
+ LOGGER.info(f'Unzipping {f}...')
+ if f.suffix == '.zip':
+ ZipFile(f).extractall(path=dir) # unzip
+ elif f.suffix == '.tar':
+ os.system(f'tar xf {f} --directory {f.parent}') # unzip
+ elif f.suffix == '.gz':
+ os.system(f'tar xfz {f} --directory {f.parent}') # unzip
+ if delete:
+ f.unlink() # remove zip
+
+ dir = Path(dir)
+ dir.mkdir(parents=True, exist_ok=True) # make directory
+ if threads > 1:
+ pool = ThreadPool(threads)
+ pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
+ pool.close()
+ pool.join()
+ else:
+ for u in [url] if isinstance(url, (str, Path)) else url:
+ download_one(u, dir)
diff --git a/ultralytics/yolo/utils/files.py b/ultralytics/yolo/utils/files.py
new file mode 100644
index 0000000000000000000000000000000000000000..7360ca77b8189c86a27fb4be786bfe0ca526a1e7
--- /dev/null
+++ b/ultralytics/yolo/utils/files.py
@@ -0,0 +1,103 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import contextlib
+import glob
+import os
+import urllib
+from datetime import datetime
+from pathlib import Path
+from zipfile import ZipFile
+
+
+class WorkingDirectory(contextlib.ContextDecorator):
+ # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
+ def __init__(self, new_dir):
+ self.dir = new_dir # new dir
+ self.cwd = Path.cwd().resolve() # current dir
+
+ def __enter__(self):
+ os.chdir(self.dir)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ os.chdir(self.cwd)
+
+
+def increment_path(path, exist_ok=False, sep='', mkdir=False):
+ """
+ Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
+
+ If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
+ the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
+ number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
+ directory if it does not already exist.
+
+ Args:
+ path (str or pathlib.Path): Path to increment.
+ exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
+ sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
+ mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
+
+ Returns:
+ pathlib.Path: Incremented path.
+ """
+ path = Path(path) # os-agnostic
+ if path.exists() and not exist_ok:
+ path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
+
+ # Method 1
+ for n in range(2, 9999):
+ p = f'{path}{sep}{n}{suffix}' # increment path
+ if not os.path.exists(p): #
+ break
+ path = Path(p)
+
+ if mkdir:
+ path.mkdir(parents=True, exist_ok=True) # make directory
+
+ return path
+
+
+def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
+ # Unzip a *.zip file to path/, excluding files containing strings in exclude list
+ if path is None:
+ path = Path(file).parent # default path
+ with ZipFile(file) as zipObj:
+ for f in zipObj.namelist(): # list all archived filenames in the zip
+ if all(x not in f for x in exclude):
+ zipObj.extract(f, path=path)
+
+
+def file_age(path=__file__):
+ # Return days since last file update
+ dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
+ return dt.days # + dt.seconds / 86400 # fractional days
+
+
+def file_date(path=__file__):
+ # Return human-readable file modification date, i.e. '2021-3-26'
+ t = datetime.fromtimestamp(Path(path).stat().st_mtime)
+ return f'{t.year}-{t.month}-{t.day}'
+
+
+def file_size(path):
+ # Return file/dir size (MB)
+ mb = 1 << 20 # bytes to MiB (1024 ** 2)
+ path = Path(path)
+ if path.is_file():
+ return path.stat().st_size / mb
+ elif path.is_dir():
+ return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
+ else:
+ return 0.0
+
+
+def url2file(url):
+ # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
+ url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
+ return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
+
+
+def get_latest_run(search_dir='.'):
+ # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
+ last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
+ return max(last_list, key=os.path.getctime) if last_list else ''
diff --git a/ultralytics/yolo/utils/instance.py b/ultralytics/yolo/utils/instance.py
new file mode 100644
index 0000000000000000000000000000000000000000..965a616fd6b917a26c6f932d87b0a87842f9f52b
--- /dev/null
+++ b/ultralytics/yolo/utils/instance.py
@@ -0,0 +1,337 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from collections import abc
+from itertools import repeat
+from numbers import Number
+from typing import List
+
+import numpy as np
+
+from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
+
+
+# From PyTorch internals
+def _ntuple(n):
+
+ def parse(x):
+ return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
+
+ return parse
+
+
+to_4tuple = _ntuple(4)
+
+# `xyxy` means left top and right bottom
+# `xywh` means center x, center y and width, height(yolo format)
+# `ltwh` means left top and width, height(coco format)
+_formats = ["xyxy", "xywh", "ltwh"]
+
+__all__ = ["Bboxes"]
+
+
+class Bboxes:
+ """Now only numpy is supported"""
+
+ def __init__(self, bboxes, format="xyxy") -> None:
+ assert format in _formats
+ bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
+ assert bboxes.ndim == 2
+ assert bboxes.shape[1] == 4
+ self.bboxes = bboxes
+ self.format = format
+ # self.normalized = normalized
+
+ # def convert(self, format):
+ # assert format in _formats
+ # if self.format == format:
+ # bboxes = self.bboxes
+ # elif self.format == "xyxy":
+ # if format == "xywh":
+ # bboxes = xyxy2xywh(self.bboxes)
+ # else:
+ # bboxes = xyxy2ltwh(self.bboxes)
+ # elif self.format == "xywh":
+ # if format == "xyxy":
+ # bboxes = xywh2xyxy(self.bboxes)
+ # else:
+ # bboxes = xywh2ltwh(self.bboxes)
+ # else:
+ # if format == "xyxy":
+ # bboxes = ltwh2xyxy(self.bboxes)
+ # else:
+ # bboxes = ltwh2xywh(self.bboxes)
+ #
+ # return Bboxes(bboxes, format)
+
+ def convert(self, format):
+ assert format in _formats
+ if self.format == format:
+ return
+ elif self.format == "xyxy":
+ bboxes = xyxy2xywh(self.bboxes) if format == "xywh" else xyxy2ltwh(self.bboxes)
+ elif self.format == "xywh":
+ bboxes = xywh2xyxy(self.bboxes) if format == "xyxy" else xywh2ltwh(self.bboxes)
+ else:
+ bboxes = ltwh2xyxy(self.bboxes) if format == "xyxy" else ltwh2xywh(self.bboxes)
+ self.bboxes = bboxes
+ self.format = format
+
+ def areas(self):
+ self.convert("xyxy")
+ return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
+
+ # def denormalize(self, w, h):
+ # if not self.normalized:
+ # return
+ # assert (self.bboxes <= 1.0).all()
+ # self.bboxes[:, 0::2] *= w
+ # self.bboxes[:, 1::2] *= h
+ # self.normalized = False
+ #
+ # def normalize(self, w, h):
+ # if self.normalized:
+ # return
+ # assert (self.bboxes > 1.0).any()
+ # self.bboxes[:, 0::2] /= w
+ # self.bboxes[:, 1::2] /= h
+ # self.normalized = True
+
+ def mul(self, scale):
+ """
+ Args:
+ scale (tuple | List | int): the scale for four coords.
+ """
+ if isinstance(scale, Number):
+ scale = to_4tuple(scale)
+ assert isinstance(scale, (tuple, list))
+ assert len(scale) == 4
+ self.bboxes[:, 0] *= scale[0]
+ self.bboxes[:, 1] *= scale[1]
+ self.bboxes[:, 2] *= scale[2]
+ self.bboxes[:, 3] *= scale[3]
+
+ def add(self, offset):
+ """
+ Args:
+ offset (tuple | List | int): the offset for four coords.
+ """
+ if isinstance(offset, Number):
+ offset = to_4tuple(offset)
+ assert isinstance(offset, (tuple, list))
+ assert len(offset) == 4
+ self.bboxes[:, 0] += offset[0]
+ self.bboxes[:, 1] += offset[1]
+ self.bboxes[:, 2] += offset[2]
+ self.bboxes[:, 3] += offset[3]
+
+ def __len__(self):
+ return len(self.bboxes)
+
+ @classmethod
+ def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
+ """
+ Concatenates a list of Boxes into a single Bboxes
+
+ Arguments:
+ boxes_list (list[Bboxes])
+
+ Returns:
+ Bboxes: the concatenated Boxes
+ """
+ assert isinstance(boxes_list, (list, tuple))
+ if not boxes_list:
+ return cls(np.empty(0))
+ assert all(isinstance(box, Bboxes) for box in boxes_list)
+
+ if len(boxes_list) == 1:
+ return boxes_list[0]
+ return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
+
+ def __getitem__(self, index) -> "Bboxes":
+ """
+ Args:
+ index: int, slice, or a BoolArray
+
+ Returns:
+ Bboxes: Create a new :class:`Bboxes` by indexing.
+ """
+ if isinstance(index, int):
+ return Bboxes(self.bboxes[index].view(1, -1))
+ b = self.bboxes[index]
+ assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
+ return Bboxes(b)
+
+
+class Instances:
+
+ def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
+ """
+ Args:
+ bboxes (ndarray): bboxes with shape [N, 4].
+ segments (list | ndarray): segments.
+ keypoints (ndarray): keypoints with shape [N, 17, 2].
+ """
+ if segments is None:
+ segments = []
+ self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
+ self.keypoints = keypoints
+ self.normalized = normalized
+
+ if len(segments) > 0:
+ # list[np.array(1000, 2)] * num_samples
+ segments = resample_segments(segments)
+ # (N, 1000, 2)
+ segments = np.stack(segments, axis=0)
+ else:
+ segments = np.zeros((0, 1000, 2), dtype=np.float32)
+ self.segments = segments
+
+ def convert_bbox(self, format):
+ self._bboxes.convert(format=format)
+
+ def bbox_areas(self):
+ self._bboxes.areas()
+
+ def scale(self, scale_w, scale_h, bbox_only=False):
+ """this might be similar with denormalize func but without normalized sign"""
+ self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
+ if bbox_only:
+ return
+ self.segments[..., 0] *= scale_w
+ self.segments[..., 1] *= scale_h
+ if self.keypoints is not None:
+ self.keypoints[..., 0] *= scale_w
+ self.keypoints[..., 1] *= scale_h
+
+ def denormalize(self, w, h):
+ if not self.normalized:
+ return
+ self._bboxes.mul(scale=(w, h, w, h))
+ self.segments[..., 0] *= w
+ self.segments[..., 1] *= h
+ if self.keypoints is not None:
+ self.keypoints[..., 0] *= w
+ self.keypoints[..., 1] *= h
+ self.normalized = False
+
+ def normalize(self, w, h):
+ if self.normalized:
+ return
+ self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
+ self.segments[..., 0] /= w
+ self.segments[..., 1] /= h
+ if self.keypoints is not None:
+ self.keypoints[..., 0] /= w
+ self.keypoints[..., 1] /= h
+ self.normalized = True
+
+ def add_padding(self, padw, padh):
+ # handle rect and mosaic situation
+ assert not self.normalized, "you should add padding with absolute coordinates."
+ self._bboxes.add(offset=(padw, padh, padw, padh))
+ self.segments[..., 0] += padw
+ self.segments[..., 1] += padh
+ if self.keypoints is not None:
+ self.keypoints[..., 0] += padw
+ self.keypoints[..., 1] += padh
+
+ def __getitem__(self, index) -> "Instances":
+ """
+ Args:
+ index: int, slice, or a BoolArray
+
+ Returns:
+ Instances: Create a new :class:`Instances` by indexing.
+ """
+ segments = self.segments[index] if len(self.segments) else self.segments
+ keypoints = self.keypoints[index] if self.keypoints is not None else None
+ bboxes = self.bboxes[index]
+ bbox_format = self._bboxes.format
+ return Instances(
+ bboxes=bboxes,
+ segments=segments,
+ keypoints=keypoints,
+ bbox_format=bbox_format,
+ normalized=self.normalized,
+ )
+
+ def flipud(self, h):
+ if self._bboxes.format == "xyxy":
+ y1 = self.bboxes[:, 1].copy()
+ y2 = self.bboxes[:, 3].copy()
+ self.bboxes[:, 1] = h - y2
+ self.bboxes[:, 3] = h - y1
+ else:
+ self.bboxes[:, 1] = h - self.bboxes[:, 1]
+ self.segments[..., 1] = h - self.segments[..., 1]
+ if self.keypoints is not None:
+ self.keypoints[..., 1] = h - self.keypoints[..., 1]
+
+ def fliplr(self, w):
+ if self._bboxes.format == "xyxy":
+ x1 = self.bboxes[:, 0].copy()
+ x2 = self.bboxes[:, 2].copy()
+ self.bboxes[:, 0] = w - x2
+ self.bboxes[:, 2] = w - x1
+ else:
+ self.bboxes[:, 0] = w - self.bboxes[:, 0]
+ self.segments[..., 0] = w - self.segments[..., 0]
+ if self.keypoints is not None:
+ self.keypoints[..., 0] = w - self.keypoints[..., 0]
+
+ def clip(self, w, h):
+ ori_format = self._bboxes.format
+ self.convert_bbox(format="xyxy")
+ self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
+ self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
+ if ori_format != "xyxy":
+ self.convert_bbox(format=ori_format)
+ self.segments[..., 0] = self.segments[..., 0].clip(0, w)
+ self.segments[..., 1] = self.segments[..., 1].clip(0, h)
+ if self.keypoints is not None:
+ self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
+ self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
+
+ def update(self, bboxes, segments=None, keypoints=None):
+ new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
+ self._bboxes = new_bboxes
+ if segments is not None:
+ self.segments = segments
+ if keypoints is not None:
+ self.keypoints = keypoints
+
+ def __len__(self):
+ return len(self.bboxes)
+
+ @classmethod
+ def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
+ """
+ Concatenates a list of Boxes into a single Bboxes
+
+ Arguments:
+ instances_list (list[Bboxes])
+ axis
+
+ Returns:
+ Boxes: the concatenated Boxes
+ """
+ assert isinstance(instances_list, (list, tuple))
+ if not instances_list:
+ return cls(np.empty(0))
+ assert all(isinstance(instance, Instances) for instance in instances_list)
+
+ if len(instances_list) == 1:
+ return instances_list[0]
+
+ use_keypoint = instances_list[0].keypoints is not None
+ bbox_format = instances_list[0]._bboxes.format
+ normalized = instances_list[0].normalized
+
+ cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
+ cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
+ cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
+ return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
+
+ @property
+ def bboxes(self):
+ return self._bboxes.bboxes
diff --git a/ultralytics/yolo/utils/loss.py b/ultralytics/yolo/utils/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..60dd6b2e7f83806e89976afa18c9360b3bc0b247
--- /dev/null
+++ b/ultralytics/yolo/utils/loss.py
@@ -0,0 +1,55 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .metrics import bbox_iou
+from .tal import bbox2dist
+
+
+class VarifocalLoss(nn.Module):
+ # Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
+ weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
+ with torch.cuda.amp.autocast(enabled=False):
+ loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") *
+ weight).sum()
+ return loss
+
+
+class BboxLoss(nn.Module):
+
+ def __init__(self, reg_max, use_dfl=False):
+ super().__init__()
+ self.reg_max = reg_max
+ self.use_dfl = use_dfl
+
+ def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
+ # IoU loss
+ weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
+ iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
+ loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
+
+ # DFL loss
+ if self.use_dfl:
+ target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
+ loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
+ loss_dfl = loss_dfl.sum() / target_scores_sum
+ else:
+ loss_dfl = torch.tensor(0.0).to(pred_dist.device)
+
+ return loss_iou, loss_dfl
+
+ @staticmethod
+ def _df_loss(pred_dist, target):
+ # Return sum of left and right DFL losses
+ tl = target.long() # target left
+ tr = tl + 1 # target right
+ wl = tr - target # weight left
+ wr = 1 - wl # weight right
+ return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl +
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True)
diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a973a969bc0ebc27d6f3a603dc4808dec03a93
--- /dev/null
+++ b/ultralytics/yolo/utils/metrics.py
@@ -0,0 +1,617 @@
+# Ultralytics YOLO π, GPL-3.0 license
+"""
+Model validation metrics
+"""
+import math
+import warnings
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ultralytics.yolo.utils import TryExcept
+
+
+# boxes
+def box_area(box):
+ # box = xyxy(4,n)
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+
+def bbox_ioa(box1, box2, eps=1e-7):
+ """Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
+ box1: np.array of shape(nx4)
+ box2: np.array of shape(mx4)
+ returns: np.array of shape(nxm)
+ """
+
+ # Get the coordinates of bounding boxes
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
+
+ # Intersection area
+ inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
+ (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
+
+ # box2 area
+ box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
+
+ # Intersection over box2 area
+ return inter_area / box2_area
+
+
+def box_iou(box1, box2, eps=1e-7):
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Arguments:
+ box1 (Tensor[N, 4])
+ box2 (Tensor[M, 4])
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
+ (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
+ inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
+
+ # IoU = inter / (area1 + area2 - inter)
+ return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
+
+
+def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
+ # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
+
+ # Get the coordinates of bounding boxes
+ if xywh: # transform from xywh to xyxy
+ (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
+ w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
+ b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
+ b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
+ else: # x1, y1, x2, y2 = box1
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
+
+ # Intersection area
+ inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
+ (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
+
+ # Union Area
+ union = w1 * h1 + w2 * h2 - inter + eps
+
+ # IoU
+ iou = inter / union
+ if CIoU or DIoU or GIoU:
+ cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
+ ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
+ if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
+ v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
+ with torch.no_grad():
+ alpha = v / (v - iou + (1 + eps))
+ return iou - (rho2 / c2 + v * alpha) # CIoU
+ return iou - rho2 / c2 # DIoU
+ c_area = cw * ch + eps # convex area
+ return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
+ return iou # IoU
+
+
+def mask_iou(mask1, mask2, eps=1e-7):
+ """
+ mask1: [N, n] m1 means number of predicted objects
+ mask2: [M, n] m2 means number of gt objects
+ Note: n means image_w x image_h
+ return: masks iou, [N, M]
+ """
+ intersection = torch.matmul(mask1, mask2.t()).clamp(0)
+ union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
+ return intersection / (union + eps)
+
+
+def masks_iou(mask1, mask2, eps=1e-7):
+ """
+ mask1: [N, n] m1 means number of predicted objects
+ mask2: [N, n] m2 means number of gt objects
+ Note: n means image_w x image_h
+ return: masks iou, (N, )
+ """
+ intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
+ union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
+ return intersection / (union + eps)
+
+
+def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
+ # return positive, negative label smoothing BCE targets
+ return 1.0 - 0.5 * eps, 0.5 * eps
+
+
+# losses
+class FocalLoss(nn.Module):
+ # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
+ def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
+ super().__init__()
+ self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.reduction = loss_fcn.reduction
+ self.loss_fcn.reduction = 'none' # required to apply FL to each element
+
+ def forward(self, pred, true):
+ loss = self.loss_fcn(pred, true)
+ # p_t = torch.exp(-loss)
+ # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
+
+ # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
+ pred_prob = torch.sigmoid(pred) # prob from logits
+ p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
+ alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
+ modulating_factor = (1.0 - p_t) ** self.gamma
+ loss *= alpha_factor * modulating_factor
+
+ if self.reduction == 'mean':
+ return loss.mean()
+ elif self.reduction == 'sum':
+ return loss.sum()
+ else: # 'none'
+ return loss
+
+
+class ConfusionMatrix:
+ # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
+ def __init__(self, nc, conf=0.25, iou_thres=0.45):
+ self.matrix = np.zeros((nc + 1, nc + 1))
+ self.nc = nc # number of classes
+ self.conf = conf
+ self.iou_thres = iou_thres
+
+ def process_batch(self, detections, labels):
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Arguments:
+ detections (Array[N, 6]), x1, y1, x2, y2, conf, class
+ labels (Array[M, 5]), class, x1, y1, x2, y2
+ Returns:
+ None, updates confusion matrix accordingly
+ """
+ if detections is None:
+ gt_classes = labels.int()
+ for gc in gt_classes:
+ self.matrix[self.nc, gc] += 1 # background FN
+ return
+
+ detections = detections[detections[:, 4] > self.conf]
+ gt_classes = labels[:, 0].int()
+ detection_classes = detections[:, 5].int()
+ iou = box_iou(labels[:, 1:], detections[:, :4])
+
+ x = torch.where(iou > self.iou_thres)
+ if x[0].shape[0]:
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
+ if x[0].shape[0] > 1:
+ matches = matches[matches[:, 2].argsort()[::-1]]
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
+ matches = matches[matches[:, 2].argsort()[::-1]]
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
+ else:
+ matches = np.zeros((0, 3))
+
+ n = matches.shape[0] > 0
+ m0, m1, _ = matches.transpose().astype(int)
+ for i, gc in enumerate(gt_classes):
+ j = m0 == i
+ if n and sum(j) == 1:
+ self.matrix[detection_classes[m1[j]], gc] += 1 # correct
+ else:
+ self.matrix[self.nc, gc] += 1 # true background
+
+ if n:
+ for i, dc in enumerate(detection_classes):
+ if not any(m1 == i):
+ self.matrix[dc, self.nc] += 1 # predicted background
+
+ def matrix(self):
+ return self.matrix
+
+ def tp_fp(self):
+ tp = self.matrix.diagonal() # true positives
+ fp = self.matrix.sum(1) - tp # false positives
+ # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
+ return tp[:-1], fp[:-1] # remove background class
+
+ @TryExcept('WARNING β οΈ ConfusionMatrix plot failure')
+ def plot(self, normalize=True, save_dir='', names=()):
+ import seaborn as sn
+
+ array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
+ array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
+
+ fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
+ nc, nn = self.nc, len(names) # number of classes, names
+ sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
+ labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
+ ticklabels = (names + ['background']) if labels else "auto"
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
+ sn.heatmap(array,
+ ax=ax,
+ annot=nc < 30,
+ annot_kws={
+ "size": 8},
+ cmap='Blues',
+ fmt='.2f',
+ square=True,
+ vmin=0.0,
+ xticklabels=ticklabels,
+ yticklabels=ticklabels).set_facecolor((1, 1, 1))
+ ax.set_ylabel('True')
+ ax.set_ylabel('Predicted')
+ ax.set_title('Confusion Matrix')
+ fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
+ plt.close(fig)
+
+ def print(self):
+ for i in range(self.nc + 1):
+ print(' '.join(map(str, self.matrix[i])))
+
+
+def smooth(y, f=0.05):
+ # Box filter of fraction f
+ nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
+ p = np.ones(nf // 2) # ones padding
+ yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
+ return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
+
+
+def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
+ # Precision-recall curve
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
+ py = np.stack(py, axis=1)
+
+ if 0 < len(names) < 21: # display per-class legend if < 21 classes
+ for i, y in enumerate(py.T):
+ ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
+ else:
+ ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
+
+ ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
+ ax.set_xlabel('Recall')
+ ax.set_ylabel('Precision')
+ ax.set_xlim(0, 1)
+ ax.set_ylim(0, 1)
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
+ ax.set_title('Precision-Recall Curve')
+ fig.savefig(save_dir, dpi=250)
+ plt.close(fig)
+
+
+def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
+ # Metric-confidence curve
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
+
+ if 0 < len(names) < 21: # display per-class legend if < 21 classes
+ for i, y in enumerate(py):
+ ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
+ else:
+ ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
+
+ y = smooth(py.mean(0), 0.05)
+ ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
+ ax.set_xlabel(xlabel)
+ ax.set_ylabel(ylabel)
+ ax.set_xlim(0, 1)
+ ax.set_ylim(0, 1)
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
+ ax.set_title(f'{ylabel}-Confidence Curve')
+ fig.savefig(save_dir, dpi=250)
+ plt.close(fig)
+
+
+def compute_ap(recall, precision):
+ """ Compute the average precision, given the recall and precision curves
+ # Arguments
+ recall: The recall curve (list)
+ precision: The precision curve (list)
+ # Returns
+ Average precision, precision curve, recall curve
+ """
+
+ # Append sentinel values to beginning and end
+ mrec = np.concatenate(([0.0], recall, [1.0]))
+ mpre = np.concatenate(([1.0], precision, [0.0]))
+
+ # Compute the precision envelope
+ mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
+
+ # Integrate area under curve
+ method = 'interp' # methods: 'continuous', 'interp'
+ if method == 'interp':
+ x = np.linspace(0, 1, 101) # 101-point interp (COCO)
+ ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
+ else: # 'continuous'
+ i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
+
+ return ap, mpre, mrec
+
+
+def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""):
+ """ Compute the average precision, given the recall and precision curves.
+ Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
+ # Arguments
+ tp: True positives (nparray, nx1 or nx10).
+ conf: Objectness value from 0-1 (nparray).
+ pred_cls: Predicted object classes (nparray).
+ target_cls: True object classes (nparray).
+ plot: Plot precision-recall curve at mAP@0.5
+ save_dir: Plot save directory
+ # Returns
+ The average precision as computed in py-faster-rcnn.
+ """
+
+ # Sort by objectness
+ i = np.argsort(-conf)
+ tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
+
+ # Find unique classes
+ unique_classes, nt = np.unique(target_cls, return_counts=True)
+ nc = unique_classes.shape[0] # number of classes, number of detections
+
+ # Create Precision-Recall curve and compute AP for each class
+ px, py = np.linspace(0, 1, 1000), [] # for plotting
+ ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
+ for ci, c in enumerate(unique_classes):
+ i = pred_cls == c
+ n_l = nt[ci] # number of labels
+ n_p = i.sum() # number of predictions
+ if n_p == 0 or n_l == 0:
+ continue
+
+ # Accumulate FPs and TPs
+ fpc = (1 - tp[i]).cumsum(0)
+ tpc = tp[i].cumsum(0)
+
+ # Recall
+ recall = tpc / (n_l + eps) # recall curve
+ r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
+
+ # Precision
+ precision = tpc / (tpc + fpc) # precision curve
+ p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
+
+ # AP from recall-precision curve
+ for j in range(tp.shape[1]):
+ ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
+ if plot and j == 0:
+ py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
+
+ # Compute F1 (harmonic mean of precision and recall)
+ f1 = 2 * p * r / (p + r + eps)
+ names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
+ names = dict(enumerate(names)) # to dict
+ if plot:
+ plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names)
+ plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1')
+ plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision')
+ plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall')
+
+ i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
+ p, r, f1 = p[:, i], r[:, i], f1[:, i]
+ tp = (r * nt).round() # true positives
+ fp = (tp / (p + eps) - tp).round() # false positives
+ return tp, fp, p, r, f1, ap, unique_classes.astype(int)
+
+
+class Metric:
+
+ def __init__(self) -> None:
+ self.p = [] # (nc, )
+ self.r = [] # (nc, )
+ self.f1 = [] # (nc, )
+ self.all_ap = [] # (nc, 10)
+ self.ap_class_index = [] # (nc, )
+
+ @property
+ def ap50(self):
+ """AP@0.5 of all classes.
+ Return:
+ (nc, ) or [].
+ """
+ return self.all_ap[:, 0] if len(self.all_ap) else []
+
+ @property
+ def ap(self):
+ """AP@0.5:0.95
+ Return:
+ (nc, ) or [].
+ """
+ return self.all_ap.mean(1) if len(self.all_ap) else []
+
+ @property
+ def mp(self):
+ """mean precision of all classes.
+ Return:
+ float.
+ """
+ return self.p.mean() if len(self.p) else 0.0
+
+ @property
+ def mr(self):
+ """mean recall of all classes.
+ Return:
+ float.
+ """
+ return self.r.mean() if len(self.r) else 0.0
+
+ @property
+ def map50(self):
+ """Mean AP@0.5 of all classes.
+ Return:
+ float.
+ """
+ return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
+
+ @property
+ def map(self):
+ """Mean AP@0.5:0.95 of all classes.
+ Return:
+ float.
+ """
+ return self.all_ap.mean() if len(self.all_ap) else 0.0
+
+ def mean_results(self):
+ """Mean of results, return mp, mr, map50, map"""
+ return [self.mp, self.mr, self.map50, self.map]
+
+ def class_result(self, i):
+ """class-aware result, return p[i], r[i], ap50[i], ap[i]"""
+ return self.p[i], self.r[i], self.ap50[i], self.ap[i]
+
+ def get_maps(self, nc):
+ maps = np.zeros(nc) + self.map
+ for i, c in enumerate(self.ap_class_index):
+ maps[c] = self.ap[i]
+ return maps
+
+ def fitness(self):
+ # Model fitness as a weighted combination of metrics
+ w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
+ return (np.array(self.mean_results()) * w).sum()
+
+ def update(self, results):
+ """
+ Args:
+ results: tuple(p, r, ap, f1, ap_class)
+ """
+ self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
+
+
+class DetMetrics:
+
+ def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
+ self.save_dir = save_dir
+ self.plot = plot
+ self.names = names
+ self.metric = Metric()
+
+ def process(self, tp, conf, pred_cls, target_cls):
+ results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
+ names=self.names)[2:]
+ self.metric.update(results)
+
+ @property
+ def keys(self):
+ return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
+
+ def mean_results(self):
+ return self.metric.mean_results()
+
+ def class_result(self, i):
+ return self.metric.class_result(i)
+
+ def get_maps(self, nc):
+ return self.metric.get_maps(nc)
+
+ @property
+ def fitness(self):
+ return self.metric.fitness()
+
+ @property
+ def ap_class_index(self):
+ return self.metric.ap_class_index
+
+ @property
+ def results_dict(self):
+ return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
+
+
+class SegmentMetrics:
+
+ def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
+ self.save_dir = save_dir
+ self.plot = plot
+ self.names = names
+ self.metric_box = Metric()
+ self.metric_mask = Metric()
+
+ def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
+ results_mask = ap_per_class(tp_m,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ prefix="Mask")[2:]
+ self.metric_mask.update(results_mask)
+ results_box = ap_per_class(tp_b,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ prefix="Box")[2:]
+ self.metric_box.update(results_box)
+
+ @property
+ def keys(self):
+ return [
+ "metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)",
+ "metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"]
+
+ def mean_results(self):
+ return self.metric_box.mean_results() + self.metric_mask.mean_results()
+
+ def class_result(self, i):
+ return self.metric_box.class_result(i) + self.metric_mask.class_result(i)
+
+ def get_maps(self, nc):
+ return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc)
+
+ @property
+ def fitness(self):
+ return self.metric_mask.fitness() + self.metric_box.fitness()
+
+ @property
+ def ap_class_index(self):
+ # boxes and masks have the same ap_class_index
+ return self.metric_box.ap_class_index
+
+ @property
+ def results_dict(self):
+ return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
+
+
+class ClassifyMetrics:
+
+ def __init__(self) -> None:
+ self.top1 = 0
+ self.top5 = 0
+
+ def process(self, targets, pred):
+ # target classes and predicted classes
+ pred, targets = torch.cat(pred), torch.cat(targets)
+ correct = (targets[:, None] == pred).float()
+ acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
+ self.top1, self.top5 = acc.mean(0).tolist()
+
+ @property
+ def fitness(self):
+ return self.top5
+
+ @property
+ def results_dict(self):
+ return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
+
+ @property
+ def keys(self):
+ return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e5d523618fe506fa2278c0d9eadfdf80964547
--- /dev/null
+++ b/ultralytics/yolo/utils/ops.py
@@ -0,0 +1,674 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import contextlib
+import math
+import re
+import time
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+
+from ultralytics.yolo.utils import LOGGER
+
+from .metrics import box_iou
+
+
+class Profile(contextlib.ContextDecorator):
+ # YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
+ def __init__(self, t=0.0):
+ self.t = t
+ self.cuda = torch.cuda.is_available()
+
+ def __enter__(self):
+ self.start = self.time()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.dt = self.time() - self.start # delta-time
+ self.t += self.dt # accumulate dt
+
+ def time(self):
+ if self.cuda:
+ torch.cuda.synchronize()
+ return time.time()
+
+
+def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
+ # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
+ # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
+ # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
+ # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
+ # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
+ return [
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
+ 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
+
+
+def segment2box(segment, width=640, height=640):
+ """
+ > Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to
+ (xyxy)
+ Args:
+ segment: the segment label
+ width: the width of the image. Defaults to 640
+ height: The height of the image. Defaults to 640
+
+ Returns:
+ the minimum and maximum x and y values of the segment.
+ """
+ # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
+ x, y = segment.T # segment xy
+ inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
+ x, y, = x[inside], y[inside]
+ return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy
+
+
+def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
+ """
+ > Rescale boxes (xyxy) from img1_shape to img0_shape
+ Args:
+ img1_shape: The shape of the image that the bounding boxes are for.
+ boxes: the bounding boxes of the objects in the image
+ img0_shape: the shape of the original image
+ ratio_pad: a tuple of (ratio, pad)
+
+ Returns:
+ The boxes are being returned.
+ """
+ #
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ boxes[..., [0, 2]] -= pad[0] # x padding
+ boxes[..., [1, 3]] -= pad[1] # y padding
+ boxes[..., :4] /= gain
+ clip_boxes(boxes, img0_shape)
+ return boxes
+
+
+def make_divisible(x, divisor):
+ # Returns nearest x divisible by divisor
+ if isinstance(divisor, torch.Tensor):
+ divisor = int(divisor.max()) # to int
+ return math.ceil(x / divisor) * divisor
+
+
+def non_max_suppression(
+ prediction,
+ conf_thres=0.25,
+ iou_thres=0.45,
+ classes=None,
+ agnostic=False,
+ multi_label=False,
+ labels=(),
+ max_det=300,
+ nm=0, # number of masks
+):
+ """
+ > Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
+
+ Arguments:
+ prediction (torch.Tensor): A tensor of shape (batch_size, num_boxes, num_classes + 4 + num_masks)
+ containing the predicted boxes, classes, and masks. The tensor should be in the format
+ output by a model, such as YOLO.
+ conf_thres (float): The confidence threshold below which boxes will be filtered out.
+ Valid values are between 0.0 and 1.0.
+ iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
+ Valid values are between 0.0 and 1.0.
+ classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
+ agnostic (bool): If True, the model is agnostic to the number of classes, and all
+ classes will be considered as one.
+ multi_label (bool): If True, each box may have multiple labels.
+ labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
+ list contains the apriori labels for a given image. The list should be in the format
+ output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
+ max_det (int): The maximum number of boxes to keep after NMS.
+ nm (int): The number of masks output by the model.
+
+ Returns:
+ List[torch.Tensor]: A list of length batch_size, where each element is a tensor of
+ shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
+ (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
+ """
+
+ # Checks
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
+ if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
+ prediction = prediction[0] # select only inference output
+
+ device = prediction.device
+ mps = 'mps' in device.type # Apple MPS
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
+ prediction = prediction.cpu()
+ bs = prediction.shape[0] # batch size
+ nc = prediction.shape[1] - nm - 4 # number of classes
+ mi = 4 + nc # mask start index
+ xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
+
+ # Settings
+ # min_wh = 2 # (pixels) minimum box width and height
+ max_wh = 7680 # (pixels) maximum box width and height
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
+ time_limit = 0.5 + 0.05 * bs # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
+ x = x.transpose(0, -1)[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ lb = labels[xi]
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
+ v[:, :4] = lb[:, 1:5] # box
+ v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ box, cls, mask = x.split((4, nc, nm), 1)
+ box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
+ if multi_label:
+ i, j = (cls > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
+ else: # best class only
+ conf, j = cls.max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Apply finite constraint
+ # if not torch.isfinite(x).all():
+ # x = x[torch.isfinite(x).all(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ i = i[:max_det] # limit detections
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if mps:
+ output[xi] = output[xi].to(device)
+ if (time.time() - t) > time_limit:
+ LOGGER.warning(f'WARNING β οΈ NMS time limit {time_limit:.3f}s exceeded')
+ break # time limit exceeded
+
+ return output
+
+
+def clip_boxes(boxes, shape):
+ """
+ > It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
+ shape
+
+ Args:
+ boxes: the bounding boxes to clip
+ shape: the shape of the image
+ """
+ if isinstance(boxes, torch.Tensor): # faster individually
+ boxes[..., 0].clamp_(0, shape[1]) # x1
+ boxes[..., 1].clamp_(0, shape[0]) # y1
+ boxes[..., 2].clamp_(0, shape[1]) # x2
+ boxes[..., 3].clamp_(0, shape[0]) # y2
+ else: # np.array (faster grouped)
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
+
+
+def clip_coords(boxes, shape):
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
+ if isinstance(boxes, torch.Tensor): # faster individually
+ boxes[:, 0].clamp_(0, shape[1]) # x1
+ boxes[:, 1].clamp_(0, shape[0]) # y1
+ boxes[:, 2].clamp_(0, shape[1]) # x2
+ boxes[:, 3].clamp_(0, shape[0]) # y2
+ else: # np.array (faster grouped)
+ boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
+ boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
+
+
+def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
+ """
+ > It takes a mask, and resizes it to the original image size
+
+ Args:
+ im1_shape: model input shape, [h, w]
+ masks: [h, w, num]
+ im0_shape: the original image shape
+ ratio_pad: the ratio of the padding to the original image.
+
+ Returns:
+ The masks are being returned.
+ """
+ # Rescale coordinates (xyxy) from im1_shape to im0_shape
+ if ratio_pad is None: # calculate from im0_shape
+ gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
+ pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
+ else:
+ pad = ratio_pad[1]
+ top, left = int(pad[1]), int(pad[0]) # y, x
+ bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
+
+ if len(masks.shape) < 2:
+ raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
+ masks = masks[top:bottom, left:right]
+ # masks = masks.permute(2, 0, 1).contiguous()
+ # masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
+ # masks = masks.permute(1, 2, 0).contiguous()
+ masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
+
+ if len(masks.shape) == 2:
+ masks = masks[:, :, None]
+ return masks
+
+
+def xyxy2xywh(x):
+ """
+ > It takes a list of bounding boxes, and converts them from the format [x1, y1, x2, y2] to [x, y, w,
+ h] where xy1=top-left, xy2=bottom-right
+
+ Args:
+ x: the input tensor
+
+ Returns:
+ the center of the box, the width and the height of the box.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
+ y[..., 2] = x[..., 2] - x[..., 0] # width
+ y[..., 3] = x[..., 3] - x[..., 1] # height
+ return y
+
+
+def xywh2xyxy(x):
+ """
+ > It converts the bounding box from x,y,w,h to x1,y1,x2,y2 where xy1=top-left, xy2=bottom-right
+
+ Args:
+ x: the input tensor
+
+ Returns:
+ the top left and bottom right coordinates of the bounding box.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
+ y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
+ y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
+ return y
+
+
+def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
+ """
+ > It converts the normalized coordinates to the actual coordinates [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+
+ Args:
+ x: the bounding box coordinates
+ w: width of the image. Defaults to 640
+ h: height of the image. Defaults to 640
+ padw: padding width. Defaults to 0
+ padh: height of the padding. Defaults to 0
+
+ Returns:
+ the xyxy coordinates of the bounding box.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
+ return y
+
+
+def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
+ """
+ > It takes in a list of bounding boxes, and returns a list of bounding boxes, but with the x and y
+ coordinates normalized to the width and height of the image
+
+ Args:
+ x: the bounding box coordinates
+ w: width of the image. Defaults to 640
+ h: height of the image. Defaults to 640
+ clip: If True, the boxes will be clipped to the image boundaries. Defaults to False
+ eps: the minimum value of the box's width and height.
+
+ Returns:
+ the xywhn format of the bounding boxes.
+ """
+ if clip:
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
+ return y
+
+
+def xyn2xy(x, w=640, h=640, padw=0, padh=0):
+ """
+ > It converts normalized segments into pixel segments of shape (n,2)
+
+ Args:
+ x: the normalized coordinates of the bounding box
+ w: width of the image. Defaults to 640
+ h: height of the image. Defaults to 640
+ padw: padding width. Defaults to 0
+ padh: padding height. Defaults to 0
+
+ Returns:
+ the x and y coordinates of the top left corner of the bounding box.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = w * x[..., 0] + padw # top left x
+ y[..., 1] = h * x[..., 1] + padh # top left y
+ return y
+
+
+def xywh2ltwh(x):
+ """
+ > It converts the bounding box from [x, y, w, h] to [x1, y1, w, h] where xy1=top-left
+
+ Args:
+ x: the x coordinate of the center of the bounding box
+
+ Returns:
+ the top left x and y coordinates of the bounding box.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ return y
+
+
+def xyxy2ltwh(x):
+ """
+ > Convert nx4 boxes from [x1, y1, x2, y2] to [x1, y1, w, h] where xy1=top-left, xy2=bottom-right
+
+ Args:
+ x: the input tensor
+
+ Returns:
+ the xyxy2ltwh function.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 2] = x[:, 2] - x[:, 0] # width
+ y[:, 3] = x[:, 3] - x[:, 1] # height
+ return y
+
+
+def ltwh2xywh(x):
+ """
+ > Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
+
+ Args:
+ x: the input tensor
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
+ y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
+ return y
+
+
+def ltwh2xyxy(x):
+ """
+ > It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left,
+ xy2=bottom-right
+
+ Args:
+ x: the input image
+
+ Returns:
+ the xyxy coordinates of the bounding boxes.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 2] = x[:, 2] + x[:, 0] # width
+ y[:, 3] = x[:, 3] + x[:, 1] # height
+ return y
+
+
+def segments2boxes(segments):
+ """
+ > It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
+
+ Args:
+ segments: list of segments, each segment is a list of points, each point is a list of x, y
+ coordinates
+
+ Returns:
+ the xywh coordinates of the bounding boxes.
+ """
+ boxes = []
+ for s in segments:
+ x, y = s.T # segment xy
+ boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
+ return xyxy2xywh(np.array(boxes)) # cls, xywh
+
+
+def resample_segments(segments, n=1000):
+ """
+ > It takes a list of segments (n,2) and returns a list of segments (n,2) where each segment has been
+ up-sampled to n points
+
+ Args:
+ segments: a list of (n,2) arrays, where n is the number of points in the segment.
+ n: number of points to resample the segment to. Defaults to 1000
+
+ Returns:
+ the resampled segments.
+ """
+ for i, s in enumerate(segments):
+ s = np.concatenate((s, s[0:1, :]), axis=0)
+ x = np.linspace(0, len(s) - 1, n)
+ xp = np.arange(len(s))
+ segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
+ return segments
+
+
+def crop_mask(masks, boxes):
+ """
+ > It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
+
+ Args:
+ masks: [h, w, n] tensor of masks
+ boxes: [n, 4] tensor of bbox coords in relative point form
+
+ Returns:
+ The masks are being cropped to the bounding box.
+ """
+ n, h, w = masks.shape
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
+ r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
+ c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
+
+ return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
+
+
+def process_mask_upsample(protos, masks_in, bboxes, shape):
+ """
+ > It takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
+ quality but is slower.
+
+ Args:
+ protos: [mask_dim, mask_h, mask_w]
+ masks_in: [n, mask_dim], n is number of masks after nms
+ bboxes: [n, 4], n is number of masks after nms
+ shape: the size of the input image
+
+ Returns:
+ mask
+ """
+ c, mh, mw = protos.shape # CHW
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
+ masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
+ masks = crop_mask(masks, bboxes) # CHW
+ return masks.gt_(0.5)
+
+
+def process_mask(protos, masks_in, bboxes, shape, upsample=False):
+ """
+ > It takes the output of the mask head, and applies the mask to the bounding boxes. This is faster but produces
+ downsampled quality of mask
+
+ Args:
+ protos: [mask_dim, mask_h, mask_w]
+ masks_in: [n, mask_dim], n is number of masks after nms
+ bboxes: [n, 4], n is number of masks after nms
+ shape: the size of the input image
+
+ Returns:
+ mask
+ """
+
+ c, mh, mw = protos.shape # CHW
+ ih, iw = shape
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
+
+ downsampled_bboxes = bboxes.clone()
+ downsampled_bboxes[:, 0] *= mw / iw
+ downsampled_bboxes[:, 2] *= mw / iw
+ downsampled_bboxes[:, 3] *= mh / ih
+ downsampled_bboxes[:, 1] *= mh / ih
+
+ masks = crop_mask(masks, downsampled_bboxes) # CHW
+ if upsample:
+ masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
+ return masks.gt_(0.5)
+
+
+def process_mask_native(protos, masks_in, bboxes, shape):
+ """
+ > It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
+
+ Args:
+ protos: [mask_dim, mask_h, mask_w]
+ masks_in: [n, mask_dim], n is number of masks after nms
+ bboxes: [n, 4], n is number of masks after nms
+ shape: input_image_size, (h, w)
+
+ Returns:
+ masks: [h, w, n]
+ """
+ c, mh, mw = protos.shape # CHW
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
+ gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
+ pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
+ top, left = int(pad[1]), int(pad[0]) # y, x
+ bottom, right = int(mh - pad[1]), int(mw - pad[0])
+ masks = masks[:, top:bottom, left:right]
+
+ masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
+ masks = crop_mask(masks, bboxes) # CHW
+ return masks.gt_(0.5)
+
+
+def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
+ """
+ > Rescale segment coords (xyxy) from img1_shape to img0_shape
+
+ Args:
+ img1_shape: The shape of the image that the segments are from.
+ segments: the segments to be scaled
+ img0_shape: the shape of the image that the segmentation is being applied to
+ ratio_pad: the ratio of the image size to the padded image size.
+ normalize: If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
+
+ Returns:
+ the segmented image.
+ """
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ segments[:, 0] -= pad[0] # x padding
+ segments[:, 1] -= pad[1] # y padding
+ segments /= gain
+ clip_segments(segments, img0_shape)
+ if normalize:
+ segments[:, 0] /= img0_shape[1] # width
+ segments[:, 1] /= img0_shape[0] # height
+ return segments
+
+
+def masks2segments(masks, strategy='largest'):
+ """
+ > It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
+
+ Args:
+ masks: the output of the model, which is a tensor of shape (batch_size, 160, 160)
+ strategy: 'concat' or 'largest'. Defaults to largest
+
+ Returns:
+ segments (List): list of segment masks
+ """
+ segments = []
+ for x in masks.int().cpu().numpy().astype('uint8'):
+ c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
+ if c:
+ if strategy == 'concat': # concatenate all segments
+ c = np.concatenate([x.reshape(-1, 2) for x in c])
+ elif strategy == 'largest': # select largest segment
+ c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
+ else:
+ c = np.zeros((0, 2)) # no segments found
+ segments.append(c.astype('float32'))
+ return segments
+
+
+def clip_segments(segments, shape):
+ """
+ > It takes a list of line segments (x1,y1,x2,y2) and clips them to the image shape (height, width)
+
+ Args:
+ segments: a list of segments, each segment is a list of points, each point is a list of x,y
+ coordinates
+ shape: the shape of the image
+ """
+ if isinstance(segments, torch.Tensor): # faster individually
+ segments[:, 0].clamp_(0, shape[1]) # x
+ segments[:, 1].clamp_(0, shape[0]) # y
+ else: # np.array (faster grouped)
+ segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
+ segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
+
+
+def clean_str(s):
+ # Cleans a string by replacing special characters with underscore _
+ return re.sub(pattern="[|@#!‘·$β¬%&()=?ΒΏ^*;:,¨´><+]", repl="_", string=s)
diff --git a/ultralytics/yolo/utils/plotting.py b/ultralytics/yolo/utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..f03a22a5746a1d04ee126cdccb95e26303d2c1af
--- /dev/null
+++ b/ultralytics/yolo/utils/plotting.py
@@ -0,0 +1,319 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import contextlib
+import math
+from pathlib import Path
+from urllib.error import URLError
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+from PIL import Image, ImageDraw, ImageFont
+
+from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded
+
+from .checks import check_font, check_requirements, is_ascii
+from .files import increment_path
+from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh
+
+
+class Colors:
+ # Ultralytics color palette https://ultralytics.com/
+ def __init__(self):
+ # hex = matplotlib.colors.TABLEAU_COLORS.values()
+ hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
+ '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
+ self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
+ self.n = len(self.palette)
+
+ def __call__(self, i, bgr=False):
+ c = self.palette[int(i) % self.n]
+ return (c[2], c[1], c[0]) if bgr else c
+
+ @staticmethod
+ def hex2rgb(h): # rgb order (PIL)
+ return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
+
+
+colors = Colors() # create instance for 'from utils.plots import colors'
+
+
+class Annotator:
+ # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
+ def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
+ assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
+ non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
+ self.pil = pil or non_ascii
+ if self.pil: # use PIL
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
+ self.draw = ImageDraw.Draw(self.im)
+ self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
+ size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
+ else: # use cv2
+ self.im = im
+ self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
+
+ def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
+ # Add one xyxy box to image with label
+ if self.pil or not is_ascii(label):
+ self.draw.rectangle(box, width=self.lw, outline=color) # box
+ if label:
+ w, h = self.font.getsize(label) # text width, height
+ outside = box[1] - h >= 0 # label fits outside box
+ self.draw.rectangle(
+ (box[0], box[1] - h if outside else box[1], box[0] + w + 1,
+ box[1] + 1 if outside else box[1] + h + 1),
+ fill=color,
+ )
+ # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
+ self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
+ else: # cv2
+ p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
+ cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
+ if label:
+ tf = max(self.lw - 1, 1) # font thickness
+ w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
+ outside = p1[1] - h >= 3
+ p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
+ cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
+ cv2.putText(self.im,
+ label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
+ 0,
+ self.lw / 3,
+ txt_color,
+ thickness=tf,
+ lineType=cv2.LINE_AA)
+
+ def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
+ """Plot masks at once.
+ Args:
+ masks (tensor): predicted masks on cuda, shape: [n, h, w]
+ colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
+ im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
+ alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
+ """
+ if self.pil:
+ # convert to numpy first
+ self.im = np.asarray(self.im).copy()
+ if len(masks) == 0:
+ self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
+ colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
+ colors = colors[:, None, None] # shape(n,1,1,3)
+ masks = masks.unsqueeze(3) # shape(n,h,w,1)
+ masks_color = masks * (colors * alpha) # shape(n,h,w,3)
+
+ inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
+ mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
+
+ im_gpu = im_gpu.flip(dims=[0]) # flip channel
+ im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
+ im_gpu = im_gpu * inv_alph_masks[-1] + mcs
+ im_mask = (im_gpu * 255)
+ im_mask_np = im_mask.byte().cpu().numpy()
+ self.im[:] = im_mask_np if retina_masks else scale_image(im_gpu.shape, im_mask_np, self.im.shape)
+ if self.pil:
+ # convert im back to PIL and update draw
+ self.fromarray(self.im)
+
+ def rectangle(self, xy, fill=None, outline=None, width=1):
+ # Add rectangle to image (PIL-only)
+ self.draw.rectangle(xy, fill, outline, width)
+
+ def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
+ # Add text to image (PIL-only)
+ if anchor == 'bottom': # start y from font bottom
+ w, h = self.font.getsize(text) # text width, height
+ xy[1] += 1 - h
+ self.draw.text(xy, text, fill=txt_color, font=self.font)
+
+ def fromarray(self, im):
+ # Update self.im from a numpy array
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
+ self.draw = ImageDraw.Draw(self.im)
+
+ def result(self):
+ # Return annotated image as array
+ return np.asarray(self.im)
+
+
+def check_pil_font(font=FONT, size=10):
+ # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
+ font = Path(font)
+ font = font if font.exists() else (USER_CONFIG_DIR / font.name)
+ try:
+ return ImageFont.truetype(str(font) if font.exists() else font.name, size)
+ except Exception: # download if missing
+ try:
+ check_font(font)
+ return ImageFont.truetype(str(font), size)
+ except TypeError:
+ check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
+ except URLError: # not online
+ return ImageFont.load_default()
+
+
+def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
+ # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
+ xyxy = torch.tensor(xyxy).view(-1, 4)
+ b = xyxy2xywh(xyxy) # boxes
+ if square:
+ b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
+ b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
+ xyxy = xywh2xyxy(b).long()
+ clip_coords(xyxy, im.shape)
+ crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
+ if save:
+ file.parent.mkdir(parents=True, exist_ok=True) # make directory
+ f = str(increment_path(file).with_suffix('.jpg'))
+ # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
+ Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
+ return crop
+
+
+@threaded
+def plot_images(images,
+ batch_idx,
+ cls,
+ bboxes,
+ masks=np.zeros(0, dtype=np.uint8),
+ paths=None,
+ fname='images.jpg',
+ names=None):
+ # Plot image grid with labels
+ if isinstance(images, torch.Tensor):
+ images = images.cpu().float().numpy()
+ if isinstance(cls, torch.Tensor):
+ cls = cls.cpu().numpy()
+ if isinstance(bboxes, torch.Tensor):
+ bboxes = bboxes.cpu().numpy()
+ if isinstance(masks, torch.Tensor):
+ masks = masks.cpu().numpy().astype(int)
+ if isinstance(batch_idx, torch.Tensor):
+ batch_idx = batch_idx.cpu().numpy()
+
+ max_size = 1920 # max image size
+ max_subplots = 16 # max image subplots, i.e. 4x4
+ bs, _, h, w = images.shape # batch size, _, height, width
+ bs = min(bs, max_subplots) # limit plot images
+ ns = np.ceil(bs ** 0.5) # number of subplots (square)
+ if np.max(images[0]) <= 1:
+ images *= 255 # de-normalise (optional)
+
+ # Build Image
+ mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
+ for i, im in enumerate(images):
+ if i == max_subplots: # if last batch has fewer images than we expect
+ break
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
+ im = im.transpose(1, 2, 0)
+ mosaic[y:y + h, x:x + w, :] = im
+
+ # Resize (optional)
+ scale = max_size / ns / max(h, w)
+ if scale < 1:
+ h = math.ceil(scale * h)
+ w = math.ceil(scale * w)
+ mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
+
+ # Annotate
+ fs = int((h + w) * ns * 0.01) # font size
+ annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
+ for i in range(i + 1):
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
+ annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
+ if paths:
+ annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
+ if len(cls) > 0:
+ idx = batch_idx == i
+
+ boxes = xywh2xyxy(bboxes[idx, :4]).T
+ classes = cls[idx].astype('int')
+ labels = bboxes.shape[1] == 4 # labels if no conf column
+ conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred)
+
+ if boxes.shape[1]:
+ if boxes.max() <= 1.01: # if normalized with tolerance 0.01
+ boxes[[0, 2]] *= w # scale to pixels
+ boxes[[1, 3]] *= h
+ elif scale < 1: # absolute coords need scale if image scales
+ boxes *= scale
+ boxes[[0, 2]] += x
+ boxes[[1, 3]] += y
+ for j, box in enumerate(boxes.T.tolist()):
+ c = classes[j]
+ color = colors(c)
+ c = names[c] if names else c
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
+ label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
+ annotator.box_label(box, label, color=color)
+
+ # Plot masks
+ if len(masks):
+ if masks.max() > 1.0: # mean that masks are overlap
+ image_masks = masks[[i]] # (1, 640, 640)
+ nl = idx.sum()
+ index = np.arange(nl).reshape(nl, 1, 1) + 1
+ image_masks = np.repeat(image_masks, nl, axis=0)
+ image_masks = np.where(image_masks == index, 1.0, 0.0)
+ else:
+ image_masks = masks[idx]
+
+ im = np.asarray(annotator.im).copy()
+ for j, box in enumerate(boxes.T.tolist()):
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
+ color = colors(classes[j])
+ mh, mw = image_masks[j].shape
+ if mh != h or mw != w:
+ mask = image_masks[j].astype(np.uint8)
+ mask = cv2.resize(mask, (w, h))
+ mask = mask.astype(bool)
+ else:
+ mask = image_masks[j].astype(bool)
+ with contextlib.suppress(Exception):
+ im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
+ annotator.fromarray(im)
+ annotator.im.save(fname) # save
+
+
+def plot_results(file='path/to/results.csv', dir='', segment=False):
+ # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
+ save_dir = Path(file).parent if file else Path(dir)
+ if segment:
+ fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
+ index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]
+ else:
+ fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
+ index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
+ ax = ax.ravel()
+ files = list(save_dir.glob('results*.csv'))
+ assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
+ for f in files:
+ try:
+ data = pd.read_csv(f)
+ s = [x.strip() for x in data.columns]
+ x = data.values[:, 0]
+ for i, j in enumerate(index):
+ y = data.values[:, j].astype('float')
+ # y[y == 0] = np.nan # don't show zero values
+ ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
+ ax[i].set_title(s[j], fontsize=12)
+ # if j in [8, 9, 10]: # share train and val loss y axes
+ # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
+ except Exception as e:
+ print(f'Warning: Plotting error for {f}: {e}')
+ ax[1].legend()
+ fig.savefig(save_dir / 'results.png', dpi=200)
+ plt.close()
+
+
+def output_to_target(output, max_det=300):
+ # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
+ targets = []
+ for i, o in enumerate(output):
+ box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
+ j = torch.full((conf.shape[0], 1), i)
+ targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
+ targets = torch.cat(targets, 0).numpy()
+ return targets[:, 0], targets[:, 1], targets[:, 2:]
diff --git a/ultralytics/yolo/utils/tal.py b/ultralytics/yolo/utils/tal.py
new file mode 100644
index 0000000000000000000000000000000000000000..98481ad53af41e431d2809281cf253462bc8cc1e
--- /dev/null
+++ b/ultralytics/yolo/utils/tal.py
@@ -0,0 +1,211 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .checks import check_version
+from .metrics import bbox_iou
+
+TORCH_1_10 = check_version(torch.__version__, '1.10.0')
+
+
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+ """select the positive anchor center in gt
+
+ Args:
+ xy_centers (Tensor): shape(h*w, 4)
+ gt_bboxes (Tensor): shape(b, n_boxes, 4)
+ Return:
+ (Tensor): shape(b, n_boxes, h*w)
+ """
+ n_anchors = xy_centers.shape[0]
+ bs, n_boxes, _ = gt_bboxes.shape
+ lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
+ bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
+ # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
+ return bbox_deltas.amin(3).gt_(eps)
+
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+ """if an anchor box is assigned to multiple gts,
+ the one with the highest iou will be selected.
+
+ Args:
+ mask_pos (Tensor): shape(b, n_max_boxes, h*w)
+ overlaps (Tensor): shape(b, n_max_boxes, h*w)
+ Return:
+ target_gt_idx (Tensor): shape(b, h*w)
+ fg_mask (Tensor): shape(b, h*w)
+ mask_pos (Tensor): shape(b, n_max_boxes, h*w)
+ """
+ # (b, n_max_boxes, h*w) -> (b, h*w)
+ fg_mask = mask_pos.sum(-2)
+ if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
+ mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) # (b, n_max_boxes, h*w)
+ max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
+ is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes) # (b, h*w, n_max_boxes)
+ is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # (b, n_max_boxes, h*w)
+ mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) # (b, n_max_boxes, h*w)
+ fg_mask = mask_pos.sum(-2)
+ # find each grid serve which gt(index)
+ target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
+ return target_gt_idx, fg_mask, mask_pos
+
+
+class TaskAlignedAssigner(nn.Module):
+
+ def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
+ super().__init__()
+ self.topk = topk
+ self.num_classes = num_classes
+ self.bg_idx = num_classes
+ self.alpha = alpha
+ self.beta = beta
+ self.eps = eps
+
+ @torch.no_grad()
+ def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
+ """This code referenced to
+ https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
+
+ Args:
+ pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+ pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+ anc_points (Tensor): shape(num_total_anchors, 2)
+ gt_labels (Tensor): shape(bs, n_max_boxes, 1)
+ gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+ mask_gt (Tensor): shape(bs, n_max_boxes, 1)
+ Returns:
+ target_labels (Tensor): shape(bs, num_total_anchors)
+ target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+ target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+ fg_mask (Tensor): shape(bs, num_total_anchors)
+ """
+ self.bs = pd_scores.size(0)
+ self.n_max_boxes = gt_bboxes.size(1)
+
+ if self.n_max_boxes == 0:
+ device = gt_bboxes.device
+ return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
+ torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
+ torch.zeros_like(pd_scores[..., 0]).to(device))
+
+ mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
+ mask_gt)
+
+ target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
+
+ # assigned target
+ target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+ # normalize
+ align_metric *= mask_pos
+ pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj
+ pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj
+ norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+ target_scores = target_scores * norm_align_metric
+
+ return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+ def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
+ # get anchor_align metric, (b, max_num_obj, h*w)
+ align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
+ # get in_gts mask, (b, max_num_obj, h*w)
+ mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+ # get topk_metric mask, (b, max_num_obj, h*w)
+ mask_topk = self.select_topk_candidates(align_metric * mask_in_gts,
+ topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
+ # merge all mask to a final mask, (b, max_num_obj, h*w)
+ mask_pos = mask_topk * mask_in_gts * mask_gt
+
+ return mask_pos, align_metric, overlaps
+
+ def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
+ ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
+ ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj
+ ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj
+ # get the scores of each grid for each gt cls
+ bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w
+
+ overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False, CIoU=True).squeeze(3).clamp(0)
+ align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+ return align_metric, overlaps
+
+ def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
+ """
+ Args:
+ metrics: (b, max_num_obj, h*w).
+ topk_mask: (b, max_num_obj, topk) or None
+ """
+
+ num_anchors = metrics.shape[-1] # h*w
+ # (b, max_num_obj, topk)
+ topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
+ if topk_mask is None:
+ topk_mask = (topk_metrics.max(-1, keepdim=True) > self.eps).tile([1, 1, self.topk])
+ # (b, max_num_obj, topk)
+ topk_idxs = torch.where(topk_mask, topk_idxs, 0)
+ # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+ is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
+ # filter invalid bboxes
+ is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
+ return is_in_topk.to(metrics.dtype)
+
+ def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+ """
+ Args:
+ gt_labels: (b, max_num_obj, 1)
+ gt_bboxes: (b, max_num_obj, 4)
+ target_gt_idx: (b, h*w)
+ fg_mask: (b, h*w)
+ """
+
+ # assigned target labels, (b, 1)
+ batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+ target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
+ target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
+
+ # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
+ target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+ # assigned target scores
+ target_labels.clamp(0)
+ target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80)
+ fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
+ target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+ return target_labels, target_bboxes, target_scores
+
+
+def make_anchors(feats, strides, grid_cell_offset=0.5):
+ """Generate anchors from features."""
+ anchor_points, stride_tensor = [], []
+ assert feats is not None
+ dtype, device = feats[0].dtype, feats[0].device
+ for i, stride in enumerate(strides):
+ _, _, h, w = feats[i].shape
+ sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
+ sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
+ sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
+ anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
+ stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
+ return torch.cat(anchor_points), torch.cat(stride_tensor)
+
+
+def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
+ """Transform distance(ltrb) to box(xywh or xyxy)."""
+ lt, rb = torch.split(distance, 2, dim)
+ x1y1 = anchor_points - lt
+ x2y2 = anchor_points + rb
+ if xywh:
+ c_xy = (x1y1 + x2y2) / 2
+ wh = x2y2 - x1y1
+ return torch.cat((c_xy, wh), dim) # xywh bbox
+ return torch.cat((x1y1, x2y2), dim) # xyxy bbox
+
+
+def bbox2dist(anchor_points, bbox, reg_max):
+ """Transform bbox(xyxy) to dist(ltrb)."""
+ x1y1, x2y2 = torch.split(bbox, 2, -1)
+ return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp(0, reg_max - 0.01) # dist (lt, rb)
diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e06224c1f81d678403be7f53cd4d27630541d5d
--- /dev/null
+++ b/ultralytics/yolo/utils/torch_utils.py
@@ -0,0 +1,369 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import math
+import os
+import platform
+import random
+import time
+from contextlib import contextmanager
+from copy import deepcopy
+from pathlib import Path
+
+import numpy as np
+import thop
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import ultralytics
+from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
+from ultralytics.yolo.utils.checks import git_describe
+
+from .checks import check_version
+
+LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
+RANK = int(os.getenv('RANK', -1))
+WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
+
+
+@contextmanager
+def torch_distributed_zero_first(local_rank: int):
+ # Decorator to make all processes in distributed training wait for each local_master to do something
+ initialized = torch.distributed.is_initialized() # prevent 'Default process group has not been initialized' errors
+ if initialized and local_rank not in {-1, 0}:
+ dist.barrier(device_ids=[local_rank])
+ yield
+ if initialized and local_rank == 0:
+ dist.barrier(device_ids=[0])
+
+
+def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
+ # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
+ def decorate(fn):
+ return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
+
+ return decorate
+
+
+def DDP_model(model):
+ # Model DDP creation with checks
+ assert not check_version(torch.__version__, '1.12.0', pinned=True), \
+ 'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
+ 'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
+ if check_version(torch.__version__, '1.11.0'):
+ return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
+ else:
+ return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
+
+
+def select_device(device='', batch_size=0, newline=False):
+ # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
+ ver = git_describe() or ultralytics.__version__ # git commit or pip package version
+ s = f'Ultralytics YOLOv{ver} π Python-{platform.python_version()} torch-{torch.__version__} '
+ device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
+ cpu = device == 'cpu'
+ mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
+ if cpu or mps:
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
+ elif device: # non-cpu device requested
+ os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
+ assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
+ f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
+
+ if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
+ devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
+ n = len(devices) # device count
+ if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
+ assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
+ space = ' ' * (len(s) + 1)
+ for i, d in enumerate(devices):
+ p = torch.cuda.get_device_properties(i)
+ s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
+ arg = 'cuda:0'
+ elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
+ s += 'MPS\n'
+ arg = 'mps'
+ else: # revert to CPU
+ s += 'CPU\n'
+ arg = 'cpu'
+
+ if RANK == -1:
+ LOGGER.info(s if newline else s.rstrip())
+ return torch.device(arg)
+
+
+def time_sync():
+ # PyTorch-accurate time
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ return time.time()
+
+
+def fuse_conv_and_bn(conv, bn):
+ # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
+ fusedconv = nn.Conv2d(conv.in_channels,
+ conv.out_channels,
+ kernel_size=conv.kernel_size,
+ stride=conv.stride,
+ padding=conv.padding,
+ dilation=conv.dilation,
+ groups=conv.groups,
+ bias=True).requires_grad_(False).to(conv.weight.device)
+
+ # Prepare filters
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
+
+ # Prepare spatial bias
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+ return fusedconv
+
+
+def model_info(model, verbose=False, imgsz=640):
+ # Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
+ n_p = get_num_params(model)
+ n_g = get_num_gradients(model) # number gradients
+ if verbose:
+ print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
+ for i, (name, p) in enumerate(model.named_parameters()):
+ name = name.replace('module_list.', '')
+ print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
+ (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
+
+ flops = get_flops(model, imgsz)
+ fs = f', {flops:.1f} GFLOPs' if flops else ''
+ m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
+ LOGGER.info(f"{m} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
+
+
+def get_num_params(model):
+ return sum(x.numel() for x in model.parameters())
+
+
+def get_num_gradients(model):
+ return sum(x.numel() for x in model.parameters() if x.requires_grad)
+
+
+def get_flops(model, imgsz=640):
+ try:
+ model = de_parallel(model)
+ p = next(model.parameters())
+ stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
+ flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
+ imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
+ flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
+ return flops
+ except Exception:
+ return 0
+
+
+def initialize_weights(model):
+ for m in model.modules():
+ t = type(m)
+ if t is nn.Conv2d:
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif t is nn.BatchNorm2d:
+ m.eps = 1e-3
+ m.momentum = 0.03
+ elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+ m.inplace = True
+
+
+def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
+ # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
+ if ratio == 1.0:
+ return img
+ h, w = img.shape[2:]
+ s = (int(h * ratio), int(w * ratio)) # new size
+ img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
+ if not same_shape: # pad/crop img
+ h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
+
+
+def make_divisible(x, divisor):
+ # Returns nearest x divisible by divisor
+ if isinstance(divisor, torch.Tensor):
+ divisor = int(divisor.max()) # to int
+ return math.ceil(x / divisor) * divisor
+
+
+def copy_attr(a, b, include=(), exclude=()):
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
+ for k, v in b.__dict__.items():
+ if (len(include) and k not in include) or k.startswith('_') or k in exclude:
+ continue
+ else:
+ setattr(a, k, v)
+
+
+def intersect_dicts(da, db, exclude=()):
+ # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
+ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
+
+
+def is_parallel(model):
+ # Returns True if model is of type DP or DDP
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+
+
+def de_parallel(model):
+ # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
+ return model.module if is_parallel(model) else model
+
+
+def one_cycle(y1=0.0, y2=1.0, steps=100):
+ # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
+ return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
+
+
+def init_seeds(seed=0, deterministic=False):
+ # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
+ # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
+ if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
+ torch.use_deterministic_algorithms(True)
+ torch.backends.cudnn.deterministic = True
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
+ os.environ['PYTHONHASHSEED'] = str(seed)
+
+
+class ModelEMA:
+ """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
+ Keeps a moving average of everything in the model state_dict (parameters and buffers)
+ For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+ """
+
+ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
+ # Create EMA
+ self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
+ self.updates = updates # number of EMA updates
+ self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
+ for p in self.ema.parameters():
+ p.requires_grad_(False)
+
+ def update(self, model):
+ # Update EMA parameters
+ self.updates += 1
+ d = self.decay(self.updates)
+
+ msd = de_parallel(model).state_dict() # model state_dict
+ for k, v in self.ema.state_dict().items():
+ if v.dtype.is_floating_point: # true for FP16 and FP32
+ v *= d
+ v += (1 - d) * msd[k].detach()
+ # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
+
+ def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
+ # Update EMA attributes
+ copy_attr(self.ema, model, include, exclude)
+
+
+def strip_optimizer(f='best.pt', s=''):
+ """
+ Strip optimizer from 'f' to finalize training, optionally save as 's'.
+
+ Usage:
+ from ultralytics.yolo.utils.torch_utils import strip_optimizer
+ from pathlib import Path
+ for f in Path('/Users/glennjocher/Downloads/weights').glob('*.pt'):
+ strip_optimizer(f)
+
+ Args:
+ f (str): file path to model state to strip the optimizer from. Default is 'best.pt'.
+ s (str): file path to save the model with stripped optimizer to. Default is ''. If not provided, the original file will be overwritten.
+
+ Returns:
+ None
+ """
+ x = torch.load(f, map_location=torch.device('cpu'))
+ args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
+ if x.get('ema'):
+ x['model'] = x['ema'] # replace model with ema
+ for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
+ x[k] = None
+ x['epoch'] = -1
+ x['model'].half() # to FP16
+ for p in x['model'].parameters():
+ p.requires_grad = False
+ x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
+ torch.save(x, s or f)
+ mb = os.path.getsize(s or f) / 1E6 # filesize
+ LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
+
+
+def guess_task_from_head(head):
+ task = None
+ if head.lower() in ["classify", "classifier", "cls", "fc"]:
+ task = "classify"
+ if head.lower() in ["detect"]:
+ task = "detect"
+ if head.lower() in ["segment"]:
+ task = "segment"
+
+ if not task:
+ raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
+
+ return task
+
+
+def profile(input, ops, n=10, device=None):
+ """ YOLOv5 speed/memory/FLOPs profiler
+ Usage:
+ input = torch.randn(16, 3, 640, 640)
+ m1 = lambda x: x * torch.sigmoid(x)
+ m2 = nn.SiLU()
+ profile(input, [m1, m2], n=100) # profile over 100 iterations
+ """
+ results = []
+ if not isinstance(device, torch.device):
+ device = select_device(device)
+ print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
+ f"{'input':>24s}{'output':>24s}")
+
+ for x in input if isinstance(input, list) else [input]:
+ x = x.to(device)
+ x.requires_grad = True
+ for m in ops if isinstance(ops, list) else [ops]:
+ m = m.to(device) if hasattr(m, 'to') else m # device
+ m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
+ tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
+ try:
+ flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
+ except Exception:
+ flops = 0
+
+ try:
+ for _ in range(n):
+ t[0] = time_sync()
+ y = m(x)
+ t[1] = time_sync()
+ try:
+ _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
+ t[2] = time_sync()
+ except Exception: # no backward method
+ # print(e) # for debug
+ t[2] = float('nan')
+ tf += (t[1] - t[0]) * 1000 / n # ms per op forward
+ tb += (t[2] - t[1]) * 1000 / n # ms per op backward
+ mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
+ s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
+ p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
+ print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
+ results.append([p, flops, mem, tf, tb, s_in, s_out])
+ except Exception as e:
+ print(e)
+ results.append(None)
+ torch.cuda.empty_cache()
+ return results
diff --git a/ultralytics/yolo/v8/detect/__init__.py b/ultralytics/yolo/v8/detect/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aaa30d4d9df9cc07e51c29697de99fae5086595
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/__init__.py
@@ -0,0 +1,5 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from .predict import DetectionPredictor, predict
+from .train import DetectionTrainer, train
+from .val import DetectionValidator, val
diff --git a/ultralytics/yolo/v8/detect/configs/__init__.py b/ultralytics/yolo/v8/detect/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28093039288327fdcf73435b3b52a0dbf0ff78c3
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/configs/__init__.py
@@ -0,0 +1,36 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from pathlib import Path
+from typing import Dict, Union
+
+from omegaconf import DictConfig, OmegaConf
+
+from ultralytics.yolo.configs.hydra_patch import check_config_mismatch
+
+
+def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = None):
+ """
+ Load and merge configuration data from a file or dictionary.
+
+ Args:
+ config (Union[str, DictConfig]): Configuration data in the form of a file name or a DictConfig object.
+ overrides (Union[str, Dict], optional): Overrides in the form of a file name or a dictionary. Default is None.
+
+ Returns:
+ OmegaConf.Namespace: Training arguments namespace.
+ """
+ if overrides is None:
+ overrides = {}
+ if isinstance(config, (str, Path)):
+ config = OmegaConf.load(config)
+ elif isinstance(config, Dict):
+ config = OmegaConf.create(config)
+ # override
+ if isinstance(overrides, str):
+ overrides = OmegaConf.load(overrides)
+ elif isinstance(overrides, Dict):
+ overrides = OmegaConf.create(overrides)
+
+ check_config_mismatch(dict(overrides).keys(), dict(config).keys())
+
+ return OmegaConf.merge(config, overrides)
diff --git a/ultralytics/yolo/v8/detect/configs/default.yaml b/ultralytics/yolo/v8/detect/configs/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..669ec6ef7454698d458b7c519b6f0c14d47bf117
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/configs/default.yaml
@@ -0,0 +1,110 @@
+# Ultralytics YOLO π, GPL-3.0 license
+# Default training settings and hyperparameters for medium-augmentation COCO training
+
+task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
+mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
+
+# Train settings -------------------------------------------------------------------------------------------------------
+model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
+data: null # i.e. coco128.yaml. Path to data file
+epochs: 100 # number of epochs to train for
+patience: 50 # TODO: epochs to wait for no observable improvement for early stopping of training
+batch: 16 # number of images per batch
+imgsz: 640 # size of input images
+save: True # save checkpoints
+cache: False # True/ram, disk or False. Use cache for data loading
+device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
+workers: 8 # number of worker threads for data loading
+project: null # project name
+name: null # experiment name
+exist_ok: False # whether to overwrite existing experiment
+pretrained: False # whether to use a pretrained model
+optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
+verbose: False # whether to print verbose output
+seed: 0 # random seed for reproducibility
+deterministic: True # whether to enable deterministic mode
+single_cls: False # train multi-class data as single-class
+image_weights: False # use weighted image selection for training
+rect: False # support rectangular training
+cos_lr: False # use cosine learning rate scheduler
+close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
+resume: False # resume training from last checkpoint
+# Segmentation
+overlap_mask: True # masks should overlap during training
+mask_ratio: 4 # mask downsample ratio
+# Classification
+dropout: 0.0 # use dropout regularization
+
+# Val/Test settings ----------------------------------------------------------------------------------------------------
+val: True # validate/test during training
+save_json: False # save results to JSON file
+save_hybrid: False # save hybrid version of labels (labels + additional predictions)
+conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val)
+iou: 0.7 # intersection over union (IoU) threshold for NMS
+max_det: 300 # maximum number of detections per image
+half: False # use half precision (FP16)
+dnn: False # use OpenCV DNN for ONNX inference
+plots: True # show plots during training
+
+# Prediction settings --------------------------------------------------------------------------------------------------
+source: null # source directory for images or videos
+show: False # show results if possible
+save_txt: False # save results as .txt file
+save_conf: False # save results with confidence scores
+save_crop: False # save cropped images with results
+hide_labels: False # hide labels
+hide_conf: False # hide confidence scores
+vid_stride: 1 # video frame-rate stride
+line_thickness: 3 # bounding box thickness (pixels)
+visualize: False # visualize results
+augment: False # apply data augmentation to images
+agnostic_nms: False # class-agnostic NMS
+retina_masks: False # use retina masks for object detection
+
+# Export settings ------------------------------------------------------------------------------------------------------
+format: torchscript # format to export to
+keras: False # use Keras
+optimize: False # TorchScript: optimize for mobile
+int8: False # CoreML/TF INT8 quantization
+dynamic: False # ONNX/TF/TensorRT: dynamic axes
+simplify: False # ONNX: simplify model
+opset: 17 # ONNX: opset version
+workspace: 4 # TensorRT: workspace size (GB)
+nms: False # CoreML: add NMS
+
+# Hyperparameters ------------------------------------------------------------------------------------------------------
+lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
+lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
+momentum: 0.937 # SGD momentum/Adam beta1
+weight_decay: 0.0005 # optimizer weight decay 5e-4
+warmup_epochs: 3.0 # warmup epochs (fractions ok)
+warmup_momentum: 0.8 # warmup initial momentum
+warmup_bias_lr: 0.1 # warmup initial bias lr
+box: 7.5 # box loss gain
+cls: 0.5 # cls loss gain (scale with pixels)
+dfl: 1.5 # dfl loss gain
+fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
+label_smoothing: 0.0
+nbs: 64 # nominal batch size
+hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
+hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
+hsv_v: 0.4 # image HSV-Value augmentation (fraction)
+degrees: 0.0 # image rotation (+/- deg)
+translate: 0.1 # image translation (+/- fraction)
+scale: 0.5 # image scale (+/- gain)
+shear: 0.0 # image shear (+/- deg)
+perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
+flipud: 0.0 # image flip up-down (probability)
+fliplr: 0.5 # image flip left-right (probability)
+mosaic: 1.0 # image mosaic (probability)
+mixup: 0.0 # image mixup (probability)
+copy_paste: 0.0 # segment copy-paste (probability)
+
+# Hydra configs --------------------------------------------------------------------------------------------------------
+hydra:
+ output_subdir: null # disable hydra directory creation
+ run:
+ dir: .
+
+# Debug, do not modify -------------------------------------------------------------------------------------------------
+v5loader: False # use legacy YOLOv5 dataloader
diff --git a/ultralytics/yolo/v8/detect/configs/hydra_patch.py b/ultralytics/yolo/v8/detect/configs/hydra_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d68d72576cd6340a1dc9c64e6759d8feec287c2a
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/configs/hydra_patch.py
@@ -0,0 +1,77 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import sys
+from difflib import get_close_matches
+from textwrap import dedent
+
+import hydra
+from hydra.errors import ConfigCompositionException
+from omegaconf import OmegaConf, open_dict # noqa
+from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException # noqa
+
+from ultralytics.yolo.utils import LOGGER, colorstr
+
+
+def override_config(overrides, cfg):
+ override_keys = [override.key_or_group for override in overrides]
+ check_config_mismatch(override_keys, cfg.keys())
+ for override in overrides:
+ if override.package is not None:
+ raise ConfigCompositionException(f"Override {override.input_line} looks like a config group"
+ f" override, but config group '{override.key_or_group}' does not exist.")
+
+ key = override.key_or_group
+ value = override.value()
+ try:
+ if override.is_delete():
+ config_val = OmegaConf.select(cfg, key, throw_on_missing=False)
+ if config_val is None:
+ raise ConfigCompositionException(f"Could not delete from config. '{override.key_or_group}'"
+ " does not exist.")
+ elif value is not None and value != config_val:
+ raise ConfigCompositionException("Could not delete from config. The value of"
+ f" '{override.key_or_group}' is {config_val} and not"
+ f" {value}.")
+
+ last_dot = key.rfind(".")
+ with open_dict(cfg):
+ if last_dot == -1:
+ del cfg[key]
+ else:
+ node = OmegaConf.select(cfg, key[:last_dot])
+ del node[key[last_dot + 1:]]
+
+ elif override.is_add():
+ if OmegaConf.select(cfg, key, throw_on_missing=False) is None or isinstance(value, (dict, list)):
+ OmegaConf.update(cfg, key, value, merge=True, force_add=True)
+ else:
+ assert override.input_line is not None
+ raise ConfigCompositionException(
+ dedent(f"""\
+ Could not append to config. An item is already at '{override.key_or_group}'.
+ Either remove + prefix: '{override.input_line[1:]}'
+ Or add a second + to add or override '{override.key_or_group}': '+{override.input_line}'
+ """))
+ elif override.is_force_add():
+ OmegaConf.update(cfg, key, value, merge=True, force_add=True)
+ else:
+ try:
+ OmegaConf.update(cfg, key, value, merge=True)
+ except (ConfigAttributeError, ConfigKeyError) as ex:
+ raise ConfigCompositionException(f"Could not override '{override.key_or_group}'."
+ f"\nTo append to your config use +{override.input_line}") from ex
+ except OmegaConfBaseException as ex:
+ raise ConfigCompositionException(f"Error merging override {override.input_line}").with_traceback(
+ sys.exc_info()[2]) from ex
+
+
+def check_config_mismatch(overrides, cfg):
+ mismatched = [option for option in overrides if option not in cfg and 'hydra.' not in option]
+
+ for option in mismatched:
+ LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}")
+ if mismatched:
+ exit()
+
+
+hydra._internal.config_loader_impl.ConfigLoaderImpl._apply_overrides_to_config = override_config
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/.gitignore b/ultralytics/yolo/v8/detect/deep_sort_pytorch/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..37ed2f4dc4a1ca945a0d807274bfe2f6cc7e2fec
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/.gitignore
@@ -0,0 +1,13 @@
+# Folders
+__pycache__/
+build/
+*.egg-info
+
+
+# Files
+*.weights
+*.t7
+*.mp4
+*.avi
+*.so
+*.txt
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/LICENSE b/ultralytics/yolo/v8/detect/deep_sort_pytorch/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..92a1ed5dc27676f33e306463d532e4969fbc42ae
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Ziqiang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/README.md b/ultralytics/yolo/v8/detect/deep_sort_pytorch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6073f8064faeaa5dfe6ec9642830b5506d02276f
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/README.md
@@ -0,0 +1,137 @@
+# Deep Sort with PyTorch
+
+
+
+## Update(1-1-2020)
+Changes
+- fix bugs
+- refactor code
+- accerate detection by adding nms on gpu
+
+## Latest Update(07-22)
+Changes
+- bug fix (Thanks @JieChen91 and @yingsen1 for bug reporting).
+- using batch for feature extracting for each frame, which lead to a small speed up.
+- code improvement.
+
+Futher improvement direction
+- Train detector on specific dataset rather than the official one.
+- Retrain REID model on pedestrain dataset for better performance.
+- Replace YOLOv3 detector with advanced ones.
+
+**Any contributions to this repository is welcome!**
+
+
+## Introduction
+This is an implement of MOT tracking algorithm deep sort. Deep sort is basicly the same with sort but added a CNN model to extract features in image of human part bounded by a detector. This CNN model is indeed a RE-ID model and the detector used in [PAPER](https://arxiv.org/abs/1703.07402) is FasterRCNN , and the original source code is [HERE](https://github.com/nwojke/deep_sort).
+However in original code, the CNN model is implemented with tensorflow, which I'm not familier with. SO I re-implemented the CNN feature extraction model with PyTorch, and changed the CNN model a little bit. Also, I use **YOLOv3** to generate bboxes instead of FasterRCNN.
+
+## Dependencies
+- python 3 (python2 not sure)
+- numpy
+- scipy
+- opencv-python
+- sklearn
+- torch >= 0.4
+- torchvision >= 0.1
+- pillow
+- vizer
+- edict
+
+## Quick Start
+0. Check all dependencies installed
+```bash
+pip install -r requirements.txt
+```
+for user in china, you can specify pypi source to accelerate install like:
+```bash
+pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
+```
+
+1. Clone this repository
+```
+git clone git@github.com:ZQPei/deep_sort_pytorch.git
+```
+
+2. Download YOLOv3 parameters
+```
+cd detector/YOLOv3/weight/
+wget https://pjreddie.com/media/files/yolov3.weights
+wget https://pjreddie.com/media/files/yolov3-tiny.weights
+cd ../../../
+```
+
+3. Download deepsort parameters ckpt.t7
+```
+cd deep_sort/deep/checkpoint
+# download ckpt.t7 from
+https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6 to this folder
+cd ../../../
+```
+
+4. Compile nms module
+```bash
+cd detector/YOLOv3/nms
+sh build.sh
+cd ../../..
+```
+
+Notice:
+If compiling failed, the simplist way is to **Upgrade your pytorch >= 1.1 and torchvision >= 0.3" and you can avoid the troublesome compiling problems which are most likely caused by either `gcc version too low` or `libraries missing`.
+
+5. Run demo
+```
+usage: python yolov3_deepsort.py VIDEO_PATH
+ [--help]
+ [--frame_interval FRAME_INTERVAL]
+ [--config_detection CONFIG_DETECTION]
+ [--config_deepsort CONFIG_DEEPSORT]
+ [--display]
+ [--display_width DISPLAY_WIDTH]
+ [--display_height DISPLAY_HEIGHT]
+ [--save_path SAVE_PATH]
+ [--cpu]
+
+# yolov3 + deepsort
+python yolov3_deepsort.py [VIDEO_PATH]
+
+# yolov3_tiny + deepsort
+python yolov3_deepsort.py [VIDEO_PATH] --config_detection ./configs/yolov3_tiny.yaml
+
+# yolov3 + deepsort on webcam
+python3 yolov3_deepsort.py /dev/video0 --camera 0
+
+# yolov3_tiny + deepsort on webcam
+python3 yolov3_deepsort.py /dev/video0 --config_detection ./configs/yolov3_tiny.yaml --camera 0
+```
+Use `--display` to enable display.
+Results will be saved to `./output/results.avi` and `./output/results.txt`.
+
+All files above can also be accessed from BaiduDisk!
+linkerοΌ[BaiduDisk](https://pan.baidu.com/s/1YJ1iPpdFTlUyLFoonYvozg)
+passwdοΌfbuw
+
+## Training the RE-ID model
+The original model used in paper is in original_model.py, and its parameter here [original_ckpt.t7](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6).
+
+To train the model, first you need download [Market1501](http://www.liangzheng.com.cn/Project/project_reid.html) dataset or [Mars](http://www.liangzheng.com.cn/Project/project_mars.html) dataset.
+
+Then you can try [train.py](deep_sort/deep/train.py) to train your own parameter and evaluate it using [test.py](deep_sort/deep/test.py) and [evaluate.py](deep_sort/deep/evalute.py).
+
+
+## Demo videos and images
+[demo.avi](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6)
+[demo2.avi](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6)
+
+
+
+
+
+## References
+- paper: [Simple Online and Realtime Tracking with a Deep Association Metric](https://arxiv.org/abs/1703.07402)
+
+- code: [nwojke/deep_sort](https://github.com/nwojke/deep_sort)
+
+- paper: [YOLOv3](https://pjreddie.com/media/files/papers/YOLOv3.pdf)
+
+- code: [Joseph Redmon/yolov3](https://pjreddie.com/darknet/yolo/)
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/configs/deep_sort.yaml b/ultralytics/yolo/v8/detect/deep_sort_pytorch/configs/deep_sort.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f91a9c1ac5d2b1d4e3030d42f4f9d4c3476738e3
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/configs/deep_sort.yaml
@@ -0,0 +1,10 @@
+DEEPSORT:
+ REID_CKPT: "deep_sort_pytorch/deep_sort/deep/checkpoint/ckpt.t7"
+ MAX_DIST: 0.2
+ MIN_CONFIDENCE: 0.3
+ NMS_MAX_OVERLAP: 0.5
+ MAX_IOU_DISTANCE: 0.7
+ MAX_AGE: 70
+ N_INIT: 3
+ NN_BUDGET: 100
+
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/README.md b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e89c9b3ea08691210046fbb9184bf8e44e88f29e
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/README.md
@@ -0,0 +1,3 @@
+# Deep Sort
+
+This is the implemention of deep sort with pytorch.
\ No newline at end of file
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__init__.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fe5d0fd796ec4f46dc4141f5e4f9f5092f7d321
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__init__.py
@@ -0,0 +1,21 @@
+from .deep_sort import DeepSort
+
+
+__all__ = ['DeepSort', 'build_tracker']
+
+
+def build_tracker(cfg, use_cuda):
+ return DeepSort(cfg.DEEPSORT.REID_CKPT,
+ max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
+ nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
+ max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET, use_cuda=use_cuda)
+
+
+
+
+
+
+
+
+
+
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65407b119093265960f8b0a0c021aea8477d25af
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93bae5efb0352f1e46b1ee4194acade53dcf255e
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e46516e5c0c5ea2a4ea4327040644f51ee2bd27
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..905f2089f6539c14b9936422bd3fd58e36c00f9c
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a46575c73b42b9053dea6b47679a381cc83709a
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae93de27e3e41531bf41e4d59c582c518f50a503
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/__pycache__/deep_sort.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__init__.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70e634c0f7d6241562350060140dbdc38adbe1b1
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35756354e5d8595db590a360927a77c49c51165e
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..825bafb3b0b7046492895e7f3c257e673a3187e9
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14ae9c7f36f16b8a86fbd145b8984ac493c164f4
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c6dd4efc4cfdc9d096afd80f2a51be82c42a709
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3a5e91359b5302123721da475e46da865368306
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/feature_extractor.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b68ac0eff6267923b6a388449d21a03f5b1dd06
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6f62629453b5eeeb47305b617b2aaeb585e00e0
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d21330765d33b3a83f2cb903b1173da8f8924394
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/__pycache__/model.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/checkpoint/.gitkeep b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/checkpoint/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/checkpoint/ckpt.t7 b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/checkpoint/ckpt.t7
new file mode 100644
index 0000000000000000000000000000000000000000..cd7ceebe86bdfaea299f31994844488295f00ca8
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/checkpoint/ckpt.t7
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df75ddef42c3d1bda67bc94b093e7ce61de7f75a89f36a8f868a428462198316
+size 46034619
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/evaluate.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0458ace6993dcae9f820e076f8c5dcc62d592ca
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/evaluate.py
@@ -0,0 +1,13 @@
+import torch
+
+features = torch.load("features.pth")
+qf = features["qf"]
+ql = features["ql"]
+gf = features["gf"]
+gl = features["gl"]
+
+scores = qf.mm(gf.t())
+res = scores.topk(5, dim=1)[1][:, 0]
+top1correct = gl[res].eq(ql).sum().item()
+
+print("Acc top1:{:.3f}".format(top1correct / ql.size(0)))
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/feature_extractor.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a342cf5b6021dcc009ea7e4d35f6f28e298bda65
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/feature_extractor.py
@@ -0,0 +1,54 @@
+import torch
+import torchvision.transforms as transforms
+import numpy as np
+import cv2
+import logging
+
+from .model import Net
+
+
+class Extractor(object):
+ def __init__(self, model_path, use_cuda=True):
+ self.net = Net(reid=True)
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
+ state_dict = torch.load(model_path, map_location=torch.device(self.device))[
+ 'net_dict']
+ self.net.load_state_dict(state_dict)
+ logger = logging.getLogger("root.tracker")
+ logger.info("Loading weights from {}... Done!".format(model_path))
+ self.net.to(self.device)
+ self.size = (64, 128)
+ self.norm = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ])
+
+ def _preprocess(self, im_crops):
+ """
+ TODO:
+ 1. to float with scale from 0 to 1
+ 2. resize to (64, 128) as Market1501 dataset did
+ 3. concatenate to a numpy array
+ 3. to torch Tensor
+ 4. normalize
+ """
+ def _resize(im, size):
+ return cv2.resize(im.astype(np.float32)/255., size)
+
+ im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(
+ 0) for im in im_crops], dim=0).float()
+ return im_batch
+
+ def __call__(self, im_crops):
+ im_batch = self._preprocess(im_crops)
+ with torch.no_grad():
+ im_batch = im_batch.to(self.device)
+ features = self.net(im_batch)
+ return features.cpu().numpy()
+
+
+if __name__ == '__main__':
+ img = cv2.imread("demo.jpg")[:, :, (2, 1, 0)]
+ extr = Extractor("checkpoint/ckpt.t7")
+ feature = extr(img)
+ print(feature.shape)
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/model.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b99247489627df09276b52f6d47ef866e0e5bd4a
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/model.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, c_in, c_out, is_downsample=False):
+ super(BasicBlock, self).__init__()
+ self.is_downsample = is_downsample
+ if is_downsample:
+ self.conv1 = nn.Conv2d(
+ c_in, c_out, 3, stride=2, padding=1, bias=False)
+ else:
+ self.conv1 = nn.Conv2d(
+ c_in, c_out, 3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(c_out)
+ self.relu = nn.ReLU(True)
+ self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(c_out)
+ if is_downsample:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
+ nn.BatchNorm2d(c_out)
+ )
+ elif c_in != c_out:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
+ nn.BatchNorm2d(c_out)
+ )
+ self.is_downsample = True
+
+ def forward(self, x):
+ y = self.conv1(x)
+ y = self.bn1(y)
+ y = self.relu(y)
+ y = self.conv2(y)
+ y = self.bn2(y)
+ if self.is_downsample:
+ x = self.downsample(x)
+ return F.relu(x.add(y), True)
+
+
+def make_layers(c_in, c_out, repeat_times, is_downsample=False):
+ blocks = []
+ for i in range(repeat_times):
+ if i == 0:
+ blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
+ else:
+ blocks += [BasicBlock(c_out, c_out), ]
+ return nn.Sequential(*blocks)
+
+
+class Net(nn.Module):
+ def __init__(self, num_classes=751, reid=False):
+ super(Net, self).__init__()
+ # 3 128 64
+ self.conv = nn.Sequential(
+ nn.Conv2d(3, 64, 3, stride=1, padding=1),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
+ # nn.Conv2d(32,32,3,stride=1,padding=1),
+ # nn.BatchNorm2d(32),
+ # nn.ReLU(inplace=True),
+ nn.MaxPool2d(3, 2, padding=1),
+ )
+ # 32 64 32
+ self.layer1 = make_layers(64, 64, 2, False)
+ # 32 64 32
+ self.layer2 = make_layers(64, 128, 2, True)
+ # 64 32 16
+ self.layer3 = make_layers(128, 256, 2, True)
+ # 128 16 8
+ self.layer4 = make_layers(256, 512, 2, True)
+ # 256 8 4
+ self.avgpool = nn.AvgPool2d((8, 4), 1)
+ # 256 1 1
+ self.reid = reid
+ self.classifier = nn.Sequential(
+ nn.Linear(512, 256),
+ nn.BatchNorm1d(256),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(256, num_classes),
+ )
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ # B x 128
+ if self.reid:
+ x = x.div(x.norm(p=2, dim=1, keepdim=True))
+ return x
+ # classifier
+ x = self.classifier(x)
+ return x
+
+
+if __name__ == '__main__':
+ net = Net()
+ x = torch.randn(4, 3, 128, 64)
+ y = net(x)
+ import ipdb
+ ipdb.set_trace()
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/original_model.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/original_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..27734ad52b3b02d815416d998bae145a93dbf519
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/original_model.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, c_in, c_out, is_downsample=False):
+ super(BasicBlock, self).__init__()
+ self.is_downsample = is_downsample
+ if is_downsample:
+ self.conv1 = nn.Conv2d(
+ c_in, c_out, 3, stride=2, padding=1, bias=False)
+ else:
+ self.conv1 = nn.Conv2d(
+ c_in, c_out, 3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(c_out)
+ self.relu = nn.ReLU(True)
+ self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(c_out)
+ if is_downsample:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
+ nn.BatchNorm2d(c_out)
+ )
+ elif c_in != c_out:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
+ nn.BatchNorm2d(c_out)
+ )
+ self.is_downsample = True
+
+ def forward(self, x):
+ y = self.conv1(x)
+ y = self.bn1(y)
+ y = self.relu(y)
+ y = self.conv2(y)
+ y = self.bn2(y)
+ if self.is_downsample:
+ x = self.downsample(x)
+ return F.relu(x.add(y), True)
+
+
+def make_layers(c_in, c_out, repeat_times, is_downsample=False):
+ blocks = []
+ for i in range(repeat_times):
+ if i == 0:
+ blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
+ else:
+ blocks += [BasicBlock(c_out, c_out), ]
+ return nn.Sequential(*blocks)
+
+
+class Net(nn.Module):
+ def __init__(self, num_classes=625, reid=False):
+ super(Net, self).__init__()
+ # 3 128 64
+ self.conv = nn.Sequential(
+ nn.Conv2d(3, 32, 3, stride=1, padding=1),
+ nn.BatchNorm2d(32),
+ nn.ELU(inplace=True),
+ nn.Conv2d(32, 32, 3, stride=1, padding=1),
+ nn.BatchNorm2d(32),
+ nn.ELU(inplace=True),
+ nn.MaxPool2d(3, 2, padding=1),
+ )
+ # 32 64 32
+ self.layer1 = make_layers(32, 32, 2, False)
+ # 32 64 32
+ self.layer2 = make_layers(32, 64, 2, True)
+ # 64 32 16
+ self.layer3 = make_layers(64, 128, 2, True)
+ # 128 16 8
+ self.dense = nn.Sequential(
+ nn.Dropout(p=0.6),
+ nn.Linear(128*16*8, 128),
+ nn.BatchNorm1d(128),
+ nn.ELU(inplace=True)
+ )
+ # 256 1 1
+ self.reid = reid
+ self.batch_norm = nn.BatchNorm1d(128)
+ self.classifier = nn.Sequential(
+ nn.Linear(128, num_classes),
+ )
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ x = x.view(x.size(0), -1)
+ if self.reid:
+ x = self.dense[0](x)
+ x = self.dense[1](x)
+ x = x.div(x.norm(p=2, dim=1, keepdim=True))
+ return x
+ x = self.dense(x)
+ # B x 128
+ # classifier
+ x = self.classifier(x)
+ return x
+
+
+if __name__ == '__main__':
+ net = Net(reid=True)
+ x = torch.randn(4, 3, 128, 64)
+ y = net(x)
+ import ipdb
+ ipdb.set_trace()
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/test.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba3050cb441e6419112604657797c78b6aa9b74
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/test.py
@@ -0,0 +1,80 @@
+import torch
+import torch.backends.cudnn as cudnn
+import torchvision
+
+import argparse
+import os
+
+from model import Net
+
+parser = argparse.ArgumentParser(description="Train on market1501")
+parser.add_argument("--data-dir", default='data', type=str)
+parser.add_argument("--no-cuda", action="store_true")
+parser.add_argument("--gpu-id", default=0, type=int)
+args = parser.parse_args()
+
+# device
+device = "cuda:{}".format(
+ args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
+if torch.cuda.is_available() and not args.no_cuda:
+ cudnn.benchmark = True
+
+# data loader
+root = args.data_dir
+query_dir = os.path.join(root, "query")
+gallery_dir = os.path.join(root, "gallery")
+transform = torchvision.transforms.Compose([
+ torchvision.transforms.Resize((128, 64)),
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Normalize(
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+])
+queryloader = torch.utils.data.DataLoader(
+ torchvision.datasets.ImageFolder(query_dir, transform=transform),
+ batch_size=64, shuffle=False
+)
+galleryloader = torch.utils.data.DataLoader(
+ torchvision.datasets.ImageFolder(gallery_dir, transform=transform),
+ batch_size=64, shuffle=False
+)
+
+# net definition
+net = Net(reid=True)
+assert os.path.isfile(
+ "./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
+print('Loading from checkpoint/ckpt.t7')
+checkpoint = torch.load("./checkpoint/ckpt.t7")
+net_dict = checkpoint['net_dict']
+net.load_state_dict(net_dict, strict=False)
+net.eval()
+net.to(device)
+
+# compute features
+query_features = torch.tensor([]).float()
+query_labels = torch.tensor([]).long()
+gallery_features = torch.tensor([]).float()
+gallery_labels = torch.tensor([]).long()
+
+with torch.no_grad():
+ for idx, (inputs, labels) in enumerate(queryloader):
+ inputs = inputs.to(device)
+ features = net(inputs).cpu()
+ query_features = torch.cat((query_features, features), dim=0)
+ query_labels = torch.cat((query_labels, labels))
+
+ for idx, (inputs, labels) in enumerate(galleryloader):
+ inputs = inputs.to(device)
+ features = net(inputs).cpu()
+ gallery_features = torch.cat((gallery_features, features), dim=0)
+ gallery_labels = torch.cat((gallery_labels, labels))
+
+gallery_labels -= 2
+
+# save features
+features = {
+ "qf": query_features,
+ "ql": query_labels,
+ "gf": gallery_features,
+ "gl": gallery_labels
+}
+torch.save(features, "features.pth")
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/train.jpg b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/train.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3635a614738828b880aa862bc52423848ac8e472
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/train.jpg differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/train.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..67f475634cea1997212ee37917397134c5c4173b
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep/train.py
@@ -0,0 +1,206 @@
+import argparse
+import os
+import time
+
+import numpy as np
+import matplotlib.pyplot as plt
+import torch
+import torch.backends.cudnn as cudnn
+import torchvision
+
+from model import Net
+
+parser = argparse.ArgumentParser(description="Train on market1501")
+parser.add_argument("--data-dir", default='data', type=str)
+parser.add_argument("--no-cuda", action="store_true")
+parser.add_argument("--gpu-id", default=0, type=int)
+parser.add_argument("--lr", default=0.1, type=float)
+parser.add_argument("--interval", '-i', default=20, type=int)
+parser.add_argument('--resume', '-r', action='store_true')
+args = parser.parse_args()
+
+# device
+device = "cuda:{}".format(
+ args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
+if torch.cuda.is_available() and not args.no_cuda:
+ cudnn.benchmark = True
+
+# data loading
+root = args.data_dir
+train_dir = os.path.join(root, "train")
+test_dir = os.path.join(root, "test")
+transform_train = torchvision.transforms.Compose([
+ torchvision.transforms.RandomCrop((128, 64), padding=4),
+ torchvision.transforms.RandomHorizontalFlip(),
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Normalize(
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+])
+transform_test = torchvision.transforms.Compose([
+ torchvision.transforms.Resize((128, 64)),
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Normalize(
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+])
+trainloader = torch.utils.data.DataLoader(
+ torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
+ batch_size=64, shuffle=True
+)
+testloader = torch.utils.data.DataLoader(
+ torchvision.datasets.ImageFolder(test_dir, transform=transform_test),
+ batch_size=64, shuffle=True
+)
+num_classes = max(len(trainloader.dataset.classes),
+ len(testloader.dataset.classes))
+
+# net definition
+start_epoch = 0
+net = Net(num_classes=num_classes)
+if args.resume:
+ assert os.path.isfile(
+ "./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
+ print('Loading from checkpoint/ckpt.t7')
+ checkpoint = torch.load("./checkpoint/ckpt.t7")
+ # import ipdb; ipdb.set_trace()
+ net_dict = checkpoint['net_dict']
+ net.load_state_dict(net_dict)
+ best_acc = checkpoint['acc']
+ start_epoch = checkpoint['epoch']
+net.to(device)
+
+# loss and optimizer
+criterion = torch.nn.CrossEntropyLoss()
+optimizer = torch.optim.SGD(
+ net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4)
+best_acc = 0.
+
+# train function for each epoch
+
+
+def train(epoch):
+ print("\nEpoch : %d" % (epoch+1))
+ net.train()
+ training_loss = 0.
+ train_loss = 0.
+ correct = 0
+ total = 0
+ interval = args.interval
+ start = time.time()
+ for idx, (inputs, labels) in enumerate(trainloader):
+ # forward
+ inputs, labels = inputs.to(device), labels.to(device)
+ outputs = net(inputs)
+ loss = criterion(outputs, labels)
+
+ # backward
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ # accumurating
+ training_loss += loss.item()
+ train_loss += loss.item()
+ correct += outputs.max(dim=1)[1].eq(labels).sum().item()
+ total += labels.size(0)
+
+ # print
+ if (idx+1) % interval == 0:
+ end = time.time()
+ print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
+ 100.*(idx+1)/len(trainloader), end-start, training_loss /
+ interval, correct, total, 100.*correct/total
+ ))
+ training_loss = 0.
+ start = time.time()
+
+ return train_loss/len(trainloader), 1. - correct/total
+
+
+def test(epoch):
+ global best_acc
+ net.eval()
+ test_loss = 0.
+ correct = 0
+ total = 0
+ start = time.time()
+ with torch.no_grad():
+ for idx, (inputs, labels) in enumerate(testloader):
+ inputs, labels = inputs.to(device), labels.to(device)
+ outputs = net(inputs)
+ loss = criterion(outputs, labels)
+
+ test_loss += loss.item()
+ correct += outputs.max(dim=1)[1].eq(labels).sum().item()
+ total += labels.size(0)
+
+ print("Testing ...")
+ end = time.time()
+ print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
+ 100.*(idx+1)/len(testloader), end-start, test_loss /
+ len(testloader), correct, total, 100.*correct/total
+ ))
+
+ # saving checkpoint
+ acc = 100.*correct/total
+ if acc > best_acc:
+ best_acc = acc
+ print("Saving parameters to checkpoint/ckpt.t7")
+ checkpoint = {
+ 'net_dict': net.state_dict(),
+ 'acc': acc,
+ 'epoch': epoch,
+ }
+ if not os.path.isdir('checkpoint'):
+ os.mkdir('checkpoint')
+ torch.save(checkpoint, './checkpoint/ckpt.t7')
+
+ return test_loss/len(testloader), 1. - correct/total
+
+
+# plot figure
+x_epoch = []
+record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
+fig = plt.figure()
+ax0 = fig.add_subplot(121, title="loss")
+ax1 = fig.add_subplot(122, title="top1err")
+
+
+def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
+ global record
+ record['train_loss'].append(train_loss)
+ record['train_err'].append(train_err)
+ record['test_loss'].append(test_loss)
+ record['test_err'].append(test_err)
+
+ x_epoch.append(epoch)
+ ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
+ ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
+ ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
+ ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
+ if epoch == 0:
+ ax0.legend()
+ ax1.legend()
+ fig.savefig("train.jpg")
+
+# lr decay
+
+
+def lr_decay():
+ global optimizer
+ for params in optimizer.param_groups:
+ params['lr'] *= 0.1
+ lr = params['lr']
+ print("Learning rate adjusted to {}".format(lr))
+
+
+def main():
+ for epoch in range(start_epoch, start_epoch+40):
+ train_loss, train_err = train(epoch)
+ test_loss, test_err = test(epoch)
+ draw_curve(epoch, train_loss, train_err, test_loss, test_err)
+ if (epoch+1) % 20 == 0:
+ lr_decay()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep_sort.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep_sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..2402e54f1f504d7fed82bc94c15c82b8d79aef99
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/deep_sort.py
@@ -0,0 +1,113 @@
+import numpy as np
+import torch
+
+from .deep.feature_extractor import Extractor
+from .sort.nn_matching import NearestNeighborDistanceMetric
+from .sort.detection import Detection
+from .sort.tracker import Tracker
+
+
+__all__ = ['DeepSort']
+
+
+class DeepSort(object):
+ def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7, max_age=70, n_init=3, nn_budget=100, use_cuda=True):
+ self.min_confidence = min_confidence
+ self.nms_max_overlap = nms_max_overlap
+
+ self.extractor = Extractor(model_path, use_cuda=use_cuda)
+
+ max_cosine_distance = max_dist
+ metric = NearestNeighborDistanceMetric(
+ "cosine", max_cosine_distance, nn_budget)
+ self.tracker = Tracker(
+ metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init)
+
+ def update(self, bbox_xywh, confidences, oids, ori_img):
+ self.height, self.width = ori_img.shape[:2]
+ # generate detections
+ features = self._get_features(bbox_xywh, ori_img)
+ bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)
+ detections = [Detection(bbox_tlwh[i], conf, features[i],oid) for i, (conf,oid) in enumerate(zip(confidences,oids)) if conf > self.min_confidence]
+
+ # run on non-maximum supression
+ boxes = np.array([d.tlwh for d in detections])
+ scores = np.array([d.confidence for d in detections])
+
+ # update tracker
+ self.tracker.predict()
+ self.tracker.update(detections)
+
+ # output bbox identities
+ outputs = []
+ for track in self.tracker.tracks:
+ if not track.is_confirmed() or track.time_since_update > 1:
+ continue
+ box = track.to_tlwh()
+ x1, y1, x2, y2 = self._tlwh_to_xyxy(box)
+ track_id = track.track_id
+ track_oid = track.oid
+ outputs.append(np.array([x1, y1, x2, y2, track_id, track_oid], dtype=np.int))
+ if len(outputs) > 0:
+ outputs = np.stack(outputs, axis=0)
+ return outputs
+
+ """
+ TODO:
+ Convert bbox from xc_yc_w_h to xtl_ytl_w_h
+ Thanks JieChen91@github.com for reporting this bug!
+ """
+ @staticmethod
+ def _xywh_to_tlwh(bbox_xywh):
+ if isinstance(bbox_xywh, np.ndarray):
+ bbox_tlwh = bbox_xywh.copy()
+ elif isinstance(bbox_xywh, torch.Tensor):
+ bbox_tlwh = bbox_xywh.clone()
+ bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2.
+ bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2.
+ return bbox_tlwh
+
+ def _xywh_to_xyxy(self, bbox_xywh):
+ x, y, w, h = bbox_xywh
+ x1 = max(int(x - w / 2), 0)
+ x2 = min(int(x + w / 2), self.width - 1)
+ y1 = max(int(y - h / 2), 0)
+ y2 = min(int(y + h / 2), self.height - 1)
+ return x1, y1, x2, y2
+
+ def _tlwh_to_xyxy(self, bbox_tlwh):
+ """
+ TODO:
+ Convert bbox from xtl_ytl_w_h to xc_yc_w_h
+ Thanks JieChen91@github.com for reporting this bug!
+ """
+ x, y, w, h = bbox_tlwh
+ x1 = max(int(x), 0)
+ x2 = min(int(x+w), self.width - 1)
+ y1 = max(int(y), 0)
+ y2 = min(int(y+h), self.height - 1)
+ return x1, y1, x2, y2
+
+ def increment_ages(self):
+ self.tracker.increment_ages()
+
+ def _xyxy_to_tlwh(self, bbox_xyxy):
+ x1, y1, x2, y2 = bbox_xyxy
+
+ t = x1
+ l = y1
+ w = int(x2 - x1)
+ h = int(y2 - y1)
+ return t, l, w, h
+
+ def _get_features(self, bbox_xywh, ori_img):
+ im_crops = []
+ for box in bbox_xywh:
+ x1, y1, x2, y2 = self._xywh_to_xyxy(box)
+ im = ori_img[y1:y2, x1:x2]
+ im_crops.append(im)
+ if im_crops:
+ features = self.extractor(im_crops)
+ else:
+ features = np.array([])
+ return features
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__init__.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/__init__.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6204224f223c6cd198f45b616310699e83c551b
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/__init__.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/__init__.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1c6756cac22ddf242087888fdaaec99a2698f03
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/detection.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/detection.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77e768bdff7fb23b384702acd75354d6a413651f
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/detection.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/detection.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/detection.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55dd0a151c34cb560db583a0d5f627502beab456
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/detection.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/iou_matching.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/iou_matching.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1257a1231c1e9c2b78aaf60b6378b9c2acde5e9
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/iou_matching.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/iou_matching.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/iou_matching.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc025dec73f90da4efa385c8fbc010ddc358223a
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/iou_matching.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/kalman_filter.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/kalman_filter.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c748513a8fbb21ed18e42f0ee36b6c29ebbe1b3e
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/kalman_filter.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/kalman_filter.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/kalman_filter.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35c48bcdd4a1129427c65068edb6673688ac41c7
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/kalman_filter.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/linear_assignment.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/linear_assignment.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0538ca0740413bd57e415fde38dea46c79530277
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/linear_assignment.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/linear_assignment.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/linear_assignment.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a29d6746fe1b7eb3b29407b6f4213de73fadd33f
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/linear_assignment.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/nn_matching.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/nn_matching.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f145a8f22534047d9e2d0637dbc28ce0ab6d35b3
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/nn_matching.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/nn_matching.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/nn_matching.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ee19ca9c77684004eb6fdc5fa8638399c636bdd
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/nn_matching.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/track.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/track.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d0c10f8a6769fa3f96704386a871c5056bd3b2e
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/track.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/track.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/track.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6250abc92a33dd95552b7ba5c2c8c4e03d58a03
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/track.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/tracker.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/tracker.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b3ed928b6d1241475dc18cca023292f9f843447
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/tracker.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/tracker.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/tracker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c13b762cebe6cb85b6900a7b52f80603dde40ffd
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/__pycache__/tracker.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/iou_matching.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/iou_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..62d5a3f63b70db5e322b6f8766444dd824c010ae
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/iou_matching.py
@@ -0,0 +1,82 @@
+# vim: expandtab:ts=4:sw=4
+from __future__ import absolute_import
+import numpy as np
+from . import linear_assignment
+
+
+def iou(bbox, candidates):
+ """Computer intersection over union.
+
+ Parameters
+ ----------
+ bbox : ndarray
+ A bounding box in format `(top left x, top left y, width, height)`.
+ candidates : ndarray
+ A matrix of candidate bounding boxes (one per row) in the same format
+ as `bbox`.
+
+ Returns
+ -------
+ ndarray
+ The intersection over union in [0, 1] between the `bbox` and each
+ candidate. A higher score means a larger fraction of the `bbox` is
+ occluded by the candidate.
+
+ """
+ bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
+ candidates_tl = candidates[:, :2]
+ candidates_br = candidates[:, :2] + candidates[:, 2:]
+
+ tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
+ np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
+ br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
+ np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
+ wh = np.maximum(0., br - tl)
+
+ area_intersection = wh.prod(axis=1)
+ area_bbox = bbox[2:].prod()
+ area_candidates = candidates[:, 2:].prod(axis=1)
+ return area_intersection / (area_bbox + area_candidates - area_intersection)
+
+
+def iou_cost(tracks, detections, track_indices=None,
+ detection_indices=None):
+ """An intersection over union distance metric.
+
+ Parameters
+ ----------
+ tracks : List[deep_sort.track.Track]
+ A list of tracks.
+ detections : List[deep_sort.detection.Detection]
+ A list of detections.
+ track_indices : Optional[List[int]]
+ A list of indices to tracks that should be matched. Defaults to
+ all `tracks`.
+ detection_indices : Optional[List[int]]
+ A list of indices to detections that should be matched. Defaults
+ to all `detections`.
+
+ Returns
+ -------
+ ndarray
+ Returns a cost matrix of shape
+ len(track_indices), len(detection_indices) where entry (i, j) is
+ `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
+
+ """
+ if track_indices is None:
+ track_indices = np.arange(len(tracks))
+ if detection_indices is None:
+ detection_indices = np.arange(len(detections))
+
+ cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
+ for row, track_idx in enumerate(track_indices):
+ if tracks[track_idx].time_since_update > 1:
+ cost_matrix[row, :] = linear_assignment.INFTY_COST
+ continue
+
+ bbox = tracks[track_idx].to_tlwh()
+ candidates = np.asarray(
+ [detections[i].tlwh for i in detection_indices])
+ cost_matrix[row, :] = 1. - iou(bbox, candidates)
+ return cost_matrix
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/kalman_filter.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/kalman_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..787a76e6a43870a9538647b51fda6a5254ce2d43
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/kalman_filter.py
@@ -0,0 +1,229 @@
+# vim: expandtab:ts=4:sw=4
+import numpy as np
+import scipy.linalg
+
+
+"""
+Table for the 0.95 quantile of the chi-square distribution with N degrees of
+freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
+function and used as Mahalanobis gating threshold.
+"""
+chi2inv95 = {
+ 1: 3.8415,
+ 2: 5.9915,
+ 3: 7.8147,
+ 4: 9.4877,
+ 5: 11.070,
+ 6: 12.592,
+ 7: 14.067,
+ 8: 15.507,
+ 9: 16.919}
+
+
+class KalmanFilter(object):
+ """
+ A simple Kalman filter for tracking bounding boxes in image space.
+
+ The 8-dimensional state space
+
+ x, y, a, h, vx, vy, va, vh
+
+ contains the bounding box center position (x, y), aspect ratio a, height h,
+ and their respective velocities.
+
+ Object motion follows a constant velocity model. The bounding box location
+ (x, y, a, h) is taken as direct observation of the state space (linear
+ observation model).
+
+ """
+
+ def __init__(self):
+ ndim, dt = 4, 1.
+
+ # Create Kalman filter model matrices.
+ self._motion_mat = np.eye(2 * ndim, 2 * ndim)
+ for i in range(ndim):
+ self._motion_mat[i, ndim + i] = dt
+ self._update_mat = np.eye(ndim, 2 * ndim)
+
+ # Motion and observation uncertainty are chosen relative to the current
+ # state estimate. These weights control the amount of uncertainty in
+ # the model. This is a bit hacky.
+ self._std_weight_position = 1. / 20
+ self._std_weight_velocity = 1. / 160
+
+ def initiate(self, measurement):
+ """Create track from unassociated measurement.
+
+ Parameters
+ ----------
+ measurement : ndarray
+ Bounding box coordinates (x, y, a, h) with center position (x, y),
+ aspect ratio a, and height h.
+
+ Returns
+ -------
+ (ndarray, ndarray)
+ Returns the mean vector (8 dimensional) and covariance matrix (8x8
+ dimensional) of the new track. Unobserved velocities are initialized
+ to 0 mean.
+
+ """
+ mean_pos = measurement
+ mean_vel = np.zeros_like(mean_pos)
+ mean = np.r_[mean_pos, mean_vel]
+
+ std = [
+ 2 * self._std_weight_position * measurement[3],
+ 2 * self._std_weight_position * measurement[3],
+ 1e-2,
+ 2 * self._std_weight_position * measurement[3],
+ 10 * self._std_weight_velocity * measurement[3],
+ 10 * self._std_weight_velocity * measurement[3],
+ 1e-5,
+ 10 * self._std_weight_velocity * measurement[3]]
+ covariance = np.diag(np.square(std))
+ return mean, covariance
+
+ def predict(self, mean, covariance):
+ """Run Kalman filter prediction step.
+
+ Parameters
+ ----------
+ mean : ndarray
+ The 8 dimensional mean vector of the object state at the previous
+ time step.
+ covariance : ndarray
+ The 8x8 dimensional covariance matrix of the object state at the
+ previous time step.
+
+ Returns
+ -------
+ (ndarray, ndarray)
+ Returns the mean vector and covariance matrix of the predicted
+ state. Unobserved velocities are initialized to 0 mean.
+
+ """
+ std_pos = [
+ self._std_weight_position * mean[3],
+ self._std_weight_position * mean[3],
+ 1e-2,
+ self._std_weight_position * mean[3]]
+ std_vel = [
+ self._std_weight_velocity * mean[3],
+ self._std_weight_velocity * mean[3],
+ 1e-5,
+ self._std_weight_velocity * mean[3]]
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
+
+ mean = np.dot(self._motion_mat, mean)
+ covariance = np.linalg.multi_dot((
+ self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
+
+ return mean, covariance
+
+ def project(self, mean, covariance):
+ """Project state distribution to measurement space.
+
+ Parameters
+ ----------
+ mean : ndarray
+ The state's mean vector (8 dimensional array).
+ covariance : ndarray
+ The state's covariance matrix (8x8 dimensional).
+
+ Returns
+ -------
+ (ndarray, ndarray)
+ Returns the projected mean and covariance matrix of the given state
+ estimate.
+
+ """
+ std = [
+ self._std_weight_position * mean[3],
+ self._std_weight_position * mean[3],
+ 1e-1,
+ self._std_weight_position * mean[3]]
+ innovation_cov = np.diag(np.square(std))
+
+ mean = np.dot(self._update_mat, mean)
+ covariance = np.linalg.multi_dot((
+ self._update_mat, covariance, self._update_mat.T))
+ return mean, covariance + innovation_cov
+
+ def update(self, mean, covariance, measurement):
+ """Run Kalman filter correction step.
+
+ Parameters
+ ----------
+ mean : ndarray
+ The predicted state's mean vector (8 dimensional).
+ covariance : ndarray
+ The state's covariance matrix (8x8 dimensional).
+ measurement : ndarray
+ The 4 dimensional measurement vector (x, y, a, h), where (x, y)
+ is the center position, a the aspect ratio, and h the height of the
+ bounding box.
+
+ Returns
+ -------
+ (ndarray, ndarray)
+ Returns the measurement-corrected state distribution.
+
+ """
+ projected_mean, projected_cov = self.project(mean, covariance)
+
+ chol_factor, lower = scipy.linalg.cho_factor(
+ projected_cov, lower=True, check_finite=False)
+ kalman_gain = scipy.linalg.cho_solve(
+ (chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
+ check_finite=False).T
+ innovation = measurement - projected_mean
+
+ new_mean = mean + np.dot(innovation, kalman_gain.T)
+ new_covariance = covariance - np.linalg.multi_dot((
+ kalman_gain, projected_cov, kalman_gain.T))
+ return new_mean, new_covariance
+
+ def gating_distance(self, mean, covariance, measurements,
+ only_position=False):
+ """Compute gating distance between state distribution and measurements.
+
+ A suitable distance threshold can be obtained from `chi2inv95`. If
+ `only_position` is False, the chi-square distribution has 4 degrees of
+ freedom, otherwise 2.
+
+ Parameters
+ ----------
+ mean : ndarray
+ Mean vector over the state distribution (8 dimensional).
+ covariance : ndarray
+ Covariance of the state distribution (8x8 dimensional).
+ measurements : ndarray
+ An Nx4 dimensional matrix of N measurements, each in
+ format (x, y, a, h) where (x, y) is the bounding box center
+ position, a the aspect ratio, and h the height.
+ only_position : Optional[bool]
+ If True, distance computation is done with respect to the bounding
+ box center position only.
+
+ Returns
+ -------
+ ndarray
+ Returns an array of length N, where the i-th element contains the
+ squared Mahalanobis distance between (mean, covariance) and
+ `measurements[i]`.
+
+ """
+ mean, covariance = self.project(mean, covariance)
+ if only_position:
+ mean, covariance = mean[:2], covariance[:2, :2]
+ measurements = measurements[:, :2]
+
+ cholesky_factor = np.linalg.cholesky(covariance)
+ d = measurements - mean
+ z = scipy.linalg.solve_triangular(
+ cholesky_factor, d.T, lower=True, check_finite=False,
+ overwrite_b=True)
+ squared_maha = np.sum(z * z, axis=0)
+ return squared_maha
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/linear_assignment.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/linear_assignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..858b71a4ae32ca39f03ff5d0ca0fdcc5963171b0
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/linear_assignment.py
@@ -0,0 +1,192 @@
+# vim: expandtab:ts=4:sw=4
+from __future__ import absolute_import
+import numpy as np
+# from sklearn.utils.linear_assignment_ import linear_assignment
+from scipy.optimize import linear_sum_assignment as linear_assignment
+from . import kalman_filter
+
+
+INFTY_COST = 1e+5
+
+
+def min_cost_matching(
+ distance_metric, max_distance, tracks, detections, track_indices=None,
+ detection_indices=None):
+ """Solve linear assignment problem.
+
+ Parameters
+ ----------
+ distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
+ The distance metric is given a list of tracks and detections as well as
+ a list of N track indices and M detection indices. The metric should
+ return the NxM dimensional cost matrix, where element (i, j) is the
+ association cost between the i-th track in the given track indices and
+ the j-th detection in the given detection_indices.
+ max_distance : float
+ Gating threshold. Associations with cost larger than this value are
+ disregarded.
+ tracks : List[track.Track]
+ A list of predicted tracks at the current time step.
+ detections : List[detection.Detection]
+ A list of detections at the current time step.
+ track_indices : List[int]
+ List of track indices that maps rows in `cost_matrix` to tracks in
+ `tracks` (see description above).
+ detection_indices : List[int]
+ List of detection indices that maps columns in `cost_matrix` to
+ detections in `detections` (see description above).
+
+ Returns
+ -------
+ (List[(int, int)], List[int], List[int])
+ Returns a tuple with the following three entries:
+ * A list of matched track and detection indices.
+ * A list of unmatched track indices.
+ * A list of unmatched detection indices.
+
+ """
+ if track_indices is None:
+ track_indices = np.arange(len(tracks))
+ if detection_indices is None:
+ detection_indices = np.arange(len(detections))
+
+ if len(detection_indices) == 0 or len(track_indices) == 0:
+ return [], track_indices, detection_indices # Nothing to match.
+
+ cost_matrix = distance_metric(
+ tracks, detections, track_indices, detection_indices)
+ cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
+
+ row_indices, col_indices = linear_assignment(cost_matrix)
+
+ matches, unmatched_tracks, unmatched_detections = [], [], []
+ for col, detection_idx in enumerate(detection_indices):
+ if col not in col_indices:
+ unmatched_detections.append(detection_idx)
+ for row, track_idx in enumerate(track_indices):
+ if row not in row_indices:
+ unmatched_tracks.append(track_idx)
+ for row, col in zip(row_indices, col_indices):
+ track_idx = track_indices[row]
+ detection_idx = detection_indices[col]
+ if cost_matrix[row, col] > max_distance:
+ unmatched_tracks.append(track_idx)
+ unmatched_detections.append(detection_idx)
+ else:
+ matches.append((track_idx, detection_idx))
+ return matches, unmatched_tracks, unmatched_detections
+
+
+def matching_cascade(
+ distance_metric, max_distance, cascade_depth, tracks, detections,
+ track_indices=None, detection_indices=None):
+ """Run matching cascade.
+
+ Parameters
+ ----------
+ distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
+ The distance metric is given a list of tracks and detections as well as
+ a list of N track indices and M detection indices. The metric should
+ return the NxM dimensional cost matrix, where element (i, j) is the
+ association cost between the i-th track in the given track indices and
+ the j-th detection in the given detection indices.
+ max_distance : float
+ Gating threshold. Associations with cost larger than this value are
+ disregarded.
+ cascade_depth: int
+ The cascade depth, should be se to the maximum track age.
+ tracks : List[track.Track]
+ A list of predicted tracks at the current time step.
+ detections : List[detection.Detection]
+ A list of detections at the current time step.
+ track_indices : Optional[List[int]]
+ List of track indices that maps rows in `cost_matrix` to tracks in
+ `tracks` (see description above). Defaults to all tracks.
+ detection_indices : Optional[List[int]]
+ List of detection indices that maps columns in `cost_matrix` to
+ detections in `detections` (see description above). Defaults to all
+ detections.
+
+ Returns
+ -------
+ (List[(int, int)], List[int], List[int])
+ Returns a tuple with the following three entries:
+ * A list of matched track and detection indices.
+ * A list of unmatched track indices.
+ * A list of unmatched detection indices.
+
+ """
+ if track_indices is None:
+ track_indices = list(range(len(tracks)))
+ if detection_indices is None:
+ detection_indices = list(range(len(detections)))
+
+ unmatched_detections = detection_indices
+ matches = []
+ for level in range(cascade_depth):
+ if len(unmatched_detections) == 0: # No detections left
+ break
+
+ track_indices_l = [
+ k for k in track_indices
+ if tracks[k].time_since_update == 1 + level
+ ]
+ if len(track_indices_l) == 0: # Nothing to match at this level
+ continue
+
+ matches_l, _, unmatched_detections = \
+ min_cost_matching(
+ distance_metric, max_distance, tracks, detections,
+ track_indices_l, unmatched_detections)
+ matches += matches_l
+ unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
+ return matches, unmatched_tracks, unmatched_detections
+
+
+def gate_cost_matrix(
+ kf, cost_matrix, tracks, detections, track_indices, detection_indices,
+ gated_cost=INFTY_COST, only_position=False):
+ """Invalidate infeasible entries in cost matrix based on the state
+ distributions obtained by Kalman filtering.
+
+ Parameters
+ ----------
+ kf : The Kalman filter.
+ cost_matrix : ndarray
+ The NxM dimensional cost matrix, where N is the number of track indices
+ and M is the number of detection indices, such that entry (i, j) is the
+ association cost between `tracks[track_indices[i]]` and
+ `detections[detection_indices[j]]`.
+ tracks : List[track.Track]
+ A list of predicted tracks at the current time step.
+ detections : List[detection.Detection]
+ A list of detections at the current time step.
+ track_indices : List[int]
+ List of track indices that maps rows in `cost_matrix` to tracks in
+ `tracks` (see description above).
+ detection_indices : List[int]
+ List of detection indices that maps columns in `cost_matrix` to
+ detections in `detections` (see description above).
+ gated_cost : Optional[float]
+ Entries in the cost matrix corresponding to infeasible associations are
+ set this value. Defaults to a very large value.
+ only_position : Optional[bool]
+ If True, only the x, y position of the state distribution is considered
+ during gating. Defaults to False.
+
+ Returns
+ -------
+ ndarray
+ Returns the modified cost matrix.
+
+ """
+ gating_dim = 2 if only_position else 4
+ gating_threshold = kalman_filter.chi2inv95[gating_dim]
+ measurements = np.asarray(
+ [detections[i].to_xyah() for i in detection_indices])
+ for row, track_idx in enumerate(track_indices):
+ track = tracks[track_idx]
+ gating_distance = kf.gating_distance(
+ track.mean, track.covariance, measurements, only_position)
+ cost_matrix[row, gating_distance > gating_threshold] = gated_cost
+ return cost_matrix
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/nn_matching.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/nn_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..21e5b4f478fead21d38227ce2eac34556bd1179e
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/nn_matching.py
@@ -0,0 +1,176 @@
+# vim: expandtab:ts=4:sw=4
+import numpy as np
+
+
+def _pdist(a, b):
+ """Compute pair-wise squared distance between points in `a` and `b`.
+
+ Parameters
+ ----------
+ a : array_like
+ An NxM matrix of N samples of dimensionality M.
+ b : array_like
+ An LxM matrix of L samples of dimensionality M.
+
+ Returns
+ -------
+ ndarray
+ Returns a matrix of size len(a), len(b) such that eleement (i, j)
+ contains the squared distance between `a[i]` and `b[j]`.
+
+ """
+ a, b = np.asarray(a), np.asarray(b)
+ if len(a) == 0 or len(b) == 0:
+ return np.zeros((len(a), len(b)))
+ a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
+ r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
+ r2 = np.clip(r2, 0., float(np.inf))
+ return r2
+
+
+def _cosine_distance(a, b, data_is_normalized=False):
+ """Compute pair-wise cosine distance between points in `a` and `b`.
+
+ Parameters
+ ----------
+ a : array_like
+ An NxM matrix of N samples of dimensionality M.
+ b : array_like
+ An LxM matrix of L samples of dimensionality M.
+ data_is_normalized : Optional[bool]
+ If True, assumes rows in a and b are unit length vectors.
+ Otherwise, a and b are explicitly normalized to lenght 1.
+
+ Returns
+ -------
+ ndarray
+ Returns a matrix of size len(a), len(b) such that eleement (i, j)
+ contains the squared distance between `a[i]` and `b[j]`.
+
+ """
+ if not data_is_normalized:
+ a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
+ b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
+ return 1. - np.dot(a, b.T)
+
+
+def _nn_euclidean_distance(x, y):
+ """ Helper function for nearest neighbor distance metric (Euclidean).
+
+ Parameters
+ ----------
+ x : ndarray
+ A matrix of N row-vectors (sample points).
+ y : ndarray
+ A matrix of M row-vectors (query points).
+
+ Returns
+ -------
+ ndarray
+ A vector of length M that contains for each entry in `y` the
+ smallest Euclidean distance to a sample in `x`.
+
+ """
+ distances = _pdist(x, y)
+ return np.maximum(0.0, distances.min(axis=0))
+
+
+def _nn_cosine_distance(x, y):
+ """ Helper function for nearest neighbor distance metric (cosine).
+
+ Parameters
+ ----------
+ x : ndarray
+ A matrix of N row-vectors (sample points).
+ y : ndarray
+ A matrix of M row-vectors (query points).
+
+ Returns
+ -------
+ ndarray
+ A vector of length M that contains for each entry in `y` the
+ smallest cosine distance to a sample in `x`.
+
+ """
+ distances = _cosine_distance(x, y)
+ return distances.min(axis=0)
+
+
+class NearestNeighborDistanceMetric(object):
+ """
+ A nearest neighbor distance metric that, for each target, returns
+ the closest distance to any sample that has been observed so far.
+
+ Parameters
+ ----------
+ metric : str
+ Either "euclidean" or "cosine".
+ matching_threshold: float
+ The matching threshold. Samples with larger distance are considered an
+ invalid match.
+ budget : Optional[int]
+ If not None, fix samples per class to at most this number. Removes
+ the oldest samples when the budget is reached.
+
+ Attributes
+ ----------
+ samples : Dict[int -> List[ndarray]]
+ A dictionary that maps from target identities to the list of samples
+ that have been observed so far.
+
+ """
+
+ def __init__(self, metric, matching_threshold, budget=None):
+
+ if metric == "euclidean":
+ self._metric = _nn_euclidean_distance
+ elif metric == "cosine":
+ self._metric = _nn_cosine_distance
+ else:
+ raise ValueError(
+ "Invalid metric; must be either 'euclidean' or 'cosine'")
+ self.matching_threshold = matching_threshold
+ self.budget = budget
+ self.samples = {}
+
+ def partial_fit(self, features, targets, active_targets):
+ """Update the distance metric with new data.
+
+ Parameters
+ ----------
+ features : ndarray
+ An NxM matrix of N features of dimensionality M.
+ targets : ndarray
+ An integer array of associated target identities.
+ active_targets : List[int]
+ A list of targets that are currently present in the scene.
+
+ """
+ for feature, target in zip(features, targets):
+ self.samples.setdefault(target, []).append(feature)
+ if self.budget is not None:
+ self.samples[target] = self.samples[target][-self.budget:]
+ self.samples = {k: self.samples[k] for k in active_targets}
+
+ def distance(self, features, targets):
+ """Compute distance between features and targets.
+
+ Parameters
+ ----------
+ features : ndarray
+ An NxM matrix of N features of dimensionality M.
+ targets : List[int]
+ A list of targets to match the given `features` against.
+
+ Returns
+ -------
+ ndarray
+ Returns a cost matrix of shape len(targets), len(features), where
+ element (i, j) contains the closest squared distance between
+ `targets[i]` and `features[j]`.
+
+ """
+ cost_matrix = np.zeros((len(targets), len(features)))
+ for i, target in enumerate(targets):
+ cost_matrix[i, :] = self._metric(self.samples[target], features)
+ return cost_matrix
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/preprocessing.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..5493b127f602dec398efac4269c00d31a3650ce9
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort - Copy/preprocessing.py
@@ -0,0 +1,73 @@
+# vim: expandtab:ts=4:sw=4
+import numpy as np
+import cv2
+
+
+def non_max_suppression(boxes, max_bbox_overlap, scores=None):
+ """Suppress overlapping detections.
+
+ Original code from [1]_ has been adapted to include confidence score.
+
+ .. [1] http://www.pyimagesearch.com/2015/02/16/
+ faster-non-maximum-suppression-python/
+
+ Examples
+ --------
+
+ >>> boxes = [d.roi for d in detections]
+ >>> scores = [d.confidence for d in detections]
+ >>> indices = non_max_suppression(boxes, max_bbox_overlap, scores)
+ >>> detections = [detections[i] for i in indices]
+
+ Parameters
+ ----------
+ boxes : ndarray
+ Array of ROIs (x, y, width, height).
+ max_bbox_overlap : float
+ ROIs that overlap more than this values are suppressed.
+ scores : Optional[array_like]
+ Detector confidence score.
+
+ Returns
+ -------
+ List[int]
+ Returns indices of detections that have survived non-maxima suppression.
+
+ """
+ if len(boxes) == 0:
+ return []
+
+ boxes = boxes.astype(np.float)
+ pick = []
+
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2] + boxes[:, 0]
+ y2 = boxes[:, 3] + boxes[:, 1]
+
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ if scores is not None:
+ idxs = np.argsort(scores)
+ else:
+ idxs = np.argsort(y2)
+
+ while len(idxs) > 0:
+ last = len(idxs) - 1
+ i = idxs[last]
+ pick.append(i)
+
+ xx1 = np.maximum(x1[i], x1[idxs[:last]])
+ yy1 = np.maximum(y1[i], y1[idxs[:last]])
+ xx2 = np.minimum(x2[i], x2[idxs[:last]])
+ yy2 = np.minimum(y2[i], y2[idxs[:last]])
+
+ w = np.maximum(0, xx2 - xx1 + 1)
+ h = np.maximum(0, yy2 - yy1 + 1)
+
+ overlap = (w * h) / area[idxs[:last]]
+
+ idxs = np.delete(
+ idxs, np.concatenate(
+ ([last], np.where(overlap > max_bbox_overlap)[0])))
+
+ return pick
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__init__.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b50ec4d5d63870341a9e2c66ab9d48a725d2fe65
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6204224f223c6cd198f45b616310699e83c551b
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a51b8a02b979fa9d0d1b165fcee2564530c8cda
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a85c4505412bd4db68a9fc92c523f1351d49b287
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77e768bdff7fb23b384702acd75354d6a413651f
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..caa8a4041c25ea8fbc9766271b281601f96dfcca
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/detection.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9a2cd52ff09f47a7a3e48d7c0c8058377eb07e2
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1257a1231c1e9c2b78aaf60b6378b9c2acde5e9
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d588b7ae77f90488911c8de26bd82a3812a0812f
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/iou_matching.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d34d7a79586d654abaa2ff9cb5b3ef33b02181ee
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c748513a8fbb21ed18e42f0ee36b6c29ebbe1b3e
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96f39540e1d01772c129c31a61288d2f24af5f95
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/kalman_filter.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..179e11c8f916e5c05165511901ad91d8a09b4829
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0538ca0740413bd57e415fde38dea46c79530277
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bedca06436aba947a087d3f5205cfbf82fd34372
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/linear_assignment.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c26d3d04ff760decdf4631683ebda2ccf884f780
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f145a8f22534047d9e2d0637dbc28ce0ab6d35b3
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1396459181738ddb8cc09cb8fa88b6185d2daa0
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/nn_matching.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b50926dcaf7dd9ae18f35b8fb97cb21ae54e16d8
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d0c10f8a6769fa3f96704386a871c5056bd3b2e
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..816fd551c69d9bfa6fcabd06714c2dffe8f96086
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/track.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee22dfddbc954e4c59f9b45bdda8195ae2c5f2fb
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b3ed928b6d1241475dc18cca023292f9f843447
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03d2504b1a9aed6d6266d03f701ab1e6a14a60cf
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/__pycache__/tracker.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/detection.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..de29a7ac5d2d66aae328f2a5a0cd7ef822cc0403
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/detection.py
@@ -0,0 +1,50 @@
+# vim: expandtab:ts=4:sw=4
+import numpy as np
+
+
+class Detection(object):
+ """
+ This class represents a bounding box detection in a single image.
+
+ Parameters
+ ----------
+ tlwh : array_like
+ Bounding box in format `(x, y, w, h)`.
+ confidence : float
+ Detector confidence score.
+ feature : array_like
+ A feature vector that describes the object contained in this image.
+
+ Attributes
+ ----------
+ tlwh : ndarray
+ Bounding box in format `(top left x, top left y, width, height)`.
+ confidence : ndarray
+ Detector confidence score.
+ feature : ndarray | NoneType
+ A feature vector that describes the object contained in this image.
+
+ """
+
+ def __init__(self, tlwh, confidence, feature, oid):
+ self.tlwh = np.asarray(tlwh, dtype=np.float)
+ self.confidence = float(confidence)
+ self.feature = np.asarray(feature, dtype=np.float32)
+ self.oid = oid
+
+ def to_tlbr(self):
+ """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
+ `(top left, bottom right)`.
+ """
+ ret = self.tlwh.copy()
+ ret[2:] += ret[:2]
+ return ret
+
+ def to_xyah(self):
+ """Convert bounding box to format `(center x, center y, aspect ratio,
+ height)`, where the aspect ratio is `width / height`.
+ """
+ ret = self.tlwh.copy()
+ ret[:2] += ret[2:] / 2
+ ret[2] /= ret[3]
+ return ret
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/iou_matching.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/iou_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..62d5a3f63b70db5e322b6f8766444dd824c010ae
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/iou_matching.py
@@ -0,0 +1,82 @@
+# vim: expandtab:ts=4:sw=4
+from __future__ import absolute_import
+import numpy as np
+from . import linear_assignment
+
+
+def iou(bbox, candidates):
+ """Computer intersection over union.
+
+ Parameters
+ ----------
+ bbox : ndarray
+ A bounding box in format `(top left x, top left y, width, height)`.
+ candidates : ndarray
+ A matrix of candidate bounding boxes (one per row) in the same format
+ as `bbox`.
+
+ Returns
+ -------
+ ndarray
+ The intersection over union in [0, 1] between the `bbox` and each
+ candidate. A higher score means a larger fraction of the `bbox` is
+ occluded by the candidate.
+
+ """
+ bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
+ candidates_tl = candidates[:, :2]
+ candidates_br = candidates[:, :2] + candidates[:, 2:]
+
+ tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
+ np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
+ br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
+ np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
+ wh = np.maximum(0., br - tl)
+
+ area_intersection = wh.prod(axis=1)
+ area_bbox = bbox[2:].prod()
+ area_candidates = candidates[:, 2:].prod(axis=1)
+ return area_intersection / (area_bbox + area_candidates - area_intersection)
+
+
+def iou_cost(tracks, detections, track_indices=None,
+ detection_indices=None):
+ """An intersection over union distance metric.
+
+ Parameters
+ ----------
+ tracks : List[deep_sort.track.Track]
+ A list of tracks.
+ detections : List[deep_sort.detection.Detection]
+ A list of detections.
+ track_indices : Optional[List[int]]
+ A list of indices to tracks that should be matched. Defaults to
+ all `tracks`.
+ detection_indices : Optional[List[int]]
+ A list of indices to detections that should be matched. Defaults
+ to all `detections`.
+
+ Returns
+ -------
+ ndarray
+ Returns a cost matrix of shape
+ len(track_indices), len(detection_indices) where entry (i, j) is
+ `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
+
+ """
+ if track_indices is None:
+ track_indices = np.arange(len(tracks))
+ if detection_indices is None:
+ detection_indices = np.arange(len(detections))
+
+ cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
+ for row, track_idx in enumerate(track_indices):
+ if tracks[track_idx].time_since_update > 1:
+ cost_matrix[row, :] = linear_assignment.INFTY_COST
+ continue
+
+ bbox = tracks[track_idx].to_tlwh()
+ candidates = np.asarray(
+ [detections[i].tlwh for i in detection_indices])
+ cost_matrix[row, :] = 1. - iou(bbox, candidates)
+ return cost_matrix
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/kalman_filter.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/kalman_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..787a76e6a43870a9538647b51fda6a5254ce2d43
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/kalman_filter.py
@@ -0,0 +1,229 @@
+# vim: expandtab:ts=4:sw=4
+import numpy as np
+import scipy.linalg
+
+
+"""
+Table for the 0.95 quantile of the chi-square distribution with N degrees of
+freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
+function and used as Mahalanobis gating threshold.
+"""
+chi2inv95 = {
+ 1: 3.8415,
+ 2: 5.9915,
+ 3: 7.8147,
+ 4: 9.4877,
+ 5: 11.070,
+ 6: 12.592,
+ 7: 14.067,
+ 8: 15.507,
+ 9: 16.919}
+
+
+class KalmanFilter(object):
+ """
+ A simple Kalman filter for tracking bounding boxes in image space.
+
+ The 8-dimensional state space
+
+ x, y, a, h, vx, vy, va, vh
+
+ contains the bounding box center position (x, y), aspect ratio a, height h,
+ and their respective velocities.
+
+ Object motion follows a constant velocity model. The bounding box location
+ (x, y, a, h) is taken as direct observation of the state space (linear
+ observation model).
+
+ """
+
+ def __init__(self):
+ ndim, dt = 4, 1.
+
+ # Create Kalman filter model matrices.
+ self._motion_mat = np.eye(2 * ndim, 2 * ndim)
+ for i in range(ndim):
+ self._motion_mat[i, ndim + i] = dt
+ self._update_mat = np.eye(ndim, 2 * ndim)
+
+ # Motion and observation uncertainty are chosen relative to the current
+ # state estimate. These weights control the amount of uncertainty in
+ # the model. This is a bit hacky.
+ self._std_weight_position = 1. / 20
+ self._std_weight_velocity = 1. / 160
+
+ def initiate(self, measurement):
+ """Create track from unassociated measurement.
+
+ Parameters
+ ----------
+ measurement : ndarray
+ Bounding box coordinates (x, y, a, h) with center position (x, y),
+ aspect ratio a, and height h.
+
+ Returns
+ -------
+ (ndarray, ndarray)
+ Returns the mean vector (8 dimensional) and covariance matrix (8x8
+ dimensional) of the new track. Unobserved velocities are initialized
+ to 0 mean.
+
+ """
+ mean_pos = measurement
+ mean_vel = np.zeros_like(mean_pos)
+ mean = np.r_[mean_pos, mean_vel]
+
+ std = [
+ 2 * self._std_weight_position * measurement[3],
+ 2 * self._std_weight_position * measurement[3],
+ 1e-2,
+ 2 * self._std_weight_position * measurement[3],
+ 10 * self._std_weight_velocity * measurement[3],
+ 10 * self._std_weight_velocity * measurement[3],
+ 1e-5,
+ 10 * self._std_weight_velocity * measurement[3]]
+ covariance = np.diag(np.square(std))
+ return mean, covariance
+
+ def predict(self, mean, covariance):
+ """Run Kalman filter prediction step.
+
+ Parameters
+ ----------
+ mean : ndarray
+ The 8 dimensional mean vector of the object state at the previous
+ time step.
+ covariance : ndarray
+ The 8x8 dimensional covariance matrix of the object state at the
+ previous time step.
+
+ Returns
+ -------
+ (ndarray, ndarray)
+ Returns the mean vector and covariance matrix of the predicted
+ state. Unobserved velocities are initialized to 0 mean.
+
+ """
+ std_pos = [
+ self._std_weight_position * mean[3],
+ self._std_weight_position * mean[3],
+ 1e-2,
+ self._std_weight_position * mean[3]]
+ std_vel = [
+ self._std_weight_velocity * mean[3],
+ self._std_weight_velocity * mean[3],
+ 1e-5,
+ self._std_weight_velocity * mean[3]]
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
+
+ mean = np.dot(self._motion_mat, mean)
+ covariance = np.linalg.multi_dot((
+ self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
+
+ return mean, covariance
+
+ def project(self, mean, covariance):
+ """Project state distribution to measurement space.
+
+ Parameters
+ ----------
+ mean : ndarray
+ The state's mean vector (8 dimensional array).
+ covariance : ndarray
+ The state's covariance matrix (8x8 dimensional).
+
+ Returns
+ -------
+ (ndarray, ndarray)
+ Returns the projected mean and covariance matrix of the given state
+ estimate.
+
+ """
+ std = [
+ self._std_weight_position * mean[3],
+ self._std_weight_position * mean[3],
+ 1e-1,
+ self._std_weight_position * mean[3]]
+ innovation_cov = np.diag(np.square(std))
+
+ mean = np.dot(self._update_mat, mean)
+ covariance = np.linalg.multi_dot((
+ self._update_mat, covariance, self._update_mat.T))
+ return mean, covariance + innovation_cov
+
+ def update(self, mean, covariance, measurement):
+ """Run Kalman filter correction step.
+
+ Parameters
+ ----------
+ mean : ndarray
+ The predicted state's mean vector (8 dimensional).
+ covariance : ndarray
+ The state's covariance matrix (8x8 dimensional).
+ measurement : ndarray
+ The 4 dimensional measurement vector (x, y, a, h), where (x, y)
+ is the center position, a the aspect ratio, and h the height of the
+ bounding box.
+
+ Returns
+ -------
+ (ndarray, ndarray)
+ Returns the measurement-corrected state distribution.
+
+ """
+ projected_mean, projected_cov = self.project(mean, covariance)
+
+ chol_factor, lower = scipy.linalg.cho_factor(
+ projected_cov, lower=True, check_finite=False)
+ kalman_gain = scipy.linalg.cho_solve(
+ (chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
+ check_finite=False).T
+ innovation = measurement - projected_mean
+
+ new_mean = mean + np.dot(innovation, kalman_gain.T)
+ new_covariance = covariance - np.linalg.multi_dot((
+ kalman_gain, projected_cov, kalman_gain.T))
+ return new_mean, new_covariance
+
+ def gating_distance(self, mean, covariance, measurements,
+ only_position=False):
+ """Compute gating distance between state distribution and measurements.
+
+ A suitable distance threshold can be obtained from `chi2inv95`. If
+ `only_position` is False, the chi-square distribution has 4 degrees of
+ freedom, otherwise 2.
+
+ Parameters
+ ----------
+ mean : ndarray
+ Mean vector over the state distribution (8 dimensional).
+ covariance : ndarray
+ Covariance of the state distribution (8x8 dimensional).
+ measurements : ndarray
+ An Nx4 dimensional matrix of N measurements, each in
+ format (x, y, a, h) where (x, y) is the bounding box center
+ position, a the aspect ratio, and h the height.
+ only_position : Optional[bool]
+ If True, distance computation is done with respect to the bounding
+ box center position only.
+
+ Returns
+ -------
+ ndarray
+ Returns an array of length N, where the i-th element contains the
+ squared Mahalanobis distance between (mean, covariance) and
+ `measurements[i]`.
+
+ """
+ mean, covariance = self.project(mean, covariance)
+ if only_position:
+ mean, covariance = mean[:2], covariance[:2, :2]
+ measurements = measurements[:, :2]
+
+ cholesky_factor = np.linalg.cholesky(covariance)
+ d = measurements - mean
+ z = scipy.linalg.solve_triangular(
+ cholesky_factor, d.T, lower=True, check_finite=False,
+ overwrite_b=True)
+ squared_maha = np.sum(z * z, axis=0)
+ return squared_maha
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/linear_assignment.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/linear_assignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..858b71a4ae32ca39f03ff5d0ca0fdcc5963171b0
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/linear_assignment.py
@@ -0,0 +1,192 @@
+# vim: expandtab:ts=4:sw=4
+from __future__ import absolute_import
+import numpy as np
+# from sklearn.utils.linear_assignment_ import linear_assignment
+from scipy.optimize import linear_sum_assignment as linear_assignment
+from . import kalman_filter
+
+
+INFTY_COST = 1e+5
+
+
+def min_cost_matching(
+ distance_metric, max_distance, tracks, detections, track_indices=None,
+ detection_indices=None):
+ """Solve linear assignment problem.
+
+ Parameters
+ ----------
+ distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
+ The distance metric is given a list of tracks and detections as well as
+ a list of N track indices and M detection indices. The metric should
+ return the NxM dimensional cost matrix, where element (i, j) is the
+ association cost between the i-th track in the given track indices and
+ the j-th detection in the given detection_indices.
+ max_distance : float
+ Gating threshold. Associations with cost larger than this value are
+ disregarded.
+ tracks : List[track.Track]
+ A list of predicted tracks at the current time step.
+ detections : List[detection.Detection]
+ A list of detections at the current time step.
+ track_indices : List[int]
+ List of track indices that maps rows in `cost_matrix` to tracks in
+ `tracks` (see description above).
+ detection_indices : List[int]
+ List of detection indices that maps columns in `cost_matrix` to
+ detections in `detections` (see description above).
+
+ Returns
+ -------
+ (List[(int, int)], List[int], List[int])
+ Returns a tuple with the following three entries:
+ * A list of matched track and detection indices.
+ * A list of unmatched track indices.
+ * A list of unmatched detection indices.
+
+ """
+ if track_indices is None:
+ track_indices = np.arange(len(tracks))
+ if detection_indices is None:
+ detection_indices = np.arange(len(detections))
+
+ if len(detection_indices) == 0 or len(track_indices) == 0:
+ return [], track_indices, detection_indices # Nothing to match.
+
+ cost_matrix = distance_metric(
+ tracks, detections, track_indices, detection_indices)
+ cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
+
+ row_indices, col_indices = linear_assignment(cost_matrix)
+
+ matches, unmatched_tracks, unmatched_detections = [], [], []
+ for col, detection_idx in enumerate(detection_indices):
+ if col not in col_indices:
+ unmatched_detections.append(detection_idx)
+ for row, track_idx in enumerate(track_indices):
+ if row not in row_indices:
+ unmatched_tracks.append(track_idx)
+ for row, col in zip(row_indices, col_indices):
+ track_idx = track_indices[row]
+ detection_idx = detection_indices[col]
+ if cost_matrix[row, col] > max_distance:
+ unmatched_tracks.append(track_idx)
+ unmatched_detections.append(detection_idx)
+ else:
+ matches.append((track_idx, detection_idx))
+ return matches, unmatched_tracks, unmatched_detections
+
+
+def matching_cascade(
+ distance_metric, max_distance, cascade_depth, tracks, detections,
+ track_indices=None, detection_indices=None):
+ """Run matching cascade.
+
+ Parameters
+ ----------
+ distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
+ The distance metric is given a list of tracks and detections as well as
+ a list of N track indices and M detection indices. The metric should
+ return the NxM dimensional cost matrix, where element (i, j) is the
+ association cost between the i-th track in the given track indices and
+ the j-th detection in the given detection indices.
+ max_distance : float
+ Gating threshold. Associations with cost larger than this value are
+ disregarded.
+ cascade_depth: int
+ The cascade depth, should be se to the maximum track age.
+ tracks : List[track.Track]
+ A list of predicted tracks at the current time step.
+ detections : List[detection.Detection]
+ A list of detections at the current time step.
+ track_indices : Optional[List[int]]
+ List of track indices that maps rows in `cost_matrix` to tracks in
+ `tracks` (see description above). Defaults to all tracks.
+ detection_indices : Optional[List[int]]
+ List of detection indices that maps columns in `cost_matrix` to
+ detections in `detections` (see description above). Defaults to all
+ detections.
+
+ Returns
+ -------
+ (List[(int, int)], List[int], List[int])
+ Returns a tuple with the following three entries:
+ * A list of matched track and detection indices.
+ * A list of unmatched track indices.
+ * A list of unmatched detection indices.
+
+ """
+ if track_indices is None:
+ track_indices = list(range(len(tracks)))
+ if detection_indices is None:
+ detection_indices = list(range(len(detections)))
+
+ unmatched_detections = detection_indices
+ matches = []
+ for level in range(cascade_depth):
+ if len(unmatched_detections) == 0: # No detections left
+ break
+
+ track_indices_l = [
+ k for k in track_indices
+ if tracks[k].time_since_update == 1 + level
+ ]
+ if len(track_indices_l) == 0: # Nothing to match at this level
+ continue
+
+ matches_l, _, unmatched_detections = \
+ min_cost_matching(
+ distance_metric, max_distance, tracks, detections,
+ track_indices_l, unmatched_detections)
+ matches += matches_l
+ unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
+ return matches, unmatched_tracks, unmatched_detections
+
+
+def gate_cost_matrix(
+ kf, cost_matrix, tracks, detections, track_indices, detection_indices,
+ gated_cost=INFTY_COST, only_position=False):
+ """Invalidate infeasible entries in cost matrix based on the state
+ distributions obtained by Kalman filtering.
+
+ Parameters
+ ----------
+ kf : The Kalman filter.
+ cost_matrix : ndarray
+ The NxM dimensional cost matrix, where N is the number of track indices
+ and M is the number of detection indices, such that entry (i, j) is the
+ association cost between `tracks[track_indices[i]]` and
+ `detections[detection_indices[j]]`.
+ tracks : List[track.Track]
+ A list of predicted tracks at the current time step.
+ detections : List[detection.Detection]
+ A list of detections at the current time step.
+ track_indices : List[int]
+ List of track indices that maps rows in `cost_matrix` to tracks in
+ `tracks` (see description above).
+ detection_indices : List[int]
+ List of detection indices that maps columns in `cost_matrix` to
+ detections in `detections` (see description above).
+ gated_cost : Optional[float]
+ Entries in the cost matrix corresponding to infeasible associations are
+ set this value. Defaults to a very large value.
+ only_position : Optional[bool]
+ If True, only the x, y position of the state distribution is considered
+ during gating. Defaults to False.
+
+ Returns
+ -------
+ ndarray
+ Returns the modified cost matrix.
+
+ """
+ gating_dim = 2 if only_position else 4
+ gating_threshold = kalman_filter.chi2inv95[gating_dim]
+ measurements = np.asarray(
+ [detections[i].to_xyah() for i in detection_indices])
+ for row, track_idx in enumerate(track_indices):
+ track = tracks[track_idx]
+ gating_distance = kf.gating_distance(
+ track.mean, track.covariance, measurements, only_position)
+ cost_matrix[row, gating_distance > gating_threshold] = gated_cost
+ return cost_matrix
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/nn_matching.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/nn_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..21e5b4f478fead21d38227ce2eac34556bd1179e
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/nn_matching.py
@@ -0,0 +1,176 @@
+# vim: expandtab:ts=4:sw=4
+import numpy as np
+
+
+def _pdist(a, b):
+ """Compute pair-wise squared distance between points in `a` and `b`.
+
+ Parameters
+ ----------
+ a : array_like
+ An NxM matrix of N samples of dimensionality M.
+ b : array_like
+ An LxM matrix of L samples of dimensionality M.
+
+ Returns
+ -------
+ ndarray
+ Returns a matrix of size len(a), len(b) such that eleement (i, j)
+ contains the squared distance between `a[i]` and `b[j]`.
+
+ """
+ a, b = np.asarray(a), np.asarray(b)
+ if len(a) == 0 or len(b) == 0:
+ return np.zeros((len(a), len(b)))
+ a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
+ r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
+ r2 = np.clip(r2, 0., float(np.inf))
+ return r2
+
+
+def _cosine_distance(a, b, data_is_normalized=False):
+ """Compute pair-wise cosine distance between points in `a` and `b`.
+
+ Parameters
+ ----------
+ a : array_like
+ An NxM matrix of N samples of dimensionality M.
+ b : array_like
+ An LxM matrix of L samples of dimensionality M.
+ data_is_normalized : Optional[bool]
+ If True, assumes rows in a and b are unit length vectors.
+ Otherwise, a and b are explicitly normalized to lenght 1.
+
+ Returns
+ -------
+ ndarray
+ Returns a matrix of size len(a), len(b) such that eleement (i, j)
+ contains the squared distance between `a[i]` and `b[j]`.
+
+ """
+ if not data_is_normalized:
+ a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
+ b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
+ return 1. - np.dot(a, b.T)
+
+
+def _nn_euclidean_distance(x, y):
+ """ Helper function for nearest neighbor distance metric (Euclidean).
+
+ Parameters
+ ----------
+ x : ndarray
+ A matrix of N row-vectors (sample points).
+ y : ndarray
+ A matrix of M row-vectors (query points).
+
+ Returns
+ -------
+ ndarray
+ A vector of length M that contains for each entry in `y` the
+ smallest Euclidean distance to a sample in `x`.
+
+ """
+ distances = _pdist(x, y)
+ return np.maximum(0.0, distances.min(axis=0))
+
+
+def _nn_cosine_distance(x, y):
+ """ Helper function for nearest neighbor distance metric (cosine).
+
+ Parameters
+ ----------
+ x : ndarray
+ A matrix of N row-vectors (sample points).
+ y : ndarray
+ A matrix of M row-vectors (query points).
+
+ Returns
+ -------
+ ndarray
+ A vector of length M that contains for each entry in `y` the
+ smallest cosine distance to a sample in `x`.
+
+ """
+ distances = _cosine_distance(x, y)
+ return distances.min(axis=0)
+
+
+class NearestNeighborDistanceMetric(object):
+ """
+ A nearest neighbor distance metric that, for each target, returns
+ the closest distance to any sample that has been observed so far.
+
+ Parameters
+ ----------
+ metric : str
+ Either "euclidean" or "cosine".
+ matching_threshold: float
+ The matching threshold. Samples with larger distance are considered an
+ invalid match.
+ budget : Optional[int]
+ If not None, fix samples per class to at most this number. Removes
+ the oldest samples when the budget is reached.
+
+ Attributes
+ ----------
+ samples : Dict[int -> List[ndarray]]
+ A dictionary that maps from target identities to the list of samples
+ that have been observed so far.
+
+ """
+
+ def __init__(self, metric, matching_threshold, budget=None):
+
+ if metric == "euclidean":
+ self._metric = _nn_euclidean_distance
+ elif metric == "cosine":
+ self._metric = _nn_cosine_distance
+ else:
+ raise ValueError(
+ "Invalid metric; must be either 'euclidean' or 'cosine'")
+ self.matching_threshold = matching_threshold
+ self.budget = budget
+ self.samples = {}
+
+ def partial_fit(self, features, targets, active_targets):
+ """Update the distance metric with new data.
+
+ Parameters
+ ----------
+ features : ndarray
+ An NxM matrix of N features of dimensionality M.
+ targets : ndarray
+ An integer array of associated target identities.
+ active_targets : List[int]
+ A list of targets that are currently present in the scene.
+
+ """
+ for feature, target in zip(features, targets):
+ self.samples.setdefault(target, []).append(feature)
+ if self.budget is not None:
+ self.samples[target] = self.samples[target][-self.budget:]
+ self.samples = {k: self.samples[k] for k in active_targets}
+
+ def distance(self, features, targets):
+ """Compute distance between features and targets.
+
+ Parameters
+ ----------
+ features : ndarray
+ An NxM matrix of N features of dimensionality M.
+ targets : List[int]
+ A list of targets to match the given `features` against.
+
+ Returns
+ -------
+ ndarray
+ Returns a cost matrix of shape len(targets), len(features), where
+ element (i, j) contains the closest squared distance between
+ `targets[i]` and `features[j]`.
+
+ """
+ cost_matrix = np.zeros((len(targets), len(features)))
+ for i, target in enumerate(targets):
+ cost_matrix[i, :] = self._metric(self.samples[target], features)
+ return cost_matrix
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/preprocessing.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..5493b127f602dec398efac4269c00d31a3650ce9
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/preprocessing.py
@@ -0,0 +1,73 @@
+# vim: expandtab:ts=4:sw=4
+import numpy as np
+import cv2
+
+
+def non_max_suppression(boxes, max_bbox_overlap, scores=None):
+ """Suppress overlapping detections.
+
+ Original code from [1]_ has been adapted to include confidence score.
+
+ .. [1] http://www.pyimagesearch.com/2015/02/16/
+ faster-non-maximum-suppression-python/
+
+ Examples
+ --------
+
+ >>> boxes = [d.roi for d in detections]
+ >>> scores = [d.confidence for d in detections]
+ >>> indices = non_max_suppression(boxes, max_bbox_overlap, scores)
+ >>> detections = [detections[i] for i in indices]
+
+ Parameters
+ ----------
+ boxes : ndarray
+ Array of ROIs (x, y, width, height).
+ max_bbox_overlap : float
+ ROIs that overlap more than this values are suppressed.
+ scores : Optional[array_like]
+ Detector confidence score.
+
+ Returns
+ -------
+ List[int]
+ Returns indices of detections that have survived non-maxima suppression.
+
+ """
+ if len(boxes) == 0:
+ return []
+
+ boxes = boxes.astype(np.float)
+ pick = []
+
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2] + boxes[:, 0]
+ y2 = boxes[:, 3] + boxes[:, 1]
+
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ if scores is not None:
+ idxs = np.argsort(scores)
+ else:
+ idxs = np.argsort(y2)
+
+ while len(idxs) > 0:
+ last = len(idxs) - 1
+ i = idxs[last]
+ pick.append(i)
+
+ xx1 = np.maximum(x1[i], x1[idxs[:last]])
+ yy1 = np.maximum(y1[i], y1[idxs[:last]])
+ xx2 = np.minimum(x2[i], x2[idxs[:last]])
+ yy2 = np.minimum(y2[i], y2[idxs[:last]])
+
+ w = np.maximum(0, xx2 - xx1 + 1)
+ h = np.maximum(0, yy2 - yy1 + 1)
+
+ overlap = (w * h) / area[idxs[:last]]
+
+ idxs = np.delete(
+ idxs, np.concatenate(
+ ([last], np.where(overlap > max_bbox_overlap)[0])))
+
+ return pick
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/track.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/track.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b9814a5afa44c22ac54a7c4f8084025d4443fab
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/track.py
@@ -0,0 +1,170 @@
+# vim: expandtab:ts=4:sw=4
+
+
+class TrackState:
+ """
+ Enumeration type for the single target track state. Newly created tracks are
+ classified as `tentative` until enough evidence has been collected. Then,
+ the track state is changed to `confirmed`. Tracks that are no longer alive
+ are classified as `deleted` to mark them for removal from the set of active
+ tracks.
+
+ """
+
+ Tentative = 1
+ Confirmed = 2
+ Deleted = 3
+
+
+class Track:
+ """
+ A single target track with state space `(x, y, a, h)` and associated
+ velocities, where `(x, y)` is the center of the bounding box, `a` is the
+ aspect ratio and `h` is the height.
+
+ Parameters
+ ----------
+ mean : ndarray
+ Mean vector of the initial state distribution.
+ covariance : ndarray
+ Covariance matrix of the initial state distribution.
+ track_id : int
+ A unique track identifier.
+ n_init : int
+ Number of consecutive detections before the track is confirmed. The
+ track state is set to `Deleted` if a miss occurs within the first
+ `n_init` frames.
+ max_age : int
+ The maximum number of consecutive misses before the track state is
+ set to `Deleted`.
+ feature : Optional[ndarray]
+ Feature vector of the detection this track originates from. If not None,
+ this feature is added to the `features` cache.
+
+ Attributes
+ ----------
+ mean : ndarray
+ Mean vector of the initial state distribution.
+ covariance : ndarray
+ Covariance matrix of the initial state distribution.
+ track_id : int
+ A unique track identifier.
+ hits : int
+ Total number of measurement updates.
+ age : int
+ Total number of frames since first occurance.
+ time_since_update : int
+ Total number of frames since last measurement update.
+ state : TrackState
+ The current track state.
+ features : List[ndarray]
+ A cache of features. On each measurement update, the associated feature
+ vector is added to this list.
+
+ """
+
+ def __init__(self, mean, covariance, track_id, n_init, max_age,oid,
+ feature=None):
+ self.mean = mean
+ self.covariance = covariance
+ self.track_id = track_id
+ self.hits = 1
+ self.age = 1
+ self.time_since_update = 0
+ self.oid = oid
+
+ self.state = TrackState.Tentative
+ self.features = []
+ if feature is not None:
+ self.features.append(feature)
+
+ self._n_init = n_init
+ self._max_age = max_age
+
+ def to_tlwh(self):
+ """Get current position in bounding box format `(top left x, top left y,
+ width, height)`.
+
+ Returns
+ -------
+ ndarray
+ The bounding box.
+
+ """
+ ret = self.mean[:4].copy()
+ ret[2] *= ret[3]
+ ret[:2] -= ret[2:] / 2
+ return ret
+
+ def to_tlbr(self):
+ """Get current position in bounding box format `(min x, miny, max x,
+ max y)`.
+
+ Returns
+ -------
+ ndarray
+ The bounding box.
+
+ """
+ ret = self.to_tlwh()
+ ret[2:] = ret[:2] + ret[2:]
+ return ret
+
+ def increment_age(self):
+ self.age += 1
+ self.time_since_update += 1
+
+ def predict(self, kf):
+ """Propagate the state distribution to the current time step using a
+ Kalman filter prediction step.
+
+ Parameters
+ ----------
+ kf : kalman_filter.KalmanFilter
+ The Kalman filter.
+
+ """
+ self.mean, self.covariance = kf.predict(self.mean, self.covariance)
+ self.increment_age()
+
+ def update(self, kf, detection):
+ """Perform Kalman filter measurement update step and update the feature
+ cache.
+
+ Parameters
+ ----------
+ kf : kalman_filter.KalmanFilter
+ The Kalman filter.
+ detection : Detection
+ The associated detection.
+
+ """
+ self.mean, self.covariance = kf.update(
+ self.mean, self.covariance, detection.to_xyah())
+ self.features.append(detection.feature)
+
+ self.hits += 1
+ self.time_since_update = 0
+ if self.state == TrackState.Tentative and self.hits >= self._n_init:
+ self.state = TrackState.Confirmed
+
+ def mark_missed(self):
+ """Mark this track as missed (no association at the current time step).
+ """
+ if self.state == TrackState.Tentative:
+ self.state = TrackState.Deleted
+ elif self.time_since_update > self._max_age:
+ self.state = TrackState.Deleted
+
+ def is_tentative(self):
+ """Returns True if this track is tentative (unconfirmed).
+ """
+ return self.state == TrackState.Tentative
+
+ def is_confirmed(self):
+ """Returns True if this track is confirmed."""
+ return self.state == TrackState.Confirmed
+
+ def is_deleted(self):
+ """Returns True if this track is dead and should be deleted."""
+ return self.state == TrackState.Deleted
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/tracker.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..89a57afe950f95033f093f3055d293edb6775a38
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/deep_sort/sort/tracker.py
@@ -0,0 +1,144 @@
+# vim: expandtab:ts=4:sw=4
+from __future__ import absolute_import
+import numpy as np
+from . import kalman_filter
+from . import linear_assignment
+from . import iou_matching
+from .track import Track
+
+
+class Tracker:
+ """
+ This is the multi-target tracker.
+
+ Parameters
+ ----------
+ metric : nn_matching.NearestNeighborDistanceMetric
+ A distance metric for measurement-to-track association.
+ max_age : int
+ Maximum number of missed misses before a track is deleted.
+ n_init : int
+ Number of consecutive detections before the track is confirmed. The
+ track state is set to `Deleted` if a miss occurs within the first
+ `n_init` frames.
+
+ Attributes
+ ----------
+ metric : nn_matching.NearestNeighborDistanceMetric
+ The distance metric used for measurement to track association.
+ max_age : int
+ Maximum number of missed misses before a track is deleted.
+ n_init : int
+ Number of frames that a track remains in initialization phase.
+ kf : kalman_filter.KalmanFilter
+ A Kalman filter to filter target trajectories in image space.
+ tracks : List[Track]
+ The list of active tracks at the current time step.
+
+ """
+
+ def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3):
+ self.metric = metric
+ self.max_iou_distance = max_iou_distance
+ self.max_age = max_age
+ self.n_init = n_init
+
+ self.kf = kalman_filter.KalmanFilter()
+ self.tracks = []
+ self._next_id = 1
+
+ def predict(self):
+ """Propagate track state distributions one time step forward.
+
+ This function should be called once every time step, before `update`.
+ """
+ for track in self.tracks:
+ track.predict(self.kf)
+
+ def increment_ages(self):
+ for track in self.tracks:
+ track.increment_age()
+ track.mark_missed()
+
+ def update(self, detections, obj_id):
+ """Perform measurement update and track management.
+
+ Parameters
+ ----------
+ detections : List[deep_sort.detection.Detection]
+ A list of detections at the current time step.
+
+ """
+ # Run matching cascade.
+ matches, unmatched_tracks, unmatched_detections = \
+ self._match(detections)
+
+ # Update track set.
+ for track_idx, detection_idx in matches:
+ self.tracks[track_idx].update(
+ self.kf, detections[detection_idx])
+ for track_idx in unmatched_tracks:
+ self.tracks[track_idx].mark_missed()
+ for detection_idx in unmatched_detections:
+ if obj_id in [2, 3, 5, 7]:
+ self._initiate_track(detections[detection_idx])
+ self.tracks = [t for t in self.tracks if not t.is_deleted()]
+
+ # Update distance metric.
+ active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
+ features, targets = [], []
+ for track in self.tracks:
+ if not track.is_confirmed():
+ continue
+ features += track.features
+ targets += [track.track_id for _ in track.features]
+ track.features = []
+ self.metric.partial_fit(
+ np.asarray(features), np.asarray(targets), active_targets)
+
+ def _match(self, detections):
+
+ def gated_metric(tracks, dets, track_indices, detection_indices):
+ features = np.array([dets[i].feature for i in detection_indices])
+ targets = np.array([tracks[i].track_id for i in track_indices])
+ cost_matrix = self.metric.distance(features, targets)
+ cost_matrix = linear_assignment.gate_cost_matrix(
+ self.kf, cost_matrix, tracks, dets, track_indices,
+ detection_indices)
+
+ return cost_matrix
+
+ # Split track set into confirmed and unconfirmed tracks.
+ confirmed_tracks = [
+ i for i, t in enumerate(self.tracks) if t.is_confirmed()]
+ unconfirmed_tracks = [
+ i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
+
+ # Associate confirmed tracks using appearance features.
+ matches_a, unmatched_tracks_a, unmatched_detections = \
+ linear_assignment.matching_cascade(
+ gated_metric, self.metric.matching_threshold, self.max_age,
+ self.tracks, detections, confirmed_tracks)
+
+ # Associate remaining tracks together with unconfirmed tracks using IOU.
+ iou_track_candidates = unconfirmed_tracks + [
+ k for k in unmatched_tracks_a if
+ self.tracks[k].time_since_update == 1]
+ unmatched_tracks_a = [
+ k for k in unmatched_tracks_a if
+ self.tracks[k].time_since_update != 1]
+ matches_b, unmatched_tracks_b, unmatched_detections = \
+ linear_assignment.min_cost_matching(
+ iou_matching.iou_cost, self.max_iou_distance, self.tracks,
+ detections, iou_track_candidates, unmatched_detections)
+
+ matches = matches_a + matches_b
+ unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
+ return matches, unmatched_tracks, unmatched_detections
+
+ def _initiate_track(self, detection):
+ mean, covariance = self.kf.initiate(detection.to_xyah())
+ self.tracks.append(Track(
+ mean, covariance, self._next_id, self.n_init, self.max_age,detection.oid,
+ detection.feature))
+ self._next_id += 1
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__init__.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07edf46307ee2874484960eb17ce63550316e280
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d6bbd89ccc52618477fd6e95fc7c9f3fc5a9a70
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d15acc453295b27c3c0dac0f666716c208c0ec7
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-310.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..381162a0615bda1038a50ef8397e98645d7642a0
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-310.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-37.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1625add17b72de0386e69b7ba0f97037f3298a5b
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-37.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-38.pyc b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee1ae32f392675e799fa05bb6867d01b0d8d3ef0
Binary files /dev/null and b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/__pycache__/parser.cpython-38.pyc differ
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/asserts.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/asserts.py
new file mode 100644
index 0000000000000000000000000000000000000000..59a73cc04025762d6490fcd2945a747d963def32
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/asserts.py
@@ -0,0 +1,13 @@
+from os import environ
+
+
+def assert_in(file, files_to_check):
+ if file not in files_to_check:
+ raise AssertionError("{} does not exist in the list".format(str(file)))
+ return True
+
+
+def assert_in_env(check_list: list):
+ for item in check_list:
+ assert_in(item, environ.keys())
+ return True
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/draw.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/draw.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc7cb537978e86805d5d9789785a8afe67df9030
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/draw.py
@@ -0,0 +1,36 @@
+import numpy as np
+import cv2
+
+palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
+
+
+def compute_color_for_labels(label):
+ """
+ Simple function that adds fixed color depending on the class
+ """
+ color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
+ return tuple(color)
+
+
+def draw_boxes(img, bbox, identities=None, offset=(0,0)):
+ for i,box in enumerate(bbox):
+ x1,y1,x2,y2 = [int(i) for i in box]
+ x1 += offset[0]
+ x2 += offset[0]
+ y1 += offset[1]
+ y2 += offset[1]
+ # box text and bar
+ id = int(identities[i]) if identities is not None else 0
+ color = compute_color_for_labels(id)
+ label = '{}{:d}'.format("", id)
+ t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2 , 2)[0]
+ cv2.rectangle(img,(x1, y1),(x2,y2),color,3)
+ cv2.rectangle(img,(x1, y1),(x1+t_size[0]+3,y1+t_size[1]+4), color,-1)
+ cv2.putText(img,label,(x1,y1+t_size[1]+4), cv2.FONT_HERSHEY_PLAIN, 2, [255,255,255], 2)
+ return img
+
+
+
+if __name__ == '__main__':
+ for i in range(82):
+ print(compute_color_for_labels(i))
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/evaluation.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..100179407181933d59809b25400d115cfa789867
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/evaluation.py
@@ -0,0 +1,103 @@
+import os
+import numpy as np
+import copy
+import motmetrics as mm
+mm.lap.default_solver = 'lap'
+from utils.io import read_results, unzip_objs
+
+
+class Evaluator(object):
+
+ def __init__(self, data_root, seq_name, data_type):
+ self.data_root = data_root
+ self.seq_name = seq_name
+ self.data_type = data_type
+
+ self.load_annotations()
+ self.reset_accumulator()
+
+ def load_annotations(self):
+ assert self.data_type == 'mot'
+
+ gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
+ self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
+ self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
+
+ def reset_accumulator(self):
+ self.acc = mm.MOTAccumulator(auto_id=True)
+
+ def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
+ # results
+ trk_tlwhs = np.copy(trk_tlwhs)
+ trk_ids = np.copy(trk_ids)
+
+ # gts
+ gt_objs = self.gt_frame_dict.get(frame_id, [])
+ gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
+
+ # ignore boxes
+ ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
+ ignore_tlwhs = unzip_objs(ignore_objs)[0]
+
+
+ # remove ignored results
+ keep = np.ones(len(trk_tlwhs), dtype=bool)
+ iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
+ if len(iou_distance) > 0:
+ match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
+ match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
+ match_ious = iou_distance[match_is, match_js]
+
+ match_js = np.asarray(match_js, dtype=int)
+ match_js = match_js[np.logical_not(np.isnan(match_ious))]
+ keep[match_js] = False
+ trk_tlwhs = trk_tlwhs[keep]
+ trk_ids = trk_ids[keep]
+
+ # get distance matrix
+ iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
+
+ # acc
+ self.acc.update(gt_ids, trk_ids, iou_distance)
+
+ if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
+ events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
+ else:
+ events = None
+ return events
+
+ def eval_file(self, filename):
+ self.reset_accumulator()
+
+ result_frame_dict = read_results(filename, self.data_type, is_gt=False)
+ frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
+ for frame_id in frames:
+ trk_objs = result_frame_dict.get(frame_id, [])
+ trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
+ self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
+
+ return self.acc
+
+ @staticmethod
+ def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')):
+ names = copy.deepcopy(names)
+ if metrics is None:
+ metrics = mm.metrics.motchallenge_metrics
+ metrics = copy.deepcopy(metrics)
+
+ mh = mm.metrics.create()
+ summary = mh.compute_many(
+ accs,
+ metrics=metrics,
+ names=names,
+ generate_overall=True
+ )
+
+ return summary
+
+ @staticmethod
+ def save_summary(summary, filename):
+ import pandas as pd
+ writer = pd.ExcelWriter(filename)
+ summary.to_excel(writer)
+ writer.save()
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/io.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dc9afd24019cd930eef6c21ab9f579313dd3b3a
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/io.py
@@ -0,0 +1,133 @@
+import os
+from typing import Dict
+import numpy as np
+
+# from utils.log import get_logger
+
+
+def write_results(filename, results, data_type):
+ if data_type == 'mot':
+ save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n'
+ elif data_type == 'kitti':
+ save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
+ else:
+ raise ValueError(data_type)
+
+ with open(filename, 'w') as f:
+ for frame_id, tlwhs, track_ids in results:
+ if data_type == 'kitti':
+ frame_id -= 1
+ for tlwh, track_id in zip(tlwhs, track_ids):
+ if track_id < 0:
+ continue
+ x1, y1, w, h = tlwh
+ x2, y2 = x1 + w, y1 + h
+ line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
+ f.write(line)
+
+
+# def write_results(filename, results_dict: Dict, data_type: str):
+# if not filename:
+# return
+# path = os.path.dirname(filename)
+# if not os.path.exists(path):
+# os.makedirs(path)
+
+# if data_type in ('mot', 'mcmot', 'lab'):
+# save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
+# elif data_type == 'kitti':
+# save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
+# else:
+# raise ValueError(data_type)
+
+# with open(filename, 'w') as f:
+# for frame_id, frame_data in results_dict.items():
+# if data_type == 'kitti':
+# frame_id -= 1
+# for tlwh, track_id in frame_data:
+# if track_id < 0:
+# continue
+# x1, y1, w, h = tlwh
+# x2, y2 = x1 + w, y1 + h
+# line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
+# f.write(line)
+# logger.info('Save results to {}'.format(filename))
+
+
+def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
+ if data_type in ('mot', 'lab'):
+ read_fun = read_mot_results
+ else:
+ raise ValueError('Unknown data type: {}'.format(data_type))
+
+ return read_fun(filename, is_gt, is_ignore)
+
+
+"""
+labels={'ped', ... % 1
+'person_on_vhcl', ... % 2
+'car', ... % 3
+'bicycle', ... % 4
+'mbike', ... % 5
+'non_mot_vhcl', ... % 6
+'static_person', ... % 7
+'distractor', ... % 8
+'occluder', ... % 9
+'occluder_on_grnd', ... %10
+'occluder_full', ... % 11
+'reflection', ... % 12
+'crowd' ... % 13
+};
+"""
+
+
+def read_mot_results(filename, is_gt, is_ignore):
+ valid_labels = {1}
+ ignore_labels = {2, 7, 8, 12}
+ results_dict = dict()
+ if os.path.isfile(filename):
+ with open(filename, 'r') as f:
+ for line in f.readlines():
+ linelist = line.split(',')
+ if len(linelist) < 7:
+ continue
+ fid = int(linelist[0])
+ if fid < 1:
+ continue
+ results_dict.setdefault(fid, list())
+
+ if is_gt:
+ if 'MOT16-' in filename or 'MOT17-' in filename:
+ label = int(float(linelist[7]))
+ mark = int(float(linelist[6]))
+ if mark == 0 or label not in valid_labels:
+ continue
+ score = 1
+ elif is_ignore:
+ if 'MOT16-' in filename or 'MOT17-' in filename:
+ label = int(float(linelist[7]))
+ vis_ratio = float(linelist[8])
+ if label not in ignore_labels and vis_ratio >= 0:
+ continue
+ else:
+ continue
+ score = 1
+ else:
+ score = float(linelist[6])
+
+ tlwh = tuple(map(float, linelist[2:6]))
+ target_id = int(linelist[1])
+
+ results_dict[fid].append((tlwh, target_id, score))
+
+ return results_dict
+
+
+def unzip_objs(objs):
+ if len(objs) > 0:
+ tlwhs, ids, scores = zip(*objs)
+ else:
+ tlwhs, ids, scores = [], [], []
+ tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
+
+ return tlwhs, ids, scores
\ No newline at end of file
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/json_logger.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/json_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..0afd0b45df736866c49473db78286685d77660ac
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/json_logger.py
@@ -0,0 +1,383 @@
+"""
+References:
+ https://medium.com/analytics-vidhya/creating-a-custom-logging-mechanism-for-real-time-object-detection-using-tdd-4ca2cfcd0a2f
+"""
+import json
+from os import makedirs
+from os.path import exists, join
+from datetime import datetime
+
+
+class JsonMeta(object):
+ HOURS = 3
+ MINUTES = 59
+ SECONDS = 59
+ PATH_TO_SAVE = 'LOGS'
+ DEFAULT_FILE_NAME = 'remaining'
+
+
+class BaseJsonLogger(object):
+ """
+ This is the base class that returns __dict__ of its own
+ it also returns the dicts of objects in the attributes that are list instances
+
+ """
+
+ def dic(self):
+ # returns dicts of objects
+ out = {}
+ for k, v in self.__dict__.items():
+ if hasattr(v, 'dic'):
+ out[k] = v.dic()
+ elif isinstance(v, list):
+ out[k] = self.list(v)
+ else:
+ out[k] = v
+ return out
+
+ @staticmethod
+ def list(values):
+ # applies the dic method on items in the list
+ return [v.dic() if hasattr(v, 'dic') else v for v in values]
+
+
+class Label(BaseJsonLogger):
+ """
+ For each bounding box there are various categories with confidences. Label class keeps track of that information.
+ """
+
+ def __init__(self, category: str, confidence: float):
+ self.category = category
+ self.confidence = confidence
+
+
+class Bbox(BaseJsonLogger):
+ """
+ This module stores the information for each frame and use them in JsonParser
+ Attributes:
+ labels (list): List of label module.
+ top (int):
+ left (int):
+ width (int):
+ height (int):
+
+ Args:
+ bbox_id (float):
+ top (int):
+ left (int):
+ width (int):
+ height (int):
+
+ References:
+ Check Label module for better understanding.
+
+
+ """
+
+ def __init__(self, bbox_id, top, left, width, height):
+ self.labels = []
+ self.bbox_id = bbox_id
+ self.top = top
+ self.left = left
+ self.width = width
+ self.height = height
+
+ def add_label(self, category, confidence):
+ # adds category and confidence only if top_k is not exceeded.
+ self.labels.append(Label(category, confidence))
+
+ def labels_full(self, value):
+ return len(self.labels) == value
+
+
+class Frame(BaseJsonLogger):
+ """
+ This module stores the information for each frame and use them in JsonParser
+ Attributes:
+ timestamp (float): The elapsed time of captured frame
+ frame_id (int): The frame number of the captured video
+ bboxes (list of Bbox objects): Stores the list of bbox objects.
+
+ References:
+ Check Bbox class for better information
+
+ Args:
+ timestamp (float):
+ frame_id (int):
+
+ """
+
+ def __init__(self, frame_id: int, timestamp: float = None):
+ self.frame_id = frame_id
+ self.timestamp = timestamp
+ self.bboxes = []
+
+ def add_bbox(self, bbox_id: int, top: int, left: int, width: int, height: int):
+ bboxes_ids = [bbox.bbox_id for bbox in self.bboxes]
+ if bbox_id not in bboxes_ids:
+ self.bboxes.append(Bbox(bbox_id, top, left, width, height))
+ else:
+ raise ValueError("Frame with id: {} already has a Bbox with id: {}".format(self.frame_id, bbox_id))
+
+ def add_label_to_bbox(self, bbox_id: int, category: str, confidence: float):
+ bboxes = {bbox.id: bbox for bbox in self.bboxes}
+ if bbox_id in bboxes.keys():
+ res = bboxes.get(bbox_id)
+ res.add_label(category, confidence)
+ else:
+ raise ValueError('the bbox with id: {} does not exists!'.format(bbox_id))
+
+
+class BboxToJsonLogger(BaseJsonLogger):
+ """
+ Ω This module is designed to automate the task of logging jsons. An example json is used
+ to show the contents of json file shortly
+ Example:
+ {
+ "video_details": {
+ "frame_width": 1920,
+ "frame_height": 1080,
+ "frame_rate": 20,
+ "video_name": "/home/gpu/codes/MSD/pedestrian_2/project/public/camera1.avi"
+ },
+ "frames": [
+ {
+ "frame_id": 329,
+ "timestamp": 3365.1254
+ "bboxes": [
+ {
+ "labels": [
+ {
+ "category": "pedestrian",
+ "confidence": 0.9
+ }
+ ],
+ "bbox_id": 0,
+ "top": 1257,
+ "left": 138,
+ "width": 68,
+ "height": 109
+ }
+ ]
+ }],
+
+ Attributes:
+ frames (dict): It's a dictionary that maps each frame_id to json attributes.
+ video_details (dict): information about video file.
+ top_k_labels (int): shows the allowed number of labels
+ start_time (datetime object): we use it to automate the json output by time.
+
+ Args:
+ top_k_labels (int): shows the allowed number of labels
+
+ """
+
+ def __init__(self, top_k_labels: int = 1):
+ self.frames = {}
+ self.video_details = self.video_details = dict(frame_width=None, frame_height=None, frame_rate=None,
+ video_name=None)
+ self.top_k_labels = top_k_labels
+ self.start_time = datetime.now()
+
+ def set_top_k(self, value):
+ self.top_k_labels = value
+
+ def frame_exists(self, frame_id: int) -> bool:
+ """
+ Args:
+ frame_id (int):
+
+ Returns:
+ bool: true if frame_id is recognized
+ """
+ return frame_id in self.frames.keys()
+
+ def add_frame(self, frame_id: int, timestamp: float = None) -> None:
+ """
+ Args:
+ frame_id (int):
+ timestamp (float): opencv captured frame time property
+
+ Raises:
+ ValueError: if frame_id would not exist in class frames attribute
+
+ Returns:
+ None
+
+ """
+ if not self.frame_exists(frame_id):
+ self.frames[frame_id] = Frame(frame_id, timestamp)
+ else:
+ raise ValueError("Frame id: {} already exists".format(frame_id))
+
+ def bbox_exists(self, frame_id: int, bbox_id: int) -> bool:
+ """
+ Args:
+ frame_id:
+ bbox_id:
+
+ Returns:
+ bool: if bbox exists in frame bboxes list
+ """
+ bboxes = []
+ if self.frame_exists(frame_id=frame_id):
+ bboxes = [bbox.bbox_id for bbox in self.frames[frame_id].bboxes]
+ return bbox_id in bboxes
+
+ def find_bbox(self, frame_id: int, bbox_id: int):
+ """
+
+ Args:
+ frame_id:
+ bbox_id:
+
+ Returns:
+ bbox_id (int):
+
+ Raises:
+ ValueError: if bbox_id does not exist in the bbox list of specific frame.
+ """
+ if not self.bbox_exists(frame_id, bbox_id):
+ raise ValueError("frame with id: {} does not contain bbox with id: {}".format(frame_id, bbox_id))
+ bboxes = {bbox.bbox_id: bbox for bbox in self.frames[frame_id].bboxes}
+ return bboxes.get(bbox_id)
+
+ def add_bbox_to_frame(self, frame_id: int, bbox_id: int, top: int, left: int, width: int, height: int) -> None:
+ """
+
+ Args:
+ frame_id (int):
+ bbox_id (int):
+ top (int):
+ left (int):
+ width (int):
+ height (int):
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: if bbox_id already exist in frame information with frame_id
+ ValueError: if frame_id does not exist in frames attribute
+ """
+ if self.frame_exists(frame_id):
+ frame = self.frames[frame_id]
+ if not self.bbox_exists(frame_id, bbox_id):
+ frame.add_bbox(bbox_id, top, left, width, height)
+ else:
+ raise ValueError(
+ "frame with frame_id: {} already contains the bbox with id: {} ".format(frame_id, bbox_id))
+ else:
+ raise ValueError("frame with frame_id: {} does not exist".format(frame_id))
+
+ def add_label_to_bbox(self, frame_id: int, bbox_id: int, category: str, confidence: float):
+ """
+ Args:
+ frame_id:
+ bbox_id:
+ category:
+ confidence: the confidence value returned from yolo detection
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: if labels quota (top_k_labels) exceeds.
+ """
+ bbox = self.find_bbox(frame_id, bbox_id)
+ if not bbox.labels_full(self.top_k_labels):
+ bbox.add_label(category, confidence)
+ else:
+ raise ValueError("labels in frame_id: {}, bbox_id: {} is fulled".format(frame_id, bbox_id))
+
+ def add_video_details(self, frame_width: int = None, frame_height: int = None, frame_rate: int = None,
+ video_name: str = None):
+ self.video_details['frame_width'] = frame_width
+ self.video_details['frame_height'] = frame_height
+ self.video_details['frame_rate'] = frame_rate
+ self.video_details['video_name'] = video_name
+
+ def output(self):
+ output = {'video_details': self.video_details}
+ result = list(self.frames.values())
+ output['frames'] = [item.dic() for item in result]
+ return output
+
+ def json_output(self, output_name):
+ """
+ Args:
+ output_name:
+
+ Returns:
+ None
+
+ Notes:
+ It creates the json output with `output_name` name.
+ """
+ if not output_name.endswith('.json'):
+ output_name += '.json'
+ with open(output_name, 'w') as file:
+ json.dump(self.output(), file)
+ file.close()
+
+ def set_start(self):
+ self.start_time = datetime.now()
+
+ def schedule_output_by_time(self, output_dir=JsonMeta.PATH_TO_SAVE, hours: int = 0, minutes: int = 0,
+ seconds: int = 60) -> None:
+ """
+ Notes:
+ Creates folder and then periodically stores the jsons on that address.
+
+ Args:
+ output_dir (str): the directory where output files will be stored
+ hours (int):
+ minutes (int):
+ seconds (int):
+
+ Returns:
+ None
+
+ """
+ end = datetime.now()
+ interval = 0
+ interval += abs(min([hours, JsonMeta.HOURS]) * 3600)
+ interval += abs(min([minutes, JsonMeta.MINUTES]) * 60)
+ interval += abs(min([seconds, JsonMeta.SECONDS]))
+ diff = (end - self.start_time).seconds
+
+ if diff > interval:
+ output_name = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '.json'
+ if not exists(output_dir):
+ makedirs(output_dir)
+ output = join(output_dir, output_name)
+ self.json_output(output_name=output)
+ self.frames = {}
+ self.start_time = datetime.now()
+
+ def schedule_output_by_frames(self, frames_quota, frame_counter, output_dir=JsonMeta.PATH_TO_SAVE):
+ """
+ saves as the number of frames quota increases higher.
+ :param frames_quota:
+ :param frame_counter:
+ :param output_dir:
+ :return:
+ """
+ pass
+
+ def flush(self, output_dir):
+ """
+ Notes:
+ We use this function to output jsons whenever possible.
+ like the time that we exit the while loop of opencv.
+
+ Args:
+ output_dir:
+
+ Returns:
+ None
+
+ """
+ filename = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '-remaining.json'
+ output = join(output_dir, filename)
+ self.json_output(output_name=output)
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/log.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d48757dca88f35e9ea2cd1ca16e41bac9976a45
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/log.py
@@ -0,0 +1,17 @@
+import logging
+
+
+def get_logger(name='root'):
+ formatter = logging.Formatter(
+ # fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s')
+ fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
+
+ handler = logging.StreamHandler()
+ handler.setFormatter(formatter)
+
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.INFO)
+ logger.addHandler(handler)
+ return logger
+
+
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/parser.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..449d1aaac85c917e223f61535e0f24bd9e197489
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/parser.py
@@ -0,0 +1,41 @@
+import os
+import yaml
+from easydict import EasyDict as edict
+
+
+class YamlParser(edict):
+ """
+ This is yaml parser based on EasyDict.
+ """
+
+ def __init__(self, cfg_dict=None, config_file=None):
+ if cfg_dict is None:
+ cfg_dict = {}
+
+ if config_file is not None:
+ assert(os.path.isfile(config_file))
+ with open(config_file, 'r') as fo:
+ yaml_ = yaml.load(fo.read(), Loader=yaml.FullLoader)
+ cfg_dict.update(yaml_)
+
+ super(YamlParser, self).__init__(cfg_dict)
+
+ def merge_from_file(self, config_file):
+ with open(config_file, 'r') as fo:
+ yaml_ = yaml.load(fo.read(), Loader=yaml.FullLoader)
+ self.update(yaml_)
+
+ def merge_from_dict(self, config_dict):
+ self.update(config_dict)
+
+
+def get_config(config_file=None):
+ return YamlParser(config_file=config_file)
+
+
+if __name__ == "__main__":
+ cfg = YamlParser(config_file="../configs/yolov3.yaml")
+ cfg.merge_from_file("../configs/deep_sort.yaml")
+
+ import ipdb
+ ipdb.set_trace()
diff --git a/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/tools.py b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..965fb69c2df41510fd740a4ab57d8fc7b81012de
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/deep_sort_pytorch/utils/tools.py
@@ -0,0 +1,39 @@
+from functools import wraps
+from time import time
+
+
+def is_video(ext: str):
+ """
+ Returns true if ext exists in
+ allowed_exts for video files.
+
+ Args:
+ ext:
+
+ Returns:
+
+ """
+
+ allowed_exts = ('.mp4', '.webm', '.ogg', '.avi', '.wmv', '.mkv', '.3gp')
+ return any((ext.endswith(x) for x in allowed_exts))
+
+
+def tik_tok(func):
+ """
+ keep track of time for each process.
+ Args:
+ func:
+
+ Returns:
+
+ """
+ @wraps(func)
+ def _time_it(*args, **kwargs):
+ start = time()
+ try:
+ return func(*args, **kwargs)
+ finally:
+ end_ = time()
+ print("time: {:.03f}s, fps: {:.03f}".format(end_ - start, 1 / (end_ - start)))
+
+ return _time_it
diff --git a/ultralytics/yolo/v8/detect/dozhd-mashina-trollejbus-ulica-doroga.jpg b/ultralytics/yolo/v8/detect/dozhd-mashina-trollejbus-ulica-doroga.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e6da21642cecdee7c0c546f849003144282ed944
Binary files /dev/null and b/ultralytics/yolo/v8/detect/dozhd-mashina-trollejbus-ulica-doroga.jpg differ
diff --git a/ultralytics/yolo/v8/detect/night_motorbikes.mp4 b/ultralytics/yolo/v8/detect/night_motorbikes.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..96bfdb9fed1bd21a85a2cbb09b6f65c23e34f203
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/night_motorbikes.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6eb91fe4eb724e1e07555583f19b108457a80218e6f3e0075f97454a68a8f3e4
+size 8356796
diff --git a/ultralytics/yolo/v8/detect/predict.log b/ultralytics/yolo/v8/detect/predict.log
new file mode 100644
index 0000000000000000000000000000000000000000..86b6db3f7e9f0db978b150734caf974421580304
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/predict.log
@@ -0,0 +1,3 @@
+[2023-05-08 02:21:28,475][root.tracker][INFO] - Loading weights from deep_sort_pytorch/deep_sort/deep/checkpoint/ckpt.t7... Done!
+[2023-05-08 02:28:45,123][root.tracker][INFO] - Loading weights from deep_sort_pytorch/deep_sort/deep/checkpoint/ckpt.t7... Done!
+[2023-05-08 02:32:02,713][root.tracker][INFO] - Loading weights from deep_sort_pytorch/deep_sort/deep/checkpoint/ckpt.t7... Done!
diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..26a8eace467dfb8bc13e01bf8bbabc6b9de22346
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/predict.py
@@ -0,0 +1,303 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import hydra
+import torch
+import argparse
+import time
+from pathlib import Path
+import math
+import cv2
+import torch
+import torch.backends.cudnn as cudnn
+from numpy import random
+from ultralytics.yolo.engine.predictor import BasePredictor
+from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
+from ultralytics.yolo.utils.checks import check_imgsz
+from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
+
+import cv2
+from deep_sort_pytorch.utils.parser import get_config
+from deep_sort_pytorch.deep_sort import DeepSort
+from collections import deque
+import numpy as np
+palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
+deq = {}
+indices = [0] * 100
+c = 0
+
+
+
+
+deepsort = None
+
+object_counter = {}
+
+speed_line_queue = {}
+def estimatespeed(Location1, Location2, h, w):
+ #Euclidean Distance Formula
+ d_pixel = math.sqrt(math.pow(Location2[0] - Location1[0], 2) + math.pow(Location2[1] - Location1[1], 2))
+ # defining thr pixels per meter
+ ppm = max(h, w) // 10
+ d_meters = d_pixel/ppm
+ time_constant = 15*3.6
+ #distance = speed/time
+ speed = d_meters * time_constant
+
+ return int(speed)
+def init_tracker():
+ global deepsort
+ cfg_deep = get_config()
+ cfg_deep.merge_from_file("deep_sort_pytorch/configs/deep_sort.yaml")
+
+ deepsort= DeepSort(cfg_deep.DEEPSORT.REID_CKPT,
+ max_dist=cfg_deep.DEEPSORT.MAX_DIST, min_confidence=cfg_deep.DEEPSORT.MIN_CONFIDENCE,
+ nms_max_overlap=cfg_deep.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg_deep.DEEPSORT.MAX_IOU_DISTANCE,
+ max_age=cfg_deep.DEEPSORT.MAX_AGE, n_init=cfg_deep.DEEPSORT.N_INIT, nn_budget=cfg_deep.DEEPSORT.NN_BUDGET,
+ use_cuda=True)
+##########################################################################################
+def xyxy_to_xywh(*xyxy):
+ """" Calculates the relative bounding box from absolute pixel values. """
+ bbox_left = min([xyxy[0].item(), xyxy[2].item()])
+ bbox_top = min([xyxy[1].item(), xyxy[3].item()])
+ bbox_w = abs(xyxy[0].item() - xyxy[2].item())
+ bbox_h = abs(xyxy[1].item() - xyxy[3].item())
+ x_c = (bbox_left + bbox_w / 2)
+ y_c = (bbox_top + bbox_h / 2)
+ w = bbox_w
+ h = bbox_h
+ return x_c, y_c, w, h
+
+
+def compute_color_for_labels(label):
+ """
+ Simple function that adds fixed color depending on the class
+ """
+ if label == 7: #truck
+ color = (85,45,255)
+ elif label == 2: # Car
+ color = (222,82,175)
+ elif label == 3: # Motorcycle
+ color = (0, 204, 255)
+ elif label == 5: # Bus
+ color = (0, 149, 255)
+ else:
+ color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
+ return tuple(color)
+
+def draw_border(img, pt1, pt2, color, thickness, r, d):
+ x1,y1 = pt1
+ x2,y2 = pt2
+ # Top left
+ cv2.line(img, (x1 + r, y1), (x1 + r + d, y1), color, thickness)
+ cv2.line(img, (x1, y1 + r), (x1, y1 + r + d), color, thickness)
+ cv2.ellipse(img, (x1 + r, y1 + r), (r, r), 180, 0, 90, color, thickness)
+ # Top right
+ cv2.line(img, (x2 - r, y1), (x2 - r - d, y1), color, thickness)
+ cv2.line(img, (x2, y1 + r), (x2, y1 + r + d), color, thickness)
+ cv2.ellipse(img, (x2 - r, y1 + r), (r, r), 270, 0, 90, color, thickness)
+ # Bottom left
+ cv2.line(img, (x1 + r, y2), (x1 + r + d, y2), color, thickness)
+ cv2.line(img, (x1, y2 - r), (x1, y2 - r - d), color, thickness)
+ cv2.ellipse(img, (x1 + r, y2 - r), (r, r), 90, 0, 90, color, thickness)
+ # Bottom right
+ cv2.line(img, (x2 - r, y2), (x2 - r - d, y2), color, thickness)
+ cv2.line(img, (x2, y2 - r), (x2, y2 - r - d), color, thickness)
+ cv2.ellipse(img, (x2 - r, y2 - r), (r, r), 0, 0, 90, color, thickness)
+
+ cv2.rectangle(img, (x1 + r, y1), (x2 - r, y2), color, -1, cv2.LINE_AA)
+ cv2.rectangle(img, (x1, y1 + r), (x2, y2 - r - d), color, -1, cv2.LINE_AA)
+
+ cv2.circle(img, (x1 +r, y1+r), 2, color, 12)
+ cv2.circle(img, (x2 -r, y1+r), 2, color, 12)
+ cv2.circle(img, (x1 +r, y2-r), 2, color, 12)
+ cv2.circle(img, (x2 -r, y2-r), 2, color, 12)
+
+ return img
+
+def UI_box(x, img, color=None, label=None, line_thickness=None):
+ # Plots one bounding box on image img
+ tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
+ color = color or [random.randint(0, 255) for _ in range(3)]
+ c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
+ cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
+ if label:
+ tf = max(tl - 1, 1) # font thickness
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
+
+ img = draw_border(img, (c1[0], c1[1] - t_size[1] -3), (c1[0] + t_size[0], c1[1]+3), color, 1, 8, 2)
+
+ cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+
+
+def ccw(A,B,C):
+ return (C[1]-A[1]) * (B[0]-A[0]) > (B[1]-A[1]) * (C[0]-A[0])
+
+
+def draw_boxes(img, bbox, names,object_id, identities=None, offset=(0, 0)):
+ height, width, _ = img.shape
+ # remove tracked point from buffer if object is lost
+ global c
+
+
+ for key in list(deq):
+ if key not in identities:
+ deq.pop(key)
+
+ weights = [0,0,int(6.72),int(1.638),0,30,0,int(18.75)]
+ speeds1 = [[0],[0],[0],[0],[0],[0],[0],[0]]
+ speeds = [0] * 8
+
+ for i, box in enumerate(bbox):
+ obj_name = names[object_id[i]]
+ x1, y1, x2, y2 = [int(i) for i in box]
+ x1 += offset[0]
+ x2 += offset[0]
+ y1 += offset[1]
+ y2 += offset[1]
+
+ # code to find center of bottom edge
+ center = (int((x2+x1)/ 2), int((y2+y2)/2))
+
+ # get ID of object
+
+ id = int(identities[i]) if identities is not None else 0
+
+ # create new buffer for new object
+ if id not in deq:
+ deq[id] = deque(maxlen= 64)
+ if object_id[i] in [2, 3, 5, 7]:
+ c +=1
+ indices[id] = c
+ speed_line_queue[id] = []
+ color = compute_color_for_labels(object_id[i])
+
+
+ label = '{}{:d}'.format("", indices[id]) + ":"+ '%s' % (obj_name)
+
+
+ # add center to buffer
+ deq[id].appendleft(center)
+ if len(deq[id]) >= 2:
+ object_speed = estimatespeed(deq[id][1], deq[id][0], x2-x1, y2-y1)
+ speed_line_queue[id].append(object_speed)
+ if obj_name not in object_counter:
+ object_counter[obj_name] = 1
+
+ #motorcycle_weight = 1.638
+ #car_weight = 6.72
+ #truck_weight = 18.75
+ #bus_weight = 30
+
+ try:
+ spd = sum(speed_line_queue[id])//len(speed_line_queue[id])
+ speeds[object_id[i]] += spd
+ label = label + " v=" + str(spd) + " m=" + str(weights[object_id[i]])
+
+ except:
+ pass
+ UI_box(box, img, label=label, color=color, line_thickness=2)
+ #cv2.putText(img, f"{speeds}", (500, 50), 0, 1, [0, 255, 0], thickness=2, lineType=cv2.LINE_AA)
+ cv2.putText(img, f"pulse: {sum(np.multiply(speeds, weights))}", (500, 50), 0, 1, [0, 255, 0], thickness=2, lineType=cv2.LINE_AA)
+ #for i, object_speed in enumerate(speeds):
+ # object_speed = sum(object_speed)*weights[i]
+
+
+
+ return img
+
+
+class DetectionPredictor(BasePredictor):
+
+ def get_annotator(self, img):
+ return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
+
+ def preprocess(self, img):
+ img = torch.from_numpy(img).to(self.model.device)
+ img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
+ img /= 255 # 0 - 255 to 0.0 - 1.0
+ return img
+
+ def postprocess(self, preds, img, orig_img):
+ preds = ops.non_max_suppression(preds,
+ self.args.conf,
+ self.args.iou,
+ classes = [2, 3, 5, 7],
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det)
+
+ for i, pred in enumerate(preds):
+ shape = orig_img[i].shape if self.webcam else orig_img.shape
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
+
+ return preds
+
+ def write_results(self, idx, preds, batch):
+ p, im, im0 = batch
+ all_outputs = []
+ log_string = ""
+ if len(im.shape) == 3:
+ im = im[None] # expand for batch dim
+ self.seen += 1
+ im0 = im0.copy()
+ if self.webcam: # batch_size >= 1
+ log_string += f'{idx}: '
+ frame = self.dataset.count
+ else:
+ frame = getattr(self.dataset, 'frame', 0)
+
+ self.data_path = p
+ save_path = str(self.save_dir / p.name) # im.jpg
+ self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
+ log_string += '%gx%g ' % im.shape[2:] # print string
+ self.annotator = self.get_annotator(im0)
+
+ det = preds[idx]
+ all_outputs.append(det)
+ if len(det) == 0:
+ return log_string
+
+ count = 0
+ for c in det[:, 5].unique():
+ count += 1
+ n = (det[:, 5] == c).sum() # detections per class
+ cv2.putText(im0, f"{n} {self.model.names[int(c)]}", (11, count*50), 0, 1, [0, 255, 0], thickness=2, lineType=cv2.LINE_AA)
+ log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
+ # write
+ gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
+ xywh_bboxs = []
+ confs = []
+ oids = []
+ outputs = []
+ for *xyxy, conf, cls in reversed(det):
+ x_c, y_c, bbox_w, bbox_h = xyxy_to_xywh(*xyxy)
+ xywh_obj = [x_c, y_c, bbox_w, bbox_h]
+ xywh_bboxs.append(xywh_obj)
+ confs.append([conf.item()])
+ oids.append(int(cls))
+ xywhs = torch.Tensor(xywh_bboxs)
+ confss = torch.Tensor(confs)
+
+ outputs = deepsort.update(xywhs, confss, oids, im0)
+
+ if len(outputs) > 0:
+ bbox_xyxy = outputs[:, :4]
+ identities = outputs[:, -2]
+ object_id = outputs[:, -1]
+
+ draw_boxes(im0, bbox_xyxy, self.model.names, object_id,identities)
+
+ return log_string
+
+
+@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
+def predict(cfg):
+ init_tracker()
+ cfg.model = cfg.model or "yolov8n.pt"
+ cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
+ cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
+ predictor = DetectionPredictor(cfg)
+ predictor()
+
+
+if __name__ == "__main__":
+ predict()
diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b0322c8dc15adb02033b86120e8a2f1df72b7da
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/train.py
@@ -0,0 +1,217 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+from copy import copy
+
+import hydra
+import torch
+import torch.nn as nn
+
+from ultralytics.nn.tasks import DetectionModel
+from ultralytics.yolo import v8
+from ultralytics.yolo.data import build_dataloader
+from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
+from ultralytics.yolo.engine.trainer import BaseTrainer
+from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr
+from ultralytics.yolo.utils.loss import BboxLoss
+from ultralytics.yolo.utils.ops import xywh2xyxy
+from ultralytics.yolo.utils.plotting import plot_images, plot_results
+from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
+from ultralytics.yolo.utils.torch_utils import de_parallel
+
+
+# BaseTrainer python usage
+class DetectionTrainer(BaseTrainer):
+
+ def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
+ # TODO: manage splits differently
+ # calculate stride - check if model is initialized
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
+ return create_dataloader(path=dataset_path,
+ imgsz=self.args.imgsz,
+ batch_size=batch_size,
+ stride=gs,
+ hyp=dict(self.args),
+ augment=mode == "train",
+ cache=self.args.cache,
+ pad=0 if mode == "train" else 0.5,
+ rect=self.args.rect,
+ rank=rank,
+ workers=self.args.workers,
+ close_mosaic=self.args.close_mosaic != 0,
+ prefix=colorstr(f'{mode}: '),
+ shuffle=mode == "train",
+ seed=self.args.seed)[0] if self.args.v5loader else \
+ build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
+
+ def preprocess_batch(self, batch):
+ batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
+ return batch
+
+ def set_model_attributes(self):
+ nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
+ self.args.box *= 3 / nl # scale to layers
+ # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
+ self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
+ self.model.nc = self.data["nc"] # attach number of classes to model
+ self.model.args = self.args # attach hyperparameters to model
+ # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
+ self.model.names = self.data["names"]
+
+ def get_model(self, cfg=None, weights=None, verbose=True):
+ model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
+ if weights:
+ model.load(weights)
+
+ return model
+
+ def get_validator(self):
+ self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
+ return v8.detect.DetectionValidator(self.test_loader,
+ save_dir=self.save_dir,
+ logger=self.console,
+ args=copy(self.args))
+
+ def criterion(self, preds, batch):
+ if not hasattr(self, 'compute_loss'):
+ self.compute_loss = Loss(de_parallel(self.model))
+ return self.compute_loss(preds, batch)
+
+ def label_loss_items(self, loss_items=None, prefix="train"):
+ """
+ Returns a loss dict with labelled training loss items tensor
+ """
+ # Not needed for classification but necessary for segmentation & detection
+ keys = [f"{prefix}/{x}" for x in self.loss_names]
+ if loss_items is not None:
+ loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
+ return dict(zip(keys, loss_items))
+ else:
+ return keys
+
+ def progress_string(self):
+ return ('\n' + '%11s' *
+ (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
+
+ def plot_training_samples(self, batch, ni):
+ plot_images(images=batch["img"],
+ batch_idx=batch["batch_idx"],
+ cls=batch["cls"].squeeze(-1),
+ bboxes=batch["bboxes"],
+ paths=batch["im_file"],
+ fname=self.save_dir / f"train_batch{ni}.jpg")
+
+ def plot_metrics(self):
+ plot_results(file=self.csv) # save results.png
+
+
+# Criterion class for computing training losses
+class Loss:
+
+ def __init__(self, model): # model must be de-paralleled
+
+ device = next(model.parameters()).device # get model device
+ h = model.args # hyperparameters
+
+ m = model.model[-1] # Detect() module
+ self.bce = nn.BCEWithLogitsLoss(reduction='none')
+ self.hyp = h
+ self.stride = m.stride # model strides
+ self.nc = m.nc # number of classes
+ self.no = m.no
+ self.reg_max = m.reg_max
+ self.device = device
+
+ self.use_dfl = m.reg_max > 1
+ self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
+ self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
+ self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
+
+ def preprocess(self, targets, batch_size, scale_tensor):
+ if targets.shape[0] == 0:
+ out = torch.zeros(batch_size, 0, 5, device=self.device)
+ else:
+ i = targets[:, 0] # image index
+ _, counts = i.unique(return_counts=True)
+ out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
+ for j in range(batch_size):
+ matches = i == j
+ n = matches.sum()
+ if n:
+ out[j, :n] = targets[matches, 1:]
+ out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
+ return out
+
+ def bbox_decode(self, anchor_points, pred_dist):
+ if self.use_dfl:
+ b, a, c = pred_dist.shape # batch, anchors, channels
+ pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
+ # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
+ # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
+ return dist2bbox(pred_dist, anchor_points, xywh=False)
+
+ def __call__(self, preds, batch):
+ loss = torch.zeros(3, device=self.device) # box, cls, dfl
+ feats = preds[1] if isinstance(preds, tuple) else preds
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+ (self.reg_max * 4, self.nc), 1)
+
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+
+ dtype = pred_scores.dtype
+ batch_size = pred_scores.shape[0]
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+ # targets
+ targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+ gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
+
+ # pboxes
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
+
+ _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
+ pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+ anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
+
+ target_bboxes /= stride_tensor
+ target_scores_sum = target_scores.sum()
+
+ # cls loss
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
+ loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
+
+ # bbox loss
+ if fg_mask.sum():
+ loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
+ target_scores_sum, fg_mask)
+
+ loss[0] *= self.hyp.box # box gain
+ loss[1] *= self.hyp.cls # cls gain
+ loss[2] *= self.hyp.dfl # dfl gain
+
+ return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
+
+
+@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
+def train(cfg):
+ cfg.model = cfg.model or "yolov8n.yaml"
+ cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
+ # trainer = DetectionTrainer(cfg)
+ # trainer.train()
+ from ultralytics import YOLO
+ model = YOLO(cfg.model)
+ model.train(**cfg)
+
+
+if __name__ == "__main__":
+ """
+ CLI usage:
+ python ultralytics/yolo/v8/detect/train.py model=yolov8n.yaml data=coco128 epochs=100 imgsz=640
+
+ TODO:
+ yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
+ """
+ train()
diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb4f9039ad0b23186ea1f0bf71dd9e1cf1c3726
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/val.py
@@ -0,0 +1,241 @@
+# Ultralytics YOLO π, GPL-3.0 license
+
+import os
+from pathlib import Path
+
+import hydra
+import numpy as np
+import torch
+
+from ultralytics.yolo.data import build_dataloader
+from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
+from ultralytics.yolo.engine.validator import BaseValidator
+from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr, ops, yaml_load
+from ultralytics.yolo.utils.checks import check_file, check_requirements
+from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
+from ultralytics.yolo.utils.plotting import output_to_target, plot_images
+from ultralytics.yolo.utils.torch_utils import de_parallel
+
+
+class DetectionValidator(BaseValidator):
+
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
+ super().__init__(dataloader, save_dir, pbar, logger, args)
+ self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None
+ self.is_coco = False
+ self.class_map = None
+ self.metrics = DetMetrics(save_dir=self.save_dir, plot=self.args.plots)
+ self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
+ self.niou = self.iouv.numel()
+
+ def preprocess(self, batch):
+ batch["img"] = batch["img"].to(self.device, non_blocking=True)
+ batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
+ for k in ["batch_idx", "cls", "bboxes"]:
+ batch[k] = batch[k].to(self.device)
+
+ nb, _, height, width = batch["img"].shape
+ batch["bboxes"] *= torch.tensor((width, height, width, height), device=self.device) # to pixels
+ self.lb = [torch.cat([batch["cls"], batch["bboxes"]], dim=-1)[batch["batch_idx"] == i]
+ for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
+
+ return batch
+
+ def init_metrics(self, model):
+ head = model.model[-1] if self.training else model.model.model[-1]
+ val = self.data.get('val', '') # validation path
+ self.is_coco = isinstance(val, str) and val.endswith(f'coco{os.sep}val2017.txt') # is COCO dataset
+ self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
+ self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
+ self.nc = head.nc
+ self.names = model.names
+ self.metrics.names = self.names
+ self.confusion_matrix = ConfusionMatrix(nc=self.nc)
+ self.seen = 0
+ self.jdict = []
+ self.stats = []
+
+ def get_desc(self):
+ return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)")
+
+ def postprocess(self, preds):
+ preds = ops.non_max_suppression(preds,
+ self.args.conf,
+ self.args.iou,
+ labels=self.lb,
+ multi_label=True,
+ agnostic=self.args.single_cls,
+ max_det=self.args.max_det)
+ return preds
+
+ def update_metrics(self, preds, batch):
+ # Metrics
+ for si, pred in enumerate(preds):
+ idx = batch["batch_idx"] == si
+ cls = batch["cls"][idx]
+ bbox = batch["bboxes"][idx]
+ nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
+ shape = batch["ori_shape"][si]
+ correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
+ self.seen += 1
+
+ if npr == 0:
+ if nl:
+ self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
+ if self.args.plots:
+ self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
+ continue
+
+ # Predictions
+ if self.args.single_cls:
+ pred[:, 5] = 0
+ predn = pred.clone()
+ ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape,
+ ratio_pad=batch["ratio_pad"][si]) # native-space pred
+
+ # Evaluate
+ if nl:
+ tbox = ops.xywh2xyxy(bbox) # target boxes
+ ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape,
+ ratio_pad=batch["ratio_pad"][si]) # native-space labels
+ labelsn = torch.cat((cls, tbox), 1) # native-space labels
+ correct_bboxes = self._process_batch(predn, labelsn)
+ # TODO: maybe remove these `self.` arguments as they already are member variable
+ if self.args.plots:
+ self.confusion_matrix.process_batch(predn, labelsn)
+ self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
+
+ # Save
+ if self.args.save_json:
+ self.pred_to_json(predn, batch["im_file"][si])
+ # if self.args.save_txt:
+ # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
+
+ def get_stats(self):
+ stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
+ if len(stats) and stats[0].any():
+ self.metrics.process(*stats)
+ self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
+ return self.metrics.results_dict
+
+ def print_results(self):
+ pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
+ self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
+ if self.nt_per_class.sum() == 0:
+ self.logger.warning(
+ f'WARNING β οΈ no labels found in {self.args.task} set, can not compute metrics without labels')
+
+ # Print results per class
+ if (self.args.verbose or not self.training) and self.nc > 1 and len(self.stats):
+ for i, c in enumerate(self.metrics.ap_class_index):
+ self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
+
+ if self.args.plots:
+ self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
+
+ def _process_batch(self, detections, labels):
+ """
+ Return correct prediction matrix
+ Arguments:
+ detections (array[N, 6]), x1, y1, x2, y2, conf, class
+ labels (array[M, 5]), class, x1, y1, x2, y2
+ Returns:
+ correct (array[N, 10]), for 10 IoU levels
+ """
+ iou = box_iou(labels[:, 1:], detections[:, :4])
+ correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
+ correct_class = labels[:, 0:1] == detections[:, 5]
+ for i in range(len(self.iouv)):
+ x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
+ if x[0].shape[0]:
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
+ 1).cpu().numpy() # [label, detect, iou]
+ if x[0].shape[0] > 1:
+ matches = matches[matches[:, 2].argsort()[::-1]]
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
+ # matches = matches[matches[:, 2].argsort()[::-1]]
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
+ correct[matches[:, 1].astype(int), i] = True
+ return torch.tensor(correct, dtype=torch.bool, device=detections.device)
+
+ def get_dataloader(self, dataset_path, batch_size):
+ # TODO: manage splits differently
+ # calculate stride - check if model is initialized
+ gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
+ return create_dataloader(path=dataset_path,
+ imgsz=self.args.imgsz,
+ batch_size=batch_size,
+ stride=gs,
+ hyp=dict(self.args),
+ cache=False,
+ pad=0.5,
+ rect=True,
+ workers=self.args.workers,
+ prefix=colorstr(f'{self.args.mode}: '),
+ shuffle=False,
+ seed=self.args.seed)[0] if self.args.v5loader else \
+ build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
+
+ def plot_val_samples(self, batch, ni):
+ plot_images(batch["img"],
+ batch["batch_idx"],
+ batch["cls"].squeeze(-1),
+ batch["bboxes"],
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+ names=self.names)
+
+ def plot_predictions(self, batch, preds, ni):
+ plot_images(batch["img"],
+ *output_to_target(preds, max_det=15),
+ paths=batch["im_file"],
+ fname=self.save_dir / f'val_batch{ni}_pred.jpg',
+ names=self.names) # pred
+
+ def pred_to_json(self, predn, filename):
+ stem = Path(filename).stem
+ image_id = int(stem) if stem.isnumeric() else stem
+ box = ops.xyxy2xywh(predn[:, :4]) # xywh
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
+ for p, b in zip(predn.tolist(), box.tolist()):
+ self.jdict.append({
+ 'image_id': image_id,
+ 'category_id': self.class_map[int(p[5])],
+ 'bbox': [round(x, 3) for x in b],
+ 'score': round(p[4], 5)})
+
+ def eval_json(self, stats):
+ if self.args.save_json and self.is_coco and len(self.jdict):
+ anno_json = self.data['path'] / "annotations/instances_val2017.json" # annotations
+ pred_json = self.save_dir / "predictions.json" # predictions
+ self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
+ try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
+ check_requirements('pycocotools>=2.0.6')
+ from pycocotools.coco import COCO # noqa
+ from pycocotools.cocoeval import COCOeval # noqa
+
+ for x in anno_json, pred_json:
+ assert x.is_file(), f"{x} file not found"
+ anno = COCO(str(anno_json)) # init annotations api
+ pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
+ eval = COCOeval(anno, pred, 'bbox')
+ if self.is_coco:
+ eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
+ eval.evaluate()
+ eval.accumulate()
+ eval.summarize()
+ stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
+ except Exception as e:
+ self.logger.warning(f'pycocotools unable to run: {e}')
+ return stats
+
+
+@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
+def val(cfg):
+ cfg.data = cfg.data or "coco128.yaml"
+ validator = DetectionValidator(args=cfg)
+ validator(model=cfg.model)
+
+
+if __name__ == "__main__":
+ val()
diff --git a/ultralytics/yolo/v8/detect/yolov8x6.pt b/ultralytics/yolo/v8/detect/yolov8x6.pt
new file mode 100644
index 0000000000000000000000000000000000000000..aadd05db746bb9f5714712de3c8494f761884348
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/yolov8x6.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:565a1ee7c0b3d230cd63ecf37ca8de1b752c8c14661ecbc72788986401904535
+size 195529012