Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- config.py +18 -0
- fairseq/.github/CODEOWNERS +21 -0
- fairseq/.github/ISSUE_TEMPLATE.md +3 -0
- fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
- fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
- fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
- fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
- fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
- fairseq/.github/stale.yml +30 -0
- fairseq/.github/workflows/build.yml +81 -0
- fairseq/.github/workflows/depreview.yml +14 -0
- fairseq/.github/workflows/release.yml +161 -0
- fairseq/.gitignore +141 -0
- fairseq/.gitmodules +4 -0
- fairseq/.pre-commit-config.yaml +40 -0
- fairseq/CODE_OF_CONDUCT.md +77 -0
- fairseq/CONTRIBUTING.md +82 -0
- fairseq/LICENSE +21 -0
- fairseq/MANIFEST.in +1 -0
- fairseq/README.md +242 -0
- fairseq/RELEASE.md +13 -0
- fairseq/docs/Makefile +20 -0
- fairseq/docs/command_line_tools.rst +85 -0
- fairseq/docs/conf.py +98 -0
- fairseq/docs/criterions.rst +31 -0
- fairseq/docs/data.rst +58 -0
- fairseq/docs/docutils.conf +2 -0
- fairseq/docs/fairseq_logo.png +0 -0
- fairseq/docs/getting_started.rst +216 -0
- fairseq/docs/hydra_integration.md +284 -0
- fairseq/docs/index.rst +49 -0
- fairseq/docs/lr_scheduler.rst +34 -0
- fairseq/docs/make.bat +36 -0
- fairseq/docs/modules.rst +9 -0
- fairseq/docs/optim.rst +38 -0
- fairseq/docs/overview.rst +74 -0
- fairseq/docs/tasks.rst +61 -0
- fairseq/docs/tutorial_simple_lstm.rst +518 -0
- fairseq/examples/.gitignore +2 -0
- fairseq/examples/MMPT/CONFIG.md +41 -0
- fairseq/examples/MMPT/DATASET.md +34 -0
- fairseq/examples/MMPT/locallaunch.py +148 -0
- fairseq/examples/MMPT/mmpt/__init__.py +12 -0
- fairseq/examples/MMPT/mmpt/datasets/__init__.py +10 -0
- fairseq/examples/MMPT/mmpt/evaluators/metric.py +313 -0
- fairseq/examples/MMPT/mmpt/evaluators/predictor.py +595 -0
- fairseq/examples/MMPT/mmpt/losses/__init__.py +16 -0
- fairseq/examples/MMPT/mmpt/losses/fairseqmmloss.py +63 -0
- fairseq/examples/MMPT/mmpt/losses/loss.py +87 -0
- fairseq/examples/MMPT/mmpt/losses/nce.py +156 -0
config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class SSLConfig(PretrainedConfig):
|
4 |
+
model_type = "ssl-aasist"
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]],
|
8 |
+
gat_dims = [64, 32],
|
9 |
+
pool_ratios = [0.5, 0.5, 0.5, 0.5],
|
10 |
+
temperatures = [2.0, 2.0, 100.0, 100.0],
|
11 |
+
**kwargs,
|
12 |
+
):
|
13 |
+
|
14 |
+
self.filts = filts
|
15 |
+
self.gat_dims = gat_dims
|
16 |
+
self.pool_ratios = pool_ratios
|
17 |
+
self.temperatures = temperatures
|
18 |
+
super().__init__(**kwargs)
|
fairseq/.github/CODEOWNERS
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Setting up CODEOWNERS for UST related codebase
|
2 |
+
# Documentation for open sourced models relevant to UST
|
3 |
+
examples/speech_to_text @kahne @sravyapopuri388 @jmp84
|
4 |
+
examples/speech_to_speech @an918tw @sravyapopuri388 @jmp84
|
5 |
+
examples/speech_synthesis @kahne @jmp84
|
6 |
+
examples/simultaneous_translation @kahne @jmp84
|
7 |
+
examples/speech_text_joint_to_text @yuntang @jmp84
|
8 |
+
|
9 |
+
# Speech related models relevant to UST
|
10 |
+
fairseq/models/speech_to_speech @sravyapopuri388 @jmp84
|
11 |
+
fairseq/models/speech_to_text @kahne @sravyapopuri388 @jmp84
|
12 |
+
fairseq/models/text_to_speech @kahne @jmp84
|
13 |
+
|
14 |
+
# CONFORMER IMPLEMENTATION
|
15 |
+
fairseq/modules/conformer_layer.py @sravyapopuri388 @jmp84
|
16 |
+
fairseq/modules/espnet_multihead_attention.py @sravyapopuri388 @jmp84
|
17 |
+
fairseq/modules/rotary_positional_embedding.py @sravyapopuri388 @jmp84
|
18 |
+
fairseq/modules/positional_encoding.py @sravyapopuri388 @jmp84
|
19 |
+
|
20 |
+
# Machine Translation/NLLB
|
21 |
+
fairseq/tasks/translation.py @gwenzek
|
fairseq/.github/ISSUE_TEMPLATE.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
|
2 |
+
|
3 |
+
Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
|
fairseq/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🐛 Bug Report
|
3 |
+
about: Submit a bug report to help us improve
|
4 |
+
labels: 'bug, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 🐛 Bug
|
8 |
+
|
9 |
+
<!-- A clear and concise description of what the bug is. -->
|
10 |
+
|
11 |
+
### To Reproduce
|
12 |
+
|
13 |
+
Steps to reproduce the behavior (**always include the command you ran**):
|
14 |
+
|
15 |
+
1. Run cmd '....'
|
16 |
+
2. See error
|
17 |
+
|
18 |
+
<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
|
19 |
+
|
20 |
+
|
21 |
+
#### Code sample
|
22 |
+
<!-- Ideally attach a minimal code sample to reproduce the decried issue.
|
23 |
+
Minimal means having the shortest code but still preserving the bug. -->
|
24 |
+
|
25 |
+
### Expected behavior
|
26 |
+
|
27 |
+
<!-- A clear and concise description of what you expected to happen. -->
|
28 |
+
|
29 |
+
### Environment
|
30 |
+
|
31 |
+
- fairseq Version (e.g., 1.0 or main):
|
32 |
+
- PyTorch Version (e.g., 1.0)
|
33 |
+
- OS (e.g., Linux):
|
34 |
+
- How you installed fairseq (`pip`, source):
|
35 |
+
- Build command you used (if compiling from source):
|
36 |
+
- Python version:
|
37 |
+
- CUDA/cuDNN version:
|
38 |
+
- GPU models and configuration:
|
39 |
+
- Any other relevant information:
|
40 |
+
|
41 |
+
### Additional context
|
42 |
+
|
43 |
+
<!-- Add any other context about the problem here. -->
|
fairseq/.github/ISSUE_TEMPLATE/documentation.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 📚 Documentation/Typos
|
3 |
+
about: Report an issue related to documentation or a typo
|
4 |
+
labels: 'documentation, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 📚 Documentation
|
8 |
+
|
9 |
+
For typos and doc fixes, please go ahead and:
|
10 |
+
|
11 |
+
1. Create an issue.
|
12 |
+
2. Fix the typo.
|
13 |
+
3. Submit a PR.
|
14 |
+
|
15 |
+
Thanks!
|
fairseq/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🚀 Feature Request
|
3 |
+
about: Submit a proposal/request for a new feature
|
4 |
+
labels: 'enhancement, help wanted, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 🚀 Feature Request
|
8 |
+
<!-- A clear and concise description of the feature proposal -->
|
9 |
+
|
10 |
+
### Motivation
|
11 |
+
|
12 |
+
<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
|
13 |
+
|
14 |
+
### Pitch
|
15 |
+
|
16 |
+
<!-- A clear and concise description of what you want to happen. -->
|
17 |
+
|
18 |
+
### Alternatives
|
19 |
+
|
20 |
+
<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
|
21 |
+
|
22 |
+
### Additional context
|
23 |
+
|
24 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
fairseq/.github/ISSUE_TEMPLATE/how-to-question.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: ❓ Questions/Help
|
3 |
+
about: If you have questions, please first search existing issues and docs
|
4 |
+
labels: 'question, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## ❓ Questions and Help
|
8 |
+
|
9 |
+
### Before asking:
|
10 |
+
1. search the issues.
|
11 |
+
2. search the docs.
|
12 |
+
|
13 |
+
<!-- If you still can't find what you need: -->
|
14 |
+
|
15 |
+
#### What is your question?
|
16 |
+
|
17 |
+
#### Code
|
18 |
+
|
19 |
+
<!-- Please paste a code snippet if your question requires it! -->
|
20 |
+
|
21 |
+
#### What have you tried?
|
22 |
+
|
23 |
+
#### What's your environment?
|
24 |
+
|
25 |
+
- fairseq Version (e.g., 1.0 or main):
|
26 |
+
- PyTorch Version (e.g., 1.0)
|
27 |
+
- OS (e.g., Linux):
|
28 |
+
- How you installed fairseq (`pip`, source):
|
29 |
+
- Build command you used (if compiling from source):
|
30 |
+
- Python version:
|
31 |
+
- CUDA/cuDNN version:
|
32 |
+
- GPU models and configuration:
|
33 |
+
- Any other relevant information:
|
fairseq/.github/PULL_REQUEST_TEMPLATE.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Before submitting
|
2 |
+
|
3 |
+
- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
|
4 |
+
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
|
5 |
+
- [ ] Did you make sure to update the docs?
|
6 |
+
- [ ] Did you write any new necessary tests?
|
7 |
+
|
8 |
+
## What does this PR do?
|
9 |
+
Fixes # (issue).
|
10 |
+
|
11 |
+
## PR review
|
12 |
+
Anyone in the community is free to review the PR once the tests have passed.
|
13 |
+
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
|
14 |
+
|
15 |
+
## Did you have fun?
|
16 |
+
Make sure you had fun coding 🙃
|
fairseq/.github/stale.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for probot-stale - https://github.com/probot/stale
|
2 |
+
# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
|
3 |
+
# Number of days of inactivity before an issue becomes stale
|
4 |
+
daysUntilStale: 90
|
5 |
+
# Number of days of inactivity before a stale issue is closed
|
6 |
+
daysUntilClose: 7
|
7 |
+
# Issues with these labels will never be considered stale
|
8 |
+
exemptLabels:
|
9 |
+
- bug
|
10 |
+
# Label to use when marking an issue as stale
|
11 |
+
staleLabel: stale
|
12 |
+
issues:
|
13 |
+
# Comment to post when marking an issue as stale.
|
14 |
+
markComment: >
|
15 |
+
This issue has been automatically marked as stale.
|
16 |
+
**If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
|
17 |
+
We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
|
18 |
+
# Comment to post when closing a stale issue.
|
19 |
+
closeComment: >
|
20 |
+
Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
|
21 |
+
pulls:
|
22 |
+
# Comment to post when marking a pull request as stale.
|
23 |
+
markComment: >
|
24 |
+
This pull request has been automatically marked as stale.
|
25 |
+
**If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
|
26 |
+
We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
|
27 |
+
# Comment to post when closing a stale pull request.
|
28 |
+
closeComment: >
|
29 |
+
Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
|
30 |
+
|
fairseq/.github/workflows/build.yml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: build
|
2 |
+
|
3 |
+
on:
|
4 |
+
# Trigger the workflow on push to main or any pull request
|
5 |
+
push:
|
6 |
+
branches:
|
7 |
+
- main
|
8 |
+
pull_request:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build:
|
12 |
+
|
13 |
+
strategy:
|
14 |
+
max-parallel: 4
|
15 |
+
matrix:
|
16 |
+
platform: [ubuntu-latest, macos-latest]
|
17 |
+
python-version: [3.8, 3.9]
|
18 |
+
|
19 |
+
runs-on: ${{ matrix.platform }}
|
20 |
+
|
21 |
+
steps:
|
22 |
+
- uses: actions/checkout@v2
|
23 |
+
|
24 |
+
- name: Set up Python ${{ matrix.python-version }}
|
25 |
+
uses: actions/setup-python@v2
|
26 |
+
with:
|
27 |
+
python-version: ${{ matrix.python-version }}
|
28 |
+
|
29 |
+
- name: Conditionally install pytorch
|
30 |
+
if: matrix.platform == 'windows-latest'
|
31 |
+
run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
|
32 |
+
|
33 |
+
- name: Install locally
|
34 |
+
run: |
|
35 |
+
python -m pip install --upgrade pip
|
36 |
+
git submodule update --init --recursive
|
37 |
+
python -m pip install .
|
38 |
+
|
39 |
+
- name: Check installation
|
40 |
+
working-directory: /tmp
|
41 |
+
run: python $GITHUB_WORKSPACE/scripts/check_installation.py
|
42 |
+
|
43 |
+
- name: Install optional test requirements
|
44 |
+
run: |
|
45 |
+
python -m pip install '.[dev,docs]'
|
46 |
+
python -m pip install iopath transformers pyarrow
|
47 |
+
python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
|
48 |
+
python -m pip install pygit2 pgzip
|
49 |
+
|
50 |
+
- name: Install xformers for Macos
|
51 |
+
if: matrix.platform == 'macos-latest'
|
52 |
+
run: |
|
53 |
+
brew install llvm libomp
|
54 |
+
CC=/usr/local/opt/llvm/bin/clang CXX=clang++ pip install git+https://github.com/facebookresearch/xformers.git@main
|
55 |
+
|
56 |
+
- name: Install xformers for non-MacOS
|
57 |
+
if: matrix.platform != 'macos-latest'
|
58 |
+
run: |
|
59 |
+
python -m pip install --progress-bar off git+https://github.com/facebookresearch/xformers.git@main
|
60 |
+
|
61 |
+
- name: Lint with black
|
62 |
+
run: black --check --diff .
|
63 |
+
|
64 |
+
- name: Lint with flake8
|
65 |
+
run: |
|
66 |
+
# stop the build if there are Python syntax errors or undefined names
|
67 |
+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
68 |
+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
69 |
+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
70 |
+
|
71 |
+
- name: Build doc
|
72 |
+
run: make singlehtml
|
73 |
+
working-directory: docs/
|
74 |
+
|
75 |
+
- name: Run tests
|
76 |
+
# When installing in non-editable mode, the .so files will be generated in 'site-packages/fairseq'.
|
77 |
+
# But by default, pytest import machinery will load local fairseq, and won't see the .so.
|
78 |
+
# Use --import-mode=append to favorize the 'site-packages/fairseq'.
|
79 |
+
# https://docs.pytest.org/en/7.1.x/explanation/pythonpath.html
|
80 |
+
run: pytest --import-mode=append -vvv tests/
|
81 |
+
|
fairseq/.github/workflows/depreview.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: 'Dependency Review'
|
2 |
+
on: [pull_request]
|
3 |
+
|
4 |
+
permissions:
|
5 |
+
contents: read
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
dependency-review:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- name: 'Checkout Repository'
|
12 |
+
uses: actions/checkout@v4
|
13 |
+
- name: Dependency Review
|
14 |
+
uses: actions/dependency-review-action@v4
|
fairseq/.github/workflows/release.yml
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Fairseq Release
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
inputs:
|
6 |
+
name:
|
7 |
+
description: 'Release Type'
|
8 |
+
default: 'patch'
|
9 |
+
required: true
|
10 |
+
|
11 |
+
jobs:
|
12 |
+
|
13 |
+
get_next_version:
|
14 |
+
runs-on: ubuntu-latest
|
15 |
+
steps:
|
16 |
+
- name: checkout-repo-content
|
17 |
+
uses: actions/checkout@v2
|
18 |
+
|
19 |
+
- name: setup-python
|
20 |
+
uses: actions/setup-python@v2
|
21 |
+
with:
|
22 |
+
python-version: 3.8
|
23 |
+
|
24 |
+
- name: get next version and tag
|
25 |
+
id: get-next-version-and-tag
|
26 |
+
run: |
|
27 |
+
output=$(python3 release_utils.py --release-type ${{ github.event.inputs.name }})
|
28 |
+
echo $output
|
29 |
+
new_version=$(echo $output | awk '{print $1}')
|
30 |
+
new_tag=$(echo $output | awk '{print $2}')
|
31 |
+
echo "new version is $new_version"
|
32 |
+
echo "new tag is $new_tag"
|
33 |
+
echo ::set-output name=version::$new_version
|
34 |
+
echo ::set-output name=tag::$new_tag
|
35 |
+
echo ::set-output name=branch_name::$new_version-release
|
36 |
+
echo "NEW_TAG=$new_tag" >> $GITHUB_ENV
|
37 |
+
echo "NEW_BRANCH=$new_version-release" >> $GITHUB_ENV
|
38 |
+
|
39 |
+
|
40 |
+
# update the version number in version.txt
|
41 |
+
- name: update version
|
42 |
+
id: update-version
|
43 |
+
run : |
|
44 |
+
echo "current folder = $PWD"
|
45 |
+
echo "current branch = $(git branch --show-current)"
|
46 |
+
output=$(python3 release_utils.py --release-type ${{ github.event.inputs.name }} --update-version)
|
47 |
+
|
48 |
+
- name: add and commit
|
49 |
+
uses: EndBug/add-and-commit@v9
|
50 |
+
with:
|
51 |
+
author_name: ${{ secrets.AUTHOR_NAME }}
|
52 |
+
author_email: ${{ secrets.AUTHOR_EMAIL }}
|
53 |
+
|
54 |
+
# TODO: change this to main once shipit is disabled.
|
55 |
+
new_branch: '${{ env.NEW_BRANCH }}'
|
56 |
+
default_author: github_actor
|
57 |
+
message: '${{ env.NEW_TAG }} release'
|
58 |
+
pathspec_error_handling: exitAtEnd
|
59 |
+
|
60 |
+
# Arguments for the git pull command. Use NO-PULL to avoid the action pulling at all.
|
61 |
+
# pull: 'NO-PULL'
|
62 |
+
tag: '${{ env.NEW_TAG }}'
|
63 |
+
|
64 |
+
outputs:
|
65 |
+
new_version: ${{ steps.get-next-version-and-tag.outputs.version }}
|
66 |
+
new_tag: ${{ steps.get-next-version-and-tag.outputs.tag }}
|
67 |
+
branch_name: ${{ steps.get-next-version-and-tag.outputs.branch_name }}
|
68 |
+
|
69 |
+
create_sdist:
|
70 |
+
runs-on: ubuntu-latest
|
71 |
+
name: Create Source Distribution
|
72 |
+
needs: get_next_version
|
73 |
+
steps:
|
74 |
+
- uses: actions/checkout@v3
|
75 |
+
with:
|
76 |
+
ref: ${{ needs.get_next_version.outputs.branch_name }}
|
77 |
+
|
78 |
+
- name: Install Python
|
79 |
+
uses: actions/setup-python@v2
|
80 |
+
with:
|
81 |
+
python-version: '3.8'
|
82 |
+
|
83 |
+
- name: Upgrade pip
|
84 |
+
run: |
|
85 |
+
python3 -m pip install --upgrade pip
|
86 |
+
|
87 |
+
- name: Create Source Distribution
|
88 |
+
run: |
|
89 |
+
python3 -m pip install setuptools wheel twine torch
|
90 |
+
python3 setup.py sdist
|
91 |
+
|
92 |
+
- uses: actions/upload-artifact@v2
|
93 |
+
with:
|
94 |
+
path: dist/*.tar.gz
|
95 |
+
|
96 |
+
build_wheels:
|
97 |
+
name: Build wheels on ${{ matrix.os }}
|
98 |
+
runs-on: ${{ matrix.os }}
|
99 |
+
needs: get_next_version
|
100 |
+
strategy:
|
101 |
+
matrix:
|
102 |
+
os: [ubuntu-latest, macos-latest]
|
103 |
+
|
104 |
+
steps:
|
105 |
+
- uses: actions/checkout@v3
|
106 |
+
with:
|
107 |
+
ref: ${{ needs.get_next_version.outputs.branch_name }}
|
108 |
+
|
109 |
+
- name: Install Python
|
110 |
+
uses: actions/setup-python@v2
|
111 |
+
with:
|
112 |
+
python-version: '3.8'
|
113 |
+
|
114 |
+
- name: Upgrade pip
|
115 |
+
run: |
|
116 |
+
python3 -m pip install --upgrade pip
|
117 |
+
|
118 |
+
- name: Install cibuildwheel
|
119 |
+
run: |
|
120 |
+
python3 -m pip install cibuildwheel
|
121 |
+
|
122 |
+
- name: Build wheels for CPython
|
123 |
+
run: |
|
124 |
+
python3 -m cibuildwheel --output-dir dist
|
125 |
+
env:
|
126 |
+
CIBW_BUILD: "cp38-*64"
|
127 |
+
CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
|
128 |
+
CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
|
129 |
+
# Install system library
|
130 |
+
CIBW_BEFORE_BUILD_LINUX: (yum install -y libffi-devel || apt-get install -y libffi-devel || apk add --update --no-cache libffi-devel || true) && (yum install -y libc6 || apt-get install -y libc6 || apk add --update --no-cache libc6 || true)
|
131 |
+
CIBW_ENVIRONMENT: "PIP_ONLY_BINARY=numpy"
|
132 |
+
CIBW_SKIP: "*musllinux*"
|
133 |
+
|
134 |
+
- uses: actions/upload-artifact@v2
|
135 |
+
with:
|
136 |
+
path: dist
|
137 |
+
|
138 |
+
upload:
|
139 |
+
name: Upload to PyPi and create release
|
140 |
+
runs-on: ubuntu-latest
|
141 |
+
needs: [build_wheels, create_sdist, get_next_version]
|
142 |
+
steps:
|
143 |
+
- uses: actions/download-artifact@v2
|
144 |
+
with:
|
145 |
+
name: artifact
|
146 |
+
path: dist
|
147 |
+
|
148 |
+
# build the PyPI package and upload it
|
149 |
+
- name: upload
|
150 |
+
env:
|
151 |
+
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
152 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
153 |
+
run: |
|
154 |
+
pip install setuptools wheel twine
|
155 |
+
python3 -m twine upload --repository pypi dist/*
|
156 |
+
|
157 |
+
# create the release on github
|
158 |
+
- name: create release on github
|
159 |
+
uses: ncipollo/release-action@v1
|
160 |
+
with:
|
161 |
+
tag: '${{ needs.get_next_version.outputs.new_tag }}'
|
fairseq/.gitignore
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# JetBrains PyCharm IDE
|
2 |
+
.idea/
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# macOS dir files
|
13 |
+
.DS_Store
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
env/
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
|
34 |
+
# Checkpoints
|
35 |
+
checkpoints
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
.hypothesis/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# pyenv
|
83 |
+
.python-version
|
84 |
+
|
85 |
+
# celery beat schedule file
|
86 |
+
celerybeat-schedule
|
87 |
+
|
88 |
+
# SageMath parsed files
|
89 |
+
*.sage.py
|
90 |
+
|
91 |
+
# dotenv
|
92 |
+
.env
|
93 |
+
|
94 |
+
# virtualenv
|
95 |
+
.venv
|
96 |
+
venv/
|
97 |
+
ENV/
|
98 |
+
|
99 |
+
# Spyder project settings
|
100 |
+
.spyderproject
|
101 |
+
.spyproject
|
102 |
+
|
103 |
+
# Rope project settings
|
104 |
+
.ropeproject
|
105 |
+
|
106 |
+
# mkdocs documentation
|
107 |
+
/site
|
108 |
+
|
109 |
+
# mypy
|
110 |
+
.mypy_cache/
|
111 |
+
|
112 |
+
# Generated files
|
113 |
+
/fairseq/temporal_convolution_tbc
|
114 |
+
/fairseq/modules/*_layer/*_forward.cu
|
115 |
+
/fairseq/modules/*_layer/*_backward.cu
|
116 |
+
/fairseq/version.py
|
117 |
+
|
118 |
+
# data
|
119 |
+
data-bin/
|
120 |
+
|
121 |
+
# reranking
|
122 |
+
/examples/reranking/rerank_data
|
123 |
+
|
124 |
+
# Cython-generated C++ source files
|
125 |
+
/fairseq/data/data_utils_fast.cpp
|
126 |
+
/fairseq/data/token_block_utils_fast.cpp
|
127 |
+
|
128 |
+
# VSCODE
|
129 |
+
.vscode/ftp-sync.json
|
130 |
+
.vscode/settings.json
|
131 |
+
|
132 |
+
# Experimental Folder
|
133 |
+
experimental/*
|
134 |
+
|
135 |
+
# Weights and Biases logs
|
136 |
+
wandb/
|
137 |
+
|
138 |
+
# Hydra artifacts
|
139 |
+
nohup.out
|
140 |
+
multirun
|
141 |
+
outputs
|
fairseq/.gitmodules
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "fairseq/model_parallel/megatron"]
|
2 |
+
path = fairseq/model_parallel/megatron
|
3 |
+
url = https://github.com/ngoyal2707/Megatron-LM
|
4 |
+
branch = fairseq
|
fairseq/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exclude: 'build|stubs'
|
2 |
+
|
3 |
+
default_language_version:
|
4 |
+
python: python3
|
5 |
+
|
6 |
+
repos:
|
7 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
8 |
+
rev: v4.1.0
|
9 |
+
hooks:
|
10 |
+
- id: trailing-whitespace
|
11 |
+
- id: check-ast
|
12 |
+
- id: check-merge-conflict
|
13 |
+
- id: no-commit-to-branch
|
14 |
+
args: ['--branch=master']
|
15 |
+
- id: check-added-large-files
|
16 |
+
args: ['--maxkb=500']
|
17 |
+
- id: end-of-file-fixer
|
18 |
+
|
19 |
+
- repo: https://github.com/ambv/black
|
20 |
+
rev: 22.3.0
|
21 |
+
hooks:
|
22 |
+
- id: black
|
23 |
+
language_version: python3.8
|
24 |
+
|
25 |
+
- repo: https://gitlab.com/pycqa/flake8
|
26 |
+
rev: 3.9.2
|
27 |
+
hooks:
|
28 |
+
- id: flake8
|
29 |
+
args: [
|
30 |
+
# only error for syntax errors and undefined names
|
31 |
+
"--select=E9,F63,F7,F82",
|
32 |
+
]
|
33 |
+
|
34 |
+
- repo: https://github.com/pycqa/isort
|
35 |
+
rev: 5.10.1
|
36 |
+
hooks:
|
37 |
+
- id: isort
|
38 |
+
exclude: README.md
|
39 |
+
additional_dependencies: [toml]
|
40 |
+
args: ["--profile", "black"]
|
fairseq/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
## Enforcement
|
56 |
+
|
57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
58 |
+
reported by contacting the project team at <[email protected]>. All
|
59 |
+
complaints will be reviewed and investigated and will result in a response that
|
60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
62 |
+
Further details of specific enforcement policies may be posted separately.
|
63 |
+
|
64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
65 |
+
faith may face temporary or permanent repercussions as determined by other
|
66 |
+
members of the project's leadership.
|
67 |
+
|
68 |
+
## Attribution
|
69 |
+
|
70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
72 |
+
|
73 |
+
[homepage]: https://www.contributor-covenant.org
|
74 |
+
|
75 |
+
For answers to common questions about this code of conduct, see
|
76 |
+
https://www.contributor-covenant.org/faq
|
77 |
+
|
fairseq/CONTRIBUTING.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
## License
|
26 |
+
By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
|
27 |
+
you agree that your contributions will be licensed under the LICENSE file in
|
28 |
+
the root directory of this source tree.
|
29 |
+
|
30 |
+
## Pre-commit hooks
|
31 |
+
In order to ensure your code lints, there are pre-commit hooks configured in the repository which you can install.
|
32 |
+
After installation, they will automatically run each time you commit.
|
33 |
+
An abbreviated guide is given below; for more information, refer to [the offical pre-commit documentation](https://pre-commit.com/).
|
34 |
+
|
35 |
+
### Installation
|
36 |
+
```
|
37 |
+
pip install pre-commit
|
38 |
+
pre-commit install
|
39 |
+
```
|
40 |
+
|
41 |
+
### Usage
|
42 |
+
Just commit your changes:
|
43 |
+
```
|
44 |
+
git commit -m "My informative commit message"
|
45 |
+
```
|
46 |
+
|
47 |
+
If there was a failure, you will get feedback
|
48 |
+
```
|
49 |
+
[INFO] Initializing environment for https://github.com/PyCQA/flake8.
|
50 |
+
[INFO] Installing environment for https://github.com/pre-commit/pre-commit-hooks.
|
51 |
+
[INFO] Once installed this environment will be reused.
|
52 |
+
[INFO] This may take a few minutes...
|
53 |
+
[INFO] Installing environment for https://github.com/PyCQA/flake8.
|
54 |
+
[INFO] Once installed this environment will be reused.
|
55 |
+
[INFO] This may take a few minutes...
|
56 |
+
Trim Trailing Whitespace.................................................Failed
|
57 |
+
- hook id: trailing-whitespace
|
58 |
+
- exit code: 1
|
59 |
+
- files were modified by this hook
|
60 |
+
Fixing examples/nllb/modeling/wmt15_benchmark/eval_langs2.sh
|
61 |
+
Fix End of Files.........................................................Failed
|
62 |
+
- hook id: end-of-file-fixer
|
63 |
+
- exit code: 1
|
64 |
+
- files were modified by this hook
|
65 |
+
Fixing examples/few_shot/scripts/schedule_jobs_few_shot.py
|
66 |
+
flake8...................................................................Passed
|
67 |
+
```
|
68 |
+
|
69 |
+
Certain hooks modify your files to comply.
|
70 |
+
To include these modifications, you will need to add them (i.e. `git add ...`) and commit again.
|
71 |
+
|
72 |
+
If all is well, you should see something like:
|
73 |
+
```
|
74 |
+
Trim Trailing Whitespace.................................................Passed
|
75 |
+
Fix End of Files.........................................................Passed
|
76 |
+
flake8...................................................................Passed
|
77 |
+
[gshard-fix-ci 8698644e1] Fix lint, add pre-commit hooks
|
78 |
+
10 files changed, 148 insertions(+), 110 deletions(-)
|
79 |
+
create mode 100644 .flake8
|
80 |
+
create mode 100644 .pre-commit-config.yaml
|
81 |
+
rename examples/nllb/modeling/wmt15_benchmark/{eval_langs2.py => eval_langs2.sh} (99%)
|
82 |
+
```
|
fairseq/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
fairseq/MANIFEST.in
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
include fairseq/version.txt
|
fairseq/README.md
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="docs/fairseq_logo.png" width="150">
|
3 |
+
<br />
|
4 |
+
<br />
|
5 |
+
<a href="https://opensource.fb.com/support-ukraine"><img alt="Support Ukraine" src="https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB" /></a>
|
6 |
+
<a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
7 |
+
<a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
|
8 |
+
<a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
|
9 |
+
<a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
|
10 |
+
<a href="https://app.circleci.com/pipelines/github/facebookresearch/fairseq/"><img alt="CicleCI Status" src="https://circleci.com/gh/facebookresearch/fairseq.svg?style=shield" /></a>
|
11 |
+
</p>
|
12 |
+
|
13 |
+
--------------------------------------------------------------------------------
|
14 |
+
|
15 |
+
Fairseq(-py) is a sequence modeling toolkit that allows researchers and
|
16 |
+
developers to train custom models for translation, summarization, language
|
17 |
+
modeling and other text generation tasks.
|
18 |
+
|
19 |
+
We provide reference implementations of various sequence modeling papers:
|
20 |
+
|
21 |
+
<details><summary>List of implemented papers</summary><p>
|
22 |
+
|
23 |
+
* **Convolutional Neural Networks (CNN)**
|
24 |
+
+ [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
|
25 |
+
+ [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
26 |
+
+ [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
27 |
+
+ [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
28 |
+
+ [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
29 |
+
* **LightConv and DynamicConv models**
|
30 |
+
+ [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
31 |
+
* **Long Short-Term Memory (LSTM) networks**
|
32 |
+
+ Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
|
33 |
+
* **Transformer (self-attention) networks**
|
34 |
+
+ Attention Is All You Need (Vaswani et al., 2017)
|
35 |
+
+ [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
36 |
+
+ [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
37 |
+
+ [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
|
38 |
+
+ [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
|
39 |
+
+ [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
|
40 |
+
+ [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
|
41 |
+
+ [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
42 |
+
+ [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
43 |
+
+ [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
44 |
+
+ [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
|
45 |
+
+ [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
46 |
+
+ [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
47 |
+
+ [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
48 |
+
+ [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
49 |
+
+ [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
|
50 |
+
+ [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
|
51 |
+
+ [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
52 |
+
+ [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
|
53 |
+
+ [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
|
54 |
+
+ [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430)
|
55 |
+
+ [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
|
56 |
+
+ [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
|
57 |
+
+ [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680)
|
58 |
+
+ [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf)
|
59 |
+
+ [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf)
|
60 |
+
+ [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md)
|
61 |
+
* **Non-autoregressive Transformers**
|
62 |
+
+ Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
|
63 |
+
+ Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
|
64 |
+
+ Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
|
65 |
+
+ Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
|
66 |
+
+ [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
67 |
+
* **Finetuning**
|
68 |
+
+ [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
|
69 |
+
|
70 |
+
</p></details>
|
71 |
+
|
72 |
+
### What's New:
|
73 |
+
* May 2023 [Released models for Scaling Speech Technology to 1,000+ Languages (Pratap, et al., 2023)](examples/mms/README.md)
|
74 |
+
* June 2022 [Released code for wav2vec-U 2.0 from Towards End-to-end Unsupervised Speech Recognition (Liu, et al., 2022)](examples/wav2vec/unsupervised/README.md)
|
75 |
+
* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers)
|
76 |
+
* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md)
|
77 |
+
* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
|
78 |
+
* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)
|
79 |
+
* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
|
80 |
+
* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
|
81 |
+
* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
|
82 |
+
* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
|
83 |
+
* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
|
84 |
+
* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
|
85 |
+
* February 2021 [Added LASER training code](examples/laser/README.md)
|
86 |
+
* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
|
87 |
+
* December 2020: [GottBERT model and code released](examples/gottbert/README.md)
|
88 |
+
* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
|
89 |
+
* [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
|
90 |
+
* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
|
91 |
+
* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
|
92 |
+
* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
|
93 |
+
* October 2020: [Added CRISS models and code](examples/criss/README.md)
|
94 |
+
|
95 |
+
<details><summary>Previous updates</summary><p>
|
96 |
+
|
97 |
+
* September 2020: [Added Linformer code](examples/linformer/README.md)
|
98 |
+
* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
|
99 |
+
* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
|
100 |
+
* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
|
101 |
+
* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
|
102 |
+
* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
|
103 |
+
* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
|
104 |
+
* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
|
105 |
+
* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
|
106 |
+
* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
|
107 |
+
* February 2020: [mBART model and code released](examples/mbart/README.md)
|
108 |
+
* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
|
109 |
+
* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
|
110 |
+
* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
|
111 |
+
* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
|
112 |
+
* November 2019: [BART model and code released](examples/bart/README.md)
|
113 |
+
* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
|
114 |
+
* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
|
115 |
+
* August 2019: [WMT'19 models released](examples/wmt19/README.md)
|
116 |
+
* July 2019: fairseq relicensed under MIT license
|
117 |
+
* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
|
118 |
+
* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
|
119 |
+
|
120 |
+
</p></details>
|
121 |
+
|
122 |
+
### Features:
|
123 |
+
|
124 |
+
* multi-GPU training on one machine or across multiple machines (data and model parallel)
|
125 |
+
* fast generation on both CPU and GPU with multiple search algorithms implemented:
|
126 |
+
+ beam search
|
127 |
+
+ Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
|
128 |
+
+ sampling (unconstrained, top-k and top-p/nucleus)
|
129 |
+
+ [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
|
130 |
+
* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
|
131 |
+
* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
|
132 |
+
* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
|
133 |
+
* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
|
134 |
+
* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
|
135 |
+
* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
|
136 |
+
|
137 |
+
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
|
138 |
+
with a convenient `torch.hub` interface:
|
139 |
+
|
140 |
+
``` python
|
141 |
+
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
|
142 |
+
en2de.translate('Hello world', beam=5)
|
143 |
+
# 'Hallo Welt'
|
144 |
+
```
|
145 |
+
|
146 |
+
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
|
147 |
+
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
|
148 |
+
|
149 |
+
# Requirements and Installation
|
150 |
+
|
151 |
+
* [PyTorch](http://pytorch.org/) version >= 1.10.0
|
152 |
+
* Python version >= 3.8
|
153 |
+
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
|
154 |
+
* **To install fairseq** and develop locally:
|
155 |
+
|
156 |
+
``` bash
|
157 |
+
git clone https://github.com/pytorch/fairseq
|
158 |
+
cd fairseq
|
159 |
+
pip install --editable ./
|
160 |
+
|
161 |
+
# on MacOS:
|
162 |
+
# CFLAGS="-stdlib=libc++" pip install --editable ./
|
163 |
+
|
164 |
+
# to install the latest stable release (0.10.x)
|
165 |
+
# pip install fairseq
|
166 |
+
```
|
167 |
+
|
168 |
+
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
|
169 |
+
|
170 |
+
``` bash
|
171 |
+
git clone https://github.com/NVIDIA/apex
|
172 |
+
cd apex
|
173 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
|
174 |
+
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
|
175 |
+
--global-option="--fast_multihead_attn" ./
|
176 |
+
```
|
177 |
+
|
178 |
+
* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
|
179 |
+
* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
|
180 |
+
as command line options to `nvidia-docker run` .
|
181 |
+
|
182 |
+
# Getting Started
|
183 |
+
|
184 |
+
The [full documentation](https://fairseq.readthedocs.io/) contains instructions
|
185 |
+
for getting started, training new models and extending fairseq with new model
|
186 |
+
types and tasks.
|
187 |
+
|
188 |
+
# Pre-trained models and examples
|
189 |
+
|
190 |
+
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
|
191 |
+
as well as example training and evaluation commands.
|
192 |
+
|
193 |
+
* [Translation](examples/translation/README.md): convolutional and transformer models are available
|
194 |
+
* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
|
195 |
+
|
196 |
+
We also have more detailed READMEs to reproduce results from specific papers:
|
197 |
+
|
198 |
+
* [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md)
|
199 |
+
* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
200 |
+
* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
201 |
+
* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
202 |
+
* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
|
203 |
+
* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
204 |
+
* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
205 |
+
* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
|
206 |
+
* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
|
207 |
+
* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
208 |
+
* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
209 |
+
* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
210 |
+
* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
211 |
+
* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
212 |
+
* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
213 |
+
* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
214 |
+
* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
215 |
+
* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
216 |
+
* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
217 |
+
* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
218 |
+
* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
|
219 |
+
|
220 |
+
# Join the fairseq community
|
221 |
+
|
222 |
+
* Twitter: https://twitter.com/fairseq
|
223 |
+
* Facebook page: https://www.facebook.com/groups/fairseq.users
|
224 |
+
* Google group: https://groups.google.com/forum/#!forum/fairseq-users
|
225 |
+
|
226 |
+
# License
|
227 |
+
|
228 |
+
fairseq(-py) is MIT-licensed.
|
229 |
+
The license applies to the pre-trained models as well.
|
230 |
+
|
231 |
+
# Citation
|
232 |
+
|
233 |
+
Please cite as:
|
234 |
+
|
235 |
+
``` bibtex
|
236 |
+
@inproceedings{ott2019fairseq,
|
237 |
+
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
|
238 |
+
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
|
239 |
+
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
|
240 |
+
year = {2019},
|
241 |
+
}
|
242 |
+
```
|
fairseq/RELEASE.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Creating a New Release
|
2 |
+
|
3 |
+
In order to create a new release:
|
4 |
+
|
5 |
+
1. Navigate to the [Fairseq Workflows](https://github.com/facebookresearch/fairseq/actions) and find the one named _Fairseq Release_.
|
6 |
+
|
7 |
+
2. Under _Run Workflow_ choose the branch `main` and for _Release Type_ enter either `major`, `minor`, or `patch`.
|
8 |
+
|
9 |
+
3. A branch named `$new_version-release` will be created where the `version.txt` file is updated. Merge those changes into `main`.
|
10 |
+
|
11 |
+
4. Make sure that a [new PYPI package](https://pypi.org/project/fairseq/) has been uploaded.
|
12 |
+
|
13 |
+
5. Make sure that a [new github release](https://github.com/facebookresearch/fairseq/releases) has been created.
|
fairseq/docs/Makefile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Minimal makefile for Sphinx documentation
|
2 |
+
#
|
3 |
+
|
4 |
+
# You can set these variables from the command line.
|
5 |
+
SPHINXOPTS =
|
6 |
+
SPHINXBUILD = python -msphinx
|
7 |
+
SPHINXPROJ = fairseq
|
8 |
+
SOURCEDIR = .
|
9 |
+
BUILDDIR = _build
|
10 |
+
|
11 |
+
# Put it first so that "make" without argument is like "make help".
|
12 |
+
help:
|
13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
14 |
+
|
15 |
+
.PHONY: help Makefile
|
16 |
+
|
17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
19 |
+
%: Makefile
|
20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
fairseq/docs/command_line_tools.rst
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _Command-line Tools:
|
2 |
+
|
3 |
+
Command-line Tools
|
4 |
+
==================
|
5 |
+
|
6 |
+
Fairseq provides several command-line tools for training and evaluating models:
|
7 |
+
|
8 |
+
- :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
|
9 |
+
- :ref:`fairseq-train`: Train a new model on one or multiple GPUs
|
10 |
+
- :ref:`fairseq-generate`: Translate pre-processed data with a trained model
|
11 |
+
- :ref:`fairseq-interactive`: Translate raw text with a trained model
|
12 |
+
- :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
|
13 |
+
- :ref:`fairseq-eval-lm`: Language model evaluation
|
14 |
+
|
15 |
+
|
16 |
+
.. _fairseq-preprocess:
|
17 |
+
|
18 |
+
fairseq-preprocess
|
19 |
+
~~~~~~~~~~~~~~~~~~
|
20 |
+
.. automodule:: fairseq_cli.preprocess
|
21 |
+
|
22 |
+
.. argparse::
|
23 |
+
:module: fairseq.options
|
24 |
+
:func: get_preprocessing_parser
|
25 |
+
:prog: fairseq-preprocess
|
26 |
+
|
27 |
+
|
28 |
+
.. _fairseq-train:
|
29 |
+
|
30 |
+
fairseq-train
|
31 |
+
~~~~~~~~~~~~~
|
32 |
+
.. automodule:: fairseq_cli.train
|
33 |
+
|
34 |
+
.. argparse::
|
35 |
+
:module: fairseq.options
|
36 |
+
:func: get_training_parser
|
37 |
+
:prog: fairseq-train
|
38 |
+
|
39 |
+
|
40 |
+
.. _fairseq-generate:
|
41 |
+
|
42 |
+
fairseq-generate
|
43 |
+
~~~~~~~~~~~~~~~~
|
44 |
+
.. automodule:: fairseq_cli.generate
|
45 |
+
|
46 |
+
.. argparse::
|
47 |
+
:module: fairseq.options
|
48 |
+
:func: get_generation_parser
|
49 |
+
:prog: fairseq-generate
|
50 |
+
|
51 |
+
|
52 |
+
.. _fairseq-interactive:
|
53 |
+
|
54 |
+
fairseq-interactive
|
55 |
+
~~~~~~~~~~~~~~~~~~~
|
56 |
+
.. automodule:: fairseq_cli.interactive
|
57 |
+
|
58 |
+
.. argparse::
|
59 |
+
:module: fairseq.options
|
60 |
+
:func: get_interactive_generation_parser
|
61 |
+
:prog: fairseq-interactive
|
62 |
+
|
63 |
+
|
64 |
+
.. _fairseq-score:
|
65 |
+
|
66 |
+
fairseq-score
|
67 |
+
~~~~~~~~~~~~~
|
68 |
+
.. automodule:: fairseq_cli.score
|
69 |
+
|
70 |
+
.. argparse::
|
71 |
+
:module: fairseq_cli.score
|
72 |
+
:func: get_parser
|
73 |
+
:prog: fairseq-score
|
74 |
+
|
75 |
+
|
76 |
+
.. _fairseq-eval-lm:
|
77 |
+
|
78 |
+
fairseq-eval-lm
|
79 |
+
~~~~~~~~~~~~~~~
|
80 |
+
.. automodule:: fairseq_cli.eval_lm
|
81 |
+
|
82 |
+
.. argparse::
|
83 |
+
:module: fairseq.options
|
84 |
+
:func: get_eval_lm_parser
|
85 |
+
:prog: fairseq-eval-lm
|
fairseq/docs/conf.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
#
|
4 |
+
# fairseq documentation build configuration file, created by
|
5 |
+
# sphinx-quickstart on Fri Aug 17 21:45:30 2018.
|
6 |
+
#
|
7 |
+
# This file is execfile()d with the current directory set to its
|
8 |
+
# containing dir.
|
9 |
+
#
|
10 |
+
# Note that not all possible configuration values are present in this
|
11 |
+
# autogenerated file.
|
12 |
+
#
|
13 |
+
# All configuration values have a default; values that are commented out
|
14 |
+
# serve to show the default.
|
15 |
+
|
16 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
17 |
+
# add these directories to sys.path here. If the directory is relative to the
|
18 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
19 |
+
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
from fairseq import __version__
|
23 |
+
|
24 |
+
|
25 |
+
# source code directory, relative to this file, for sphinx-autobuild
|
26 |
+
sys.path.insert(0, os.path.abspath(".."))
|
27 |
+
|
28 |
+
source_suffix = [".rst"]
|
29 |
+
|
30 |
+
# -- General configuration ------------------------------------------------
|
31 |
+
|
32 |
+
# If your documentation needs a minimal Sphinx version, state it here.
|
33 |
+
#
|
34 |
+
# needs_sphinx = '1.0'
|
35 |
+
|
36 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
37 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
38 |
+
# ones.
|
39 |
+
extensions = [
|
40 |
+
"sphinx.ext.autodoc",
|
41 |
+
"sphinx.ext.intersphinx",
|
42 |
+
"sphinx.ext.viewcode",
|
43 |
+
"sphinx.ext.napoleon",
|
44 |
+
"sphinxarg.ext",
|
45 |
+
]
|
46 |
+
|
47 |
+
# Add any paths that contain templates here, relative to this directory.
|
48 |
+
templates_path = ["_templates"]
|
49 |
+
|
50 |
+
# The master toctree document.
|
51 |
+
master_doc = "index"
|
52 |
+
|
53 |
+
# General information about the project.
|
54 |
+
project = "fairseq"
|
55 |
+
copyright = "Facebook AI Research (FAIR)"
|
56 |
+
author = "Facebook AI Research (FAIR)"
|
57 |
+
|
58 |
+
github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/"
|
59 |
+
|
60 |
+
# The version info for the project you're documenting, acts as replacement for
|
61 |
+
# |version| and |release|, also used in various other places throughout the
|
62 |
+
# built documents.
|
63 |
+
#
|
64 |
+
# The short X.Y version.
|
65 |
+
version = __version__
|
66 |
+
# The full version, including alpha/beta/rc tags.
|
67 |
+
release = __version__
|
68 |
+
|
69 |
+
# The language for content autogenerated by Sphinx. Refer to documentation
|
70 |
+
# for a list of supported languages.
|
71 |
+
#
|
72 |
+
# This is also used if you do content translation via gettext catalogs.
|
73 |
+
# Usually you set "language" from the command line for these cases.
|
74 |
+
language = None
|
75 |
+
|
76 |
+
# List of patterns, relative to source directory, that match files and
|
77 |
+
# directories to ignore when looking for source files.
|
78 |
+
# This patterns also effect to html_static_path and html_extra_path
|
79 |
+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
80 |
+
|
81 |
+
# The name of the Pygments (syntax highlighting) style to use.
|
82 |
+
pygments_style = "sphinx"
|
83 |
+
highlight_language = "python"
|
84 |
+
|
85 |
+
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
86 |
+
todo_include_todos = False
|
87 |
+
|
88 |
+
|
89 |
+
# -- Options for HTML output ----------------------------------------------
|
90 |
+
|
91 |
+
html_theme = "classic"
|
92 |
+
|
93 |
+
# Example configuration for intersphinx: refer to the Python standard library.
|
94 |
+
intersphinx_mapping = {
|
95 |
+
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
|
96 |
+
"python": ("https://docs.python.org/", None),
|
97 |
+
"torch": ("https://pytorch.org/docs/master/", None),
|
98 |
+
}
|
fairseq/docs/criterions.rst
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. _Criterions:
|
5 |
+
|
6 |
+
Criterions
|
7 |
+
==========
|
8 |
+
|
9 |
+
Criterions compute the loss function given the model and batch, roughly::
|
10 |
+
|
11 |
+
loss = criterion(model, batch)
|
12 |
+
|
13 |
+
.. automodule:: fairseq.criterions
|
14 |
+
:members:
|
15 |
+
|
16 |
+
.. autoclass:: fairseq.criterions.FairseqCriterion
|
17 |
+
:members:
|
18 |
+
:undoc-members:
|
19 |
+
|
20 |
+
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
|
21 |
+
:members:
|
22 |
+
:undoc-members:
|
23 |
+
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
|
24 |
+
:members:
|
25 |
+
:undoc-members:
|
26 |
+
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
|
27 |
+
:members:
|
28 |
+
:undoc-members:
|
29 |
+
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
|
30 |
+
:members:
|
31 |
+
:undoc-members:
|
fairseq/docs/data.rst
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. module:: fairseq.data
|
5 |
+
|
6 |
+
Data Loading and Utilities
|
7 |
+
==========================
|
8 |
+
|
9 |
+
.. _datasets:
|
10 |
+
|
11 |
+
Datasets
|
12 |
+
--------
|
13 |
+
|
14 |
+
**Datasets** define the data format and provide helpers for creating
|
15 |
+
mini-batches.
|
16 |
+
|
17 |
+
.. autoclass:: fairseq.data.FairseqDataset
|
18 |
+
:members:
|
19 |
+
.. autoclass:: fairseq.data.LanguagePairDataset
|
20 |
+
:members:
|
21 |
+
.. autoclass:: fairseq.data.MonolingualDataset
|
22 |
+
:members:
|
23 |
+
|
24 |
+
**Helper Datasets**
|
25 |
+
|
26 |
+
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
|
27 |
+
provide additional functionality:
|
28 |
+
|
29 |
+
.. autoclass:: fairseq.data.BacktranslationDataset
|
30 |
+
:members:
|
31 |
+
.. autoclass:: fairseq.data.ConcatDataset
|
32 |
+
:members:
|
33 |
+
.. autoclass:: fairseq.data.ResamplingDataset
|
34 |
+
:members:
|
35 |
+
.. autoclass:: fairseq.data.RoundRobinZipDatasets
|
36 |
+
:members:
|
37 |
+
.. autoclass:: fairseq.data.TransformEosDataset
|
38 |
+
:members:
|
39 |
+
|
40 |
+
|
41 |
+
Dictionary
|
42 |
+
----------
|
43 |
+
|
44 |
+
.. autoclass:: fairseq.data.Dictionary
|
45 |
+
:members:
|
46 |
+
|
47 |
+
|
48 |
+
Iterators
|
49 |
+
---------
|
50 |
+
|
51 |
+
.. autoclass:: fairseq.data.CountingIterator
|
52 |
+
:members:
|
53 |
+
.. autoclass:: fairseq.data.EpochBatchIterator
|
54 |
+
:members:
|
55 |
+
.. autoclass:: fairseq.data.GroupedIterator
|
56 |
+
:members:
|
57 |
+
.. autoclass:: fairseq.data.ShardedIterator
|
58 |
+
:members:
|
fairseq/docs/docutils.conf
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[writers]
|
2 |
+
option-limit=0
|
fairseq/docs/fairseq_logo.png
ADDED
![]() |
fairseq/docs/getting_started.rst
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Evaluating Pre-trained Models
|
2 |
+
=============================
|
3 |
+
|
4 |
+
First, download a pre-trained model along with its vocabularies:
|
5 |
+
|
6 |
+
.. code-block:: console
|
7 |
+
|
8 |
+
> curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
|
9 |
+
|
10 |
+
This model uses a `Byte Pair Encoding (BPE)
|
11 |
+
vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply
|
12 |
+
the encoding to the source text before it can be translated. This can be
|
13 |
+
done with the
|
14 |
+
`apply\_bpe.py <https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/apply_bpe.py>`__
|
15 |
+
script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
|
16 |
+
used as a continuation marker and the original text can be easily
|
17 |
+
recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
|
18 |
+
flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
|
19 |
+
using ``tokenizer.perl`` from
|
20 |
+
`mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
|
21 |
+
|
22 |
+
Let's use :ref:`fairseq-interactive` to generate translations interactively.
|
23 |
+
Here, we use a beam size of 5 and preprocess the input with the Moses
|
24 |
+
tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
|
25 |
+
remove the BPE continuation markers and detokenize the output.
|
26 |
+
|
27 |
+
.. code-block:: console
|
28 |
+
|
29 |
+
> MODEL_DIR=wmt14.en-fr.fconv-py
|
30 |
+
> fairseq-interactive \
|
31 |
+
--path $MODEL_DIR/model.pt $MODEL_DIR \
|
32 |
+
--beam 5 --source-lang en --target-lang fr \
|
33 |
+
--tokenizer moses \
|
34 |
+
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
|
35 |
+
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
|
36 |
+
| [en] dictionary: 44206 types
|
37 |
+
| [fr] dictionary: 44463 types
|
38 |
+
| Type the input sentence and press return:
|
39 |
+
Why is it rare to discover new marine mammal species?
|
40 |
+
S-0 Why is it rare to discover new marine mam@@ mal species ?
|
41 |
+
H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
|
42 |
+
P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
|
43 |
+
|
44 |
+
This generation script produces three types of outputs: a line prefixed
|
45 |
+
with *O* is a copy of the original source sentence; *H* is the
|
46 |
+
hypothesis along with an average log-likelihood; and *P* is the
|
47 |
+
positional score per token position, including the
|
48 |
+
end-of-sentence marker which is omitted from the text.
|
49 |
+
|
50 |
+
Other types of output lines you might see are *D*, the detokenized hypothesis,
|
51 |
+
*T*, the reference target, *A*, alignment info, *E* the history of generation steps.
|
52 |
+
|
53 |
+
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
|
54 |
+
full list of pre-trained models available.
|
55 |
+
|
56 |
+
Training a New Model
|
57 |
+
====================
|
58 |
+
|
59 |
+
The following tutorial is for machine translation. For an example of how
|
60 |
+
to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
|
61 |
+
``examples/`` directory.
|
62 |
+
|
63 |
+
Data Pre-processing
|
64 |
+
-------------------
|
65 |
+
|
66 |
+
Fairseq contains example pre-processing scripts for several translation
|
67 |
+
datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
|
68 |
+
2014 (English-German). To pre-process and binarize the IWSLT dataset:
|
69 |
+
|
70 |
+
.. code-block:: console
|
71 |
+
|
72 |
+
> cd examples/translation/
|
73 |
+
> bash prepare-iwslt14.sh
|
74 |
+
> cd ../..
|
75 |
+
> TEXT=examples/translation/iwslt14.tokenized.de-en
|
76 |
+
> fairseq-preprocess --source-lang de --target-lang en \
|
77 |
+
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
|
78 |
+
--destdir data-bin/iwslt14.tokenized.de-en
|
79 |
+
|
80 |
+
This will write binarized data that can be used for model training to
|
81 |
+
``data-bin/iwslt14.tokenized.de-en``.
|
82 |
+
|
83 |
+
Training
|
84 |
+
--------
|
85 |
+
|
86 |
+
Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
|
87 |
+
well for the IWSLT 2014 dataset:
|
88 |
+
|
89 |
+
.. code-block:: console
|
90 |
+
|
91 |
+
> mkdir -p checkpoints/fconv
|
92 |
+
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
|
93 |
+
--optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
|
94 |
+
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
|
95 |
+
|
96 |
+
By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
|
97 |
+
``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
|
98 |
+
change the number of GPU devices that will be used.
|
99 |
+
|
100 |
+
Also note that the batch size is specified in terms of the maximum
|
101 |
+
number of tokens per batch (``--max-tokens``). You may need to use a
|
102 |
+
smaller value depending on the available GPU memory on your system.
|
103 |
+
|
104 |
+
Generation
|
105 |
+
----------
|
106 |
+
|
107 |
+
Once your model is trained, you can generate translations using
|
108 |
+
:ref:`fairseq-generate` **(for binarized data)** or
|
109 |
+
:ref:`fairseq-interactive` **(for raw text)**:
|
110 |
+
|
111 |
+
.. code-block:: console
|
112 |
+
|
113 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
114 |
+
--path checkpoints/fconv/checkpoint_best.pt \
|
115 |
+
--batch-size 128 --beam 5
|
116 |
+
| [de] dictionary: 35475 types
|
117 |
+
| [en] dictionary: 24739 types
|
118 |
+
| data-bin/iwslt14.tokenized.de-en test 6750 examples
|
119 |
+
| model fconv
|
120 |
+
| loaded checkpoint trainings/fconv/checkpoint_best.pt
|
121 |
+
S-721 danke .
|
122 |
+
T-721 thank you .
|
123 |
+
...
|
124 |
+
|
125 |
+
To generate translations with only a CPU, use the ``--cpu`` flag. BPE
|
126 |
+
continuation markers can be removed with the ``--remove-bpe`` flag.
|
127 |
+
|
128 |
+
Advanced Training Options
|
129 |
+
=========================
|
130 |
+
|
131 |
+
Large mini-batch training with delayed updates
|
132 |
+
----------------------------------------------
|
133 |
+
|
134 |
+
The ``--update-freq`` option can be used to accumulate gradients from
|
135 |
+
multiple mini-batches and delay updating, creating a larger effective
|
136 |
+
batch size. Delayed updates can also improve training speed by reducing
|
137 |
+
inter-GPU communication costs and by saving idle time caused by variance
|
138 |
+
in workload across GPUs. See `Ott et al.
|
139 |
+
(2018) <https://arxiv.org/abs/1806.00187>`__ for more details.
|
140 |
+
|
141 |
+
To train on a single GPU with an effective batch size that is equivalent
|
142 |
+
to training on 8 GPUs:
|
143 |
+
|
144 |
+
.. code-block:: console
|
145 |
+
|
146 |
+
> CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
|
147 |
+
|
148 |
+
Training with half precision floating point (FP16)
|
149 |
+
--------------------------------------------------
|
150 |
+
|
151 |
+
.. note::
|
152 |
+
|
153 |
+
FP16 training requires a Volta GPU and CUDA 9.1 or greater
|
154 |
+
|
155 |
+
Recent GPUs enable efficient half precision floating point computation,
|
156 |
+
e.g., using `Nvidia Tensor Cores
|
157 |
+
<https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html>`__.
|
158 |
+
Fairseq supports FP16 training with the ``--fp16`` flag:
|
159 |
+
|
160 |
+
.. code-block:: console
|
161 |
+
|
162 |
+
> fairseq-train --fp16 (...)
|
163 |
+
|
164 |
+
Distributed training
|
165 |
+
--------------------
|
166 |
+
|
167 |
+
Distributed training in fairseq is implemented on top of ``torch.distributed``.
|
168 |
+
The easiest way to launch jobs is with the `torch.distributed.launch
|
169 |
+
<https://pytorch.org/docs/stable/distributed.html#launch-utility>`__ tool.
|
170 |
+
|
171 |
+
For example, to train a large English-German Transformer model on 2 nodes each
|
172 |
+
with 8 GPUs (in total 16 GPUs), run the following command on each node,
|
173 |
+
replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
|
174 |
+
sure to update ``--master_addr`` to the IP address of the first node:
|
175 |
+
|
176 |
+
.. code-block:: console
|
177 |
+
|
178 |
+
> python -m torch.distributed.launch --nproc_per_node=8 \
|
179 |
+
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
|
180 |
+
--master_port=12345 \
|
181 |
+
$(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
|
182 |
+
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
|
183 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
184 |
+
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
|
185 |
+
--lr 0.0005 \
|
186 |
+
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
187 |
+
--max-tokens 3584 \
|
188 |
+
--max-epoch 70 \
|
189 |
+
--fp16
|
190 |
+
|
191 |
+
On SLURM clusters, fairseq will automatically detect the number of nodes and
|
192 |
+
GPUs, but a port number must be provided:
|
193 |
+
|
194 |
+
.. code-block:: console
|
195 |
+
|
196 |
+
> salloc --gpus=16 --nodes 2 (...)
|
197 |
+
> srun fairseq-train --distributed-port 12345 (...).
|
198 |
+
|
199 |
+
Sharding very large datasets
|
200 |
+
----------------------------
|
201 |
+
|
202 |
+
It can be challenging to train over very large datasets, particularly if your
|
203 |
+
machine does not have much system RAM. Most tasks in fairseq support training
|
204 |
+
over "sharded" datasets, in which the original dataset has been preprocessed
|
205 |
+
into non-overlapping chunks (or "shards").
|
206 |
+
|
207 |
+
For example, instead of preprocessing all your data into a single "data-bin"
|
208 |
+
directory, you can split the data and create "data-bin1", "data-bin2", etc.
|
209 |
+
Then you can adapt your training command like so:
|
210 |
+
|
211 |
+
.. code-block:: console
|
212 |
+
|
213 |
+
> fairseq-train data-bin1:data-bin2:data-bin3 (...)
|
214 |
+
|
215 |
+
Training will now iterate over each shard, one by one, with each shard
|
216 |
+
corresponding to an "epoch", thus reducing system memory usage.
|
fairseq/docs/hydra_integration.md
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Hydra
|
2 |
+
|
3 |
+
[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
|
4 |
+
framework that simplifies the development of research and other complex
|
5 |
+
applications. The key feature is the ability to dynamically create a
|
6 |
+
hierarchical configuration by composition and override it through config files
|
7 |
+
and the command line. The name Hydra comes from its ability to run multiple
|
8 |
+
similar jobs - much like a Hydra with multiple heads.
|
9 |
+
|
10 |
+
## Motivation
|
11 |
+
|
12 |
+
Until recently, all components in fairseq were configured through a shared
|
13 |
+
`args` namespace that was created at application startup. Components declared
|
14 |
+
their own `add_args` method to update the argparse parser, hoping that the names
|
15 |
+
would not clash with arguments from other components. While this model works for
|
16 |
+
smaller applications, as fairseq grew and became integrated into other
|
17 |
+
applications, this became problematic. In order to determine how to configure
|
18 |
+
each component, one needed to a) examine what args were added by this component,
|
19 |
+
and b) read the code to figure out what shared arguments it is using that were
|
20 |
+
added in other places. Reproducing models involved sharing commands that often
|
21 |
+
contained dozens of command line switches.
|
22 |
+
|
23 |
+
The model described above is still supported by fairseq for backward
|
24 |
+
compatibility, but will be deprecated some time in the future.
|
25 |
+
|
26 |
+
New components in fairseq should now create a dataclass that encapsulates all
|
27 |
+
parameters required to configure this component. The dataclass is registered
|
28 |
+
along with the component, and fairseq takes care of constructing and providing
|
29 |
+
this configuration object to the component's constructor. Note that sharing
|
30 |
+
parameters can optionally still work, but one has to explicitly point to the
|
31 |
+
"source of truth" (see inheritance example below). These changes make components
|
32 |
+
in fairseq more independent and re-usable by other applications: all that is
|
33 |
+
needed to create a component is to initialize its dataclass and overwrite some
|
34 |
+
of the defaults.
|
35 |
+
|
36 |
+
While configuring fairseq through command line (using either the legacy argparse
|
37 |
+
based or the new Hydra based entry points) is still fully supported, you can now
|
38 |
+
take advantage of configuring fairseq completely or piece-by-piece through
|
39 |
+
hierarchical YAML configuration files. These files can also be shipped as
|
40 |
+
examples that others can use to run an identically configured job.
|
41 |
+
|
42 |
+
Additionally, Hydra has a rich and growing [library of
|
43 |
+
plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
|
44 |
+
provide functionality such as hyperparameter sweeping (including using bayesian
|
45 |
+
optimization through the [Ax](https://github.com/facebook/Ax) library), job
|
46 |
+
launching across various platforms, and more.
|
47 |
+
|
48 |
+
## Creating or migrating components
|
49 |
+
|
50 |
+
In general, each new (or updated) component should provide a companion
|
51 |
+
[dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
|
52 |
+
typically located in the same file as the component and are passed as arguments
|
53 |
+
to the `register_*()` functions. Top-level configs that should be present in
|
54 |
+
every fairseq application are placed in the
|
55 |
+
[global](fairseq/dataclass/configs.py) config file and added to the
|
56 |
+
`FairseqConfig` object.
|
57 |
+
|
58 |
+
Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
|
59 |
+
classes are decorated with a `@dataclass` decorator, and typically inherit from
|
60 |
+
`FairseqDataclass` (which adds some functionality for backward compatibility).
|
61 |
+
Each field must have a type, and generally has metadata (such as a help string)
|
62 |
+
and a default value. Only primitive types or other config objects are allowed as
|
63 |
+
data types for each field.
|
64 |
+
|
65 |
+
#### Example:
|
66 |
+
|
67 |
+
```python
|
68 |
+
from dataclasses import dataclass, field
|
69 |
+
from fairseq.dataclass import FairseqDataclass
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class InteractiveConfig(FairseqDataclass):
|
73 |
+
buffer_size: int = field(
|
74 |
+
default=0,
|
75 |
+
metadata={
|
76 |
+
"help": "read this many sentences into a buffer before processing them"
|
77 |
+
},
|
78 |
+
)
|
79 |
+
input: str = field(
|
80 |
+
default="-",
|
81 |
+
metadata={"help": "file to read from; use - for stdin"},
|
82 |
+
)
|
83 |
+
```
|
84 |
+
|
85 |
+
### Inherting values
|
86 |
+
|
87 |
+
Some components require sharing a value. For example, a learning rate scheduler
|
88 |
+
and an optimizer may both need to know the initial learning rate value. One can
|
89 |
+
declare a field that, by default, will inherit its value from another config
|
90 |
+
node in the same hierarchy:
|
91 |
+
|
92 |
+
```python
|
93 |
+
@dataclass
|
94 |
+
FairseqAdamConfig(FairseqDataclass):
|
95 |
+
...
|
96 |
+
lr: List[float] = II("optimization.lr")
|
97 |
+
...
|
98 |
+
```
|
99 |
+
|
100 |
+
`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
|
101 |
+
the value one can use in a YAML config file or through command line to achieve
|
102 |
+
the same effect. Note that this assumes that there is an "optimization" config
|
103 |
+
object in the root config and it has a field called "lr".
|
104 |
+
|
105 |
+
### Tasks and Models
|
106 |
+
|
107 |
+
Creating Tasks and Models works same as before, except that legacy
|
108 |
+
implementations now inherit from `LegacyFairseq*` base classes, while new
|
109 |
+
components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
|
110 |
+
to the `register_*()` functions.
|
111 |
+
|
112 |
+
#### Task example:
|
113 |
+
|
114 |
+
```python
|
115 |
+
@dataclass
|
116 |
+
class LanguageModelingConfig(FairseqDataclass):
|
117 |
+
data: Optional[str] = field(
|
118 |
+
default=None, metadata={"help": "path to data directory"}
|
119 |
+
)
|
120 |
+
...
|
121 |
+
|
122 |
+
@register_task("language_modeling", dataclass=LanguageModelingConfig)
|
123 |
+
class LanguageModelingTask(FairseqTask):
|
124 |
+
...
|
125 |
+
@classmethod
|
126 |
+
def setup_task(cls, cfg: LanguageModelingConfig):
|
127 |
+
...
|
128 |
+
```
|
129 |
+
|
130 |
+
#### Model example:
|
131 |
+
|
132 |
+
```python
|
133 |
+
@dataclass
|
134 |
+
class TransformerLanguageModelConfig(FairseqDataclass):
|
135 |
+
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
136 |
+
default="relu", metadata={"help": "activation function to use"}
|
137 |
+
)
|
138 |
+
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
139 |
+
...
|
140 |
+
|
141 |
+
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
|
142 |
+
class TransformerLanguageModel(FairseqLanguageModel):
|
143 |
+
...
|
144 |
+
@classmethod
|
145 |
+
def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
|
146 |
+
...
|
147 |
+
```
|
148 |
+
|
149 |
+
### Other components
|
150 |
+
|
151 |
+
Other components work as before, but they now take their configuration dataclass
|
152 |
+
as the only constructor argument:
|
153 |
+
|
154 |
+
```python
|
155 |
+
@dataclass
|
156 |
+
class MosesTokenizerConfig(FairseqDataclass):
|
157 |
+
source_lang: str = field(default="en", metadata={"help": "source language"})
|
158 |
+
...
|
159 |
+
|
160 |
+
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
|
161 |
+
class MosesTokenizer(object):
|
162 |
+
def __init__(self, cfg: MosesTokenizerConfig):
|
163 |
+
...
|
164 |
+
```
|
165 |
+
|
166 |
+
Note that if you are adding a new registry for a new set of components, you need
|
167 |
+
to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
|
168 |
+
|
169 |
+
```python
|
170 |
+
@dataclass
|
171 |
+
class FairseqConfig(object):
|
172 |
+
...
|
173 |
+
my_new_registry: Any = None
|
174 |
+
```
|
175 |
+
|
176 |
+
## Training with `fairseq-hydra-train`
|
177 |
+
|
178 |
+
To fully take advantage of configuration flexibility offered by Hydra, you may
|
179 |
+
want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
|
180 |
+
tools such as `fairseq-train` will remain supported for the foreseeable future
|
181 |
+
but will be deprecated eventually.
|
182 |
+
|
183 |
+
On startup, Hydra will create a configuration object that contains a hierarchy
|
184 |
+
of all the necessary dataclasses populated with their default values in the
|
185 |
+
code. The default values are overwritten by values found in YAML files in
|
186 |
+
`fairseq/config` directory (which currently sets minimal defaults) and then
|
187 |
+
further overwritten by values provided through command line arguments.
|
188 |
+
|
189 |
+
Some of the most common use cases are shown below:
|
190 |
+
|
191 |
+
### 1. Override default values through command line:
|
192 |
+
|
193 |
+
```shell script
|
194 |
+
$ fairseq-hydra-train \
|
195 |
+
distributed_training.distributed_world_size=1 \
|
196 |
+
dataset.batch_size=2 \
|
197 |
+
task.data=data-bin \
|
198 |
+
model=transformer_lm/transformer_lm_gpt \
|
199 |
+
task=language_modeling \
|
200 |
+
optimization.max_update=5000
|
201 |
+
```
|
202 |
+
|
203 |
+
Note that along with explicitly providing values for parameters such as
|
204 |
+
`dataset.batch_size`, this also tells Hydra to overlay configuration found in
|
205 |
+
`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
|
206 |
+
values in the dataclass. If you want to train a model without specifying a
|
207 |
+
particular architecture you can simply specify `model=transformer_lm`. This only
|
208 |
+
works for migrated tasks and models.
|
209 |
+
|
210 |
+
### 2. Replace bundled configs with an external config:
|
211 |
+
|
212 |
+
```shell script
|
213 |
+
$ fairseq-hydra-train \
|
214 |
+
--config-dir /path/to/external/configs \
|
215 |
+
--config-name wiki103
|
216 |
+
```
|
217 |
+
|
218 |
+
where `/path/to/external/configs/wiki103.yaml` contains:
|
219 |
+
|
220 |
+
```yaml
|
221 |
+
# @package _group_
|
222 |
+
|
223 |
+
model:
|
224 |
+
_name: transformer_lm
|
225 |
+
distributed_training:
|
226 |
+
distributed_world_size: 1
|
227 |
+
dataset:
|
228 |
+
batch_size: 2
|
229 |
+
task:
|
230 |
+
_name: language_modeling
|
231 |
+
data: /path/to/data
|
232 |
+
add_bos_token: false
|
233 |
+
max_target_positions: 1024
|
234 |
+
optimization:
|
235 |
+
max_update: 50000
|
236 |
+
lr: [ 0.25 ]
|
237 |
+
criterion: cross_entropy
|
238 |
+
optimizer: adam
|
239 |
+
lr_scheduler:
|
240 |
+
_name: cosine
|
241 |
+
```
|
242 |
+
|
243 |
+
Note that here bundled configs from `fairseq/config` directory are not used,
|
244 |
+
however the defaults from each dataclass will still be used (unless overwritten
|
245 |
+
by your external config).
|
246 |
+
|
247 |
+
Additionally you can choose to break up your configs by creating a directory
|
248 |
+
structure in the same location as your main config file, with the names of the
|
249 |
+
top-level fields (such as "model", "dataset", etc), and placing config files
|
250 |
+
with meaningful names that would populate that specific section of your
|
251 |
+
top-level config file (for example, you might have
|
252 |
+
`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
|
253 |
+
can then specify the correct configuration via command line, defaults in the
|
254 |
+
main config, or even launch all of them as a sweep (see Hydra documentation on
|
255 |
+
how to do this).
|
256 |
+
|
257 |
+
### 3. Add an external config directory to Hydra search path:
|
258 |
+
|
259 |
+
This allows combining default configuration (including using any bundled config
|
260 |
+
files), while specifying your own config files for some parts of the
|
261 |
+
configuration.
|
262 |
+
|
263 |
+
```shell script
|
264 |
+
$ fairseq-hydra-train \
|
265 |
+
distributed_training.distributed_world_size=1 \
|
266 |
+
dataset.batch_size=2 \
|
267 |
+
task.data=/path/to/data/ \
|
268 |
+
model=transformer_lm/2_layers \
|
269 |
+
task=language_modeling \
|
270 |
+
optimization.max_update=5000 \
|
271 |
+
--config-dir /path/to/external/configs
|
272 |
+
```
|
273 |
+
|
274 |
+
where `/path/to/external/configs` has the following structure:
|
275 |
+
```
|
276 |
+
.
|
277 |
+
+-- model
|
278 |
+
| +-- transformer_lm
|
279 |
+
| | +-- 2_layers.yaml
|
280 |
+
```
|
281 |
+
|
282 |
+
and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
|
283 |
+
`decoder_layers` set to 2. You can add other configs to configure other
|
284 |
+
components as well.
|
fairseq/docs/index.rst
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. fairseq documentation master file, created by
|
2 |
+
sphinx-quickstart on Fri Aug 17 21:45:30 2018.
|
3 |
+
You can adapt this file completely to your liking, but it should at least
|
4 |
+
contain the root `toctree` directive.
|
5 |
+
|
6 |
+
:github_url: https://github.com/pytorch/fairseq
|
7 |
+
|
8 |
+
|
9 |
+
fairseq documentation
|
10 |
+
=====================
|
11 |
+
|
12 |
+
Fairseq is a sequence modeling toolkit written in `PyTorch
|
13 |
+
<http://pytorch.org/>`_ that allows researchers and developers to
|
14 |
+
train custom models for translation, summarization, language modeling and other
|
15 |
+
text generation tasks.
|
16 |
+
|
17 |
+
.. toctree::
|
18 |
+
:maxdepth: 1
|
19 |
+
:caption: Getting Started
|
20 |
+
|
21 |
+
getting_started
|
22 |
+
command_line_tools
|
23 |
+
|
24 |
+
.. toctree::
|
25 |
+
:maxdepth: 1
|
26 |
+
:caption: Extending Fairseq
|
27 |
+
|
28 |
+
overview
|
29 |
+
tutorial_simple_lstm
|
30 |
+
tutorial_classifying_names
|
31 |
+
|
32 |
+
.. toctree::
|
33 |
+
:maxdepth: 2
|
34 |
+
:caption: Library Reference
|
35 |
+
|
36 |
+
tasks
|
37 |
+
models
|
38 |
+
criterions
|
39 |
+
optim
|
40 |
+
lr_scheduler
|
41 |
+
data
|
42 |
+
modules
|
43 |
+
|
44 |
+
|
45 |
+
Indices and tables
|
46 |
+
==================
|
47 |
+
|
48 |
+
* :ref:`genindex`
|
49 |
+
* :ref:`search`
|
fairseq/docs/lr_scheduler.rst
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. _Learning Rate Schedulers:
|
5 |
+
|
6 |
+
Learning Rate Schedulers
|
7 |
+
========================
|
8 |
+
|
9 |
+
Learning Rate Schedulers update the learning rate over the course of training.
|
10 |
+
Learning rates can be updated after each update via :func:`step_update` or at
|
11 |
+
epoch boundaries via :func:`step`.
|
12 |
+
|
13 |
+
.. automodule:: fairseq.optim.lr_scheduler
|
14 |
+
:members:
|
15 |
+
|
16 |
+
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
|
17 |
+
:members:
|
18 |
+
:undoc-members:
|
19 |
+
|
20 |
+
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
|
21 |
+
:members:
|
22 |
+
:undoc-members:
|
23 |
+
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
|
24 |
+
:members:
|
25 |
+
:undoc-members:
|
26 |
+
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
|
27 |
+
:members:
|
28 |
+
:undoc-members:
|
29 |
+
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
|
30 |
+
:members:
|
31 |
+
:undoc-members:
|
32 |
+
.. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
|
33 |
+
:members:
|
34 |
+
:undoc-members:
|
fairseq/docs/make.bat
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@ECHO OFF
|
2 |
+
|
3 |
+
pushd %~dp0
|
4 |
+
|
5 |
+
REM Command file for Sphinx documentation
|
6 |
+
|
7 |
+
if "%SPHINXBUILD%" == "" (
|
8 |
+
set SPHINXBUILD=python -msphinx
|
9 |
+
)
|
10 |
+
set SOURCEDIR=.
|
11 |
+
set BUILDDIR=_build
|
12 |
+
set SPHINXPROJ=fairseq
|
13 |
+
|
14 |
+
if "%1" == "" goto help
|
15 |
+
|
16 |
+
%SPHINXBUILD% >NUL 2>NUL
|
17 |
+
if errorlevel 9009 (
|
18 |
+
echo.
|
19 |
+
echo.The Sphinx module was not found. Make sure you have Sphinx installed,
|
20 |
+
echo.then set the SPHINXBUILD environment variable to point to the full
|
21 |
+
echo.path of the 'sphinx-build' executable. Alternatively you may add the
|
22 |
+
echo.Sphinx directory to PATH.
|
23 |
+
echo.
|
24 |
+
echo.If you don't have Sphinx installed, grab it from
|
25 |
+
echo.http://sphinx-doc.org/
|
26 |
+
exit /b 1
|
27 |
+
)
|
28 |
+
|
29 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
30 |
+
goto end
|
31 |
+
|
32 |
+
:help
|
33 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
34 |
+
|
35 |
+
:end
|
36 |
+
popd
|
fairseq/docs/modules.rst
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Modules
|
2 |
+
=======
|
3 |
+
|
4 |
+
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
|
5 |
+
be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
|
6 |
+
|
7 |
+
.. automodule:: fairseq.modules
|
8 |
+
:members:
|
9 |
+
:undoc-members:
|
fairseq/docs/optim.rst
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. _optimizers:
|
5 |
+
|
6 |
+
Optimizers
|
7 |
+
==========
|
8 |
+
|
9 |
+
Optimizers update the Model parameters based on the gradients.
|
10 |
+
|
11 |
+
.. automodule:: fairseq.optim
|
12 |
+
:members:
|
13 |
+
|
14 |
+
.. autoclass:: fairseq.optim.FairseqOptimizer
|
15 |
+
:members:
|
16 |
+
:undoc-members:
|
17 |
+
|
18 |
+
.. autoclass:: fairseq.optim.adadelta.Adadelta
|
19 |
+
:members:
|
20 |
+
:undoc-members:
|
21 |
+
.. autoclass:: fairseq.optim.adagrad.Adagrad
|
22 |
+
:members:
|
23 |
+
:undoc-members:
|
24 |
+
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
|
25 |
+
:members:
|
26 |
+
:undoc-members:
|
27 |
+
.. autoclass:: fairseq.optim.adam.FairseqAdam
|
28 |
+
:members:
|
29 |
+
:undoc-members:
|
30 |
+
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
|
31 |
+
:members:
|
32 |
+
:undoc-members:
|
33 |
+
.. autoclass:: fairseq.optim.nag.FairseqNAG
|
34 |
+
:members:
|
35 |
+
:undoc-members:
|
36 |
+
.. autoclass:: fairseq.optim.sgd.SGD
|
37 |
+
:members:
|
38 |
+
:undoc-members:
|
fairseq/docs/overview.rst
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Overview
|
2 |
+
========
|
3 |
+
|
4 |
+
Fairseq can be extended through user-supplied `plug-ins
|
5 |
+
<https://en.wikipedia.org/wiki/Plug-in_(computing)>`_. We support five kinds of
|
6 |
+
plug-ins:
|
7 |
+
|
8 |
+
- :ref:`Models` define the neural network architecture and encapsulate all of the
|
9 |
+
learnable parameters.
|
10 |
+
- :ref:`Criterions` compute the loss function given the model outputs and targets.
|
11 |
+
- :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
|
12 |
+
Datasets, initializing the Model/Criterion and calculating the loss.
|
13 |
+
- :ref:`Optimizers` update the Model parameters based on the gradients.
|
14 |
+
- :ref:`Learning Rate Schedulers` update the learning rate over the course of
|
15 |
+
training.
|
16 |
+
|
17 |
+
**Training Flow**
|
18 |
+
|
19 |
+
Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
|
20 |
+
fairseq implements the following high-level training flow::
|
21 |
+
|
22 |
+
for epoch in range(num_epochs):
|
23 |
+
itr = task.get_batch_iterator(task.dataset('train'))
|
24 |
+
for num_updates, batch in enumerate(itr):
|
25 |
+
task.train_step(batch, model, criterion, optimizer)
|
26 |
+
average_and_clip_gradients()
|
27 |
+
optimizer.step()
|
28 |
+
lr_scheduler.step_update(num_updates)
|
29 |
+
lr_scheduler.step(epoch)
|
30 |
+
|
31 |
+
where the default implementation for ``task.train_step`` is roughly::
|
32 |
+
|
33 |
+
def train_step(self, batch, model, criterion, optimizer, **unused):
|
34 |
+
loss = criterion(model, batch)
|
35 |
+
optimizer.backward(loss)
|
36 |
+
return loss
|
37 |
+
|
38 |
+
**Registering new plug-ins**
|
39 |
+
|
40 |
+
New plug-ins are *registered* through a set of ``@register`` function
|
41 |
+
decorators, for example::
|
42 |
+
|
43 |
+
@register_model('my_lstm')
|
44 |
+
class MyLSTM(FairseqEncoderDecoderModel):
|
45 |
+
(...)
|
46 |
+
|
47 |
+
Once registered, new plug-ins can be used with the existing :ref:`Command-line
|
48 |
+
Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
|
49 |
+
new plug-ins.
|
50 |
+
|
51 |
+
**Loading plug-ins from another directory**
|
52 |
+
|
53 |
+
New plug-ins can be defined in a custom module stored in the user system. In
|
54 |
+
order to import the module, and make the plugin available to *fairseq*, the
|
55 |
+
command line supports the ``--user-dir`` flag that can be used to specify a
|
56 |
+
custom location for additional modules to load into *fairseq*.
|
57 |
+
|
58 |
+
For example, assuming this directory tree::
|
59 |
+
|
60 |
+
/home/user/my-module/
|
61 |
+
└── __init__.py
|
62 |
+
|
63 |
+
with ``__init__.py``::
|
64 |
+
|
65 |
+
from fairseq.models import register_model_architecture
|
66 |
+
from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
|
67 |
+
|
68 |
+
@register_model_architecture('transformer', 'my_transformer')
|
69 |
+
def transformer_mmt_big(args):
|
70 |
+
transformer_vaswani_wmt_en_de_big(args)
|
71 |
+
|
72 |
+
it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
|
73 |
+
|
74 |
+
fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation
|
fairseq/docs/tasks.rst
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. module:: fairseq.tasks
|
5 |
+
|
6 |
+
.. _Tasks:
|
7 |
+
|
8 |
+
Tasks
|
9 |
+
=====
|
10 |
+
|
11 |
+
Tasks store dictionaries and provide helpers for loading/iterating over
|
12 |
+
Datasets, initializing the Model/Criterion and calculating the loss.
|
13 |
+
|
14 |
+
Tasks can be selected via the ``--task`` command-line argument. Once selected, a
|
15 |
+
task may expose additional command-line arguments for further configuration.
|
16 |
+
|
17 |
+
Example usage::
|
18 |
+
|
19 |
+
# setup the task (e.g., load dictionaries)
|
20 |
+
task = fairseq.tasks.setup_task(args)
|
21 |
+
|
22 |
+
# build model and criterion
|
23 |
+
model = task.build_model(args)
|
24 |
+
criterion = task.build_criterion(args)
|
25 |
+
|
26 |
+
# load datasets
|
27 |
+
task.load_dataset('train')
|
28 |
+
task.load_dataset('valid')
|
29 |
+
|
30 |
+
# iterate over mini-batches of data
|
31 |
+
batch_itr = task.get_batch_iterator(
|
32 |
+
task.dataset('train'), max_tokens=4096,
|
33 |
+
)
|
34 |
+
for batch in batch_itr:
|
35 |
+
# compute the loss
|
36 |
+
loss, sample_size, logging_output = task.get_loss(
|
37 |
+
model, criterion, batch,
|
38 |
+
)
|
39 |
+
loss.backward()
|
40 |
+
|
41 |
+
|
42 |
+
Translation
|
43 |
+
-----------
|
44 |
+
|
45 |
+
.. autoclass:: fairseq.tasks.translation.TranslationTask
|
46 |
+
|
47 |
+
.. _language modeling:
|
48 |
+
|
49 |
+
Language Modeling
|
50 |
+
-----------------
|
51 |
+
|
52 |
+
.. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
|
53 |
+
|
54 |
+
|
55 |
+
Adding new tasks
|
56 |
+
----------------
|
57 |
+
|
58 |
+
.. autofunction:: fairseq.tasks.register_task
|
59 |
+
.. autoclass:: fairseq.tasks.FairseqTask
|
60 |
+
:members:
|
61 |
+
:undoc-members:
|
fairseq/docs/tutorial_simple_lstm.rst
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tutorial: Simple LSTM
|
2 |
+
=====================
|
3 |
+
|
4 |
+
In this tutorial we will extend fairseq by adding a new
|
5 |
+
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
|
6 |
+
sentence with an LSTM and then passes the final hidden state to a second LSTM
|
7 |
+
that decodes the target sentence (without attention).
|
8 |
+
|
9 |
+
This tutorial covers:
|
10 |
+
|
11 |
+
1. **Writing an Encoder and Decoder** to encode/decode the source/target
|
12 |
+
sentence, respectively.
|
13 |
+
2. **Registering a new Model** so that it can be used with the existing
|
14 |
+
:ref:`Command-line tools`.
|
15 |
+
3. **Training the Model** using the existing command-line tools.
|
16 |
+
4. **Making generation faster** by modifying the Decoder to use
|
17 |
+
:ref:`Incremental decoding`.
|
18 |
+
|
19 |
+
|
20 |
+
1. Building an Encoder and Decoder
|
21 |
+
----------------------------------
|
22 |
+
|
23 |
+
In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
|
24 |
+
should implement the :class:`~fairseq.models.FairseqEncoder` interface and
|
25 |
+
Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
|
26 |
+
These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
|
27 |
+
and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
|
28 |
+
Modules.
|
29 |
+
|
30 |
+
|
31 |
+
Encoder
|
32 |
+
~~~~~~~
|
33 |
+
|
34 |
+
Our Encoder will embed the tokens in the source sentence, feed them to a
|
35 |
+
:class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
|
36 |
+
save the following in a new file named :file:`fairseq/models/simple_lstm.py`::
|
37 |
+
|
38 |
+
import torch.nn as nn
|
39 |
+
from fairseq import utils
|
40 |
+
from fairseq.models import FairseqEncoder
|
41 |
+
|
42 |
+
class SimpleLSTMEncoder(FairseqEncoder):
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
|
46 |
+
):
|
47 |
+
super().__init__(dictionary)
|
48 |
+
self.args = args
|
49 |
+
|
50 |
+
# Our encoder will embed the inputs before feeding them to the LSTM.
|
51 |
+
self.embed_tokens = nn.Embedding(
|
52 |
+
num_embeddings=len(dictionary),
|
53 |
+
embedding_dim=embed_dim,
|
54 |
+
padding_idx=dictionary.pad(),
|
55 |
+
)
|
56 |
+
self.dropout = nn.Dropout(p=dropout)
|
57 |
+
|
58 |
+
# We'll use a single-layer, unidirectional LSTM for simplicity.
|
59 |
+
self.lstm = nn.LSTM(
|
60 |
+
input_size=embed_dim,
|
61 |
+
hidden_size=hidden_dim,
|
62 |
+
num_layers=1,
|
63 |
+
bidirectional=False,
|
64 |
+
batch_first=True,
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, src_tokens, src_lengths):
|
68 |
+
# The inputs to the ``forward()`` function are determined by the
|
69 |
+
# Task, and in particular the ``'net_input'`` key in each
|
70 |
+
# mini-batch. We discuss Tasks in the next tutorial, but for now just
|
71 |
+
# know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
|
72 |
+
# has shape `(batch)`.
|
73 |
+
|
74 |
+
# Note that the source is typically padded on the left. This can be
|
75 |
+
# configured by adding the `--left-pad-source "False"` command-line
|
76 |
+
# argument, but here we'll make the Encoder handle either kind of
|
77 |
+
# padding by converting everything to be right-padded.
|
78 |
+
if self.args.left_pad_source:
|
79 |
+
# Convert left-padding to right-padding.
|
80 |
+
src_tokens = utils.convert_padding_direction(
|
81 |
+
src_tokens,
|
82 |
+
padding_idx=self.dictionary.pad(),
|
83 |
+
left_to_right=True
|
84 |
+
)
|
85 |
+
|
86 |
+
# Embed the source.
|
87 |
+
x = self.embed_tokens(src_tokens)
|
88 |
+
|
89 |
+
# Apply dropout.
|
90 |
+
x = self.dropout(x)
|
91 |
+
|
92 |
+
# Pack the sequence into a PackedSequence object to feed to the LSTM.
|
93 |
+
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
|
94 |
+
|
95 |
+
# Get the output from the LSTM.
|
96 |
+
_outputs, (final_hidden, _final_cell) = self.lstm(x)
|
97 |
+
|
98 |
+
# Return the Encoder's output. This can be any object and will be
|
99 |
+
# passed directly to the Decoder.
|
100 |
+
return {
|
101 |
+
# this will have shape `(bsz, hidden_dim)`
|
102 |
+
'final_hidden': final_hidden.squeeze(0),
|
103 |
+
}
|
104 |
+
|
105 |
+
# Encoders are required to implement this method so that we can rearrange
|
106 |
+
# the order of the batch elements during inference (e.g., beam search).
|
107 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
108 |
+
"""
|
109 |
+
Reorder encoder output according to `new_order`.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
encoder_out: output from the ``forward()`` method
|
113 |
+
new_order (LongTensor): desired order
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
`encoder_out` rearranged according to `new_order`
|
117 |
+
"""
|
118 |
+
final_hidden = encoder_out['final_hidden']
|
119 |
+
return {
|
120 |
+
'final_hidden': final_hidden.index_select(0, new_order),
|
121 |
+
}
|
122 |
+
|
123 |
+
|
124 |
+
Decoder
|
125 |
+
~~~~~~~
|
126 |
+
|
127 |
+
Our Decoder will predict the next word, conditioned on the Encoder's final
|
128 |
+
hidden state and an embedded representation of the previous target word -- which
|
129 |
+
is sometimes called *teacher forcing*. More specifically, we'll use a
|
130 |
+
:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
|
131 |
+
to the size of the output vocabulary to predict each target word.
|
132 |
+
|
133 |
+
::
|
134 |
+
|
135 |
+
import torch
|
136 |
+
from fairseq.models import FairseqDecoder
|
137 |
+
|
138 |
+
class SimpleLSTMDecoder(FairseqDecoder):
|
139 |
+
|
140 |
+
def __init__(
|
141 |
+
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
|
142 |
+
dropout=0.1,
|
143 |
+
):
|
144 |
+
super().__init__(dictionary)
|
145 |
+
|
146 |
+
# Our decoder will embed the inputs before feeding them to the LSTM.
|
147 |
+
self.embed_tokens = nn.Embedding(
|
148 |
+
num_embeddings=len(dictionary),
|
149 |
+
embedding_dim=embed_dim,
|
150 |
+
padding_idx=dictionary.pad(),
|
151 |
+
)
|
152 |
+
self.dropout = nn.Dropout(p=dropout)
|
153 |
+
|
154 |
+
# We'll use a single-layer, unidirectional LSTM for simplicity.
|
155 |
+
self.lstm = nn.LSTM(
|
156 |
+
# For the first layer we'll concatenate the Encoder's final hidden
|
157 |
+
# state with the embedded target tokens.
|
158 |
+
input_size=encoder_hidden_dim + embed_dim,
|
159 |
+
hidden_size=hidden_dim,
|
160 |
+
num_layers=1,
|
161 |
+
bidirectional=False,
|
162 |
+
)
|
163 |
+
|
164 |
+
# Define the output projection.
|
165 |
+
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
|
166 |
+
|
167 |
+
# During training Decoders are expected to take the entire target sequence
|
168 |
+
# (shifted right by one position) and produce logits over the vocabulary.
|
169 |
+
# The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
|
170 |
+
# ``dictionary.eos()``, followed by the target sequence.
|
171 |
+
def forward(self, prev_output_tokens, encoder_out):
|
172 |
+
"""
|
173 |
+
Args:
|
174 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
175 |
+
`(batch, tgt_len)`, for teacher forcing
|
176 |
+
encoder_out (Tensor, optional): output from the encoder, used for
|
177 |
+
encoder-side attention
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
tuple:
|
181 |
+
- the last decoder layer's output of shape
|
182 |
+
`(batch, tgt_len, vocab)`
|
183 |
+
- the last decoder layer's attention weights of shape
|
184 |
+
`(batch, tgt_len, src_len)`
|
185 |
+
"""
|
186 |
+
bsz, tgt_len = prev_output_tokens.size()
|
187 |
+
|
188 |
+
# Extract the final hidden state from the Encoder.
|
189 |
+
final_encoder_hidden = encoder_out['final_hidden']
|
190 |
+
|
191 |
+
# Embed the target sequence, which has been shifted right by one
|
192 |
+
# position and now starts with the end-of-sentence symbol.
|
193 |
+
x = self.embed_tokens(prev_output_tokens)
|
194 |
+
|
195 |
+
# Apply dropout.
|
196 |
+
x = self.dropout(x)
|
197 |
+
|
198 |
+
# Concatenate the Encoder's final hidden state to *every* embedded
|
199 |
+
# target token.
|
200 |
+
x = torch.cat(
|
201 |
+
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
|
202 |
+
dim=2,
|
203 |
+
)
|
204 |
+
|
205 |
+
# Using PackedSequence objects in the Decoder is harder than in the
|
206 |
+
# Encoder, since the targets are not sorted in descending length order,
|
207 |
+
# which is a requirement of ``pack_padded_sequence()``. Instead we'll
|
208 |
+
# feed nn.LSTM directly.
|
209 |
+
initial_state = (
|
210 |
+
final_encoder_hidden.unsqueeze(0), # hidden
|
211 |
+
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
|
212 |
+
)
|
213 |
+
output, _ = self.lstm(
|
214 |
+
x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
|
215 |
+
initial_state,
|
216 |
+
)
|
217 |
+
x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
|
218 |
+
|
219 |
+
# Project the outputs to the size of the vocabulary.
|
220 |
+
x = self.output_projection(x)
|
221 |
+
|
222 |
+
# Return the logits and ``None`` for the attention weights
|
223 |
+
return x, None
|
224 |
+
|
225 |
+
|
226 |
+
2. Registering the Model
|
227 |
+
------------------------
|
228 |
+
|
229 |
+
Now that we've defined our Encoder and Decoder we must *register* our model with
|
230 |
+
fairseq using the :func:`~fairseq.models.register_model` function decorator.
|
231 |
+
Once the model is registered we'll be able to use it with the existing
|
232 |
+
:ref:`Command-line Tools`.
|
233 |
+
|
234 |
+
All registered models must implement the
|
235 |
+
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
|
236 |
+
models (i.e., any model with a single Encoder and Decoder), we can instead
|
237 |
+
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
|
238 |
+
|
239 |
+
Create a small wrapper class in the same file and register it in fairseq with
|
240 |
+
the name ``'simple_lstm'``::
|
241 |
+
|
242 |
+
from fairseq.models import FairseqEncoderDecoderModel, register_model
|
243 |
+
|
244 |
+
# Note: the register_model "decorator" should immediately precede the
|
245 |
+
# definition of the Model class.
|
246 |
+
|
247 |
+
@register_model('simple_lstm')
|
248 |
+
class SimpleLSTMModel(FairseqEncoderDecoderModel):
|
249 |
+
|
250 |
+
@staticmethod
|
251 |
+
def add_args(parser):
|
252 |
+
# Models can override this method to add new command-line arguments.
|
253 |
+
# Here we'll add some new command-line arguments to configure dropout
|
254 |
+
# and the dimensionality of the embeddings and hidden states.
|
255 |
+
parser.add_argument(
|
256 |
+
'--encoder-embed-dim', type=int, metavar='N',
|
257 |
+
help='dimensionality of the encoder embeddings',
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
'--encoder-hidden-dim', type=int, metavar='N',
|
261 |
+
help='dimensionality of the encoder hidden state',
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
'--encoder-dropout', type=float, default=0.1,
|
265 |
+
help='encoder dropout probability',
|
266 |
+
)
|
267 |
+
parser.add_argument(
|
268 |
+
'--decoder-embed-dim', type=int, metavar='N',
|
269 |
+
help='dimensionality of the decoder embeddings',
|
270 |
+
)
|
271 |
+
parser.add_argument(
|
272 |
+
'--decoder-hidden-dim', type=int, metavar='N',
|
273 |
+
help='dimensionality of the decoder hidden state',
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
'--decoder-dropout', type=float, default=0.1,
|
277 |
+
help='decoder dropout probability',
|
278 |
+
)
|
279 |
+
|
280 |
+
@classmethod
|
281 |
+
def build_model(cls, args, task):
|
282 |
+
# Fairseq initializes models by calling the ``build_model()``
|
283 |
+
# function. This provides more flexibility, since the returned model
|
284 |
+
# instance can be of a different type than the one that was called.
|
285 |
+
# In this case we'll just return a SimpleLSTMModel instance.
|
286 |
+
|
287 |
+
# Initialize our Encoder and Decoder.
|
288 |
+
encoder = SimpleLSTMEncoder(
|
289 |
+
args=args,
|
290 |
+
dictionary=task.source_dictionary,
|
291 |
+
embed_dim=args.encoder_embed_dim,
|
292 |
+
hidden_dim=args.encoder_hidden_dim,
|
293 |
+
dropout=args.encoder_dropout,
|
294 |
+
)
|
295 |
+
decoder = SimpleLSTMDecoder(
|
296 |
+
dictionary=task.target_dictionary,
|
297 |
+
encoder_hidden_dim=args.encoder_hidden_dim,
|
298 |
+
embed_dim=args.decoder_embed_dim,
|
299 |
+
hidden_dim=args.decoder_hidden_dim,
|
300 |
+
dropout=args.decoder_dropout,
|
301 |
+
)
|
302 |
+
model = SimpleLSTMModel(encoder, decoder)
|
303 |
+
|
304 |
+
# Print the model architecture.
|
305 |
+
print(model)
|
306 |
+
|
307 |
+
return model
|
308 |
+
|
309 |
+
# We could override the ``forward()`` if we wanted more control over how
|
310 |
+
# the encoder and decoder interact, but it's not necessary for this
|
311 |
+
# tutorial since we can inherit the default implementation provided by
|
312 |
+
# the FairseqEncoderDecoderModel base class, which looks like:
|
313 |
+
#
|
314 |
+
# def forward(self, src_tokens, src_lengths, prev_output_tokens):
|
315 |
+
# encoder_out = self.encoder(src_tokens, src_lengths)
|
316 |
+
# decoder_out = self.decoder(prev_output_tokens, encoder_out)
|
317 |
+
# return decoder_out
|
318 |
+
|
319 |
+
Finally let's define a *named architecture* with the configuration for our
|
320 |
+
model. This is done with the :func:`~fairseq.models.register_model_architecture`
|
321 |
+
function decorator. Thereafter this named architecture can be used with the
|
322 |
+
``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
|
323 |
+
|
324 |
+
from fairseq.models import register_model_architecture
|
325 |
+
|
326 |
+
# The first argument to ``register_model_architecture()`` should be the name
|
327 |
+
# of the model we registered above (i.e., 'simple_lstm'). The function we
|
328 |
+
# register here should take a single argument *args* and modify it in-place
|
329 |
+
# to match the desired architecture.
|
330 |
+
|
331 |
+
@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
|
332 |
+
def tutorial_simple_lstm(args):
|
333 |
+
# We use ``getattr()`` to prioritize arguments that are explicitly given
|
334 |
+
# on the command-line, so that the defaults defined below are only used
|
335 |
+
# when no other value has been specified.
|
336 |
+
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
|
337 |
+
args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
|
338 |
+
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
|
339 |
+
args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
|
340 |
+
|
341 |
+
|
342 |
+
3. Training the Model
|
343 |
+
---------------------
|
344 |
+
|
345 |
+
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
|
346 |
+
command-line tool for this, making sure to specify our new Model architecture
|
347 |
+
(``--arch tutorial_simple_lstm``).
|
348 |
+
|
349 |
+
.. note::
|
350 |
+
|
351 |
+
Make sure you've already preprocessed the data from the IWSLT example in the
|
352 |
+
:file:`examples/translation/` directory.
|
353 |
+
|
354 |
+
.. code-block:: console
|
355 |
+
|
356 |
+
> fairseq-train data-bin/iwslt14.tokenized.de-en \
|
357 |
+
--arch tutorial_simple_lstm \
|
358 |
+
--encoder-dropout 0.2 --decoder-dropout 0.2 \
|
359 |
+
--optimizer adam --lr 0.005 --lr-shrink 0.5 \
|
360 |
+
--max-tokens 12000
|
361 |
+
(...)
|
362 |
+
| epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
|
363 |
+
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
|
364 |
+
|
365 |
+
The model files should appear in the :file:`checkpoints/` directory. While this
|
366 |
+
model architecture is not very good, we can use the :ref:`fairseq-generate` script to
|
367 |
+
generate translations and compute our BLEU score over the test set:
|
368 |
+
|
369 |
+
.. code-block:: console
|
370 |
+
|
371 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
372 |
+
--path checkpoints/checkpoint_best.pt \
|
373 |
+
--beam 5 \
|
374 |
+
--remove-bpe
|
375 |
+
(...)
|
376 |
+
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
|
377 |
+
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
|
378 |
+
|
379 |
+
|
380 |
+
4. Making generation faster
|
381 |
+
---------------------------
|
382 |
+
|
383 |
+
While autoregressive generation from sequence-to-sequence models is inherently
|
384 |
+
slow, our implementation above is especially slow because it recomputes the
|
385 |
+
entire sequence of Decoder hidden states for every output token (i.e., it is
|
386 |
+
``O(n^2)``). We can make this significantly faster by instead caching the
|
387 |
+
previous hidden states.
|
388 |
+
|
389 |
+
In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
|
390 |
+
special mode at inference time where the Model only receives a single timestep
|
391 |
+
of input corresponding to the immediately previous output token (for teacher
|
392 |
+
forcing) and must produce the next output incrementally. Thus the model must
|
393 |
+
cache any long-term state that is needed about the sequence, e.g., hidden
|
394 |
+
states, convolutional states, etc.
|
395 |
+
|
396 |
+
To implement incremental decoding we will modify our model to implement the
|
397 |
+
:class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
|
398 |
+
standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
|
399 |
+
decoder interface allows ``forward()`` methods to take an extra keyword argument
|
400 |
+
(*incremental_state*) that can be used to cache state across time-steps.
|
401 |
+
|
402 |
+
Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
|
403 |
+
|
404 |
+
import torch
|
405 |
+
from fairseq.models import FairseqIncrementalDecoder
|
406 |
+
|
407 |
+
class SimpleLSTMDecoder(FairseqIncrementalDecoder):
|
408 |
+
|
409 |
+
def __init__(
|
410 |
+
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
|
411 |
+
dropout=0.1,
|
412 |
+
):
|
413 |
+
# This remains the same as before.
|
414 |
+
super().__init__(dictionary)
|
415 |
+
self.embed_tokens = nn.Embedding(
|
416 |
+
num_embeddings=len(dictionary),
|
417 |
+
embedding_dim=embed_dim,
|
418 |
+
padding_idx=dictionary.pad(),
|
419 |
+
)
|
420 |
+
self.dropout = nn.Dropout(p=dropout)
|
421 |
+
self.lstm = nn.LSTM(
|
422 |
+
input_size=encoder_hidden_dim + embed_dim,
|
423 |
+
hidden_size=hidden_dim,
|
424 |
+
num_layers=1,
|
425 |
+
bidirectional=False,
|
426 |
+
)
|
427 |
+
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
|
428 |
+
|
429 |
+
# We now take an additional kwarg (*incremental_state*) for caching the
|
430 |
+
# previous hidden and cell states.
|
431 |
+
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
|
432 |
+
if incremental_state is not None:
|
433 |
+
# If the *incremental_state* argument is not ``None`` then we are
|
434 |
+
# in incremental inference mode. While *prev_output_tokens* will
|
435 |
+
# still contain the entire decoded prefix, we will only use the
|
436 |
+
# last step and assume that the rest of the state is cached.
|
437 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
438 |
+
|
439 |
+
# This remains the same as before.
|
440 |
+
bsz, tgt_len = prev_output_tokens.size()
|
441 |
+
final_encoder_hidden = encoder_out['final_hidden']
|
442 |
+
x = self.embed_tokens(prev_output_tokens)
|
443 |
+
x = self.dropout(x)
|
444 |
+
x = torch.cat(
|
445 |
+
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
|
446 |
+
dim=2,
|
447 |
+
)
|
448 |
+
|
449 |
+
# We will now check the cache and load the cached previous hidden and
|
450 |
+
# cell states, if they exist, otherwise we will initialize them to
|
451 |
+
# zeros (as before). We will use the ``utils.get_incremental_state()``
|
452 |
+
# and ``utils.set_incremental_state()`` helpers.
|
453 |
+
initial_state = utils.get_incremental_state(
|
454 |
+
self, incremental_state, 'prev_state',
|
455 |
+
)
|
456 |
+
if initial_state is None:
|
457 |
+
# first time initialization, same as the original version
|
458 |
+
initial_state = (
|
459 |
+
final_encoder_hidden.unsqueeze(0), # hidden
|
460 |
+
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
|
461 |
+
)
|
462 |
+
|
463 |
+
# Run one step of our LSTM.
|
464 |
+
output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
|
465 |
+
|
466 |
+
# Update the cache with the latest hidden and cell states.
|
467 |
+
utils.set_incremental_state(
|
468 |
+
self, incremental_state, 'prev_state', latest_state,
|
469 |
+
)
|
470 |
+
|
471 |
+
# This remains the same as before
|
472 |
+
x = output.transpose(0, 1)
|
473 |
+
x = self.output_projection(x)
|
474 |
+
return x, None
|
475 |
+
|
476 |
+
# The ``FairseqIncrementalDecoder`` interface also requires implementing a
|
477 |
+
# ``reorder_incremental_state()`` method, which is used during beam search
|
478 |
+
# to select and reorder the incremental state.
|
479 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
480 |
+
# Load the cached state.
|
481 |
+
prev_state = utils.get_incremental_state(
|
482 |
+
self, incremental_state, 'prev_state',
|
483 |
+
)
|
484 |
+
|
485 |
+
# Reorder batches according to *new_order*.
|
486 |
+
reordered_state = (
|
487 |
+
prev_state[0].index_select(1, new_order), # hidden
|
488 |
+
prev_state[1].index_select(1, new_order), # cell
|
489 |
+
)
|
490 |
+
|
491 |
+
# Update the cached state.
|
492 |
+
utils.set_incremental_state(
|
493 |
+
self, incremental_state, 'prev_state', reordered_state,
|
494 |
+
)
|
495 |
+
|
496 |
+
Finally, we can rerun generation and observe the speedup:
|
497 |
+
|
498 |
+
.. code-block:: console
|
499 |
+
|
500 |
+
# Before
|
501 |
+
|
502 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
503 |
+
--path checkpoints/checkpoint_best.pt \
|
504 |
+
--beam 5 \
|
505 |
+
--remove-bpe
|
506 |
+
(...)
|
507 |
+
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
|
508 |
+
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
|
509 |
+
|
510 |
+
# After
|
511 |
+
|
512 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
513 |
+
--path checkpoints/checkpoint_best.pt \
|
514 |
+
--beam 5 \
|
515 |
+
--remove-bpe
|
516 |
+
(...)
|
517 |
+
| Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
|
518 |
+
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
|
fairseq/examples/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
!*/*.sh
|
2 |
+
!*/*.md
|
fairseq/examples/MMPT/CONFIG.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Config Files Explained
|
2 |
+
|
3 |
+
Taking `projects/mfmmlm.yaml` for example, which run pretraining using masked frame model (MFM) and masked language model (MLM) on a single BERT:
|
4 |
+
|
5 |
+
```yaml
|
6 |
+
project_dir: mfmmlm # specify the project dir for this baseline.
|
7 |
+
run_task:
|
8 |
+
- how2.yaml # run pretraining on how2 when launching `projects/taskmfmmlm.yaml`
|
9 |
+
- [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml] # run fine-tuning tasks.
|
10 |
+
base_dir: task # a global template folder to specify each training task.
|
11 |
+
task_group:
|
12 |
+
pretrain: # section for pretraining. Most baselines differs in this section.
|
13 |
+
task_list:
|
14 |
+
- how2.yaml # reconfig `projects/task/how2.yaml`
|
15 |
+
dataset:
|
16 |
+
aligner: MFMMLMAligner # overwrite the aligner for MFMMLM training task.
|
17 |
+
model:
|
18 |
+
model_cls: MMFusionMFMMLM # overwrite the model, which constructs negative examples for MFM on-the-fly.
|
19 |
+
loss:
|
20 |
+
loss_cls: MFMMLM # overwrite the loss as MFMMLM, which combines MFM and MLM together.
|
21 |
+
fairseq: # all fairseq args can be expecified under this name.
|
22 |
+
dataset:
|
23 |
+
batch_size: 128
|
24 |
+
finetune: # section for fine-tuning tasks, we don't need to change anything here mostly since we want to see how pretraining can contribute to finetuning.
|
25 |
+
task_list: # specify the list of downstream tasks, e.g., copy `projects/task/vtt.yaml` to `projects/mfmmlm`.
|
26 |
+
- vtt.yaml
|
27 |
+
- vttqa.yaml
|
28 |
+
- youcook.yaml
|
29 |
+
- youcookcap.yaml
|
30 |
+
- crosstask.yaml
|
31 |
+
- coin.yaml
|
32 |
+
test: # section for testing.
|
33 |
+
task_list:
|
34 |
+
- test_vtt.yaml
|
35 |
+
- test_vttqa.yaml
|
36 |
+
- test_youcook.yaml
|
37 |
+
- test_youcookcap.yaml
|
38 |
+
- test_crosstask.yaml
|
39 |
+
- test_crosstask_zs.yaml
|
40 |
+
- test_coin.yaml
|
41 |
+
```
|
fairseq/examples/MMPT/DATASET.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset
|
2 |
+
|
3 |
+
We understand video data are challenging to download and process. For videos, we provide our preprocessing scripts under `scripts/video_feature_extractor` (deeply adapted from `https://github.com/antoine77340/video_feature_extractor`); for text, we pre-tokenizing scripts under `scripts/text_token_extractor`.
|
4 |
+
|
5 |
+
### S3D Feature Extraction
|
6 |
+
We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
|
7 |
+
|
8 |
+
We implement a `PathBuilder` to automatically track video ids, source video paths to their feature locations (you may need `conda install -c anaconda pandas`). Decoding may need `pip install ffmpeg-python`.
|
9 |
+
|
10 |
+
### Howto100M
|
11 |
+
[Howto100M](https://www.di.ens.fr/willow/research/howto100m/) is a large-scale video pre-training datasets. You may download videos by yourself and run preprocessing of our scripts.
|
12 |
+
|
13 |
+
Several key differences of our preprocessing from existing papers: (1) we use `raw_caption.json` instead of `caption.json` to have pure self-supervision on text (`caption.json` has manual removal of stop words); (2) we remove partially duplicated texts that are originally designed for real-time readability (see `mmpt/processors/dedupprocessor.py`); (3) then we shard video/text features using `SharedTensor` in `mmpt/utils/shardedtensor.py` for fast loading during training (faster than `h5py`).
|
14 |
+
|
15 |
+
#### Steps
|
16 |
+
##### video
|
17 |
+
To extract video features: edit and run `bash scripts/video_feature_extractor/how2/s3d.sh`. (consider to run this on multiple machines; by default, we store features in fp16 to save space and also for faster training).
|
18 |
+
|
19 |
+
Split available video ids as `data/how2/how2_s3d_train.lst` and `data/how2/how2_s3d_val.lst`.
|
20 |
+
|
21 |
+
Lastly, pack video features into `ShardedTensor` using `python scripts/video_feature_extractor/shard_feature.py`.
|
22 |
+
|
23 |
+
##### text
|
24 |
+
Clean captions using `python -m mmpt.processors.dedupprocessor`.
|
25 |
+
|
26 |
+
Tokenize dedupped captions `data/how2/raw_caption_dedup.pkl` into sharded numpy arrays:
|
27 |
+
```
|
28 |
+
python scripts/text_token_extractor/pretokenization.py scripts/text_token_extractor/configs/bert-base-uncased.yaml
|
29 |
+
```
|
30 |
+
|
31 |
+
### Youcook, MSRVTT etc.
|
32 |
+
We use the version of Youcook and MSRVTT come with Howto100M and MILNCE. Please download the data to `data/youcook` and `data/msrvtt` accordingly, you can also check `projects/task/youcook.yaml` and `projects/task/vtt.yaml` etc. in details.
|
33 |
+
We extract features for Youcook, MSRVTT similar to the first step of Howto100M but we read text from meta data directly and perform on-the-fly tokenization.
|
34 |
+
|
fairseq/examples/MMPT/locallaunch.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
from mmpt.utils import recursive_config, overwrite_dir
|
11 |
+
from mmpt_cli.localjob import LocalJob
|
12 |
+
|
13 |
+
|
14 |
+
class JobLauncher(object):
|
15 |
+
JOB_CONFIG = {
|
16 |
+
"local": LocalJob,
|
17 |
+
}
|
18 |
+
|
19 |
+
def __init__(self, yaml_file):
|
20 |
+
self.yaml_file = yaml_file
|
21 |
+
job_key = "local"
|
22 |
+
|
23 |
+
if yaml_file.endswith(".yaml"):
|
24 |
+
config = recursive_config(yaml_file)
|
25 |
+
if config.task_type is not None:
|
26 |
+
job_key = config.task_type.split("_")[0]
|
27 |
+
else:
|
28 |
+
raise ValueError("unknown extension of job file:", yaml_file)
|
29 |
+
self.job_key = job_key
|
30 |
+
|
31 |
+
def __call__(self, job_type=None, dryrun=False):
|
32 |
+
if job_type is not None:
|
33 |
+
self.job_key = job_type.split("_")[0]
|
34 |
+
print("[JobLauncher] job_key", self.job_key)
|
35 |
+
job = JobLauncher.JOB_CONFIG[self.job_key](
|
36 |
+
self.yaml_file, job_type=job_type, dryrun=dryrun)
|
37 |
+
return job.submit()
|
38 |
+
|
39 |
+
|
40 |
+
class Pipeline(object):
|
41 |
+
"""a job that loads yaml config."""
|
42 |
+
|
43 |
+
def __init__(self, fn):
|
44 |
+
"""
|
45 |
+
load a yaml config of a job and save generated configs as yaml for each task.
|
46 |
+
return: a list of files to run as specified by `run_task`.
|
47 |
+
"""
|
48 |
+
if fn.endswith(".py"):
|
49 |
+
# a python command.
|
50 |
+
self.backend = "python"
|
51 |
+
self.run_yamls = [fn]
|
52 |
+
return
|
53 |
+
|
54 |
+
job_config = recursive_config(fn)
|
55 |
+
if job_config.base_dir is None: # single file job config.
|
56 |
+
self.run_yamls = [fn]
|
57 |
+
return
|
58 |
+
|
59 |
+
self.project_dir = os.path.join("projects", job_config.project_dir)
|
60 |
+
self.run_dir = os.path.join("runs", job_config.project_dir)
|
61 |
+
|
62 |
+
if job_config.run_task is not None:
|
63 |
+
run_yamls = []
|
64 |
+
for stage in job_config.run_task:
|
65 |
+
# each stage can have multiple tasks running in parallel.
|
66 |
+
if OmegaConf.is_list(stage):
|
67 |
+
stage_yamls = []
|
68 |
+
for task_file in stage:
|
69 |
+
stage_yamls.append(
|
70 |
+
os.path.join(self.project_dir, task_file))
|
71 |
+
run_yamls.append(stage_yamls)
|
72 |
+
else:
|
73 |
+
run_yamls.append(os.path.join(self.project_dir, stage))
|
74 |
+
self.run_yamls = run_yamls
|
75 |
+
configs_to_save = self._overwrite_task(job_config)
|
76 |
+
self._save_configs(configs_to_save)
|
77 |
+
|
78 |
+
def __getitem__(self, idx):
|
79 |
+
yaml_files = self.run_yamls[idx]
|
80 |
+
if isinstance(yaml_files, list):
|
81 |
+
return [JobLauncher(yaml_file) for yaml_file in yaml_files]
|
82 |
+
return [JobLauncher(yaml_files)]
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
return len(self.run_yamls)
|
86 |
+
|
87 |
+
def _save_configs(self, configs_to_save: dict):
|
88 |
+
# save
|
89 |
+
os.makedirs(self.project_dir, exist_ok=True)
|
90 |
+
for config_file in configs_to_save:
|
91 |
+
config = configs_to_save[config_file]
|
92 |
+
print("saving", config_file)
|
93 |
+
OmegaConf.save(config=config, f=config_file)
|
94 |
+
|
95 |
+
def _overwrite_task(self, job_config):
|
96 |
+
configs_to_save = {}
|
97 |
+
self.base_project_dir = os.path.join("projects", job_config.base_dir)
|
98 |
+
self.base_run_dir = os.path.join("runs", job_config.base_dir)
|
99 |
+
|
100 |
+
for config_sets in job_config.task_group:
|
101 |
+
overwrite_config = job_config.task_group[config_sets]
|
102 |
+
if (
|
103 |
+
overwrite_config.task_list is None
|
104 |
+
or len(overwrite_config.task_list) == 0
|
105 |
+
):
|
106 |
+
print(
|
107 |
+
"[warning]",
|
108 |
+
job_config.task_group,
|
109 |
+
"has no task_list specified.")
|
110 |
+
# we don't want this added to a final config.
|
111 |
+
task_list = overwrite_config.pop("task_list", None)
|
112 |
+
for config_file in task_list:
|
113 |
+
config_file_path = os.path.join(
|
114 |
+
self.base_project_dir, config_file)
|
115 |
+
config = recursive_config(config_file_path)
|
116 |
+
# overwrite it.
|
117 |
+
if overwrite_config:
|
118 |
+
config = OmegaConf.merge(config, overwrite_config)
|
119 |
+
overwrite_dir(config, self.run_dir, basedir=self.base_run_dir)
|
120 |
+
save_file_path = os.path.join(self.project_dir, config_file)
|
121 |
+
configs_to_save[save_file_path] = config
|
122 |
+
return configs_to_save
|
123 |
+
|
124 |
+
|
125 |
+
def main(args):
|
126 |
+
job_type = args.jobtype if args.jobtype else None
|
127 |
+
# parse multiple pipelines.
|
128 |
+
pipelines = [Pipeline(fn) for fn in args.yamls.split(",")]
|
129 |
+
|
130 |
+
for pipe_id, pipeline in enumerate(pipelines):
|
131 |
+
if not hasattr(pipeline, "project_dir"):
|
132 |
+
for job in pipeline[0]:
|
133 |
+
job(job_type=job_type, dryrun=args.dryrun)
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
parser = argparse.ArgumentParser()
|
138 |
+
parser.add_argument("yamls", type=str)
|
139 |
+
parser.add_argument(
|
140 |
+
"--dryrun",
|
141 |
+
action="store_true",
|
142 |
+
help="run config and prepare to submit without launch the job.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--jobtype", type=str, default="",
|
146 |
+
help="force to run jobs as specified.")
|
147 |
+
args = parser.parse_args()
|
148 |
+
main(args)
|
fairseq/examples/MMPT/mmpt/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
try:
|
6 |
+
# fairseq user dir
|
7 |
+
from .datasets import FairseqMMDataset
|
8 |
+
from .losses import FairseqCriterion
|
9 |
+
from .models import FairseqMMModel
|
10 |
+
from .tasks import FairseqMMTask
|
11 |
+
except ImportError:
|
12 |
+
pass
|
fairseq/examples/MMPT/mmpt/datasets/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .mmdataset import *
|
6 |
+
|
7 |
+
try:
|
8 |
+
from .fairseqmmdataset import *
|
9 |
+
except ImportError:
|
10 |
+
pass
|
fairseq/examples/MMPT/mmpt/evaluators/metric.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
class Metric(object):
|
11 |
+
def __init__(self, config, metric_names):
|
12 |
+
self.metric_names = metric_names
|
13 |
+
|
14 |
+
def best_metric(self, metric):
|
15 |
+
return metric[self.metric_names[0]]
|
16 |
+
|
17 |
+
def save_metrics(self, fn, metrics):
|
18 |
+
with open(fn, "w") as fw:
|
19 |
+
json.dump(fw, metrics)
|
20 |
+
|
21 |
+
def print_computed_metrics(self, metrics):
|
22 |
+
raise NotImplementedError
|
23 |
+
|
24 |
+
|
25 |
+
class RetrievalMetric(Metric):
|
26 |
+
"""
|
27 |
+
this is modified from `howto100m/metrics.py`.
|
28 |
+
History of changes:
|
29 |
+
refactor as a class.
|
30 |
+
add metric_key in __init__
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, config, metric_names=["R1", "R5", "R10", "MR"]):
|
34 |
+
super().__init__(config, metric_names)
|
35 |
+
self.error = False # TODO(huxu): add to config to print error.
|
36 |
+
|
37 |
+
def compute_metrics(self, outputs, texts, **kwargs):
|
38 |
+
x = outputs
|
39 |
+
sx = np.sort(-x, axis=1)
|
40 |
+
d = np.diag(-x)
|
41 |
+
d = d[:, np.newaxis]
|
42 |
+
ind = sx - d
|
43 |
+
ind = np.where(ind == 0)
|
44 |
+
ind = ind[1]
|
45 |
+
metrics = {}
|
46 |
+
metrics["R1"] = float(np.sum(ind == 0)) / len(ind)
|
47 |
+
metrics["R5"] = float(np.sum(ind < 5)) / len(ind)
|
48 |
+
metrics["R10"] = float(np.sum(ind < 10)) / len(ind)
|
49 |
+
metrics["MR"] = np.median(ind) + 1
|
50 |
+
|
51 |
+
max_idx = np.argmax(outputs, axis=1)
|
52 |
+
if self.error:
|
53 |
+
# print top-20 errors.
|
54 |
+
error = []
|
55 |
+
for ex_idx in range(20):
|
56 |
+
error.append((texts[ex_idx], texts[max_idx[ex_idx]]))
|
57 |
+
metrics["error"] = error
|
58 |
+
return metrics
|
59 |
+
|
60 |
+
def print_computed_metrics(self, metrics):
|
61 |
+
r1 = metrics["R1"]
|
62 |
+
r5 = metrics["R5"]
|
63 |
+
r10 = metrics["R10"]
|
64 |
+
mr = metrics["MR"]
|
65 |
+
print(
|
66 |
+
"R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}".format(
|
67 |
+
r1, r5, r10, mr
|
68 |
+
)
|
69 |
+
)
|
70 |
+
if "error" in metrics:
|
71 |
+
print(metrics["error"])
|
72 |
+
|
73 |
+
|
74 |
+
class DiDeMoMetric(Metric):
|
75 |
+
"""
|
76 |
+
History of changes:
|
77 |
+
python 2.x to python 3.x.
|
78 |
+
merge utils.py into eval to save one file.
|
79 |
+
reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
|
80 |
+
Code to evaluate your results on the DiDeMo dataset.
|
81 |
+
"""
|
82 |
+
def __init__(self, config, metric_names=["rank1", "rank5", "miou"]):
|
83 |
+
super().__init__(config, metric_names)
|
84 |
+
|
85 |
+
def compute_metrics(self, outputs, targets, **kwargs):
|
86 |
+
assert len(outputs) == len(targets)
|
87 |
+
rank1, rank5, miou = self._eval_predictions(outputs, targets)
|
88 |
+
metrics = {
|
89 |
+
"rank1": rank1,
|
90 |
+
"rank5": rank5,
|
91 |
+
"miou": miou
|
92 |
+
}
|
93 |
+
return metrics
|
94 |
+
|
95 |
+
def print_computed_metrics(self, metrics):
|
96 |
+
rank1 = metrics["rank1"]
|
97 |
+
rank5 = metrics["rank5"]
|
98 |
+
miou = metrics["miou"]
|
99 |
+
# print("Average rank@1: %f" % rank1)
|
100 |
+
# print("Average rank@5: %f" % rank5)
|
101 |
+
# print("Average iou: %f" % miou)
|
102 |
+
|
103 |
+
print(
|
104 |
+
"Average rank@1: {:.4f} Average rank@5: {:.4f} Average iou: {:.4f}".format(
|
105 |
+
rank1, rank5, miou
|
106 |
+
)
|
107 |
+
)
|
108 |
+
|
109 |
+
def _iou(self, pred, gt):
|
110 |
+
intersection = max(0, min(pred[1], gt[1]) + 1 - max(pred[0], gt[0]))
|
111 |
+
union = max(pred[1], gt[1]) + 1 - min(pred[0], gt[0])
|
112 |
+
return float(intersection)/union
|
113 |
+
|
114 |
+
def _rank(self, pred, gt):
|
115 |
+
return pred.index(tuple(gt)) + 1
|
116 |
+
|
117 |
+
def _eval_predictions(self, segments, data):
|
118 |
+
'''
|
119 |
+
Inputs:
|
120 |
+
segments: For each item in the ground truth data, rank possible video segments given the description and video.
|
121 |
+
In DiDeMo, there are 21 posible moments extracted for each video so the list of video segments will be of length 21.
|
122 |
+
The first video segment should be the video segment that best corresponds to the text query.
|
123 |
+
There are 4180 sentence in the validation data, so when evaluating a model on the val dataset,
|
124 |
+
segments should be a list of lenght 4180, and each item in segments should be a list of length 21.
|
125 |
+
data: ground truth data
|
126 |
+
'''
|
127 |
+
average_ranks = []
|
128 |
+
average_iou = []
|
129 |
+
for s, d in zip(segments, data):
|
130 |
+
pred = s[0]
|
131 |
+
ious = [self._iou(pred, t) for t in d['times']]
|
132 |
+
average_iou.append(np.mean(np.sort(ious)[-3:]))
|
133 |
+
ranks = [self._rank(s, t) for t in d['times'] if tuple(t) in s] # if t in s] is added for s, e not in prediction.
|
134 |
+
average_ranks.append(np.mean(np.sort(ranks)[:3]))
|
135 |
+
rank1 = np.sum(np.array(average_ranks) <= 1)/float(len(average_ranks))
|
136 |
+
rank5 = np.sum(np.array(average_ranks) <= 5)/float(len(average_ranks))
|
137 |
+
miou = np.mean(average_iou)
|
138 |
+
|
139 |
+
# print("Average rank@1: %f" % rank1)
|
140 |
+
# print("Average rank@5: %f" % rank5)
|
141 |
+
# print("Average iou: %f" % miou)
|
142 |
+
return rank1, rank5, miou
|
143 |
+
|
144 |
+
|
145 |
+
class NLGMetric(Metric):
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
config,
|
149 |
+
metric_names=[
|
150 |
+
"Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4",
|
151 |
+
"METEOR", "ROUGE_L", "CIDEr"
|
152 |
+
]
|
153 |
+
):
|
154 |
+
super().__init__(config, metric_names)
|
155 |
+
# please install NLGEval from `https://github.com/Maluuba/nlg-eval`
|
156 |
+
from nlgeval import NLGEval
|
157 |
+
self.nlg = NLGEval()
|
158 |
+
|
159 |
+
def compute_metrics(self, outputs, targets, **kwargs):
|
160 |
+
return self.nlg.compute_metrics(
|
161 |
+
hyp_list=outputs, ref_list=targets)
|
162 |
+
|
163 |
+
def print_computed_metrics(self, metrics):
|
164 |
+
Bleu_1 = metrics["Bleu_1"]
|
165 |
+
Bleu_2 = metrics["Bleu_2"]
|
166 |
+
Bleu_3 = metrics["Bleu_3"]
|
167 |
+
Bleu_4 = metrics["Bleu_4"]
|
168 |
+
METEOR = metrics["METEOR"]
|
169 |
+
ROUGE_L = metrics["ROUGE_L"]
|
170 |
+
CIDEr = metrics["CIDEr"]
|
171 |
+
|
172 |
+
print(
|
173 |
+
"Bleu_1: {:.4f} - Bleu_2: {:.4f} - Bleu_3: {:.4f} - Bleu_4: {:.4f} - METEOR: {:.4f} - ROUGE_L: {:.4f} - CIDEr: {:.4f}".format(
|
174 |
+
Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, CIDEr
|
175 |
+
)
|
176 |
+
)
|
177 |
+
|
178 |
+
|
179 |
+
class QAMetric(Metric):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
config,
|
183 |
+
metric_names=["acc"]
|
184 |
+
):
|
185 |
+
super().__init__(config, metric_names)
|
186 |
+
|
187 |
+
def compute_metrics(self, outputs, targets, **kwargs):
|
188 |
+
from sklearn.metrics import accuracy_score
|
189 |
+
return {"acc": accuracy_score(targets, outputs)}
|
190 |
+
|
191 |
+
def print_computed_metrics(self, metrics):
|
192 |
+
print("acc: {:.4f}".format(metrics["acc"]))
|
193 |
+
|
194 |
+
|
195 |
+
class COINActionSegmentationMetric(Metric):
|
196 |
+
"""
|
197 |
+
COIN dataset listed 3 repos for Action Segmentation.
|
198 |
+
Action Sets, NeuralNetwork-Viterbi, TCFPN-ISBA.
|
199 |
+
The first and second are the same.
|
200 |
+
https://github.com/alexanderrichard/action-sets/blob/master/eval.py
|
201 |
+
|
202 |
+
Future reference for the third:
|
203 |
+
`https://github.com/Zephyr-D/TCFPN-ISBA/blob/master/utils/metrics.py`
|
204 |
+
"""
|
205 |
+
def __init__(self, config, metric_name=["frame_acc"]):
|
206 |
+
super().__init__(config, metric_name)
|
207 |
+
|
208 |
+
def compute_metrics(self, outputs, targets):
|
209 |
+
n_frames = 0
|
210 |
+
n_errors = 0
|
211 |
+
n_errors = sum(outputs != targets)
|
212 |
+
n_frames = len(targets)
|
213 |
+
return {"frame_acc": 1.0 - float(n_errors) / n_frames}
|
214 |
+
|
215 |
+
def print_computed_metrics(self, metrics):
|
216 |
+
fa = metrics["frame_acc"]
|
217 |
+
print("frame accuracy:", fa)
|
218 |
+
|
219 |
+
|
220 |
+
class CrossTaskMetric(Metric):
|
221 |
+
def __init__(self, config, metric_names=["recall"]):
|
222 |
+
super().__init__(config, metric_names)
|
223 |
+
|
224 |
+
def compute_metrics(self, outputs, targets, **kwargs):
|
225 |
+
"""refactored from line 166:
|
226 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
|
227 |
+
|
228 |
+
recalls = self._get_recalls(Y_true=targets, Y_pred=outputs)
|
229 |
+
results = {}
|
230 |
+
for task, rec in recalls.items():
|
231 |
+
results[str(task)] = rec
|
232 |
+
|
233 |
+
avg_recall = np.mean(list(recalls.values()))
|
234 |
+
results["recall"] = avg_recall
|
235 |
+
return results
|
236 |
+
|
237 |
+
def print_computed_metrics(self, metrics):
|
238 |
+
print('Recall: {0:0.3f}'.format(metrics["recall"]))
|
239 |
+
for task in metrics:
|
240 |
+
if task != "recall":
|
241 |
+
print('Task {0}. Recall = {1:0.3f}'.format(
|
242 |
+
task, metrics[task]))
|
243 |
+
|
244 |
+
def _get_recalls(self, Y_true, Y_pred):
|
245 |
+
"""refactored from
|
246 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
|
247 |
+
|
248 |
+
step_match = {task: 0 for task in Y_true.keys()}
|
249 |
+
step_total = {task: 0 for task in Y_true.keys()}
|
250 |
+
for task, ys_true in Y_true.items():
|
251 |
+
ys_pred = Y_pred[task]
|
252 |
+
for vid in set(ys_pred.keys()).intersection(set(ys_true.keys())):
|
253 |
+
y_true = ys_true[vid]
|
254 |
+
y_pred = ys_pred[vid]
|
255 |
+
step_total[task] += (y_true.sum(axis=0) > 0).sum()
|
256 |
+
step_match[task] += (y_true*y_pred).sum()
|
257 |
+
recalls = {
|
258 |
+
task: step_match[task] / n for task, n in step_total.items()}
|
259 |
+
return recalls
|
260 |
+
|
261 |
+
|
262 |
+
class ActionRecognitionMetric(Metric):
|
263 |
+
def __init__(
|
264 |
+
self,
|
265 |
+
config,
|
266 |
+
metric_names=["acc", "acc_splits", "r1_splits", "r5_splits", "r10_splits"]
|
267 |
+
):
|
268 |
+
super().__init__(config, metric_names)
|
269 |
+
|
270 |
+
def compute_metrics(self, outputs, targets, splits, **kwargs):
|
271 |
+
all_video_embd = outputs
|
272 |
+
labels = targets
|
273 |
+
split1, split2, split3 = splits
|
274 |
+
accs = []
|
275 |
+
r1s = []
|
276 |
+
r5s = []
|
277 |
+
r10s = []
|
278 |
+
for split in range(3):
|
279 |
+
if split == 0:
|
280 |
+
s = split1
|
281 |
+
elif split == 1:
|
282 |
+
s = split2
|
283 |
+
else:
|
284 |
+
s = split3
|
285 |
+
|
286 |
+
X_pred = all_video_embd[np.where(s == 2)[0]]
|
287 |
+
label_test = labels[np.where(s == 2)[0]]
|
288 |
+
logits = X_pred
|
289 |
+
X_pred = np.argmax(X_pred, axis=1)
|
290 |
+
acc = np.sum(X_pred == label_test) / float(len(X_pred))
|
291 |
+
accs.append(acc)
|
292 |
+
# compute recall.
|
293 |
+
sorted_pred = (-logits).argsort(axis=-1)
|
294 |
+
label_test_sp = label_test.reshape(-1, 1)
|
295 |
+
|
296 |
+
r1 = np.mean((sorted_pred[:, :1] == label_test_sp).sum(axis=1), axis=0)
|
297 |
+
r5 = np.mean((sorted_pred[:, :5] == label_test_sp).sum(axis=1), axis=0)
|
298 |
+
r10 = np.mean((sorted_pred[:, :10] == label_test_sp).sum(axis=1), axis=0)
|
299 |
+
r1s.append(r1)
|
300 |
+
r5s.append(r5)
|
301 |
+
r10s.append(r10)
|
302 |
+
|
303 |
+
return {"acc": accs[0], "acc_splits": accs, "r1_splits": r1s, "r5_splits": r5s, "r10_splits": r10s}
|
304 |
+
|
305 |
+
def print_computed_metrics(self, metrics):
|
306 |
+
for split, acc in enumerate(metrics["acc_splits"]):
|
307 |
+
print("Top 1 accuracy on split {}: {}; r1 {}; r5 {}; r10 {}".format(
|
308 |
+
split + 1, acc,
|
309 |
+
metrics["r1_splits"][split],
|
310 |
+
metrics["r5_splits"][split],
|
311 |
+
metrics["r10_splits"][split],
|
312 |
+
)
|
313 |
+
)
|
fairseq/examples/MMPT/mmpt/evaluators/predictor.py
ADDED
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import json
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import pickle
|
11 |
+
import math
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
class Predictor(object):
|
17 |
+
"""this base class is used to save predictions to disk
|
18 |
+
(and being called by a evaluator later).
|
19 |
+
Predictor has minimum support of single gpu prediction.
|
20 |
+
"""
|
21 |
+
def __init__(self, config):
|
22 |
+
self.pred_dir = None # on-the-fly eval does not save the results.
|
23 |
+
if hasattr(config, "eval") and config.eval is not None:
|
24 |
+
self.pred_dir = config.eval.save_path
|
25 |
+
os.makedirs(self.pred_dir, exist_ok=True)
|
26 |
+
|
27 |
+
def __call__(self, outputs):
|
28 |
+
"""extract the prediction and save it."""
|
29 |
+
raise NotImplementedError
|
30 |
+
|
31 |
+
def predict_loop(self, model, eval_dataloader, output_file=None):
|
32 |
+
"""on-the-fly prediction on a single gpu."""
|
33 |
+
self.full_scores = []
|
34 |
+
model.eval()
|
35 |
+
model = model.to(0)
|
36 |
+
with torch.no_grad():
|
37 |
+
for data in eval_dataloader:
|
38 |
+
data = self.to_ctx(data)
|
39 |
+
outputs = model(**data)
|
40 |
+
outputs.update(data)
|
41 |
+
self(outputs)
|
42 |
+
return self.finalize(output_file)
|
43 |
+
|
44 |
+
def finalize(self, output_file):
|
45 |
+
pass
|
46 |
+
|
47 |
+
def to_ctx(self, data, ctx=0, dtype=None):
|
48 |
+
if isinstance(data, dict):
|
49 |
+
for key in data:
|
50 |
+
if torch.is_tensor(data[key]):
|
51 |
+
if dtype is not None and data[key].dtype == torch.float32:
|
52 |
+
data[key] = data[key].to(dtype)
|
53 |
+
data[key] = data[key].to(ctx)
|
54 |
+
return data
|
55 |
+
else:
|
56 |
+
raise ValueError("non-dict type of batch is not supported yet.")
|
57 |
+
|
58 |
+
|
59 |
+
class NLGPredictor(Predictor):
|
60 |
+
"""Predicting Text from MMFusion models."""
|
61 |
+
"""TODO: make a context."""
|
62 |
+
def __init__(self, config):
|
63 |
+
super().__init__(config)
|
64 |
+
from transformers import AutoTokenizer
|
65 |
+
|
66 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
67 |
+
config.dataset.bert_name,
|
68 |
+
bos_token="[CLS]", eos_token="[SEP]")
|
69 |
+
self.bos_token_id = self.tokenizer.bos_token_id
|
70 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
71 |
+
|
72 |
+
def predict_loop(self, model, eval_dataloader, output_file=None):
|
73 |
+
"""TODO: refactor base classes."""
|
74 |
+
ctx = 0
|
75 |
+
outputs = {"outputs": [], "targets": [[]]}
|
76 |
+
model.eval()
|
77 |
+
model = model.to(ctx)
|
78 |
+
with torch.no_grad():
|
79 |
+
for data in tqdm(eval_dataloader):
|
80 |
+
data = self.to_ctx(data, ctx)
|
81 |
+
self(data, model, outputs)
|
82 |
+
return self.finalize(outputs, output_file)
|
83 |
+
|
84 |
+
def __call__(self, data, model, outputs):
|
85 |
+
data.update({
|
86 |
+
"bos_token_id": self.bos_token_id,
|
87 |
+
"eos_token_id": self.eos_token_id
|
88 |
+
})
|
89 |
+
|
90 |
+
output = model.generate(**data)
|
91 |
+
assert len(output) == len(data["ref"])
|
92 |
+
for idx, _output in enumerate(output):
|
93 |
+
generated_text = self.tokenizer.decode(
|
94 |
+
_output, skip_special_tokens=True)
|
95 |
+
if generated_text == "":
|
96 |
+
generated_text = "none"
|
97 |
+
outputs["outputs"].append(generated_text)
|
98 |
+
outputs["targets"][0].append(data["ref"][idx])
|
99 |
+
if random.random() < 0.001:
|
100 |
+
print("_output", _output)
|
101 |
+
print("generated_text", generated_text)
|
102 |
+
print("ref", data["ref"][idx])
|
103 |
+
|
104 |
+
def finalize(self, outputs, output_file=None):
|
105 |
+
if output_file is not None:
|
106 |
+
with open(os.path.join(
|
107 |
+
self.pred_dir, output_file + ".json"), "w") as fw:
|
108 |
+
json.dump(outputs, fw, indent=4)
|
109 |
+
return outputs
|
110 |
+
|
111 |
+
|
112 |
+
class RetrievalPredictor(Predictor):
|
113 |
+
"""generated `pooled_video` and `pooled_text`."""
|
114 |
+
def __init__(self, config):
|
115 |
+
super().__init__(config)
|
116 |
+
from transformers import AutoTokenizer
|
117 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
118 |
+
config.dataset.bert_name)
|
119 |
+
|
120 |
+
def predict_loop(
|
121 |
+
self,
|
122 |
+
model,
|
123 |
+
eval_dataloader,
|
124 |
+
output_file="retrieval.npy"
|
125 |
+
):
|
126 |
+
"""on-the-fly prediction on a single gpu."""
|
127 |
+
full_scores = []
|
128 |
+
texts = []
|
129 |
+
model.eval()
|
130 |
+
model = model.cuda()
|
131 |
+
with torch.no_grad():
|
132 |
+
for data in eval_dataloader:
|
133 |
+
# convert to dict.
|
134 |
+
if not isinstance(data, dict):
|
135 |
+
data = {
|
136 |
+
"caps": data[0],
|
137 |
+
"cmasks": data[1],
|
138 |
+
"vfeats": data[2],
|
139 |
+
"vmasks": data[3],
|
140 |
+
"video_id": data[4]
|
141 |
+
}
|
142 |
+
data = self.to_ctx(data)
|
143 |
+
outputs = model(**data)
|
144 |
+
outputs.update(data)
|
145 |
+
self(outputs, full_scores)
|
146 |
+
for _cap in data["caps"]:
|
147 |
+
texts.append(
|
148 |
+
self.tokenizer.decode(_cap, skip_special_tokens=True)
|
149 |
+
)
|
150 |
+
|
151 |
+
return self.finalize(full_scores, texts, output_file)
|
152 |
+
|
153 |
+
def __call__(self, sample, full_scores):
|
154 |
+
scores = self._get_pooled_outputs(sample)
|
155 |
+
self._append_scores(scores, full_scores)
|
156 |
+
|
157 |
+
def finalize(self, full_scores, texts, output_file=None):
|
158 |
+
outputs = self._aggregate_scores(full_scores)
|
159 |
+
if output_file is not None:
|
160 |
+
np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
|
161 |
+
return {"outputs": outputs, "texts": texts}
|
162 |
+
|
163 |
+
def _get_pooled_outputs(self, outputs):
|
164 |
+
if "pooled_video" in outputs:
|
165 |
+
return outputs["pooled_video"], outputs["pooled_text"]
|
166 |
+
else:
|
167 |
+
raise ValueError("unknown format of outputs.")
|
168 |
+
|
169 |
+
def _append_scores(self, scores, full_scores):
|
170 |
+
assert len(scores) == 2
|
171 |
+
if len(full_scores) == 0:
|
172 |
+
full_scores.append([])
|
173 |
+
full_scores.append([])
|
174 |
+
full_scores[0].append(scores[0].cpu().detach().numpy())
|
175 |
+
full_scores[1].append(scores[1].cpu().detach().numpy())
|
176 |
+
|
177 |
+
def _aggregate_scores(self, scores):
|
178 |
+
assert len(scores) == 2
|
179 |
+
video_hidden = np.concatenate(scores[0], axis=0)
|
180 |
+
text_hidden = np.concatenate(scores[1], axis=0)
|
181 |
+
# clear up.
|
182 |
+
self.full_scores = []
|
183 |
+
return np.matmul(text_hidden, video_hidden.T)
|
184 |
+
|
185 |
+
|
186 |
+
class QAPredictor(Predictor):
|
187 |
+
"""generated `pooled_video` and `pooled_text`."""
|
188 |
+
def __init__(self, config):
|
189 |
+
super().__init__(config)
|
190 |
+
"""predictor maintains scores and aggregate them."""
|
191 |
+
|
192 |
+
def predict_loop(self, model, eval_dataloader, output_file="qa.npy"):
|
193 |
+
"""on-the-fly prediction on a single gpu."""
|
194 |
+
self.full_scores = []
|
195 |
+
model.eval()
|
196 |
+
model = model.cuda()
|
197 |
+
with torch.no_grad():
|
198 |
+
for data in eval_dataloader:
|
199 |
+
# reshape ans and dup video 5 times.
|
200 |
+
v_len = data["vfeats"].size(1)
|
201 |
+
hidden_size = data["vfeats"].size(2)
|
202 |
+
data["vfeats"] = data["vfeats"].unsqueeze(1).repeat(1, 5, 1, 1).view(-1, v_len, hidden_size)
|
203 |
+
data["vmasks"] = data["vmasks"].unsqueeze(1).repeat(1, 5, 1).view(-1, v_len)
|
204 |
+
|
205 |
+
t_len = data["caps"].size(-1)
|
206 |
+
data["caps"] = data["caps"].view(-1, t_len)
|
207 |
+
data["cmasks"] = data["cmasks"].view(-1, t_len)
|
208 |
+
|
209 |
+
data = self.to_ctx(data)
|
210 |
+
outputs = model(**data)
|
211 |
+
outputs.update(data)
|
212 |
+
self(outputs)
|
213 |
+
return self.finalize(output_file)
|
214 |
+
|
215 |
+
def __call__(self, sample):
|
216 |
+
hidden_size = sample["pooled_video"].size(-1)
|
217 |
+
pooled_video = sample["pooled_video"].view(-1, 5, hidden_size)
|
218 |
+
pooled_text = sample["pooled_text"].view(-1, 5, hidden_size)
|
219 |
+
scores = torch.bmm(pooled_video, pooled_text.transpose(2, 1))
|
220 |
+
scores = scores.argmax(-1)
|
221 |
+
self._append_scores(scores[:, 0], sample["answers"], self.full_scores)
|
222 |
+
|
223 |
+
def finalize(self, output_file=None):
|
224 |
+
outputs, targets = self._aggregate_scores(self.full_scores)
|
225 |
+
if output_file is not None:
|
226 |
+
np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
|
227 |
+
return {"outputs": outputs, "targets": targets}
|
228 |
+
|
229 |
+
def _append_scores(self, scores, answers, full_scores):
|
230 |
+
if len(full_scores) == 0:
|
231 |
+
full_scores.append([])
|
232 |
+
full_scores.append([])
|
233 |
+
full_scores[0].append(scores.cpu().detach().numpy())
|
234 |
+
full_scores[1].append(answers.cpu().detach().numpy())
|
235 |
+
|
236 |
+
def _aggregate_scores(self, scores):
|
237 |
+
assert len(scores) == 2
|
238 |
+
outputs = np.concatenate(scores[0], axis=0)
|
239 |
+
targets = np.concatenate(scores[1], axis=0)
|
240 |
+
# clear up.
|
241 |
+
self.full_scores = []
|
242 |
+
return outputs, targets
|
243 |
+
|
244 |
+
|
245 |
+
class CrossTaskPredictor(Predictor):
|
246 |
+
"""
|
247 |
+
CrossTaskPredictor needs to compute the average of logits
|
248 |
+
for overlapped sliding-window.
|
249 |
+
"""
|
250 |
+
def __init__(self, config):
|
251 |
+
super().__init__(config)
|
252 |
+
self.lsm = torch.nn.LogSoftmax(dim=1)
|
253 |
+
self.max_video_len = config.dataset.max_video_len
|
254 |
+
self.sliding_window = config.dataset.sliding_window
|
255 |
+
self.sliding_window_size = config.dataset.sliding_window_size
|
256 |
+
self.annotation_path = config.dataset.annotation_path
|
257 |
+
|
258 |
+
def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
|
259 |
+
"""refactored from line 144:
|
260 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py
|
261 |
+
"""
|
262 |
+
ctx = 0
|
263 |
+
model.eval()
|
264 |
+
model = model.to(ctx)
|
265 |
+
# this is not a loss but just compute neg_log_prob.
|
266 |
+
Y_pred = {}
|
267 |
+
Y_true = {}
|
268 |
+
with torch.no_grad():
|
269 |
+
for batch in eval_dataloader:
|
270 |
+
self(batch, model, Y_pred, Y_true)
|
271 |
+
return self.finalize(Y_pred, Y_true, output_file)
|
272 |
+
|
273 |
+
def __call__(self, sample, model, Y_pred, Y_true):
|
274 |
+
# please install dp from `https://github.com/DmZhukov/CrossTask`
|
275 |
+
from dp import dp
|
276 |
+
vid, task = sample['video_id'][0], sample['task'][0]
|
277 |
+
sample = self.to_ctx(sample)
|
278 |
+
# compute the average logits over sliding windows.
|
279 |
+
output = model(**sample)
|
280 |
+
batch_logits = output["logits"].cpu()
|
281 |
+
|
282 |
+
video_len = sample["video_len"][0]
|
283 |
+
|
284 |
+
# the following version is slow.
|
285 |
+
logits = torch.zeros((video_len, batch_logits.size(1)))
|
286 |
+
logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
|
287 |
+
# use the same loop as aligner to recover.
|
288 |
+
batch_logit_idx = 0
|
289 |
+
for window_start in range(0, video_len, self.sliding_window):
|
290 |
+
video_end = min(video_len - window_start, self.sliding_window_size)
|
291 |
+
logits[window_start: window_start + video_end] += batch_logits[
|
292 |
+
batch_logit_idx: batch_logit_idx + video_end]
|
293 |
+
batch_logit_idx += video_end
|
294 |
+
logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
|
295 |
+
|
296 |
+
if (video_len - window_start) <= self.sliding_window_size:
|
297 |
+
break
|
298 |
+
|
299 |
+
logits /= logits_counts
|
300 |
+
assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
|
301 |
+
|
302 |
+
O = self.lsm(logits)
|
303 |
+
y = np.zeros(O.size(), dtype=np.float32)
|
304 |
+
dp(y, -O.detach().cpu().numpy())
|
305 |
+
if task not in Y_pred:
|
306 |
+
Y_pred[task] = {}
|
307 |
+
Y_pred[task][vid] = y
|
308 |
+
annot_path = os.path.join(
|
309 |
+
self.annotation_path, task+'_'+vid+'.csv')
|
310 |
+
if os.path.exists(annot_path):
|
311 |
+
if task not in Y_true:
|
312 |
+
Y_true[task] = {}
|
313 |
+
Y_true[task][vid] = self._read_assignment(
|
314 |
+
*y.shape, annot_path)
|
315 |
+
|
316 |
+
def finalize(self, Y_pred, Y_true, output_file=None):
|
317 |
+
if output_file is not None:
|
318 |
+
with open(
|
319 |
+
os.path.join(self.pred_dir, output_file + ".pkl"),
|
320 |
+
"wb") as fw:
|
321 |
+
pickle.dump(
|
322 |
+
{"Y_pred": Y_pred, "Y_true": Y_true}, fw,
|
323 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
324 |
+
return {"outputs": Y_pred, "targets": Y_true}
|
325 |
+
|
326 |
+
def _read_assignment(self, T, K, path):
|
327 |
+
"""
|
328 |
+
refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py
|
329 |
+
Howto interpret contraints on loss that is going to be minimized:
|
330 |
+
lambd is a big number;
|
331 |
+
self.lambd * C is a big number for all valid position (csv stores invalids)
|
332 |
+
|
333 |
+
def forward(self, O, Y, C):
|
334 |
+
return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum()
|
335 |
+
|
336 |
+
This will load the csv file and fill-in the step col from start to end rows.
|
337 |
+
"""
|
338 |
+
|
339 |
+
Y = np.zeros([T, K], dtype=np.uint8)
|
340 |
+
with open(path, 'r') as f:
|
341 |
+
for line in f:
|
342 |
+
step, start, end = line.strip().split(',')
|
343 |
+
start = int(math.floor(float(start)))
|
344 |
+
end = int(math.ceil(float(end)))
|
345 |
+
step = int(step) - 1
|
346 |
+
Y[start:end, step] = 1
|
347 |
+
return Y
|
348 |
+
|
349 |
+
|
350 |
+
class COINPredictor(Predictor):
|
351 |
+
"""
|
352 |
+
COINPredictor is similar to CrossTask on sliding windows.
|
353 |
+
"""
|
354 |
+
def __init__(self, config):
|
355 |
+
super().__init__(config)
|
356 |
+
self.max_video_len = config.dataset.max_video_len
|
357 |
+
self.sliding_window = config.dataset.sliding_window
|
358 |
+
self.sliding_window_size = config.dataset.sliding_window_size
|
359 |
+
|
360 |
+
def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
|
361 |
+
"""refactored from line 144:
|
362 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py
|
363 |
+
"""
|
364 |
+
ctx = 0
|
365 |
+
model.eval()
|
366 |
+
model = model.to(ctx)
|
367 |
+
# this is not a loss but just compute neg_log_prob.
|
368 |
+
Y_pred = []
|
369 |
+
Y_true = []
|
370 |
+
with torch.no_grad():
|
371 |
+
for batch in eval_dataloader:
|
372 |
+
self(batch, model, Y_pred, Y_true)
|
373 |
+
return self.finalize(Y_pred, Y_true, output_file)
|
374 |
+
|
375 |
+
def __call__(self, sample, model, Y_pred, Y_true):
|
376 |
+
sample = self.to_ctx(sample)
|
377 |
+
# compute the average logits over sliding windows.
|
378 |
+
output = model(**sample)
|
379 |
+
logits = self._merge_windows(sample, output)
|
380 |
+
Y_pred.append(logits.argmax(dim=1))
|
381 |
+
Y_true.append(sample["video_targets"].squeeze(0).cpu())
|
382 |
+
|
383 |
+
def _merge_windows(self, sample, output):
|
384 |
+
targets = sample["targets"].reshape(-1).cpu()
|
385 |
+
valid_mask = targets != -100
|
386 |
+
targets = targets[valid_mask]
|
387 |
+
batch_logits = output["logits"].cpu()
|
388 |
+
batch_logits = batch_logits.reshape(-1, batch_logits.size(-1))
|
389 |
+
batch_logits = batch_logits[valid_mask]
|
390 |
+
|
391 |
+
video_len = sample["video_len"][0]
|
392 |
+
|
393 |
+
# the following version is slow.
|
394 |
+
logits = torch.zeros((video_len, batch_logits.size(1)))
|
395 |
+
logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
|
396 |
+
# use the same loop as aligner to recover.
|
397 |
+
batch_logit_idx = 0
|
398 |
+
for window_start in range(0, video_len, self.sliding_window):
|
399 |
+
video_end = min(video_len - window_start, self.sliding_window_size)
|
400 |
+
logits[window_start: window_start + video_end] += batch_logits[
|
401 |
+
batch_logit_idx: batch_logit_idx + video_end]
|
402 |
+
batch_logit_idx += video_end
|
403 |
+
logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
|
404 |
+
if (video_len - window_start) <= self.sliding_window_size:
|
405 |
+
break
|
406 |
+
logits /= logits_counts
|
407 |
+
assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
|
408 |
+
return logits
|
409 |
+
|
410 |
+
def finalize(self, Y_pred, Y_true, output_file=None):
|
411 |
+
Y_pred = torch.cat(Y_pred, dim=0).numpy()
|
412 |
+
Y_true = torch.cat(Y_true, dim=0).numpy()
|
413 |
+
assert len(Y_pred) == len(Y_true)
|
414 |
+
|
415 |
+
error_mask = Y_pred != Y_true
|
416 |
+
print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
|
417 |
+
print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
|
418 |
+
|
419 |
+
if output_file is not None:
|
420 |
+
with open(
|
421 |
+
os.path.join(self.pred_dir, output_file + ".pkl"),
|
422 |
+
"wb") as fw:
|
423 |
+
pickle.dump(
|
424 |
+
{"Y_pred": Y_pred, "Y_true": Y_true}, fw,
|
425 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
426 |
+
return {"outputs": Y_pred, "targets": Y_true}
|
427 |
+
|
428 |
+
|
429 |
+
class COINZSPredictor(COINPredictor):
|
430 |
+
"""
|
431 |
+
COINZSPredictor for COIN zero-shot prediction.
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self, config):
|
435 |
+
super().__init__(config)
|
436 |
+
self.dataset_config = config.dataset
|
437 |
+
|
438 |
+
def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
|
439 |
+
"""refactored from line 144:
|
440 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py
|
441 |
+
"""
|
442 |
+
ctx = 0
|
443 |
+
model.eval()
|
444 |
+
model = model.to(ctx)
|
445 |
+
|
446 |
+
with torch.no_grad():
|
447 |
+
outputs = eval_dataloader.dataset.meta_processor.meta_text_labels(
|
448 |
+
self.dataset_config)
|
449 |
+
outputs = self.to_ctx(outputs, ctx)
|
450 |
+
label_hidden_states = model.forward_text(**outputs).cpu()
|
451 |
+
label_sim = label_hidden_states @ label_hidden_states.t()
|
452 |
+
num_labels = label_sim.size(0)
|
453 |
+
eye_mask = ~torch.eye(num_labels, dtype=torch.bool)
|
454 |
+
label_sim = label_sim.masked_select(eye_mask).view(num_labels, num_labels - 1)
|
455 |
+
lbd = label_sim.max()
|
456 |
+
|
457 |
+
# this is not a loss but just compute neg_log_prob.
|
458 |
+
Y_pred = []
|
459 |
+
Y_true = []
|
460 |
+
with torch.no_grad():
|
461 |
+
for batch in eval_dataloader:
|
462 |
+
self(batch, label_hidden_states, model, lbd, Y_pred, Y_true)
|
463 |
+
return self.finalize(Y_pred, Y_true, output_file)
|
464 |
+
|
465 |
+
def reshape_subsample(self, sample):
|
466 |
+
for key in sample:
|
467 |
+
if torch.is_tensor(sample[key]):
|
468 |
+
sample[key] = self.flat_subsample(sample[key])
|
469 |
+
return sample
|
470 |
+
|
471 |
+
def flat_subsample(self, tensor):
|
472 |
+
if len(tensor.size()) > 1 and tensor.size(0) == 1:
|
473 |
+
tensor = tensor.squeeze(0)
|
474 |
+
return tensor
|
475 |
+
|
476 |
+
def __call__(self, sample, label_hidden_states, model, lbd, Y_pred, Y_true):
|
477 |
+
sample = self.reshape_subsample(sample)
|
478 |
+
sample = self.to_ctx(sample)
|
479 |
+
# compute the average logits over sliding windows.
|
480 |
+
sample["output_hidden_states"] = True
|
481 |
+
video_outputs = model.forward_video(**sample).cpu()
|
482 |
+
output = {"logits": video_outputs[:, 1:sample["vmasks"].size(1)+1] @ label_hidden_states.t()}
|
483 |
+
logits = self._merge_windows(sample, output)
|
484 |
+
# logic of zero-shot for sequence labeling.
|
485 |
+
logits_argmax = logits.argmax(dim=1) + 1 # 0 is "O" label.
|
486 |
+
logits_max = logits.max(dim=1)[0]
|
487 |
+
|
488 |
+
pred = torch.zeros_like(logits_argmax)
|
489 |
+
label_select = logits_max > lbd # 73 or 74
|
490 |
+
pred[label_select] = logits_argmax[label_select]
|
491 |
+
|
492 |
+
Y_pred.append(pred)
|
493 |
+
Y_true.append(sample["video_targets"].squeeze(0).cpu())
|
494 |
+
|
495 |
+
def finalize(self, Y_pred, Y_true, output_file=None):
|
496 |
+
Y_pred = torch.cat(Y_pred, dim=0).numpy()
|
497 |
+
Y_true = torch.cat(Y_true, dim=0).numpy()
|
498 |
+
assert len(Y_pred) == len(Y_true)
|
499 |
+
|
500 |
+
error_mask = Y_pred != Y_true
|
501 |
+
print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
|
502 |
+
print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
|
503 |
+
|
504 |
+
if output_file is not None:
|
505 |
+
with open(
|
506 |
+
os.path.join(self.pred_dir, output_file + ".pkl"),
|
507 |
+
"wb") as fw:
|
508 |
+
pickle.dump(
|
509 |
+
{"Y_pred": Y_pred, "Y_true": Y_true}, fw,
|
510 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
511 |
+
return {"outputs": Y_pred, "targets": Y_true}
|
512 |
+
|
513 |
+
|
514 |
+
class DiDeMoPredictor(Predictor):
|
515 |
+
"""reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
|
516 |
+
https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
|
517 |
+
"""
|
518 |
+
def __init__(self, config):
|
519 |
+
super().__init__(config)
|
520 |
+
# load targets.
|
521 |
+
with open(config.dataset.test_path) as data_file:
|
522 |
+
self.test_data = json.load(data_file)
|
523 |
+
|
524 |
+
def predict_loop(self, model, eval_dataloader, output_file="didemo.npy"):
|
525 |
+
"""
|
526 |
+
TODO: two solutions here.
|
527 |
+
"""
|
528 |
+
import itertools
|
529 |
+
# 21 chunks.
|
530 |
+
self.possible_segments = [(0,0), (1,1), (2,2), (3,3), (4,4), (5,5)]
|
531 |
+
for i in itertools.combinations(range(6), 2):
|
532 |
+
self.possible_segments.append(i)
|
533 |
+
# pick segments from a video.
|
534 |
+
|
535 |
+
"""on-the-fly prediction on a single gpu."""
|
536 |
+
self.full_scores = []
|
537 |
+
model.eval()
|
538 |
+
model = model.cuda()
|
539 |
+
with torch.no_grad():
|
540 |
+
for data in eval_dataloader:
|
541 |
+
# TODO special forwarding logic here.
|
542 |
+
data = self.to_ctx(data)
|
543 |
+
data["output_hidden_states"] = True
|
544 |
+
hidden_video = model.forward_video(**data)
|
545 |
+
data["output_hidden_states"] = False
|
546 |
+
pooled_text = model.forward_text(**data)
|
547 |
+
outputs = {
|
548 |
+
"hidden_video": hidden_video,
|
549 |
+
"pooled_text": pooled_text
|
550 |
+
}
|
551 |
+
outputs.update(data)
|
552 |
+
self(outputs)
|
553 |
+
return self.finalize(output_file)
|
554 |
+
|
555 |
+
def __call__(self, sample):
|
556 |
+
# TODO: make an index select from self.possible_segments.
|
557 |
+
hidden_video = sample["hidden_video"]
|
558 |
+
pooled_text = sample["pooled_text"]
|
559 |
+
vmasks = sample["vmasks"]
|
560 |
+
# probably maintain valid results here.
|
561 |
+
|
562 |
+
hidden_video = hidden_video[:, 1:-1, :]
|
563 |
+
# probably maintain valid results here.
|
564 |
+
pooled_video = []
|
565 |
+
for s, e in self.possible_segments:
|
566 |
+
pooled_video.append(
|
567 |
+
torch.mean(
|
568 |
+
hidden_video[:, int(s*5):int((e+1)*5), :],
|
569 |
+
dim=1, keepdim=True)
|
570 |
+
)
|
571 |
+
pooled_video = torch.cat(pooled_video, dim=1)
|
572 |
+
scores = torch.bmm(
|
573 |
+
pooled_video, pooled_text.unsqueeze(-1)).squeeze(-1).cpu()
|
574 |
+
|
575 |
+
ranks = scores.argsort(dim=-1, descending=True)
|
576 |
+
|
577 |
+
for batch_idx, rank in enumerate(ranks):
|
578 |
+
rank_of_moment = []
|
579 |
+
for m_idx, moment in enumerate(rank):
|
580 |
+
s, e = self.possible_segments[moment.item()]
|
581 |
+
if torch.any(
|
582 |
+
vmasks[batch_idx, int(s*5):int((e+1)*5)]
|
583 |
+
):
|
584 |
+
rank_of_moment.append((s, e))
|
585 |
+
self.full_scores.append(rank_of_moment)
|
586 |
+
|
587 |
+
def finalize(self, output_file=None):
|
588 |
+
outputs = self._aggregate_scores(self.full_scores)
|
589 |
+
if output_file is not None:
|
590 |
+
np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
|
591 |
+
return {"outputs": outputs, "targets": self.test_data}
|
592 |
+
|
593 |
+
def _aggregate_scores(self, scores):
|
594 |
+
self.full_scores = []
|
595 |
+
return scores
|
fairseq/examples/MMPT/mmpt/losses/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .loss import *
|
6 |
+
from .nce import *
|
7 |
+
|
8 |
+
try:
|
9 |
+
from .fairseqmmloss import *
|
10 |
+
except ImportError:
|
11 |
+
pass
|
12 |
+
|
13 |
+
try:
|
14 |
+
from .expnce import *
|
15 |
+
except ImportError:
|
16 |
+
pass
|
fairseq/examples/MMPT/mmpt/losses/fairseqmmloss.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
TODO (huxu): a general fairseq criterion for all your pre-defined losses.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
11 |
+
from fairseq.logging import metrics
|
12 |
+
|
13 |
+
|
14 |
+
@register_criterion("mmloss")
|
15 |
+
class MMCriterion(FairseqCriterion):
|
16 |
+
def __init__(self, task):
|
17 |
+
super().__init__(task)
|
18 |
+
# TODO (huxu): wrap forward call of loss_fn and eval_fn into task.
|
19 |
+
self.mmtask = task.mmtask
|
20 |
+
|
21 |
+
def forward(self, model, sample):
|
22 |
+
"""Compute the loss for the given sample.
|
23 |
+
Returns a tuple with three elements:
|
24 |
+
1) the loss
|
25 |
+
2) the sample size, which is used as the denominator for the gradient
|
26 |
+
3) logging outputs to display while training
|
27 |
+
"""
|
28 |
+
outputs = self.mmtask(model, sample)
|
29 |
+
|
30 |
+
loss, loss_scalar, max_len, batch_size, sample_size = (
|
31 |
+
outputs["loss"],
|
32 |
+
outputs["loss_scalar"],
|
33 |
+
outputs["max_len"],
|
34 |
+
outputs["batch_size"],
|
35 |
+
outputs["sample_size"],
|
36 |
+
)
|
37 |
+
|
38 |
+
logging_output = {
|
39 |
+
"loss": loss_scalar,
|
40 |
+
"ntokens": max_len * batch_size, # dummy report.
|
41 |
+
"nsentences": batch_size, # dummy report.
|
42 |
+
"sample_size": sample_size,
|
43 |
+
}
|
44 |
+
|
45 |
+
return loss, 1, logging_output
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def reduce_metrics(logging_outputs) -> None:
|
49 |
+
"""Aggregate logging outputs from data parallel training."""
|
50 |
+
"""since we use NCE, our actual batch_size is 1 per GPU.
|
51 |
+
Then we take the mean of each worker."""
|
52 |
+
loss_sum = sum(log.get("loss", 0.0) for log in logging_outputs)
|
53 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
54 |
+
metrics.log_scalar("loss", loss_sum / sample_size, round=3)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def logging_outputs_can_be_summed() -> bool:
|
58 |
+
"""
|
59 |
+
Whether the logging outputs returned by `forward` can be summed
|
60 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
61 |
+
to True will improves distributed training speed.
|
62 |
+
"""
|
63 |
+
return True
|
fairseq/examples/MMPT/mmpt/losses/loss.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class Loss(object):
|
9 |
+
def __call__(self, *args, **kwargs):
|
10 |
+
raise NotImplementedError
|
11 |
+
|
12 |
+
|
13 |
+
# Dummy Loss for testing.
|
14 |
+
class DummyLoss(Loss):
|
15 |
+
def __init__(self):
|
16 |
+
self.loss = nn.CrossEntropyLoss()
|
17 |
+
|
18 |
+
def __call__(self, logits, targets, **kwargs):
|
19 |
+
return self.loss(logits, targets)
|
20 |
+
|
21 |
+
|
22 |
+
class DummyK400Loss(Loss):
|
23 |
+
"""dummy k400 loss for MViT."""
|
24 |
+
def __init__(self):
|
25 |
+
self.loss = nn.CrossEntropyLoss()
|
26 |
+
|
27 |
+
def __call__(self, logits, targets, **kwargs):
|
28 |
+
return self.loss(
|
29 |
+
logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))
|
30 |
+
|
31 |
+
|
32 |
+
class CrossEntropy(Loss):
|
33 |
+
def __init__(self):
|
34 |
+
self.loss = nn.CrossEntropyLoss()
|
35 |
+
|
36 |
+
def __call__(self, logits, targets, **kwargs):
|
37 |
+
return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
|
38 |
+
|
39 |
+
|
40 |
+
class ArgmaxCrossEntropy(Loss):
|
41 |
+
def __init__(self):
|
42 |
+
self.loss = nn.CrossEntropyLoss()
|
43 |
+
|
44 |
+
def __call__(self, logits, targets, **kwargs):
|
45 |
+
return self.loss(logits, targets.argmax(dim=1))
|
46 |
+
|
47 |
+
|
48 |
+
class BCE(Loss):
|
49 |
+
def __init__(self):
|
50 |
+
self.loss = nn.BCEWithLogitsLoss()
|
51 |
+
|
52 |
+
def __call__(self, logits, targets, **kwargs):
|
53 |
+
targets = targets.squeeze(0)
|
54 |
+
return self.loss(logits, targets)
|
55 |
+
|
56 |
+
|
57 |
+
class NLGLoss(Loss):
|
58 |
+
def __init__(self):
|
59 |
+
self.loss = nn.CrossEntropyLoss()
|
60 |
+
|
61 |
+
def __call__(self, logits, text_label, **kwargs):
|
62 |
+
targets = text_label[text_label != -100]
|
63 |
+
return self.loss(logits, targets)
|
64 |
+
|
65 |
+
|
66 |
+
class MSE(Loss):
|
67 |
+
def __init__(self):
|
68 |
+
self.loss = nn.MSELoss()
|
69 |
+
|
70 |
+
def __call__(self, logits, targets, **kwargs):
|
71 |
+
return self.loss(logits, targets)
|
72 |
+
|
73 |
+
|
74 |
+
class L1(Loss):
|
75 |
+
def __init__(self):
|
76 |
+
self.loss = nn.L1Loss()
|
77 |
+
|
78 |
+
def __call__(self, logits, targets, **kwargs):
|
79 |
+
return self.loss(logits, targets)
|
80 |
+
|
81 |
+
|
82 |
+
class SmoothL1(Loss):
|
83 |
+
def __init__(self):
|
84 |
+
self.loss = nn.SmoothL1Loss()
|
85 |
+
|
86 |
+
def __call__(self, logits, targets, **kwargs):
|
87 |
+
return self.loss(logits, targets)
|
fairseq/examples/MMPT/mmpt/losses/nce.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
softmax-based NCE loss, used by this project.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from .loss import Loss
|
15 |
+
|
16 |
+
|
17 |
+
class NCE(Loss):
|
18 |
+
def __init__(self):
|
19 |
+
# TODO (huxu): define temperature.
|
20 |
+
self.loss = nn.CrossEntropyLoss()
|
21 |
+
|
22 |
+
def __call__(self, align_scores, **kargs):
|
23 |
+
# note: we reuse the same shape as cls head in BERT (batch_size, 2)
|
24 |
+
# but NCE only needs one logits.
|
25 |
+
# (so we drop all weights in the second neg logits.)
|
26 |
+
align_scores = align_scores[:, :1]
|
27 |
+
# duplicate negative examples
|
28 |
+
batch_size = align_scores.size(0) // 2
|
29 |
+
pos_scores = align_scores[:batch_size]
|
30 |
+
neg_scores = align_scores[batch_size:].view(1, batch_size).repeat(
|
31 |
+
batch_size, 1)
|
32 |
+
scores = torch.cat([pos_scores, neg_scores], dim=1)
|
33 |
+
return self.loss(
|
34 |
+
scores,
|
35 |
+
torch.zeros(
|
36 |
+
(batch_size,),
|
37 |
+
dtype=torch.long,
|
38 |
+
device=align_scores.device),
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
class T2VContraLoss(Loss):
|
43 |
+
"""NCE for MM joint space, on softmax text2video matrix.
|
44 |
+
"""
|
45 |
+
def __init__(self):
|
46 |
+
# TODO (huxu): define temperature.
|
47 |
+
self.loss = nn.CrossEntropyLoss()
|
48 |
+
|
49 |
+
def __call__(self, pooled_video, pooled_text, **kargs):
|
50 |
+
batch_size = pooled_video.size(0)
|
51 |
+
logits = torch.mm(pooled_text, pooled_video.transpose(1, 0))
|
52 |
+
targets = torch.arange(
|
53 |
+
batch_size,
|
54 |
+
dtype=torch.long,
|
55 |
+
device=pooled_video.device)
|
56 |
+
return self.loss(logits, targets)
|
57 |
+
|
58 |
+
|
59 |
+
class V2TContraLoss(Loss):
|
60 |
+
"""NCE for MM joint space, with softmax on video2text matrix."""
|
61 |
+
|
62 |
+
def __init__(self):
|
63 |
+
# TODO (huxu): define temperature.
|
64 |
+
self.loss = nn.CrossEntropyLoss()
|
65 |
+
|
66 |
+
def __call__(self, pooled_video, pooled_text, **kargs):
|
67 |
+
batch_size = pooled_video.size(0)
|
68 |
+
logits = torch.mm(pooled_video, pooled_text.transpose(1, 0))
|
69 |
+
targets = torch.arange(
|
70 |
+
batch_size,
|
71 |
+
dtype=torch.long,
|
72 |
+
device=pooled_video.device)
|
73 |
+
return self.loss(logits, targets)
|
74 |
+
|
75 |
+
|
76 |
+
class MMContraLoss(Loss):
|
77 |
+
def __init__(self):
|
78 |
+
self.loss = nn.CrossEntropyLoss()
|
79 |
+
|
80 |
+
def __call__(self, pooled_video, pooled_text, **kwargs):
|
81 |
+
logits_per_video = pooled_video @ pooled_text.t()
|
82 |
+
logits_per_text = pooled_text @ pooled_video.t()
|
83 |
+
|
84 |
+
targets = torch.arange(
|
85 |
+
pooled_video.size(0),
|
86 |
+
dtype=torch.long,
|
87 |
+
device=pooled_video.device)
|
88 |
+
loss_video = self.loss(logits_per_video, targets)
|
89 |
+
loss_text = self.loss(logits_per_text, targets)
|
90 |
+
return loss_video + loss_text
|
91 |
+
|
92 |
+
|
93 |
+
class MTM(Loss):
|
94 |
+
"""Combination of MFM and MLM."""
|
95 |
+
|
96 |
+
def __init__(self):
|
97 |
+
self.loss = nn.CrossEntropyLoss()
|
98 |
+
|
99 |
+
def __call__(
|
100 |
+
self,
|
101 |
+
video_logits,
|
102 |
+
text_logits,
|
103 |
+
video_label,
|
104 |
+
text_label,
|
105 |
+
**kwargs
|
106 |
+
):
|
107 |
+
text_logits = torch.cat([
|
108 |
+
text_logits,
|
109 |
+
torch.zeros(
|
110 |
+
(text_logits.size(0), 1), device=text_logits.device)
|
111 |
+
], dim=1)
|
112 |
+
vt_logits = torch.cat([video_logits, text_logits], dim=0)
|
113 |
+
# loss for video.
|
114 |
+
video_label = torch.zeros(
|
115 |
+
(video_logits.size(0),),
|
116 |
+
dtype=torch.long,
|
117 |
+
device=video_logits.device
|
118 |
+
)
|
119 |
+
|
120 |
+
# loss for text.
|
121 |
+
text_label = text_label.reshape(-1)
|
122 |
+
labels_mask = text_label != -100
|
123 |
+
selected_text_label = text_label[labels_mask]
|
124 |
+
|
125 |
+
vt_label = torch.cat([video_label, selected_text_label], dim=0)
|
126 |
+
return self.loss(vt_logits, vt_label)
|
127 |
+
|
128 |
+
|
129 |
+
class MFMMLM(Loss):
|
130 |
+
"""Combination of MFM and MLM."""
|
131 |
+
|
132 |
+
def __init__(self):
|
133 |
+
self.loss = nn.CrossEntropyLoss()
|
134 |
+
|
135 |
+
def __call__(
|
136 |
+
self,
|
137 |
+
video_logits,
|
138 |
+
text_logits,
|
139 |
+
video_label,
|
140 |
+
text_label,
|
141 |
+
**kwargs
|
142 |
+
):
|
143 |
+
# loss for video.
|
144 |
+
video_label = torch.zeros(
|
145 |
+
(video_logits.size(0),),
|
146 |
+
dtype=torch.long,
|
147 |
+
device=video_logits.device
|
148 |
+
)
|
149 |
+
masked_frame_loss = self.loss(video_logits, video_label)
|
150 |
+
|
151 |
+
# loss for text.
|
152 |
+
text_label = text_label.reshape(-1)
|
153 |
+
labels_mask = text_label != -100
|
154 |
+
selected_text_label = text_label[labels_mask]
|
155 |
+
masked_lm_loss = self.loss(text_logits, selected_text_label)
|
156 |
+
return masked_frame_loss + masked_lm_loss
|