diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..16e9bf2e960e0afb27e2396d102dc961a311d46f --- /dev/null +++ b/config.py @@ -0,0 +1,18 @@ +from transformers import PretrainedConfig + +class SSLConfig(PretrainedConfig): + model_type = "ssl-aasist" + def __init__( + self, + filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]], + gat_dims = [64, 32], + pool_ratios = [0.5, 0.5, 0.5, 0.5], + temperatures = [2.0, 2.0, 100.0, 100.0], + **kwargs, + ): + + self.filts = filts + self.gat_dims = gat_dims + self.pool_ratios = pool_ratios + self.temperatures = temperatures + super().__init__(**kwargs) \ No newline at end of file diff --git a/fairseq/.github/CODEOWNERS b/fairseq/.github/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..b79aa2ff06e1ffc0607dce614847b097f703cba3 --- /dev/null +++ b/fairseq/.github/CODEOWNERS @@ -0,0 +1,21 @@ +# Setting up CODEOWNERS for UST related codebase +# Documentation for open sourced models relevant to UST +examples/speech_to_text @kahne @sravyapopuri388 @jmp84 +examples/speech_to_speech @an918tw @sravyapopuri388 @jmp84 +examples/speech_synthesis @kahne @jmp84 +examples/simultaneous_translation @kahne @jmp84 +examples/speech_text_joint_to_text @yuntang @jmp84 + +# Speech related models relevant to UST +fairseq/models/speech_to_speech @sravyapopuri388 @jmp84 +fairseq/models/speech_to_text @kahne @sravyapopuri388 @jmp84 +fairseq/models/text_to_speech @kahne @jmp84 + +# CONFORMER IMPLEMENTATION +fairseq/modules/conformer_layer.py @sravyapopuri388 @jmp84 +fairseq/modules/espnet_multihead_attention.py @sravyapopuri388 @jmp84 +fairseq/modules/rotary_positional_embedding.py @sravyapopuri388 @jmp84 +fairseq/modules/positional_encoding.py @sravyapopuri388 @jmp84 + +# Machine Translation/NLLB +fairseq/tasks/translation.py @gwenzek diff --git a/fairseq/.github/ISSUE_TEMPLATE.md b/fairseq/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..5c4c4493e4a8e5386b927e4f4554df925955d129 --- /dev/null +++ b/fairseq/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,3 @@ +## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈 + +Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates. diff --git a/fairseq/.github/ISSUE_TEMPLATE/bug_report.md b/fairseq/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000000000000000000000000000000000..aa15123d8ef25c2de745572563505cf0ddc4e351 --- /dev/null +++ b/fairseq/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,43 @@ +--- +name: 🐛 Bug Report +about: Submit a bug report to help us improve +labels: 'bug, needs triage' +--- + +## 🐛 Bug + + + +### To Reproduce + +Steps to reproduce the behavior (**always include the command you ran**): + +1. Run cmd '....' +2. See error + + + + +#### Code sample + + +### Expected behavior + + + +### Environment + + - fairseq Version (e.g., 1.0 or main): + - PyTorch Version (e.g., 1.0) + - OS (e.g., Linux): + - How you installed fairseq (`pip`, source): + - Build command you used (if compiling from source): + - Python version: + - CUDA/cuDNN version: + - GPU models and configuration: + - Any other relevant information: + +### Additional context + + diff --git a/fairseq/.github/ISSUE_TEMPLATE/documentation.md b/fairseq/.github/ISSUE_TEMPLATE/documentation.md new file mode 100644 index 0000000000000000000000000000000000000000..3a6e2e9ea4bb71102122c17ff53051eb3770cb5e --- /dev/null +++ b/fairseq/.github/ISSUE_TEMPLATE/documentation.md @@ -0,0 +1,15 @@ +--- +name: 📚 Documentation/Typos +about: Report an issue related to documentation or a typo +labels: 'documentation, needs triage' +--- + +## 📚 Documentation + +For typos and doc fixes, please go ahead and: + +1. Create an issue. +2. Fix the typo. +3. Submit a PR. + +Thanks! diff --git a/fairseq/.github/ISSUE_TEMPLATE/feature_request.md b/fairseq/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000000000000000000000000000000000..93c8668041f8a7af29e4c11e905d8b56b946dd51 --- /dev/null +++ b/fairseq/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,24 @@ +--- +name: 🚀 Feature Request +about: Submit a proposal/request for a new feature +labels: 'enhancement, help wanted, needs triage' +--- + +## 🚀 Feature Request + + +### Motivation + + + +### Pitch + + + +### Alternatives + + + +### Additional context + + diff --git a/fairseq/.github/ISSUE_TEMPLATE/how-to-question.md b/fairseq/.github/ISSUE_TEMPLATE/how-to-question.md new file mode 100644 index 0000000000000000000000000000000000000000..04f3f15d3ed391e26ca87f726ae88f30d1d414ab --- /dev/null +++ b/fairseq/.github/ISSUE_TEMPLATE/how-to-question.md @@ -0,0 +1,33 @@ +--- +name: ❓ Questions/Help +about: If you have questions, please first search existing issues and docs +labels: 'question, needs triage' +--- + +## ❓ Questions and Help + +### Before asking: +1. search the issues. +2. search the docs. + + + +#### What is your question? + +#### Code + + + +#### What have you tried? + +#### What's your environment? + + - fairseq Version (e.g., 1.0 or main): + - PyTorch Version (e.g., 1.0) + - OS (e.g., Linux): + - How you installed fairseq (`pip`, source): + - Build command you used (if compiling from source): + - Python version: + - CUDA/cuDNN version: + - GPU models and configuration: + - Any other relevant information: diff --git a/fairseq/.github/PULL_REQUEST_TEMPLATE.md b/fairseq/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..d005e2df4f717ea4844a8320981d77d96e425a52 --- /dev/null +++ b/fairseq/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,16 @@ +# Before submitting + +- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) +- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? +- [ ] Did you make sure to update the docs? +- [ ] Did you write any new necessary tests? + +## What does this PR do? +Fixes # (issue). + +## PR review +Anyone in the community is free to review the PR once the tests have passed. +If we didn't discuss your PR in Github issues there's a high chance it will not be merged. + +## Did you have fun? +Make sure you had fun coding 🙃 diff --git a/fairseq/.github/stale.yml b/fairseq/.github/stale.yml new file mode 100644 index 0000000000000000000000000000000000000000..b12867dab005e7a7608d4c7138a67d409c76f7ae --- /dev/null +++ b/fairseq/.github/stale.yml @@ -0,0 +1,30 @@ +# Configuration for probot-stale - https://github.com/probot/stale +# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 90 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 7 +# Issues with these labels will never be considered stale +exemptLabels: + - bug +# Label to use when marking an issue as stale +staleLabel: stale +issues: + # Comment to post when marking an issue as stale. + markComment: > + This issue has been automatically marked as stale. + **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open. + We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment! + # Comment to post when closing a stale issue. + closeComment: > + Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you! +pulls: + # Comment to post when marking a pull request as stale. + markComment: > + This pull request has been automatically marked as stale. + **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open. + We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated. + # Comment to post when closing a stale pull request. + closeComment: > + Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you! + diff --git a/fairseq/.github/workflows/build.yml b/fairseq/.github/workflows/build.yml new file mode 100644 index 0000000000000000000000000000000000000000..036233d8cf65f0e2b6162d662e9c2c62d1e0ef4f --- /dev/null +++ b/fairseq/.github/workflows/build.yml @@ -0,0 +1,81 @@ +name: build + +on: + # Trigger the workflow on push to main or any pull request + push: + branches: + - main + pull_request: + +jobs: + build: + + strategy: + max-parallel: 4 + matrix: + platform: [ubuntu-latest, macos-latest] + python-version: [3.8, 3.9] + + runs-on: ${{ matrix.platform }} + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Conditionally install pytorch + if: matrix.platform == 'windows-latest' + run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html + + - name: Install locally + run: | + python -m pip install --upgrade pip + git submodule update --init --recursive + python -m pip install . + + - name: Check installation + working-directory: /tmp + run: python $GITHUB_WORKSPACE/scripts/check_installation.py + + - name: Install optional test requirements + run: | + python -m pip install '.[dev,docs]' + python -m pip install iopath transformers pyarrow + python -m pip install git+https://github.com/facebookresearch/fairscale.git@main + python -m pip install pygit2 pgzip + + - name: Install xformers for Macos + if: matrix.platform == 'macos-latest' + run: | + brew install llvm libomp + CC=/usr/local/opt/llvm/bin/clang CXX=clang++ pip install git+https://github.com/facebookresearch/xformers.git@main + + - name: Install xformers for non-MacOS + if: matrix.platform != 'macos-latest' + run: | + python -m pip install --progress-bar off git+https://github.com/facebookresearch/xformers.git@main + + - name: Lint with black + run: black --check --diff . + + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + - name: Build doc + run: make singlehtml + working-directory: docs/ + + - name: Run tests + # When installing in non-editable mode, the .so files will be generated in 'site-packages/fairseq'. + # But by default, pytest import machinery will load local fairseq, and won't see the .so. + # Use --import-mode=append to favorize the 'site-packages/fairseq'. + # https://docs.pytest.org/en/7.1.x/explanation/pythonpath.html + run: pytest --import-mode=append -vvv tests/ + diff --git a/fairseq/.github/workflows/depreview.yml b/fairseq/.github/workflows/depreview.yml new file mode 100644 index 0000000000000000000000000000000000000000..032eddef5fbbe20fc4c3d99dc277239e536c2bc3 --- /dev/null +++ b/fairseq/.github/workflows/depreview.yml @@ -0,0 +1,14 @@ +name: 'Dependency Review' +on: [pull_request] + +permissions: + contents: read + +jobs: + dependency-review: + runs-on: ubuntu-latest + steps: + - name: 'Checkout Repository' + uses: actions/checkout@v4 + - name: Dependency Review + uses: actions/dependency-review-action@v4 diff --git a/fairseq/.github/workflows/release.yml b/fairseq/.github/workflows/release.yml new file mode 100644 index 0000000000000000000000000000000000000000..241b74b32d8072229c969ef55ec85067340fe140 --- /dev/null +++ b/fairseq/.github/workflows/release.yml @@ -0,0 +1,161 @@ +name: Fairseq Release + +on: + workflow_dispatch: + inputs: + name: + description: 'Release Type' + default: 'patch' + required: true + +jobs: + + get_next_version: + runs-on: ubuntu-latest + steps: + - name: checkout-repo-content + uses: actions/checkout@v2 + + - name: setup-python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: get next version and tag + id: get-next-version-and-tag + run: | + output=$(python3 release_utils.py --release-type ${{ github.event.inputs.name }}) + echo $output + new_version=$(echo $output | awk '{print $1}') + new_tag=$(echo $output | awk '{print $2}') + echo "new version is $new_version" + echo "new tag is $new_tag" + echo ::set-output name=version::$new_version + echo ::set-output name=tag::$new_tag + echo ::set-output name=branch_name::$new_version-release + echo "NEW_TAG=$new_tag" >> $GITHUB_ENV + echo "NEW_BRANCH=$new_version-release" >> $GITHUB_ENV + + + # update the version number in version.txt + - name: update version + id: update-version + run : | + echo "current folder = $PWD" + echo "current branch = $(git branch --show-current)" + output=$(python3 release_utils.py --release-type ${{ github.event.inputs.name }} --update-version) + + - name: add and commit + uses: EndBug/add-and-commit@v9 + with: + author_name: ${{ secrets.AUTHOR_NAME }} + author_email: ${{ secrets.AUTHOR_EMAIL }} + + # TODO: change this to main once shipit is disabled. + new_branch: '${{ env.NEW_BRANCH }}' + default_author: github_actor + message: '${{ env.NEW_TAG }} release' + pathspec_error_handling: exitAtEnd + + # Arguments for the git pull command. Use NO-PULL to avoid the action pulling at all. + # pull: 'NO-PULL' + tag: '${{ env.NEW_TAG }}' + + outputs: + new_version: ${{ steps.get-next-version-and-tag.outputs.version }} + new_tag: ${{ steps.get-next-version-and-tag.outputs.tag }} + branch_name: ${{ steps.get-next-version-and-tag.outputs.branch_name }} + + create_sdist: + runs-on: ubuntu-latest + name: Create Source Distribution + needs: get_next_version + steps: + - uses: actions/checkout@v3 + with: + ref: ${{ needs.get_next_version.outputs.branch_name }} + + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + - name: Upgrade pip + run: | + python3 -m pip install --upgrade pip + + - name: Create Source Distribution + run: | + python3 -m pip install setuptools wheel twine torch + python3 setup.py sdist + + - uses: actions/upload-artifact@v2 + with: + path: dist/*.tar.gz + + build_wheels: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + needs: get_next_version + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + + steps: + - uses: actions/checkout@v3 + with: + ref: ${{ needs.get_next_version.outputs.branch_name }} + + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + - name: Upgrade pip + run: | + python3 -m pip install --upgrade pip + + - name: Install cibuildwheel + run: | + python3 -m pip install cibuildwheel + + - name: Build wheels for CPython + run: | + python3 -m cibuildwheel --output-dir dist + env: + CIBW_BUILD: "cp38-*64" + CIBW_MANYLINUX_X86_64_IMAGE: manylinux1 + CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install . + # Install system library + CIBW_BEFORE_BUILD_LINUX: (yum install -y libffi-devel || apt-get install -y libffi-devel || apk add --update --no-cache libffi-devel || true) && (yum install -y libc6 || apt-get install -y libc6 || apk add --update --no-cache libc6 || true) + CIBW_ENVIRONMENT: "PIP_ONLY_BINARY=numpy" + CIBW_SKIP: "*musllinux*" + + - uses: actions/upload-artifact@v2 + with: + path: dist + + upload: + name: Upload to PyPi and create release + runs-on: ubuntu-latest + needs: [build_wheels, create_sdist, get_next_version] + steps: + - uses: actions/download-artifact@v2 + with: + name: artifact + path: dist + + # build the PyPI package and upload it + - name: upload + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + pip install setuptools wheel twine + python3 -m twine upload --repository pypi dist/* + + # create the release on github + - name: create release on github + uses: ncipollo/release-action@v1 + with: + tag: '${{ needs.get_next_version.outputs.new_tag }}' diff --git a/fairseq/.gitignore b/fairseq/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4be13638de637fe994e601f3de8e215109b30322 --- /dev/null +++ b/fairseq/.gitignore @@ -0,0 +1,141 @@ +# JetBrains PyCharm IDE +.idea/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# macOS dir files +.DS_Store + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Checkpoints +checkpoints + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# Generated files +/fairseq/temporal_convolution_tbc +/fairseq/modules/*_layer/*_forward.cu +/fairseq/modules/*_layer/*_backward.cu +/fairseq/version.py + +# data +data-bin/ + +# reranking +/examples/reranking/rerank_data + +# Cython-generated C++ source files +/fairseq/data/data_utils_fast.cpp +/fairseq/data/token_block_utils_fast.cpp + +# VSCODE +.vscode/ftp-sync.json +.vscode/settings.json + +# Experimental Folder +experimental/* + +# Weights and Biases logs +wandb/ + +# Hydra artifacts +nohup.out +multirun +outputs diff --git a/fairseq/.gitmodules b/fairseq/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..07a55d45d4f0bed755dbfc1f440f214ed43d206a --- /dev/null +++ b/fairseq/.gitmodules @@ -0,0 +1,4 @@ +[submodule "fairseq/model_parallel/megatron"] + path = fairseq/model_parallel/megatron + url = https://github.com/ngoyal2707/Megatron-LM + branch = fairseq diff --git a/fairseq/.pre-commit-config.yaml b/fairseq/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b1d6aed8c7a4eb2d0c9b4eba7104ed9dc41b591 --- /dev/null +++ b/fairseq/.pre-commit-config.yaml @@ -0,0 +1,40 @@ +exclude: 'build|stubs' + +default_language_version: + python: python3 + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.1.0 + hooks: + - id: trailing-whitespace + - id: check-ast + - id: check-merge-conflict + - id: no-commit-to-branch + args: ['--branch=master'] + - id: check-added-large-files + args: ['--maxkb=500'] + - id: end-of-file-fixer + +- repo: https://github.com/ambv/black + rev: 22.3.0 + hooks: + - id: black + language_version: python3.8 + +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.9.2 + hooks: + - id: flake8 + args: [ + # only error for syntax errors and undefined names + "--select=E9,F63,F7,F82", + ] + +- repo: https://github.com/pycqa/isort + rev: 5.10.1 + hooks: + - id: isort + exclude: README.md + additional_dependencies: [toml] + args: ["--profile", "black"] diff --git a/fairseq/CODE_OF_CONDUCT.md b/fairseq/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..a0cbeaab7650bf08267fbdbc9bb54e845c88f392 --- /dev/null +++ b/fairseq/CODE_OF_CONDUCT.md @@ -0,0 +1,77 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq + diff --git a/fairseq/CONTRIBUTING.md b/fairseq/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..60e90258877423bb458fafbc9d35781484dbe9c6 --- /dev/null +++ b/fairseq/CONTRIBUTING.md @@ -0,0 +1,82 @@ +# Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq) +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +## License +By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq), +you agree that your contributions will be licensed under the LICENSE file in +the root directory of this source tree. + +## Pre-commit hooks +In order to ensure your code lints, there are pre-commit hooks configured in the repository which you can install. +After installation, they will automatically run each time you commit. +An abbreviated guide is given below; for more information, refer to [the offical pre-commit documentation](https://pre-commit.com/). + +### Installation +``` +pip install pre-commit +pre-commit install +``` + +### Usage +Just commit your changes: +``` +git commit -m "My informative commit message" +``` + +If there was a failure, you will get feedback +``` +[INFO] Initializing environment for https://github.com/PyCQA/flake8. +[INFO] Installing environment for https://github.com/pre-commit/pre-commit-hooks. +[INFO] Once installed this environment will be reused. +[INFO] This may take a few minutes... +[INFO] Installing environment for https://github.com/PyCQA/flake8. +[INFO] Once installed this environment will be reused. +[INFO] This may take a few minutes... +Trim Trailing Whitespace.................................................Failed +- hook id: trailing-whitespace +- exit code: 1 +- files were modified by this hook +Fixing examples/nllb/modeling/wmt15_benchmark/eval_langs2.sh +Fix End of Files.........................................................Failed +- hook id: end-of-file-fixer +- exit code: 1 +- files were modified by this hook +Fixing examples/few_shot/scripts/schedule_jobs_few_shot.py +flake8...................................................................Passed +``` + +Certain hooks modify your files to comply. +To include these modifications, you will need to add them (i.e. `git add ...`) and commit again. + +If all is well, you should see something like: +``` +Trim Trailing Whitespace.................................................Passed +Fix End of Files.........................................................Passed +flake8...................................................................Passed +[gshard-fix-ci 8698644e1] Fix lint, add pre-commit hooks + 10 files changed, 148 insertions(+), 110 deletions(-) + create mode 100644 .flake8 + create mode 100644 .pre-commit-config.yaml + rename examples/nllb/modeling/wmt15_benchmark/{eval_langs2.py => eval_langs2.sh} (99%) + ``` diff --git a/fairseq/LICENSE b/fairseq/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b96dcb0480a0b0be0727976e5202a1e7b23edc3f --- /dev/null +++ b/fairseq/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/fairseq/MANIFEST.in b/fairseq/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..4f719da85c737a216ad4f373837c0e96c647ec36 --- /dev/null +++ b/fairseq/MANIFEST.in @@ -0,0 +1 @@ +include fairseq/version.txt diff --git a/fairseq/README.md b/fairseq/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1150c66cbeaa92aaa3a062b91369914a8a191306 --- /dev/null +++ b/fairseq/README.md @@ -0,0 +1,242 @@ +

+ +
+
+ Support Ukraine + MIT License + Latest Release + Build Status + Documentation Status + CicleCI Status +

+ +-------------------------------------------------------------------------------- + +Fairseq(-py) is a sequence modeling toolkit that allows researchers and +developers to train custom models for translation, summarization, language +modeling and other text generation tasks. + +We provide reference implementations of various sequence modeling papers: + +
List of implemented papers

+ +* **Convolutional Neural Networks (CNN)** + + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) + + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) + + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) + + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) + + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* **LightConv and DynamicConv models** + + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* **Long Short-Term Memory (LSTM) networks** + + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) +* **Transformer (self-attention) networks** + + Attention Is All You Need (Vaswani et al., 2017) + + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) + + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) + + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) + + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) + + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) + + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) + + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) + + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) + + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) + + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) + + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) + + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) + + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979) + + [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430) + + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027) + + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084) + + [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680) + + [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf) + + [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf) + + [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md) +* **Non-autoregressive Transformers** + + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) + + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) + + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) + + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* **Finetuning** + + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) + +

+ +### What's New: +* May 2023 [Released models for Scaling Speech Technology to 1,000+ Languages (Pratap, et al., 2023)](examples/mms/README.md) +* June 2022 [Released code for wav2vec-U 2.0 from Towards End-to-end Unsupervised Speech Recognition (Liu, et al., 2022)](examples/wav2vec/unsupervised/README.md) +* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers) +* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md) +* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md) +* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md) +* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming). +* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md) +* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md) +* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md) +* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md) +* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) +* February 2021 [Added LASER training code](examples/laser/README.md) +* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md) +* December 2020: [GottBERT model and code released](examples/gottbert/README.md) +* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework + * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) +* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) +* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) +* October 2020: [Added CRISS models and code](examples/criss/README.md) + +
Previous updates

