diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..5407cc64f2b48d0da5303963952efa61c11d22da 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+figs/mirror-frontpage.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5a95fee2cbe69e506ba780dda73574822b7042ec
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,187 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+.DS_Store
+._.DS_Store
+debug.py
+outputs/
+resources/NER/msra/cache/
+resources/NER/msra/mrc/
+resources/NER/msra/formatted/
+resources/MRC/cmrc2018/cache/
+resources/MRC/cmrc2018/formatted/
+cache/*.cache
+resources/MRC/DuReader-*/
+resources/**/*.json
+resources/**/*.jsonl
+resources/**/*.zip
+resources/**/*.tsv
+resources/**/*.xml
+resources/**/raw/
+resources.tar.gz
+debug/
+debug.json
+mirror_outputs/
+sampled_stats.xlsx
+mirror_fewshot_outputs/
+conll03-100.jsonl
+tmp*/
+resources/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1da8d0f898fec24d4e2a98814a9e8d3398ea7a4f
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,19 @@
+repos:
+- repo: https://github.com/pycqa/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ name: isort (python)
+ args: ["--profile", "black", "--filter-files"]
+- repo: https://github.com/psf/black
+ rev: 22.12.0
+ hooks:
+ - id: black
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: check-yaml
+ - id: check-added-large-files
+ args: [--maxkb=900]
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..7b8a5b210ed30226cc558b91f53250629576e154
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2023
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..a153a772c60e50e5b644a0aec1110dbb28d81349
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,26 @@
+all: format clean test pre
+ echo 'finished'
+
+.PHONY: format
+format:
+ isort --profile black --filter-files .
+ black .
+
+.PHONY: test
+test:
+ coverage run --source src -m pytest -vv .
+ coverage report -m
+ flake8
+
+.PHONY: pre
+pre:
+ pre-commit run --all-files
+
+.PHONY: clean
+clean:
+ rm -rf build/
+ rm -rf dist/
+ rm -rf *.egg-info/
+ rm -f .coverage
+ rm -f coverage.xml
+ find . | grep -E '(__pycache__|\.pyc|\.pyo$$)' | xargs rm -rf
diff --git a/README.md b/README.md
index 6567813f2d2afe8111afe15240ab8292b4eba473..65f75b91bdff8fa2059a3c1d1e4e1229140d297e 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,143 @@
---
title: Mirror
-emoji: ๐
-colorFrom: green
-colorTo: red
+emoji: ๐ช
+colorFrom: blue
+colorTo: yellow
sdk: gradio
sdk_version: 4.1.2
-app_file: app.py
-pinned: false
+app_file: src/app/gradio_app.py
+pinned: true
license: apache-2.0
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
๐ช Mirror: A Universal Framework for Various Information Extraction Tasks
+
+
Image generated by DALLE 3
+
+
[Paper] |
[Demo]
+ ๐ Our paper has been accepted to EMNLP23 main conference,
check it out !
+
+
+
+
+๐: This is the official implementation of [๐ชMirror](https://arxiv.org/abs/2311.05419) which supports *almost* all the Information Extraction tasks.
+
+The name, Mirror, comes from the classical story *Snow White and the Seven Dwarfs*, where a magic mirror knows everything in the world.
+We aim to build such a powerful tool for the IE community.
+
+## ๐ฅ Supported Tasks
+
+1. Named Entity Recognition
+2. Entity Relationship Extraction (Triplet Extraction)
+3. Event Extraction
+4. Aspect-based Sentiment Analysis
+5. Multi-span Extraction (e.g. Discontinuous NER)
+6. N-ary Extraction (e.g. Hyper Relation Extraction)
+7. Extractive Machine Reading Comprehension (MRC) and Question Answering
+8. Classification & Multi-choice MRC
+
+
+
+## ๐ด Dependencies
+
+Python>=3.10
+
+```bash
+pip install -r requirements.txt
+```
+
+## ๐ QuickStart
+
+### Pretrained Model Weights & Datasets
+
+Download the pretrained model weights & datasets from [[OSF]](https://osf.io/kwsm4/?view_only=5b66734d88cf456b93f17b6bac8a44fb) .
+
+No worries, it's an anonymous link just for double blind peer reviewing.
+
+### Pretraining
+
+1. Download and unzip the pretraining corpus into `resources/Mirror/v1.4_sampled_v3/merged/all_excluded`
+2. Start to run
+
+```bash
+CUDA_VISIBLE_DEVICES=0 rex train -m src.task -dc conf/Pretrain_excluded.yaml
+```
+
+### Fine-tuning
+
+โ ๏ธ Due to data license constraints, some datasets are unavailable to provide directly (e.g. ACE04, ACE05).
+
+1. Download and unzip the pretraining corpus into `resources/Mirror/v1.4_sampled_v3/merged/all_excluded`
+2. Download and unzip the fine-tuning datasets into `resources/Mirror/uie/`
+3. Start to fine-tuning
+
+```bash
+# UIE tasks
+CUDA_VISIBLE_DEVICES=0 bash scripts/single_task_wPTAllExcluded_wInstruction/run1.sh
+CUDA_VISIBLE_DEVICES=1 bash scripts/single_task_wPTAllExcluded_wInstruction/run2.sh
+CUDA_VISIBLE_DEVICES=2 bash scripts/single_task_wPTAllExcluded_wInstruction/run3.sh
+CUDA_VISIBLE_DEVICES=3 bash scripts/single_task_wPTAllExcluded_wInstruction/run4.sh
+# Multi-span and N-ary extraction
+CUDA_VISIBLE_DEVICES=4 bash scripts/single_task_wPTAllExcluded_wInstruction/run_new_tasks.sh
+# GLUE datasets
+CUDA_VISIBLE_DEVICES=5 bash scripts/single_task_wPTAllExcluded_wInstruction/glue.sh
+```
+
+### Analysis Experiments
+
+- Few-shot experiments : `scripts/run_fewshot.sh`. Collecting results: `python mirror_fewshot_outputs/get_avg_results.py`
+- Mirror w/ PT w/o Inst. : `scripts/single_task_wPTAllExcluded_woInstruction`
+- Mirror w/o PT w/ Inst. : `scripts/single_task_wo_pretrain`
+- Mirror w/o PT w/o Inst. : `scripts/single_task_wo_pretrain_wo_instruction`
+
+### Evaluation
+
+1. Change `task_dir` and `data_pairs` you want to evaluate. The default setting is to get results of Mirrordirect on all downstream tasks.
+2. `CUDA_VISIBLE_DEVICES=0 python -m src.eval`
+
+### Demo
+
+1. Download and unzip the pretrained task dump into `mirror_outputs/Mirror_Pretrain_AllExcluded_2`
+2. Try our demo:
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python -m src.app.api_backend
+```
+
+
+
+## ๐ Citation
+
+```bibtex
+@misc{zhu_mirror_2023,
+ shorttitle = {Mirror},
+ title = {Mirror: A Universal Framework for Various Information Extraction Tasks},
+ author = {Zhu, Tong and Ren, Junfei and Yu, Zijian and Wu, Mengsong and Zhang, Guoliang and Qu, Xiaoye and Chen, Wenliang and Wang, Zhefeng and Huai, Baoxing and Zhang, Min},
+ url = {http://arxiv.org/abs/2311.05419},
+ doi = {10.48550/arXiv.2311.05419},
+ urldate = {2023-11-10},
+ publisher = {arXiv},
+ month = nov,
+ year = {2023},
+ note = {arXiv:2311.05419 [cs]},
+ keywords = {Computer Science - Artificial Intelligence, Computer Science - Computation and Language},
+}
+```
+
+## ๐ฃ๏ธ Roadmap
+
+- [ ] Convert current model into Huggingface version, supporting loading from `transformers` like other newly released LLMs.
+- [ ] Remove `Background` area, merge `TL`, `TP` into a single `T` token
+- [ ] Add more task data: keyword extraction, coreference resolution, FrameNet, WikiNER, T-Rex relation extraction dataset, etc.
+- [ ] Pre-train on all the data (including benchmarks) to build a nice out-of-the-box toolkit for universal IE.
+
+## ๐ Yours sincerely
+
+This project is licensed under Apache-2.0.
+We hope you enjoy it ~
+
+
+
diff --git a/conf/Pretrain_excluded.yaml b/conf/Pretrain_excluded.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1c8d23a61cdecb8189933ec77c442448e84cf404
--- /dev/null
+++ b/conf/Pretrain_excluded.yaml
@@ -0,0 +1,51 @@
+# task
+task_type: SchemaGuidedInstructBertTask
+task_name: Mirror_Pretrain_AllExcluded_2
+comment: '~~content as label, (start, end + 1) span'
+
+# data preprocessing
+max_seq_len: 512
+debug_mode: false
+label_span: tag # tag `[LM]` or content `person`
+mode: span # w2 (1,2,3) or span (1,3)
+stream_mode: false
+
+# filepaths
+plm_dir: microsoft/deberta-v3-large
+data_dir: resources/Mirror/v1.4_sampled_v3/merged/all_excluded
+output_dir: mirror_outputs
+task_dir: ${output_dir}/${task_name}
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+dump_cache_dir: ${task_dir}/cache
+regenerate_cache: false
+
+# training
+random_seed: 1227
+base_model_path: null
+eval_on_data: [train]
+select_best_on_data: train
+select_best_by_key: loss
+final_eval_on_test: false
+save_every_ckpt: true
+save_best_ckpt: true
+
+warmup_proportion: 0.1
+num_epochs: 3
+epoch_patience: -1
+num_steps: -1
+step_patience: -1
+step_eval_interval: 10000
+train_batch_size: 8
+eval_batch_size: 8
+grad_accum_steps: 1
+learning_rate: !!float 2e-5
+other_learning_rate: !!float 1e-4
+max_grad_norm: 1.0
+weight_decay: 0.1
+
+# model
+dropout: 0.3
+use_rope: true
+biaffine_size: 512
diff --git a/conf/Pretrain_v1.5.yaml b/conf/Pretrain_v1.5.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8091120acbe10f010adb314d95370f7e33c37e35
--- /dev/null
+++ b/conf/Pretrain_v1.5.yaml
@@ -0,0 +1,51 @@
+# task
+task_type: SchemaGuidedInstructBertTask
+task_name: Mirror_Pretrain_DataV1.5_2
+comment: '~~content as label, (start, end + 1) span'
+
+# data preprocessing
+max_seq_len: 512
+debug_mode: false
+label_span: tag # tag `[LM]` or content `person`
+mode: span # w2 (1,2,3) or span (1,3)
+stream_mode: false
+
+# filepaths
+plm_dir: microsoft/deberta-v3-large
+data_dir: resources/Mirror/v1.5/merged/t-rex-200k
+output_dir: mirror_outputs
+task_dir: ${output_dir}/${task_name}
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+dump_cache_dir: ${task_dir}/cache
+regenerate_cache: false
+
+# training
+random_seed: 1227
+base_model_path: null
+eval_on_data: [train]
+select_best_on_data: train
+select_best_by_key: loss
+final_eval_on_test: false
+save_every_ckpt: true
+save_best_ckpt: true
+
+warmup_proportion: 0.1
+num_epochs: 3
+epoch_patience: -1
+num_steps: -1
+step_patience: -1
+step_eval_interval: 10000
+train_batch_size: 8
+eval_batch_size: 8
+grad_accum_steps: 1
+learning_rate: !!float 2e-5
+other_learning_rate: !!float 1e-4
+max_grad_norm: 1.0
+weight_decay: 0.1
+
+# model
+dropout: 0.3
+use_rope: true
+biaffine_size: 512
diff --git a/conf/Pretrain_v1.5_woInstruction.yaml b/conf/Pretrain_v1.5_woInstruction.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..91cab75be973e1f90c9b1d543a8542e2f121d104
--- /dev/null
+++ b/conf/Pretrain_v1.5_woInstruction.yaml
@@ -0,0 +1,51 @@
+# task
+task_type: SchemaGuidedInstructBertTask
+task_name: Mirror_Pretrain_DataV1.5_woInstruction
+comment: '~~content as label, (start, end + 1) span'
+
+# data preprocessing
+max_seq_len: 512
+debug_mode: false
+label_span: tag # tag `[LM]` or content `person`
+mode: span # w2 (1,2,3) or span (1,3)
+stream_mode: false
+
+# filepaths
+plm_dir: microsoft/deberta-v3-large
+data_dir: resources/Mirror/v1.5/merged/t-rex-200k-woInstruction/remove_instruction
+output_dir: mirror_outputs
+task_dir: ${output_dir}/${task_name}
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+dump_cache_dir: ${task_dir}/cache
+regenerate_cache: false
+
+# training
+random_seed: 1227
+base_model_path: null
+eval_on_data: [train]
+select_best_on_data: train
+select_best_by_key: loss
+final_eval_on_test: false
+save_every_ckpt: true
+save_best_ckpt: true
+
+warmup_proportion: 0.1
+num_epochs: 3
+epoch_patience: -1
+num_steps: -1
+step_patience: -1
+step_eval_interval: 10000
+train_batch_size: 8
+eval_batch_size: 8
+grad_accum_steps: 1
+learning_rate: !!float 2e-5
+other_learning_rate: !!float 1e-4
+max_grad_norm: 1.0
+weight_decay: 0.1
+
+# model
+dropout: 0.3
+use_rope: true
+biaffine_size: 512
diff --git a/conf/Pretrain_woOverlapV2.yaml b/conf/Pretrain_woOverlapV2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e68ef555106e91486618cba971e5b8f3e5f5fd3
--- /dev/null
+++ b/conf/Pretrain_woOverlapV2.yaml
@@ -0,0 +1,51 @@
+# task
+task_type: SchemaGuidedInstructBertTask
+task_name: Mirror_Pretrain_woOverlapV2
+comment: '~~content as label, (start, end + 1) span'
+
+# data preprocessing
+max_seq_len: 512
+debug_mode: false
+label_span: tag # tag `[LM]` or content `person`
+mode: span # w2 (1,2,3) or span (1,3)
+stream_mode: false
+
+# filepaths
+plm_dir: microsoft/deberta-v3-large
+data_dir: resources/Mirror/v1.4_sampled_v3/merged/all
+output_dir: mirror_outputs
+task_dir: ${output_dir}/${task_name}
+train_filepath: ${data_dir}/train_wo_overlap_v2.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+dump_cache_dir: ${task_dir}/cache
+regenerate_cache: false
+
+# training
+random_seed: 1227
+base_model_path: null
+eval_on_data: [train]
+select_best_on_data: train
+select_best_by_key: loss
+final_eval_on_test: false
+save_every_ckpt: true
+save_best_ckpt: true
+
+warmup_proportion: 0.1
+num_epochs: 3
+epoch_patience: -1
+num_steps: -1
+step_patience: -1
+step_eval_interval: 10000
+train_batch_size: 8
+eval_batch_size: 8
+grad_accum_steps: 1
+learning_rate: !!float 2e-5
+other_learning_rate: !!float 1e-4
+max_grad_norm: 1.0
+weight_decay: 0.1
+
+# model
+dropout: 0.3
+use_rope: true
+biaffine_size: 512
diff --git a/conf/ac/g1_dpspd.yaml b/conf/ac/g1_dpspd.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..af21ebdeb172b58c62fef4f4a3b95bc535d6e385
--- /dev/null
+++ b/conf/ac/g1_dpspd.yaml
@@ -0,0 +1,18 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ gradient_accumulation_steps: 1
+ zero3_init_flag: false
+ zero_stage: 1
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/conf/ac/g1_dpspd_fp16.yaml b/conf/ac/g1_dpspd_fp16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..381b21d64e7eeb23734f041ca1ed2c7a9c65dd52
--- /dev/null
+++ b/conf/ac/g1_dpspd_fp16.yaml
@@ -0,0 +1,18 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ gradient_accumulation_steps: 4
+ zero3_init_flag: false
+ zero_stage: 1
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: fp16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/conf/cadec.yaml b/conf/cadec.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4186d74890b5c8869c414ea1ebb61a986e40dc30
--- /dev/null
+++ b/conf/cadec.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_DiscontinuousNER_CADEC
+data_dir: resources/Mirror/new_abilities_v2/cadec/new
+best_metric_field: discontinuous_ent.micro.f1
diff --git a/conf/hyperred.yaml b/conf/hyperred.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78e4089063ff593d5ce331d56cea2e99798e781d
--- /dev/null
+++ b/conf/hyperred.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_HyperRel_HyperRED
+data_dir: resources/Mirror/new_abilities_v2/HyperRED/new
+best_metric_field: hyper_rel.micro.f1
diff --git a/conf/merge_all_data.yaml b/conf/merge_all_data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..45ec8711892db5a9b966d99f39cdafefbb59f4c7
--- /dev/null
+++ b/conf/merge_all_data.yaml
@@ -0,0 +1,6 @@
+task_name: InstructBert_MergedAllData
+data_dir: resources/Mirror/v1.3/merged_pretrained_data
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: resources/Mirror/v1.3/uie_data/dev.jsonl
+test_filepath: resources/Mirror/v1.3/uie_data/test.jsonl
+num_epochs: 1
diff --git a/conf/merge_analysis_data.yaml b/conf/merge_analysis_data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eff69048bafd4a7039c7f53995f404ac12741396
--- /dev/null
+++ b/conf/merge_analysis_data.yaml
@@ -0,0 +1,18 @@
+task_name: Mirror_MultiTask_Analysis
+plm_dir: microsoft/deberta-v3-large
+
+data_dir: resources/Mirror/uie/merged_analysis
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+num_epochs: 20
+epoch_patience: 3
+regenerate_cache: true
+
+eval_on_data: [dev]
+select_best_on_data: dev
+select_best_by_key: metric
+best_metric_field: general_spans.micro.f1
+final_eval_on_test: true
+
+base_model_path: null
diff --git a/conf/merge_analysis_data_woInstruction.yaml b/conf/merge_analysis_data_woInstruction.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e5d3d7dad622641d49ff9ddd53b452324ac280f2
--- /dev/null
+++ b/conf/merge_analysis_data_woInstruction.yaml
@@ -0,0 +1,18 @@
+task_name: Mirror_MultiTask_Analysis_woInstruction
+plm_dir: microsoft/deberta-v3-large
+
+data_dir: resources/Mirror/uie/merged_analysis/remove_instruction
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+num_epochs: 20
+epoch_patience: 3
+regenerate_cache: true
+
+eval_on_data: [dev]
+select_best_on_data: dev
+select_best_by_key: metric
+best_metric_field: general_spans.micro.f1
+final_eval_on_test: true
+
+base_model_path: null
diff --git a/conf/merge_uie_data.yaml b/conf/merge_uie_data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b2450e266315b54a07a5a772252acdf27fcb3953
--- /dev/null
+++ b/conf/merge_uie_data.yaml
@@ -0,0 +1,18 @@
+task_name: Mirror_woPT_NewMergedUIEData_woOverlap
+plm_dir: microsoft/deberta-v3-large
+
+data_dir: resources/Mirror/uie/merged
+train_filepath: ${data_dir}/train_wo_overlap.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+num_epochs: 20
+epoch_patience: 3
+regenerate_cache: true
+
+eval_on_data: [dev]
+select_best_on_data: dev
+select_best_by_key: metric
+best_metric_field: general_spans.micro.f1
+final_eval_on_test: true
+
+base_model_path: null
diff --git a/conf/mirror-ace05en.yaml b/conf/mirror-ace05en.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..106103f6b9aef1747af2d843b98e500738bd8899
--- /dev/null
+++ b/conf/mirror-ace05en.yaml
@@ -0,0 +1,70 @@
+# task
+task_type: SchemaGuidedInstructBertTask
+task_name: InstructBert_TagSpan_DebertaV3Base_ACE05ENPlus
+comment: '~~content as label, (start, end + 1) span'
+
+# data preprocessing
+max_seq_len: 512
+debug_mode: false
+label_span: tag # tag `[LM]` or content `person`
+mode: span # w2 (1,2,3) or span (1,3)
+
+# filepaths
+plm_dir: microsoft/deberta-v3-base
+# plm_dir: bert-base-cased
+# data_dir: resources/Mirror/Tasks/EE/ACE05-EN
+# data_dir: resources/Mirror/Tasks/RE/merged-20230502-2340-v1
+# data_dir: resources/Mirror/Tasks/RE/merged-20230502-2358-v2-woADE
+# data_dir: resources/Mirror/Tasks/EE/ACE05-EN-labelmap
+data_dir: resources/Mirror/v1.3/event/en/ACE05-EN-plus/fixed_instructed
+output_dir: outputs
+task_dir: ${output_dir}/${task_name}
+# train_filepath: ${data_dir}/ACE2005_plus_train.jsonl
+# dev_filepath: ${data_dir}/ACE2005_plus_dev.jsonl
+# test_filepath: ${data_dir}/ACE2005_plus_test.jsonl
+# train_filepath: ${data_dir}/ACE2005_oneie_NER_train.jsonl
+# dev_filepath: ${data_dir}/ACE2005_oneie_NER_dev.jsonl
+# test_filepath: ${data_dir}/ACE2005_oneie_NER_test.jsonl
+# train_filepath: ${data_dir}/ACE2005_oneie_RE_train.jsonl
+# dev_filepath: ${data_dir}/ACE2005_oneie_RE_dev.jsonl
+# test_filepath: ${data_dir}/ACE2005_oneie_RE_test.jsonl
+# train_filepath: ${data_dir}/ACE2005_oneie_EE_train.jsonl
+# dev_filepath: ${data_dir}/ACE2005_oneie_EE_dev.jsonl
+# test_filepath: ${data_dir}/ACE2005_oneie_EE_test.jsonl
+# train_filepath: ${data_dir}/ACE2005_oneie_train.jsonl
+# dev_filepath: ${data_dir}/ACE2005_oneie_dev.jsonl
+# test_filepath: ${data_dir}/ACE2005_oneie_test.jsonl
+# train_filepath: ${data_dir}/train.jsonl
+# dev_filepath: ${data_dir}/dev.jsonl
+# test_filepath: ${data_dir}/test.jsonl
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+
+dump_cache_dir: ${task_dir}/cache
+regenerate_cache: false
+
+# training
+random_seed: 1227
+eval_on_data: [dev, test]
+select_best_on_data: dev
+select_best_by_key: metric
+best_metric_field: general_spans.micro.f1
+final_eval_on_test: true
+save_every_ckpt: false
+save_best_ckpt: true
+
+warmup_proportion: 0.1
+num_epochs: 50
+epoch_patience: 5
+train_batch_size: 32
+eval_batch_size: 32
+learning_rate: !!float 3e-5
+other_learning_rate: !!float 3e-5
+max_grad_norm: 1.0
+weight_decay: 0.1
+
+# model
+dropout: 0.3
+use_rope: true
+biaffine_size: 512
diff --git a/conf/mirror-multi-task-pretrain.yaml b/conf/mirror-multi-task-pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..22b0d039890e82c093114876038e1664ca798b5c
--- /dev/null
+++ b/conf/mirror-multi-task-pretrain.yaml
@@ -0,0 +1,51 @@
+# task
+task_type: SchemaGuidedInstructBertTask
+task_name: MirrorLarge_SamplingPretrain_woLowResource_woOverlap
+comment: '~~content as label, (start, end + 1) span'
+
+# data preprocessing
+max_seq_len: 512
+debug_mode: false
+label_span: tag # tag `[LM]` or content `person`
+mode: span # w2 (1,2,3) or span (1,3)
+stream_mode: false
+
+# filepaths
+plm_dir: microsoft/deberta-v3-large
+data_dir: resources/Mirror/v1.4_sampled_v3/merged/woLowResource
+output_dir: mirror_outputs
+task_dir: ${output_dir}/${task_name}
+train_filepath: ${data_dir}/train_wo_overlap.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+dump_cache_dir: ${task_dir}/cache
+regenerate_cache: false
+
+# training
+random_seed: 1227
+base_model_path: null
+eval_on_data: [train]
+select_best_on_data: train
+select_best_by_key: loss
+final_eval_on_test: false
+save_every_ckpt: true
+save_best_ckpt: true
+
+warmup_proportion: 0.1
+num_epochs: 1
+epoch_patience: -1
+num_steps: -1
+step_patience: -1
+step_eval_interval: 3000
+train_batch_size: 8
+eval_batch_size: 8
+grad_accum_steps: 1
+learning_rate: !!float 2e-5
+other_learning_rate: !!float 1e-4
+max_grad_norm: 1.0
+weight_decay: 0.1
+
+# model
+dropout: 0.3
+use_rope: true
+biaffine_size: 512
diff --git a/conf/mrc.yaml b/conf/mrc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9a967f09a984c9b51f07fb41873d7d586bee337d
--- /dev/null
+++ b/conf/mrc.yaml
@@ -0,0 +1,43 @@
+# task
+task_type: MrcQaTask
+task_name: Mirror_RobertaBaseWwm_Cons_MsraMrc
+comment: 'GlobalPointer with RoPE'
+
+# data preprocessing
+max_seq_len: 512
+debug_mode: false
+mode: cons
+
+# filepaths
+plm_dir: hfl/chinese-roberta-wwm-ext
+data_dir: resources/NER/msra/mrc
+output_dir: outputs
+task_dir: ${output_dir}/${task_name}
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/test.jsonl
+test_filepath: ${data_dir}/test.jsonl
+dump_cache_dir: ${task_dir}/cache
+regenerate_cache: true
+
+# training
+random_seed: 1227
+eval_on_data: [dev]
+select_best_on_data: dev
+select_best_by_key: metric
+best_metric_field: micro.f1
+final_eval_on_test: true
+
+warmup_proportion: 0.1
+step_eval_interval: 20000
+step_patience: -1
+num_epochs: 5
+epoch_patience: 5
+train_batch_size: 32
+eval_batch_size: 64
+learning_rate: !!float 5e-5
+other_learning_rate: !!float 1e-4
+max_grad_norm: 1.0
+
+# model
+dropout: 0.3
+biaffine_size: 512
diff --git a/conf/ner.yaml b/conf/ner.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..349156646000fe19e98aea7fb005fd40a48514a3
--- /dev/null
+++ b/conf/ner.yaml
@@ -0,0 +1,45 @@
+# task
+task_type: MrcTaggingTask
+task_name: debug-Mirror_W2_MSRAv2_NER_FreezeBertEmbAnd0-3_bs64
+comment: 'bert mrc w/ w2ner for NER'
+
+# data preprocessing
+max_seq_len: 300
+negative_sample_prob: 1.0
+debug_mode: false
+mode: w2
+
+# filepaths
+base_model_path: outputs/RobertaBase_data20230314v2/ckpt/MrcGlobalPointerModel.best.pth
+plm_dir: hfl/chinese-roberta-wwm-ext
+data_dir: resources/NER/MSRA_v2/formatted
+output_dir: outputs
+task_dir: ${output_dir}/${task_name}
+train_filepath: ${data_dir}/train.char.bmes.jsonl
+dev_filepath: ${data_dir}/dev.char.bmes.jsonl
+test_filepath: ${data_dir}/test.char.bmes.jsonl
+ent_type2query_filepath: ${data_dir}/query.json
+dump_cache_dir: ${task_dir}/cache
+regenerate_cache: true
+
+# training
+random_seed: 1227
+eval_on_data: [dev, test]
+select_best_on_data: dev
+select_best_by_key: metric
+best_metric_field: micro.f1
+final_eval_on_test: true
+
+warmup_proportion: 0.1
+num_epochs: 5
+epoch_patience: 5
+train_batch_size: 64
+eval_batch_size: 128
+learning_rate: !!float 5e-5
+other_learning_rate: !!float 1e-4
+max_grad_norm: 1.0
+weight_decay: 0.1
+
+# model
+dropout: 0.3
+biaffine_size: 512
diff --git a/conf/nlu/cola.yaml b/conf/nlu/cola.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d7dfcdc810cdf630cabfddcc2e50a0095adf4882
--- /dev/null
+++ b/conf/nlu/cola.yaml
@@ -0,0 +1,6 @@
+task_name: Mirror_SingleTask_Cls_CoLA
+data_dir: resources/Mirror/v1.3/cls/en/CoLA/formated
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/dev.jsonl
+best_metric_field: cls.mcc
diff --git a/conf/nlu/mnli.yaml b/conf/nlu/mnli.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d7ddb6a927dcd8510fceb376d9cb178a72b2838a
--- /dev/null
+++ b/conf/nlu/mnli.yaml
@@ -0,0 +1,6 @@
+task_name: Mirror_SingleTask_Cls_MNLI
+data_dir: resources/Mirror/v1.3/cls/en/MNLI/formated
+train_filepath: ${data_dir}/MNLI_train.jsonl
+dev_filepath: ${data_dir}/MNLI_dev.jsonl
+test_filepath: ${data_dir}/MNLI_dev.jsonl
+best_metric_field: cls.acc
diff --git a/conf/nlu/mrpc.yaml b/conf/nlu/mrpc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fc9841b3c6fbd3b36ddcd50dcb81a8958660439f
--- /dev/null
+++ b/conf/nlu/mrpc.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Cls_MRPC
+data_dir: resources/Mirror/v1.3/cls/en/MRPC/formated
+best_metric_field: cls.acc
diff --git a/conf/nlu/plm.yaml b/conf/nlu/plm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..770e53393bc86b8eb4a002424e7d91f7881bef51
--- /dev/null
+++ b/conf/nlu/plm.yaml
@@ -0,0 +1,19 @@
+plm_dir: microsoft/deberta-v3-large
+base_model_path: mirror_outputs/Mirror_Pretrain_AllExcluded_2/ckpt/SchemaGuidedInstructBertModel.best.pth
+
+stream_mode: false
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+
+num_epochs: 5
+epoch_patience: -1
+num_steps: -1
+step_patience: -1
+step_eval_interval: -1
+
+eval_on_data: [dev]
+select_best_on_data: dev
+select_best_by_key: metric
+best_metric_field: general_spans.micro.f1
+final_eval_on_test: true
diff --git a/conf/nlu/qnli.yaml b/conf/nlu/qnli.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b95f03aba58455d2a21585cc7f63751e7152694d
--- /dev/null
+++ b/conf/nlu/qnli.yaml
@@ -0,0 +1,6 @@
+task_name: Mirror_SingleTask_Cls_QNLI
+data_dir: resources/Mirror/v1.3/cls/en/QNLI/processed
+train_filepath: ${data_dir}/QNLI_train.jsonl
+dev_filepath: ${data_dir}/QNLI_dev.jsonl
+test_filepath: ${data_dir}/QNLI_dev.jsonl
+best_metric_field: cls.acc
diff --git a/conf/nlu/qqp.yaml b/conf/nlu/qqp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..62e64176c9e75536beee7adc1f912e8694e5f17a
--- /dev/null
+++ b/conf/nlu/qqp.yaml
@@ -0,0 +1,6 @@
+task_name: Mirror_SingleTask_Cls_QQP
+data_dir: resources/Mirror/v1.3/cls/en/QQP/new
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/dev.jsonl
+best_metric_field: cls.acc
diff --git a/conf/nlu/rte.yaml b/conf/nlu/rte.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c281cab7a5f8c5f0f816cd809066ab68f883c126
--- /dev/null
+++ b/conf/nlu/rte.yaml
@@ -0,0 +1,6 @@
+task_name: Mirror_SingleTask_Cls_RTE
+data_dir: resources/Mirror/v1.3/cls/en/RTE/formated
+train_filepath: ${data_dir}/RTE_train.jsonl
+dev_filepath: ${data_dir}/RTE_dev.jsonl
+test_filepath: ${data_dir}/RTE_dev.jsonl
+best_metric_field: cls.acc
diff --git a/conf/nlu/squad_v2.yaml b/conf/nlu/squad_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8fada03a197b25ee0976cfb29b041cabd8917513
--- /dev/null
+++ b/conf/nlu/squad_v2.yaml
@@ -0,0 +1,4 @@
+task_name: Mirror_SingleTask_MRC_SQuADv2
+data_dir: resources/Mirror/v1.3/span/en/squad_v2
+test_filepath: ${data_dir}/dev.jsonl
+best_metric_field: span.f1.f1
diff --git a/conf/nlu/sst-2.yaml b/conf/nlu/sst-2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..873b510f4b2c90935643eda318d3a1bc12fba37e
--- /dev/null
+++ b/conf/nlu/sst-2.yaml
@@ -0,0 +1,6 @@
+task_name: Mirror_SingleTask_Cls_SST2
+data_dir: resources/Mirror/v1.3/cls/en/SST-2/instructed
+train_filepath: ${data_dir}/SST-2_train.jsonl
+dev_filepath: ${data_dir}/SST-2_dev.jsonl
+test_filepath: ${data_dir}/SST-2_dev.jsonl
+best_metric_field: cls.acc
diff --git a/conf/t-rex_pretrain.yaml b/conf/t-rex_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9ba2f299aea8c43218609a8463afe14d7241c8e9
--- /dev/null
+++ b/conf/t-rex_pretrain.yaml
@@ -0,0 +1,9 @@
+task_name: InstructBert_TagSpan_DebertaV3Base_TRExPretrain
+data_dir: resources/Mirror/v1.3/rel/en/T-REx/instructed
+train_filepath: ${data_dir}/t-rex.udi.fix.jsonl
+
+num_epochs: 3
+eval_on_data: [train]
+select_best_on_data: train
+select_best_by_key: loss
+final_eval_on_test: false
diff --git a/conf/uie_data/absa_14lap.yaml b/conf/uie_data/absa_14lap.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ee860dcb9141cd9819cca6d26a60f845fa2db2f6
--- /dev/null
+++ b/conf/uie_data/absa_14lap.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_ABSA_14lap
+data_dir: resources/Mirror/uie/absa/14lap
+best_metric_field: rel.rel.micro.f1
diff --git a/conf/uie_data/absa_14res.yaml b/conf/uie_data/absa_14res.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e3447a0bc10d81d54389b4478499afe480b68273
--- /dev/null
+++ b/conf/uie_data/absa_14res.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_ABSA_14res
+data_dir: resources/Mirror/uie/absa/14res
+best_metric_field: rel.rel.micro.f1
diff --git a/conf/uie_data/absa_15res.yaml b/conf/uie_data/absa_15res.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d70f989e346d14ded78e5aa99f4fd17411236bea
--- /dev/null
+++ b/conf/uie_data/absa_15res.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_ABSA_15res
+data_dir: resources/Mirror/uie/absa/15res
+best_metric_field: rel.rel.micro.f1
diff --git a/conf/uie_data/absa_16res.yaml b/conf/uie_data/absa_16res.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..10c1a8a073e2e26a02dfb626823bf5920867e670
--- /dev/null
+++ b/conf/uie_data/absa_16res.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_ABSA_16res
+data_dir: resources/Mirror/uie/absa/16res
+best_metric_field: rel.rel.micro.f1
diff --git a/conf/uie_data/ent_ace04.yaml b/conf/uie_data/ent_ace04.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..49969a4a5d5c69150453f7e9ee9cc2e9004f9cdb
--- /dev/null
+++ b/conf/uie_data/ent_ace04.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Ent_ACE04
+data_dir: resources/Mirror/uie/ent/ace04
+best_metric_field: ent.micro.f1
diff --git a/conf/uie_data/ent_ace05.yaml b/conf/uie_data/ent_ace05.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..413e9082313f503b991e51ce9dbe6c022f4a4a83
--- /dev/null
+++ b/conf/uie_data/ent_ace05.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Ent_ACE05
+data_dir: resources/Mirror/uie/ent/ace05
+best_metric_field: ent.micro.f1
diff --git a/conf/uie_data/ent_conll03.yaml b/conf/uie_data/ent_conll03.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..377ceb61de89ed29e9841f94f9a1bfea536d40a4
--- /dev/null
+++ b/conf/uie_data/ent_conll03.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Ent_CoNLL03
+data_dir: resources/Mirror/uie/ent/conll03
+best_metric_field: ent.micro.f1
diff --git a/conf/uie_data/event_ace05.yaml b/conf/uie_data/event_ace05.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f31ef7454b35217b4142ebc2de7dfed0565eb69e
--- /dev/null
+++ b/conf/uie_data/event_ace05.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Event_ACE05
+data_dir: resources/Mirror/uie/event/ace05-evt
+best_metric_field: event.arg_cls.f1
diff --git a/conf/uie_data/event_casie.yaml b/conf/uie_data/event_casie.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..79de43b538f11fbe6459723abbd9333912c85e39
--- /dev/null
+++ b/conf/uie_data/event_casie.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Event_CASIE
+data_dir: resources/Mirror/uie/event/casie
+best_metric_field: event.arg_cls.f1
diff --git a/conf/uie_data/fewshot.yaml b/conf/uie_data/fewshot.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e20665083de1cd3a17b7fa83ff665e58aeea0eb9
--- /dev/null
+++ b/conf/uie_data/fewshot.yaml
@@ -0,0 +1,5 @@
+num_epochs: 200
+epoch_patience: 10
+output_dir: mirror_fewshot_outputs
+base_model_path: mirror_outputs/Mirror_Pretrain_AllExcluded_2/ckpt/SchemaGuidedInstructBertModel.best.pth
+save_every_ckpt: false
diff --git a/conf/uie_data/merged.yaml b/conf/uie_data/merged.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2c88730b7f7ff6384b06a1e871df7ba490dccc93
--- /dev/null
+++ b/conf/uie_data/merged.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_MultiTask_UIE
+data_dir: resources/Mirror/uie/merged
+best_metric_field: general_spans.micro.f1
diff --git a/conf/uie_data/rel_ace05.yaml b/conf/uie_data/rel_ace05.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4c8951eb48c2e0e6476a711ad4d5445a6c68597f
--- /dev/null
+++ b/conf/uie_data/rel_ace05.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Rel_ACE05
+data_dir: resources/Mirror/uie/rel/ace05-rel
+best_metric_field: rel.rel.micro.f1
diff --git a/conf/uie_data/rel_conll04.yaml b/conf/uie_data/rel_conll04.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e3bde2d363451745a3fff17bb0779a69b888e5ab
--- /dev/null
+++ b/conf/uie_data/rel_conll04.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Rel_CoNLL04
+data_dir: resources/Mirror/uie/rel/conll04
+best_metric_field: rel.rel.micro.f1
diff --git a/conf/uie_data/rel_nyt.yaml b/conf/uie_data/rel_nyt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1131946b1bce7022603c684084531e9e1c80eb9d
--- /dev/null
+++ b/conf/uie_data/rel_nyt.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Rel_NYT
+data_dir: resources/Mirror/uie/rel/nyt
+best_metric_field: rel.rel.micro.f1
diff --git a/conf/uie_data/rel_scierc.yaml b/conf/uie_data/rel_scierc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7655405e61cae5f07ce9261c74839f371ccb6428
--- /dev/null
+++ b/conf/uie_data/rel_scierc.yaml
@@ -0,0 +1,3 @@
+task_name: Mirror_SingleTask_Rel_SciERC
+data_dir: resources/Mirror/uie/rel/scierc
+best_metric_field: rel.rel.micro.f1
diff --git a/conf/uie_data/wPretrain.yaml b/conf/uie_data/wPretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a54be67d2b04181502a91dd25dca1a1bfa8daedc
--- /dev/null
+++ b/conf/uie_data/wPretrain.yaml
@@ -0,0 +1,19 @@
+plm_dir: microsoft/deberta-v3-large
+base_model_path: mirror_outputs/Mirror_Pretrain_AllExcluded_2/ckpt/SchemaGuidedInstructBertModel.best.pth
+
+stream_mode: false
+train_filepath: ${data_dir}/train.jsonl
+dev_filepath: ${data_dir}/dev.jsonl
+test_filepath: ${data_dir}/test.jsonl
+
+num_epochs: 20
+epoch_patience: 3
+num_steps: -1
+step_patience: -1
+step_eval_interval: -1
+
+eval_on_data: [dev]
+select_best_on_data: dev
+select_best_by_key: metric
+best_metric_field: general_spans.micro.f1
+final_eval_on_test: true
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/index.html b/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..1af05bcbaf3e3599411534ddf8f35633b50a306e
--- /dev/null
+++ b/index.html
@@ -0,0 +1,288 @@
+
+
+
+
+
+
+
+ ๐ชMirror
+
+
+
+
+
+
+
+
+
+
+
+ Instruction
+
+
+
+
Schema Labels
+
Split with #
for multiple inputs
+
For entities, relations or classification, input {"ent|rel|cls": ["cls1", "type2"]}
.
+
For events and hyper relations, input {"type": ["role1", "role2"]}
.
+
+
+
+
+ Text
+
+
+
+
+
+ Reset
+ Clear Output
+ Ask Mirror
+
+
+
+
โฑ๏ธ {{ searchSecondsString }}
+
+
+
+
+
Output
+
+
+ Item
+ Predicted
+
+
+
+
+ {{ key }}
+ {{ value }}
+
+
+
+
+
+
+
+
+
+
+
+
+ Made by Mirror Team w/ ๐
+
+
+
+
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8089003957db8735e2913a88283c46afc480d77f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+pandas
+rich
+numpy
+omegaconf
+gpu-watchmen
+tqdm
+datasets
+transformers
+gradio
+git+https://github.com/Spico197/REx
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/analyze.py b/src/analyze.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2f25c0c9a52e34d5c331ccfa98e11abbc13e55e
--- /dev/null
+++ b/src/analyze.py
@@ -0,0 +1,135 @@
+from collections import defaultdict
+
+from rex.metrics.tagging import tagging_prf1
+from rex.utils.io import load_jsonlines
+from rex.utils.position import find_all_positions
+
+
+def main():
+ middle_filepath = "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_labelmap_Rel_updateTag_bs32/middle/test.final.jsonl"
+ data = load_jsonlines(middle_filepath)
+ for ins in data:
+ gold = ins["gold"]
+ pred = ins["pred"]
+ if gold["spans"] != pred["spans"]:
+ breakpoint()
+
+
+def check_ent_string_matching_upper_bound(filepath: str, strategy: str = "first"):
+ def _check_overlap(x, y):
+ if x[0] > y[1] or y[0] > x[1]:
+ return False
+ else:
+ return True
+
+ data = load_jsonlines(filepath)
+ golds = []
+ preds = []
+ for ins in data:
+ text = ins["text"]
+ gold_ents = ins["ans"]["ent"]
+ gold_ents = list(
+ set([(ent["text"], ent["type"], tuple(ent["span"])) for ent in gold_ents])
+ )
+ gold_ents.sort(key=lambda x: len(x[0]), reverse=True)
+ pred_ents = []
+ matched = set()
+ for gold_ent in gold_ents:
+ ent_string = gold_ent[0]
+ ent_type = gold_ent[1]
+ positions = find_all_positions(text, ent_string)
+ if strategy == "first":
+ for position in positions:
+ if (ent_type, position) not in matched:
+ matched.add((ent_type, position))
+ pred_ents.append((ent_string, ent_type, tuple(position)))
+ else:
+ flag = False
+ for position in positions:
+ for _, g in matched:
+ if _check_overlap(g, position):
+ flag = True
+ if flag:
+ continue
+
+ if (ent_type, position) not in matched:
+ matched.add((ent_type, position))
+ pred_ents.append((ent_string, ent_type, tuple(position)))
+ break
+
+ golds.append(gold_ents)
+ preds.append(pred_ents)
+
+ results = tagging_prf1(golds, preds)
+
+ print(f"filepath: {filepath}, Strategy: {strategy}")
+ print(f"Results: {results['micro']}")
+
+
+def check_rel_tanl_upper_bound(filepath):
+ data = load_jsonlines(filepath)
+ golds = []
+ preds = []
+ for ins in data:
+ text = ins["text"]
+ gold_rels = ins["ans"]["rel"]
+ ent_text_to_spans = defaultdict(set)
+ for ent in ins["ans"]["ent"]:
+ ent_text_to_spans[ent["text"]].add(tuple(ent["span"]))
+ gold_rels = list(
+ set(
+ [
+ (
+ tuple(rel["head"]["span"]),
+ rel["relation"],
+ tuple(rel["tail"]["span"]),
+ )
+ for rel in gold_rels
+ ]
+ )
+ )
+ pred_rels = []
+ for pred_rel in ins["ans"]["rel"]:
+ # pred_triple = ()
+ tail_text = pred_rel["tail"]["text"]
+ if (
+ tail_text in ent_text_to_spans
+ and len(ent_text_to_spans[tail_text]) == 1
+ ):
+ tail_span = list(ent_text_to_spans[tail_text])[0]
+ pred_rels.append(
+ (tuple(pred_rel["head"]["span"]), pred_rel["relation"], tail_span)
+ )
+ # if tail_text in ent_text_to_spans:
+ # tail_span = list(ent_text_to_spans[tail_text])[0]
+ # else:
+ # tail_span = find_all_positions(text, tail_text)[0]
+ # pred_rels.append((tuple(pred_rel["head"]["span"]), pred_rel["relation"], tail_span))
+
+ golds.append(gold_rels)
+ preds.append(pred_rels)
+
+ results = tagging_prf1(golds, preds)
+
+ print(f"filepath: {filepath}")
+ print(f"Results: {results['micro']}")
+
+
+if __name__ == "__main__":
+ # main()
+
+ # for filepath in [
+ # "/data/tzhu/Mirror/resources/Mirror/uie/ent/ace04/test.jsonl",
+ # "/data/tzhu/Mirror/resources/Mirror/uie/ent/ace05/test.jsonl",
+ # "/data/tzhu/Mirror/resources/Mirror/uie/ent/conll03/test.jsonl",
+ # ]:
+ # for strategy in ["first", "longer_first"]:
+ # check_ent_string_matching_upper_bound(filepath, strategy)
+
+ for filepath in [
+ "/data/tzhu/Mirror/resources/Mirror/uie/rel/ace05-rel/test.jsonl",
+ "/data/tzhu/Mirror/resources/Mirror/uie/rel/conll04/test.jsonl",
+ "/data/tzhu/Mirror/resources/Mirror/uie/rel/nyt/test.jsonl",
+ "/data/tzhu/Mirror/resources/Mirror/uie/rel/scierc/test.jsonl",
+ ]:
+ check_rel_tanl_upper_bound(filepath)
diff --git a/src/app/__init__.py b/src/app/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/app/api_backend.py b/src/app/api_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..220f6484880bfc78b2052a5c1d72a11e5ecfe38f
--- /dev/null
+++ b/src/app/api_backend.py
@@ -0,0 +1,80 @@
+import traceback
+from typing import Any, Dict, List
+
+import uvicorn
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from pydantic import BaseModel
+from rex.utils.initialization import set_seed_and_log_path
+
+from src.task import SchemaGuidedInstructBertTask
+
+set_seed_and_log_path(log_path="debug.log")
+
+app = FastAPI()
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+
+class RequestData(BaseModel):
+ data: List[Dict[str, Any]]
+
+
+task = SchemaGuidedInstructBertTask.from_taskdir(
+ "mirror_outputs/Mirror_Pretrain_AllExcluded_2",
+ load_best_model=True,
+ initialize=False,
+ dump_configfile=False,
+ update_config={
+ "regenerate_cache": False,
+ },
+)
+
+
+@app.post("/process")
+def process_data(data: RequestData):
+ input_data = data.data
+
+ ok = True
+ msg = ""
+ results = {}
+ try:
+ results = task.predict(input_data)
+ msg = "success"
+ except KeyboardInterrupt:
+ raise KeyboardInterrupt
+ except Exception:
+ ok = False
+ msg = traceback.format_exc()
+
+ # Return the processed data
+ return {"ok": ok, "msg": msg, "results": results}
+
+
+@app.get("/")
+async def api():
+ return FileResponse("./index.html", media_type="text/html")
+
+
+if __name__ == "__main__":
+ log_config = uvicorn.config.LOGGING_CONFIG
+ log_config["formatters"]["access"]["fmt"] = (
+ "%(asctime)s | " + log_config["formatters"]["access"]["fmt"]
+ )
+ log_config["formatters"]["default"]["fmt"] = (
+ "%(asctime)s | " + log_config["formatters"]["default"]["fmt"]
+ )
+ uvicorn.run(
+ "src.app.api_backend:app",
+ host="0.0.0.0",
+ port=7860,
+ log_level="debug",
+ log_config=log_config,
+ reload=True,
+ )
diff --git a/src/app/demo1_deprecated.py b/src/app/demo1_deprecated.py
new file mode 100644
index 0000000000000000000000000000000000000000..e41018bfdb17651c97e2a03660044e67d629ebfc
--- /dev/null
+++ b/src/app/demo1_deprecated.py
@@ -0,0 +1,97 @@
+import gradio as gr
+from rex.utils.initialization import set_seed_and_log_path
+from rex.utils.logging import logger
+
+from src.task import MrcQaTask, SchemaGuidedInstructBertTask
+
+set_seed_and_log_path(log_path="app.log")
+
+
+class MrcQaPipeline:
+ def __init__(self, task_dir: str, load_path: str = None) -> None:
+ self.task = MrcQaTask.from_taskdir(
+ task_dir, load_best_model=load_path is None, initialize=False
+ )
+ if load_path:
+ self.task.load(load_path, load_history=False)
+
+ def predict(self, query, context, background=None):
+ data = [
+ {
+ "query": query,
+ "context": context,
+ "background": background,
+ }
+ ]
+ results = self.task.predict(data)
+ ret = results[0]
+
+ data[0]["pred"] = ret
+ logger.opt(colors=False).debug(data[0])
+
+ return ret
+
+
+class InstructBertPipeline:
+ def __init__(self, task_dir: str, load_path: str = None) -> None:
+ self.task = SchemaGuidedInstructBertTask.from_taskdir(
+ task_dir, load_best_model=load_path is None, initialize=False
+ )
+ if load_path:
+ self.task.load(load_path, load_history=False)
+
+ def predict(self, instruction, schema, text, background):
+ data = [
+ {
+ "query": query,
+ "context": context,
+ "background": background,
+ }
+ ]
+ results = self.task.predict(data)
+ ret = results[0]
+
+ data[0]["pred"] = ret
+ logger.opt(colors=False).debug(data[0])
+
+ return ret
+
+
+def mrc_qa():
+ pipe = Pipeline("outputs/RobertaBase_data20230314v2")
+
+ with gr.Blocks() as demo:
+ gr.Markdown("# ๐ช Mirror Mirror")
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ query = gr.Textbox(
+ label="Query", placeholder="Mirror Mirror, tell me ..."
+ )
+ with gr.Row():
+ context = gr.TextArea(
+ label="Candidates",
+ placeholder="Separated by comma (,) without spaces.",
+ )
+ with gr.Row():
+ background = gr.TextArea(
+ label="Background",
+ placeholder="Background explanation, could be empty",
+ )
+
+ with gr.Column():
+ with gr.Row():
+ trigger_button = gr.Button("Tell me the truth", variant="primary")
+ with gr.Row():
+ output = gr.TextArea(label="Output")
+
+ trigger_button.click(
+ pipe.predict, inputs=[query, context, background], outputs=output
+ )
+
+ demo.launch(show_error=True, share=False)
+
+
+def instruct_bert_pipeline():
+ task = SchemaGuidedInstructBertTask.from_taskdir()
diff --git a/src/app/gradio_app.py b/src/app/gradio_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aa58aa0b48245f4ef5bedaf45e89c99023e3f15
--- /dev/null
+++ b/src/app/gradio_app.py
@@ -0,0 +1,58 @@
+import json
+
+import gradio as gr
+from rex.utils.initialization import set_seed_and_log_path
+
+from src.task import SchemaGuidedInstructBertTask
+
+set_seed_and_log_path(log_path="debug.log")
+
+
+task = SchemaGuidedInstructBertTask.from_taskdir(
+ "mirror_outputs/Mirror_Pretrain_AllExcluded_2",
+ load_best_model=True,
+ initialize=False,
+ dump_configfile=False,
+ update_config={
+ "regenerate_cache": False,
+ },
+)
+
+
+def ask_mirror(instruction, schema, text):
+ input_data = {
+ "id": "app",
+ "instruction": instruction,
+ "schema": json.loads(schema),
+ "text": text,
+ "ans": {},
+ }
+ results = task.predict(input_data)
+ return results
+
+
+with gr.Blocks() as demo:
+ gr.Markdown("# ๐ชMirror")
+ gr.Markdown(
+ "๐ชMirror can help you deal with a wide range of Natural Language Understanding and Information Extraction tasks."
+ )
+ gr.Markdown(
+ "[[paper]](https://arxiv.org/abs/2311.05419) | [[code]](https://github.com/Spico197/Mirror)"
+ )
+
+ instruction = gr.Textbox(label="Instruction")
+ schema = gr.Textbox(
+ label="schema",
+ placeholder='{"cls": ["class1", "class2"], "ent": ["type1", "type2"], "rel": ["relation1", "relation2"]} leave it as {} to support span extraction.',
+ )
+ text = gr.TextArea(label="Text")
+ output = gr.Textbox(label="Output")
+
+ submit_btn = gr.Button("Ask Mirror")
+ submit_btn.click(ask_mirror, inputs=[instruction, schema, text], outputs=output)
+
+ gr.Markdown("Made by Mirror Team w/ ๐")
+
+
+if __name__ == "__main__":
+ demo.launch()
diff --git a/src/eval.py b/src/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5158f3d892c73e0a2d15042b12e11904d7f7ac9
--- /dev/null
+++ b/src/eval.py
@@ -0,0 +1,142 @@
+from pathlib import Path
+
+import pandas as pd
+from rex.utils.initialization import set_seed_and_log_path
+from rex.utils.io import load_json
+from rich.console import Console
+from rich.table import Table
+
+from src.task import SchemaGuidedInstructBertTask
+
+set_seed_and_log_path(log_path="tmp_eval.log")
+
+
+if __name__ == "__main__":
+ task_dir = "mirror_outputs/Mirror_Pretrain_AllExcluded_2"
+ # task_dir = "mirror_outputs/Mirror_SingleTask_wPTAllExcluded_Event_ACE05"
+ task: SchemaGuidedInstructBertTask = SchemaGuidedInstructBertTask.from_taskdir(
+ task_dir,
+ load_best_model=True,
+ initialize=False,
+ dump_configfile=False,
+ update_config={
+ "regenerate_cache": True,
+ "eval_on_data": ["dev"],
+ "select_best_on_data": "dev",
+ "select_best_by_key": "metric",
+ "best_metric_field": "general_spans.micro.f1",
+ "eval_batch_size": 32,
+ },
+ )
+ table = Table(title=task_dir)
+
+ data_pairs = [
+ # fmt: off
+
+ # UIE eval data
+ # ["ent_ace04_test", "resources/Mirror/uie/ent/ace04/test.jsonl"],
+ # ["ent_ace05_test", "resources/Mirror/uie/ent/ace05/test.jsonl"],
+ ["ent_conll03_test", "resources/Mirror/uie/ent/conll03/test.jsonl"],
+ # ["rel_ace05_test", "resources/Mirror/uie/rel/ace05-rel/test.jsonl"],
+ ["rel_conll04_test", "resources/Mirror/uie/rel/conll04/test.jsonl"],
+ # ["rel_nyt_test", "resources/Mirror/uie/rel/nyt/test.jsonl"],
+ # ["rel_scierc_test", "resources/Mirror/uie/rel/scierc/test.jsonl"],
+ ["event_ace05_test", "resources/Mirror/uie/event/ace05-evt/test.jsonl"],
+ # ["event_casie_test", "resources/Mirror/uie/event/casie/test.jsonl"],
+ # ["absa_14res_test", "resources/Mirror/uie/absa/14res/test.jsonl"],
+ # ["absa_14lap_test", "resources/Mirror/uie/absa/14lap/test.jsonl"],
+ # ["absa_15res_test", "resources/Mirror/uie/absa/15res/test.jsonl"],
+ # ["absa_16res_test", "resources/Mirror/uie/absa/16res/test.jsonl"],
+ # # discontinuous NER
+ # ["discontinuous_ent", "resources/Mirror/new_abilities_v2/cadec/new/test.jsonl"],
+ # # hyper-RE
+ # ["hyper_rel", "resources/Mirror/new_abilities_v2/HyperRED/new/test.jsonl"],
+ # # zero-shot NER
+ # ["ent_movie", "resources/Mirror/v1.3/ent/en/MIT_MOVIE_Review/instructed/test.jsonl"],
+ # ["ent_restaurant", "resources/Mirror/v1.3/ent/en/MIT_Restaurant_Review/instructed/test.jsonl"],
+ # ["ent_ai", "resources/Mirror/v1.3/ent/en/CrossNER_AI/instructed/test.jsonl"],
+ # ["ent_literature", "resources/Mirror/v1.3/ent/en/CrossNER_literature/instructed/test.jsonl"],
+ # ["ent_music", "resources/Mirror/v1.3/ent/en/CrossNER_music/instructed/test.jsonl"],
+ # ["ent_politics", "resources/Mirror/v1.3/ent/en/CrossNER_politics/instructed/test.jsonl"],
+ # ["ent_science", "resources/Mirror/v1.3/ent/en/CrossNER_science/instructed/test.jsonl"],
+ # # mrc
+ # ["span_squad2", "resources/Mirror/v1.3/span/en/squad_v2/dev.jsonl"],
+ # # glue
+ # ["cls_glue_cola", "resources/Mirror/v1.3/cls/en/CoLA/formated/dev.jsonl"],
+ # ["cls_glue_qqp", "resources/Mirror/v1.3/cls/en/QQP/new/dev.jsonl"],
+ # ["cls_glue_mnli", "resources/Mirror/v1.3/cls/en/MNLI/formated/MNLI_dev.jsonl"],
+ # ["cls_glue_sst2", "resources/Mirror/v1.3/cls/en/SST-2/instructed/SST-2_dev.jsonl"],
+ # ["cls_glue_qnli", "resources/Mirror/v1.3/cls/en/QNLI/processed/QNLI_dev.jsonl"],
+ # ["cls_glue_rte", "resources/Mirror/v1.3/cls/en/RTE/formated/RTE_dev.jsonl"],
+ # ["cls_glue_mrpc", "resources/Mirror/v1.3/cls/en/MRPC/formated/dev.jsonl"],
+ # fmt: on
+ ]
+
+ eval_res = {"task": [], "dataset": [], "metric_val": []}
+ table.add_column("Task", justify="left", style="cyan")
+ table.add_column("Dataset", justify="left", style="magenta")
+ table.add_column("Metric (%)", justify="right", style="green")
+ for dname, fpath in data_pairs:
+ dname = dname.lower()
+ task.data_manager.update_datapath(dname, fpath)
+ _, res = task.eval(dname, verbose=True, dump=True, dump_middle=True)
+ # res = load_json(Path(task_dir) / "measures" / f"{dname}.json")["metrics"]
+ if dname.startswith("ent_"):
+ eval_res["task"].append("ent")
+ eval_res["dataset"].append(dname)
+ eval_res["metric_val"].append(res["ent"]["micro"]["f1"])
+ elif dname.startswith("rel_"):
+ eval_res["task"].append("rel")
+ eval_res["dataset"].append(dname)
+ eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"])
+ elif dname.startswith("event_"):
+ eval_res["task"].append("event")
+ eval_res["dataset"].append(dname + "_tgg")
+ eval_res["metric_val"].append(res["event"]["trigger_cls"]["f1"])
+ eval_res["task"].append("event")
+ eval_res["dataset"].append(dname + "_arg")
+ eval_res["metric_val"].append(res["event"]["arg_cls"]["f1"])
+ elif dname.startswith("absa_"):
+ eval_res["task"].append("absa")
+ eval_res["dataset"].append(dname)
+ eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"])
+ elif dname.startswith("cls_"):
+ eval_res["task"].append("cls")
+ eval_res["dataset"].append(dname)
+ if "_glue_" in dname:
+ if "_cola" in dname:
+ eval_res["metric_val"].append(res["cls"]["mcc"])
+ else:
+ eval_res["metric_val"].append(res["cls"]["acc"])
+ else:
+ eval_res["metric_val"].append(res["cls"]["mf1"]["micro"]["f1"])
+ elif dname.startswith("span"):
+ eval_res["task"].append("span_em")
+ eval_res["dataset"].append(dname)
+ eval_res["metric_val"].append(res["span"]["em"])
+ eval_res["task"].append("span_f1")
+ eval_res["dataset"].append(dname)
+ eval_res["metric_val"].append(res["span"]["f1"]["f1"])
+ elif dname.startswith("discontinuous_ent"):
+ eval_res["task"].append("discontinuous_ent")
+ eval_res["dataset"].append(dname)
+ eval_res["metric_val"].append(res["discontinuous_ent"]["micro"]["f1"])
+ elif dname.startswith("hyper_rel"):
+ eval_res["task"].append("hyper_rel")
+ eval_res["dataset"].append(dname)
+ eval_res["metric_val"].append(res["hyper_rel"]["micro"]["f1"])
+ else:
+ raise ValueError
+
+ for i in range(len(eval_res["task"])):
+ table.add_row(
+ eval_res["task"][i],
+ eval_res["dataset"][i],
+ f"{100*eval_res['metric_val'][i]:.3f}",
+ )
+
+ console = Console()
+ console.print(table)
+
+ df = pd.DataFrame(eval_res)
+ df.to_excel(task.measures_path.joinpath("data_eval_res.xlsx"))
diff --git a/src/get_avg_results.py b/src/get_avg_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..f57d9705424451452bb2b17e3f17cc3b23e1897e
--- /dev/null
+++ b/src/get_avg_results.py
@@ -0,0 +1,89 @@
+import os
+import re
+import statistics as sts
+from collections import defaultdict
+from pathlib import Path
+
+from rex.utils.dict import get_dict_content
+from rex.utils.io import load_json
+from rich.console import Console
+from rich.table import Table
+
+inputs_dir = Path("mirror_fewshot_outputs")
+# regex = re.compile(r"Mirror_SingleTask_(.*?)_seed(\d+)_(\d+)shot")
+regex = re.compile(r"Mirror_wPT_woInst_(.*?)_seed(\d+)_(\d+)shot")
+
+# task -> shot -> seeds
+results = defaultdict(lambda: defaultdict(list))
+
+for dirname in os.listdir(inputs_dir):
+ dpath = inputs_dir / dirname
+ re_matched = regex.match(dirname)
+ if dpath.is_dir() and re_matched:
+ task, seed, shot = re_matched.groups()
+ results_json_p = dpath / "measures" / "test.final.json"
+ metrics = load_json(results_json_p)
+ if "Ent_" in task:
+ results[task][shot].append(
+ get_dict_content(metrics, "metrics.ent.micro.f1")
+ )
+ elif "Rel_" in task or "ABSA_" in task:
+ results[task][shot].append(
+ get_dict_content(metrics, "metrics.rel.rel.micro.f1")
+ )
+ elif "Event_" in task:
+ results[task + "_Trigger"][shot].append(
+ get_dict_content(metrics, "metrics.event.trigger_cls.f1")
+ )
+ results[task + "_Arg"][shot].append(
+ get_dict_content(metrics, "metrics.event.arg_cls.f1")
+ )
+ else:
+ raise RuntimeError
+
+table = Table(title="Few-shot results")
+table.add_column("Task", justify="center")
+table.add_column("1-shot", justify="right")
+table.add_column("5-shot", justify="right")
+table.add_column("10-shot", justify="right")
+table.add_column("Avg.", justify="right")
+for task in results:
+ shots = sorted(results[task].keys(), key=lambda x: int(x))
+ all_seeds = []
+ shot_results = []
+ for shot in shots:
+ seeds = results[task][shot]
+ all_seeds.extend(seeds)
+ avg = sum(seeds) / len(seeds)
+ sts.stdev(seeds)
+ shot_results.append(f"{100*avg:.2f}ยฑ{100*sts.stdev(seeds):.2f}")
+ shot_results.append(f"{100*sts.mean(all_seeds):.2f}")
+ table.add_row(task, *shot_results)
+
+console = Console()
+console.print(table)
+
+"""
+ Few-shot results wPT wInst
+โโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโณโโโโโโโโโโโโโโณโโโโโโโโโโโโโณโโโโโโโโ
+โ Task โ 1-shot โ 5-shot โ 10-shot โ Avg. โ
+โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ
+โ Ent_CoNLL03 โ 77.50ยฑ1.64 โ 82.73ยฑ2.29 โ 84.48ยฑ1.62 โ 81.57 โ
+โ Rel_CoNLL04 โ 34.66ยฑ10.52 โ 52.23ยฑ3.16 โ 58.68ยฑ1.77 โ 48.52 โ
+โ Event_ACE05_Trigger โ 49.50ยฑ3.59 โ 65.61ยฑ19.29 โ 60.68ยฑ2.45 โ 58.60 โ
+โ Event_ACE05_Arg โ 23.46ยฑ1.66 โ 48.32ยฑ28.91 โ 41.90ยฑ1.95 โ 37.89 โ
+โ ABSA_16res โ 67.06ยฑ0.56 โ 73.51ยฑ14.75 โ 68.70ยฑ1.46 โ 69.76 โ
+โโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโดโโโโโโโโโโโโโโดโโโโโโโโโโโโโดโโโโโโโโ
+
+ Few-shot results wPT woInst
+โโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโณโโโโโโโโโโโโโณโโโโโโโโโโโโโณโโโโโโโโ
+โ Task โ 1-shot โ 5-shot โ 10-shot โ Avg. โ
+โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ
+โ Ent_CoNLL03 โ 76.33ยฑ1.74 โ 82.50ยฑ1.87 โ 84.47ยฑ1.18 โ 81.10 โ
+โ woInst_Rel_CoNLL04 โ 34.86ยฑ6.20 โ 48.00ยฑ4.44 โ 55.65ยฑ2.53 โ 46.17 โ
+โ Rel_CoNLL04 โ 26.83ยฑ15.22 โ 47.39ยฑ3.60 โ 55.38ยฑ2.41 โ 43.20 โ
+โ Event_ACE05_Trigger โ 46.60ยฑ1.09 โ 57.21ยฑ3.51 โ 59.67ยฑ3.20 โ 54.49 โ
+โ Event_ACE05_Arg โ 21.60ยฑ3.61 โ 34.43ยฑ3.63 โ 39.62ยฑ2.60 โ 31.88 โ
+โ ABSA_16res โ 8.10ยฑ18.11 โ 52.73ยฑ5.52 โ 57.32ยฑ1.73 โ 39.38 โ
+โโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโดโโโโโโโโโโโโโดโโโโโโโโโโโโโดโโโโโโโโ
+"""
diff --git a/src/inference.py b/src/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..c501bf8de812c3e1e097a2e87a3a2898813568c8
--- /dev/null
+++ b/src/inference.py
@@ -0,0 +1,23 @@
+import os
+
+from rex.utils.logging import logger
+
+from src.task import MrcTaggingTask
+
+if __name__ == "__main__":
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+ task = MrcTaggingTask.from_taskdir(
+ "outputs/bert_mrc_ner",
+ load_best_model=True,
+ update_config={
+ "skip_train": True,
+ "debug_mode": False,
+ },
+ )
+
+ cases = ["123123", "123123"]
+ logger.info(f"Cases: {cases}")
+
+ ents = task.predict(cases)
+ logger.info(f"Results: {ents}")
diff --git a/src/metric.py b/src/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f4a18e0fac965b239c2e8fc2c490746e06929c1
--- /dev/null
+++ b/src/metric.py
@@ -0,0 +1,555 @@
+from collections import defaultdict
+from typing import Tuple
+
+from rex.metrics import calc_p_r_f1_from_tp_fp_fn, safe_division
+from rex.metrics.base import MetricBase
+from rex.metrics.tagging import tagging_prf1
+from rex.utils.batch import decompose_batch_into_instances
+from rex.utils.iteration import windowed_queue_iter
+from rex.utils.random import generate_random_string_with_datetime
+from sklearn.metrics import accuracy_score, matthews_corrcoef
+
+
+class MrcNERMetric(MetricBase):
+ def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple:
+ gold_instances = []
+ pred_instances = []
+
+ batch_gold = decompose_batch_into_instances(raw_batch)
+ assert len(batch_gold) == len(out_batch["pred"])
+
+ for i, gold in enumerate(batch_gold):
+ gold_instances.append(
+ {
+ "id": gold["id"],
+ "ents": {(gold["ent_type"], gent) for gent in gold["gold_ents"]},
+ }
+ )
+ pred_instances.append(
+ {
+ "id": gold["id"],
+ "ents": {(gold["ent_type"], pent) for pent in out_batch["pred"][i]},
+ }
+ )
+
+ return gold_instances, pred_instances
+
+ def calculate_scores(self, golds: list, preds: list) -> dict:
+ id2gold = defaultdict(set)
+ id2pred = defaultdict(set)
+ # aggregate all ents with diff queries before evaluating
+ for gold in golds:
+ id2gold[gold["id"]].update(gold["ents"])
+ for pred in preds:
+ id2pred[pred["id"]].update(pred["ents"])
+ assert len(id2gold) == len(id2pred)
+
+ gold_ents = []
+ pred_ents = []
+ for _id in id2gold:
+ gold_ents.append(id2gold[_id])
+ pred_ents.append(id2pred[_id])
+
+ return tagging_prf1(gold_ents, pred_ents, type_idx=0)
+
+
+class MrcSpanMetric(MetricBase):
+ def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple:
+ gold_instances = []
+ pred_instances = []
+
+ batch_gold = decompose_batch_into_instances(raw_batch)
+ assert len(batch_gold) == len(out_batch["pred"])
+
+ for i, gold in enumerate(batch_gold):
+ gold_instances.append(
+ {
+ "id": gold["id"],
+ "spans": set(tuple(span) for span in gold["gold_spans"]),
+ }
+ )
+ pred_instances.append(
+ {
+ "id": gold["id"],
+ "spans": set(out_batch["pred"][i]),
+ }
+ )
+
+ return gold_instances, pred_instances
+
+ def calculate_scores(self, golds: list, preds: list) -> dict:
+ id2gold = defaultdict(set)
+ id2pred = defaultdict(set)
+ # aggregate all ents with diff queries before evaluating
+ for gold in golds:
+ id2gold[gold["id"]].update(gold["spans"])
+ for pred in preds:
+ id2pred[pred["id"]].update(pred["spans"])
+ assert len(id2gold) == len(id2pred)
+
+ gold_spans = []
+ pred_spans = []
+ for _id in id2gold:
+ gold_spans.append(id2gold[_id])
+ pred_spans.append(id2pred[_id])
+
+ return tagging_prf1(gold_spans, pred_spans, type_idx=None)
+
+
+def calc_char_event(golds, preds):
+ """
+ Calculate char-level event argument scores
+
+ References:
+ - https://aistudio.baidu.com/aistudio/competition/detail/46/0/submit-result
+
+ Args:
+ golds: a list of gold answers (a list of `event_list`), len=#data,
+ format is a list of `event_list`
+ preds: a list of pred answers, len=#data
+ """
+
+ def _match_arg_char_f1(gold_arg, pred_args):
+ gtype, grole, gstring = gold_arg
+ gchars = set(gstring)
+ garg_len = len(gchars)
+ cands = []
+ for parg in pred_args:
+ if parg[0] == gtype and parg[1] == grole:
+ pchars = set(str(parg[-1]))
+ parg_len = len(pchars)
+ pmatch = len(pchars & gchars)
+ p = safe_division(pmatch, parg_len)
+ r = safe_division(pmatch, garg_len)
+ f1 = safe_division(2 * p * r, p + r)
+ cands.append(f1)
+ if len(cands) > 0:
+ f1 = sorted(cands)[-1]
+ return f1
+ else:
+ return 0.0
+
+ pscore = num_gargs = num_pargs = 0
+ for _golds, _preds in zip(golds, preds):
+ # _golds and _preds pair in one data instance
+ gold_args = []
+ pred_args = []
+ for gold in _golds:
+ for arg in gold.get("arguments", []):
+ gold_args.append(
+ (gold.get("event_type"), arg.get("role"), arg.get("argument"))
+ )
+ for pred in _preds:
+ for arg in pred.get("arguments", []):
+ pred_args.append(
+ (pred.get("event_type"), arg.get("role"), arg.get("argument"))
+ )
+
+ num_gargs += len(gold_args)
+ num_pargs += len(pred_args)
+ for gold_arg in gold_args:
+ pscore += _match_arg_char_f1(gold_arg, pred_args)
+
+ p = safe_division(pscore, num_pargs)
+ r = safe_division(pscore, num_gargs)
+ f1 = safe_division(2 * p * r, p + r)
+ return {
+ "p": p,
+ "r": r,
+ "f1": f1,
+ "pscore": pscore,
+ "num_pargs": num_pargs,
+ "num_gargs": num_gargs,
+ }
+
+
+def calc_trigger_identification_metrics(golds, preds):
+ tp = fp = fn = 0
+ for _golds, _preds in zip(golds, preds):
+ gold_triggers = {gold["trigger"] for gold in _golds}
+ pred_triggers = {pred["trigger"] for pred in _preds}
+ tp += len(gold_triggers & pred_triggers)
+ fp += len(pred_triggers - gold_triggers)
+ fn += len(gold_triggers - pred_triggers)
+ metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
+ return metrics
+
+
+def calc_trigger_classification_metrics(golds, preds):
+ tp = fp = fn = 0
+ for _golds, _preds in zip(golds, preds):
+ gold_tgg_cls = {(gold["trigger"], gold["event_type"]) for gold in _golds}
+ pred_tgg_cls = {(pred["trigger"], pred["event_type"]) for pred in _preds}
+ tp += len(gold_tgg_cls & pred_tgg_cls)
+ fp += len(pred_tgg_cls - gold_tgg_cls)
+ fn += len(gold_tgg_cls - pred_tgg_cls)
+ metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
+ return metrics
+
+
+def calc_arg_identification_metrics(golds, preds):
+ """Calculate argument identification metrics
+
+ Notice:
+ An entity could take different roles in an event,
+ so the base number must be calculated by
+ (arg, event type, pos, role)
+ """
+ tp = fp = fn = 0
+ for _golds, _preds in zip(golds, preds):
+ gold_args = set()
+ pred_args = set()
+ for gold in _golds:
+ _args = {
+ (arg["role"], arg["argument"], gold["event_type"])
+ for arg in gold["arguments"]
+ }
+ gold_args.update(_args)
+ for pred in _preds:
+ _args = {
+ (arg["role"], arg["argument"], pred["event_type"])
+ for arg in pred["arguments"]
+ }
+ pred_args.update(_args)
+ # logic derived from OneIE
+ _tp = 0
+ _tp_fp = len(pred_args)
+ _tp_fn = len(gold_args)
+ _gold_args_wo_role = {_ga[1:] for _ga in gold_args}
+ for pred_arg in pred_args:
+ if pred_arg[1:] in _gold_args_wo_role:
+ _tp += 1
+ tp += _tp
+ fp += _tp_fp - _tp
+ fn += _tp_fn - _tp
+ metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
+ return metrics
+
+
+def calc_arg_classification_metrics(golds, preds):
+ tp = fp = fn = 0
+ for _golds, _preds in zip(golds, preds):
+ gold_arg_cls = set()
+ pred_arg_cls = set()
+ for gold in _golds:
+ _args = {
+ (arg["argument"], arg["role"], gold["event_type"])
+ for arg in gold["arguments"]
+ }
+ gold_arg_cls.update(_args)
+ for pred in _preds:
+ _args = {
+ (arg["argument"], arg["role"], pred["event_type"])
+ for arg in pred["arguments"]
+ }
+ pred_arg_cls.update(_args)
+ tp += len(gold_arg_cls & pred_arg_cls)
+ fp += len(pred_arg_cls - gold_arg_cls)
+ fn += len(gold_arg_cls - pred_arg_cls)
+ metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
+ return metrics
+
+
+def calc_ent(golds, preds):
+ """
+ Args:
+ golds, preds: [(type, index list), ...]
+ """
+ res = tagging_prf1(golds, preds, type_idx=0)
+ return res
+
+
+def calc_rel(golds, preds):
+ gold_ents = []
+ pred_ents = []
+ for gold, pred in zip(golds, preds):
+ gold_ins_ents = []
+ for t in gold:
+ gold_ins_ents.extend(t[1:])
+ gold_ents.append(gold_ins_ents)
+ pred_ins_ents = []
+ for t in pred:
+ pred_ins_ents.extend(t[1:])
+ pred_ents.append(pred_ins_ents)
+
+ metrics = {
+ "ent": tagging_prf1(gold_ents, pred_ents, type_idx=None),
+ "rel": tagging_prf1(golds, preds, type_idx=None),
+ }
+ return metrics
+
+
+def calc_cls(golds, preds):
+ metrics = {
+ "mcc": -1,
+ "acc": -1,
+ "mf1": tagging_prf1(golds, preds, type_idx=None),
+ }
+ y_true = []
+ y_pred = []
+ for gold, pred in zip(golds, preds):
+ y_true.append(" ".join(sorted(gold)))
+ y_pred.append(" ".join(sorted(pred)))
+ if y_true and y_pred:
+ metrics["acc"] = accuracy_score(y_true, y_pred)
+ else:
+ metrics["acc"] = 0.0
+ metrics["mcc"] = matthews_corrcoef(y_true, y_pred)
+ return metrics
+
+
+def calc_span(golds, preds, mode="span"):
+ def _get_tokens(spans: list[tuple[tuple[int]]]) -> list[int]:
+ tokens = []
+ for span in spans:
+ for part in span:
+ _toks = []
+ if len(part) == 1:
+ _toks = [part[0]]
+ elif len(part) > 1:
+ if mode == "w2":
+ _toks = [*part]
+ elif mode == "span":
+ _toks = [*range(part[0], part[1] + 1)]
+ else:
+ raise ValueError
+ tokens.extend(_toks)
+ return tokens
+
+ metrics = {
+ "em": -1,
+ "f1": None,
+ }
+ acc_num = 0
+ tp = fp = fn = 0
+ for gold, pred in zip(golds, preds):
+ if gold == pred:
+ acc_num += 1
+ gold_tokens = _get_tokens(gold)
+ pred_tokens = _get_tokens(pred)
+ tp += len(set(gold_tokens) & set(pred_tokens))
+ fp += len(set(pred_tokens) - set(gold_tokens))
+ fn += len(set(gold_tokens) - set(pred_tokens))
+ if len(golds) > 0:
+ metrics["em"] = acc_num / len(golds)
+ else:
+ metrics["em"] = 0.0
+ metrics["f1"] = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
+ return metrics
+
+
+class MultiPartSpanMetric(MetricBase):
+ def _encode_span_to_label_dict(self, span_to_label: dict) -> list:
+ span_to_label_list = []
+ for key, val in span_to_label.items():
+ span_to_label_list.append({"key": key, "val": val})
+ return span_to_label_list
+
+ def _decode_span_to_label(self, span_to_label_list: list) -> dict:
+ span_to_label = {}
+ for content in span_to_label_list:
+ span_to_label[tuple(content["key"])] = content["val"]
+ return span_to_label
+
+ def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple:
+ gold_instances = []
+ pred_instances = []
+
+ batch_gold = decompose_batch_into_instances(raw_batch)
+ assert len(batch_gold) == len(out_batch["pred"])
+
+ for i, gold in enumerate(batch_gold):
+ ins_id = gold["raw"].get("id", generate_random_string_with_datetime())
+ # encode to list to make the span_to_label dict json-serializable
+ # where the original dict key is a tuple
+ span_to_label_list = self._encode_span_to_label_dict(gold["span_to_label"])
+ gold["span_to_label"] = span_to_label_list
+ gold_instances.append(
+ {
+ "id": ins_id,
+ "span_to_label_list": span_to_label_list,
+ "raw_gold_content": gold,
+ "spans": set(
+ tuple(multi_part_span) for multi_part_span in gold["spans"]
+ ),
+ }
+ )
+ pred_instances.append(
+ {
+ "id": ins_id,
+ "spans": set(
+ tuple(multi_part_span)
+ for multi_part_span in out_batch["pred"][i]
+ ),
+ }
+ )
+
+ return gold_instances, pred_instances
+
+ def calculate_scores(self, golds: list, preds: list) -> dict:
+ # for general purpose evaluation
+ general_gold_spans, general_pred_spans = [], []
+ # cls task
+ gold_cls_list, pred_cls_list = [], []
+ # ent task
+ gold_ent_list, pred_ent_list = [], []
+ # rel task
+ gold_rel_list, pred_rel_list = [], []
+ # event task
+ gold_event_list, pred_event_list = [], []
+ # span task
+ gold_span_list, pred_span_list = [], []
+ # discon ent task
+ gold_discon_ent_list, pred_discon_ent_list = [], []
+ # hyper rel task
+ gold_hyper_rel_list, pred_hyper_rel_list = [], []
+
+ for gold, pred in zip(golds, preds):
+ general_gold_spans.append(gold["spans"])
+ general_pred_spans.append(pred["spans"])
+ span_to_label = self._decode_span_to_label(gold["span_to_label_list"])
+ gold_clses, pred_clses = [], []
+ gold_ents, pred_ents = [], []
+ gold_rels, pred_rels = [], []
+ gold_trigger_to_event = defaultdict(
+ lambda: {"event_type": "", "arguments": []}
+ )
+ pred_trigger_to_event = defaultdict(
+ lambda: {"event_type": "", "arguments": []}
+ )
+ gold_events, pred_events = [], []
+ gold_spans, pred_spans = [], []
+ gold_discon_ents, pred_discon_ents = [], []
+ gold_hyper_rels, pred_hyper_rels = [], []
+
+ raw_schema = gold["raw_gold_content"]["raw"]["schema"]
+ for span in gold["spans"]:
+ if span[0] in span_to_label:
+ label = span_to_label[span[0]]
+ if label["task"] == "cls" and len(span) == 1:
+ gold_clses.append(label["string"])
+ elif label["task"] == "ent" and len(span) == 2:
+ gold_ents.append((label["string"], *span[1:]))
+ elif label["task"] == "rel" and len(span) == 3:
+ gold_rels.append((label["string"], *span[1:]))
+ elif label["task"] == "event":
+ if label["type"] == "lm" and len(span) == 2:
+ gold_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip
+ elif label["type"] == "lr" and len(span) == 3:
+ gold_trigger_to_event[span[1]]["arguments"].append(
+ {"argument": span[2], "role": label["string"]}
+ )
+ elif label["task"] == "discontinuous_ent" and len(span) > 1:
+ gold_discon_ents.append((label["string"], *span[1:]))
+ elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip
+ q_label = span_to_label[span[3]]
+ gold_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) # fmt: skip
+ else:
+ # span task has no labels
+ gold_spans.append(tuple(span))
+ for trigger, item in gold_trigger_to_event.items():
+ legal_roles = raw_schema["event"][item["event_type"]]
+ gold_events.append(
+ {
+ "trigger": trigger,
+ "event_type": item["event_type"],
+ "arguments": [
+ arg
+ for arg in filter(
+ lambda arg: arg["role"] in legal_roles,
+ item["arguments"],
+ )
+ ],
+ }
+ )
+
+ for span in pred["spans"]:
+ if span[0] in span_to_label:
+ label = span_to_label[span[0]]
+ if label["task"] == "cls" and len(span) == 1:
+ pred_clses.append(label["string"])
+ elif label["task"] == "ent" and len(span) == 2:
+ pred_ents.append((label["string"], *span[1:]))
+ elif label["task"] == "rel" and len(span) == 3:
+ pred_rels.append((label["string"], *span[1:]))
+ elif label["task"] == "event":
+ if label["type"] == "lm" and len(span) == 2:
+ pred_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip
+ elif label["type"] == "lr" and len(span) == 3:
+ pred_trigger_to_event[span[1]]["arguments"].append(
+ {"argument": span[2], "role": label["string"]}
+ )
+ elif label["task"] == "discontinuous_ent" and len(span) > 1:
+ pred_discon_ents.append((label["string"], *span[1:]))
+ elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip
+ q_label = span_to_label[span[3]]
+ pred_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) # fmt: skip
+ else:
+ # span task has no labels
+ pred_spans.append(tuple(span))
+ for trigger, item in pred_trigger_to_event.items():
+ if item["event_type"] not in raw_schema["event"]:
+ continue
+ legal_roles = raw_schema["event"][item["event_type"]]
+ pred_events.append(
+ {
+ "trigger": trigger,
+ "event_type": item["event_type"],
+ "arguments": [
+ arg
+ for arg in filter(
+ lambda arg: arg["role"] in legal_roles,
+ item["arguments"],
+ )
+ ],
+ }
+ )
+
+ gold_cls_list.append(gold_clses)
+ pred_cls_list.append(pred_clses)
+ gold_ent_list.append(gold_ents)
+ pred_ent_list.append(pred_ents)
+ gold_rel_list.append(gold_rels)
+ pred_rel_list.append(pred_rels)
+ gold_event_list.append(gold_events)
+ pred_event_list.append(pred_events)
+ gold_span_list.append(gold_spans)
+ pred_span_list.append(pred_spans)
+ gold_discon_ent_list.append(gold_discon_ents)
+ pred_discon_ent_list.append(pred_discon_ents)
+ gold_hyper_rel_list.append(gold_hyper_rels)
+ pred_hyper_rel_list.append(pred_hyper_rels)
+
+ metrics = {
+ "general_spans": tagging_prf1(
+ general_gold_spans, general_pred_spans, type_idx=None
+ ),
+ "cls": calc_cls(gold_cls_list, pred_cls_list),
+ "ent": calc_ent(gold_ent_list, pred_ent_list),
+ "rel": calc_rel(gold_rel_list, pred_rel_list),
+ "event": {
+ "trigger_id": calc_trigger_identification_metrics(
+ gold_event_list, pred_event_list
+ ),
+ "trigger_cls": calc_trigger_classification_metrics(
+ gold_event_list, pred_event_list
+ ),
+ "arg_id": calc_arg_identification_metrics(
+ gold_event_list, pred_event_list
+ ),
+ "arg_cls": calc_arg_classification_metrics(
+ gold_event_list, pred_event_list
+ ),
+ "char_event": calc_char_event(gold_event_list, pred_event_list),
+ },
+ "discontinuous_ent": tagging_prf1(
+ gold_discon_ent_list, pred_discon_ent_list, type_idx=None
+ ),
+ "hyper_rel": tagging_prf1(
+ gold_hyper_rel_list, pred_hyper_rel_list, type_idx=None
+ ),
+ # "span": tagging_prf1(gold_span_list, pred_span_list, type_idx=None),
+ "span": calc_span(gold_span_list, pred_span_list),
+ }
+
+ return metrics
diff --git a/src/model.py b/src/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..57ba605e6752cab782a437f7ddfbd8054fac36d7
--- /dev/null
+++ b/src/model.py
@@ -0,0 +1,533 @@
+import torch
+import torch.nn as nn
+from rex.utils.iteration import windowed_queue_iter
+from transformers import AutoModel, BertModel
+
+from src.utils import decode_nnw_nsw_thw_mat, decode_nnw_thw_mat, decode_pointer_mat
+
+
+class Biaffine(nn.Module):
+ """Biaffine transformation
+
+ References:
+ - https://github.com/yzhangcs/parser/blob/main/supar/modules/affine.py
+ - https://github.com/ljynlp/W2NER
+ """
+
+ def __init__(self, n_in, n_out=2, bias_x=True, bias_y=True):
+ super().__init__()
+
+ self.n_in = n_in
+ self.n_out = n_out
+ self.bias_x = bias_x
+ self.bias_y = bias_y
+ weight = torch.zeros(n_out, n_in + int(bias_x), n_in + int(bias_y))
+ nn.init.xavier_normal_(weight)
+ self.weight = nn.Parameter(weight, requires_grad=True)
+
+ def extra_repr(self):
+ s = f"n_in={self.n_in}, n_out={self.n_out}"
+ if self.bias_x:
+ s += f", bias_x={self.bias_x}"
+ if self.bias_y:
+ s += f", bias_y={self.bias_y}"
+
+ return s
+
+ def forward(self, x, y):
+ if self.bias_x:
+ x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
+ if self.bias_y:
+ y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
+ # [batch_size, n_out, seq_len, seq_len]
+ s = torch.einsum("bxi,oij,byj->boxy", x, self.weight, y)
+ # s = s.permute(0, 2, 3, 1)
+
+ return s
+
+
+class LinearWithAct(nn.Module):
+ def __init__(self, n_in, n_out, dropout=0) -> None:
+ super().__init__()
+
+ self.linear = nn.Linear(n_in, n_out)
+ self.act_fn = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.act_fn(x)
+ x = self.dropout(x)
+ return x
+
+
+class PointerMatrix(nn.Module):
+ """Pointer Matrix Prediction
+
+ References:
+ - https://github.com/ljynlp/W2NER
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ biaffine_size,
+ cls_num=2,
+ dropout=0,
+ biaffine_bias=False,
+ use_rope=False,
+ ):
+ super().__init__()
+ self.linear_h = LinearWithAct(
+ n_in=hidden_size, n_out=biaffine_size, dropout=dropout
+ )
+ self.linear_t = LinearWithAct(
+ n_in=hidden_size, n_out=biaffine_size, dropout=dropout
+ )
+ self.biaffine = Biaffine(
+ n_in=biaffine_size,
+ n_out=cls_num,
+ bias_x=biaffine_bias,
+ bias_y=biaffine_bias,
+ )
+ self.use_rope = use_rope
+
+ def sinusoidal_position_embedding(self, qw, kw):
+ batch_size, seq_len, output_dim = qw.shape
+ position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)
+
+ indices = torch.arange(0, output_dim // 2, dtype=torch.float)
+ indices = torch.pow(10000, -2 * indices / output_dim)
+ pos_emb = position_ids * indices
+ pos_emb = torch.stack([torch.sin(pos_emb), torch.cos(pos_emb)], dim=-1)
+ pos_emb = pos_emb.repeat((batch_size, *([1] * len(pos_emb.shape))))
+ pos_emb = torch.reshape(pos_emb, (batch_size, seq_len, output_dim))
+ pos_emb = pos_emb.to(qw)
+
+ # (bs, seq_len, 1, hz) -> (bs, seq_len, hz)
+ cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
+ # (bs, seq_len, 1, hz) -> (bs, seq_len, hz)
+ sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)
+ qw2 = torch.cat([-qw[..., 1::2], qw[..., ::2]], -1)
+ qw = qw * cos_pos + qw2 * sin_pos
+ kw2 = torch.cat([-kw[..., 1::2], kw[..., ::2]], -1)
+ kw = kw * cos_pos + kw2 * sin_pos
+ return qw, kw
+
+ def forward(self, x):
+ h = self.linear_h(x)
+ t = self.linear_t(x)
+ if self.use_rope:
+ h, t = self.sinusoidal_position_embedding(h, t)
+ o = self.biaffine(h, t)
+ return o
+
+
+def multilabel_categorical_crossentropy(y_pred, y_true, bit_mask=None):
+ """
+ https://kexue.fm/archives/7359
+ https://github.com/gaohongkui/GlobalPointer_pytorch/blob/main/common/utils.py
+ """
+ y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes
+ y_pred_neg = y_pred - y_true * 1e12 # mask the pred outputs of pos classes
+ y_pred_pos = y_pred - (1 - y_true) * 1e12 # mask the pred outputs of neg classes
+ zeros = torch.zeros_like(y_pred[..., :1])
+ y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
+ y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
+ neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
+ pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
+
+ if bit_mask is None:
+ return neg_loss + pos_loss
+ else:
+ raise NotImplementedError
+
+
+class MrcPointerMatrixModel(nn.Module):
+ def __init__(
+ self,
+ plm_dir: str,
+ cls_num: int = 2,
+ biaffine_size: int = 384,
+ none_type_id: int = 0,
+ text_mask_id: int = 4,
+ dropout: float = 0.3,
+ ):
+ super().__init__()
+
+ # num of predicted classes, default is 3: None, NNW and THW
+ self.cls_num = cls_num
+ # None type id: 0, Next Neighboring Word (NNW): 1, Tail Head Word (THW): 2
+ self.none_type_id = none_type_id
+ # input: cls instruction sep text sep pad
+ # mask: 1 2 3 4 5 0
+ self.text_mask_id = text_mask_id
+
+ self.plm = BertModel.from_pretrained(plm_dir)
+ hidden_size = self.plm.config.hidden_size
+ # self.biaffine_size = biaffine_size
+ self.nnw_mat = PointerMatrix(
+ hidden_size, biaffine_size, cls_num=2, dropout=dropout
+ )
+ self.thw_mat = PointerMatrix(
+ hidden_size, biaffine_size, cls_num=2, dropout=dropout
+ )
+ self.criterion = nn.CrossEntropyLoss()
+
+ def input_encoding(self, input_ids, mask):
+ attention_mask = mask.gt(0).float()
+ plm_outputs = self.plm(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ )
+ return plm_outputs.last_hidden_state
+
+ def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ # mask: (batch_size, seq_len)
+ bs, seq_len = mask.shape
+ mask_mat = (
+ mask.eq(self.text_mask_id).unsqueeze(-1).expand((bs, seq_len, seq_len))
+ )
+ # bit_mask: (batch_size, seq_len, seq_len, 1)
+ bit_mask = (
+ torch.logical_and(mask_mat, mask_mat.transpose(1, 2)).unsqueeze(1).long()
+ )
+ return bit_mask
+
+ def forward(self, input_ids, mask, labels=None, is_eval=False, **kwargs):
+ hidden = self.input_encoding(input_ids, mask)
+ nnw_hidden = self.nnw_mat(hidden)
+ thw_hidden = self.thw_mat(hidden)
+ # nnw_hidden = nnw_hidden / self.biaffine_size ** 0.5
+ # thw_hidden = thw_hidden / self.biaffine_size ** 0.5
+ # # (bs, 2, seq_len, seq_len)
+ bs, _, seq_len, seq_len = nnw_hidden.shape
+
+ bit_mask = self.build_bit_mask(mask)
+
+ results = {"logits": {"nnw": nnw_hidden, "thw": thw_hidden}}
+ if labels is not None:
+ # mean
+ nnw_loss = self.criterion(
+ nnw_hidden.permute(0, 2, 3, 1).reshape(-1, 2),
+ labels[:, 0, :, :].reshape(-1),
+ )
+ thw_loss = self.criterion(
+ thw_hidden.permute(0, 2, 3, 1).reshape(-1, 2),
+ labels[:, 1, :, :].reshape(-1),
+ )
+ loss = nnw_loss + thw_loss
+ results["loss"] = loss
+
+ if is_eval:
+ batch_positions = self.decode(nnw_hidden, thw_hidden, bit_mask, **kwargs)
+ results["pred"] = batch_positions
+ return results
+
+ def decode(
+ self,
+ nnw_hidden: torch.Tensor,
+ thw_hidden: torch.Tensor,
+ bit_mask: torch.Tensor,
+ **kwargs,
+ ):
+ # B x L x L
+ nnw_pred = nnw_hidden.argmax(1)
+ thw_pred = thw_hidden.argmax(1)
+ # B x 2 x L x L
+ pred = torch.stack([nnw_pred, thw_pred], dim=1)
+ pred = pred * bit_mask
+
+ batch_preds = decode_nnw_thw_mat(pred, offsets=kwargs.get("offset"))
+
+ return batch_preds
+
+
+class MrcGlobalPointerModel(nn.Module):
+ def __init__(
+ self,
+ plm_dir: str,
+ use_rope: bool = True,
+ cls_num: int = 2,
+ biaffine_size: int = 384,
+ none_type_id: int = 0,
+ text_mask_id: int = 4,
+ dropout: float = 0.3,
+ mode: str = "w2",
+ ):
+ super().__init__()
+
+ # num of predicted classes, default is 3: None, NNW and THW
+ self.cls_num = cls_num
+ # None type id: 0, Next Neighboring Word (NNW): 1, Tail Head Word (THW): 2
+ self.none_type_id = none_type_id
+ # input: cls instruction sep text sep pad
+ # mask: 1 2 3 4 5 0
+ self.text_mask_id = text_mask_id
+ self.use_rope = use_rope
+
+ # mode: w2: w2ner, cons: consecutive spans
+ self.mode = mode
+ assert self.mode in ["w2", "cons"]
+
+ self.plm = BertModel.from_pretrained(plm_dir)
+ self.hidden_size = self.plm.config.hidden_size
+ self.biaffine_size = biaffine_size
+ self.pointer = PointerMatrix(
+ self.hidden_size,
+ biaffine_size,
+ cls_num=2 if self.mode == "w2" else 1,
+ dropout=dropout,
+ biaffine_bias=True,
+ use_rope=use_rope,
+ )
+
+ def input_encoding(self, input_ids, mask):
+ attention_mask = mask.gt(0).float()
+ plm_outputs = self.plm(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ )
+ return plm_outputs.last_hidden_state
+
+ def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ # mask: (batch_size, seq_len)
+ bs, seq_len = mask.shape
+ mask_mat = (
+ mask.eq(self.text_mask_id).unsqueeze(-1).expand((bs, seq_len, seq_len))
+ )
+ # bit_mask: (batch_size, 1, seq_len, seq_len)
+ bit_mask = (
+ torch.logical_and(mask_mat, mask_mat.transpose(1, 2)).unsqueeze(1).float()
+ )
+ if self.mode == "cons":
+ bit_mask = bit_mask.triu()
+
+ return bit_mask
+
+ def forward(
+ self, input_ids, mask, labels=None, is_eval=False, top_p=0.5, top_k=-1, **kwargs
+ ):
+ bit_mask = self.build_bit_mask(mask)
+ hidden = self.input_encoding(input_ids, mask)
+ # (bs, 2, seq_len, seq_len)
+ logits = self.pointer(hidden)
+ logits = logits * bit_mask - (1.0 - bit_mask) * 1e12
+ logits = logits / (self.biaffine_size**0.5)
+ # # (bs, 2, seq_len, seq_len)
+ bs, cls_num, seq_len, seq_len = logits.shape
+ assert labels.shape == (bs, cls_num, seq_len, seq_len)
+
+ results = {"logits": logits}
+ if labels is not None:
+ loss = multilabel_categorical_crossentropy(
+ logits.reshape(bs * cls_num, -1), labels.reshape(bs * cls_num, -1)
+ )
+ loss = loss.mean()
+ results["loss"] = loss
+
+ if is_eval:
+ batch_positions = self.decode(logits, top_p=top_p, top_k=top_k, **kwargs)
+ results["pred"] = batch_positions
+ return results
+
+ def calc_path_prob(self, probs, paths):
+ """
+ Args:
+ probs: (2, seq_len, seq_len) | (1, seq_len, seq_len)
+ paths: a list of paths in tuple
+
+ Returns:
+ [(path: tuple, prob: float), ...]
+ """
+ assert self.mode in ["w2", "cons"]
+ paths_with_prob = []
+ for path in paths:
+ path_prob = 1.0
+ if self.mode == "w2":
+ for se in windowed_queue_iter(path, 2, 1, drop_last=True):
+ path_prob *= probs[0, se[0], se[-1]]
+ path_prob *= probs[1, path[-1], path[0]]
+ elif self.mode == "cons":
+ path_prob = probs[0, path[0], path[-1]]
+ paths_with_prob.append((path, path_prob))
+ return paths_with_prob
+
+ def decode(
+ self,
+ logits: torch.Tensor,
+ top_p: float = 0.5,
+ top_k: int = -1,
+ **kwargs,
+ ):
+ # mode: w2: w2ner with nnw and thw labels, cons: consecutive spans with one type of labels
+ assert self.mode in ["w2", "cons"]
+ # B x 2 x L x L
+ probs = logits.sigmoid()
+ pred = (probs > top_p).long()
+ if self.mode == "w2":
+ preds = decode_nnw_thw_mat(pred, offsets=kwargs.get("offset"))
+ elif self.mode == "cons":
+ pred = pred.triu()
+ preds = decode_pointer_mat(pred, offsets=kwargs.get("offset"))
+
+ if top_k == -1:
+ batch_preds = preds
+ else:
+ batch_preds = []
+ for i, paths in enumerate(preds):
+ paths_with_prob = self.calc_path_prob(probs[i], paths)
+ paths_with_prob.sort(key=lambda pp: pp[1], reverse=True)
+ batch_preds.append([pp[0] for pp in paths_with_prob[:top_k]])
+
+ return batch_preds
+
+
+class SchemaGuidedInstructBertModel(nn.Module):
+ def __init__(
+ self,
+ plm_dir: str,
+ vocab_size: int = None,
+ use_rope: bool = True,
+ biaffine_size: int = 512,
+ label_mask_id: int = 4,
+ text_mask_id: int = 7,
+ dropout: float = 0.3,
+ ):
+ super().__init__()
+
+ # input: [CLS] [I] Instruction [LM] PER [LM] LOC [LM] ORG [TL] Text [B] Background [SEP] [PAD]
+ # mask: 1 2 3 4 5 4 5 4 5 6 7 8 9 10 0
+ self.label_mask_id = label_mask_id
+ self.text_mask_id = text_mask_id
+ self.use_rope = use_rope
+
+ self.plm = AutoModel.from_pretrained(plm_dir)
+ if vocab_size:
+ self.plm.resize_token_embeddings(vocab_size)
+ self.hidden_size = self.plm.config.hidden_size
+ self.biaffine_size = biaffine_size
+ self.pointer = PointerMatrix(
+ self.hidden_size,
+ biaffine_size,
+ cls_num=3,
+ dropout=dropout,
+ biaffine_bias=True,
+ use_rope=use_rope,
+ )
+
+ def input_encoding(self, input_ids, mask):
+ attention_mask = mask.gt(0).float()
+ plm_outputs = self.plm(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ )
+ return plm_outputs.last_hidden_state
+
+ def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ # mask: (batch_size, seq_len)
+ bs, seq_len = mask.shape
+ # _m = torch.logical_or(mask.eq(self.label_mask_id), mask.eq(self.text_mask_id))
+ # mask_mat = _m.unsqueeze(-1).expand((bs, seq_len, seq_len))
+ # # bit_mask: (batch_size, 1, seq_len, seq_len)
+ # bit_mask = (
+ # torch.logical_and(mask_mat, mask_mat.transpose(1, 2)).unsqueeze(1).float()
+ # )
+ bit_mask = (
+ mask.gt(0).unsqueeze(1).unsqueeze(1).expand(bs, 1, seq_len, seq_len).float()
+ )
+
+ return bit_mask
+
+ def forward(
+ self, input_ids, mask, labels=None, is_eval=False, top_p=0.5, top_k=-1, **kwargs
+ ):
+ bit_mask = self.build_bit_mask(mask)
+ hidden = self.input_encoding(input_ids, mask)
+ # (bs, 3, seq_len, seq_len)
+ logits = self.pointer(hidden)
+ logits = logits * bit_mask - (1.0 - bit_mask) * 1e12
+ logits = logits / (self.biaffine_size**0.5)
+ # # (bs, 3, seq_len, seq_len)
+ bs, cls_num, seq_len, seq_len = logits.shape
+ assert labels.shape == (bs, cls_num, seq_len, seq_len)
+
+ results = {"logits": logits}
+ if labels is not None:
+ loss = multilabel_categorical_crossentropy(
+ logits.reshape(bs * cls_num, -1), labels.reshape(bs * cls_num, -1)
+ )
+ loss = loss.mean()
+ results["loss"] = loss
+
+ if is_eval:
+ batch_positions = self.decode(logits, top_p=top_p, top_k=top_k, **kwargs)
+ results["pred"] = batch_positions
+ return results
+
+ def calc_path_prob(self, probs, paths):
+ """
+ Args:
+ probs: (2, seq_len, seq_len) | (1, seq_len, seq_len)
+ paths: a list of paths in tuple
+
+ Returns:
+ [(path: tuple, prob: float), ...]
+ """
+ paths_with_prob = []
+ for path in paths:
+ path_prob = 1.0
+ for se in windowed_queue_iter(path, 2, 1, drop_last=True):
+ path_prob *= probs[0, se[0], se[-1]]
+ path_prob *= probs[1, path[-1], path[0]]
+ paths_with_prob.append((path, path_prob))
+ return paths_with_prob
+
+ def decode(
+ self,
+ logits: torch.Tensor,
+ top_p: float = 0.5,
+ top_k: int = -1,
+ # legal_num_parts: tuple = (1, 2, 3),
+ legal_num_parts: tuple = None,
+ labels: torch.Tensor = None,
+ **kwargs,
+ ):
+ # B x 3 x L x L
+ if labels is None:
+ # `labels` is used for upper bound analysis
+ probs = logits.sigmoid()
+ pred = (probs > top_p).long()
+ else:
+ pred = labels
+ preds = decode_nnw_nsw_thw_mat(pred, offsets=kwargs.get("offset"))
+ # for pred, gold in zip(preds, kwargs.get("spans")):
+ # sorted_pred = sorted(set(tuple(x) for x in pred))
+ # sorted_gold = sorted(set(tuple(x) for x in gold))
+ # if sorted_pred != sorted_gold:
+ # breakpoint()
+
+ if top_k == -1:
+ batch_preds = preds
+ else:
+ batch_preds = []
+ for i, paths in enumerate(preds):
+ paths_with_prob = self.calc_path_prob(probs[i], paths)
+ paths_with_prob.sort(key=lambda pp: pp[1], reverse=True)
+ batch_preds.append([pp[0] for pp in paths_with_prob[:top_k]])
+
+ if legal_num_parts is not None:
+ legal_preds = []
+ for ins_paths in batch_preds:
+ legal_paths = []
+ for path in ins_paths:
+ if len(path) in legal_num_parts:
+ legal_paths.append(path)
+ legal_preds.append(legal_paths)
+ else:
+ legal_preds = batch_preds
+
+ return legal_preds
diff --git a/src/preprocess.py b/src/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/task.py b/src/task.py
new file mode 100644
index 0000000000000000000000000000000000000000..52285fdd5cacbbb2e8c3276343884c111647d89a
--- /dev/null
+++ b/src/task.py
@@ -0,0 +1,590 @@
+import math
+import re
+from collections import defaultdict
+from datetime import datetime
+from typing import List
+
+import torch
+import torch.optim as optim
+from rex import accelerator
+from rex.data.data_manager import DataManager
+from rex.data.dataset import CachedDataset, StreamReadDataset
+from rex.tasks.simple_metric_task import SimpleMetricTask
+from rex.utils.batch import decompose_batch_into_instances
+from rex.utils.config import ConfigParser
+from rex.utils.dict import flatten_dict
+from rex.utils.io import load_jsonlines
+from rex.utils.registry import register
+from torch.utils.tensorboard import SummaryWriter
+from transformers.optimization import (
+ get_cosine_schedule_with_warmup,
+ get_linear_schedule_with_warmup,
+)
+
+from .metric import MrcNERMetric, MrcSpanMetric, MultiPartSpanMetric
+from .model import (
+ MrcGlobalPointerModel,
+ MrcPointerMatrixModel,
+ SchemaGuidedInstructBertModel,
+)
+from .transform import (
+ CachedLabelPointerTransform,
+ CachedPointerMRCTransform,
+ CachedPointerTaggingTransform,
+)
+
+
+@register("task")
+class MrcTaggingTask(SimpleMetricTask):
+ def __init__(self, config, **kwargs) -> None:
+ super().__init__(config, **kwargs)
+
+ def after_initialization(self):
+ now_string = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ self.tb_logger: SummaryWriter = SummaryWriter(
+ log_dir=self.task_path / "tb_summary" / now_string,
+ comment=self.config.comment,
+ )
+
+ def after_whole_train(self):
+ self.tb_logger.close()
+
+ def get_grad_norm(self):
+ # for name, param in self.model.named_parameters():
+ # if param.grad is not None:
+ # grads = param.grad.detach().data
+ # grad_norm = (grads.norm(p=2) / grads.numel()).item()
+ total_norm = 0.0
+ for p in self.model.parameters():
+ if p.grad is not None:
+ param_norm = p.grad.detach().data.norm(2)
+ total_norm += param_norm.item() ** 2
+ total_norm = total_norm ** (1.0 / 2)
+ return total_norm
+
+ def log_loss(
+ self, idx: int, loss_item: float, step_or_epoch: str, dataset_name: str
+ ):
+ self.tb_logger.add_scalar(
+ f"loss/{dataset_name}/{step_or_epoch}", loss_item, idx
+ )
+ # self.tb_logger.add_scalars(
+ # "lr",
+ # {
+ # str(i): self.optimizer.param_groups[i]["lr"]
+ # for i in range(len(self.optimizer.param_groups))
+ # },
+ # idx,
+ # )
+ self.tb_logger.add_scalar("lr", self.optimizer.param_groups[0]["lr"], idx)
+ self.tb_logger.add_scalar("grad_norm_total", self.get_grad_norm(), idx)
+
+ def log_metrics(
+ self, idx: int, metrics: dict, step_or_epoch: str, dataset_name: str
+ ):
+ metrics = flatten_dict(metrics)
+ self.tb_logger.add_scalars(f"{dataset_name}/{step_or_epoch}", metrics, idx)
+
+ def init_transform(self):
+ return CachedPointerTaggingTransform(
+ self.config.max_seq_len,
+ self.config.plm_dir,
+ self.config.ent_type2query_filepath,
+ mode=self.config.mode,
+ negative_sample_prob=self.config.negative_sample_prob,
+ )
+
+ def init_data_manager(self):
+ return DataManager(
+ self.config.train_filepath,
+ self.config.dev_filepath,
+ self.config.test_filepath,
+ CachedDataset,
+ self.transform,
+ load_jsonlines,
+ self.config.train_batch_size,
+ self.config.eval_batch_size,
+ self.transform.collate_fn,
+ use_stream_transform=False,
+ debug_mode=self.config.debug_mode,
+ dump_cache_dir=self.config.dump_cache_dir,
+ regenerate_cache=self.config.regenerate_cache,
+ )
+
+ def init_model(self):
+ # m = MrcPointerMatrixModel(
+ m = MrcGlobalPointerModel(
+ self.config.plm_dir,
+ biaffine_size=self.config.biaffine_size,
+ dropout=self.config.dropout,
+ mode=self.config.mode,
+ )
+ return m
+
+ def init_metric(self):
+ return MrcNERMetric()
+
+ def init_optimizer(self):
+ no_decay = r"(embedding|LayerNorm|\.bias$)"
+ plm_lr = r"^plm\."
+ non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])"
+
+ param_groups = []
+ for name, param in self.model.named_parameters():
+ lr = self.config.learning_rate
+ weight_decay = self.config.weight_decay
+ if re.search(non_trainable, name):
+ param.requires_grad = False
+ if not re.search(plm_lr, name):
+ lr = self.config.other_learning_rate
+ if re.search(no_decay, name):
+ weight_decay = 0.0
+ param_groups.append(
+ {"params": param, "lr": lr, "weight_decay": weight_decay}
+ )
+ return optim.AdamW(
+ param_groups,
+ lr=self.config.learning_rate,
+ betas=(0.9, 0.98),
+ eps=1e-6,
+ )
+
+ def init_lr_scheduler(self):
+ num_training_steps = int(
+ len(self.data_manager.train_loader)
+ * self.config.num_epochs
+ * accelerator.num_processes
+ )
+ num_warmup_steps = math.floor(
+ num_training_steps * self.config.warmup_proportion
+ )
+ # return get_linear_schedule_with_warmup(
+ return get_cosine_schedule_with_warmup(
+ self.optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ )
+
+ def predict_api(self, texts: List[str], **kwargs):
+ raw_dataset = self.transform.predict_transform(texts)
+ text_ids = sorted(list({ins["id"] for ins in raw_dataset}))
+ loader = self.data_manager.prepare_loader(raw_dataset)
+ # to prepare input device
+ loader = accelerator.prepare_data_loader(loader)
+ id2ents = defaultdict(set)
+ for batch in loader:
+ batch_out = self.model(**batch, is_eval=True)
+ for _id, _pred in zip(batch["id"], batch_out["pred"]):
+ id2ents[_id].update(_pred)
+ results = [id2ents[_id] for _id in text_ids]
+
+ return results
+
+
+@register("task")
+class MrcQaTask(MrcTaggingTask):
+ def init_transform(self):
+ return CachedPointerMRCTransform(
+ self.config.max_seq_len,
+ self.config.plm_dir,
+ mode=self.config.mode,
+ )
+
+ def init_model(self):
+ # m = MrcPointerMatrixModel(
+ m = MrcGlobalPointerModel(
+ self.config.plm_dir,
+ biaffine_size=self.config.biaffine_size,
+ dropout=self.config.dropout,
+ mode=self.config.mode,
+ )
+ return m
+
+ def init_metric(self):
+ return MrcSpanMetric()
+
+ def predict_api(self, data: list[dict], **kwargs):
+ """
+ Args:
+ data: a list of dict with query, context, and background strings
+ """
+ raw_dataset = self.transform.predict_transform(data)
+ loader = self.data_manager.prepare_loader(raw_dataset)
+ results = []
+ for batch in loader:
+ batch_out = self.model(**batch, is_eval=True)
+ batch["pred"] = batch_out["pred"]
+ instances = decompose_batch_into_instances(batch)
+ for ins in instances:
+ preds = ins["pred"]
+ ins_results = []
+ for index_list in preds:
+ ins_result = []
+ for i in index_list:
+ ins_result.append(ins["raw_tokens"][i])
+ ins_results.append(("".join(ins_result), tuple(index_list)))
+ results.append(ins_results)
+
+ return results
+
+
+class StreamReadDatasetWithLen(StreamReadDataset):
+ def __len__(self):
+ return 631346
+
+
+@register("task")
+class SchemaGuidedInstructBertTask(MrcTaggingTask):
+ # def __init__(self, config, **kwargs) -> None:
+ # super().__init__(config, **kwargs)
+
+ # from watchmen import ClientMode, WatchClient
+
+ # client = WatchClient(
+ # id=config.task_name,
+ # gpus=[4],
+ # req_gpu_num=1,
+ # mode=ClientMode.SCHEDULE,
+ # server_host="127.0.0.1",
+ # server_port=62333,
+ # )
+ # client.wait()
+
+ # def init_lr_scheduler(self):
+ # num_training_steps = int(
+ # 631346 / self.config.train_batch_size
+ # * self.config.num_epochs
+ # * accelerator.num_processes
+ # )
+ # num_warmup_steps = math.floor(
+ # num_training_steps * self.config.warmup_proportion
+ # )
+ # # return get_linear_schedule_with_warmup(
+ # return get_cosine_schedule_with_warmup(
+ # self.optimizer,
+ # num_warmup_steps=num_warmup_steps,
+ # num_training_steps=num_training_steps,
+ # )
+
+ def init_transform(self):
+ self.transform: CachedLabelPointerTransform
+ return CachedLabelPointerTransform(
+ self.config.max_seq_len,
+ self.config.plm_dir,
+ mode=self.config.mode,
+ label_span=self.config.label_span,
+ include_instructions=self.config.get("include_instructions", True),
+ )
+
+ def init_data_manager(self):
+ if self.config.get("stream_mode", False):
+ DatasetClass = StreamReadDatasetWithLen
+ transform = self.transform.transform
+ else:
+ DatasetClass = CachedDataset
+ transform = self.transform
+ return DataManager(
+ self.config.train_filepath,
+ self.config.dev_filepath,
+ self.config.test_filepath,
+ DatasetClass,
+ transform,
+ load_jsonlines,
+ self.config.train_batch_size,
+ self.config.eval_batch_size,
+ self.transform.collate_fn,
+ use_stream_transform=self.config.get("stream_mode", False),
+ debug_mode=self.config.debug_mode,
+ dump_cache_dir=self.config.dump_cache_dir,
+ regenerate_cache=self.config.regenerate_cache,
+ )
+
+ def init_model(self):
+ self.model = SchemaGuidedInstructBertModel(
+ self.config.plm_dir,
+ vocab_size=len(self.transform.tokenizer),
+ use_rope=self.config.use_rope,
+ biaffine_size=self.config.biaffine_size,
+ dropout=self.config.dropout,
+ )
+
+ if self.config.get("base_model_path"):
+ self.load(
+ self.config.base_model_path,
+ load_config=False,
+ load_model=True,
+ load_optimizer=False,
+ load_history=False,
+ )
+ return self.model
+
+ def init_optimizer(self):
+ no_decay = r"(embedding|LayerNorm|\.bias$)"
+ plm_lr = r"^plm\."
+ # non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])"
+ non_trainable = "no_non_trainable"
+
+ param_groups = []
+ for name, param in self.model.named_parameters():
+ lr = self.config.learning_rate
+ weight_decay = self.config.weight_decay
+ if re.search(non_trainable, name):
+ param.requires_grad = False
+ if not re.search(plm_lr, name):
+ lr = self.config.other_learning_rate
+ if re.search(no_decay, name):
+ weight_decay = 0.0
+ param_groups.append(
+ {"params": param, "lr": lr, "weight_decay": weight_decay}
+ )
+ return optim.AdamW(
+ param_groups,
+ lr=self.config.learning_rate,
+ betas=(0.9, 0.98),
+ eps=1e-6,
+ )
+
+ def init_metric(self):
+ return MultiPartSpanMetric()
+
+ def _convert_span_to_string(self, span, token_ids, tokenizer):
+ string = ""
+ if len(span) == 0 or len(span) > 2:
+ pass
+ elif len(span) == 1:
+ string = tokenizer.decode(token_ids[span[0]])
+ elif len(span) == 2:
+ string = tokenizer.decode(token_ids[span[0] : span[1] + 1])
+ return (string, self.reset_position(token_ids, span))
+
+ def reset_position(self, token_ids: list[int], span: list[int]) -> list[int]:
+ if isinstance(token_ids, torch.Tensor):
+ input_ids = token_ids.cpu().tolist()
+ if len(span) < 1:
+ return span
+
+ tp_token_id, tl_token_id = self.transform.tokenizer.convert_tokens_to_ids(
+ [self.transform.tp_token, self.transform.tl_token]
+ )
+ offset = 0
+ if tp_token_id in input_ids:
+ offset = input_ids.index(tp_token_id) + 1
+ elif tl_token_id in input_ids:
+ offset = input_ids.index(tl_token_id) + 1
+ return [i - offset for i in span]
+
+ def predict_api(self, data: list[dict], **kwargs):
+ """
+ Args:
+ data: a list of dict in UDI:
+ {
+ "id": str,
+ "instruction": str,
+ "schema": {
+ "ent": list,
+ "rel": list,
+ "event": dict,
+ "cls": list,
+ "discontinuous_ent": list,
+ "hyper_rel": dict
+ },
+ "text": str,
+ "bg": str,
+ "ans": {}, # empty dict
+ }
+ """
+ raw_dataset = [self.transform.transform(d) for d in data]
+ loader = self.data_manager.prepare_loader(raw_dataset)
+ results = []
+ for batch in loader:
+ batch_out = self.model(**batch, is_eval=True)
+ batch["pred"] = batch_out["pred"]
+ instances = decompose_batch_into_instances(batch)
+ for ins in instances:
+ pred_clses = []
+ pred_ents = []
+ pred_rels = []
+ pred_trigger_to_event = defaultdict(
+ lambda: {"event_type": "", "arguments": []}
+ )
+ pred_events = []
+ pred_spans = []
+ pred_discon_ents = []
+ pred_hyper_rels = []
+ raw_schema = ins["raw"]["schema"]
+ for multi_part_span in ins["pred"]:
+ span = tuple(multi_part_span)
+ span_to_label = ins["span_to_label"]
+ if span[0] in span_to_label:
+ label = span_to_label[span[0]]
+ if label["task"] == "cls" and len(span) == 1:
+ pred_clses.append(label["string"])
+ elif label["task"] == "ent" and len(span) == 2:
+ string = self._convert_span_to_string(
+ span[1], ins["input_ids"], self.transform.tokenizer
+ )
+ pred_ents.append((label["string"], string))
+ elif label["task"] == "rel" and len(span) == 3:
+ head = self._convert_span_to_string(
+ span[1], ins["input_ids"], self.transform.tokenizer
+ )
+ tail = self._convert_span_to_string(
+ span[2], ins["input_ids"], self.transform.tokenizer
+ )
+ pred_rels.append((label["string"], head, tail))
+ elif label["task"] == "event":
+ if label["type"] == "lm" and len(span) == 2:
+ pred_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip
+ elif label["type"] == "lr" and len(span) == 3:
+ arg = self._convert_span_to_string(
+ span[2], ins["input_ids"], self.transform.tokenizer
+ )
+ pred_trigger_to_event[span[1]]["arguments"].append(
+ {"argument": arg, "role": label["string"]}
+ )
+ elif label["task"] == "discontinuous_ent" and len(span) > 1:
+ parts = [
+ self._convert_span_to_string(
+ part, ins["input_ids"], self.transform.tokenizer
+ )
+ for part in span[1:]
+ ]
+ string = " ".join([part[0] for part in parts])
+ position = []
+ for part in parts:
+ position.append(part[1])
+ pred_discon_ents.append(
+ (label["string"], string, self.reset_position(position))
+ )
+ elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip
+ q_label = span_to_label[span[3]]
+ span_1 = self._convert_span_to_string(
+ span[1], ins["input_ids"], self.transform.tokenizer
+ )
+ span_2 = self._convert_span_to_string(
+ span[2], ins["input_ids"], self.transform.tokenizer
+ )
+ span_4 = self._convert_span_to_string(
+ span[4], ins["input_ids"], self.transform.tokenizer
+ )
+ pred_hyper_rels.append((label["string"], span_1, span_2, q_label["string"], span_4)) # fmt: skip
+ else:
+ # span task has no labels
+ pred_token_ids = []
+ for part in span:
+ _pred_token_ids = [ins["input_ids"][i] for i in part]
+ pred_token_ids.extend(_pred_token_ids)
+ span_string = self.transform.tokenizer.decode(pred_token_ids)
+ pred_spans.append(
+ (
+ span_string,
+ tuple(
+ [
+ tuple(
+ self.reset_position(
+ ins["input_ids"].cpu().tolist(), part
+ )
+ )
+ for part in span
+ ]
+ ),
+ )
+ )
+ for trigger, item in pred_trigger_to_event.items():
+ trigger = self._convert_span_to_string(
+ trigger, ins["input_ids"], self.transform.tokenizer
+ )
+ if item["event_type"] not in raw_schema["event"]:
+ continue
+ legal_roles = raw_schema["event"][item["event_type"]]
+ pred_events.append(
+ {
+ "trigger": trigger,
+ "event_type": item["event_type"],
+ "arguments": [
+ arg
+ for arg in filter(
+ lambda arg: arg["role"] in legal_roles,
+ item["arguments"],
+ )
+ ],
+ }
+ )
+ results.append(
+ {
+ "id": ins["raw"]["id"],
+ "results": {
+ "cls": pred_clses,
+ "ent": pred_ents,
+ "rel": pred_rels,
+ "event": pred_events,
+ "span": pred_spans,
+ "discon_ent": pred_discon_ents,
+ "hyper_rel": pred_hyper_rels,
+ },
+ }
+ )
+
+ return results
+
+
+if __name__ == "__main__":
+ pass
+ # further_finetune()
+
+ # from rex.utils.config import ConfigParser
+
+ # config = ConfigParser.parse_cmd(cmd_args=["-dc", "conf/ner.yaml"])
+ # config = ConfigParser.parse_cmd(cmd_args=["-dc", "conf/mirror-ace05en.yaml"])
+
+ # task = MrcTaggingTask(
+ # config,
+ # initialize=True,
+ # makedirs=True,
+ # dump_configfile=True,
+ # )
+ # task = SchemaGuidedInstructBertTask.from_taskdir(
+ # "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_Rel",
+ # initialize=True,
+ # load_config=True,
+ # dump_configfile=False,
+ # )
+ # task = SchemaGuidedInstructBertTask(
+ # config,
+ # initialize=True,
+ # makedirs=True,
+ # dump_configfile=False,
+ # )
+ # task.load(
+ # "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_NerRelEvent/ckpt/SchemaGuidedInstructBertModel.epoch.0.pth",
+ # load_config=False,
+ # )
+ # task.eval("test", verbose=True, dump=True, dump_middle=True, postfix="re_eval")
+ # task.load(
+ # # "outputs/Mirror_RobertaBaseWwm_Cons_MsraMrc/ckpt/MrcGlobalPointerModel.best.pth",
+ # # "outputs/Mirror_RobertaBaseWwm_W2_MsraMrc_HyperParamExp1/ckpt/MrcGlobalPointerModel.best.pth",
+ # config.base_model_path,
+ # load_config=False,
+ # load_model=True,
+ # load_optimizer=False,
+ # load_history=False,
+ # )
+ # task.train()
+ # task = MrcTaggingTask.from_taskdir(
+ # "outputs/Mirror_W2_MSRAv2_NER",
+ # initialize=True,
+ # dump_configfile=False,
+ # load_config=True,
+ # )
+ # for name, _ in task.model.named_parameters():
+ # print(name)
+ # task.eval("test", verbose=True, dump=True, dump_middle=True, postfix="re_eval.0.1")
+
+ # task = MrcQaTask(
+ # config,
+ # initialize=True,
+ # makedirs=True,
+ # dump_configfile=True,
+ # )
+ # task.train()
+ # task.eval("dev", verbose=True, dump=True, dump_middle=True, postfix="re_eval")
diff --git a/src/transform.py b/src/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..19fdd5371922bb75d5193aeac2f2926d65a4f436
--- /dev/null
+++ b/src/transform.py
@@ -0,0 +1,693 @@
+import random
+import re
+from collections import defaultdict
+from typing import Iterable, Iterator, List, MutableSet, Optional, Tuple, TypeVar, Union
+
+import torch
+import torch.nn.functional as F
+from rex.data.collate_fn import GeneralCollateFn
+from rex.data.transforms.base import CachedTransformBase, CachedTransformOneBase
+from rex.metrics import calc_p_r_f1_from_tp_fp_fn
+from rex.utils.io import load_json
+from rex.utils.iteration import windowed_queue_iter
+from rex.utils.logging import logger
+from transformers import AutoTokenizer
+from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
+from transformers.models.deberta_v2.tokenization_deberta_v2_fast import (
+ DebertaV2TokenizerFast,
+)
+from transformers.tokenization_utils_base import BatchEncoding
+
+from src.utils import (
+ decode_nnw_nsw_thw_mat,
+ decode_nnw_thw_mat,
+ encode_nnw_nsw_thw_mat,
+ encode_nnw_thw_mat,
+)
+
+Filled = TypeVar("Filled")
+
+
+class PaddingMixin:
+ max_seq_len: int
+
+ def pad_seq(self, batch_seqs: Iterable[Filled], fill: Filled) -> Iterable[Filled]:
+ max_len = max(len(seq) for seq in batch_seqs)
+ assert max_len <= self.max_seq_len
+ for i in range(len(batch_seqs)):
+ batch_seqs[i] = batch_seqs[i] + [fill] * (max_len - len(batch_seqs[i]))
+ return batch_seqs
+
+ def pad_mat(
+ self, mats: List[torch.Tensor], fill: Union[int, float]
+ ) -> List[torch.Tensor]:
+ max_len = max(mat.shape[0] for mat in mats)
+ assert max_len <= self.max_seq_len
+ for i in range(len(mats)):
+ num_add = max_len - mats[i].shape[0]
+ mats[i] = F.pad(
+ mats[i], (0, 0, 0, num_add, 0, num_add), mode="constant", value=fill
+ )
+ return mats
+
+
+class PointerTransformMixin:
+ tokenizer: BertTokenizerFast
+ max_seq_len: int
+ space_token: str = "[unused1]"
+
+ def build_ins(
+ self,
+ query_tokens: list[str],
+ context_tokens: list[str],
+ answer_indexes: list[list[int]],
+ add_context_tokens: list[str] = None,
+ ) -> Tuple:
+ # -2: cls and sep
+ reserved_seq_len = self.max_seq_len - 3 - len(query_tokens)
+ # reserve at least 20 tokens
+ if reserved_seq_len < 20:
+ raise ValueError(
+ f"Query {query_tokens} too long: {len(query_tokens)} "
+ f"while max seq len is {self.max_seq_len}"
+ )
+
+ input_tokens = [self.tokenizer.cls_token]
+ input_tokens += query_tokens
+ input_tokens += [self.tokenizer.sep_token]
+ offset = len(input_tokens)
+ input_tokens += context_tokens[:reserved_seq_len]
+ available_token_range = range(
+ offset, offset + len(context_tokens[:reserved_seq_len])
+ )
+ input_tokens += [self.tokenizer.sep_token]
+
+ add_context_len = 0
+ max_add_context_len = self.max_seq_len - len(input_tokens) - 1
+ add_context_flag = False
+ if add_context_tokens and len(add_context_tokens) > 0:
+ add_context_flag = True
+ add_context_len = len(add_context_tokens[:max_add_context_len])
+ input_tokens += add_context_tokens[:max_add_context_len]
+ input_tokens += [self.tokenizer.sep_token]
+ new_tokens = []
+ for t in input_tokens:
+ if len(t.strip()) > 0:
+ new_tokens.append(t)
+ else:
+ new_tokens.append(self.space_token)
+ input_tokens = new_tokens
+ input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens)
+
+ mask = [1]
+ mask += [2] * len(query_tokens)
+ mask += [3]
+ mask += [4] * len(context_tokens[:reserved_seq_len])
+ mask += [5]
+ if add_context_flag:
+ mask += [6] * add_context_len
+ mask += [7]
+ assert len(mask) == len(input_ids) <= self.max_seq_len
+
+ available_spans = [tuple(i + offset for i in index) for index in answer_indexes]
+ available_spans = list(
+ filter(
+ lambda index: all(i in available_token_range for i in index),
+ available_spans,
+ )
+ )
+
+ token_len = len(input_ids)
+ pad_len = self.max_seq_len - token_len
+ input_tokens += pad_len * [self.tokenizer.pad_token]
+ input_ids += pad_len * [self.tokenizer.pad_token_id]
+ mask += pad_len * [0]
+
+ return input_tokens, input_ids, mask, offset, available_spans
+
+ def update_labels(self, data: dict) -> dict:
+ bs = len(data["input_ids"])
+ seq_len = self.max_seq_len
+ labels = torch.zeros((bs, 2, seq_len, seq_len))
+ for i, batch_spans in enumerate(data["available_spans"]):
+ # offset = data["offset"][i]
+ # pad_len = data["mask"].count(0)
+ # token_len = seq_len - pad_len
+ for span in batch_spans:
+ if len(span) == 1:
+ labels[i, :, span[0], span[0]] = 1
+ else:
+ for s, e in windowed_queue_iter(span, 2, 1, drop_last=True):
+ labels[i, 0, s, e] = 1
+ labels[i, 1, span[-1], span[0]] = 1
+ # labels[i, :, 0:offset, :] = -100
+ # labels[i, :, :, 0:offset] = -100
+ # labels[i, :, :, token_len:] = -100
+ # labels[i, :, token_len:, :] = -100
+ data["labels"] = labels
+ return data
+
+ def update_consecutive_span_labels(self, data: dict) -> dict:
+ bs = len(data["input_ids"])
+ seq_len = self.max_seq_len
+ labels = torch.zeros((bs, 1, seq_len, seq_len))
+ for i, batch_spans in enumerate(data["available_spans"]):
+ for span in batch_spans:
+ assert span == tuple(sorted(set(span)))
+ if len(span) == 1:
+ labels[i, 0, span[0], span[0]] = 1
+ else:
+ labels[i, 0, span[0], span[-1]] = 1
+ data["labels"] = labels
+ return data
+
+
+class CachedPointerTaggingTransform(CachedTransformBase, PointerTransformMixin):
+ def __init__(
+ self,
+ max_seq_len: int,
+ plm_dir: str,
+ ent_type2query_filepath: str,
+ mode: str = "w2",
+ negative_sample_prob: float = 1.0,
+ ) -> None:
+ super().__init__()
+
+ self.max_seq_len: int = max_seq_len
+ self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir)
+ self.ent_type2query: dict = load_json(ent_type2query_filepath)
+ self.negative_sample_prob = negative_sample_prob
+
+ self.collate_fn: GeneralCollateFn = GeneralCollateFn(
+ {
+ "input_ids": torch.long,
+ "mask": torch.long,
+ "labels": torch.long,
+ },
+ guessing=False,
+ missing_key_as_null=True,
+ )
+ if mode == "w2":
+ self.collate_fn.update_before_tensorify = self.update_labels
+ elif mode == "cons":
+ self.collate_fn.update_before_tensorify = (
+ self.update_consecutive_span_labels
+ )
+ else:
+ raise ValueError(f"Mode: {mode} not recognizable")
+
+ def transform(
+ self,
+ transform_loader: Iterator,
+ dataset_name: str = None,
+ **kwargs,
+ ) -> Iterable:
+ final_data = []
+ # tp = fp = fn = 0
+ for data in transform_loader:
+ ent_type2ents = defaultdict(set)
+ for ent in data["ents"]:
+ ent_type2ents[ent["type"]].add(tuple(ent["index"]))
+ for ent_type in self.ent_type2query:
+ gold_ents = ent_type2ents[ent_type]
+ if (
+ len(gold_ents) < 1
+ and dataset_name == "train"
+ and random.random() > self.negative_sample_prob
+ ):
+ # skip negative samples
+ continue
+ # res = self.build_ins(ent_type, data["tokens"], gold_ents)
+ query = self.ent_type2query[ent_type]
+ query_tokens = self.tokenizer.tokenize(query)
+ try:
+ res = self.build_ins(query_tokens, data["tokens"], gold_ents)
+ except (ValueError, AssertionError):
+ continue
+ input_tokens, input_ids, mask, offset, available_spans = res
+ ins = {
+ "id": data.get("id", str(len(final_data))),
+ "ent_type": ent_type,
+ "gold_ents": gold_ents,
+ "raw_tokens": data["tokens"],
+ "input_tokens": input_tokens,
+ "input_ids": input_ids,
+ "mask": mask,
+ "offset": offset,
+ "available_spans": available_spans,
+ # labels are dynamically padded in collate fn
+ "labels": None,
+ # "labels": labels.tolist(),
+ }
+ final_data.append(ins)
+
+ # # upper bound analysis
+ # pred_spans = set(decode_nnw_thw_mat(labels.unsqueeze(0))[0])
+ # g_ents = set(available_spans)
+ # tp += len(g_ents & pred_spans)
+ # fp += len(pred_spans - g_ents)
+ # fn += len(g_ents - pred_spans)
+
+ # # upper bound results
+ # measures = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
+ # logger.info(f"Upper Bound: {measures}")
+
+ return final_data
+
+ def predict_transform(self, texts: List[str]):
+ dataset = []
+ for text_id, text in enumerate(texts):
+ data_id = f"Prediction#{text_id}"
+ tokens = self.tokenizer.tokenize(text)
+ dataset.append(
+ {
+ "id": data_id,
+ "tokens": tokens,
+ "ents": [],
+ }
+ )
+ final_data = self(dataset, disable_pbar=True)
+ return final_data
+
+
+class CachedPointerMRCTransform(CachedTransformBase, PointerTransformMixin):
+ def __init__(
+ self,
+ max_seq_len: int,
+ plm_dir: str,
+ mode: str = "w2",
+ ) -> None:
+ super().__init__()
+
+ self.max_seq_len: int = max_seq_len
+ self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir)
+
+ self.collate_fn: GeneralCollateFn = GeneralCollateFn(
+ {
+ "input_ids": torch.long,
+ "mask": torch.long,
+ "labels": torch.long,
+ },
+ guessing=False,
+ missing_key_as_null=True,
+ )
+
+ if mode == "w2":
+ self.collate_fn.update_before_tensorify = self.update_labels
+ elif mode == "cons":
+ self.collate_fn.update_before_tensorify = (
+ self.update_consecutive_span_labels
+ )
+ else:
+ raise ValueError(f"Mode: {mode} not recognizable")
+
+ def transform(
+ self,
+ transform_loader: Iterator,
+ dataset_name: str = None,
+ **kwargs,
+ ) -> Iterable:
+ final_data = []
+ for data in transform_loader:
+ try:
+ res = self.build_ins(
+ data["query_tokens"],
+ data["context_tokens"],
+ data["answer_index"],
+ data.get("background_tokens"),
+ )
+ except (ValueError, AssertionError):
+ continue
+ input_tokens, input_ids, mask, offset, available_spans = res
+ ins = {
+ "id": data.get("id", str(len(final_data))),
+ "gold_spans": sorted(set(tuple(x) for x in data["answer_index"])),
+ "raw_tokens": data["context_tokens"],
+ "input_tokens": input_tokens,
+ "input_ids": input_ids,
+ "mask": mask,
+ "offset": offset,
+ "available_spans": available_spans,
+ "labels": None,
+ }
+ final_data.append(ins)
+
+ return final_data
+
+ def predict_transform(self, data: list[dict]):
+ """
+ Args:
+ data: a list of dict with query, context, and background strings
+ """
+ dataset = []
+ for idx, ins in enumerate(data):
+ idx = f"Prediction#{idx}"
+ dataset.append(
+ {
+ "id": idx,
+ "query_tokens": list(ins["query"]),
+ "context_tokens": list(ins["context"]),
+ "background_tokens": list(ins.get("background")),
+ "answer_index": [],
+ }
+ )
+ final_data = self(dataset, disable_pbar=True, num_samples=0)
+ return final_data
+
+
+class CachedLabelPointerTransform(CachedTransformOneBase):
+ """Transform for label-token linking for skip consecutive spans"""
+
+ def __init__(
+ self,
+ max_seq_len: int,
+ plm_dir: str,
+ mode: str = "w2",
+ label_span: str = "tag",
+ include_instructions: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ self.max_seq_len: int = max_seq_len
+ self.mode = mode
+ self.label_span = label_span
+ self.include_instructions = include_instructions
+
+ self.tokenizer: DebertaV2TokenizerFast = DebertaV2TokenizerFast.from_pretrained(
+ plm_dir
+ )
+ self.lc_token = "[LC]"
+ self.lm_token = "[LM]"
+ self.lr_token = "[LR]"
+ self.i_token = "[I]"
+ self.tl_token = "[TL]"
+ self.tp_token = "[TP]"
+ self.b_token = "[B]"
+ num_added = self.tokenizer.add_tokens(
+ [
+ self.lc_token,
+ self.lm_token,
+ self.lr_token,
+ self.i_token,
+ self.tl_token,
+ self.tp_token,
+ self.b_token,
+ ]
+ )
+ assert num_added == 7
+
+ self.collate_fn: GeneralCollateFn = GeneralCollateFn(
+ {
+ "input_ids": torch.long,
+ "mask": torch.long,
+ "labels": torch.long,
+ "spans": None,
+ },
+ guessing=False,
+ missing_key_as_null=True,
+ # only for pre-training
+ discard_missing=False,
+ )
+
+ self.collate_fn.update_before_tensorify = self.skip_consecutive_span_labels
+
+ def transform(self, instance: dict, **kwargs):
+ # input
+ tokens = [self.tokenizer.cls_token]
+ mask = [1]
+ label_map = {"lc": {}, "lm": {}, "lr": {}}
+ # (2, 3): {"type": "lc", "task": "cls/ent/rel/event/hyper_rel/discontinuous_ent", "string": ""}
+ span_to_label = {}
+
+ def _update_seq(
+ label: str,
+ label_type: str,
+ task: str = "",
+ label_mask: int = 4,
+ content_mask: int = 5,
+ ):
+ if label not in label_map[label_type]:
+ label_token_map = {
+ "lc": self.lc_token,
+ "lm": self.lm_token,
+ "lr": self.lr_token,
+ }
+ label_tag_start_idx = len(tokens)
+ tokens.append(label_token_map[label_type])
+ mask.append(label_mask)
+ label_tag_end_idx = len(tokens) - 1 # exact end position
+ label_tokens = self.tokenizer(label, add_special_tokens=False).tokens()
+ label_content_start_idx = len(tokens)
+ tokens.extend(label_tokens)
+ mask.extend([content_mask] * len(label_tokens))
+ label_content_end_idx = len(tokens) - 1 # exact end position
+
+ if self.label_span == "tag":
+ start_idx = label_tag_start_idx
+ end_idx = label_tag_end_idx
+ elif self.label_span == "content":
+ start_idx = label_content_start_idx
+ end_idx = label_content_end_idx
+ else:
+ raise ValueError(f"label_span={self.label_span} is not supported")
+
+ if end_idx == start_idx:
+ label_map[label_type][label] = (start_idx,)
+ else:
+ label_map[label_type][label] = (start_idx, end_idx)
+ span_to_label[label_map[label_type][label]] = {
+ "type": label_type,
+ "task": task,
+ "string": label,
+ }
+ return label_map[label_type][label]
+
+ if self.include_instructions:
+ instruction = instance.get("instruction")
+ if not instruction:
+ logger.warning(
+ "include_instructions=True, while the instruction is empty!"
+ )
+ else:
+ instruction = ""
+ if instruction:
+ tokens.append(self.i_token)
+ mask.append(2)
+ instruction_tokens = self.tokenizer(
+ instruction, add_special_tokens=False
+ ).tokens()
+ tokens.extend(instruction_tokens)
+ mask.extend([3] * len(instruction_tokens))
+ types = instance["schema"].get("cls")
+ if types:
+ for t in types:
+ _update_seq(t, "lc", task="cls")
+ mention_types = instance["schema"].get("ent")
+ if mention_types:
+ for mt in mention_types:
+ _update_seq(mt, "lm", task="ent")
+ discon_ent_types = instance["schema"].get("discontinuous_ent")
+ if discon_ent_types:
+ for mt in discon_ent_types:
+ _update_seq(mt, "lm", task="discontinuous_ent")
+ rel_types = instance["schema"].get("rel")
+ if rel_types:
+ for rt in rel_types:
+ _update_seq(rt, "lr", task="rel")
+ hyper_rel_schema = instance["schema"].get("hyper_rel")
+ if hyper_rel_schema:
+ for rel, qualifiers in hyper_rel_schema.items():
+ _update_seq(rel, "lr", task="hyper_rel")
+ for qualifier in qualifiers:
+ _update_seq(qualifier, "lr", task="hyper_rel")
+ event_schema = instance["schema"].get("event")
+ if event_schema:
+ for event_type, roles in event_schema.items():
+ _update_seq(event_type, "lm", task="event")
+ for role in roles:
+ _update_seq(role, "lr", task="event")
+
+ text = instance.get("text")
+ if text:
+ text_tokenized = self.tokenizer(
+ text, return_offsets_mapping=True, add_special_tokens=False
+ )
+ if any(val for val in label_map.values()):
+ text_label_token = self.tl_token
+ else:
+ text_label_token = self.tp_token
+ tokens.append(text_label_token)
+ mask.append(6)
+ remain_token_len = self.max_seq_len - 1 - len(tokens)
+ if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train":
+ return None
+ text_off = len(tokens)
+ text_tokens = text_tokenized.tokens()[:remain_token_len]
+ tokens.extend(text_tokens)
+ mask.extend([7] * len(text_tokens))
+ else:
+ text_tokenized = None
+
+ bg = instance.get("bg")
+ if bg:
+ bg_tokenized = self.tokenizer(
+ bg, return_offsets_mapping=True, add_special_tokens=False
+ )
+ tokens.append(self.b_token)
+ mask.append(8)
+ remain_token_len = self.max_seq_len - 1 - len(tokens)
+ if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train":
+ return None
+ bg_tokens = bg_tokenized.tokens()[:remain_token_len]
+ tokens.extend(bg_tokens)
+ mask.extend([9] * len(bg_tokens))
+ else:
+ bg_tokenized = None
+
+ tokens.append(self.tokenizer.sep_token)
+ mask.append(10)
+
+ # labels
+ # spans: [[(ent_type start, ent_type end + 1), (ent s, ent e + 1)]]
+ spans = [] # one span may have many parts
+ if "cls" in instance["ans"]:
+ for t in instance["ans"]["cls"]:
+ part = label_map["lc"][t]
+ spans.append([part])
+ if "ent" in instance["ans"]:
+ for ent in instance["ans"]["ent"]:
+ label_part = label_map["lm"][ent["type"]]
+ position_seq = self.char_to_token_span(
+ ent["span"], text_tokenized, text_off
+ )
+ spans.append([label_part, position_seq])
+ if "discontinuous_ent" in instance["ans"]:
+ for ent in instance["ans"]["discontinuous_ent"]:
+ label_part = label_map["lm"][ent["type"]]
+ ent_span = [label_part]
+ for part in ent["span"]:
+ position_seq = self.char_to_token_span(
+ part, text_tokenized, text_off
+ )
+ ent_span.append(position_seq)
+ spans.append(ent_span)
+ if "rel" in instance["ans"]:
+ for rel in instance["ans"]["rel"]:
+ label_part = label_map["lr"][rel["relation"]]
+ head_position_seq = self.char_to_token_span(
+ rel["head"]["span"], text_tokenized, text_off
+ )
+ tail_position_seq = self.char_to_token_span(
+ rel["tail"]["span"], text_tokenized, text_off
+ )
+ spans.append([label_part, head_position_seq, tail_position_seq])
+ if "hyper_rel" in instance["ans"]:
+ for rel in instance["ans"]["hyper_rel"]:
+ label_part = label_map["lr"][rel["relation"]]
+ head_position_seq = self.char_to_token_span(
+ rel["head"]["span"], text_tokenized, text_off
+ )
+ tail_position_seq = self.char_to_token_span(
+ rel["tail"]["span"], text_tokenized, text_off
+ )
+ # rel_span = [label_part, head_position_seq, tail_position_seq]
+ for q in rel["qualifiers"]:
+ q_label_part = label_map["lr"][q["label"]]
+ q_position_seq = self.char_to_token_span(
+ q["span"], text_tokenized, text_off
+ )
+ spans.append(
+ [
+ label_part,
+ head_position_seq,
+ tail_position_seq,
+ q_label_part,
+ q_position_seq,
+ ]
+ )
+ if "event" in instance["ans"]:
+ for event in instance["ans"]["event"]:
+ event_type_label_part = label_map["lm"][event["event_type"]]
+ trigger_position_seq = self.char_to_token_span(
+ event["trigger"]["span"], text_tokenized, text_off
+ )
+ trigger_part = [event_type_label_part, trigger_position_seq]
+ spans.append(trigger_part)
+ for arg in event["args"]:
+ role_label_part = label_map["lr"][arg["role"]]
+ arg_position_seq = self.char_to_token_span(
+ arg["span"], text_tokenized, text_off
+ )
+ arg_part = [role_label_part, trigger_position_seq, arg_position_seq]
+ spans.append(arg_part)
+ if "span" in instance["ans"]:
+ # Extractive-QA or Extractive-MRC tasks
+ for span in instance["ans"]["span"]:
+ span_position_seq = self.char_to_token_span(
+ span["span"], text_tokenized, text_off
+ )
+ spans.append([span_position_seq])
+
+ if self.mode == "w2":
+ new_spans = []
+ for parts in spans:
+ new_parts = []
+ for part in parts:
+ new_parts.append(tuple(range(part[0], part[-1] + 1)))
+ new_spans.append(new_parts)
+ spans = new_spans
+ elif self.mode == "span":
+ spans = spans
+ else:
+ raise ValueError(f"mode={self.mode} is not supported")
+
+ ins = {
+ "raw": instance,
+ "tokens": tokens,
+ "input_ids": self.tokenizer.convert_tokens_to_ids(tokens),
+ "mask": mask,
+ "spans": spans,
+ "label_map": label_map,
+ "span_to_label": span_to_label,
+ "labels": None, # labels are calculated dynamically in collate_fn
+ }
+ return ins
+
+ def char_to_token_span(
+ self, span: list[int], tokenized: BatchEncoding, offset: int = 0
+ ) -> list[int]:
+ token_s = tokenized.char_to_token(span[0])
+ token_e = tokenized.char_to_token(span[1] - 1)
+ if token_e == token_s:
+ position_seq = (offset + token_s,)
+ else:
+ position_seq = (offset + token_s, offset + token_e)
+ return position_seq
+
+ def skip_consecutive_span_labels(self, data: dict) -> dict:
+ bs = len(data["input_ids"])
+ max_seq_len = max(len(input_ids) for input_ids in data["input_ids"])
+ batch_seq_len = min(self.max_seq_len, max_seq_len)
+ for i in range(bs):
+ data["input_ids"][i] = data["input_ids"][i][:batch_seq_len]
+ data["mask"][i] = data["mask"][i][:batch_seq_len]
+ assert len(data["input_ids"][i]) == len(data["mask"][i])
+ pad_len = batch_seq_len - len(data["mask"][i])
+ data["input_ids"][i] = (
+ data["input_ids"][i] + [self.tokenizer.pad_token_id] * pad_len
+ )
+ data["mask"][i] = data["mask"][i] + [0] * pad_len
+ data["labels"][i] = encode_nnw_nsw_thw_mat(data["spans"][i], batch_seq_len)
+
+ # # for debugging only
+ # pred_spans = decode_nnw_nsw_thw_mat(data["labels"][i].unsqueeze(0))[0]
+ # sorted_gold = sorted(set(tuple(x) for x in data["spans"][i]))
+ # sorted_pred = sorted(set(tuple(x) for x in pred_spans))
+ # if sorted_gold != sorted_pred:
+ # breakpoint()
+
+ # # for pre-training only
+ # del data["spans"]
+
+ return data
diff --git a/src/udi/__init__.py b/src/udi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..535ca0ac8e86273765bceec6ac221c3bb9b39f0f
--- /dev/null
+++ b/src/udi/__init__.py
@@ -0,0 +1,38 @@
+# udi-v1: universal data interface
+{
+ "id": "semeval.train.0",
+ "instruction": "instruction text",
+ "schema": {
+ "cls": ["class1", "class2"],
+ "ent": ["person", "location"],
+ "rel": ["birth in", "study in"],
+ "event": {
+ "event type (attack)": ["roles like instrument", "attacker"],
+ "another type": ["role", "role"],
+ },
+ },
+ "ans": {
+ "cls": ["class1"],
+ "ent": [
+ {"type": "person", "text": "1234", "span": [0, 4]}
+ ], # span: [start, end + 1]
+ "rel": [
+ {
+ "relation": "study in",
+ "head": {"text": "1234", "span": [0, 4]},
+ "tail": {"text": "1234", "span": [5, 9]},
+ }
+ ],
+ "event": [
+ {
+ "event_type": "attack",
+ "trigger": {"text": "hit", "span": [6, 9]},
+ "args": [{"role": "instrument", "text": "ax", "span": [8, 10]}],
+ }
+ ],
+ "span": [{"text": "machine learning", "span": [16, 32]}],
+ },
+ # DONE: whether or not to concatenate instruction with text (v2)
+ "text": "plain text",
+ "bg": "background text",
+}
diff --git a/src/udi/check.py b/src/udi/check.py
new file mode 100644
index 0000000000000000000000000000000000000000..07a15e03c1452e66a743abeb495eb4e1ea956afa
--- /dev/null
+++ b/src/udi/check.py
@@ -0,0 +1,139 @@
+from rex.utils.io import load_jsonlines
+
+
+def check_udi_instance(instance: dict):
+ assert isinstance(instance["id"], str)
+ assert isinstance(instance["instruction"], str)
+ assert isinstance(instance["schema"], dict)
+ for key in instance["schema"]:
+ assert key in ["cls", "ent", "rel", "event"]
+ if key in ["cls", "ent", "rel"]:
+ assert isinstance(instance["schema"][key], list) and all(
+ isinstance(x, str) for x in instance["schema"][key]
+ )
+ elif key == "event":
+ assert isinstance(instance["schema"][key], dict)
+ for event_type in instance["schema"][key]:
+ assert isinstance(instance["schema"][key][event_type], list) and all(
+ isinstance(x, str) for x in instance["schema"][key][event_type]
+ )
+ else:
+ raise ValueError
+ assert isinstance(instance["ans"], dict)
+ for key in instance["ans"]:
+ assert key in ["cls", "ent", "rel", "event", "span"]
+ if key == "cls":
+ assert isinstance(instance["ans"][key], list) and all(
+ isinstance(x, str) for x in instance["ans"][key]
+ )
+ elif key == "ent":
+ assert isinstance(instance["ans"][key], list) and all(
+ isinstance(x, dict) for x in instance["ans"][key]
+ )
+ for ent in instance["ans"][key]:
+ assert (
+ isinstance(ent["type"], str)
+ and ent["type"] in instance["schema"]["ent"]
+ )
+ assert (
+ isinstance(ent["text"], str)
+ and instance["text"][ent["span"][0] : ent["span"][1]] == ent["text"]
+ )
+ assert (
+ isinstance(ent["span"], list)
+ and len(ent["span"]) == 2
+ and all(isinstance(x, int) for x in ent["span"])
+ )
+ elif key == "rel":
+ assert isinstance(instance["ans"][key], list) and all(
+ isinstance(x, dict) for x in instance["ans"][key]
+ )
+ for rel in instance["ans"][key]:
+ assert (
+ isinstance(rel["relation"], str)
+ and rel["relation"] in instance["schema"]["rel"]
+ )
+ assert (
+ isinstance(rel["head"], dict)
+ and instance["text"][
+ rel["head"]["span"][0] : rel["head"]["span"][1]
+ ]
+ == rel["head"]["text"]
+ )
+ assert (
+ isinstance(rel["tail"], dict)
+ and instance["text"][
+ rel["tail"]["span"][0] : rel["tail"]["span"][1]
+ ]
+ == rel["tail"]["text"]
+ )
+ elif key == "event":
+ assert isinstance(instance["ans"][key], list) and all(
+ isinstance(x, dict) for x in instance["ans"][key]
+ )
+ for event in instance["ans"][key]:
+ assert event["event_type"] in instance["schema"]["event"]
+ assert (
+ isinstance(event["trigger"], dict)
+ and event["trigger"]["text"] in instance["text"]
+ and instance["text"][
+ event["trigger"]["span"][0] : event["trigger"]["span"][1]
+ ]
+ == event["trigger"]["text"]
+ )
+ for arg in event["args"]:
+ assert (
+ arg["role"] in instance["schema"]["event"][event["event_type"]]
+ )
+ assert (
+ isinstance(arg["text"], str)
+ and instance["text"][arg["span"][0] : arg["span"][1]]
+ == arg["text"]
+ )
+ elif key == "span":
+ assert isinstance(instance["ans"][key], list) and all(
+ isinstance(x, dict) for x in instance["ans"][key]
+ )
+ for span in instance["ans"][key]:
+ assert (
+ isinstance(span["text"], str)
+ and instance["text"][span["span"][0] : span["span"][1]]
+ == span["text"]
+ )
+ else:
+ raise ValueError
+ assert isinstance(instance["text"], str)
+ assert isinstance(instance["bg"], str)
+ for key in ["ent", "rel", "event"]:
+ if instance["schema"].get(key):
+ assert len(instance["text"]) > 0
+ if "span" in instance["ans"]:
+ assert len(instance["text"]) > 0
+ assert instance["instruction"] or instance["text"] or instance["bg"]
+
+
+def is_valid_udi_instance(instance: dict):
+ ok = True
+ try:
+ check_udi_instance(instance)
+ except:
+ ok = False
+ return ok
+
+
+def main():
+ filepaths = []
+ for filepath in filepaths:
+ data = load_jsonlines(filepath)
+ data_ok = True
+ for ins in data:
+ ok = is_valid_udi_instance(ins)
+ if not ok:
+ data_ok = False
+ break
+ if not data_ok:
+ print(filepath)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7f5516fdea1457db640645847db0a842dfda398
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,320 @@
+from collections import defaultdict
+
+import torch
+from rex.utils.iteration import windowed_queue_iter
+from rex.utils.position import find_all_positions
+
+
+def find_paths_from_adj_mat(adj_mat: torch.Tensor) -> list[tuple[int]]:
+ assert adj_mat.shape[0] == adj_mat.shape[1] and len(adj_mat.shape) == 2
+
+ paths = []
+ self_loops = set()
+ adj_map = defaultdict(set)
+ rev_adj_map = defaultdict(set)
+ # current -> next
+ for c, n in adj_mat.detach().nonzero().tolist():
+ # self-loop
+ if c == n:
+ self_loops.add(c)
+ else:
+ adj_map[c].add(n)
+ # reversed map
+ rev_adj_map[n].add(c)
+ for self_loop_node in self_loops:
+ paths.append((self_loop_node,))
+
+ def track(path: tuple[int], c: int):
+ visited: set[tuple[int]] = set()
+ stack = [(path, c)]
+ while stack:
+ path, c = stack.pop()
+ if c in adj_map:
+ for n in adj_map[c]:
+ if (c, n) in visited:
+ continue
+ visited.add((c, n))
+ stack.append((path + (c,), n))
+ # else:
+ if path:
+ paths.append(path + (c,))
+
+ # def track(path: tuple[int], c: int, visited: set[tuple[int]]):
+ # if c in adj_map:
+ # for n in adj_map[c]:
+ # if (c, n) in visited:
+ # continue
+ # visited.add((c, n))
+ # track(path + (c,), n, visited)
+ # else:
+ # if path:
+ # paths.append(path + (c,))
+
+ # # # include loops
+ # # if path not in paths and all(not set(path).issubset(p) for p in paths):
+ # # paths.append(path)
+
+ start_nodes = set(adj_map.keys()) - set(rev_adj_map.keys())
+ for c in start_nodes:
+ ns = adj_map[c]
+ for n in ns:
+ track((c,), n)
+
+ return paths
+
+
+def encode_nnw_thw_mat(
+ spans: list[tuple[int]], seq_len: int, nnw_id: int = 0, thw_id: int = 1
+) -> torch.Tensor:
+ mat = torch.zeros(2, seq_len, seq_len)
+ for span in spans:
+ if len(span) == 1:
+ mat[:, span[0], span[0]] = 1
+ else:
+ for s, e in windowed_queue_iter(span, 2, 1, drop_last=True):
+ mat[nnw_id, s, e] = 1
+ mat[thw_id, span[-1], span[0]] = 1
+ return mat
+
+
+def decode_nnw_thw_mat(
+ batch_mat: torch.LongTensor,
+ nnw_id: int = 0,
+ thw_id: int = 1,
+ offsets: list[int] = None,
+) -> list[list[tuple[int]]]:
+ """Decode NNW THW matrix into a list of spans
+
+ Args:
+ matrix: (batch_size, 2, seq_len, seq_len)
+ """
+ ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape
+ assert seq_len1 == seq_len2
+ assert cls_num == 2
+
+ result_batch = []
+ for ins_id in range(ins_num):
+ offset = offsets[ins_id] if offsets else 0
+ ins_span_paths = []
+ # ins_mat: (2, seq_len, seq_len)
+ ins_mat = batch_mat[ins_id]
+ nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...])
+ end_start_to_paths = defaultdict(set)
+ for path in nnw_paths:
+ end_start_to_paths[(path[-1], path[0])].add(path)
+ thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist()
+ # reversed match, end -> start
+ for e, s in thw_pairs:
+ for path in end_start_to_paths[(e, s)]:
+ ins_span_paths.append(tuple(i - offset for i in path))
+ result_batch.append(ins_span_paths)
+
+ return result_batch
+
+
+def decode_pointer_mat(
+ batch_mat: torch.LongTensor, offsets: list[int] = None
+) -> list[list[tuple[int]]]:
+ batch_paths = []
+ for i in range(len(batch_mat)):
+ offset = offsets[i] if offsets else 0
+ coordinates = (batch_mat[i, 0] == 1).nonzero().tolist()
+ paths = []
+ for s, e in coordinates:
+ path = tuple(range(s - offset, e + 1 - offset))
+ paths.append(path)
+ batch_paths.append(paths)
+ return batch_paths
+
+
+def encode_nnw_nsw_thw_mat(
+ spans: list[list[tuple[int]]],
+ seq_len: int,
+ nnw_id: int = 0,
+ nsw_id: int = 1,
+ thw_id: int = 2,
+) -> torch.Tensor:
+ mat = torch.zeros(3, seq_len, seq_len)
+ for parts in spans:
+ span = ()
+ for p_i, part in enumerate(parts):
+ if not all(0 <= el <= seq_len - 1 for el in part):
+ continue
+ span += part
+ if p_i < len(parts) - 1 and 0 <= parts[p_i + 1][0] <= seq_len - 1:
+ # current part to next part
+ mat[nsw_id, parts[p_i][-1], parts[p_i + 1][0]] = 1
+ if len(span) == 1:
+ mat[:, span[0], span[0]] = 1
+ elif len(span) > 1:
+ for s, e in windowed_queue_iter(span, 2, 1, drop_last=True):
+ mat[nnw_id, s, e] = 1
+ if span:
+ mat[thw_id, span[-1], span[0]] = 1
+ return mat
+
+
+def split_tuple_by_positions(nums, positions) -> list:
+ """
+ Examples:
+ >>> nums = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ >>> positions = [2, 5, 7]
+ >>> split_tuple_by_positions(nums, positions)
+ ((1, 2), (3, 4, 5), (6, 7), (8, 9, 10))
+ """
+ # Check if the given positions are valid
+ if not all(p < len(nums) for p in positions):
+ raise ValueError("Invalid positions")
+
+ # Add 0 and len(nums) to the list of positions
+ positions = [0] + sorted(positions) + [len(nums)]
+
+ # Split the tuple into multiple tuples based on the positions
+ result = []
+ for i in range(1, len(positions)):
+ start = positions[i - 1]
+ end = positions[i]
+ result.append(nums[start:end])
+
+ return result
+
+
+def decode_nnw_nsw_thw_mat(
+ batch_mat: torch.LongTensor,
+ nnw_id: int = 0,
+ nsw_id: int = 1,
+ thw_id: int = 2,
+ offsets: list[int] = None,
+) -> list[list[tuple[int]]]:
+ """Decode NNW NSW THW matrix into a list of spans
+ One span has multiple parts
+
+ Args:
+ batch_mat: (batch_size, 3, seq_len, seq_len)
+ """
+ ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape
+ assert seq_len1 == seq_len2
+ assert cls_num == 3
+
+ result_batch = []
+ for ins_id in range(ins_num):
+ offset = offsets[ins_id] if offsets else 0
+ ins_span_paths = set()
+ # ins_mat: (2, seq_len, seq_len)
+ ins_mat = batch_mat[ins_id]
+ nsw_connections = {
+ (part1e, part2s)
+ for part1e, part2s in ins_mat[nsw_id, ...].detach().nonzero().tolist()
+ }
+ nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...])
+ end_start_to_paths = defaultdict(set)
+ for path in nnw_paths:
+ end_start_to_paths[(path[-1], path[0])].add(path)
+ thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist()
+ # reversed match, end -> start
+ for e, s in thw_pairs:
+ for path in nnw_paths:
+ if s in path:
+ sub_path = path[path.index(s) :]
+ if e in sub_path:
+ sub_path = sub_path[: sub_path.index(e) + 1]
+ chain = tuple(i - offset for i in sub_path)
+ parts = []
+ all_sep_positions = set()
+ # cut path into multiple spans if there are skip links
+ if len(chain) > 1:
+ for sep in nsw_connections:
+ sep = tuple(i - offset for i in sep)
+ positions = find_all_positions(list(chain), list(sep))
+ if positions:
+ # +1: (5, 6, 269) with (6, 269) as sep, found position is 1,
+ # while we want to split after 6, which needs +1
+ positions = {p[0] + 1 for p in positions}
+ all_sep_positions.update(positions)
+ parts = split_tuple_by_positions(chain, all_sep_positions)
+ if not parts:
+ parts = [chain]
+ ins_span_paths.add(tuple(parts))
+ result_batch.append(list(ins_span_paths))
+
+ return result_batch
+
+
+# def encode_nnw_nsw_thw_mat(
+# spans: list[list[tuple[int]]],
+# seq_len: int,
+# nnw_id: int = 0,
+# nsw_id: int = 1,
+# thw_id: int = 2,
+# ) -> torch.Tensor:
+# mat = torch.zeros(3, seq_len, seq_len)
+# for span in spans:
+# for p_i, part in enumerate(span):
+# if len(part) == 1:
+# mat[:, part[0], part[0]] = 1
+# else:
+# for s, e in windowed_queue_iter(part, 2, 1, drop_last=True):
+# mat[nnw_id, s, e] = 1
+# if p_i < len(span) - 1:
+# # current part to next part
+# mat[nsw_id, span[p_i][-1], span[p_i + 1][0]] = 1
+# mat[thw_id, span[-1][-1], span[0][0]] = 1
+# return mat
+
+
+# def decode_nnw_nsw_thw_mat(
+# batch_mat: torch.LongTensor,
+# nnw_id: int = 0,
+# nsw_id: int = 1,
+# thw_id: int = 2,
+# offsets: list[int] = None,
+# ) -> list[list[tuple[int]]]:
+# """Decode NNW NSW THW matrix into a list of spans
+# One span has multiple parts
+
+# Args:
+# batch_mat: (batch_size, 3, seq_len, seq_len)
+# """
+
+# ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape
+# assert seq_len1 == seq_len2
+# assert cls_num == 2
+
+# result_batch = []
+# for ins_id in range(ins_num):
+# offset = offsets[ins_id] if offsets else 0
+# ins_span_paths = []
+# # ins_mat: (3, seq_len, seq_len)
+# ins_mat = batch_mat[ins_id]
+# nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...])
+
+# path_index = {"s": defaultdict(set), "e": defaultdict(set)}
+# for path in nnw_paths:
+# s = path[0]
+# e = path[-1]
+# path_index["s"][s].add(path)
+# path_index["e"][e].add(path)
+
+# nsw_connections = {(part1e, part2s) for part1e, part2s in ins_mat[nsw_id, ...].detach().nonzero().tolist()}
+# thw_connections = {(span_e, span_s) for span_e, span_s in ins_mat[thw_id, ...].detach().nonzero().tolist()}
+# for e, s in thw_connections:
+
+
+# path_span_combinations = []
+# for part1_e, part2_s in nsw_connections:
+# part1s = path_index["e"][part1_e]
+# part2s = path_index["s"][part2_s]
+# # for part1 in part1s:
+# # for part2 in part2s:
+# # if ()
+
+# end_start_to_paths = defaultdict(set)
+# for path in nnw_paths:
+# end_start_to_paths[(path[-1], path[0])].add(path)
+# # reversed match, end -> start
+# for e, s in thw_pairs:
+# for path in end_start_to_paths[(e, s)]:
+# ins_span_paths.append(tuple(i - offset for i in path))
+# result_batch.append(ins_span_paths)
+
+# return result_batch
diff --git a/src/wait.py b/src/wait.py
new file mode 100644
index 0000000000000000000000000000000000000000..adb992ce66bb549841b24b0492b73ccbf41ab32d
--- /dev/null
+++ b/src/wait.py
@@ -0,0 +1,47 @@
+import argparse
+import random
+import string
+import sys
+
+from watchmen import WatchClient
+
+
+def parse_args(in_args=None):
+ arg_parser = argparse.ArgumentParser()
+ arg_parser.add_argument("--task_name", type=str, required=True, help="Take Name")
+ arg_parser.add_argument("--cuda", type=str, required=True, help="cuda to be waited")
+ arg_parser.add_argument(
+ "--req_gpu_num",
+ type=int,
+ required=False,
+ default=1,
+ help="request number of gpus",
+ )
+ arg_parser.add_argument(
+ "--wait",
+ choices=["schedule", "queue", "none"],
+ default="none",
+ help="scheduling/queue wait",
+ )
+ arg_info = arg_parser.parse_args(args=in_args)
+ return arg_info
+
+
+if __name__ == "__main__":
+ in_argv = parse_args()
+ if in_argv.wait == "none":
+ sys.exit(0)
+ random_id = "-" + "".join(random.sample(string.ascii_letters + string.digits, 8))
+ exp_id = in_argv.task_name + random_id
+ watch_client = WatchClient(
+ id=exp_id,
+ gpus=eval(f"[{in_argv.cuda}]"),
+ server_host="localhost",
+ server_port=62333,
+ req_gpu_num=in_argv.req_gpu_num,
+ mode=in_argv.wait,
+ timeout=60,
+ )
+ available_gpus = watch_client.wait()
+ available_gpus = [str(x) for x in available_gpus]
+ print(",".join(available_gpus))
diff --git a/tox.ini b/tox.ini
new file mode 100644
index 0000000000000000000000000000000000000000..b8d0f8abf145d51bc3d27a024717321e1ece6bbc
--- /dev/null
+++ b/tox.ini
@@ -0,0 +1,12 @@
+[flake8]
+ignore=
+ # line length
+ E501,
+ # whitespace before ':'
+ E203,
+ # line break before binary operator
+ W503,
+ # import but not used
+ F401
+exclude=
+ debug.py