+ +* September 2020: [Added Linformer code](examples/linformer/README.md) +* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) +* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) +* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) +* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) +* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) +* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) +* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) +* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) +* February 2020: [mBART model and code released](examples/mbart/README.md) +* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german) +* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) +* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) +* November 2019: [CamemBERT model and code released](examples/camembert/README.md) +* November 2019: [BART model and code released](examples/bart/README.md) +* November 2019: [XLM-R models and code released](examples/xlmr/README.md) +* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) +* August 2019: [WMT'19 models released](examples/wmt19/README.md) +* July 2019: fairseq relicensed under MIT license +* July 2019: [RoBERTa models and code released](examples/roberta/README.md) +* June 2019: [wav2vec models and code released](examples/wav2vec/README.md) + +

+ +### Features: + +* multi-GPU training on one machine or across multiple machines (data and model parallel) +* fast generation on both CPU and GPU with multiple search algorithms implemented: + + beam search + + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) + + sampling (unconstrained, top-k and top-p/nucleus) + + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018) +* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU +* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) +* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers +* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration +* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md) +* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md) + +We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) +with a convenient `torch.hub` interface: + +``` python +en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') +en2de.translate('Hello world', beam=5) +# 'Hallo Welt' +``` + +See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) +and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. + +# Requirements and Installation + +* [PyTorch](http://pytorch.org/) version >= 1.10.0 +* Python version >= 3.8 +* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) +* **To install fairseq** and develop locally: + +``` bash +git clone https://github.com/pytorch/fairseq +cd fairseq +pip install --editable ./ + +# on MacOS: +# CFLAGS="-stdlib=libc++" pip install --editable ./ + +# to install the latest stable release (0.10.x) +# pip install fairseq +``` + +* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: + +``` bash +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ + --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ + --global-option="--fast_multihead_attn" ./ +``` + +* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` +* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` + as command line options to `nvidia-docker run` . + +# Getting Started + +The [full documentation](https://fairseq.readthedocs.io/) contains instructions +for getting started, training new models and extending fairseq with new model +types and tasks. + +# Pre-trained models and examples + +We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, +as well as example training and evaluation commands. + +* [Translation](examples/translation/README.md): convolutional and transformer models are available +* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available + +We also have more detailed READMEs to reproduce results from specific papers: + +* [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md) +* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) +* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) +* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) +* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) +* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) +* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) +* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) +* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) +* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) +* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) +* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) +* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) +* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) +* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) +* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) +* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) +* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) + +# Join the fairseq community + +* Twitter: https://twitter.com/fairseq +* Facebook page: https://www.facebook.com/groups/fairseq.users +* Google group: https://groups.google.com/forum/#!forum/fairseq-users + +# License + +fairseq(-py) is MIT-licensed. +The license applies to the pre-trained models as well. + +# Citation + +Please cite as: + +``` bibtex +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` diff --git a/fairseq/RELEASE.md b/fairseq/RELEASE.md new file mode 100644 index 0000000000000000000000000000000000000000..79480a11c58a91178e9fa189a8a8e0ceea13efc6 --- /dev/null +++ b/fairseq/RELEASE.md @@ -0,0 +1,13 @@ +# Creating a New Release + +In order to create a new release: + +1. Navigate to the [Fairseq Workflows](https://github.com/facebookresearch/fairseq/actions) and find the one named _Fairseq Release_. + +2. Under _Run Workflow_ choose the branch `main` and for _Release Type_ enter either `major`, `minor`, or `patch`. + +3. A branch named `$new_version-release` will be created where the `version.txt` file is updated. Merge those changes into `main`. + +4. Make sure that a [new PYPI package](https://pypi.org/project/fairseq/) has been uploaded. + +5. Make sure that a [new github release](https://github.com/facebookresearch/fairseq/releases) has been created. diff --git a/fairseq/docs/Makefile b/fairseq/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..c2f5b1a89cfc9e02d1bb09027d9e1e520ba53d53 --- /dev/null +++ b/fairseq/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = python -msphinx +SPHINXPROJ = fairseq +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/fairseq/docs/command_line_tools.rst b/fairseq/docs/command_line_tools.rst new file mode 100644 index 0000000000000000000000000000000000000000..c16300ff5cd42d9a6c0070c2d9bec3a802eacfad --- /dev/null +++ b/fairseq/docs/command_line_tools.rst @@ -0,0 +1,85 @@ +.. _Command-line Tools: + +Command-line Tools +================== + +Fairseq provides several command-line tools for training and evaluating models: + +- :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data +- :ref:`fairseq-train`: Train a new model on one or multiple GPUs +- :ref:`fairseq-generate`: Translate pre-processed data with a trained model +- :ref:`fairseq-interactive`: Translate raw text with a trained model +- :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations +- :ref:`fairseq-eval-lm`: Language model evaluation + + +.. _fairseq-preprocess: + +fairseq-preprocess +~~~~~~~~~~~~~~~~~~ +.. automodule:: fairseq_cli.preprocess + + .. argparse:: + :module: fairseq.options + :func: get_preprocessing_parser + :prog: fairseq-preprocess + + +.. _fairseq-train: + +fairseq-train +~~~~~~~~~~~~~ +.. automodule:: fairseq_cli.train + + .. argparse:: + :module: fairseq.options + :func: get_training_parser + :prog: fairseq-train + + +.. _fairseq-generate: + +fairseq-generate +~~~~~~~~~~~~~~~~ +.. automodule:: fairseq_cli.generate + + .. argparse:: + :module: fairseq.options + :func: get_generation_parser + :prog: fairseq-generate + + +.. _fairseq-interactive: + +fairseq-interactive +~~~~~~~~~~~~~~~~~~~ +.. automodule:: fairseq_cli.interactive + + .. argparse:: + :module: fairseq.options + :func: get_interactive_generation_parser + :prog: fairseq-interactive + + +.. _fairseq-score: + +fairseq-score +~~~~~~~~~~~~~ +.. automodule:: fairseq_cli.score + + .. argparse:: + :module: fairseq_cli.score + :func: get_parser + :prog: fairseq-score + + +.. _fairseq-eval-lm: + +fairseq-eval-lm +~~~~~~~~~~~~~~~ +.. automodule:: fairseq_cli.eval_lm + + .. argparse:: + :module: fairseq.options + :func: get_eval_lm_parser + :prog: fairseq-eval-lm diff --git a/fairseq/docs/conf.py b/fairseq/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc049f8028bb3a4d28059f03e59cd368c5b95d4 --- /dev/null +++ b/fairseq/docs/conf.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# fairseq documentation build configuration file, created by +# sphinx-quickstart on Fri Aug 17 21:45:30 2018. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +import os +import sys +from fairseq import __version__ + + +# source code directory, relative to this file, for sphinx-autobuild +sys.path.insert(0, os.path.abspath("..")) + +source_suffix = [".rst"] + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinxarg.ext", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = "fairseq" +copyright = "Facebook AI Research (FAIR)" +author = "Facebook AI Research (FAIR)" + +github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = __version__ +# The full version, including alpha/beta/rc tags. +release = __version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" +highlight_language = "python" + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ---------------------------------------------- + +html_theme = "classic" + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "numpy": ("http://docs.scipy.org/doc/numpy/", None), + "python": ("https://docs.python.org/", None), + "torch": ("https://pytorch.org/docs/master/", None), +} diff --git a/fairseq/docs/criterions.rst b/fairseq/docs/criterions.rst new file mode 100644 index 0000000000000000000000000000000000000000..d6b8ca6b671a32d0da4aca7b18626e0df58a7258 --- /dev/null +++ b/fairseq/docs/criterions.rst @@ -0,0 +1,31 @@ +.. role:: hidden + :class: hidden-section + +.. _Criterions: + +Criterions +========== + +Criterions compute the loss function given the model and batch, roughly:: + + loss = criterion(model, batch) + +.. automodule:: fairseq.criterions + :members: + +.. autoclass:: fairseq.criterions.FairseqCriterion + :members: + :undoc-members: + +.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss + :members: + :undoc-members: +.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss + :members: + :undoc-members: +.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion + :members: + :undoc-members: +.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion + :members: + :undoc-members: diff --git a/fairseq/docs/data.rst b/fairseq/docs/data.rst new file mode 100644 index 0000000000000000000000000000000000000000..6a390cb336ab3c5fb28edec7448abc35a8e22bbb --- /dev/null +++ b/fairseq/docs/data.rst @@ -0,0 +1,58 @@ +.. role:: hidden + :class: hidden-section + +.. module:: fairseq.data + +Data Loading and Utilities +========================== + +.. _datasets: + +Datasets +-------- + +**Datasets** define the data format and provide helpers for creating +mini-batches. + +.. autoclass:: fairseq.data.FairseqDataset + :members: +.. autoclass:: fairseq.data.LanguagePairDataset + :members: +.. autoclass:: fairseq.data.MonolingualDataset + :members: + +**Helper Datasets** + +These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and +provide additional functionality: + +.. autoclass:: fairseq.data.BacktranslationDataset + :members: +.. autoclass:: fairseq.data.ConcatDataset + :members: +.. autoclass:: fairseq.data.ResamplingDataset + :members: +.. autoclass:: fairseq.data.RoundRobinZipDatasets + :members: +.. autoclass:: fairseq.data.TransformEosDataset + :members: + + +Dictionary +---------- + +.. autoclass:: fairseq.data.Dictionary + :members: + + +Iterators +--------- + +.. autoclass:: fairseq.data.CountingIterator + :members: +.. autoclass:: fairseq.data.EpochBatchIterator + :members: +.. autoclass:: fairseq.data.GroupedIterator + :members: +.. autoclass:: fairseq.data.ShardedIterator + :members: diff --git a/fairseq/docs/docutils.conf b/fairseq/docs/docutils.conf new file mode 100644 index 0000000000000000000000000000000000000000..526acffd32d16217160aee917db2b120354f20f0 --- /dev/null +++ b/fairseq/docs/docutils.conf @@ -0,0 +1,2 @@ +[writers] +option-limit=0 diff --git a/fairseq/docs/fairseq_logo.png b/fairseq/docs/fairseq_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..75472cbb5ff78acc8716ad9121ed421f17f96c9a Binary files /dev/null and b/fairseq/docs/fairseq_logo.png differ diff --git a/fairseq/docs/getting_started.rst b/fairseq/docs/getting_started.rst new file mode 100644 index 0000000000000000000000000000000000000000..745ad7763cee67a8dec25bdd7ba7b79cbe0b7754 --- /dev/null +++ b/fairseq/docs/getting_started.rst @@ -0,0 +1,216 @@ +Evaluating Pre-trained Models +============================= + +First, download a pre-trained model along with its vocabularies: + +.. code-block:: console + + > curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - + +This model uses a `Byte Pair Encoding (BPE) +vocabulary `__, so we'll have to apply +the encoding to the source text before it can be translated. This can be +done with the +`apply\_bpe.py `__ +script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is +used as a continuation marker and the original text can be easily +recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe`` +flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized +using ``tokenizer.perl`` from +`mosesdecoder `__. + +Let's use :ref:`fairseq-interactive` to generate translations interactively. +Here, we use a beam size of 5 and preprocess the input with the Moses +tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically +remove the BPE continuation markers and detokenize the output. + +.. code-block:: console + + > MODEL_DIR=wmt14.en-fr.fconv-py + > fairseq-interactive \ + --path $MODEL_DIR/model.pt $MODEL_DIR \ + --beam 5 --source-lang en --target-lang fr \ + --tokenizer moses \ + --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes + | loading model(s) from wmt14.en-fr.fconv-py/model.pt + | [en] dictionary: 44206 types + | [fr] dictionary: 44463 types + | Type the input sentence and press return: + Why is it rare to discover new marine mammal species? + S-0 Why is it rare to discover new marine mam@@ mal species ? + H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins? + P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015 + +This generation script produces three types of outputs: a line prefixed +with *O* is a copy of the original source sentence; *H* is the +hypothesis along with an average log-likelihood; and *P* is the +positional score per token position, including the +end-of-sentence marker which is omitted from the text. + +Other types of output lines you might see are *D*, the detokenized hypothesis, +*T*, the reference target, *A*, alignment info, *E* the history of generation steps. + +See the `README `__ for a +full list of pre-trained models available. + +Training a New Model +==================== + +The following tutorial is for machine translation. For an example of how +to use Fairseq for other tasks, such as :ref:`language modeling`, please see the +``examples/`` directory. + +Data Pre-processing +------------------- + +Fairseq contains example pre-processing scripts for several translation +datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT +2014 (English-German). To pre-process and binarize the IWSLT dataset: + +.. code-block:: console + + > cd examples/translation/ + > bash prepare-iwslt14.sh + > cd ../.. + > TEXT=examples/translation/iwslt14.tokenized.de-en + > fairseq-preprocess --source-lang de --target-lang en \ + --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ + --destdir data-bin/iwslt14.tokenized.de-en + +This will write binarized data that can be used for model training to +``data-bin/iwslt14.tokenized.de-en``. + +Training +-------- + +Use :ref:`fairseq-train` to train a new model. Here a few example settings that work +well for the IWSLT 2014 dataset: + +.. code-block:: console + + > mkdir -p checkpoints/fconv + > CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \ + --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ + --arch fconv_iwslt_de_en --save-dir checkpoints/fconv + +By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the +``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to +change the number of GPU devices that will be used. + +Also note that the batch size is specified in terms of the maximum +number of tokens per batch (``--max-tokens``). You may need to use a +smaller value depending on the available GPU memory on your system. + +Generation +---------- + +Once your model is trained, you can generate translations using +:ref:`fairseq-generate` **(for binarized data)** or +:ref:`fairseq-interactive` **(for raw text)**: + +.. code-block:: console + + > fairseq-generate data-bin/iwslt14.tokenized.de-en \ + --path checkpoints/fconv/checkpoint_best.pt \ + --batch-size 128 --beam 5 + | [de] dictionary: 35475 types + | [en] dictionary: 24739 types + | data-bin/iwslt14.tokenized.de-en test 6750 examples + | model fconv + | loaded checkpoint trainings/fconv/checkpoint_best.pt + S-721 danke . + T-721 thank you . + ... + +To generate translations with only a CPU, use the ``--cpu`` flag. BPE +continuation markers can be removed with the ``--remove-bpe`` flag. + +Advanced Training Options +========================= + +Large mini-batch training with delayed updates +---------------------------------------------- + +The ``--update-freq`` option can be used to accumulate gradients from +multiple mini-batches and delay updating, creating a larger effective +batch size. Delayed updates can also improve training speed by reducing +inter-GPU communication costs and by saving idle time caused by variance +in workload across GPUs. See `Ott et al. +(2018) `__ for more details. + +To train on a single GPU with an effective batch size that is equivalent +to training on 8 GPUs: + +.. code-block:: console + + > CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...) + +Training with half precision floating point (FP16) +-------------------------------------------------- + +.. note:: + + FP16 training requires a Volta GPU and CUDA 9.1 or greater + +Recent GPUs enable efficient half precision floating point computation, +e.g., using `Nvidia Tensor Cores +`__. +Fairseq supports FP16 training with the ``--fp16`` flag: + +.. code-block:: console + + > fairseq-train --fp16 (...) + +Distributed training +-------------------- + +Distributed training in fairseq is implemented on top of ``torch.distributed``. +The easiest way to launch jobs is with the `torch.distributed.launch +`__ tool. + +For example, to train a large English-German Transformer model on 2 nodes each +with 8 GPUs (in total 16 GPUs), run the following command on each node, +replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making +sure to update ``--master_addr`` to the IP address of the first node: + +.. code-block:: console + + > python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \ + --master_port=12345 \ + $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \ + --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ + --lr 0.0005 \ + --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --max-tokens 3584 \ + --max-epoch 70 \ + --fp16 + +On SLURM clusters, fairseq will automatically detect the number of nodes and +GPUs, but a port number must be provided: + +.. code-block:: console + + > salloc --gpus=16 --nodes 2 (...) + > srun fairseq-train --distributed-port 12345 (...). + +Sharding very large datasets +---------------------------- + +It can be challenging to train over very large datasets, particularly if your +machine does not have much system RAM. Most tasks in fairseq support training +over "sharded" datasets, in which the original dataset has been preprocessed +into non-overlapping chunks (or "shards"). + +For example, instead of preprocessing all your data into a single "data-bin" +directory, you can split the data and create "data-bin1", "data-bin2", etc. +Then you can adapt your training command like so: + +.. code-block:: console + + > fairseq-train data-bin1:data-bin2:data-bin3 (...) + +Training will now iterate over each shard, one by one, with each shard +corresponding to an "epoch", thus reducing system memory usage. diff --git a/fairseq/docs/hydra_integration.md b/fairseq/docs/hydra_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..6a15298382a6a16dfc4c5a4a812ea1cd0477ed52 --- /dev/null +++ b/fairseq/docs/hydra_integration.md @@ -0,0 +1,284 @@ +## Hydra + +[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python +framework that simplifies the development of research and other complex +applications. The key feature is the ability to dynamically create a +hierarchical configuration by composition and override it through config files +and the command line. The name Hydra comes from its ability to run multiple +similar jobs - much like a Hydra with multiple heads. + +## Motivation + +Until recently, all components in fairseq were configured through a shared +`args` namespace that was created at application startup. Components declared +their own `add_args` method to update the argparse parser, hoping that the names +would not clash with arguments from other components. While this model works for +smaller applications, as fairseq grew and became integrated into other +applications, this became problematic. In order to determine how to configure +each component, one needed to a) examine what args were added by this component, +and b) read the code to figure out what shared arguments it is using that were +added in other places. Reproducing models involved sharing commands that often +contained dozens of command line switches. + +The model described above is still supported by fairseq for backward +compatibility, but will be deprecated some time in the future. + +New components in fairseq should now create a dataclass that encapsulates all +parameters required to configure this component. The dataclass is registered +along with the component, and fairseq takes care of constructing and providing +this configuration object to the component's constructor. Note that sharing +parameters can optionally still work, but one has to explicitly point to the +"source of truth" (see inheritance example below). These changes make components +in fairseq more independent and re-usable by other applications: all that is +needed to create a component is to initialize its dataclass and overwrite some +of the defaults. + +While configuring fairseq through command line (using either the legacy argparse +based or the new Hydra based entry points) is still fully supported, you can now +take advantage of configuring fairseq completely or piece-by-piece through +hierarchical YAML configuration files. These files can also be shipped as +examples that others can use to run an identically configured job. + +Additionally, Hydra has a rich and growing [library of +plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that +provide functionality such as hyperparameter sweeping (including using bayesian +optimization through the [Ax](https://github.com/facebook/Ax) library), job +launching across various platforms, and more. + +## Creating or migrating components + +In general, each new (or updated) component should provide a companion +[dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are +typically located in the same file as the component and are passed as arguments +to the `register_*()` functions. Top-level configs that should be present in +every fairseq application are placed in the +[global](fairseq/dataclass/configs.py) config file and added to the +`FairseqConfig` object. + +Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These +classes are decorated with a `@dataclass` decorator, and typically inherit from +`FairseqDataclass` (which adds some functionality for backward compatibility). +Each field must have a type, and generally has metadata (such as a help string) +and a default value. Only primitive types or other config objects are allowed as +data types for each field. + +#### Example: + +```python +from dataclasses import dataclass, field +from fairseq.dataclass import FairseqDataclass + +@dataclass +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) +``` + +### Inherting values + +Some components require sharing a value. For example, a learning rate scheduler +and an optimizer may both need to know the initial learning rate value. One can +declare a field that, by default, will inherit its value from another config +node in the same hierarchy: + +```python +@dataclass +FairseqAdamConfig(FairseqDataclass): + ... + lr: List[float] = II("optimization.lr") + ... +``` + +`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is +the value one can use in a YAML config file or through command line to achieve +the same effect. Note that this assumes that there is an "optimization" config +object in the root config and it has a field called "lr". + +### Tasks and Models + +Creating Tasks and Models works same as before, except that legacy +implementations now inherit from `LegacyFairseq*` base classes, while new +components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass +to the `register_*()` functions. + +#### Task example: + +```python +@dataclass +class LanguageModelingConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + ... + +@register_task("language_modeling", dataclass=LanguageModelingConfig) +class LanguageModelingTask(FairseqTask): + ... + @classmethod + def setup_task(cls, cfg: LanguageModelingConfig): + ... +``` + +#### Model example: + +```python +@dataclass +class TransformerLanguageModelConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", metadata={"help": "activation function to use"} + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + ... + +@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) +class TransformerLanguageModel(FairseqLanguageModel): + ... + @classmethod + def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask): + ... +``` + +### Other components + +Other components work as before, but they now take their configuration dataclass +as the only constructor argument: + +```python +@dataclass +class MosesTokenizerConfig(FairseqDataclass): + source_lang: str = field(default="en", metadata={"help": "source language"}) + ... + +@register_tokenizer("moses", dataclass=MosesTokenizerConfig) +class MosesTokenizer(object): + def __init__(self, cfg: MosesTokenizerConfig): + ... +``` + +Note that if you are adding a new registry for a new set of components, you need +to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`: + +```python +@dataclass +class FairseqConfig(object): + ... + my_new_registry: Any = None +``` + +## Training with `fairseq-hydra-train` + +To fully take advantage of configuration flexibility offered by Hydra, you may +want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI +tools such as `fairseq-train` will remain supported for the foreseeable future +but will be deprecated eventually. + +On startup, Hydra will create a configuration object that contains a hierarchy +of all the necessary dataclasses populated with their default values in the +code. The default values are overwritten by values found in YAML files in +`fairseq/config` directory (which currently sets minimal defaults) and then +further overwritten by values provided through command line arguments. + +Some of the most common use cases are shown below: + +### 1. Override default values through command line: + +```shell script +$ fairseq-hydra-train \ + distributed_training.distributed_world_size=1 \ + dataset.batch_size=2 \ + task.data=data-bin \ + model=transformer_lm/transformer_lm_gpt \ + task=language_modeling \ + optimization.max_update=5000 +``` + +Note that along with explicitly providing values for parameters such as +`dataset.batch_size`, this also tells Hydra to overlay configuration found in +`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default +values in the dataclass. If you want to train a model without specifying a +particular architecture you can simply specify `model=transformer_lm`. This only +works for migrated tasks and models. + +### 2. Replace bundled configs with an external config: + +```shell script +$ fairseq-hydra-train \ + --config-dir /path/to/external/configs \ + --config-name wiki103 +``` + +where `/path/to/external/configs/wiki103.yaml` contains: + +```yaml +# @package _group_ + +model: + _name: transformer_lm +distributed_training: + distributed_world_size: 1 +dataset: + batch_size: 2 +task: + _name: language_modeling + data: /path/to/data + add_bos_token: false + max_target_positions: 1024 +optimization: + max_update: 50000 + lr: [ 0.25 ] +criterion: cross_entropy +optimizer: adam +lr_scheduler: + _name: cosine +``` + +Note that here bundled configs from `fairseq/config` directory are not used, +however the defaults from each dataclass will still be used (unless overwritten +by your external config). + +Additionally you can choose to break up your configs by creating a directory +structure in the same location as your main config file, with the names of the +top-level fields (such as "model", "dataset", etc), and placing config files +with meaningful names that would populate that specific section of your +top-level config file (for example, you might have +`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You +can then specify the correct configuration via command line, defaults in the +main config, or even launch all of them as a sweep (see Hydra documentation on +how to do this). + +### 3. Add an external config directory to Hydra search path: + +This allows combining default configuration (including using any bundled config +files), while specifying your own config files for some parts of the +configuration. + +```shell script +$ fairseq-hydra-train \ + distributed_training.distributed_world_size=1 \ + dataset.batch_size=2 \ + task.data=/path/to/data/ \ + model=transformer_lm/2_layers \ + task=language_modeling \ + optimization.max_update=5000 \ + --config-dir /path/to/external/configs +``` + +where `/path/to/external/configs` has the following structure: +``` +. ++-- model +| +-- transformer_lm +| | +-- 2_layers.yaml +``` + +and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with +`decoder_layers` set to 2. You can add other configs to configure other +components as well. diff --git a/fairseq/docs/index.rst b/fairseq/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..591db86cdf49e6f0a7a6686df2150f11418e90d0 --- /dev/null +++ b/fairseq/docs/index.rst @@ -0,0 +1,49 @@ +.. fairseq documentation master file, created by + sphinx-quickstart on Fri Aug 17 21:45:30 2018. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +:github_url: https://github.com/pytorch/fairseq + + +fairseq documentation +===================== + +Fairseq is a sequence modeling toolkit written in `PyTorch +`_ that allows researchers and developers to +train custom models for translation, summarization, language modeling and other +text generation tasks. + +.. toctree:: + :maxdepth: 1 + :caption: Getting Started + + getting_started + command_line_tools + +.. toctree:: + :maxdepth: 1 + :caption: Extending Fairseq + + overview + tutorial_simple_lstm + tutorial_classifying_names + +.. toctree:: + :maxdepth: 2 + :caption: Library Reference + + tasks + models + criterions + optim + lr_scheduler + data + modules + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`search` diff --git a/fairseq/docs/lr_scheduler.rst b/fairseq/docs/lr_scheduler.rst new file mode 100644 index 0000000000000000000000000000000000000000..bbc09dc22e6a7ac05137954e0b9c80ca030f62f4 --- /dev/null +++ b/fairseq/docs/lr_scheduler.rst @@ -0,0 +1,34 @@ +.. role:: hidden + :class: hidden-section + +.. _Learning Rate Schedulers: + +Learning Rate Schedulers +======================== + +Learning Rate Schedulers update the learning rate over the course of training. +Learning rates can be updated after each update via :func:`step_update` or at +epoch boundaries via :func:`step`. + +.. automodule:: fairseq.optim.lr_scheduler + :members: + +.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler + :members: + :undoc-members: + +.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule + :members: + :undoc-members: +.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule + :members: + :undoc-members: +.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule + :members: + :undoc-members: +.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau + :members: + :undoc-members: +.. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule + :members: + :undoc-members: diff --git a/fairseq/docs/make.bat b/fairseq/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..35c5085de318190514ee3b48d10060aa57a4fa50 --- /dev/null +++ b/fairseq/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=python -msphinx +) +set SOURCEDIR=. +set BUILDDIR=_build +set SPHINXPROJ=fairseq + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The Sphinx module was not found. Make sure you have Sphinx installed, + echo.then set the SPHINXBUILD environment variable to point to the full + echo.path of the 'sphinx-build' executable. Alternatively you may add the + echo.Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/fairseq/docs/modules.rst b/fairseq/docs/modules.rst new file mode 100644 index 0000000000000000000000000000000000000000..9631c93d4682286e1cea1ddd961d3f6ab06f2589 --- /dev/null +++ b/fairseq/docs/modules.rst @@ -0,0 +1,9 @@ +Modules +======= + +Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may +be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`. + +.. automodule:: fairseq.modules + :members: + :undoc-members: diff --git a/fairseq/docs/optim.rst b/fairseq/docs/optim.rst new file mode 100644 index 0000000000000000000000000000000000000000..c3326456bd9291a1d05bd3316bef5c9fb25c6c49 --- /dev/null +++ b/fairseq/docs/optim.rst @@ -0,0 +1,38 @@ +.. role:: hidden + :class: hidden-section + +.. _optimizers: + +Optimizers +========== + +Optimizers update the Model parameters based on the gradients. + +.. automodule:: fairseq.optim + :members: + +.. autoclass:: fairseq.optim.FairseqOptimizer + :members: + :undoc-members: + +.. autoclass:: fairseq.optim.adadelta.Adadelta + :members: + :undoc-members: +.. autoclass:: fairseq.optim.adagrad.Adagrad + :members: + :undoc-members: +.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor + :members: + :undoc-members: +.. autoclass:: fairseq.optim.adam.FairseqAdam + :members: + :undoc-members: +.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer + :members: + :undoc-members: +.. autoclass:: fairseq.optim.nag.FairseqNAG + :members: + :undoc-members: +.. autoclass:: fairseq.optim.sgd.SGD + :members: + :undoc-members: diff --git a/fairseq/docs/overview.rst b/fairseq/docs/overview.rst new file mode 100644 index 0000000000000000000000000000000000000000..026b3b5c7b21d071d8b8a3405898977c760d05b8 --- /dev/null +++ b/fairseq/docs/overview.rst @@ -0,0 +1,74 @@ +Overview +======== + +Fairseq can be extended through user-supplied `plug-ins +`_. We support five kinds of +plug-ins: + +- :ref:`Models` define the neural network architecture and encapsulate all of the + learnable parameters. +- :ref:`Criterions` compute the loss function given the model outputs and targets. +- :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over + Datasets, initializing the Model/Criterion and calculating the loss. +- :ref:`Optimizers` update the Model parameters based on the gradients. +- :ref:`Learning Rate Schedulers` update the learning rate over the course of + training. + +**Training Flow** + +Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``, +fairseq implements the following high-level training flow:: + + for epoch in range(num_epochs): + itr = task.get_batch_iterator(task.dataset('train')) + for num_updates, batch in enumerate(itr): + task.train_step(batch, model, criterion, optimizer) + average_and_clip_gradients() + optimizer.step() + lr_scheduler.step_update(num_updates) + lr_scheduler.step(epoch) + +where the default implementation for ``task.train_step`` is roughly:: + + def train_step(self, batch, model, criterion, optimizer, **unused): + loss = criterion(model, batch) + optimizer.backward(loss) + return loss + +**Registering new plug-ins** + +New plug-ins are *registered* through a set of ``@register`` function +decorators, for example:: + + @register_model('my_lstm') + class MyLSTM(FairseqEncoderDecoderModel): + (...) + +Once registered, new plug-ins can be used with the existing :ref:`Command-line +Tools`. See the Tutorial sections for more detailed walkthroughs of how to add +new plug-ins. + +**Loading plug-ins from another directory** + +New plug-ins can be defined in a custom module stored in the user system. In +order to import the module, and make the plugin available to *fairseq*, the +command line supports the ``--user-dir`` flag that can be used to specify a +custom location for additional modules to load into *fairseq*. + +For example, assuming this directory tree:: + + /home/user/my-module/ + └── __init__.py + +with ``__init__.py``:: + + from fairseq.models import register_model_architecture + from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big + + @register_model_architecture('transformer', 'my_transformer') + def transformer_mmt_big(args): + transformer_vaswani_wmt_en_de_big(args) + +it is possible to invoke the :ref:`fairseq-train` script with the new architecture with:: + + fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation diff --git a/fairseq/docs/tasks.rst b/fairseq/docs/tasks.rst new file mode 100644 index 0000000000000000000000000000000000000000..5f65c3c866865e50332d8e6ca012a4a81e7bea74 --- /dev/null +++ b/fairseq/docs/tasks.rst @@ -0,0 +1,61 @@ +.. role:: hidden + :class: hidden-section + +.. module:: fairseq.tasks + +.. _Tasks: + +Tasks +===== + +Tasks store dictionaries and provide helpers for loading/iterating over +Datasets, initializing the Model/Criterion and calculating the loss. + +Tasks can be selected via the ``--task`` command-line argument. Once selected, a +task may expose additional command-line arguments for further configuration. + +Example usage:: + + # setup the task (e.g., load dictionaries) + task = fairseq.tasks.setup_task(args) + + # build model and criterion + model = task.build_model(args) + criterion = task.build_criterion(args) + + # load datasets + task.load_dataset('train') + task.load_dataset('valid') + + # iterate over mini-batches of data + batch_itr = task.get_batch_iterator( + task.dataset('train'), max_tokens=4096, + ) + for batch in batch_itr: + # compute the loss + loss, sample_size, logging_output = task.get_loss( + model, criterion, batch, + ) + loss.backward() + + +Translation +----------- + +.. autoclass:: fairseq.tasks.translation.TranslationTask + +.. _language modeling: + +Language Modeling +----------------- + +.. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask + + +Adding new tasks +---------------- + +.. autofunction:: fairseq.tasks.register_task +.. autoclass:: fairseq.tasks.FairseqTask + :members: + :undoc-members: diff --git a/fairseq/docs/tutorial_simple_lstm.rst b/fairseq/docs/tutorial_simple_lstm.rst new file mode 100644 index 0000000000000000000000000000000000000000..f52988507c5da5125668e143bd2bfe4df117b41c --- /dev/null +++ b/fairseq/docs/tutorial_simple_lstm.rst @@ -0,0 +1,518 @@ +Tutorial: Simple LSTM +===================== + +In this tutorial we will extend fairseq by adding a new +:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source +sentence with an LSTM and then passes the final hidden state to a second LSTM +that decodes the target sentence (without attention). + +This tutorial covers: + +1. **Writing an Encoder and Decoder** to encode/decode the source/target + sentence, respectively. +2. **Registering a new Model** so that it can be used with the existing + :ref:`Command-line tools`. +3. **Training the Model** using the existing command-line tools. +4. **Making generation faster** by modifying the Decoder to use + :ref:`Incremental decoding`. + + +1. Building an Encoder and Decoder +---------------------------------- + +In this section we'll define a simple LSTM Encoder and Decoder. All Encoders +should implement the :class:`~fairseq.models.FairseqEncoder` interface and +Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface. +These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders +and FairseqDecoders can be written and used in the same ways as ordinary PyTorch +Modules. + + +Encoder +~~~~~~~ + +Our Encoder will embed the tokens in the source sentence, feed them to a +:class:`torch.nn.LSTM` and return the final hidden state. To create our encoder +save the following in a new file named :file:`fairseq/models/simple_lstm.py`:: + + import torch.nn as nn + from fairseq import utils + from fairseq.models import FairseqEncoder + + class SimpleLSTMEncoder(FairseqEncoder): + + def __init__( + self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1, + ): + super().__init__(dictionary) + self.args = args + + # Our encoder will embed the inputs before feeding them to the LSTM. + self.embed_tokens = nn.Embedding( + num_embeddings=len(dictionary), + embedding_dim=embed_dim, + padding_idx=dictionary.pad(), + ) + self.dropout = nn.Dropout(p=dropout) + + # We'll use a single-layer, unidirectional LSTM for simplicity. + self.lstm = nn.LSTM( + input_size=embed_dim, + hidden_size=hidden_dim, + num_layers=1, + bidirectional=False, + batch_first=True, + ) + + def forward(self, src_tokens, src_lengths): + # The inputs to the ``forward()`` function are determined by the + # Task, and in particular the ``'net_input'`` key in each + # mini-batch. We discuss Tasks in the next tutorial, but for now just + # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths* + # has shape `(batch)`. + + # Note that the source is typically padded on the left. This can be + # configured by adding the `--left-pad-source "False"` command-line + # argument, but here we'll make the Encoder handle either kind of + # padding by converting everything to be right-padded. + if self.args.left_pad_source: + # Convert left-padding to right-padding. + src_tokens = utils.convert_padding_direction( + src_tokens, + padding_idx=self.dictionary.pad(), + left_to_right=True + ) + + # Embed the source. + x = self.embed_tokens(src_tokens) + + # Apply dropout. + x = self.dropout(x) + + # Pack the sequence into a PackedSequence object to feed to the LSTM. + x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True) + + # Get the output from the LSTM. + _outputs, (final_hidden, _final_cell) = self.lstm(x) + + # Return the Encoder's output. This can be any object and will be + # passed directly to the Decoder. + return { + # this will have shape `(bsz, hidden_dim)` + 'final_hidden': final_hidden.squeeze(0), + } + + # Encoders are required to implement this method so that we can rearrange + # the order of the batch elements during inference (e.g., beam search). + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to `new_order`. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + `encoder_out` rearranged according to `new_order` + """ + final_hidden = encoder_out['final_hidden'] + return { + 'final_hidden': final_hidden.index_select(0, new_order), + } + + +Decoder +~~~~~~~ + +Our Decoder will predict the next word, conditioned on the Encoder's final +hidden state and an embedded representation of the previous target word -- which +is sometimes called *teacher forcing*. More specifically, we'll use a +:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project +to the size of the output vocabulary to predict each target word. + +:: + + import torch + from fairseq.models import FairseqDecoder + + class SimpleLSTMDecoder(FairseqDecoder): + + def __init__( + self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, + dropout=0.1, + ): + super().__init__(dictionary) + + # Our decoder will embed the inputs before feeding them to the LSTM. + self.embed_tokens = nn.Embedding( + num_embeddings=len(dictionary), + embedding_dim=embed_dim, + padding_idx=dictionary.pad(), + ) + self.dropout = nn.Dropout(p=dropout) + + # We'll use a single-layer, unidirectional LSTM for simplicity. + self.lstm = nn.LSTM( + # For the first layer we'll concatenate the Encoder's final hidden + # state with the embedded target tokens. + input_size=encoder_hidden_dim + embed_dim, + hidden_size=hidden_dim, + num_layers=1, + bidirectional=False, + ) + + # Define the output projection. + self.output_projection = nn.Linear(hidden_dim, len(dictionary)) + + # During training Decoders are expected to take the entire target sequence + # (shifted right by one position) and produce logits over the vocabulary. + # The *prev_output_tokens* tensor begins with the end-of-sentence symbol, + # ``dictionary.eos()``, followed by the target sequence. + def forward(self, prev_output_tokens, encoder_out): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + + Returns: + tuple: + - the last decoder layer's output of shape + `(batch, tgt_len, vocab)` + - the last decoder layer's attention weights of shape + `(batch, tgt_len, src_len)` + """ + bsz, tgt_len = prev_output_tokens.size() + + # Extract the final hidden state from the Encoder. + final_encoder_hidden = encoder_out['final_hidden'] + + # Embed the target sequence, which has been shifted right by one + # position and now starts with the end-of-sentence symbol. + x = self.embed_tokens(prev_output_tokens) + + # Apply dropout. + x = self.dropout(x) + + # Concatenate the Encoder's final hidden state to *every* embedded + # target token. + x = torch.cat( + [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], + dim=2, + ) + + # Using PackedSequence objects in the Decoder is harder than in the + # Encoder, since the targets are not sorted in descending length order, + # which is a requirement of ``pack_padded_sequence()``. Instead we'll + # feed nn.LSTM directly. + initial_state = ( + final_encoder_hidden.unsqueeze(0), # hidden + torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell + ) + output, _ = self.lstm( + x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)` + initial_state, + ) + x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)` + + # Project the outputs to the size of the vocabulary. + x = self.output_projection(x) + + # Return the logits and ``None`` for the attention weights + return x, None + + +2. Registering the Model +------------------------ + +Now that we've defined our Encoder and Decoder we must *register* our model with +fairseq using the :func:`~fairseq.models.register_model` function decorator. +Once the model is registered we'll be able to use it with the existing +:ref:`Command-line Tools`. + +All registered models must implement the +:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence +models (i.e., any model with a single Encoder and Decoder), we can instead +implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface. + +Create a small wrapper class in the same file and register it in fairseq with +the name ``'simple_lstm'``:: + + from fairseq.models import FairseqEncoderDecoderModel, register_model + + # Note: the register_model "decorator" should immediately precede the + # definition of the Model class. + + @register_model('simple_lstm') + class SimpleLSTMModel(FairseqEncoderDecoderModel): + + @staticmethod + def add_args(parser): + # Models can override this method to add new command-line arguments. + # Here we'll add some new command-line arguments to configure dropout + # and the dimensionality of the embeddings and hidden states. + parser.add_argument( + '--encoder-embed-dim', type=int, metavar='N', + help='dimensionality of the encoder embeddings', + ) + parser.add_argument( + '--encoder-hidden-dim', type=int, metavar='N', + help='dimensionality of the encoder hidden state', + ) + parser.add_argument( + '--encoder-dropout', type=float, default=0.1, + help='encoder dropout probability', + ) + parser.add_argument( + '--decoder-embed-dim', type=int, metavar='N', + help='dimensionality of the decoder embeddings', + ) + parser.add_argument( + '--decoder-hidden-dim', type=int, metavar='N', + help='dimensionality of the decoder hidden state', + ) + parser.add_argument( + '--decoder-dropout', type=float, default=0.1, + help='decoder dropout probability', + ) + + @classmethod + def build_model(cls, args, task): + # Fairseq initializes models by calling the ``build_model()`` + # function. This provides more flexibility, since the returned model + # instance can be of a different type than the one that was called. + # In this case we'll just return a SimpleLSTMModel instance. + + # Initialize our Encoder and Decoder. + encoder = SimpleLSTMEncoder( + args=args, + dictionary=task.source_dictionary, + embed_dim=args.encoder_embed_dim, + hidden_dim=args.encoder_hidden_dim, + dropout=args.encoder_dropout, + ) + decoder = SimpleLSTMDecoder( + dictionary=task.target_dictionary, + encoder_hidden_dim=args.encoder_hidden_dim, + embed_dim=args.decoder_embed_dim, + hidden_dim=args.decoder_hidden_dim, + dropout=args.decoder_dropout, + ) + model = SimpleLSTMModel(encoder, decoder) + + # Print the model architecture. + print(model) + + return model + + # We could override the ``forward()`` if we wanted more control over how + # the encoder and decoder interact, but it's not necessary for this + # tutorial since we can inherit the default implementation provided by + # the FairseqEncoderDecoderModel base class, which looks like: + # + # def forward(self, src_tokens, src_lengths, prev_output_tokens): + # encoder_out = self.encoder(src_tokens, src_lengths) + # decoder_out = self.decoder(prev_output_tokens, encoder_out) + # return decoder_out + +Finally let's define a *named architecture* with the configuration for our +model. This is done with the :func:`~fairseq.models.register_model_architecture` +function decorator. Thereafter this named architecture can be used with the +``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``:: + + from fairseq.models import register_model_architecture + + # The first argument to ``register_model_architecture()`` should be the name + # of the model we registered above (i.e., 'simple_lstm'). The function we + # register here should take a single argument *args* and modify it in-place + # to match the desired architecture. + + @register_model_architecture('simple_lstm', 'tutorial_simple_lstm') + def tutorial_simple_lstm(args): + # We use ``getattr()`` to prioritize arguments that are explicitly given + # on the command-line, so that the defaults defined below are only used + # when no other value has been specified. + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) + args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) + args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256) + + +3. Training the Model +--------------------- + +Now we're ready to train the model. We can use the existing :ref:`fairseq-train` +command-line tool for this, making sure to specify our new Model architecture +(``--arch tutorial_simple_lstm``). + +.. note:: + + Make sure you've already preprocessed the data from the IWSLT example in the + :file:`examples/translation/` directory. + +.. code-block:: console + + > fairseq-train data-bin/iwslt14.tokenized.de-en \ + --arch tutorial_simple_lstm \ + --encoder-dropout 0.2 --decoder-dropout 0.2 \ + --optimizer adam --lr 0.005 --lr-shrink 0.5 \ + --max-tokens 12000 + (...) + | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396 + | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954 + +The model files should appear in the :file:`checkpoints/` directory. While this +model architecture is not very good, we can use the :ref:`fairseq-generate` script to +generate translations and compute our BLEU score over the test set: + +.. code-block:: console + + > fairseq-generate data-bin/iwslt14.tokenized.de-en \ + --path checkpoints/checkpoint_best.pt \ + --beam 5 \ + --remove-bpe + (...) + | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s) + | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) + + +4. Making generation faster +--------------------------- + +While autoregressive generation from sequence-to-sequence models is inherently +slow, our implementation above is especially slow because it recomputes the +entire sequence of Decoder hidden states for every output token (i.e., it is +``O(n^2)``). We can make this significantly faster by instead caching the +previous hidden states. + +In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a +special mode at inference time where the Model only receives a single timestep +of input corresponding to the immediately previous output token (for teacher +forcing) and must produce the next output incrementally. Thus the model must +cache any long-term state that is needed about the sequence, e.g., hidden +states, convolutional states, etc. + +To implement incremental decoding we will modify our model to implement the +:class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the +standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental +decoder interface allows ``forward()`` methods to take an extra keyword argument +(*incremental_state*) that can be used to cache state across time-steps. + +Let's replace our ``SimpleLSTMDecoder`` with an incremental one:: + + import torch + from fairseq.models import FairseqIncrementalDecoder + + class SimpleLSTMDecoder(FairseqIncrementalDecoder): + + def __init__( + self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, + dropout=0.1, + ): + # This remains the same as before. + super().__init__(dictionary) + self.embed_tokens = nn.Embedding( + num_embeddings=len(dictionary), + embedding_dim=embed_dim, + padding_idx=dictionary.pad(), + ) + self.dropout = nn.Dropout(p=dropout) + self.lstm = nn.LSTM( + input_size=encoder_hidden_dim + embed_dim, + hidden_size=hidden_dim, + num_layers=1, + bidirectional=False, + ) + self.output_projection = nn.Linear(hidden_dim, len(dictionary)) + + # We now take an additional kwarg (*incremental_state*) for caching the + # previous hidden and cell states. + def forward(self, prev_output_tokens, encoder_out, incremental_state=None): + if incremental_state is not None: + # If the *incremental_state* argument is not ``None`` then we are + # in incremental inference mode. While *prev_output_tokens* will + # still contain the entire decoded prefix, we will only use the + # last step and assume that the rest of the state is cached. + prev_output_tokens = prev_output_tokens[:, -1:] + + # This remains the same as before. + bsz, tgt_len = prev_output_tokens.size() + final_encoder_hidden = encoder_out['final_hidden'] + x = self.embed_tokens(prev_output_tokens) + x = self.dropout(x) + x = torch.cat( + [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], + dim=2, + ) + + # We will now check the cache and load the cached previous hidden and + # cell states, if they exist, otherwise we will initialize them to + # zeros (as before). We will use the ``utils.get_incremental_state()`` + # and ``utils.set_incremental_state()`` helpers. + initial_state = utils.get_incremental_state( + self, incremental_state, 'prev_state', + ) + if initial_state is None: + # first time initialization, same as the original version + initial_state = ( + final_encoder_hidden.unsqueeze(0), # hidden + torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell + ) + + # Run one step of our LSTM. + output, latest_state = self.lstm(x.transpose(0, 1), initial_state) + + # Update the cache with the latest hidden and cell states. + utils.set_incremental_state( + self, incremental_state, 'prev_state', latest_state, + ) + + # This remains the same as before + x = output.transpose(0, 1) + x = self.output_projection(x) + return x, None + + # The ``FairseqIncrementalDecoder`` interface also requires implementing a + # ``reorder_incremental_state()`` method, which is used during beam search + # to select and reorder the incremental state. + def reorder_incremental_state(self, incremental_state, new_order): + # Load the cached state. + prev_state = utils.get_incremental_state( + self, incremental_state, 'prev_state', + ) + + # Reorder batches according to *new_order*. + reordered_state = ( + prev_state[0].index_select(1, new_order), # hidden + prev_state[1].index_select(1, new_order), # cell + ) + + # Update the cached state. + utils.set_incremental_state( + self, incremental_state, 'prev_state', reordered_state, + ) + +Finally, we can rerun generation and observe the speedup: + +.. code-block:: console + + # Before + + > fairseq-generate data-bin/iwslt14.tokenized.de-en \ + --path checkpoints/checkpoint_best.pt \ + --beam 5 \ + --remove-bpe + (...) + | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s) + | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) + + # After + + > fairseq-generate data-bin/iwslt14.tokenized.de-en \ + --path checkpoints/checkpoint_best.pt \ + --beam 5 \ + --remove-bpe + (...) + | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s) + | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) diff --git a/fairseq/examples/.gitignore b/fairseq/examples/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1ef816f2cd7b4a9aa7adf8bd5635a644834738f1 --- /dev/null +++ b/fairseq/examples/.gitignore @@ -0,0 +1,2 @@ +!*/*.sh +!*/*.md diff --git a/fairseq/examples/MMPT/CONFIG.md b/fairseq/examples/MMPT/CONFIG.md new file mode 100644 index 0000000000000000000000000000000000000000..bbd1403dfafc79ab581c691281ca49de28a4656f --- /dev/null +++ b/fairseq/examples/MMPT/CONFIG.md @@ -0,0 +1,41 @@ +### Config Files Explained + +Taking `projects/mfmmlm.yaml` for example, which run pretraining using masked frame model (MFM) and masked language model (MLM) on a single BERT: + +```yaml +project_dir: mfmmlm # specify the project dir for this baseline. +run_task: + - how2.yaml # run pretraining on how2 when launching `projects/taskmfmmlm.yaml` + - [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml] # run fine-tuning tasks. +base_dir: task # a global template folder to specify each training task. +task_group: + pretrain: # section for pretraining. Most baselines differs in this section. + task_list: + - how2.yaml # reconfig `projects/task/how2.yaml` + dataset: + aligner: MFMMLMAligner # overwrite the aligner for MFMMLM training task. + model: + model_cls: MMFusionMFMMLM # overwrite the model, which constructs negative examples for MFM on-the-fly. + loss: + loss_cls: MFMMLM # overwrite the loss as MFMMLM, which combines MFM and MLM together. + fairseq: # all fairseq args can be expecified under this name. + dataset: + batch_size: 128 + finetune: # section for fine-tuning tasks, we don't need to change anything here mostly since we want to see how pretraining can contribute to finetuning. + task_list: # specify the list of downstream tasks, e.g., copy `projects/task/vtt.yaml` to `projects/mfmmlm`. + - vtt.yaml + - vttqa.yaml + - youcook.yaml + - youcookcap.yaml + - crosstask.yaml + - coin.yaml + test: # section for testing. + task_list: + - test_vtt.yaml + - test_vttqa.yaml + - test_youcook.yaml + - test_youcookcap.yaml + - test_crosstask.yaml + - test_crosstask_zs.yaml + - test_coin.yaml +``` diff --git a/fairseq/examples/MMPT/DATASET.md b/fairseq/examples/MMPT/DATASET.md new file mode 100644 index 0000000000000000000000000000000000000000..930403eb3651a14dfcd015d094383faeb052c8ac --- /dev/null +++ b/fairseq/examples/MMPT/DATASET.md @@ -0,0 +1,34 @@ +# Dataset + +We understand video data are challenging to download and process. For videos, we provide our preprocessing scripts under `scripts/video_feature_extractor` (deeply adapted from `https://github.com/antoine77340/video_feature_extractor`); for text, we pre-tokenizing scripts under `scripts/text_token_extractor`. + +### S3D Feature Extraction +We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`. + +We implement a `PathBuilder` to automatically track video ids, source video paths to their feature locations (you may need `conda install -c anaconda pandas`). Decoding may need `pip install ffmpeg-python`. + +### Howto100M +[Howto100M](https://www.di.ens.fr/willow/research/howto100m/) is a large-scale video pre-training datasets. You may download videos by yourself and run preprocessing of our scripts. + +Several key differences of our preprocessing from existing papers: (1) we use `raw_caption.json` instead of `caption.json` to have pure self-supervision on text (`caption.json` has manual removal of stop words); (2) we remove partially duplicated texts that are originally designed for real-time readability (see `mmpt/processors/dedupprocessor.py`); (3) then we shard video/text features using `SharedTensor` in `mmpt/utils/shardedtensor.py` for fast loading during training (faster than `h5py`). + +#### Steps +##### video +To extract video features: edit and run `bash scripts/video_feature_extractor/how2/s3d.sh`. (consider to run this on multiple machines; by default, we store features in fp16 to save space and also for faster training). + +Split available video ids as `data/how2/how2_s3d_train.lst` and `data/how2/how2_s3d_val.lst`. + +Lastly, pack video features into `ShardedTensor` using `python scripts/video_feature_extractor/shard_feature.py`. + +##### text +Clean captions using `python -m mmpt.processors.dedupprocessor`. + +Tokenize dedupped captions `data/how2/raw_caption_dedup.pkl` into sharded numpy arrays: +``` +python scripts/text_token_extractor/pretokenization.py scripts/text_token_extractor/configs/bert-base-uncased.yaml +``` + +### Youcook, MSRVTT etc. +We use the version of Youcook and MSRVTT come with Howto100M and MILNCE. Please download the data to `data/youcook` and `data/msrvtt` accordingly, you can also check `projects/task/youcook.yaml` and `projects/task/vtt.yaml` etc. in details. +We extract features for Youcook, MSRVTT similar to the first step of Howto100M but we read text from meta data directly and perform on-the-fly tokenization. + diff --git a/fairseq/examples/MMPT/locallaunch.py b/fairseq/examples/MMPT/locallaunch.py new file mode 100644 index 0000000000000000000000000000000000000000..e20fd816fa3bed8b1af8f6a4d1a07ccb69a1fffa --- /dev/null +++ b/fairseq/examples/MMPT/locallaunch.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import os + +from omegaconf import OmegaConf + +from mmpt.utils import recursive_config, overwrite_dir +from mmpt_cli.localjob import LocalJob + + +class JobLauncher(object): + JOB_CONFIG = { + "local": LocalJob, + } + + def __init__(self, yaml_file): + self.yaml_file = yaml_file + job_key = "local" + + if yaml_file.endswith(".yaml"): + config = recursive_config(yaml_file) + if config.task_type is not None: + job_key = config.task_type.split("_")[0] + else: + raise ValueError("unknown extension of job file:", yaml_file) + self.job_key = job_key + + def __call__(self, job_type=None, dryrun=False): + if job_type is not None: + self.job_key = job_type.split("_")[0] + print("[JobLauncher] job_key", self.job_key) + job = JobLauncher.JOB_CONFIG[self.job_key]( + self.yaml_file, job_type=job_type, dryrun=dryrun) + return job.submit() + + +class Pipeline(object): + """a job that loads yaml config.""" + + def __init__(self, fn): + """ + load a yaml config of a job and save generated configs as yaml for each task. + return: a list of files to run as specified by `run_task`. + """ + if fn.endswith(".py"): + # a python command. + self.backend = "python" + self.run_yamls = [fn] + return + + job_config = recursive_config(fn) + if job_config.base_dir is None: # single file job config. + self.run_yamls = [fn] + return + + self.project_dir = os.path.join("projects", job_config.project_dir) + self.run_dir = os.path.join("runs", job_config.project_dir) + + if job_config.run_task is not None: + run_yamls = [] + for stage in job_config.run_task: + # each stage can have multiple tasks running in parallel. + if OmegaConf.is_list(stage): + stage_yamls = [] + for task_file in stage: + stage_yamls.append( + os.path.join(self.project_dir, task_file)) + run_yamls.append(stage_yamls) + else: + run_yamls.append(os.path.join(self.project_dir, stage)) + self.run_yamls = run_yamls + configs_to_save = self._overwrite_task(job_config) + self._save_configs(configs_to_save) + + def __getitem__(self, idx): + yaml_files = self.run_yamls[idx] + if isinstance(yaml_files, list): + return [JobLauncher(yaml_file) for yaml_file in yaml_files] + return [JobLauncher(yaml_files)] + + def __len__(self): + return len(self.run_yamls) + + def _save_configs(self, configs_to_save: dict): + # save + os.makedirs(self.project_dir, exist_ok=True) + for config_file in configs_to_save: + config = configs_to_save[config_file] + print("saving", config_file) + OmegaConf.save(config=config, f=config_file) + + def _overwrite_task(self, job_config): + configs_to_save = {} + self.base_project_dir = os.path.join("projects", job_config.base_dir) + self.base_run_dir = os.path.join("runs", job_config.base_dir) + + for config_sets in job_config.task_group: + overwrite_config = job_config.task_group[config_sets] + if ( + overwrite_config.task_list is None + or len(overwrite_config.task_list) == 0 + ): + print( + "[warning]", + job_config.task_group, + "has no task_list specified.") + # we don't want this added to a final config. + task_list = overwrite_config.pop("task_list", None) + for config_file in task_list: + config_file_path = os.path.join( + self.base_project_dir, config_file) + config = recursive_config(config_file_path) + # overwrite it. + if overwrite_config: + config = OmegaConf.merge(config, overwrite_config) + overwrite_dir(config, self.run_dir, basedir=self.base_run_dir) + save_file_path = os.path.join(self.project_dir, config_file) + configs_to_save[save_file_path] = config + return configs_to_save + + +def main(args): + job_type = args.jobtype if args.jobtype else None + # parse multiple pipelines. + pipelines = [Pipeline(fn) for fn in args.yamls.split(",")] + + for pipe_id, pipeline in enumerate(pipelines): + if not hasattr(pipeline, "project_dir"): + for job in pipeline[0]: + job(job_type=job_type, dryrun=args.dryrun) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("yamls", type=str) + parser.add_argument( + "--dryrun", + action="store_true", + help="run config and prepare to submit without launch the job.", + ) + parser.add_argument( + "--jobtype", type=str, default="", + help="force to run jobs as specified.") + args = parser.parse_args() + main(args) diff --git a/fairseq/examples/MMPT/mmpt/__init__.py b/fairseq/examples/MMPT/mmpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff86ddd5ce0c454281e6568c628bd3c49ea5024 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +try: + # fairseq user dir + from .datasets import FairseqMMDataset + from .losses import FairseqCriterion + from .models import FairseqMMModel + from .tasks import FairseqMMTask +except ImportError: + pass diff --git a/fairseq/examples/MMPT/mmpt/datasets/__init__.py b/fairseq/examples/MMPT/mmpt/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2578235e1771fdc7e6fcfb66a519cbe891d7e254 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/datasets/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .mmdataset import * + +try: + from .fairseqmmdataset import * +except ImportError: + pass diff --git a/fairseq/examples/MMPT/mmpt/evaluators/metric.py b/fairseq/examples/MMPT/mmpt/evaluators/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..163724bb250cb1b7057b3fa4d75a9fafa9c181f5 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/evaluators/metric.py @@ -0,0 +1,313 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import json + + +class Metric(object): + def __init__(self, config, metric_names): + self.metric_names = metric_names + + def best_metric(self, metric): + return metric[self.metric_names[0]] + + def save_metrics(self, fn, metrics): + with open(fn, "w") as fw: + json.dump(fw, metrics) + + def print_computed_metrics(self, metrics): + raise NotImplementedError + + +class RetrievalMetric(Metric): + """ + this is modified from `howto100m/metrics.py`. + History of changes: + refactor as a class. + add metric_key in __init__ + """ + + def __init__(self, config, metric_names=["R1", "R5", "R10", "MR"]): + super().__init__(config, metric_names) + self.error = False # TODO(huxu): add to config to print error. + + def compute_metrics(self, outputs, texts, **kwargs): + x = outputs + sx = np.sort(-x, axis=1) + d = np.diag(-x) + d = d[:, np.newaxis] + ind = sx - d + ind = np.where(ind == 0) + ind = ind[1] + metrics = {} + metrics["R1"] = float(np.sum(ind == 0)) / len(ind) + metrics["R5"] = float(np.sum(ind < 5)) / len(ind) + metrics["R10"] = float(np.sum(ind < 10)) / len(ind) + metrics["MR"] = np.median(ind) + 1 + + max_idx = np.argmax(outputs, axis=1) + if self.error: + # print top-20 errors. + error = [] + for ex_idx in range(20): + error.append((texts[ex_idx], texts[max_idx[ex_idx]])) + metrics["error"] = error + return metrics + + def print_computed_metrics(self, metrics): + r1 = metrics["R1"] + r5 = metrics["R5"] + r10 = metrics["R10"] + mr = metrics["MR"] + print( + "R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}".format( + r1, r5, r10, mr + ) + ) + if "error" in metrics: + print(metrics["error"]) + + +class DiDeMoMetric(Metric): + """ + History of changes: + python 2.x to python 3.x. + merge utils.py into eval to save one file. + reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py + Code to evaluate your results on the DiDeMo dataset. + """ + def __init__(self, config, metric_names=["rank1", "rank5", "miou"]): + super().__init__(config, metric_names) + + def compute_metrics(self, outputs, targets, **kwargs): + assert len(outputs) == len(targets) + rank1, rank5, miou = self._eval_predictions(outputs, targets) + metrics = { + "rank1": rank1, + "rank5": rank5, + "miou": miou + } + return metrics + + def print_computed_metrics(self, metrics): + rank1 = metrics["rank1"] + rank5 = metrics["rank5"] + miou = metrics["miou"] + # print("Average rank@1: %f" % rank1) + # print("Average rank@5: %f" % rank5) + # print("Average iou: %f" % miou) + + print( + "Average rank@1: {:.4f} Average rank@5: {:.4f} Average iou: {:.4f}".format( + rank1, rank5, miou + ) + ) + + def _iou(self, pred, gt): + intersection = max(0, min(pred[1], gt[1]) + 1 - max(pred[0], gt[0])) + union = max(pred[1], gt[1]) + 1 - min(pred[0], gt[0]) + return float(intersection)/union + + def _rank(self, pred, gt): + return pred.index(tuple(gt)) + 1 + + def _eval_predictions(self, segments, data): + ''' + Inputs: + segments: For each item in the ground truth data, rank possible video segments given the description and video. + In DiDeMo, there are 21 posible moments extracted for each video so the list of video segments will be of length 21. + The first video segment should be the video segment that best corresponds to the text query. + There are 4180 sentence in the validation data, so when evaluating a model on the val dataset, + segments should be a list of lenght 4180, and each item in segments should be a list of length 21. + data: ground truth data + ''' + average_ranks = [] + average_iou = [] + for s, d in zip(segments, data): + pred = s[0] + ious = [self._iou(pred, t) for t in d['times']] + average_iou.append(np.mean(np.sort(ious)[-3:])) + ranks = [self._rank(s, t) for t in d['times'] if tuple(t) in s] # if t in s] is added for s, e not in prediction. + average_ranks.append(np.mean(np.sort(ranks)[:3])) + rank1 = np.sum(np.array(average_ranks) <= 1)/float(len(average_ranks)) + rank5 = np.sum(np.array(average_ranks) <= 5)/float(len(average_ranks)) + miou = np.mean(average_iou) + + # print("Average rank@1: %f" % rank1) + # print("Average rank@5: %f" % rank5) + # print("Average iou: %f" % miou) + return rank1, rank5, miou + + +class NLGMetric(Metric): + def __init__( + self, + config, + metric_names=[ + "Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", + "METEOR", "ROUGE_L", "CIDEr" + ] + ): + super().__init__(config, metric_names) + # please install NLGEval from `https://github.com/Maluuba/nlg-eval` + from nlgeval import NLGEval + self.nlg = NLGEval() + + def compute_metrics(self, outputs, targets, **kwargs): + return self.nlg.compute_metrics( + hyp_list=outputs, ref_list=targets) + + def print_computed_metrics(self, metrics): + Bleu_1 = metrics["Bleu_1"] + Bleu_2 = metrics["Bleu_2"] + Bleu_3 = metrics["Bleu_3"] + Bleu_4 = metrics["Bleu_4"] + METEOR = metrics["METEOR"] + ROUGE_L = metrics["ROUGE_L"] + CIDEr = metrics["CIDEr"] + + print( + "Bleu_1: {:.4f} - Bleu_2: {:.4f} - Bleu_3: {:.4f} - Bleu_4: {:.4f} - METEOR: {:.4f} - ROUGE_L: {:.4f} - CIDEr: {:.4f}".format( + Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, CIDEr + ) + ) + + +class QAMetric(Metric): + def __init__( + self, + config, + metric_names=["acc"] + ): + super().__init__(config, metric_names) + + def compute_metrics(self, outputs, targets, **kwargs): + from sklearn.metrics import accuracy_score + return {"acc": accuracy_score(targets, outputs)} + + def print_computed_metrics(self, metrics): + print("acc: {:.4f}".format(metrics["acc"])) + + +class COINActionSegmentationMetric(Metric): + """ + COIN dataset listed 3 repos for Action Segmentation. + Action Sets, NeuralNetwork-Viterbi, TCFPN-ISBA. + The first and second are the same. + https://github.com/alexanderrichard/action-sets/blob/master/eval.py + + Future reference for the third: + `https://github.com/Zephyr-D/TCFPN-ISBA/blob/master/utils/metrics.py` + """ + def __init__(self, config, metric_name=["frame_acc"]): + super().__init__(config, metric_name) + + def compute_metrics(self, outputs, targets): + n_frames = 0 + n_errors = 0 + n_errors = sum(outputs != targets) + n_frames = len(targets) + return {"frame_acc": 1.0 - float(n_errors) / n_frames} + + def print_computed_metrics(self, metrics): + fa = metrics["frame_acc"] + print("frame accuracy:", fa) + + +class CrossTaskMetric(Metric): + def __init__(self, config, metric_names=["recall"]): + super().__init__(config, metric_names) + + def compute_metrics(self, outputs, targets, **kwargs): + """refactored from line 166: + https://github.com/DmZhukov/CrossTask/blob/master/train.py""" + + recalls = self._get_recalls(Y_true=targets, Y_pred=outputs) + results = {} + for task, rec in recalls.items(): + results[str(task)] = rec + + avg_recall = np.mean(list(recalls.values())) + results["recall"] = avg_recall + return results + + def print_computed_metrics(self, metrics): + print('Recall: {0:0.3f}'.format(metrics["recall"])) + for task in metrics: + if task != "recall": + print('Task {0}. Recall = {1:0.3f}'.format( + task, metrics[task])) + + def _get_recalls(self, Y_true, Y_pred): + """refactored from + https://github.com/DmZhukov/CrossTask/blob/master/train.py""" + + step_match = {task: 0 for task in Y_true.keys()} + step_total = {task: 0 for task in Y_true.keys()} + for task, ys_true in Y_true.items(): + ys_pred = Y_pred[task] + for vid in set(ys_pred.keys()).intersection(set(ys_true.keys())): + y_true = ys_true[vid] + y_pred = ys_pred[vid] + step_total[task] += (y_true.sum(axis=0) > 0).sum() + step_match[task] += (y_true*y_pred).sum() + recalls = { + task: step_match[task] / n for task, n in step_total.items()} + return recalls + + +class ActionRecognitionMetric(Metric): + def __init__( + self, + config, + metric_names=["acc", "acc_splits", "r1_splits", "r5_splits", "r10_splits"] + ): + super().__init__(config, metric_names) + + def compute_metrics(self, outputs, targets, splits, **kwargs): + all_video_embd = outputs + labels = targets + split1, split2, split3 = splits + accs = [] + r1s = [] + r5s = [] + r10s = [] + for split in range(3): + if split == 0: + s = split1 + elif split == 1: + s = split2 + else: + s = split3 + + X_pred = all_video_embd[np.where(s == 2)[0]] + label_test = labels[np.where(s == 2)[0]] + logits = X_pred + X_pred = np.argmax(X_pred, axis=1) + acc = np.sum(X_pred == label_test) / float(len(X_pred)) + accs.append(acc) + # compute recall. + sorted_pred = (-logits).argsort(axis=-1) + label_test_sp = label_test.reshape(-1, 1) + + r1 = np.mean((sorted_pred[:, :1] == label_test_sp).sum(axis=1), axis=0) + r5 = np.mean((sorted_pred[:, :5] == label_test_sp).sum(axis=1), axis=0) + r10 = np.mean((sorted_pred[:, :10] == label_test_sp).sum(axis=1), axis=0) + r1s.append(r1) + r5s.append(r5) + r10s.append(r10) + + return {"acc": accs[0], "acc_splits": accs, "r1_splits": r1s, "r5_splits": r5s, "r10_splits": r10s} + + def print_computed_metrics(self, metrics): + for split, acc in enumerate(metrics["acc_splits"]): + print("Top 1 accuracy on split {}: {}; r1 {}; r5 {}; r10 {}".format( + split + 1, acc, + metrics["r1_splits"][split], + metrics["r5_splits"][split], + metrics["r10_splits"][split], + ) + ) diff --git a/fairseq/examples/MMPT/mmpt/evaluators/predictor.py b/fairseq/examples/MMPT/mmpt/evaluators/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..2ffef6ab474a7a275d90784250d04222bd6dc70f --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/evaluators/predictor.py @@ -0,0 +1,595 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import random +import json +import numpy as np +import torch +import pickle +import math + +from tqdm import tqdm + + +class Predictor(object): + """this base class is used to save predictions to disk + (and being called by a evaluator later). + Predictor has minimum support of single gpu prediction. + """ + def __init__(self, config): + self.pred_dir = None # on-the-fly eval does not save the results. + if hasattr(config, "eval") and config.eval is not None: + self.pred_dir = config.eval.save_path + os.makedirs(self.pred_dir, exist_ok=True) + + def __call__(self, outputs): + """extract the prediction and save it.""" + raise NotImplementedError + + def predict_loop(self, model, eval_dataloader, output_file=None): + """on-the-fly prediction on a single gpu.""" + self.full_scores = [] + model.eval() + model = model.to(0) + with torch.no_grad(): + for data in eval_dataloader: + data = self.to_ctx(data) + outputs = model(**data) + outputs.update(data) + self(outputs) + return self.finalize(output_file) + + def finalize(self, output_file): + pass + + def to_ctx(self, data, ctx=0, dtype=None): + if isinstance(data, dict): + for key in data: + if torch.is_tensor(data[key]): + if dtype is not None and data[key].dtype == torch.float32: + data[key] = data[key].to(dtype) + data[key] = data[key].to(ctx) + return data + else: + raise ValueError("non-dict type of batch is not supported yet.") + + +class NLGPredictor(Predictor): + """Predicting Text from MMFusion models.""" + """TODO: make a context.""" + def __init__(self, config): + super().__init__(config) + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained( + config.dataset.bert_name, + bos_token="[CLS]", eos_token="[SEP]") + self.bos_token_id = self.tokenizer.bos_token_id + self.eos_token_id = self.tokenizer.eos_token_id + + def predict_loop(self, model, eval_dataloader, output_file=None): + """TODO: refactor base classes.""" + ctx = 0 + outputs = {"outputs": [], "targets": [[]]} + model.eval() + model = model.to(ctx) + with torch.no_grad(): + for data in tqdm(eval_dataloader): + data = self.to_ctx(data, ctx) + self(data, model, outputs) + return self.finalize(outputs, output_file) + + def __call__(self, data, model, outputs): + data.update({ + "bos_token_id": self.bos_token_id, + "eos_token_id": self.eos_token_id + }) + + output = model.generate(**data) + assert len(output) == len(data["ref"]) + for idx, _output in enumerate(output): + generated_text = self.tokenizer.decode( + _output, skip_special_tokens=True) + if generated_text == "": + generated_text = "none" + outputs["outputs"].append(generated_text) + outputs["targets"][0].append(data["ref"][idx]) + if random.random() < 0.001: + print("_output", _output) + print("generated_text", generated_text) + print("ref", data["ref"][idx]) + + def finalize(self, outputs, output_file=None): + if output_file is not None: + with open(os.path.join( + self.pred_dir, output_file + ".json"), "w") as fw: + json.dump(outputs, fw, indent=4) + return outputs + + +class RetrievalPredictor(Predictor): + """generated `pooled_video` and `pooled_text`.""" + def __init__(self, config): + super().__init__(config) + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + config.dataset.bert_name) + + def predict_loop( + self, + model, + eval_dataloader, + output_file="retrieval.npy" + ): + """on-the-fly prediction on a single gpu.""" + full_scores = [] + texts = [] + model.eval() + model = model.cuda() + with torch.no_grad(): + for data in eval_dataloader: + # convert to dict. + if not isinstance(data, dict): + data = { + "caps": data[0], + "cmasks": data[1], + "vfeats": data[2], + "vmasks": data[3], + "video_id": data[4] + } + data = self.to_ctx(data) + outputs = model(**data) + outputs.update(data) + self(outputs, full_scores) + for _cap in data["caps"]: + texts.append( + self.tokenizer.decode(_cap, skip_special_tokens=True) + ) + + return self.finalize(full_scores, texts, output_file) + + def __call__(self, sample, full_scores): + scores = self._get_pooled_outputs(sample) + self._append_scores(scores, full_scores) + + def finalize(self, full_scores, texts, output_file=None): + outputs = self._aggregate_scores(full_scores) + if output_file is not None: + np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs) + return {"outputs": outputs, "texts": texts} + + def _get_pooled_outputs(self, outputs): + if "pooled_video" in outputs: + return outputs["pooled_video"], outputs["pooled_text"] + else: + raise ValueError("unknown format of outputs.") + + def _append_scores(self, scores, full_scores): + assert len(scores) == 2 + if len(full_scores) == 0: + full_scores.append([]) + full_scores.append([]) + full_scores[0].append(scores[0].cpu().detach().numpy()) + full_scores[1].append(scores[1].cpu().detach().numpy()) + + def _aggregate_scores(self, scores): + assert len(scores) == 2 + video_hidden = np.concatenate(scores[0], axis=0) + text_hidden = np.concatenate(scores[1], axis=0) + # clear up. + self.full_scores = [] + return np.matmul(text_hidden, video_hidden.T) + + +class QAPredictor(Predictor): + """generated `pooled_video` and `pooled_text`.""" + def __init__(self, config): + super().__init__(config) + """predictor maintains scores and aggregate them.""" + + def predict_loop(self, model, eval_dataloader, output_file="qa.npy"): + """on-the-fly prediction on a single gpu.""" + self.full_scores = [] + model.eval() + model = model.cuda() + with torch.no_grad(): + for data in eval_dataloader: + # reshape ans and dup video 5 times. + v_len = data["vfeats"].size(1) + hidden_size = data["vfeats"].size(2) + data["vfeats"] = data["vfeats"].unsqueeze(1).repeat(1, 5, 1, 1).view(-1, v_len, hidden_size) + data["vmasks"] = data["vmasks"].unsqueeze(1).repeat(1, 5, 1).view(-1, v_len) + + t_len = data["caps"].size(-1) + data["caps"] = data["caps"].view(-1, t_len) + data["cmasks"] = data["cmasks"].view(-1, t_len) + + data = self.to_ctx(data) + outputs = model(**data) + outputs.update(data) + self(outputs) + return self.finalize(output_file) + + def __call__(self, sample): + hidden_size = sample["pooled_video"].size(-1) + pooled_video = sample["pooled_video"].view(-1, 5, hidden_size) + pooled_text = sample["pooled_text"].view(-1, 5, hidden_size) + scores = torch.bmm(pooled_video, pooled_text.transpose(2, 1)) + scores = scores.argmax(-1) + self._append_scores(scores[:, 0], sample["answers"], self.full_scores) + + def finalize(self, output_file=None): + outputs, targets = self._aggregate_scores(self.full_scores) + if output_file is not None: + np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs) + return {"outputs": outputs, "targets": targets} + + def _append_scores(self, scores, answers, full_scores): + if len(full_scores) == 0: + full_scores.append([]) + full_scores.append([]) + full_scores[0].append(scores.cpu().detach().numpy()) + full_scores[1].append(answers.cpu().detach().numpy()) + + def _aggregate_scores(self, scores): + assert len(scores) == 2 + outputs = np.concatenate(scores[0], axis=0) + targets = np.concatenate(scores[1], axis=0) + # clear up. + self.full_scores = [] + return outputs, targets + + +class CrossTaskPredictor(Predictor): + """ + CrossTaskPredictor needs to compute the average of logits + for overlapped sliding-window. + """ + def __init__(self, config): + super().__init__(config) + self.lsm = torch.nn.LogSoftmax(dim=1) + self.max_video_len = config.dataset.max_video_len + self.sliding_window = config.dataset.sliding_window + self.sliding_window_size = config.dataset.sliding_window_size + self.annotation_path = config.dataset.annotation_path + + def predict_loop(self, model, eval_dataloader, output_file="result.pkl"): + """refactored from line 144: + https://github.com/DmZhukov/CrossTask/blob/master/train.py + """ + ctx = 0 + model.eval() + model = model.to(ctx) + # this is not a loss but just compute neg_log_prob. + Y_pred = {} + Y_true = {} + with torch.no_grad(): + for batch in eval_dataloader: + self(batch, model, Y_pred, Y_true) + return self.finalize(Y_pred, Y_true, output_file) + + def __call__(self, sample, model, Y_pred, Y_true): + # please install dp from `https://github.com/DmZhukov/CrossTask` + from dp import dp + vid, task = sample['video_id'][0], sample['task'][0] + sample = self.to_ctx(sample) + # compute the average logits over sliding windows. + output = model(**sample) + batch_logits = output["logits"].cpu() + + video_len = sample["video_len"][0] + + # the following version is slow. + logits = torch.zeros((video_len, batch_logits.size(1))) + logits_counts = torch.zeros((video_len, 1), dtype=torch.long) + # use the same loop as aligner to recover. + batch_logit_idx = 0 + for window_start in range(0, video_len, self.sliding_window): + video_end = min(video_len - window_start, self.sliding_window_size) + logits[window_start: window_start + video_end] += batch_logits[ + batch_logit_idx: batch_logit_idx + video_end] + batch_logit_idx += video_end + logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long) + + if (video_len - window_start) <= self.sliding_window_size: + break + + logits /= logits_counts + assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len) + + O = self.lsm(logits) + y = np.zeros(O.size(), dtype=np.float32) + dp(y, -O.detach().cpu().numpy()) + if task not in Y_pred: + Y_pred[task] = {} + Y_pred[task][vid] = y + annot_path = os.path.join( + self.annotation_path, task+'_'+vid+'.csv') + if os.path.exists(annot_path): + if task not in Y_true: + Y_true[task] = {} + Y_true[task][vid] = self._read_assignment( + *y.shape, annot_path) + + def finalize(self, Y_pred, Y_true, output_file=None): + if output_file is not None: + with open( + os.path.join(self.pred_dir, output_file + ".pkl"), + "wb") as fw: + pickle.dump( + {"Y_pred": Y_pred, "Y_true": Y_true}, fw, + protocol=pickle.HIGHEST_PROTOCOL) + return {"outputs": Y_pred, "targets": Y_true} + + def _read_assignment(self, T, K, path): + """ + refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py + Howto interpret contraints on loss that is going to be minimized: + lambd is a big number; + self.lambd * C is a big number for all valid position (csv stores invalids) + + def forward(self, O, Y, C): + return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum() + + This will load the csv file and fill-in the step col from start to end rows. + """ + + Y = np.zeros([T, K], dtype=np.uint8) + with open(path, 'r') as f: + for line in f: + step, start, end = line.strip().split(',') + start = int(math.floor(float(start))) + end = int(math.ceil(float(end))) + step = int(step) - 1 + Y[start:end, step] = 1 + return Y + + +class COINPredictor(Predictor): + """ + COINPredictor is similar to CrossTask on sliding windows. + """ + def __init__(self, config): + super().__init__(config) + self.max_video_len = config.dataset.max_video_len + self.sliding_window = config.dataset.sliding_window + self.sliding_window_size = config.dataset.sliding_window_size + + def predict_loop(self, model, eval_dataloader, output_file="result.pkl"): + """refactored from line 144: + https://github.com/DmZhukov/CrossTask/blob/master/train.py + """ + ctx = 0 + model.eval() + model = model.to(ctx) + # this is not a loss but just compute neg_log_prob. + Y_pred = [] + Y_true = [] + with torch.no_grad(): + for batch in eval_dataloader: + self(batch, model, Y_pred, Y_true) + return self.finalize(Y_pred, Y_true, output_file) + + def __call__(self, sample, model, Y_pred, Y_true): + sample = self.to_ctx(sample) + # compute the average logits over sliding windows. + output = model(**sample) + logits = self._merge_windows(sample, output) + Y_pred.append(logits.argmax(dim=1)) + Y_true.append(sample["video_targets"].squeeze(0).cpu()) + + def _merge_windows(self, sample, output): + targets = sample["targets"].reshape(-1).cpu() + valid_mask = targets != -100 + targets = targets[valid_mask] + batch_logits = output["logits"].cpu() + batch_logits = batch_logits.reshape(-1, batch_logits.size(-1)) + batch_logits = batch_logits[valid_mask] + + video_len = sample["video_len"][0] + + # the following version is slow. + logits = torch.zeros((video_len, batch_logits.size(1))) + logits_counts = torch.zeros((video_len, 1), dtype=torch.long) + # use the same loop as aligner to recover. + batch_logit_idx = 0 + for window_start in range(0, video_len, self.sliding_window): + video_end = min(video_len - window_start, self.sliding_window_size) + logits[window_start: window_start + video_end] += batch_logits[ + batch_logit_idx: batch_logit_idx + video_end] + batch_logit_idx += video_end + logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long) + if (video_len - window_start) <= self.sliding_window_size: + break + logits /= logits_counts + assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len) + return logits + + def finalize(self, Y_pred, Y_true, output_file=None): + Y_pred = torch.cat(Y_pred, dim=0).numpy() + Y_true = torch.cat(Y_true, dim=0).numpy() + assert len(Y_pred) == len(Y_true) + + error_mask = Y_pred != Y_true + print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10]) + print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20]) + + if output_file is not None: + with open( + os.path.join(self.pred_dir, output_file + ".pkl"), + "wb") as fw: + pickle.dump( + {"Y_pred": Y_pred, "Y_true": Y_true}, fw, + protocol=pickle.HIGHEST_PROTOCOL) + return {"outputs": Y_pred, "targets": Y_true} + + +class COINZSPredictor(COINPredictor): + """ + COINZSPredictor for COIN zero-shot prediction. + """ + + def __init__(self, config): + super().__init__(config) + self.dataset_config = config.dataset + + def predict_loop(self, model, eval_dataloader, output_file="result.pkl"): + """refactored from line 144: + https://github.com/DmZhukov/CrossTask/blob/master/train.py + """ + ctx = 0 + model.eval() + model = model.to(ctx) + + with torch.no_grad(): + outputs = eval_dataloader.dataset.meta_processor.meta_text_labels( + self.dataset_config) + outputs = self.to_ctx(outputs, ctx) + label_hidden_states = model.forward_text(**outputs).cpu() + label_sim = label_hidden_states @ label_hidden_states.t() + num_labels = label_sim.size(0) + eye_mask = ~torch.eye(num_labels, dtype=torch.bool) + label_sim = label_sim.masked_select(eye_mask).view(num_labels, num_labels - 1) + lbd = label_sim.max() + + # this is not a loss but just compute neg_log_prob. + Y_pred = [] + Y_true = [] + with torch.no_grad(): + for batch in eval_dataloader: + self(batch, label_hidden_states, model, lbd, Y_pred, Y_true) + return self.finalize(Y_pred, Y_true, output_file) + + def reshape_subsample(self, sample): + for key in sample: + if torch.is_tensor(sample[key]): + sample[key] = self.flat_subsample(sample[key]) + return sample + + def flat_subsample(self, tensor): + if len(tensor.size()) > 1 and tensor.size(0) == 1: + tensor = tensor.squeeze(0) + return tensor + + def __call__(self, sample, label_hidden_states, model, lbd, Y_pred, Y_true): + sample = self.reshape_subsample(sample) + sample = self.to_ctx(sample) + # compute the average logits over sliding windows. + sample["output_hidden_states"] = True + video_outputs = model.forward_video(**sample).cpu() + output = {"logits": video_outputs[:, 1:sample["vmasks"].size(1)+1] @ label_hidden_states.t()} + logits = self._merge_windows(sample, output) + # logic of zero-shot for sequence labeling. + logits_argmax = logits.argmax(dim=1) + 1 # 0 is "O" label. + logits_max = logits.max(dim=1)[0] + + pred = torch.zeros_like(logits_argmax) + label_select = logits_max > lbd # 73 or 74 + pred[label_select] = logits_argmax[label_select] + + Y_pred.append(pred) + Y_true.append(sample["video_targets"].squeeze(0).cpu()) + + def finalize(self, Y_pred, Y_true, output_file=None): + Y_pred = torch.cat(Y_pred, dim=0).numpy() + Y_true = torch.cat(Y_true, dim=0).numpy() + assert len(Y_pred) == len(Y_true) + + error_mask = Y_pred != Y_true + print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10]) + print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20]) + + if output_file is not None: + with open( + os.path.join(self.pred_dir, output_file + ".pkl"), + "wb") as fw: + pickle.dump( + {"Y_pred": Y_pred, "Y_true": Y_true}, fw, + protocol=pickle.HIGHEST_PROTOCOL) + return {"outputs": Y_pred, "targets": Y_true} + + +class DiDeMoPredictor(Predictor): + """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py + https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py + """ + def __init__(self, config): + super().__init__(config) + # load targets. + with open(config.dataset.test_path) as data_file: + self.test_data = json.load(data_file) + + def predict_loop(self, model, eval_dataloader, output_file="didemo.npy"): + """ + TODO: two solutions here. + """ + import itertools + # 21 chunks. + self.possible_segments = [(0,0), (1,1), (2,2), (3,3), (4,4), (5,5)] + for i in itertools.combinations(range(6), 2): + self.possible_segments.append(i) + # pick segments from a video. + + """on-the-fly prediction on a single gpu.""" + self.full_scores = [] + model.eval() + model = model.cuda() + with torch.no_grad(): + for data in eval_dataloader: + # TODO special forwarding logic here. + data = self.to_ctx(data) + data["output_hidden_states"] = True + hidden_video = model.forward_video(**data) + data["output_hidden_states"] = False + pooled_text = model.forward_text(**data) + outputs = { + "hidden_video": hidden_video, + "pooled_text": pooled_text + } + outputs.update(data) + self(outputs) + return self.finalize(output_file) + + def __call__(self, sample): + # TODO: make an index select from self.possible_segments. + hidden_video = sample["hidden_video"] + pooled_text = sample["pooled_text"] + vmasks = sample["vmasks"] + # probably maintain valid results here. + + hidden_video = hidden_video[:, 1:-1, :] + # probably maintain valid results here. + pooled_video = [] + for s, e in self.possible_segments: + pooled_video.append( + torch.mean( + hidden_video[:, int(s*5):int((e+1)*5), :], + dim=1, keepdim=True) + ) + pooled_video = torch.cat(pooled_video, dim=1) + scores = torch.bmm( + pooled_video, pooled_text.unsqueeze(-1)).squeeze(-1).cpu() + + ranks = scores.argsort(dim=-1, descending=True) + + for batch_idx, rank in enumerate(ranks): + rank_of_moment = [] + for m_idx, moment in enumerate(rank): + s, e = self.possible_segments[moment.item()] + if torch.any( + vmasks[batch_idx, int(s*5):int((e+1)*5)] + ): + rank_of_moment.append((s, e)) + self.full_scores.append(rank_of_moment) + + def finalize(self, output_file=None): + outputs = self._aggregate_scores(self.full_scores) + if output_file is not None: + np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs) + return {"outputs": outputs, "targets": self.test_data} + + def _aggregate_scores(self, scores): + self.full_scores = [] + return scores diff --git a/fairseq/examples/MMPT/mmpt/losses/__init__.py b/fairseq/examples/MMPT/mmpt/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc32c96d2d8aed25c59e69e9d9a2ff24a9a2a47 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/losses/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .loss import * +from .nce import * + +try: + from .fairseqmmloss import * +except ImportError: + pass + +try: + from .expnce import * +except ImportError: + pass diff --git a/fairseq/examples/MMPT/mmpt/losses/fairseqmmloss.py b/fairseq/examples/MMPT/mmpt/losses/fairseqmmloss.py new file mode 100644 index 0000000000000000000000000000000000000000..a95e5ecf45d90098c1719487bb9a11c36be7c507 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/losses/fairseqmmloss.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +TODO (huxu): a general fairseq criterion for all your pre-defined losses. +""" + +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.logging import metrics + + +@register_criterion("mmloss") +class MMCriterion(FairseqCriterion): + def __init__(self, task): + super().__init__(task) + # TODO (huxu): wrap forward call of loss_fn and eval_fn into task. + self.mmtask = task.mmtask + + def forward(self, model, sample): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + outputs = self.mmtask(model, sample) + + loss, loss_scalar, max_len, batch_size, sample_size = ( + outputs["loss"], + outputs["loss_scalar"], + outputs["max_len"], + outputs["batch_size"], + outputs["sample_size"], + ) + + logging_output = { + "loss": loss_scalar, + "ntokens": max_len * batch_size, # dummy report. + "nsentences": batch_size, # dummy report. + "sample_size": sample_size, + } + + return loss, 1, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + """since we use NCE, our actual batch_size is 1 per GPU. + Then we take the mean of each worker.""" + loss_sum = sum(log.get("loss", 0.0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + metrics.log_scalar("loss", loss_sum / sample_size, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/examples/MMPT/mmpt/losses/loss.py b/fairseq/examples/MMPT/mmpt/losses/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..99c05d067edac220f9e53080f09f0b40d7dc1e8d --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/losses/loss.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. All Rights Reserved + +import torch + +from torch import nn + + +class Loss(object): + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +# Dummy Loss for testing. +class DummyLoss(Loss): + def __init__(self): + self.loss = nn.CrossEntropyLoss() + + def __call__(self, logits, targets, **kwargs): + return self.loss(logits, targets) + + +class DummyK400Loss(Loss): + """dummy k400 loss for MViT.""" + def __init__(self): + self.loss = nn.CrossEntropyLoss() + + def __call__(self, logits, targets, **kwargs): + return self.loss( + logits, torch.randint(0, 400, (logits.size(0),), device=logits.device)) + + +class CrossEntropy(Loss): + def __init__(self): + self.loss = nn.CrossEntropyLoss() + + def __call__(self, logits, targets, **kwargs): + return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) + + +class ArgmaxCrossEntropy(Loss): + def __init__(self): + self.loss = nn.CrossEntropyLoss() + + def __call__(self, logits, targets, **kwargs): + return self.loss(logits, targets.argmax(dim=1)) + + +class BCE(Loss): + def __init__(self): + self.loss = nn.BCEWithLogitsLoss() + + def __call__(self, logits, targets, **kwargs): + targets = targets.squeeze(0) + return self.loss(logits, targets) + + +class NLGLoss(Loss): + def __init__(self): + self.loss = nn.CrossEntropyLoss() + + def __call__(self, logits, text_label, **kwargs): + targets = text_label[text_label != -100] + return self.loss(logits, targets) + + +class MSE(Loss): + def __init__(self): + self.loss = nn.MSELoss() + + def __call__(self, logits, targets, **kwargs): + return self.loss(logits, targets) + + +class L1(Loss): + def __init__(self): + self.loss = nn.L1Loss() + + def __call__(self, logits, targets, **kwargs): + return self.loss(logits, targets) + + +class SmoothL1(Loss): + def __init__(self): + self.loss = nn.SmoothL1Loss() + + def __call__(self, logits, targets, **kwargs): + return self.loss(logits, targets) diff --git a/fairseq/examples/MMPT/mmpt/losses/nce.py b/fairseq/examples/MMPT/mmpt/losses/nce.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7be8d372e371bb0e0d6166f76e01d3466d2306 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/losses/nce.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +softmax-based NCE loss, used by this project. +""" + +import torch + +from torch import nn + +from .loss import Loss + + +class NCE(Loss): + def __init__(self): + # TODO (huxu): define temperature. + self.loss = nn.CrossEntropyLoss() + + def __call__(self, align_scores, **kargs): + # note: we reuse the same shape as cls head in BERT (batch_size, 2) + # but NCE only needs one logits. + # (so we drop all weights in the second neg logits.) + align_scores = align_scores[:, :1] + # duplicate negative examples + batch_size = align_scores.size(0) // 2 + pos_scores = align_scores[:batch_size] + neg_scores = align_scores[batch_size:].view(1, batch_size).repeat( + batch_size, 1) + scores = torch.cat([pos_scores, neg_scores], dim=1) + return self.loss( + scores, + torch.zeros( + (batch_size,), + dtype=torch.long, + device=align_scores.device), + ) + + +class T2VContraLoss(Loss): + """NCE for MM joint space, on softmax text2video matrix. + """ + def __init__(self): + # TODO (huxu): define temperature. + self.loss = nn.CrossEntropyLoss() + + def __call__(self, pooled_video, pooled_text, **kargs): + batch_size = pooled_video.size(0) + logits = torch.mm(pooled_text, pooled_video.transpose(1, 0)) + targets = torch.arange( + batch_size, + dtype=torch.long, + device=pooled_video.device) + return self.loss(logits, targets) + + +class V2TContraLoss(Loss): + """NCE for MM joint space, with softmax on video2text matrix.""" + + def __init__(self): + # TODO (huxu): define temperature. + self.loss = nn.CrossEntropyLoss() + + def __call__(self, pooled_video, pooled_text, **kargs): + batch_size = pooled_video.size(0) + logits = torch.mm(pooled_video, pooled_text.transpose(1, 0)) + targets = torch.arange( + batch_size, + dtype=torch.long, + device=pooled_video.device) + return self.loss(logits, targets) + + +class MMContraLoss(Loss): + def __init__(self): + self.loss = nn.CrossEntropyLoss() + + def __call__(self, pooled_video, pooled_text, **kwargs): + logits_per_video = pooled_video @ pooled_text.t() + logits_per_text = pooled_text @ pooled_video.t() + + targets = torch.arange( + pooled_video.size(0), + dtype=torch.long, + device=pooled_video.device) + loss_video = self.loss(logits_per_video, targets) + loss_text = self.loss(logits_per_text, targets) + return loss_video + loss_text + + +class MTM(Loss): + """Combination of MFM and MLM.""" + + def __init__(self): + self.loss = nn.CrossEntropyLoss() + + def __call__( + self, + video_logits, + text_logits, + video_label, + text_label, + **kwargs + ): + text_logits = torch.cat([ + text_logits, + torch.zeros( + (text_logits.size(0), 1), device=text_logits.device) + ], dim=1) + vt_logits = torch.cat([video_logits, text_logits], dim=0) + # loss for video. + video_label = torch.zeros( + (video_logits.size(0),), + dtype=torch.long, + device=video_logits.device + ) + + # loss for text. + text_label = text_label.reshape(-1) + labels_mask = text_label != -100 + selected_text_label = text_label[labels_mask] + + vt_label = torch.cat([video_label, selected_text_label], dim=0) + return self.loss(vt_logits, vt_label) + + +class MFMMLM(Loss): + """Combination of MFM and MLM.""" + + def __init__(self): + self.loss = nn.CrossEntropyLoss() + + def __call__( + self, + video_logits, + text_logits, + video_label, + text_label, + **kwargs + ): + # loss for video. + video_label = torch.zeros( + (video_logits.size(0),), + dtype=torch.long, + device=video_logits.device + ) + masked_frame_loss = self.loss(video_logits, video_label) + + # loss for text. + text_label = text_label.reshape(-1) + labels_mask = text_label != -100 + selected_text_label = text_label[labels_mask] + masked_lm_loss = self.loss(text_logits, selected_text_label) + return masked_frame_loss + masked_lm_loss diff --git a/fairseq/examples/MMPT/mmpt/models/__init__.py b/fairseq/examples/MMPT/mmpt/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..825250cd007f5e072b6c9d2376445b955a7aa71e --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/models/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .mmfusion import * +from .transformermodel import * +from .mmfusionnlg import * + +try: + from .fairseqmmmodel import * +except ImportError: + pass + +try: + from .expmmfusion import * +except ImportError: + pass diff --git a/fairseq/examples/MMPT/mmpt/models/fairseqmmmodel.py b/fairseq/examples/MMPT/mmpt/models/fairseqmmmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..b7dd643693dee8cfc20ca77d6cea798d07eaf15a --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/models/fairseqmmmodel.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.models import ( + BaseFairseqModel, + register_model, + register_model_architecture +) + + +@register_model("mmmodel") +class FairseqMMModel(BaseFairseqModel): + """a fairseq wrapper of model built by `task`.""" + + @classmethod + def build_model(cls, args, task): + return FairseqMMModel(task.mmtask.model) + + def __init__(self, mmmodel): + super().__init__() + self.mmmodel = mmmodel + + def forward(self, *args, **kwargs): + return self.mmmodel(*args, **kwargs) + + def upgrade_state_dict_named(self, state_dict, name): + + super().upgrade_state_dict_named(state_dict, name) + + keys_to_delete = [] + + for key in state_dict: + if key not in self.state_dict(): + keys_to_delete.append(key) + for key in keys_to_delete: + print("[INFO]", key, "not used anymore.") + del state_dict[key] + + # copy any newly defined parameters. + for key in self.state_dict(): + if key not in state_dict: + print("[INFO] adding", key) + state_dict[key] = self.state_dict()[key] + + +# a dummy arch, we config the model. +@register_model_architecture("mmmodel", "mmarch") +def mmarch(args): + pass diff --git a/fairseq/examples/MMPT/mmpt/models/mmfusion.py b/fairseq/examples/MMPT/mmpt/models/mmfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2509e26b67b467c3b18c76630881e66cf334a350 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/models/mmfusion.py @@ -0,0 +1,926 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Facebook, Inc. All Rights Reserved + + +import torch + +from torch import nn + +try: + from transformers import AutoConfig, AutoTokenizer +except ImportError: + pass + +from . import transformermodel + + +class MMPTModel(nn.Module): + """An e2e wrapper of inference model. + """ + @classmethod + def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"): + import os + from ..utils import recursive_config + from ..tasks import Task + config = recursive_config(config) + mmtask = Task.config_task(config) + checkpoint_path = os.path.join(config.eval.save_path, checkpoint) + mmtask.build_model(checkpoint=checkpoint_path) + # TODO(huxu): make the video encoder configurable. + from ..processors.models.s3dg import S3D + video_encoder = S3D('pretrained_models/s3d_dict.npy', 512) + video_encoder.load_state_dict( + torch.load('pretrained_models/s3d_howto100m.pth')) + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained( + config.dataset.bert_name, use_fast=config.dataset.use_fast + ) + from ..processors import Aligner + aligner = Aligner(config.dataset) + return ( + MMPTModel(config, mmtask.model, video_encoder), + tokenizer, + aligner + ) + + def __init__(self, config, model, video_encoder, **kwargs): + super().__init__() + self.max_video_len = config.dataset.max_video_len + self.video_encoder = video_encoder + self.model = model + + def forward(self, video_frames, caps, cmasks, return_score=False): + bsz = video_frames.size(0) + assert bsz == 1, "only bsz=1 is supported now." + seq_len = video_frames.size(1) + video_frames = video_frames.view(-1, *video_frames.size()[2:]) + vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3)) + vfeats = vfeats['video_embedding'] + vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1)) + padding = torch.zeros( + bsz, self.max_video_len - seq_len, vfeats.size(-1)) + vfeats = torch.cat([vfeats, padding], dim=1) + vmasks = torch.cat([ + torch.ones((bsz, seq_len), dtype=torch.bool), + torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool) + ], + dim=1 + ) + output = self.model(caps, cmasks, vfeats, vmasks) + if return_score: + output = {"score": torch.bmm( + output["pooled_video"][:, None, :], + output["pooled_text"][:, :, None] + ).squeeze(-1).squeeze(-1)} + return output + + +class MMFusion(nn.Module): + """a MMPT wrapper class for MMBert style models. + TODO: move isolated mask to a subclass. + """ + def __init__(self, config, **kwargs): + super().__init__() + transformer_config = AutoConfig.from_pretrained( + config.dataset.bert_name) + self.hidden_size = transformer_config.hidden_size + self.is_train = False + if config.dataset.train_path is not None: + self.is_train = True + # 0 means no iso; 1-12 means iso up to that layer. + self.num_hidden_layers = transformer_config.num_hidden_layers + self.last_iso_layer = 0 + if config.dataset.num_iso_layer is not None: + self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1 + + if config.model.mm_encoder_cls is not None: + mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls) + model_config = AutoConfig.from_pretrained(config.dataset.bert_name) + model_config.max_video_len = config.dataset.max_video_len + # TODO: a general way to add parameter for a model. + model_config.use_seg_emb = config.model.use_seg_emb + self.mm_encoder = mm_encoder_cls.from_pretrained( + config.dataset.bert_name, config=model_config) + elif config.model.video_encoder_cls is not None\ + and config.model.text_encoder_cls is not None: + video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls) + model_config = AutoConfig.from_pretrained(config.dataset.bert_name) + model_config.max_video_len = config.dataset.max_video_len + # TODO: make each model a set of config class. + if hasattr(model_config, "num_layers"): + model_config.num_layers = config.model.num_hidden_video_layers + else: + model_config.num_hidden_layers = config.model.num_hidden_video_layers + self.video_encoder = video_encoder_cls.from_pretrained( + config.dataset.bert_name, config=model_config) + # exact same NLP model from Huggingface. + text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls) + self.text_encoder = text_encoder_cls.from_pretrained( + config.dataset.bert_name) + else: + raise ValueError("the encoder must be either MM or two backbones.") + + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + **kwargs + ): + raise NotImplementedError( + "Please derive MMFusion module." + ) + + def _mm_on_the_fly( + self, + cmasks, + vmasks, + attention_mask + ): + """helper function for mask, seg_ids and token_type_ids.""" + if attention_mask is None: + attention_mask = self._mm_attention_mask(cmasks, vmasks) + + """ + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + """ + token_type_ids = torch.cat( + [ + torch.zeros( + (vmasks.size(0), vmasks.size(1) + 2), + dtype=torch.long, + device=vmasks.device, + ), + torch.ones( + (cmasks.size(0), cmasks.size(1) - 2), + dtype=torch.long, + device=cmasks.device, + ), + ], + dim=1, + ) + return attention_mask, token_type_ids + + def _mm_attention_mask(self, cmasks, vmasks): + assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format( + str(cmasks.size()), + str(vmasks.size()), + str(cmasks.size(0)), + str(vmasks.size(0)), + ) + + mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1) + if self.last_iso_layer == 0: + # hard attention mask. + return mm_mask + else: + # a gpu iso mask; 0 : num_iso_layer is isolated; + # num_iso_layer: are MM-fused. + # make an iso layer + batch_size = cmasks.size(0) + iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks) + mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1) + iso_mm_masks = [] + # hard attention mask. + iso_mask = iso_mask[:, None, :, :].repeat( + 1, self.last_iso_layer, 1, 1) + iso_mm_masks.append(iso_mask) + if self.last_iso_layer < self.num_hidden_layers: + mm_mask = mm_mask[:, None, :, :].repeat( + 1, self.num_hidden_layers - self.last_iso_layer, 1, 1 + ) + iso_mm_masks.append(mm_mask) + iso_mm_masks = torch.cat(iso_mm_masks, dim=1) + return iso_mm_masks + + def _make_iso_mask(self, batch_size, cmasks, vmasks): + cls_self_mask = torch.cat( + [ + torch.ones( + (batch_size, 1), dtype=torch.bool, device=cmasks.device), + torch.zeros( + (batch_size, cmasks.size(1) + vmasks.size(1) - 1), + dtype=torch.bool, device=cmasks.device) + ], dim=1) + + iso_video_mask = torch.cat( + [ + # [CLS] is not used. + torch.zeros( + (batch_size, 1), dtype=torch.bool, device=cmasks.device + ), + vmasks, + # assume to be 1. + cmasks[:, 1:2], + # 2 means [CLS] + [SEP] + torch.zeros( + (batch_size, cmasks.size(1) - 2), + dtype=torch.bool, + device=cmasks.device, + ), + ], + dim=1, + ) + iso_text_mask = torch.cat( + [ + torch.zeros( + (batch_size, 2 + vmasks.size(1)), + dtype=torch.bool, + device=cmasks.device, + ), # [CLS] is not used. + cmasks[:, 2:], # assume to be 1. + ], + dim=1, + ) + cls_self_mask = cls_self_mask[:, None, :] + iso_video_mask = iso_video_mask[:, None, :].repeat( + 1, vmasks.size(1) + 1, 1) + iso_text_mask = iso_text_mask[:, None, :].repeat( + 1, cmasks.size(1) - 2, 1) + return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1) + + def _pooling_vt_layer( + self, + layered_sequence_output, + cmasks, + vmasks + ): + layer_idx = self.last_iso_layer \ + if self.last_iso_layer > 0 else self.num_hidden_layers + hidden_state = layered_sequence_output[layer_idx] + # also output pooled_video and pooled_text. + batch_size = cmasks.size(0) + # pool the modality. + text_offset = vmasks.size(1) + 2 # [CLS] + [SEP] + # video tokens + [SEP] + video_outputs = hidden_state[:, 1:text_offset] + video_attention_mask = torch.cat( + [ + vmasks, + torch.ones( + (batch_size, 1), dtype=torch.bool, device=vmasks.device), + ], + dim=1, + ) + assert video_outputs.size(1) == video_attention_mask.size(1) + pooled_video = torch.sum( + video_outputs * video_attention_mask.unsqueeze(-1), dim=1 + ) / video_attention_mask.sum(1, keepdim=True) + # pooled_video = torch.mean(video_outputs[0], dim=1) + + # text tokens + [SEP] + text_attention_mask = cmasks[:, 2:] + text_outputs = hidden_state[:, text_offset:] + assert text_outputs.size(1) == text_attention_mask.size(1) + pooled_text = torch.sum( + text_outputs * text_attention_mask.unsqueeze(-1), dim=1 + ) / text_attention_mask.sum(1, keepdim=True) + return pooled_video, pooled_text + + +class MMFusionMFMMLM(MMFusion): + """forward function for MFM and MLM.""" + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + attention_mask=None, + video_label=None, + text_label=None, + **kwargs + ): + output_hidden_states = False if self.is_train else True + + target_vfeats, non_masked_frame_mask = None, None + if video_label is not None: + target_vfeats = vfeats.masked_select( + video_label.unsqueeze(-1)).view( + -1, vfeats.size(-1) + ) + # mask video token. + vfeats[video_label] = 0.0 + non_masked_frame_mask = vmasks.clone() + non_masked_frame_mask[video_label] = False + + attention_mask, token_type_ids = self._mm_on_the_fly( + cmasks, vmasks, attention_mask) + + outputs = self.mm_encoder( + input_ids=caps, + input_video_embeds=vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + masked_frame_labels=video_label, + target_video_hidden_states=target_vfeats, + non_masked_frame_mask=non_masked_frame_mask, + masked_lm_labels=text_label, + output_hidden_states=output_hidden_states, + ) + + video_logits, text_logits = outputs[0], outputs[1] + + if self.is_train: # return earlier for training. + return { + "video_logits": video_logits, + "text_logits": text_logits, + } + + pooled_video, pooled_text = self._pooling_vt_layer( + outputs[2], cmasks, vmasks) + return {"pooled_video": pooled_video, "pooled_text": pooled_text} + + +class MMFusionMTM(MMFusionMFMMLM): + def __init__(self, config, **kwargs): + super().__init__(config) + """ + For reproducibility: + self.mm_encoder will be initialized then discarded. + """ + from .transformermodel import MMBertForMTM + model_config = AutoConfig.from_pretrained(config.dataset.bert_name) + model_config.max_video_len = config.dataset.max_video_len + model_config.use_seg_emb = config.model.use_seg_emb + self.mm_encoder = MMBertForMTM.from_pretrained( + config.dataset.bert_name, config=model_config) + + +class MMFusionShare(MMFusion): + """A retrival wrapper using mm_encoder as both video/text backbone. + TODO: move formally. + """ + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + attention_mask=None, + video_label=None, + text_label=None, + output_hidden_states=False, + **kwargs + ): + pooled_video = self.forward_video( + vfeats, + vmasks, + caps, + cmasks, + output_hidden_states + ) + + pooled_text = self.forward_text( + caps, + cmasks, + output_hidden_states + ) + + return {"pooled_video": pooled_video, "pooled_text": pooled_text} + + def forward_video( + self, + vfeats, + vmasks, + caps, + cmasks, + output_hidden_states=False, + **kwargs + ): + input_ids = caps[:, :2] + + attention_mask = torch.cat([ + cmasks[:, :1], + vmasks, + cmasks[:, 1:2] + ], dim=1) + + token_type_ids = torch.zeros( + (vmasks.size(0), vmasks.size(1) + 2), + dtype=torch.long, + device=vmasks.device) + + outputs = self.mm_encoder( + input_ids=input_ids, + input_video_embeds=vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + output_hidden_states=True + ) + video_outputs = outputs[0] + + if output_hidden_states: + return video_outputs + + batch_size = cmasks.size(0) + + video_attention_mask = torch.cat( + [ + torch.zeros( + (batch_size, 1), dtype=torch.bool, device=vmasks.device), + vmasks, + torch.ones( + (batch_size, 1), dtype=torch.bool, device=vmasks.device), + ], + dim=1, + ) + assert video_outputs.size(1) == video_attention_mask.size(1) + + video_attention_mask = video_attention_mask.type(video_outputs.dtype) \ + / video_attention_mask.sum(1, keepdim=True) + + pooled_video = torch.bmm( + video_outputs.transpose(2, 1), + video_attention_mask.unsqueeze(2) + ).squeeze(-1) + return pooled_video # video_outputs + + def forward_text( + self, + caps, + cmasks, + output_hidden_states=False, + **kwargs + ): + input_ids = torch.cat([ + caps[:, :1], caps[:, 2:], + ], dim=1) + + attention_mask = torch.cat([ + cmasks[:, :1], + cmasks[:, 2:] + ], dim=1) + + token_type_ids = torch.cat([ + torch.zeros( + (cmasks.size(0), 1), + dtype=torch.long, + device=cmasks.device), + torch.ones( + (cmasks.size(0), cmasks.size(1) - 2), + dtype=torch.long, + device=cmasks.device) + ], dim=1) + + outputs = self.mm_encoder( + input_ids=input_ids, + input_video_embeds=None, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + output_hidden_states=True + ) + text_outputs = outputs[0] + + if output_hidden_states: + return text_outputs + + batch_size = caps.size(0) + # text tokens + [SEP] + text_attention_mask = torch.cat([ + torch.zeros( + (batch_size, 1), dtype=torch.bool, device=cmasks.device), + cmasks[:, 2:] + ], dim=1) + + assert text_outputs.size(1) == text_attention_mask.size(1) + + text_attention_mask = text_attention_mask.type(text_outputs.dtype) \ + / text_attention_mask.sum(1, keepdim=True) + + pooled_text = torch.bmm( + text_outputs.transpose(2, 1), + text_attention_mask.unsqueeze(2) + ).squeeze(-1) + return pooled_text # text_outputs + + +class MMFusionSeparate(MMFusionShare): + def forward_video( + self, + vfeats, + vmasks, + caps, + cmasks, + output_hidden_states=False, + **kwargs + ): + input_ids = caps[:, :2] + + attention_mask = torch.cat([ + cmasks[:, :1], + vmasks, + cmasks[:, 1:2] + ], dim=1) + + token_type_ids = torch.zeros( + (vmasks.size(0), vmasks.size(1) + 2), + dtype=torch.long, + device=vmasks.device) + + outputs = self.video_encoder( + input_ids=input_ids, + input_video_embeds=vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + output_hidden_states=True + ) + video_outputs = outputs[0] + + if output_hidden_states: + return video_outputs + + batch_size = cmasks.size(0) + + video_attention_mask = torch.cat( + [ + torch.zeros( + (batch_size, 1), dtype=torch.bool, device=vmasks.device), + vmasks, + torch.ones( + (batch_size, 1), dtype=torch.bool, device=vmasks.device), + ], + dim=1, + ) + assert video_outputs.size(1) == video_attention_mask.size(1) + + video_attention_mask = video_attention_mask.type(video_outputs.dtype) \ + / video_attention_mask.sum(1, keepdim=True) + + pooled_video = torch.bmm( + video_outputs.transpose(2, 1), + video_attention_mask.unsqueeze(2) + ).squeeze(-1) + return pooled_video # video_outputs + + def forward_text( + self, + caps, + cmasks, + output_hidden_states=False, + **kwargs + ): + input_ids = torch.cat([ + caps[:, :1], caps[:, 2:], + ], dim=1) + + attention_mask = torch.cat([ + cmasks[:, :1], + cmasks[:, 2:] + ], dim=1) + # different from sharing, we use all-0 type. + token_type_ids = torch.zeros( + (cmasks.size(0), cmasks.size(1) - 1), + dtype=torch.long, + device=cmasks.device) + + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + output_hidden_states=True + ) + text_outputs = outputs[0] + + if output_hidden_states: + return text_outputs + + batch_size = caps.size(0) + # text tokens + [SEP] + text_attention_mask = torch.cat([ + torch.zeros( + (batch_size, 1), dtype=torch.bool, device=cmasks.device), + cmasks[:, 2:] + ], dim=1) + + assert text_outputs.size(1) == text_attention_mask.size(1) + + text_attention_mask = text_attention_mask.type(text_outputs.dtype) \ + / text_attention_mask.sum(1, keepdim=True) + + pooled_text = torch.bmm( + text_outputs.transpose(2, 1), + text_attention_mask.unsqueeze(2) + ).squeeze(-1) + return pooled_text # text_outputs + + +class MMFusionJoint(MMFusion): + """fine-tuning wrapper for retrival task.""" + + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + attention_mask=None, + video_label=None, + text_label=None, + **kwargs + ): + # TODO (huxu): other ways to do negative examples; move the following + # into your criterion forward. + output_hidden_states = True + + attention_mask, token_type_ids = self._mm_on_the_fly( + cmasks, vmasks, attention_mask) + + separate_forward_split = ( + None if self.is_train else vmasks.size(1) + 2 + ) # [CLS] + [SEP] + + outputs = self.mm_encoder( + input_ids=caps, + input_video_embeds=vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + output_hidden_states=output_hidden_states, + separate_forward_split=separate_forward_split, + ) + + pooled_video, pooled_text = self._pooling_vt_layer( + outputs[2], cmasks, vmasks) + return {"pooled_video": pooled_video, "pooled_text": pooled_text} + + +class MMFusionActionSegmentation(MMFusion): + """Fine-tuning wrapper for action segmentation. + TODO: rename this for VLM. + """ + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + attention_mask=None, + **kwargs + ): + # ActionLocalization assume of batch_size=1, squeeze it. + caps = caps.view(-1, caps.size(-1)) + cmasks = cmasks.view(-1, cmasks.size(-1)) + vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3)) + vmasks = vmasks.view(-1, vmasks.size(-1)) + + # this may not cover all shapes of attention_mask. + attention_mask = attention_mask.view( + -1, attention_mask.size(2), attention_mask.size(3)) \ + if attention_mask is not None else None + + # TODO (huxu): other ways to do negative examples; move the following + # into your criterion forward. + output_hidden_states = True + + # video forwarding, text is dummy; never use attention_mask. + attention_mask, token_type_ids = self._mm_on_the_fly( + cmasks, vmasks, attention_mask) + + logits = self.mm_encoder( + input_ids=caps, + input_video_embeds=vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + output_hidden_states=output_hidden_states, + ) + return {"logits": logits[0][:, 1:vmasks.size(1)+1]} + + +class MMFusionActionLocalization(MMFusion): + """fine-tuning model for retrival task.""" + + def __init__(self, config, **kwargs): + super().__init__(config) + tokenizer = AutoTokenizer.from_pretrained( + config.dataset.bert_name) + self.cls_token_id = tokenizer.cls_token_id + self.sep_token_id = tokenizer.sep_token_id + self.pad_token_id = tokenizer.pad_token_id + + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + attention_mask=None, + **kwargs + ): + # ActionLocalization assume of batch_size=1, squeeze it. + caps = caps.squeeze(0) + cmasks = cmasks.squeeze(0) + vfeats = vfeats.squeeze(0) + vmasks = vmasks.squeeze(0) + attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None + + # TODO (huxu): other ways to do negative examples; move the following + # into your criterion forward. + output_hidden_states = True + + # a len1 dummy video token. + dummy_vfeats = torch.zeros( + (caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype) + dummy_vmasks = torch.ones( + (caps.size(0), 1), dtype=torch.bool, + device=vfeats.device) + + dummy_caps = torch.LongTensor( + [[self.cls_token_id, self.sep_token_id, + self.pad_token_id, self.sep_token_id]], + ).to(caps.device).repeat(vfeats.size(0), 1) + dummy_cmasks = torch.BoolTensor( + [[0, 1, 0, 1]] # pad are valid for attention. + ).to(caps.device).repeat(vfeats.size(0), 1) + + # video forwarding, text is dummy; never use attention_mask. + attention_mask, token_type_ids = self._mm_on_the_fly( + dummy_cmasks, vmasks, None) + + outputs = self.mm_encoder( + input_ids=dummy_caps, + input_video_embeds=vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + output_hidden_states=output_hidden_states, + ) + + layer_idx = self.last_iso_layer \ + if self.last_iso_layer > 0 else self.num_hidden_layers + + video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select( + vmasks.unsqueeze(-1) + ).view(-1, self.hidden_size) + + # text forwarding, video is dummy + attention_mask, token_type_ids = self._mm_on_the_fly( + cmasks, dummy_vmasks, None) + + outputs = self.mm_encoder( + input_ids=caps, + input_video_embeds=dummy_vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + output_hidden_states=output_hidden_states, + ) + + _, pooled_text = self._pooling_vt_layer( + outputs[2], cmasks, dummy_vmasks) + # this line is not right. + logits = torch.mm(video_seq, pooled_text.transpose(1, 0)) + return {"logits": logits} + + +# --------------- MMFusionSeparate for end tasks --------------- + +class MMFusionSeparateActionSegmentation(MMFusionSeparate): + """Fine-tuning wrapper for action segmentation.""" + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + attention_mask=None, + **kwargs + ): + # ActionLocalization assume of batch_size=1, squeeze it. + caps = caps.view(-1, caps.size(-1)) + cmasks = cmasks.view(-1, cmasks.size(-1)) + vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3)) + vmasks = vmasks.view(-1, vmasks.size(-1)) + logits = self.forward_video( + vfeats, + vmasks, + caps, + cmasks, + output_hidden_states=True + ) + return {"logits": logits[:, 1:vmasks.size(1)+1]} + + +class MMFusionSeparateActionLocalization(MMFusionSeparate): + def __init__(self, config, **kwargs): + super().__init__(config) + tokenizer = AutoTokenizer.from_pretrained( + config.dataset.bert_name) + self.cls_token_id = tokenizer.cls_token_id + self.sep_token_id = tokenizer.sep_token_id + self.pad_token_id = tokenizer.pad_token_id + + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + **kwargs + ): + # ActionLocalization assume of batch_size=1, squeeze it. + caps = caps.squeeze(0) + cmasks = cmasks.squeeze(0) + vfeats = vfeats.squeeze(0) + vmasks = vmasks.squeeze(0) + + # TODO (huxu): other ways to do negative examples; move the following + # into your criterion forward. + dummy_caps = torch.LongTensor( + [[self.cls_token_id, self.sep_token_id, + self.pad_token_id, self.sep_token_id]], + ).to(caps.device).repeat(vfeats.size(0), 1) + dummy_cmasks = torch.BoolTensor( + [[0, 1, 0, 1]] # pad are valid for attention. + ).to(caps.device).repeat(vfeats.size(0), 1) + + outputs = self.forward_video( + vfeats, + vmasks, + dummy_caps, + dummy_cmasks, + output_hidden_states=True + ) + + video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select( + vmasks.unsqueeze(-1) + ).view(-1, self.hidden_size) + + pooled_text = self.forward_text( + caps, + cmasks, + output_hidden_states=False + ) + + # this line is not right. + logits = torch.mm(video_seq, pooled_text.transpose(1, 0)) + return {"logits": logits} + + +class MMFusionShareActionLocalization(MMFusionShare): + def __init__(self, config, **kwargs): + super().__init__(config) + tokenizer = AutoTokenizer.from_pretrained( + config.dataset.bert_name) + self.cls_token_id = tokenizer.cls_token_id + self.sep_token_id = tokenizer.sep_token_id + self.pad_token_id = tokenizer.pad_token_id + + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + **kwargs + ): + # ActionLocalization assume of batch_size=1, squeeze it. + caps = caps.squeeze(0) + cmasks = cmasks.squeeze(0) + vfeats = vfeats.squeeze(0) + vmasks = vmasks.squeeze(0) + + # TODO (huxu): other ways to do negative examples; move the following + # into your criterion forward. + dummy_caps = torch.LongTensor( + [[self.cls_token_id, self.sep_token_id, + self.pad_token_id, self.sep_token_id]], + ).to(caps.device).repeat(vfeats.size(0), 1) + dummy_cmasks = torch.BoolTensor( + [[0, 1, 0, 1]] # pad are valid for attention. + ).to(caps.device).repeat(vfeats.size(0), 1) + + outputs = self.forward_video( + vfeats, + vmasks, + dummy_caps, + dummy_cmasks, + output_hidden_states=True + ) + + video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select( + vmasks.unsqueeze(-1) + ).view(-1, self.hidden_size) + + pooled_text = self.forward_text( + caps, + cmasks, + output_hidden_states=False + ) + + # this line is not right. + logits = torch.mm(video_seq, pooled_text.transpose(1, 0)) + return {"logits": logits} diff --git a/fairseq/examples/MMPT/mmpt/models/mmfusionnlg.py b/fairseq/examples/MMPT/mmpt/models/mmfusionnlg.py new file mode 100644 index 0000000000000000000000000000000000000000..9207e77dab3025d7a26efcce0795183de1d34fc7 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/models/mmfusionnlg.py @@ -0,0 +1,999 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Facebook, Inc. All Rights Reserved + + +import torch + +from torch.nn import functional as F + +from typing import Optional, Iterable + +try: + from transformers import BertPreTrainedModel + from transformers.modeling_bert import BertOnlyMLMHead + + from transformers.file_utils import ModelOutput + from transformers.modeling_outputs import CausalLMOutput + from transformers.generation_utils import ( + BeamHypotheses, + top_k_top_p_filtering + ) +except ImportError: + pass + +from .mmfusion import MMFusion +from .transformermodel import MMBertModel +from ..modules import VideoTokenMLP + + +class MMFusionNLG(MMFusion): + def __init__(self, config, **kwargs): + super().__init__(config) + if config.model.max_decode_length is not None: + self.max_length = min( + config.model.max_decode_length, + config.dataset.max_len - config.dataset.max_video_len - 3 + ) + else: + self.max_length = \ + config.dataset.max_len - config.dataset.max_video_len - 3 + self.gen_param = config.gen_param if config.gen_param is not None \ + else {} + + def forward( + self, + caps, + cmasks, + vfeats, + vmasks, + attention_mask, + video_label=None, + text_label=None, + **kwargs + ): + """use pre-trained LM header for generation.""" + attention_mask, token_type_ids = self._mm_on_the_fly( + cmasks, vmasks, attention_mask) + + outputs = self.mm_encoder( + input_ids=caps, + input_video_embeds=vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + masked_lm_labels=text_label, + ) + return {"logits": outputs[0]} + + @torch.no_grad() + def generate( + self, + caps, cmasks, vfeats, vmasks, + attention_mask=None, + bos_token_id=None, + eos_token_id=None, + **kwargs + ): + # a simplified interface from + # https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html#GenerationMixin.generate + + # caps now only have + # [CLS], [SEP] (for video) and [CLS] (as bos_token) + assert caps.size(1) == 3 + + attention_mask, token_type_ids = self._mm_on_the_fly( + cmasks, vmasks, attention_mask) + + output = self.mm_encoder.generate( + input_ids=caps, + input_video_embeds=vfeats, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + max_length=self.max_length, + **self.gen_param + ) + return output + + +class MMBertForNLG(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.bert = MMBertModel(config) + self.videomlp = VideoTokenMLP(config) + # we do not use `BertGenerationOnlyLMHead` + # because we can reuse pretraining. + self.cls = BertOnlyMLMHead(config) + self.hidden_size = config.hidden_size + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def forward( + self, + input_ids=None, + input_video_embeds=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + masked_lm_labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # similar to MMBertForMFMMLM without MFM. + video_tokens = self.videomlp(input_video_embeds) + outputs = self.bert( + input_ids, + video_tokens, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + prediction_scores = None + if masked_lm_labels is not None: + text_offset = input_video_embeds.size(1) + 1 # [CLS] + # recover caps format: [CLS] [SEP] text [SEP] + text_sequence_output = torch.cat( + [sequence_output[:, :1], sequence_output[:, text_offset:]], + dim=1 + ) + + # only compute select tokens to training to speed up. + hidden_size = text_sequence_output.size(-1) + # masked_lm_labels = masked_lm_labels.reshape(-1) + labels_mask = masked_lm_labels != -100 + + selected_text_output = text_sequence_output.masked_select( + labels_mask.unsqueeze(-1) + ).view(-1, hidden_size) + prediction_scores = self.cls(selected_text_output) + + if not return_dict: + output = ( + prediction_scores, + ) + outputs[2:] + return output + + # for generation. + text_offset = input_video_embeds.size(1) + 2 # [CLS] + text_sequence_output = sequence_output[:, text_offset:] + prediction_scores = self.cls(text_sequence_output) + return CausalLMOutput( + loss=None, + logits=prediction_scores, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + input_video_embeds, + attention_mask=None, + token_type_ids=None, + **model_kwargs + ): + # must return a dictionary. + seq_len = input_ids.size(1) + input_video_embeds.size(1) + if attention_mask is not None: + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, :seq_len, :seq_len] + elif len(attention_mask.size()) == 3: + attention_mask = attention_mask[:, :seq_len, :seq_len] + else: + attention_mask = attention_mask[:, :seq_len] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, :seq_len] + + return { + "input_ids": input_ids, + "input_video_embeds": input_video_embeds, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, + **model_kwargs + ) -> torch.LongTensor: + r""" + Generates sequences for models with a language modeling head. The method currently supports greedy decoding, + beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. + Adapted in part from `Facebook's XLM beam search code + `__. + Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the + attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values + indicated are the default values of those config. + Most of these parameters are explained in more detail in `this blog post + `__. + Parameters: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes + it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only + decoder_start_token_id is passed as the first token to the decoder. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + min_length (:obj:`int`, `optional`, defaults to 10): + The minimum length of the sequence to be generated. + do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use sampling ; use greedy decoding otherwise. + early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beams (:obj:`int`, `optional`, defaults to 1): + Number of beams for beam search. 1 means no beam search. + temperature (:obj:`float`, `optional`, defaults tp 1.0): + The value used to module the next token probabilities. + top_k (:obj:`int`, `optional`, defaults to 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (:obj:`float`, `optional`, defaults to 1.0): + If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or + higher are kept for generation. + repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See `this paper + `__ for more details. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + bos_token_id (:obj:`int`, `optional`): + The id of the `beginning-of-sequence` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. + Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in + order to encourage the model to produce longer sequences. + no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + bad_words_ids(:obj:`List[int]`, `optional`): + List of token ids that are not allowed to be generated. In order to get the tokens of the words that + should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. + num_return_sequences(:obj:`int`, `optional`, defaults to 1): + The number of independently computed returned sequences for each element in the batch. + attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for + tokens that are not masked, and 0 for masked tokens. + If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_start_token_id (:obj:`int`, `optional`): + If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. + use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + model_kwargs: + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + Examples:: + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + outputs = model.generate(max_length=40) # do greedy decoding + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. + input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. + input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl + bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated + """ + + # We cannot generate if the model does not have a LM head + if self.get_output_embeddings() is None: + raise AttributeError( + "You tried to generate sequences with a model that does not have a LM Head." + "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" + ) + + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + do_sample = do_sample if do_sample is not None else self.config.do_sample + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + use_cache = use_cache if use_cache is not None else self.config.use_cache + num_beams = num_beams if num_beams is not None else self.config.num_beams + temperature = temperature if temperature is not None else self.config.temperature + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] # overriden by the input batch_size + else: + batch_size = 1 + + assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." + assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." + assert isinstance(do_sample, bool), "`do_sample` should be a boolean." + assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." + assert isinstance(use_cache, bool), "`use_cache` should be a boolean." + assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." + assert temperature > 0, "`temperature` should be strictly positive." + assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." + assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." + assert input_ids is not None or ( + isinstance(bos_token_id, int) and bos_token_id >= 0 + ), "If input_ids is not defined, `bos_token_id` should be a positive integer." + assert pad_token_id is None or ( + isinstance(pad_token_id, int) and (pad_token_id >= 0) + ), "`pad_token_id` should be a positive integer." + assert (eos_token_id is None) or ( + isinstance(eos_token_id, int) and (eos_token_id >= 0) + ), "`eos_token_id` should be a positive integer." + assert length_penalty > 0, "`length_penalty` should be strictly positive." + assert ( + isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 + ), "`no_repeat_ngram_size` should be a positive integer." + assert ( + isinstance(num_return_sequences, int) and num_return_sequences > 0 + ), "`num_return_sequences` should be a strictly positive integer." + assert ( + bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) + ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" + + if input_ids is None: + assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( + "you should either supply a context to complete as `input_ids` input " + "or a `bos_token_id` (integer >= 0) as a first token to start the generation." + ) + input_ids = torch.full( + (batch_size, 1), + bos_token_id, + dtype=torch.long, + device=next(self.parameters()).device, + ) + else: + assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." + + # not allow to duplicate outputs when greedy decoding + if do_sample is False: + if num_beams == 1: + # no_beam_search greedy generation conditions + assert ( + num_return_sequences == 1 + ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" + + else: + # beam_search greedy generation conditions + assert ( + num_beams >= num_return_sequences + ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" + + # create attention mask if necessary + # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 + if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): + attention_mask = input_ids.ne(pad_token_id).long() + elif attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + # set pad_token_id to eos_token_id if not set. Important that this is done after + # attention_mask is created + if pad_token_id is None and eos_token_id is not None: + print( + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) + ) + pad_token_id = eos_token_id + + # vocab size + if hasattr(self.config, "vocab_size"): + vocab_size = self.config.vocab_size + elif ( + self.config.is_encoder_decoder + and hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "vocab_size") + ): + vocab_size = self.config.decoder.vocab_size + else: + raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined") + + # set effective batch size and effective batch multiplier according to do_sample + if do_sample: + effective_batch_size = batch_size * num_return_sequences + effective_batch_mult = num_return_sequences + else: + effective_batch_size = batch_size + effective_batch_mult = 1 + + if self.config.is_encoder_decoder: + if decoder_start_token_id is None: + # see if BOS token can be used for decoder_start_token_id + if bos_token_id is not None: + decoder_start_token_id = bos_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None + ): + decoder_start_token_id = self.config.decoder.bos_token_id + else: + raise ValueError( + "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + ) + + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) + + # get encoder and store encoder outputs + encoder = self.get_encoder() + encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) + + # Expand input ids if num_beams > 1 or num_return_sequences > 1 + if num_return_sequences > 1 or num_beams > 1: + # TODO: make this a call-back function. + # input_ids=caps, + # input_video_embeds=vfeats, + # attention_mask=attention_mask, + # token_type_ids=token_type_ids, + input_video_embeds = model_kwargs.pop("input_video_embeds", None) + token_type_ids = model_kwargs.pop("token_type_ids", None) + + input_ids_len = input_ids.shape[-1] + input_ids = input_ids.unsqueeze(1).expand( + batch_size, effective_batch_mult * num_beams, input_ids_len) + + input_video_embeds_len, input_video_embeds_hidden = input_video_embeds.size(1), input_video_embeds.size(2) + input_video_embeds = input_video_embeds.unsqueeze(1).expand( + batch_size, effective_batch_mult * num_beams, input_video_embeds_len, input_video_embeds_hidden) + + attention_mask_from_len, attention_mask_to_len = attention_mask.size(1), attention_mask.size(2) + attention_mask = attention_mask.unsqueeze(1).expand( + batch_size, effective_batch_mult * num_beams, attention_mask_from_len, attention_mask_to_len + ) + + token_type_ids_len = token_type_ids.size(1) + token_type_ids = token_type_ids.unsqueeze(1).expand( + batch_size, effective_batch_mult * num_beams, token_type_ids_len + ) + + # contiguous ... + input_ids = input_ids.contiguous().view( + effective_batch_size * num_beams, input_ids_len + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + + input_video_embeds = input_video_embeds.contiguous().view( + effective_batch_size * num_beams, input_video_embeds_len, input_video_embeds_hidden) + + attention_mask = attention_mask.contiguous().view( + effective_batch_size * num_beams, attention_mask_from_len, attention_mask_to_len + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + + token_type_ids = token_type_ids.contiguous().view( + effective_batch_size * num_beams, token_type_ids_len + ) + + model_kwargs["input_video_embeds"] = input_video_embeds + model_kwargs["token_type_ids"] = token_type_ids + + if self.config.is_encoder_decoder: + device = next(self.parameters()).device + if decoder_input_ids is not None: + # give initial decoder input ids + input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device) + else: + # create empty decoder input_ids + input_ids = torch.full( + (effective_batch_size * num_beams, 1), + decoder_start_token_id, + dtype=torch.long, + device=device, + ) + cur_len = input_ids.shape[-1] + + assert ( + batch_size == encoder_outputs.last_hidden_state.shape[0] + ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " + + # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) + expanded_batch_idxs = ( + torch.arange(batch_size) + .view(-1, 1) + .repeat(1, num_beams * effective_batch_mult) + .view(-1) + .to(input_ids.device) + ) + + # expand encoder_outputs + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_batch_idxs + ) + + # save encoder_outputs in `model_kwargs` + model_kwargs["encoder_outputs"] = encoder_outputs + + else: + cur_len = input_ids.shape[-1] + + assert ( + cur_len < max_length + ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" + + if num_beams > 1: + output = self._generate_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + attention_mask=attention_mask, + use_cache=use_cache, + model_kwargs=model_kwargs, + ) + else: + output = self._generate_no_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + attention_mask=attention_mask, + use_cache=use_cache, + model_kwargs=model_kwargs, + ) + + return output + + def _generate_beam_search( + self, + input_ids, + cur_len, + max_length, + min_length, + do_sample, + early_stopping, + temperature, + top_k, + top_p, + repetition_penalty, + no_repeat_ngram_size, + bad_words_ids, + pad_token_id, + eos_token_id, + batch_size, + num_return_sequences, + length_penalty, + num_beams, + vocab_size, + attention_mask, + use_cache, + model_kwargs, + ): + """Generate sequences for each example with beam search.""" + + # generated hypotheses + generated_hyps = [ + BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) + for _ in range(batch_size) + ] + + # scores for each sentence in the beam + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + + # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times + if do_sample is False: + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) + + # cache compute states + past = None + + # done sentences + done = [False for _ in range(batch_size)] + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs + ) + outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size) + next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size) + + # if model has past, then set the past variable to speed up decoding + if "past_key_values" in outputs: + past = outputs.past_key_values + elif "mems" in outputs: + past = outputs.mems + + if self.config.is_encoder_decoder and do_sample is False: + # TODO (PVP) still a bit hacky here - there might be a better solution + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) + + scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) + + scores = self.postprocess_next_token_scores( + scores=scores, + input_ids=input_ids, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + cur_len=cur_len, + min_length=min_length, + max_length=max_length, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + batch_size=batch_size, + num_beams=num_beams, + ) + + assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( + scores.shape, (batch_size * num_beams, vocab_size) + ) + + if do_sample: + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + # Temperature + if temperature != 1.0: + _scores = _scores / temperature + # Top-p/top-k filtering + _scores = top_k_top_p_filtering( + _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) # (batch_size * num_beams, vocab_size) + # re-organize to group the beam together to sample from all beam_idxs + _scores = _scores.contiguous().view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) + probs = F.softmax(_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2) + # Compute next scores + next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) + # sort the sampled vector to make sure that the first num_beams samples are the best + next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) + + else: + next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + next_scores = next_scores.view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + + next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True) + + assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) + + # next batch beam content + next_batch_beam = [] + + # for each sentence + for batch_idx in range(batch_size): + + # if we are done with this sentence, add a pad token + if done[batch_idx]: + assert ( + len(generated_hyps[batch_idx]) >= num_beams + ), "Batch can only be done if at least {} beams have been generated".format(num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch + continue + + # next sentence beam content, this will get added to next_batch_beam + next_sent_beam = [] + + # next tokens for this sentence + for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx]) + ): + # get beam and token IDs + beam_id = beam_token_id // vocab_size + token_id = beam_token_id % vocab_size + + effective_beam_id = batch_idx * num_beams + beam_id + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (token_id.item() == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams + if is_beam_token_worse_than_top_num_beams: + continue + generated_hyps[batch_idx].add( + input_ids[effective_beam_id].clone(), + beam_token_score.item(), + ) + else: + # add next predicted token since it is not eos_token + next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) + + # once the beam for next step is full, don't add more tokens to it. + if len(next_sent_beam) == num_beams: + break + + # Check if we are done so that we can save a pad step if all(done) + done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + # update next beam content + assert len(next_sent_beam) == num_beams, "Beam should always be full" + next_batch_beam.extend(next_sent_beam) + assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step" + + # stop when we are done with each sentence + if all(done): + break + + # sanity check / prepare next batch + assert len(next_batch_beam) == batch_size * num_beams + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) + beam_idx = input_ids.new([x[2] for x in next_batch_beam]) + + # re-order batch and update current length + input_ids = input_ids[beam_idx, :] + input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) + cur_len = cur_len + 1 + + # re-order internal states + if past is not None: + past = self._reorder_cache(past, beam_idx) + + # extend attention_mask for new generated input if only decoder + # (huxu): move out since we trim attention_mask by ourselves. + # if self.config.is_encoder_decoder is False: + # attention_mask = torch.cat( + # [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + # ) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx in range(batch_size): + if done[batch_idx]: + continue + + # test that beam scores match previously calculated scores if not eos and batch_idx not done + if eos_token_id is not None and all( + (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx] + ): + assert torch.all( + next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] + ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( + next_scores[:, :num_beams][batch_idx], + beam_scores.view(batch_size, num_beams)[batch_idx], + ) + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(num_beams): + effective_beam_id = batch_idx * num_beams + beam_id + final_score = beam_scores[effective_beam_id].item() + final_tokens = input_ids[effective_beam_id] + generated_hyps[batch_idx].add(final_tokens, final_score) + + # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch + output_batch_size = batch_size if do_sample else batch_size * num_return_sequences + output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences + + # select the best hypotheses + sent_lengths = input_ids.new(output_batch_size) + best = [] + + # retrieve best hypotheses + for i, hypotheses in enumerate(generated_hyps): + sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) + for j in range(output_num_return_sequences_per_batch): + effective_batch_idx = output_num_return_sequences_per_batch * i + j + best_hyp = sorted_hyps.pop()[1] + sent_lengths[effective_batch_idx] = len(best_hyp) + best.append(best_hyp) + + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item() + 1, max_length) + decoded = input_ids.new(output_batch_size, sent_max_len) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < max_length: + decoded[i, sent_lengths[i]] = eos_token_id + + return decoded + + def _generate_no_beam_search( + self, + input_ids, + cur_len, + max_length, + min_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + no_repeat_ngram_size, + bad_words_ids, + pad_token_id, + eos_token_id, + batch_size, + attention_mask, + use_cache, + model_kwargs, + ): + """Generate sequences for each example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + # length of generated sentences / unfinished sentences + unfinished_sents = input_ids.new(batch_size).fill_(1) + sent_lengths = input_ids.new(batch_size).fill_(max_length) + + past = None + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs + ) + + outputs = self(**model_inputs, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] + scores = self.postprocess_next_token_scores( + scores=next_token_logits, + input_ids=input_ids, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + cur_len=cur_len, + min_length=min_length, + max_length=max_length, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + batch_size=batch_size, + num_beams=1, + ) + + # if model has past, then set the past variable to speed up decoding + if "past_key_values" in outputs: + past = outputs.past_key_values + elif "mems" in outputs: + past = outputs.mems + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + scores = scores / temperature + # Top-p/top-k filtering + next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) + # Sample + probs = F.softmax(next_token_logscores, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1) + + # print(next_token_logits[0,next_token[0]], next_token_logits[0,eos_token_id]) + + # update generations and finished sentences + if eos_token_id is not None: + # pad finished sentences if eos_token_id exist + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) + else: + tokens_to_add = next_token + + # add token and increase length by one + input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + cur_len = cur_len + 1 + + if eos_token_id is not None: + eos_in_sents = tokens_to_add == eos_token_id + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() + sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) + # unfinished_sents is set to zero if eos in sentence + unfinished_sents.mul_((~eos_in_sents).long()) + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sents.max() == 0: + break + + + # extend attention_mask for new generated input if only decoder + # if self.config.is_encoder_decoder is False: + # attention_mask = torch.cat( + # [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + # ) + + return input_ids diff --git a/fairseq/examples/MMPT/mmpt/processors/__init__.py b/fairseq/examples/MMPT/mmpt/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..434d1d92b95846b0cd87625834c5d9bed279a44e --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/processors/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .processor import * + +from .how2processor import * +from .how2retriprocessor import * + +from .dsprocessor import * + +try: + from .rawvideoprocessor import * + from .codecprocessor import * + from .webvidprocessor import * + from .expprocessor import * + from .exphow2processor import * + from .exphow2retriprocessor import * + from .expcodecprocessor import * + from .expfeatureencoder import * + from .expdsprocessor import * +except ImportError: + pass diff --git a/fairseq/examples/MMPT/mmpt/processors/dsprocessor.py b/fairseq/examples/MMPT/mmpt/processors/dsprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..ecebf0eea5c57b7846a2bd46ddffbdbf55bd83ab --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/processors/dsprocessor.py @@ -0,0 +1,848 @@ +# Copyright (c) Facebook, Inc. All Rights Reserved + +""" +Processors for all downstream (ds) tasks. +""" + +import json +import os +import pickle +import random +import math +import numpy as np +import torch + +from collections import defaultdict + +from .processor import ( + MetaProcessor, + VideoProcessor, + TextProcessor, + Aligner, + MMAttentionMask2DProcessor, +) + +from .how2processor import TextGenerationProcessor + + +# ------------- A General Aligner for all downstream tasks----------------- + + +class DSAligner(Aligner): + """ + Downstream (DS) aligner shared by all datasets. + """ + + def __call__(self, video_id, video_feature, text_feature, wps=0.7): + # random sample a starting sec for video. + video_start = 0 + video_end = min(len(video_feature), self.max_video_len) + # the whole sequence is a single clip. + video_clips = {"start": [video_start], "end": [video_end]} + + text_feature = { + "cap": [text_feature], + "start": [video_start], + "end": [len(text_feature) / wps], + } + text_clip_indexs = [0] + + vfeats, vmasks = self._build_video_seq( + video_feature, video_clips + ) + caps, cmasks = self._build_text_seq( + text_feature, text_clip_indexs + ) + + return { + "caps": caps, + "cmasks": cmasks, + "vfeats": vfeats, + "vmasks": vmasks, + "video_id": video_id, + } + + +class NLGTextProcessor(TextProcessor): + """ + Also return the original text as ref. + """ + def __call__(self, text_id): + return super().__call__(text_id), text_id + + +class DSNLGAligner(DSAligner): + """extend with the capability of 2d mask for generation.""" + def __init__(self, config): + super().__init__(config) + self.attnmasker = MMAttentionMask2DProcessor() + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained( + self.bert_name, use_fast=self.use_fast, + bos_token="[CLS]", eos_token="[SEP]" + ) + self.tokenizer = tokenizer + self.bos_token_id = tokenizer.bos_token_id + self.eos_token_id = tokenizer.eos_token_id + self.textgen = TextGenerationProcessor(tokenizer) + + def __call__(self, video_id, video_feature, text_feature): + output = super().__call__(video_id, video_feature, text_feature[0]) + if self.split == "test": + # output.update({"ref": text_feature[1]}) + output.update({"ref": self.tokenizer.decode( + output["caps"], skip_special_tokens=True)}) + text_label = output["caps"] + cmasks = torch.BoolTensor([1] * text_label.size(0)) + caps = torch.LongTensor([ + self.cls_token_id, + self.sep_token_id, + self.bos_token_id]) + else: + caps, text_label = self.textgen(output["caps"]) + cmasks = output["cmasks"] + + attention_mask = self.attnmasker( + output["vmasks"], cmasks, "textgen") + + output.update({ + "caps": caps, + "cmasks": cmasks, + "text_label": text_label, + "attention_mask": attention_mask, + }) + return output + + +# -------------------- MSRVTT ------------------------ + + +class MSRVTTMetaProcessor(MetaProcessor): + """MSRVTT dataset. + reference: `howto100m/msrvtt_dataloader.py` + """ + + def __init__(self, config): + super().__init__(config) + import pandas as pd + data = pd.read_csv(self._get_split_path(config)) + # TODO: add a text1ka flag. + if config.split == "train" \ + and config.full_test_path is not None \ + and config.jsfusion_path is not None: + # add testing videos from full_test_path not used by jfusion. + additional_data = pd.read_csv(config.full_test_path) + jsfusion_data = pd.read_csv(config.jsfusion_path) + + for video_id in additional_data["video_id"]: + if video_id not in jsfusion_data["video_id"].values: + data = data.append( + {"video_id": video_id}, ignore_index=True) + + if config.dup is not None and config.split == "train": + data = data.append([data] * (config.dup - 1), ignore_index=True) + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + """slightly modify with if condition to combine train/test.""" + vid, sentence = None, None + vid = self.data["video_id"].values[idx] + if "sentence" in self.data: # for testing. + sentence = self.data["sentence"].values[idx] + else: # for training. + sentence = vid + return vid, sentence + + +class MSRVTTTextProcessor(TextProcessor): + """MSRVTT dataset. + reference: `msrvtt_dataloader.py` `MSRVTT_TrainDataLoader`. + TODO (huxu): add max_words. + """ + + def __init__(self, config): + super().__init__(config) + self.sentences = None + if config.json_path is not None and config.split == "train": + with open(config.json_path) as fd: + self.data = json.load(fd) + self.sentences = defaultdict(list) + for s in self.data["sentences"]: + self.sentences[s["video_id"]].append(s["caption"]) + + def __call__(self, text_id): + if self.sentences is not None: + rind = random.randint(0, len(self.sentences[text_id]) - 1) + sentence = self.sentences[text_id][rind] + else: + sentence = text_id + caption = self.tokenizer(sentence, add_special_tokens=False) + return caption["input_ids"] + + +class MSRVTTNLGTextProcessor(MSRVTTTextProcessor): + """TODO: change dsaligner and merge to avoid any NLG text processor.""" + def __call__(self, text_id): + if self.sentences is not None: + rind = random.randint(0, len(self.sentences[text_id]) - 1) + sentence = self.sentences[text_id][rind] + else: + sentence = text_id + caption = self.tokenizer(sentence, add_special_tokens=False) + return caption["input_ids"], sentence + + +class MSRVTTQAMetaProcessor(MetaProcessor): + """MSRVTT-QA: retrieval-based multi-choice QA from JSFusion dataset. + For simplicity, we use the train retrieval model. + reference: `https://github.com/yj-yu/lsmdc` + """ + + def __init__(self, config): + super().__init__(config) + import pandas as pd + csv_data = pd.read_csv(self._get_split_path(config), sep="\t") + data = [] + for video_id, a1, a2, a3, a4, a5, answer in zip( + csv_data["vid_key"].values, + csv_data["a1"].values, + csv_data["a2"].values, + csv_data["a3"].values, + csv_data["a4"].values, + csv_data["a5"].values, + csv_data["answer"].values): + video_id = video_id.replace("msr", "video") + data.append((video_id, (answer, [a1, a2, a3, a4, a5]))) + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +class MSRVTTQATextProcessor(TextProcessor): + """MSRVTT-QA dataset. + text_ans is of format `(answer, [a1, a2, a3, a4, a5])`. + """ + + def __call__(self, text_ans): + for ans_idx, ans in enumerate(text_ans[1]): + if isinstance(ans, str): + text_ans[1][ans_idx] = self.tokenizer(ans, add_special_tokens=False)["input_ids"] + return text_ans + + +class MSRVTTQAAligner(DSAligner): + """MSRVTT dataset. + similar to sample in how2. + we call __call__ multiple times. + """ + + def __call__(self, video_id, video_feature, text_feature, wps=0.7): + caps = [] + cmasks = [] + answer = text_feature[0] + for ans_idx, _text_feature in enumerate(text_feature[1]): + output = super().__call__( + video_id, video_feature, _text_feature, wps) + caps.append(output["caps"]) + cmasks.append(output["cmasks"]) + output.update({ + "caps": torch.stack(caps), + "cmasks": torch.stack(cmasks), + "answers": torch.LongTensor([answer]), + }) + return output + + +# -------------------- Youcook ----------------------- + + +class YoucookMetaProcessor(MetaProcessor): + """Youcook dataset. + reference: `howto100m/youcook_dataloader.py` + note that the data can be different as the + (1) some videos already in Howto100m are removed. + (2) stop words are removed from caption + TODO (huxu): make a flag to load the original caption. + (see youcookii_annotations_trainval.json). + + The max_video_len can be 264 and text can be 64 tokens. + In reality we may not need that long. see projects/task/youcook.yaml + """ + + def __init__(self, config): + super().__init__(config) + vfeat_dir = config.vfeat_dir + print(self._get_split_path(config)) + with open(self._get_split_path(config), "rb") as fd: + data = pickle.load(fd) + all_valid_video_ids = set( + [os.path.splitext(fn)[0] for fn in os.listdir(vfeat_dir)] + ) + recs = [] + video_ids = set() + valid_video_ids = set() + for rec in data: # filter videos not available. + udl_idx = rec["id"].rindex("_") + video_id = rec["id"][:udl_idx] + video_ids.add(video_id) + if video_id in all_valid_video_ids: + valid_video_ids.add(video_id) + recs.append(rec) + print("total video_ids in .pkl", len(video_ids)) + print("valid video_ids in .pkl", len(valid_video_ids)) + print("please verify {train,val}_list.txt") + data = recs + self.data = data + + with open(config.trainval_annotation) as fd: + self.youcook_annotation = json.load(fd)["database"] + if config.use_annotation_text is True: + print("using text in annotation.") + self.use_annotation_caption = True + else: + self.use_annotation_caption = False + + def __getitem__(self, idx): + def _get_video_and_caption(rec): + vid = rec["id"] + udl_idx = vid.rindex("_") + video_id, clip_id = vid[:udl_idx], int(vid[udl_idx + 1:]) + clip = self.youcook_annotation[video_id]["annotations"][clip_id] + start, end = clip["segment"] + if self.use_annotation_caption: + caption = clip["sentence"] + else: + caption = rec["caption"] + return (video_id, start, end), caption + + rec = self.data[idx] + video_info, text_info = _get_video_and_caption(rec) + return video_info, text_info + + +class YoucookVideoProcessor(VideoProcessor): + """video_fn is a tuple of (video_id, start, end) now.""" + + def __call__(self, video_fn): + video_id, start, end = video_fn + feat = np.load(os.path.join(self.vfeat_dir, video_id + ".npy")) + return feat[start:end] + + +class YoucookNLGMetaProcessor(MetaProcessor): + """NLG uses the original split: + `train_list.txt` and `val_list.txt` + """ + + def __init__(self, config): + super().__init__(config) + vfeat_dir = config.vfeat_dir + print(self._get_split_path(config)) + with open(self._get_split_path(config)) as fd: + video_ids = [ + line.strip().split("/")[1] for line in fd.readlines()] + print("total video_ids in train/val_list.txt", len(video_ids)) + + all_valid_video_ids = set( + [os.path.splitext(fn)[0] for fn in os.listdir(vfeat_dir)] + ) + video_ids = [ + video_id for video_id in video_ids + if video_id in all_valid_video_ids] + + print("valid video_ids in train/val_list.txt", len(video_ids)) + with open(config.trainval_annotation) as fd: + self.youcook_annotation = json.load(fd)["database"] + + data = [] + for video_id in video_ids: + for clip in self.youcook_annotation[video_id]["annotations"]: + start, end = clip["segment"] + caption = clip["sentence"] + data.append(((video_id, start, end), caption)) + self.data = data + + def __getitem__(self, idx): + return self.data[idx] + + +# --------------------- CrossTask ------------------------- + +class CrossTaskMetaProcessor(MetaProcessor): + def __init__(self, config): + super().__init__(config) + np.random.seed(0) # deterministic random split. + task_vids = self._get_vids( + config.train_csv_path, + config.vfeat_dir, + config.annotation_path) + + val_vids = self._get_vids( + config.val_csv_path, + config.vfeat_dir, + config.annotation_path) + + # filter out those task and vids appear in val_vids. + task_vids = { + task: [ + vid for vid in vids + if task not in val_vids or vid not in val_vids[task]] + for task, vids in task_vids.items()} + + primary_info = self._read_task_info(config.primary_path) + test_tasks = set(primary_info['steps'].keys()) + + # if args.use_related: + related_info = self._read_task_info(config.related_path) + task_steps = {**primary_info['steps'], **related_info['steps']} + n_steps = {**primary_info['n_steps'], **related_info['n_steps']} + # else: + # task_steps = primary_info['steps'] + # n_steps = primary_info['n_steps'] + all_tasks = set(n_steps.keys()) + # filter and keep task in primary or related. + task_vids = { + task: vids for task, vids in task_vids.items() + if task in all_tasks} + # vocab-by-step matrix (A) and vocab (M) + # (huxu): we do not use BoW. + # A, M = self._get_A(task_steps, share="words") + + train_vids, test_vids = self._random_split( + task_vids, test_tasks, config.n_train) + print("train_num_videos", sum(len(vids) for vids in train_vids.values())) + print("test_num_videos", sum(len(vids) for vids in test_vids.values())) + # added by huxu to automatically determine the split. + split_map = { + "train": train_vids, + "valid": test_vids, + "test": test_vids + } + task_vids = split_map[config.split] + + self.vids = [] + for task, vids in task_vids.items(): + self.vids.extend([(task, vid) for vid in vids]) + self.task_steps = task_steps + self.n_steps = n_steps + + def __getitem__(self, idx): + task, vid = self.vids[idx] + n_steps = self.n_steps[task] + steps = self.task_steps[task] + assert len(steps) == n_steps + return (task, vid, steps, n_steps), (task, vid, steps, n_steps) + + def __len__(self): + return len(self.vids) + + def _random_split(self, task_vids, test_tasks, n_train): + train_vids = {} + test_vids = {} + for task, vids in task_vids.items(): + if task in test_tasks and len(vids) > n_train: + train_vids[task] = np.random.choice( + vids, n_train, replace=False).tolist() + test_vids[task] = [ + vid for vid in vids if vid not in train_vids[task]] + else: + train_vids[task] = vids + return train_vids, test_vids + + def _get_vids(self, path, vfeat_dir, annotation_path): + """refactored from + https://github.com/DmZhukov/CrossTask/blob/master/data.py + changes: add `vfeat_dir` to check if the video is available. + add `annotation_path` to check if the video is available. + """ + + task_vids = {} + with open(path, 'r') as f: + for line in f: + task, vid, url = line.strip().split(',') + # double check the video is available. + if not os.path.exists( + os.path.join(vfeat_dir, vid + ".npy")): + continue + # double check the annotation is available. + if not os.path.exists(os.path.join( + annotation_path, + task + "_" + vid + ".csv")): + continue + if task not in task_vids: + task_vids[task] = [] + task_vids[task].append(vid) + return task_vids + + def _read_task_info(self, path): + titles = {} + urls = {} + n_steps = {} + steps = {} + with open(path, 'r') as f: + idx = f.readline() + while idx != '': + idx = idx.strip() + titles[idx] = f.readline().strip() + urls[idx] = f.readline().strip() + n_steps[idx] = int(f.readline().strip()) + steps[idx] = f.readline().strip().split(',') + next(f) + idx = f.readline() + return { + 'title': titles, + 'url': urls, + 'n_steps': n_steps, + 'steps': steps + } + + def _get_A(self, task_steps, share="words"): + raise ValueError("running get_A is not allowed for BERT.") + """Step-to-component matrices.""" + if share == 'words': + # share words + task_step_comps = { + task: [step.split(' ') for step in steps] + for task, steps in task_steps.items()} + elif share == 'task_words': + # share words within same task + task_step_comps = { + task: [[task+'_'+tok for tok in step.split(' ')] for step in steps] + for task, steps in task_steps.items()} + elif share == 'steps': + # share whole step descriptions + task_step_comps = { + task: [[step] for step in steps] for task, steps in task_steps.items()} + else: + # no sharing + task_step_comps = { + task: [[task+'_'+step] for step in steps] + for task, steps in task_steps.items()} + # BERT tokenizer here? + vocab = [] + for task, steps in task_step_comps.items(): + for step in steps: + vocab.extend(step) + vocab = {comp: m for m, comp in enumerate(set(vocab))} + M = len(vocab) + A = {} + for task, steps in task_step_comps.items(): + K = len(steps) + a = torch.zeros(M, K) + for k, step in enumerate(steps): + a[[vocab[comp] for comp in step], k] = 1 + a /= a.sum(dim=0) + A[task] = a + return A, M + + +class CrossTaskVideoProcessor(VideoProcessor): + def __call__(self, video_fn): + task, vid, steps, n_steps = video_fn + video_fn = os.path.join(self.vfeat_dir, vid + ".npy") + feat = np.load(video_fn) + return feat + + +class CrossTaskTextProcessor(TextProcessor): + def __call__(self, text_id): + task, vid, steps, n_steps = text_id + step_ids = [] + for step_str in steps: + step_ids.append( + self.tokenizer(step_str, add_special_tokens=False)["input_ids"] + ) + return step_ids + + +class CrossTaskAligner(Aligner): + """ + TODO: it's not clear yet the formulation of the task; finish this later. + """ + def __init__(self, config): + super().__init__(config) + self.annotation_path = config.annotation_path + self.sliding_window = config.sliding_window + self.sliding_window_size = config.sliding_window_size + + def __call__(self, video_id, video_feature, text_feature): + task, vid, steps, n_steps = video_id + annot_path = os.path.join( + self.annotation_path, task + '_' + vid + '.csv') + video_len = len(video_feature) + + labels = torch.from_numpy(self._read_assignment( + video_len, n_steps, annot_path)).float() + + vfeats, vmasks, targets = [], [], [] + # sliding window on video features and targets. + for window_start in range(0, video_len, self.sliding_window): + video_start = 0 + video_end = min(video_len - window_start, self.sliding_window_size) + video_clip = {"start": [video_start], "end": [video_end]} + + vfeat, vmask = self._build_video_seq( + video_feature[window_start: window_start + video_end], + video_clip + ) + + target = labels[window_start: window_start + video_end] + assert len(vfeat) >= len(target), "{},{}".format(len(vfeat), len(target)) + # TODO: randomly drop all zero targets for training ? + # if self.split == "train" and target.sum() == 0: + # continue + vfeats.append(vfeat) + vmasks.append(vmask) + targets.append(target) + + if (video_len - window_start) <= self.sliding_window_size: + break + + vfeats = torch.stack(vfeats) + vmasks = torch.stack(vmasks) + targets = torch.cat(targets, dim=0) + + caps, cmasks = [], [] + for step in text_feature: + step_text_feature = {"start": [0], "end": [1], "cap": [step]} + step_text_clip_index = [0] + cap, cmask = self._build_text_seq( + step_text_feature, step_text_clip_index + ) + caps.append(cap) + cmasks.append(cmask) + caps = torch.stack(caps) + cmasks = torch.stack(cmasks) + + return { + "caps": caps, + "cmasks": cmasks, + "vfeats": vfeats, # X for original code. + "vmasks": vmasks, + "targets": targets, + "video_id": vid, + "task": task, + "video_len": video_len # for later checking. + } + + def _read_assignment(self, T, K, path): + """ + refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py + Howto interpret contraints on loss that is going to be minimized: + lambd is a big number; + self.lambd * C is a big number for all valid position (csv stores invalids) + + def forward(self, O, Y, C): + return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum() + + This will load the csv file and fill-in the step col from start to end rows. + """ + + Y = np.zeros([T, K], dtype=np.uint8) + with open(path, 'r') as f: + for line in f: + step, start, end = line.strip().split(',') + start = int(math.floor(float(start))) + end = int(math.ceil(float(end))) + step = int(step) - 1 + Y[start:end, step] = 1 + return Y + + +# --------------------- COIN ------------------------- + +class MetaTextBinarizer(Aligner): + def __call__(self, text_feature): + text_feature = { + "cap": [text_feature], + "start": [0.], + "end": [100.], + } + text_clip_indexs = [0] + + caps, cmasks = self._build_text_seq( + text_feature, text_clip_indexs + ) + return {"caps": caps, "cmasks": cmasks} + + +class COINActionSegmentationMetaProcessor(MetaProcessor): + split_map = { + "train": "training", + "valid": "testing", + "test": "testing", + } + + def __init__(self, config): + super().__init__(config) + with open(self._get_split_path(config)) as fr: + database = json.load(fr)["database"] + id2label = {} + data = [] + # filter the data by split. + for video_id, rec in database.items(): + # always use testing to determine label_set + if rec["subset"] == "testing": + for segment in rec["annotation"]: + id2label[int(segment["id"])] = segment["label"] + # text_labels is used for ZS setting + self.text_labels = ["none"] * len(id2label) + for label_id in id2label: + self.text_labels[label_id-1] = id2label[label_id] + + id2label[0] = "O" + print("num of labels", len(id2label)) + + for video_id, rec in database.items(): + if not os.path.isfile(os.path.join(config.vfeat_dir, video_id + ".npy")): + continue + if rec["subset"] == COINActionSegmentationMetaProcessor.split_map[self.split]: + starts, ends, labels = [], [], [] + for segment in rec["annotation"]: + start, end = segment["segment"] + label = int(segment["id"]) + starts.append(start) + ends.append(end) + labels.append(label) + data.append( + (video_id, {"start": starts, "end": ends, "label": labels})) + self.data = data + + def meta_text_labels(self, config): + from transformers import default_data_collator + from ..utils import get_local_rank + + text_processor = TextProcessor(config) + binarizer = MetaTextBinarizer(config) + # TODO: add prompts to .yaml. + text_labels = [label for label in self.text_labels] + + if get_local_rank() == 0: + print(text_labels) + + outputs = [] + for text_label in text_labels: + text_feature = text_processor(text_label) + outputs.append(binarizer(text_feature)) + return default_data_collator(outputs) + + def __getitem__(self, idx): + return self.data[idx] + + +class COINActionSegmentationTextProcessor(TextProcessor): + def __call__(self, text_label): + return text_label + + +class COINActionSegmentationAligner(Aligner): + def __init__(self, config): + super().__init__(config) + self.sliding_window = config.sliding_window + self.sliding_window_size = config.sliding_window_size + + def __call__(self, video_id, video_feature, text_feature): + starts, ends, label_ids = text_feature["start"], text_feature["end"], text_feature["label"] + # sliding window. + video_len = len(video_feature) + + vfeats, vmasks, targets = [], [], [] + # sliding window on video features and targets. + for window_start in range(0, video_len, self.sliding_window): + video_start = 0 + video_end = min(video_len - window_start, self.sliding_window_size) + video_clip = {"start": [video_start], "end": [video_end]} + vfeat, vmask = self._build_video_seq( + video_feature[window_start: window_start + video_end], + video_clip + ) + # covers video length only. + target = torch.full_like(vmask, -100, dtype=torch.long) + target[vmask] = 0 + for start, end, label_id in zip(starts, ends, label_ids): + if (window_start < end) and (start < (window_start + video_end)): + start_offset = max(0, math.floor(start) - window_start) + end_offset = min(video_end, math.ceil(end) - window_start) + target[start_offset:end_offset] = label_id + vfeats.append(vfeat) + vmasks.append(vmask) + targets.append(target) + if (video_len - window_start) <= self.sliding_window_size: + break + + vfeats = torch.stack(vfeats) + vmasks = torch.stack(vmasks) + targets = torch.stack(targets) + video_targets = torch.full((video_len,), 0) + for start, end, label_id in zip(starts, ends, label_ids): + start_offset = max(0, math.floor(start)) + end_offset = min(video_len, math.ceil(end)) + video_targets[start_offset:end_offset] = label_id + + caps = torch.LongTensor( + [[self.cls_token_id, self.sep_token_id, + self.pad_token_id, self.sep_token_id]], + ).repeat(vfeats.size(0), 1) + cmasks = torch.BoolTensor( + [[0, 1, 0, 1]] # pad are valid for attention. + ).repeat(vfeats.size(0), 1) + return { + "caps": caps, + "cmasks": cmasks, + "vfeats": vfeats, # X for original code. + "vmasks": vmasks, + "targets": targets, + "video_id": video_id, + "video_len": video_len, # for later checking. + "video_targets": video_targets + } + + +class DiDeMoMetaProcessor(MetaProcessor): + """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py + https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py + """ + def __init__(self, config): + super().__init__(config) + + assert "test" in self._get_split_path(config), "DiDeMo only supports zero-shot testing for now." + + with open(self._get_split_path(config)) as data_file: + json_data = json.load(data_file) + + data = [] + for record in json_data: + data.append((record["video"], record["description"])) + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +class DiDeMoTextProcessor(TextProcessor): + """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py + https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py + """ + + def __call__(self, text): + return self.tokenizer(text, add_special_tokens=False)["input_ids"] + + +class DiDeMoAligner(DSAligner): + """ + check video length. + """ + + def __call__(self, video_id, video_feature, text_feature): + # print(video_feature.shape[0]) + return super().__call__(video_id, video_feature, text_feature) diff --git a/fairseq/examples/MMPT/mmpt/processors/how2retriprocessor.py b/fairseq/examples/MMPT/mmpt/processors/how2retriprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a7730ec0bbe91d9997564214fffb10d0aef519 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/processors/how2retriprocessor.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .how2processor import ( + ShardedHow2MetaProcessor, + ShardedVideoProcessor, + ShardedTextProcessor, + VariedLenAligner, + OverlappedAligner +) + + +class ShardedHow2VideoRetriMetaProcessor(ShardedHow2MetaProcessor): + def __init__(self, config): + super().__init__(config) + self.num_video_per_batch = config.num_video_per_batch + self.cands = [ + self.data[batch_offset:batch_offset + self.num_video_per_batch] + for batch_offset in + range(0, (len(self.data) // (8 * self.num_video_per_batch)) * 8 * self.num_video_per_batch, self.num_video_per_batch)] + + def __len__(self): + return len(self.cands) + + def set_candidates(self, cands): + # no changes on num of batches. + print(len(self.cands), "->", len(cands)) + # assert len(self.cands) == len(cands) + self.cands = cands + + def __getitem__(self, idx): + video_ids = self.cands[idx] + assert isinstance(video_ids, list) + sharded_video_idxs = [] + for video_id in video_ids: + shard_id, video_idx = self.video_id_to_shard[video_id] + sharded_video_idxs.append((video_id, -1, shard_id, video_idx)) + return sharded_video_idxs, sharded_video_idxs + + +class ShardedVideoRetriVideoProcessor(ShardedVideoProcessor): + """In retrival case the video_id + is a list of tuples: `(shard_id, video_idx)` .""" + + def __call__(self, sharded_video_idxs): + assert isinstance(sharded_video_idxs, list) + cand_feats = [] + for shared_video_idx in sharded_video_idxs: + feat = super().__call__(shared_video_idx) + cand_feats.append(feat) + return cand_feats + + +class ShardedVideoRetriTextProcessor(ShardedTextProcessor): + """In retrival case the video_id + is a list of tuples: `(shard_id, video_idx)` .""" + + def __call__(self, sharded_video_idxs): + assert isinstance(sharded_video_idxs, list) + cand_caps = [] + for shared_video_idx in sharded_video_idxs: + caps = super().__call__(shared_video_idx) + cand_caps.append(caps) + return cand_caps + + +class VideoRetriAligner(VariedLenAligner): + # Retritask will trim dim-0. + def __call__(self, sharded_video_idxs, video_features, text_features): + from transformers import default_data_collator + batch, video_ids = [], [] + for video_id, video_feature, text_feature in \ + zip(sharded_video_idxs, video_features, text_features): + sub_batch = super().__call__(video_id, video_feature, text_feature) + batch.append(sub_batch) + if isinstance(video_id, tuple): + video_id = video_id[0] + video_ids.append(video_id) + batch = default_data_collator(batch) + batch["video_id"] = video_ids + return batch + + +class VideoRetriOverlappedAligner(OverlappedAligner): + # Retritask will trim dim-0. + def __call__(self, sharded_video_idxs, video_features, text_features): + from transformers import default_data_collator + batch, video_ids = [], [] + for video_id, video_feature, text_feature in \ + zip(sharded_video_idxs, video_features, text_features): + sub_batch = super().__call__(video_id, video_feature, text_feature) + batch.append(sub_batch) + if isinstance(video_id, tuple): + video_id = video_id[0] + video_ids.append(video_id) + batch = default_data_collator(batch) + batch["video_id"] = video_ids + return batch diff --git a/fairseq/examples/MMPT/mmpt/tasks/__init__.py b/fairseq/examples/MMPT/mmpt/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e9323a530672ef9daecd793ef645a3c1d0f3e6 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/tasks/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .task import * +from .vlmtask import * +from .retritask import * + +try: + from .fairseqmmtask import * +except ImportError: + pass + +try: + from .milncetask import * +except ImportError: + pass + +try: + from .expretritask import * +except ImportError: + pass diff --git a/fairseq/examples/MMPT/mmpt/utils/shardedtensor.py b/fairseq/examples/MMPT/mmpt/utils/shardedtensor.py new file mode 100644 index 0000000000000000000000000000000000000000..2424f360ef2a93f0f2783e38f0bd966086e92fd4 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/utils/shardedtensor.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import pickle +import numpy as np + + +class ShardedTensor(object): + def __init__(self, data, starts): + self.data = data + self.starts = starts + assert self.starts[0] == 0 + assert self.starts[-1] == len(self.data) + assert (self.starts[1:] >= self.starts[:-1]).all() + assert (self.starts > -1).all() + + @staticmethod + def from_list(xs): + starts = np.full((len(xs) + 1,), -1, dtype=np.long) + data = np.concatenate(xs, axis=0) + starts[0] = 0 + for i, x in enumerate(xs): + starts[i + 1] = starts[i] + x.shape[0] + assert (starts > -1).all() + return ShardedTensor(data, starts) + + def __getitem__(self, i): + return self.data[self.starts[i] : self.starts[i + 1]] + + def __len__(self): + return len(self.starts) - 1 + + def lengths(self): + return self.starts[1:] - self.starts[:-1] + + def save(self, path): + np.save(path + "_starts", self.starts) + np.save(path + "_data", self.data) + + @staticmethod + def load(path, mmap_mode=None): + starts = np.load(path + "_starts.npy", mmap_mode) + data = np.load(path + "_data.npy", mmap_mode) + return ShardedTensor(data, starts) diff --git a/fairseq/examples/MMPT/pretraining.md b/fairseq/examples/MMPT/pretraining.md new file mode 100644 index 0000000000000000000000000000000000000000..8f8e6d0facaa47141342294d7eb26c4232a9677b --- /dev/null +++ b/fairseq/examples/MMPT/pretraining.md @@ -0,0 +1,29 @@ +# Pretraining + +(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.) +We mostly use [howto100M](https://github.com/antoine77340/howto100m) dataset for pretraining (other datasets are coming). So you are less likely to write a new `MetaProcessor`, `VideoProcessor` or `TextProcessor` but only working on a new `Aligner`, a new model and loss. + +### Data Sharding +Pretraining on Howto100M is heavy on IO since we have millions of videos or captions on the hard disk that cannot be fit into the memory. +It is desirable to have an optimized preprocessing step before the actual dataloading. + +We support data sharding to pack multiple videos into a shards of training data for both videos and captions. (see [dataset](DATASET.md) for preprocessing). +These shards will be mapped into memory to reduce the frequency of IO access on millions of files. See (processors starting with `Sharded*`). +This will be the default config for a how2 dataset `projects/task/how2.yaml`. + +Great thanks to Dmytro Okhonko for sharing the code from MARGE project. + +### Training +Pretraining on Howto100m is expected on one or multiple nodes, where each node has 8 GPUS with 32 GB mem. +launching a pretraing on MFM+MLM can be done, via: +```python locallaunch.py projects/mfmmlm/how2.yaml``` + +### Pre-training with a Retrieval Model (VideoCLIP) +This projects now support alternatively run a retrieval model and pre-training. +We implement a basic retrieval model that is built on the hidden states of a video and faiss. + +You may need to install faiss via `conda install faiss-cpu -c pytorch`. + +Right now, the hidden states of a video is computed as the average of 8 clips of their pooled visual/text hidden states. +See `mmpt/tasks/retritask.py` for more details. +The `.yaml` config for running pre-training with a retrieval model can be found at `projects/retri/videoretri.yaml`. diff --git a/fairseq/examples/MMPT/projects/mfmmlm.yaml b/fairseq/examples/MMPT/projects/mfmmlm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f3450a1e091517de19422b0c09cea6a01a02729 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mfmmlm.yaml @@ -0,0 +1,59 @@ +project_dir: mfmmlm +run_task: + - how2.yaml + - [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml] +base_dir: task +task_group: + pretrain: + task_list: + - how2.yaml + dataset: + subsampling: 32 + sampled_min_len: 10 + sampled_max_len: 64 + max_video_len: 32 + max_len: 96 + aligner: MFMMLMAligner + lazy_vfeat_mask: True + mfm_probability: 0.15 + mlm_probability: 0.15 + mm_prob: 0.5 + model: + model_cls: MMFusionMFMMLM + mm_encoder_cls: MMFusionForMFMMLM + loss: + loss_cls: MFMMLM + fairseq: + common: + fp16: true + dataset: + batch_size: 256 + optimization: + max_epoch: 15 + finetune: + task_list: + - vtt.yaml + - vttqa.yaml + - youcook.yaml + - youcookcap.yaml + - crosstask.yaml + - coin.yaml + dataset: + max_video_len: 32 + max_len: 96 + fairseq: + common: + fp16: true + # do not write any model or loss here (they are expected to be fixed in mmfusion). + test: + task_list: + - test_vtt.yaml + - test_vttqa.yaml + - test_youcook.yaml + - test_youcookcap.yaml + - test_crosstask.yaml + - test_crosstask_zs.yaml + - test_coin.yaml + dataset: + max_video_len: 32 + max_len: 96 diff --git a/fairseq/examples/MMPT/projects/mtm/vlm.yaml b/fairseq/examples/MMPT/projects/mtm/vlm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..022a2623c5fdee1da51f211bef25a0dfebd7cbb0 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm.yaml @@ -0,0 +1,8 @@ +includes: projects/mtm/mmfusionmtm.yaml +project_dir: mtm/vlm +task_group: + pretrain: + dataset: + sampled_min_len: 8 + loss: + loss_cls: MTM diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/crosstask.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/crosstask.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4e706b549e6085c9823049afaff3285aa2b88f4f --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/crosstask.yaml @@ -0,0 +1,53 @@ +dataset: + video_processor: CrossTaskVideoProcessor + bert_name: bert-base-uncased + meta_processor: CrossTaskMetaProcessor + train_path: data/crosstask/crosstask_release/videos.csv + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + text_processor: CrossTaskTextProcessor + aligner: CrossTaskAligner + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 1 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 5 + checkpoint: + restore_file: runs/mtm/vlm/checkpoint11.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/mtm/vlm/crosstask +task_type: sweep_small +model: + model_cls: MMFusionActionLocalization + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +loss: + loss_cls: BCE diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/test_coin.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/test_coin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8df2e66ad175529e0e86f69ade1e392849507036 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/test_coin.yaml @@ -0,0 +1,31 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: COINActionSegmentationAligner + bert_name: bert-base-uncased + test_path: data/coin/COIN.json + meta_processor: COINActionSegmentationMetaProcessor + vfeat_dir: data/feat/feat_coin_s3d + text_processor: COINActionSegmentationTextProcessor + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 1 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/mtm/vlm/coin/checkpoint_best.pt +model: + model_cls: MMFusionActionSegmentation + mm_encoder_cls: MMBertForTokenClassification + use_seg_emb: true +eval: + save_path: runs/mtm/vlm/coin/eval +metric: COINActionSegmentationMetric +predictor: COINPredictor diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59833c55406ce53225d704b46eb0c141ad4d5e9d --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml @@ -0,0 +1,38 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: CrossTaskVideoProcessor + aligner: CrossTaskAligner + bert_name: bert-base-uncased + meta_processor: CrossTaskMetaProcessor + test_path: data/crosstask/crosstask_release/videos_val.csv + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + text_processor: CrossTaskTextProcessor + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 1 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/mtm/vlm/checkpoint_best.pt +model: + model_cls: MMFusionActionLocalization + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +eval: + save_path: runs/mtm/vlm/crosstask_zs/eval +metric: CrossTaskMetric +predictor: CrossTaskPredictor diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/test_youcook.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/test_youcook.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a57d25c240758cf989ad0ec538562d046446232 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/test_youcook.yaml @@ -0,0 +1,31 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: YoucookVideoProcessor + aligner: DSAligner + bert_name: bert-base-uncased + meta_processor: YoucookMetaProcessor + test_path: data/youcook/youcook_val.pkl + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + use_annotation_text: true + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: TextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/mtm/vlm/youcook/checkpoint_last.pt +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +eval: + save_path: runs/mtm/vlm/youcook/eval +metric: RetrievalMetric +predictor: RetrievalPredictor diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/youcookcap.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/youcookcap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d29dfad5cd861cc1e8f587c730938ed491d60289 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/youcookcap.yaml @@ -0,0 +1,45 @@ +dataset: + video_processor: YoucookVideoProcessor + bert_name: bert-base-uncased + meta_processor: YoucookNLGMetaProcessor + train_path: data/youcook/train_list.txt + val_path: data/youcook/val_list.txt + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: NLGTextProcessor + aligner: DSNLGAligner + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 128 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 10 + checkpoint: + restore_file: runs/mtm/vlm/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/mtm/vlm/youcookcap +task_type: sweep_small +model: + model_cls: MMFusionNLG + mm_encoder_cls: MMBertForNLG + use_seg_emb: true +loss: + loss_cls: NLGLoss diff --git a/fairseq/examples/MMPT/setup.py b/fairseq/examples/MMPT/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a82296ea9c3a53f760c29c34020b5a90091a89 --- /dev/null +++ b/fairseq/examples/MMPT/setup.py @@ -0,0 +1,24 @@ +import setuptools + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="mmpt", + version="0.0.1", + author="Hu Xu, Po-yao Huang", + author_email="huxu@fb.com", + description="A package for multimodal pretraining.", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/pytorch/fairseq/examples/MMPT", + packages=setuptools.find_packages(), + install_requires=[ + ], + classifiers=[ + "Programming Language :: Python :: 3", + "License :: CC-BY-NC", + "Operating System :: OS Independent", + ], + python_requires='>=3.6', +) diff --git a/fairseq/hubconf.py b/fairseq/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..5949e274edd02e86cb323331211641ce0d0b9b93 --- /dev/null +++ b/fairseq/hubconf.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import functools +import importlib + + +dependencies = [ + "dataclasses", + "hydra", + "numpy", + "omegaconf", + "regex", + "requests", + "torch", +] + + +# Check for required dependencies and raise a RuntimeError if any are missing. +missing_deps = [] +for dep in dependencies: + try: + importlib.import_module(dep) + except ImportError: + # Hack: the hydra package is provided under the "hydra-core" name in + # pypi. We don't want the user mistakenly calling `pip install hydra` + # since that will install an unrelated package. + if dep == "hydra": + dep = "hydra-core" + missing_deps.append(dep) +if len(missing_deps) > 0: + raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) + + +# only do fairseq imports after checking for dependencies +from fairseq.hub_utils import ( # noqa; noqa + BPEHubInterface as bpe, + TokenizerHubInterface as tokenizer, +) +from fairseq.models import MODEL_REGISTRY # noqa + + +# torch.hub doesn't build Cython components, so if they are not found then try +# to build them here +try: + import fairseq.data.token_block_utils_fast # noqa +except ImportError: + try: + import cython # noqa + import os + from setuptools import sandbox + + sandbox.run_setup( + os.path.join(os.path.dirname(__file__), "setup.py"), + ["build_ext", "--inplace"], + ) + except ImportError: + print( + "Unable to build Cython components. Please make sure Cython is " + "installed if the torch.hub model you are loading depends on it." + ) + + +# automatically expose models defined in FairseqModel::hub_models +for _model_type, _cls in MODEL_REGISTRY.items(): + for model_name in _cls.hub_models().keys(): + globals()[model_name] = functools.partial( + _cls.from_pretrained, + model_name, + ) diff --git a/fairseq/pyproject.toml b/fairseq/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..4d84c9bc361d51286b01a36960dae50428a26bac --- /dev/null +++ b/fairseq/pyproject.toml @@ -0,0 +1,23 @@ +[build-system] +requires = [ + "setuptools>=18.0", + "wheel", + "cython", + "numpy>=1.21.3", + "torch>=1.10", +] +build-backend = "setuptools.build_meta" + +[tool.black] +extend-exclude = ''' +( +^/examples/| +^/fairseq/model_parallel/megatron| +^/build/ +) +''' + +[tool.isort] +profile = "black" +known_third_party = "_cffi_backend,agg_results,aml,bitarray,boto3,botocore,dump_hubert_feature,dynamicconv_cuda,editdistance,faiss,fasttext,feature_utils,ffmpeg,g2p_en,h5py,hydra,hypothesis,indicnlp,inflect,iopath,joblib,kaldi_io,kenlm,libfb,librosa,lightconv_cuda,matplotlib,misc,mmpt,mmpt_cli,model,nltk,npy_append_array,numpy,omegaconf,pandas,pathbuilder,preprocessing,progressbar,pythainlp,random_sequence_shuffler,regex,sacrebleu,sacremoses,scipy,sentencepiece,setuptools,six,sklearn,soundfile,sweep,sweep_wmt_en2de_transformer_big_common,tabulate,torch,torchaudio,tqdm,unidecode,utils,videoreader,wav2vec_cluster_faiss,wget,yaml" +skip_gitignore = true diff --git a/fairseq/release_utils.py b/fairseq/release_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69a5e8dda3122643124b1e97d55b4efe7d0e0bc6 --- /dev/null +++ b/fairseq/release_utils.py @@ -0,0 +1,72 @@ +import argparse +from typing import Tuple + + +def get_next_version(release_type) -> Tuple[Tuple[int, int, int], str, str]: + current_ver = find_version("fairseq/version.txt") + version_list = [int(x) for x in current_ver.strip("'").split(".")] + major, minor, patch = version_list[0], version_list[1], version_list[2] + if release_type == "patch": + patch += 1 + elif release_type == "minor": + minor += 1 + patch = 0 + elif release_type == "major": + major += 1 + minor = patch = 0 + else: + raise ValueError( + "Incorrect release type specified. Acceptable types are major, minor and patch." + ) + + new_version_tuple = (major, minor, patch) + new_version_str = ".".join([str(x) for x in new_version_tuple]) + new_tag_str = "v" + new_version_str + return new_version_tuple, new_version_str, new_tag_str + + +def find_version(version_file_path) -> str: + with open(version_file_path) as f: + version = f.read().strip() + return version + + +def update_version(new_version_str) -> None: + """ + given the current version, update the version to the + next version depending on the type of release. + """ + + with open("fairseq/version.txt", "w") as writer: + writer.write(new_version_str) + + +def main(args): + if args.release_type in ["major", "minor", "patch"]: + new_version_tuple, new_version, new_tag = get_next_version(args.release_type) + else: + raise ValueError("Incorrect release type specified") + + if args.update_version: + update_version(new_version) + + print(new_version, new_tag) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Versioning utils") + parser.add_argument( + "--release-type", + type=str, + required=True, + help="type of release = major/minor/patch", + ) + parser.add_argument( + "--update-version", + action="store_true", + required=False, + help="updates the version in fairseq/version.txt", + ) + + args = parser.parse_args() + main(args) diff --git a/fairseq/setup.cfg b/fairseq/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..3fa679ddf1dd8dc4aa4045e0ac1f0a8bdc842e27 --- /dev/null +++ b/fairseq/setup.cfg @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 127 +extend-ignore = E203, W503 +extend-exclude = fairseq/model_parallel/megatron diff --git a/fairseq/setup.py b/fairseq/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..dae06080c53ae986ecdf7628215dccc894099348 --- /dev/null +++ b/fairseq/setup.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import subprocess +import sys + +from setuptools import Extension, find_packages, setup +from torch.utils import cpp_extension + +if sys.version_info < (3, 6): + sys.exit("Sorry, Python >= 3.6 is required for fairseq.") + + +def write_version_py(): + with open(os.path.join("fairseq", "version.txt")) as f: + version = f.read().strip() + + # write version info to fairseq/version.py + with open(os.path.join("fairseq", "version.py"), "w") as f: + f.write('__version__ = "{}"\n'.format(version)) + return version + + +version = write_version_py() + + +with open("README.md") as f: + readme = f.read() + + +if sys.platform == "darwin": + extra_compile_args = ["-stdlib=libc++", "-O3"] +else: + extra_compile_args = ["-std=c++11", "-O3"] + + +class NumpyExtension(Extension): + """Source: https://stackoverflow.com/a/54128391""" + + def __init__(self, *args, **kwargs): + self.__include_dirs = [] + super().__init__(*args, **kwargs) + + @property + def include_dirs(self): + import numpy + + return self.__include_dirs + [numpy.get_include()] + + @include_dirs.setter + def include_dirs(self, dirs): + self.__include_dirs = dirs + + +extensions = [ + Extension( + "fairseq.libbleu", + sources=[ + "fairseq/clib/libbleu/libbleu.cpp", + "fairseq/clib/libbleu/module.cpp", + ], + extra_compile_args=extra_compile_args, + ), + NumpyExtension( + "fairseq.data.data_utils_fast", + sources=["fairseq/data/data_utils_fast.pyx"], + language="c++", + extra_compile_args=extra_compile_args, + ), + NumpyExtension( + "fairseq.data.token_block_utils_fast", + sources=["fairseq/data/token_block_utils_fast.pyx"], + language="c++", + extra_compile_args=extra_compile_args, + ), +] + + +extensions.extend( + [ + cpp_extension.CppExtension( + "fairseq.libbase", + sources=[ + "fairseq/clib/libbase/balanced_assignment.cpp", + ], + ), + cpp_extension.CppExtension( + "fairseq.libnat", + sources=[ + "fairseq/clib/libnat/edit_dist.cpp", + ], + ), + cpp_extension.CppExtension( + "alignment_train_cpu_binding", + sources=[ + "examples/operators/alignment_train_cpu.cpp", + ], + ), + ] +) +if "CUDA_HOME" in os.environ: + extensions.extend( + [ + cpp_extension.CppExtension( + "fairseq.libnat_cuda", + sources=[ + "fairseq/clib/libnat_cuda/edit_dist.cu", + "fairseq/clib/libnat_cuda/binding.cpp", + ], + ), + cpp_extension.CppExtension( + "fairseq.ngram_repeat_block_cuda", + sources=[ + "fairseq/clib/cuda/ngram_repeat_block_cuda.cpp", + "fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu", + ], + ), + cpp_extension.CppExtension( + "alignment_train_cuda_binding", + sources=[ + "examples/operators/alignment_train_kernel.cu", + "examples/operators/alignment_train_cuda.cpp", + ], + ), + ] + ) + +cmdclass = {"build_ext": cpp_extension.BuildExtension} + +if "READTHEDOCS" in os.environ: + # don't build extensions when generating docs + extensions = [] + if "build_ext" in cmdclass: + del cmdclass["build_ext"] + + # use CPU build of PyTorch + dependency_links = [ + "https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" + ] +else: + dependency_links = [] + + +if "clean" in sys.argv[1:]: + # Source: https://bit.ly/2NLVsgE + print("deleting Cython files...") + + subprocess.run( + ["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"], + shell=True, + ) + + +extra_packages = [] +if os.path.exists(os.path.join("fairseq", "model_parallel", "megatron", "mpu")): + extra_packages.append("fairseq.model_parallel.megatron.mpu") + + +def do_setup(package_data): + setup( + name="fairseq", + version=version, + description="Facebook AI Research Sequence-to-Sequence Toolkit", + url="https://github.com/pytorch/fairseq", + classifiers=[ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + long_description=readme, + long_description_content_type="text/markdown", + install_requires=[ + "cffi", + "cython", + "hydra-core>=1.0.7,<1.1", + "omegaconf<2.1", + "numpy>=1.21.3", + "regex", + "sacrebleu>=1.4.12", + "torch>=1.13", + "tqdm", + "bitarray", + "torchaudio>=0.8.0", + "scikit-learn", + "packaging", + ], + extras_require={ + "dev": ["flake8", "pytest", "black==22.3.0"], + "docs": ["sphinx", "sphinx-argparse"], + }, + dependency_links=dependency_links, + packages=find_packages( + exclude=[ + "examples", + "examples.*", + "scripts", + "scripts.*", + "tests", + "tests.*", + ] + ) + + extra_packages, + package_data=package_data, + ext_modules=extensions, + test_suite="tests", + entry_points={ + "console_scripts": [ + "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", + "fairseq-generate = fairseq_cli.generate:cli_main", + "fairseq-hydra-train = fairseq_cli.hydra_train:cli_main", + "fairseq-interactive = fairseq_cli.interactive:cli_main", + "fairseq-preprocess = fairseq_cli.preprocess:cli_main", + "fairseq-score = fairseq_cli.score:cli_main", + "fairseq-train = fairseq_cli.train:cli_main", + "fairseq-validate = fairseq_cli.validate:cli_main", + ], + }, + cmdclass=cmdclass, + zip_safe=False, + ) + + +def get_files(path, relative_to="fairseq"): + all_files = [] + for root, _dirs, files in os.walk(path, followlinks=True): + root = os.path.relpath(root, relative_to) + for file in files: + if file.endswith(".pyc"): + continue + all_files.append(os.path.join(root, file)) + return all_files + + +if __name__ == "__main__": + try: + # symlink examples into fairseq package so package_data accepts them + fairseq_examples = os.path.join("fairseq", "examples") + if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): + os.symlink(os.path.join("..", "examples"), fairseq_examples) + + package_data = { + "fairseq": ( + get_files(fairseq_examples) + + get_files(os.path.join("fairseq", "config")) + ) + } + do_setup(package_data) + finally: + if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples): + os.unlink(fairseq_examples) diff --git a/fairseq/train.py b/fairseq/train.py new file mode 100644 index 0000000000000000000000000000000000000000..321de3d9b53f8194b58c26f5cb2c03281afc2bb1 --- /dev/null +++ b/fairseq/train.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Legacy entry point. Use fairseq_cli/train.py or fairseq-train instead. +""" + +from fairseq_cli.train import cli_main + + +if __name__ == "__main__": + cli_main() diff --git a/model_hf.py b/model_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5f6d2aa58c84a9be7925675a67d488d791ead4 --- /dev/null +++ b/model_hf.py @@ -0,0 +1,609 @@ +import random +from typing import Union +from transformers import PreTrainedModel +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import fairseq +from config import SSLConfig + +___author__ = "Hemlata Tak" +__email__ = "tak@eurecom.fr" + +############################ +## FOR fine-tuned SSL MODEL +############################ + + +class SSLModel(nn.Module): + def __init__(self,device): + super(SSLModel, self).__init__() + + cp_path = '/export/fs06/agarg22/xlsr_chkpt/xlsr2_300m.pt' # Change the pre-trained XLSR model path. + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) + self.model = model[0] + self.model_device=device + self.out_dim = 1024 + return + + def extract_feat(self, input_data): + + # put the model to GPU if it not there + if next(self.model.parameters()).device != input_data.device \ + or next(self.model.parameters()).dtype != input_data.dtype: + self.model.to(input_data.device, dtype=input_data.dtype) + self.model.train() + + + if True: + # input should be in shape (batch, length) + if input_data.ndim == 3: + input_tmp = input_data[:, :, 0] + else: + input_tmp = input_data + + # [batch, length, dim] + emb = self.model(input_tmp, mask=False, features_only=True)['x'] + return emb + + +#---------AASIST back-end------------------------# +''' Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans. + AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks. + In Proc. ICASSP 2022, pp: 6367--6371.''' + + +class GraphAttentionLayer(nn.Module): + def __init__(self, in_dim, out_dim, **kwargs): + super().__init__() + + # attention map + self.att_proj = nn.Linear(in_dim, out_dim) + self.att_weight = self._init_new_params(out_dim, 1) + + # project + self.proj_with_att = nn.Linear(in_dim, out_dim) + self.proj_without_att = nn.Linear(in_dim, out_dim) + + # batch norm + self.bn = nn.BatchNorm1d(out_dim) + + # dropout for inputs + self.input_drop = nn.Dropout(p=0.2) + + # activate + self.act = nn.SELU(inplace=True) + + # temperature + self.temp = 1. + if "temperature" in kwargs: + self.temp = kwargs["temperature"] + + def forward(self, x): + ''' + x :(#bs, #node, #dim) + ''' + # apply input dropout + x = self.input_drop(x) + + # derive attention map + att_map = self._derive_att_map(x) + + # projection + x = self._project(x, att_map) + + # apply batch norm + x = self._apply_BN(x) + x = self.act(x) + return x + + def _pairwise_mul_nodes(self, x): + ''' + Calculates pairwise multiplication of nodes. + - for attention map + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, #dim) + ''' + + nb_nodes = x.size(1) + x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) + x_mirror = x.transpose(1, 2) + + return x * x_mirror + + def _derive_att_map(self, x): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = self._pairwise_mul_nodes(x) + # size: (#bs, #node, #node, #dim_out) + att_map = torch.tanh(self.att_proj(att_map)) + # size: (#bs, #node, #node, 1) + att_map = torch.matmul(att_map, self.att_weight) + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _project(self, x, att_map): + x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) + x2 = self.proj_without_att(x) + + return x1 + x2 + + def _apply_BN(self, x): + org_size = x.size() + x = x.view(-1, org_size[-1]) + x = self.bn(x) + x = x.view(org_size) + + return x + + def _init_new_params(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + +class HtrgGraphAttentionLayer(nn.Module): + def __init__(self, in_dim, out_dim, **kwargs): + super().__init__() + + self.proj_type1 = nn.Linear(in_dim, in_dim) + self.proj_type2 = nn.Linear(in_dim, in_dim) + + # attention map + self.att_proj = nn.Linear(in_dim, out_dim) + self.att_projM = nn.Linear(in_dim, out_dim) + + self.att_weight11 = self._init_new_params(out_dim, 1) + self.att_weight22 = self._init_new_params(out_dim, 1) + self.att_weight12 = self._init_new_params(out_dim, 1) + self.att_weightM = self._init_new_params(out_dim, 1) + + # project + self.proj_with_att = nn.Linear(in_dim, out_dim) + self.proj_without_att = nn.Linear(in_dim, out_dim) + + self.proj_with_attM = nn.Linear(in_dim, out_dim) + self.proj_without_attM = nn.Linear(in_dim, out_dim) + + # batch norm + self.bn = nn.BatchNorm1d(out_dim) + + # dropout for inputs + self.input_drop = nn.Dropout(p=0.2) + + # activate + self.act = nn.SELU(inplace=True) + + # temperature + self.temp = 1. + if "temperature" in kwargs: + self.temp = kwargs["temperature"] + + def forward(self, x1, x2, master=None): + ''' + x1 :(#bs, #node, #dim) + x2 :(#bs, #node, #dim) + ''' + #print('x1',x1.shape) + #print('x2',x2.shape) + num_type1 = x1.size(1) + num_type2 = x2.size(1) + #print('num_type1',num_type1) + #print('num_type2',num_type2) + x1 = self.proj_type1(x1) + #print('proj_type1',x1.shape) + x2 = self.proj_type2(x2) + #print('proj_type2',x2.shape) + x = torch.cat([x1, x2], dim=1) + #print('Concat x1 and x2',x.shape) + + if master is None: + master = torch.mean(x, dim=1, keepdim=True) + #print('master',master.shape) + # apply input dropout + x = self.input_drop(x) + + # derive attention map + att_map = self._derive_att_map(x, num_type1, num_type2) + #print('master',master.shape) + # directional edge for master node + master = self._update_master(x, master) + #print('master',master.shape) + # projection + x = self._project(x, att_map) + #print('proj x',x.shape) + # apply batch norm + x = self._apply_BN(x) + x = self.act(x) + + x1 = x.narrow(1, 0, num_type1) + #print('x1',x1.shape) + x2 = x.narrow(1, num_type1, num_type2) + #print('x2',x2.shape) + return x1, x2, master + + def _update_master(self, x, master): + + att_map = self._derive_att_map_master(x, master) + master = self._project_master(x, master, att_map) + + return master + + def _pairwise_mul_nodes(self, x): + ''' + Calculates pairwise multiplication of nodes. + - for attention map + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, #dim) + ''' + + nb_nodes = x.size(1) + x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) + x_mirror = x.transpose(1, 2) + + return x * x_mirror + + def _derive_att_map_master(self, x, master): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = x * master + att_map = torch.tanh(self.att_projM(att_map)) + + att_map = torch.matmul(att_map, self.att_weightM) + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _derive_att_map(self, x, num_type1, num_type2): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = self._pairwise_mul_nodes(x) + # size: (#bs, #node, #node, #dim_out) + att_map = torch.tanh(self.att_proj(att_map)) + # size: (#bs, #node, #node, 1) + + att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1) + + att_board[:, :num_type1, :num_type1, :] = torch.matmul( + att_map[:, :num_type1, :num_type1, :], self.att_weight11) + att_board[:, num_type1:, num_type1:, :] = torch.matmul( + att_map[:, num_type1:, num_type1:, :], self.att_weight22) + att_board[:, :num_type1, num_type1:, :] = torch.matmul( + att_map[:, :num_type1, num_type1:, :], self.att_weight12) + att_board[:, num_type1:, :num_type1, :] = torch.matmul( + att_map[:, num_type1:, :num_type1, :], self.att_weight12) + + att_map = att_board + + + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _project(self, x, att_map): + x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) + x2 = self.proj_without_att(x) + + return x1 + x2 + + def _project_master(self, x, master, att_map): + + x1 = self.proj_with_attM(torch.matmul( + att_map.squeeze(-1).unsqueeze(1), x)) + x2 = self.proj_without_attM(master) + + return x1 + x2 + + def _apply_BN(self, x): + org_size = x.size() + x = x.view(-1, org_size[-1]) + x = self.bn(x) + x = x.view(org_size) + + return x + + def _init_new_params(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + +class GraphPool(nn.Module): + def __init__(self, k: float, in_dim: int, p: Union[float, int]): + super().__init__() + self.k = k + self.sigmoid = nn.Sigmoid() + self.proj = nn.Linear(in_dim, 1) + self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() + self.in_dim = in_dim + + def forward(self, h): + Z = self.drop(h) + weights = self.proj(Z) + scores = self.sigmoid(weights) + new_h = self.top_k_graph(scores, h, self.k) + + return new_h + + def top_k_graph(self, scores, h, k): + """ + args + ===== + scores: attention-based weights (#bs, #node, 1) + h: graph data (#bs, #node, #dim) + k: ratio of remaining nodes, (float) + returns + ===== + h: graph pool applied data (#bs, #node', #dim) + """ + _, n_nodes, n_feat = h.size() + n_nodes = max(int(n_nodes * k), 1) + _, idx = torch.topk(scores, n_nodes, dim=1) + idx = idx.expand(-1, -1, n_feat) + + h = h * scores + h = torch.gather(h, 1, idx) + + return h + + + + +class Residual_block(nn.Module): + def __init__(self, nb_filts, first=False): + super().__init__() + self.first = first + + if not self.first: + self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) + self.conv1 = nn.Conv2d(in_channels=nb_filts[0], + out_channels=nb_filts[1], + kernel_size=(2, 3), + padding=(1, 1), + stride=1) + self.selu = nn.SELU(inplace=True) + + self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) + self.conv2 = nn.Conv2d(in_channels=nb_filts[1], + out_channels=nb_filts[1], + kernel_size=(2, 3), + padding=(0, 1), + stride=1) + + if nb_filts[0] != nb_filts[1]: + self.downsample = True + self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], + out_channels=nb_filts[1], + padding=(0, 1), + kernel_size=(1, 3), + stride=1) + + else: + self.downsample = False + + + def forward(self, x): + identity = x + if not self.first: + out = self.bn1(x) + out = self.selu(out) + else: + out = x + + #print('out',out.shape) + out = self.conv1(x) + + #print('aft conv1 out',out.shape) + out = self.bn2(out) + out = self.selu(out) + # print('out',out.shape) + out = self.conv2(out) + #print('conv2 out',out.shape) + + if self.downsample: + identity = self.conv_downsample(identity) + + out += identity + #out = self.mp(out) + return out + + +class Model(PreTrainedModel,nn.Module): + config_class = SSLConfig + def __init__(self,device,config): + super().__init__(config) + # self.model_device = device + # print("Attributes and methods in PreTrainedModel:", dir(PreTrainedModel)) + # if hasattr(PreTrainedModel, "device") and isinstance(getattr(PreTrainedModel, "device"), property): + # print("device is a property in PreTrainedModel") + # else: + # print("device is NOT a property in PreTrainedModel") + + + # print(device) + # self.model_device ='cuda' if torch.cuda.is_available() else 'cpu' + # AASIST parameters + filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]] + gat_dims = [64, 32] + pool_ratios = [0.5, 0.5, 0.5, 0.5] + temperatures = [2.0, 2.0, 100.0, 100.0] + self.model_device = device + + + #### + # create network wav2vec 2.0 + #### + self.ssl_model = SSLModel(self.model_device) + self.LL = nn.Linear(self.ssl_model.out_dim, 128) + + self.first_bn = nn.BatchNorm2d(num_features=1) + self.first_bn1 = nn.BatchNorm2d(num_features=64) + self.drop = nn.Dropout(0.5, inplace=True) + self.drop_way = nn.Dropout(0.2, inplace=True) + self.selu = nn.SELU(inplace=True) + + # RawNet2 encoder + self.encoder = nn.Sequential( + nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), + nn.Sequential(Residual_block(nb_filts=filts[2])), + nn.Sequential(Residual_block(nb_filts=filts[3])), + nn.Sequential(Residual_block(nb_filts=filts[4])), + nn.Sequential(Residual_block(nb_filts=filts[4])), + nn.Sequential(Residual_block(nb_filts=filts[4]))) + + self.attention = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=(1,1)), + nn.SELU(inplace=True), + nn.BatchNorm2d(128), + nn.Conv2d(128, 64, kernel_size=(1,1)), + + ) + # position encoding + self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1])) + + self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) + self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) + + # Graph module + self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], + gat_dims[0], + temperature=temperatures[0]) + self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], + gat_dims[0], + temperature=temperatures[1]) + # HS-GAL layer + self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( + gat_dims[0], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( + gat_dims[1], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( + gat_dims[0], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( + gat_dims[1], gat_dims[1], temperature=temperatures[2]) + + # Graph pooling layers + self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) + self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) + self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + + self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + + self.out_layer = nn.Linear(5 * gat_dims[1], 2) + + def forward(self, x): + #-------pre-trained Wav2vec model fine tunning ------------------------## + x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1)) + x = self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim) + + # post-processing on front-end features + x = x.transpose(1, 2) #(bs,feat_out_dim,frame_number) + x = x.unsqueeze(dim=1) # add channel + x = F.max_pool2d(x, (3, 3)) + x = self.first_bn(x) + x = self.selu(x) + + # RawNet2-based encoder + x = self.encoder(x) + x = self.first_bn1(x) + x = self.selu(x) + + w = self.attention(x) + + #------------SA for spectral feature-------------# + w1 = F.softmax(w,dim=-1) + m = torch.sum(x * w1, dim=-1) + e_S = m.transpose(1, 2) + self.pos_S + + # graph module layer + gat_S = self.GAT_layer_S(e_S) + out_S = self.pool_S(gat_S) # (#bs, #node, #dim) + + #------------SA for temporal feature-------------# + w2 = F.softmax(w,dim=-2) + m1 = torch.sum(x * w2, dim=-2) + + e_T = m1.transpose(1, 2) + + # graph module layer + gat_T = self.GAT_layer_T(e_T) + out_T = self.pool_T(gat_T) + + # learnable master node + master1 = self.master1.expand(x.size(0), -1, -1) + master2 = self.master2.expand(x.size(0), -1, -1) + + # inference 1 + out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( + out_T, out_S, master=self.master1) + + out_S1 = self.pool_hS1(out_S1) + out_T1 = self.pool_hT1(out_T1) + + out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( + out_T1, out_S1, master=master1) + out_T1 = out_T1 + out_T_aug + out_S1 = out_S1 + out_S_aug + master1 = master1 + master_aug + + # inference 2 + out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( + out_T, out_S, master=self.master2) + out_S2 = self.pool_hS2(out_S2) + out_T2 = self.pool_hT2(out_T2) + + out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( + out_T2, out_S2, master=master2) + out_T2 = out_T2 + out_T_aug + out_S2 = out_S2 + out_S_aug + master2 = master2 + master_aug + + out_T1 = self.drop_way(out_T1) + out_T2 = self.drop_way(out_T2) + out_S1 = self.drop_way(out_S1) + out_S2 = self.drop_way(out_S2) + master1 = self.drop_way(master1) + master2 = self.drop_way(master2) + + out_T = torch.max(out_T1, out_T2) + out_S = torch.max(out_S1, out_S2) + master = torch.max(master1, master2) + + # Readout operation + T_max, _ = torch.max(torch.abs(out_T), dim=1) + T_avg = torch.mean(out_T, dim=1) + + S_max, _ = torch.max(torch.abs(out_S), dim=1) + S_avg = torch.mean(out_S, dim=1) + + # Sept 18: Features + last_hidden = torch.cat( + [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) + + last_hidden = self.drop(last_hidden) + # Sept 18: classifier + output = self.out_layer(last_hidden) + + return output