CarlosMalaga
commited on
Commit
•
3376207
1
Parent(s):
e883357
Delete relik
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- relik/__init__.py +0 -1
- relik/common/__init__.py +0 -0
- relik/common/log.py +0 -97
- relik/common/upload.py +0 -128
- relik/common/utils.py +0 -609
- relik/inference/__init__.py +0 -0
- relik/inference/annotator.py +0 -428
- relik/inference/data/__init__.py +0 -0
- relik/inference/data/objects.py +0 -64
- relik/inference/data/tokenizers/__init__.py +0 -89
- relik/inference/data/tokenizers/base_tokenizer.py +0 -84
- relik/inference/data/tokenizers/regex_tokenizer.py +0 -73
- relik/inference/data/tokenizers/spacy_tokenizer.py +0 -228
- relik/inference/data/tokenizers/whitespace_tokenizer.py +0 -70
- relik/inference/data/window/__init__.py +0 -0
- relik/inference/data/window/manager.py +0 -262
- relik/inference/gerbil.py +0 -254
- relik/inference/preprocessing.py +0 -4
- relik/inference/serve/__init__.py +0 -0
- relik/inference/serve/backend/__init__.py +0 -0
- relik/inference/serve/backend/relik.py +0 -210
- relik/inference/serve/backend/retriever.py +0 -206
- relik/inference/serve/backend/utils.py +0 -29
- relik/inference/serve/frontend/__init__.py +0 -0
- relik/inference/serve/frontend/relik.py +0 -231
- relik/inference/serve/frontend/style.css +0 -33
- relik/reader/__init__.py +0 -0
- relik/reader/conf/config.yaml +0 -14
- relik/reader/conf/data/base.yaml +0 -21
- relik/reader/conf/data/re.yaml +0 -54
- relik/reader/conf/training/base.yaml +0 -12
- relik/reader/conf/training/re.yaml +0 -12
- relik/reader/data/__init__.py +0 -0
- relik/reader/data/patches.py +0 -51
- relik/reader/data/relik_reader_data.py +0 -965
- relik/reader/data/relik_reader_data_utils.py +0 -51
- relik/reader/data/relik_reader_sample.py +0 -49
- relik/reader/lightning_modules/__init__.py +0 -0
- relik/reader/lightning_modules/relik_reader_pl_module.py +0 -50
- relik/reader/lightning_modules/relik_reader_re_pl_module.py +0 -54
- relik/reader/pytorch_modules/__init__.py +0 -0
- relik/reader/pytorch_modules/base.py +0 -248
- relik/reader/pytorch_modules/hf/__init__.py +0 -2
- relik/reader/pytorch_modules/hf/configuration_relik.py +0 -33
- relik/reader/pytorch_modules/hf/modeling_relik.py +0 -981
- relik/reader/pytorch_modules/optim/__init__.py +0 -6
- relik/reader/pytorch_modules/optim/adamw_with_warmup.py +0 -66
- relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py +0 -104
- relik/reader/pytorch_modules/span.py +0 -367
- relik/reader/relik_reader.py +0 -629
relik/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from relik.retriever.pytorch_modules.model import GoldenRetriever
|
|
|
|
relik/common/__init__.py
DELETED
File without changes
|
relik/common/log.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import sys
|
3 |
-
import threading
|
4 |
-
from typing import Optional
|
5 |
-
|
6 |
-
from rich import get_console
|
7 |
-
|
8 |
-
_lock = threading.Lock()
|
9 |
-
_default_handler: Optional[logging.Handler] = None
|
10 |
-
|
11 |
-
_default_log_level = logging.WARNING
|
12 |
-
|
13 |
-
# fancy logger
|
14 |
-
_console = get_console()
|
15 |
-
|
16 |
-
|
17 |
-
def _get_library_name() -> str:
|
18 |
-
return __name__.split(".")[0]
|
19 |
-
|
20 |
-
|
21 |
-
def _get_library_root_logger() -> logging.Logger:
|
22 |
-
return logging.getLogger(_get_library_name())
|
23 |
-
|
24 |
-
|
25 |
-
def _configure_library_root_logger() -> None:
|
26 |
-
global _default_handler
|
27 |
-
|
28 |
-
with _lock:
|
29 |
-
if _default_handler:
|
30 |
-
# This library has already configured the library root logger.
|
31 |
-
return
|
32 |
-
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
33 |
-
_default_handler.flush = sys.stderr.flush
|
34 |
-
|
35 |
-
# Apply our default configuration to the library root logger.
|
36 |
-
library_root_logger = _get_library_root_logger()
|
37 |
-
library_root_logger.addHandler(_default_handler)
|
38 |
-
library_root_logger.setLevel(_default_log_level)
|
39 |
-
library_root_logger.propagate = False
|
40 |
-
|
41 |
-
|
42 |
-
def _reset_library_root_logger() -> None:
|
43 |
-
global _default_handler
|
44 |
-
|
45 |
-
with _lock:
|
46 |
-
if not _default_handler:
|
47 |
-
return
|
48 |
-
|
49 |
-
library_root_logger = _get_library_root_logger()
|
50 |
-
library_root_logger.removeHandler(_default_handler)
|
51 |
-
library_root_logger.setLevel(logging.NOTSET)
|
52 |
-
_default_handler = None
|
53 |
-
|
54 |
-
|
55 |
-
def set_log_level(level: int, logger: logging.Logger = None) -> None:
|
56 |
-
"""
|
57 |
-
Set the log level.
|
58 |
-
Args:
|
59 |
-
level (:obj:`int`):
|
60 |
-
Logging level.
|
61 |
-
logger (:obj:`logging.Logger`):
|
62 |
-
Logger to set the log level.
|
63 |
-
"""
|
64 |
-
if not logger:
|
65 |
-
_configure_library_root_logger()
|
66 |
-
logger = _get_library_root_logger()
|
67 |
-
logger.setLevel(level)
|
68 |
-
|
69 |
-
|
70 |
-
def get_logger(
|
71 |
-
name: Optional[str] = None,
|
72 |
-
level: Optional[int] = None,
|
73 |
-
formatter: Optional[str] = None,
|
74 |
-
) -> logging.Logger:
|
75 |
-
"""
|
76 |
-
Return a logger with the specified name.
|
77 |
-
"""
|
78 |
-
|
79 |
-
if name is None:
|
80 |
-
name = _get_library_name()
|
81 |
-
|
82 |
-
_configure_library_root_logger()
|
83 |
-
|
84 |
-
if level is not None:
|
85 |
-
set_log_level(level)
|
86 |
-
|
87 |
-
if formatter is None:
|
88 |
-
formatter = logging.Formatter(
|
89 |
-
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
90 |
-
)
|
91 |
-
_default_handler.setFormatter(formatter)
|
92 |
-
|
93 |
-
return logging.getLogger(name)
|
94 |
-
|
95 |
-
|
96 |
-
def get_console_logger():
|
97 |
-
return _console
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/common/upload.py
DELETED
@@ -1,128 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import json
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
import tempfile
|
6 |
-
import zipfile
|
7 |
-
from datetime import datetime
|
8 |
-
from pathlib import Path
|
9 |
-
from typing import Optional, Union
|
10 |
-
|
11 |
-
import huggingface_hub
|
12 |
-
|
13 |
-
from relik.common.log import get_logger
|
14 |
-
from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
|
15 |
-
|
16 |
-
logger = get_logger(level=logging.DEBUG)
|
17 |
-
|
18 |
-
|
19 |
-
def create_info_file(tmpdir: Path):
|
20 |
-
logger.debug("Computing md5 of model.zip")
|
21 |
-
md5 = get_md5(tmpdir / "model.zip")
|
22 |
-
date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
|
23 |
-
|
24 |
-
logger.debug("Dumping info.json file")
|
25 |
-
with (tmpdir / "info.json").open("w") as f:
|
26 |
-
json.dump(dict(md5=md5, upload_date=date), f, indent=2)
|
27 |
-
|
28 |
-
|
29 |
-
def zip_run(
|
30 |
-
dir_path: Union[str, os.PathLike],
|
31 |
-
tmpdir: Union[str, os.PathLike],
|
32 |
-
zip_name: str = "model.zip",
|
33 |
-
) -> Path:
|
34 |
-
logger.debug(f"zipping {dir_path} to {tmpdir}")
|
35 |
-
# creates a zip version of the provided dir_path
|
36 |
-
run_dir = Path(dir_path)
|
37 |
-
zip_path = tmpdir / zip_name
|
38 |
-
|
39 |
-
with zipfile.ZipFile(zip_path, "w") as zip_file:
|
40 |
-
# fully zip the run directory maintaining its structure
|
41 |
-
for file in run_dir.rglob("*.*"):
|
42 |
-
if file.is_dir():
|
43 |
-
continue
|
44 |
-
|
45 |
-
zip_file.write(file, arcname=file.relative_to(run_dir))
|
46 |
-
|
47 |
-
return zip_path
|
48 |
-
|
49 |
-
|
50 |
-
def upload(
|
51 |
-
model_dir: Union[str, os.PathLike],
|
52 |
-
model_name: str,
|
53 |
-
organization: Optional[str] = None,
|
54 |
-
repo_name: Optional[str] = None,
|
55 |
-
commit: Optional[str] = None,
|
56 |
-
archive: bool = False,
|
57 |
-
):
|
58 |
-
token = huggingface_hub.HfFolder.get_token()
|
59 |
-
if token is None:
|
60 |
-
print(
|
61 |
-
"No HuggingFace token found. You need to execute `huggingface-cli login` first!"
|
62 |
-
)
|
63 |
-
return
|
64 |
-
|
65 |
-
repo_id = repo_name or model_name
|
66 |
-
if organization is not None:
|
67 |
-
repo_id = f"{organization}/{repo_id}"
|
68 |
-
with tempfile.TemporaryDirectory() as tmpdir:
|
69 |
-
api = huggingface_hub.HfApi()
|
70 |
-
repo_url = api.create_repo(
|
71 |
-
token=token,
|
72 |
-
repo_id=repo_id,
|
73 |
-
exist_ok=True,
|
74 |
-
)
|
75 |
-
repo = huggingface_hub.Repository(
|
76 |
-
str(tmpdir), clone_from=repo_url, use_auth_token=token
|
77 |
-
)
|
78 |
-
|
79 |
-
tmp_path = Path(tmpdir)
|
80 |
-
if archive:
|
81 |
-
# otherwise we zip the model_dir
|
82 |
-
logger.debug(f"Zipping {model_dir} to {tmp_path}")
|
83 |
-
zip_run(model_dir, tmp_path)
|
84 |
-
create_info_file(tmp_path)
|
85 |
-
else:
|
86 |
-
# if the user wants to upload a transformers model, we don't need to zip it
|
87 |
-
# we just need to copy the files to the tmpdir
|
88 |
-
logger.debug(f"Copying {model_dir} to {tmpdir}")
|
89 |
-
os.system(f"cp -r {model_dir}/* {tmpdir}")
|
90 |
-
|
91 |
-
# this method automatically puts large files (>10MB) into git lfs
|
92 |
-
repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
|
93 |
-
|
94 |
-
|
95 |
-
def parse_args() -> argparse.Namespace:
|
96 |
-
parser = argparse.ArgumentParser()
|
97 |
-
parser.add_argument(
|
98 |
-
"model_dir", help="The directory of the model you want to upload"
|
99 |
-
)
|
100 |
-
parser.add_argument("model_name", help="The model you want to upload")
|
101 |
-
parser.add_argument(
|
102 |
-
"--organization",
|
103 |
-
help="the name of the organization where you want to upload the model",
|
104 |
-
)
|
105 |
-
parser.add_argument(
|
106 |
-
"--repo_name",
|
107 |
-
help="Optional name to use when uploading to the HuggingFace repository",
|
108 |
-
)
|
109 |
-
parser.add_argument(
|
110 |
-
"--commit", help="Commit message to use when pushing to the HuggingFace Hub"
|
111 |
-
)
|
112 |
-
parser.add_argument(
|
113 |
-
"--archive",
|
114 |
-
action="store_true",
|
115 |
-
help="""
|
116 |
-
Whether to compress the model directory before uploading it.
|
117 |
-
If True, the model directory will be zipped and the zip file will be uploaded.
|
118 |
-
If False, the model directory will be uploaded as is.""",
|
119 |
-
)
|
120 |
-
return parser.parse_args()
|
121 |
-
|
122 |
-
|
123 |
-
def main():
|
124 |
-
upload(**vars(parse_args()))
|
125 |
-
|
126 |
-
|
127 |
-
if __name__ == "__main__":
|
128 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/common/utils.py
DELETED
@@ -1,609 +0,0 @@
|
|
1 |
-
import importlib.util
|
2 |
-
import json
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
import shutil
|
6 |
-
import tarfile
|
7 |
-
import tempfile
|
8 |
-
from functools import partial
|
9 |
-
from hashlib import sha256
|
10 |
-
from pathlib import Path
|
11 |
-
from typing import Any, BinaryIO, Dict, List, Optional, Union
|
12 |
-
from urllib.parse import urlparse
|
13 |
-
from zipfile import ZipFile, is_zipfile
|
14 |
-
|
15 |
-
import huggingface_hub
|
16 |
-
import requests
|
17 |
-
import tqdm
|
18 |
-
from filelock import FileLock
|
19 |
-
from transformers.utils.hub import cached_file as hf_cached_file
|
20 |
-
|
21 |
-
from relik.common.log import get_logger
|
22 |
-
|
23 |
-
# name constants
|
24 |
-
WEIGHTS_NAME = "weights.pt"
|
25 |
-
ONNX_WEIGHTS_NAME = "weights.onnx"
|
26 |
-
CONFIG_NAME = "config.yaml"
|
27 |
-
LABELS_NAME = "labels.json"
|
28 |
-
|
29 |
-
# SAPIENZANLP_USER_NAME = "sapienzanlp"
|
30 |
-
SAPIENZANLP_USER_NAME = "riccorl"
|
31 |
-
SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}"
|
32 |
-
SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = (
|
33 |
-
f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip"
|
34 |
-
)
|
35 |
-
# path constants
|
36 |
-
SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp")
|
37 |
-
SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S"
|
38 |
-
|
39 |
-
|
40 |
-
logger = get_logger(__name__)
|
41 |
-
|
42 |
-
|
43 |
-
def sapienzanlp_model_urls(model_id: str) -> str:
|
44 |
-
"""
|
45 |
-
Returns the URL for a possible SapienzaNLP valid model.
|
46 |
-
|
47 |
-
Args:
|
48 |
-
model_id (:obj:`str`):
|
49 |
-
A SapienzaNLP model id.
|
50 |
-
|
51 |
-
Returns:
|
52 |
-
:obj:`str`: The url for the model id.
|
53 |
-
"""
|
54 |
-
# check if there is already the namespace of the user
|
55 |
-
if "/" in model_id:
|
56 |
-
return model_id
|
57 |
-
return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id)
|
58 |
-
|
59 |
-
|
60 |
-
def is_package_available(package_name: str) -> bool:
|
61 |
-
"""
|
62 |
-
Check if a package is available.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
package_name (`str`): The name of the package to check.
|
66 |
-
"""
|
67 |
-
return importlib.util.find_spec(package_name) is not None
|
68 |
-
|
69 |
-
|
70 |
-
def load_json(path: Union[str, Path]) -> Any:
|
71 |
-
"""
|
72 |
-
Load a json file provided in input.
|
73 |
-
|
74 |
-
Args:
|
75 |
-
path (`Union[str, Path]`): The path to the json file to load.
|
76 |
-
|
77 |
-
Returns:
|
78 |
-
`Any`: The loaded json file.
|
79 |
-
"""
|
80 |
-
with open(path, encoding="utf8") as f:
|
81 |
-
return json.load(f)
|
82 |
-
|
83 |
-
|
84 |
-
def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None):
|
85 |
-
"""
|
86 |
-
Dump input to json file.
|
87 |
-
|
88 |
-
Args:
|
89 |
-
document (`Any`): The document to dump.
|
90 |
-
path (`Union[str, Path]`): The path to dump the document to.
|
91 |
-
indent (`Optional[int]`): The indent to use for the json file.
|
92 |
-
|
93 |
-
"""
|
94 |
-
with open(path, "w", encoding="utf8") as outfile:
|
95 |
-
json.dump(document, outfile, indent=indent)
|
96 |
-
|
97 |
-
|
98 |
-
def get_md5(path: Path):
|
99 |
-
"""
|
100 |
-
Get the MD5 value of a path.
|
101 |
-
"""
|
102 |
-
import hashlib
|
103 |
-
|
104 |
-
with path.open("rb") as fin:
|
105 |
-
data = fin.read()
|
106 |
-
return hashlib.md5(data).hexdigest()
|
107 |
-
|
108 |
-
|
109 |
-
def file_exists(path: Union[str, os.PathLike]) -> bool:
|
110 |
-
"""
|
111 |
-
Check if the file at :obj:`path` exists.
|
112 |
-
|
113 |
-
Args:
|
114 |
-
path (:obj:`str`, :obj:`os.PathLike`):
|
115 |
-
Path to check.
|
116 |
-
|
117 |
-
Returns:
|
118 |
-
:obj:`bool`: :obj:`True` if the file exists.
|
119 |
-
"""
|
120 |
-
return Path(path).exists()
|
121 |
-
|
122 |
-
|
123 |
-
def dir_exists(path: Union[str, os.PathLike]) -> bool:
|
124 |
-
"""
|
125 |
-
Check if the directory at :obj:`path` exists.
|
126 |
-
|
127 |
-
Args:
|
128 |
-
path (:obj:`str`, :obj:`os.PathLike`):
|
129 |
-
Path to check.
|
130 |
-
|
131 |
-
Returns:
|
132 |
-
:obj:`bool`: :obj:`True` if the directory exists.
|
133 |
-
"""
|
134 |
-
return Path(path).is_dir()
|
135 |
-
|
136 |
-
|
137 |
-
def is_remote_url(url_or_filename: Union[str, Path]):
|
138 |
-
"""
|
139 |
-
Returns :obj:`True` if the input path is an url.
|
140 |
-
|
141 |
-
Args:
|
142 |
-
url_or_filename (:obj:`str`, :obj:`Path`):
|
143 |
-
path to check.
|
144 |
-
|
145 |
-
Returns:
|
146 |
-
:obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise.
|
147 |
-
|
148 |
-
"""
|
149 |
-
if isinstance(url_or_filename, Path):
|
150 |
-
url_or_filename = str(url_or_filename)
|
151 |
-
parsed = urlparse(url_or_filename)
|
152 |
-
return parsed.scheme in ("http", "https")
|
153 |
-
|
154 |
-
|
155 |
-
def url_to_filename(resource: str, etag: str = None) -> str:
|
156 |
-
"""
|
157 |
-
Convert a `resource` into a hashed filename in a repeatable way.
|
158 |
-
If `etag` is specified, append its hash to the resources's, delimited
|
159 |
-
by a period.
|
160 |
-
"""
|
161 |
-
resource_bytes = resource.encode("utf-8")
|
162 |
-
resource_hash = sha256(resource_bytes)
|
163 |
-
filename = resource_hash.hexdigest()
|
164 |
-
|
165 |
-
if etag:
|
166 |
-
etag_bytes = etag.encode("utf-8")
|
167 |
-
etag_hash = sha256(etag_bytes)
|
168 |
-
filename += "." + etag_hash.hexdigest()
|
169 |
-
|
170 |
-
return filename
|
171 |
-
|
172 |
-
|
173 |
-
def download_resource(
|
174 |
-
url: str,
|
175 |
-
temp_file: BinaryIO,
|
176 |
-
headers=None,
|
177 |
-
):
|
178 |
-
"""
|
179 |
-
Download remote file.
|
180 |
-
"""
|
181 |
-
|
182 |
-
if headers is None:
|
183 |
-
headers = {}
|
184 |
-
|
185 |
-
r = requests.get(url, stream=True, headers=headers)
|
186 |
-
r.raise_for_status()
|
187 |
-
content_length = r.headers.get("Content-Length")
|
188 |
-
total = int(content_length) if content_length is not None else None
|
189 |
-
progress = tqdm(
|
190 |
-
unit="B",
|
191 |
-
unit_scale=True,
|
192 |
-
total=total,
|
193 |
-
desc="Downloading",
|
194 |
-
disable=logger.level in [logging.NOTSET],
|
195 |
-
)
|
196 |
-
for chunk in r.iter_content(chunk_size=1024):
|
197 |
-
if chunk: # filter out keep-alive new chunks
|
198 |
-
progress.update(len(chunk))
|
199 |
-
temp_file.write(chunk)
|
200 |
-
progress.close()
|
201 |
-
|
202 |
-
|
203 |
-
def download_and_cache(
|
204 |
-
url: Union[str, Path],
|
205 |
-
cache_dir: Union[str, Path] = None,
|
206 |
-
force_download: bool = False,
|
207 |
-
):
|
208 |
-
if cache_dir is None:
|
209 |
-
cache_dir = SAPIENZANLP_CACHE_DIR
|
210 |
-
if isinstance(url, Path):
|
211 |
-
url = str(url)
|
212 |
-
|
213 |
-
# check if cache dir exists
|
214 |
-
Path(cache_dir).mkdir(parents=True, exist_ok=True)
|
215 |
-
|
216 |
-
# check if file is private
|
217 |
-
headers = {}
|
218 |
-
try:
|
219 |
-
r = requests.head(url, allow_redirects=False, timeout=10)
|
220 |
-
r.raise_for_status()
|
221 |
-
except requests.exceptions.HTTPError:
|
222 |
-
if r.status_code == 401:
|
223 |
-
hf_token = huggingface_hub.HfFolder.get_token()
|
224 |
-
if hf_token is None:
|
225 |
-
raise ValueError(
|
226 |
-
"You need to login to HuggingFace to download this model "
|
227 |
-
"(use the `huggingface-cli login` command)"
|
228 |
-
)
|
229 |
-
headers["Authorization"] = f"Bearer {hf_token}"
|
230 |
-
|
231 |
-
etag = None
|
232 |
-
try:
|
233 |
-
r = requests.head(url, allow_redirects=True, timeout=10, headers=headers)
|
234 |
-
r.raise_for_status()
|
235 |
-
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
236 |
-
# We favor a custom header indicating the etag of the linked resource, and
|
237 |
-
# we fallback to the regular etag header.
|
238 |
-
# If we don't have any of those, raise an error.
|
239 |
-
if etag is None:
|
240 |
-
raise OSError(
|
241 |
-
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
242 |
-
)
|
243 |
-
# In case of a redirect,
|
244 |
-
# save an extra redirect on the request.get call,
|
245 |
-
# and ensure we download the exact atomic version even if it changed
|
246 |
-
# between the HEAD and the GET (unlikely, but hey).
|
247 |
-
if 300 <= r.status_code <= 399:
|
248 |
-
url = r.headers["Location"]
|
249 |
-
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
250 |
-
# Actually raise for those subclasses of ConnectionError
|
251 |
-
raise
|
252 |
-
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
253 |
-
# Otherwise, our Internet connection is down.
|
254 |
-
# etag is None
|
255 |
-
pass
|
256 |
-
|
257 |
-
# get filename from the url
|
258 |
-
filename = url_to_filename(url, etag)
|
259 |
-
# get cache path to put the file
|
260 |
-
cache_path = cache_dir / filename
|
261 |
-
|
262 |
-
# the file is already here, return it
|
263 |
-
if file_exists(cache_path) and not force_download:
|
264 |
-
logger.info(
|
265 |
-
f"{url} found in cache, set `force_download=True` to force the download"
|
266 |
-
)
|
267 |
-
return cache_path
|
268 |
-
|
269 |
-
cache_path = str(cache_path)
|
270 |
-
# Prevent parallel downloads of the same file with a lock.
|
271 |
-
lock_path = cache_path + ".lock"
|
272 |
-
with FileLock(lock_path):
|
273 |
-
# If the download just completed while the lock was activated.
|
274 |
-
if file_exists(cache_path) and not force_download:
|
275 |
-
# Even if returning early like here, the lock will be released.
|
276 |
-
return cache_path
|
277 |
-
|
278 |
-
temp_file_manager = partial(
|
279 |
-
tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
|
280 |
-
)
|
281 |
-
|
282 |
-
# Download to temporary file, then copy to cache dir once finished.
|
283 |
-
# Otherwise, you get corrupt cache entries if the download gets interrupted.
|
284 |
-
with temp_file_manager() as temp_file:
|
285 |
-
logger.info(
|
286 |
-
f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}"
|
287 |
-
)
|
288 |
-
download_resource(url, temp_file, headers)
|
289 |
-
|
290 |
-
logger.info(f"storing {url} in cache at {cache_path}")
|
291 |
-
os.replace(temp_file.name, cache_path)
|
292 |
-
|
293 |
-
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
|
294 |
-
umask = os.umask(0o666)
|
295 |
-
os.umask(umask)
|
296 |
-
os.chmod(cache_path, 0o666 & ~umask)
|
297 |
-
|
298 |
-
logger.info(f"creating metadata file for {cache_path}")
|
299 |
-
meta = {"url": url} # , "etag": etag}
|
300 |
-
meta_path = cache_path + ".json"
|
301 |
-
with open(meta_path, "w") as meta_file:
|
302 |
-
json.dump(meta, meta_file)
|
303 |
-
|
304 |
-
return cache_path
|
305 |
-
|
306 |
-
|
307 |
-
def download_from_hf(
|
308 |
-
path_or_repo_id: Union[str, Path],
|
309 |
-
filenames: Optional[List[str]],
|
310 |
-
cache_dir: Union[str, Path] = None,
|
311 |
-
force_download: bool = False,
|
312 |
-
resume_download: bool = False,
|
313 |
-
proxies: Optional[Dict[str, str]] = None,
|
314 |
-
use_auth_token: Optional[Union[bool, str]] = None,
|
315 |
-
revision: Optional[str] = None,
|
316 |
-
local_files_only: bool = False,
|
317 |
-
subfolder: str = "",
|
318 |
-
):
|
319 |
-
if isinstance(path_or_repo_id, Path):
|
320 |
-
path_or_repo_id = str(path_or_repo_id)
|
321 |
-
|
322 |
-
downloaded_paths = []
|
323 |
-
for filename in filenames:
|
324 |
-
downloaded_path = hf_cached_file(
|
325 |
-
path_or_repo_id,
|
326 |
-
filename,
|
327 |
-
cache_dir=cache_dir,
|
328 |
-
force_download=force_download,
|
329 |
-
proxies=proxies,
|
330 |
-
resume_download=resume_download,
|
331 |
-
use_auth_token=use_auth_token,
|
332 |
-
revision=revision,
|
333 |
-
local_files_only=local_files_only,
|
334 |
-
subfolder=subfolder,
|
335 |
-
)
|
336 |
-
downloaded_paths.append(downloaded_path)
|
337 |
-
|
338 |
-
# we want the folder where the files are downloaded
|
339 |
-
# the best guess is the parent folder of the first file
|
340 |
-
probably_the_folder = Path(downloaded_paths[0]).parent
|
341 |
-
return probably_the_folder
|
342 |
-
|
343 |
-
|
344 |
-
def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str:
|
345 |
-
"""
|
346 |
-
Resolve a model name or directory to a model archive name or directory.
|
347 |
-
|
348 |
-
Args:
|
349 |
-
model_name_or_dir (:obj:`str` or :obj:`os.PathLike`):
|
350 |
-
A model name or directory.
|
351 |
-
|
352 |
-
Returns:
|
353 |
-
:obj:`str`: The model archive name or directory.
|
354 |
-
"""
|
355 |
-
if is_remote_url(model_name_or_dir):
|
356 |
-
# if model_name_or_dir is a URL
|
357 |
-
# download it and try to load
|
358 |
-
model_archive = model_name_or_dir
|
359 |
-
elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file():
|
360 |
-
# if model_name_or_dir is a local directory or
|
361 |
-
# an archive file try to load it
|
362 |
-
model_archive = model_name_or_dir
|
363 |
-
else:
|
364 |
-
# probably model_name_or_dir is a sapienzanlp model id
|
365 |
-
# guess the url and try to download
|
366 |
-
model_name_or_dir_ = model_name_or_dir
|
367 |
-
# raise ValueError(f"Providing a model id is not supported yet.")
|
368 |
-
model_archive = sapienzanlp_model_urls(model_name_or_dir_)
|
369 |
-
|
370 |
-
return model_archive
|
371 |
-
|
372 |
-
|
373 |
-
def from_cache(
|
374 |
-
url_or_filename: Union[str, Path],
|
375 |
-
cache_dir: Union[str, Path] = None,
|
376 |
-
force_download: bool = False,
|
377 |
-
resume_download: bool = False,
|
378 |
-
proxies: Optional[Dict[str, str]] = None,
|
379 |
-
use_auth_token: Optional[Union[bool, str]] = None,
|
380 |
-
revision: Optional[str] = None,
|
381 |
-
local_files_only: bool = False,
|
382 |
-
subfolder: str = "",
|
383 |
-
filenames: Optional[List[str]] = None,
|
384 |
-
) -> Path:
|
385 |
-
"""
|
386 |
-
Given something that could be either a local path or a URL (or a SapienzaNLP model id),
|
387 |
-
determine which one and return a path to the corresponding file.
|
388 |
-
|
389 |
-
Args:
|
390 |
-
url_or_filename (:obj:`str` or :obj:`Path`):
|
391 |
-
A path to a local file or a URL (or a SapienzaNLP model id).
|
392 |
-
cache_dir (:obj:`str` or :obj:`Path`, `optional`):
|
393 |
-
Path to a directory in which a downloaded file will be cached.
|
394 |
-
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
395 |
-
Whether or not to re-download the file even if it already exists.
|
396 |
-
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
397 |
-
Whether or not to delete incompletely received files. Attempts to resume the download if such a file
|
398 |
-
exists.
|
399 |
-
proxies (:obj:`Dict[str, str]`, `optional`):
|
400 |
-
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
401 |
-
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
402 |
-
use_auth_token (:obj:`Union[bool, str]`, `optional`):
|
403 |
-
Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from
|
404 |
-
:obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token.
|
405 |
-
revision (:obj:`str`, `optional`):
|
406 |
-
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
407 |
-
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
408 |
-
identifier allowed by git.
|
409 |
-
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
410 |
-
Whether or not to raise an error if the file to be downloaded is local.
|
411 |
-
subfolder (:obj:`str`, `optional`):
|
412 |
-
In case the relevant file is in a subfolder of the URL, specify it here.
|
413 |
-
filenames (:obj:`List[str]`, `optional`):
|
414 |
-
List of filenames to look for in the directory structure.
|
415 |
-
|
416 |
-
Returns:
|
417 |
-
:obj:`Path`: Path to the cached file.
|
418 |
-
"""
|
419 |
-
|
420 |
-
url_or_filename = model_name_or_path_resolver(url_or_filename)
|
421 |
-
|
422 |
-
if cache_dir is None:
|
423 |
-
cache_dir = SAPIENZANLP_CACHE_DIR
|
424 |
-
|
425 |
-
if file_exists(url_or_filename):
|
426 |
-
logger.info(f"{url_or_filename} is a local path or file")
|
427 |
-
output_path = url_or_filename
|
428 |
-
elif is_remote_url(url_or_filename):
|
429 |
-
# URL, so get it from the cache (downloading if necessary)
|
430 |
-
output_path = download_and_cache(
|
431 |
-
url_or_filename,
|
432 |
-
cache_dir=cache_dir,
|
433 |
-
force_download=force_download,
|
434 |
-
)
|
435 |
-
else:
|
436 |
-
if filenames is None:
|
437 |
-
filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME]
|
438 |
-
output_path = download_from_hf(
|
439 |
-
url_or_filename,
|
440 |
-
filenames,
|
441 |
-
cache_dir,
|
442 |
-
force_download,
|
443 |
-
resume_download,
|
444 |
-
proxies,
|
445 |
-
use_auth_token,
|
446 |
-
revision,
|
447 |
-
local_files_only,
|
448 |
-
subfolder,
|
449 |
-
)
|
450 |
-
|
451 |
-
# if is_hf_hub_url(url_or_filename):
|
452 |
-
# HuggingFace Hub
|
453 |
-
# output_path = hf_hub_download_url(url_or_filename)
|
454 |
-
# elif is_remote_url(url_or_filename):
|
455 |
-
# # URL, so get it from the cache (downloading if necessary)
|
456 |
-
# output_path = download_and_cache(
|
457 |
-
# url_or_filename,
|
458 |
-
# cache_dir=cache_dir,
|
459 |
-
# force_download=force_download,
|
460 |
-
# )
|
461 |
-
# elif file_exists(url_or_filename):
|
462 |
-
# logger.info(f"{url_or_filename} is a local path or file")
|
463 |
-
# # File, and it exists.
|
464 |
-
# output_path = url_or_filename
|
465 |
-
# elif urlparse(url_or_filename).scheme == "":
|
466 |
-
# # File, but it doesn't exist.
|
467 |
-
# raise EnvironmentError(f"file {url_or_filename} not found")
|
468 |
-
# else:
|
469 |
-
# # Something unknown
|
470 |
-
# raise ValueError(
|
471 |
-
# f"unable to parse {url_or_filename} as a URL or as a local path"
|
472 |
-
# )
|
473 |
-
|
474 |
-
if dir_exists(output_path) or (
|
475 |
-
not is_zipfile(output_path) and not tarfile.is_tarfile(output_path)
|
476 |
-
):
|
477 |
-
return Path(output_path)
|
478 |
-
|
479 |
-
# Path where we extract compressed archives
|
480 |
-
# for now it will extract it in the same folder
|
481 |
-
# maybe implement extraction in the sapienzanlp folder
|
482 |
-
# when using local archive path?
|
483 |
-
logger.info("Extracting compressed archive")
|
484 |
-
output_dir, output_file = os.path.split(output_path)
|
485 |
-
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
486 |
-
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
487 |
-
|
488 |
-
# already extracted, do not extract
|
489 |
-
if (
|
490 |
-
os.path.isdir(output_path_extracted)
|
491 |
-
and os.listdir(output_path_extracted)
|
492 |
-
and not force_download
|
493 |
-
):
|
494 |
-
return Path(output_path_extracted)
|
495 |
-
|
496 |
-
# Prevent parallel extractions
|
497 |
-
lock_path = output_path + ".lock"
|
498 |
-
with FileLock(lock_path):
|
499 |
-
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
500 |
-
os.makedirs(output_path_extracted)
|
501 |
-
if is_zipfile(output_path):
|
502 |
-
with ZipFile(output_path, "r") as zip_file:
|
503 |
-
zip_file.extractall(output_path_extracted)
|
504 |
-
zip_file.close()
|
505 |
-
elif tarfile.is_tarfile(output_path):
|
506 |
-
tar_file = tarfile.open(output_path)
|
507 |
-
tar_file.extractall(output_path_extracted)
|
508 |
-
tar_file.close()
|
509 |
-
else:
|
510 |
-
raise EnvironmentError(
|
511 |
-
f"Archive format of {output_path} could not be identified"
|
512 |
-
)
|
513 |
-
|
514 |
-
# remove lock file, is it safe?
|
515 |
-
os.remove(lock_path)
|
516 |
-
|
517 |
-
return Path(output_path_extracted)
|
518 |
-
|
519 |
-
|
520 |
-
def is_str_a_path(maybe_path: str) -> bool:
|
521 |
-
"""
|
522 |
-
Check if a string is a path.
|
523 |
-
|
524 |
-
Args:
|
525 |
-
maybe_path (`str`): The string to check.
|
526 |
-
|
527 |
-
Returns:
|
528 |
-
`bool`: `True` if the string is a path, `False` otherwise.
|
529 |
-
"""
|
530 |
-
# first check if it is a path
|
531 |
-
if Path(maybe_path).exists():
|
532 |
-
return True
|
533 |
-
# check if it is a relative path
|
534 |
-
if Path(os.path.join(os.getcwd(), maybe_path)).exists():
|
535 |
-
return True
|
536 |
-
# otherwise it is not a path
|
537 |
-
return False
|
538 |
-
|
539 |
-
|
540 |
-
def relative_to_absolute_path(path: str) -> os.PathLike:
|
541 |
-
"""
|
542 |
-
Convert a relative path to an absolute path.
|
543 |
-
|
544 |
-
Args:
|
545 |
-
path (`str`): The relative path to convert.
|
546 |
-
|
547 |
-
Returns:
|
548 |
-
`os.PathLike`: The absolute path.
|
549 |
-
"""
|
550 |
-
if not is_str_a_path(path):
|
551 |
-
raise ValueError(f"{path} is not a path")
|
552 |
-
if Path(path).exists():
|
553 |
-
return Path(path).absolute()
|
554 |
-
if Path(os.path.join(os.getcwd(), path)).exists():
|
555 |
-
return Path(os.path.join(os.getcwd(), path)).absolute()
|
556 |
-
raise ValueError(f"{path} is not a path")
|
557 |
-
|
558 |
-
|
559 |
-
def to_config(object_to_save: Any) -> Dict[str, Any]:
|
560 |
-
"""
|
561 |
-
Convert an object to a dictionary.
|
562 |
-
|
563 |
-
Returns:
|
564 |
-
`Dict[str, Any]`: The dictionary representation of the object.
|
565 |
-
"""
|
566 |
-
|
567 |
-
def obj_to_dict(obj):
|
568 |
-
match obj:
|
569 |
-
case dict():
|
570 |
-
data = {}
|
571 |
-
for k, v in obj.items():
|
572 |
-
data[k] = obj_to_dict(v)
|
573 |
-
return data
|
574 |
-
|
575 |
-
case list() | tuple():
|
576 |
-
return [obj_to_dict(x) for x in obj]
|
577 |
-
|
578 |
-
case object(__dict__=_):
|
579 |
-
data = {
|
580 |
-
"_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
|
581 |
-
}
|
582 |
-
for k, v in obj.__dict__.items():
|
583 |
-
if not k.startswith("_"):
|
584 |
-
data[k] = obj_to_dict(v)
|
585 |
-
return data
|
586 |
-
|
587 |
-
case _:
|
588 |
-
return obj
|
589 |
-
|
590 |
-
return obj_to_dict(object_to_save)
|
591 |
-
|
592 |
-
|
593 |
-
def get_callable_from_string(callable_fn: str) -> Any:
|
594 |
-
"""
|
595 |
-
Get a callable from a string.
|
596 |
-
|
597 |
-
Args:
|
598 |
-
callable_fn (`str`):
|
599 |
-
The string representation of the callable.
|
600 |
-
|
601 |
-
Returns:
|
602 |
-
`Any`: The callable.
|
603 |
-
"""
|
604 |
-
# separate the function name from the module name
|
605 |
-
module_name, function_name = callable_fn.rsplit(".", 1)
|
606 |
-
# import the module
|
607 |
-
module = importlib.import_module(module_name)
|
608 |
-
# get the function
|
609 |
-
return getattr(module, function_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/__init__.py
DELETED
File without changes
|
relik/inference/annotator.py
DELETED
@@ -1,428 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Any, Callable, Dict, Optional, Union
|
4 |
-
|
5 |
-
import hydra
|
6 |
-
from omegaconf import OmegaConf
|
7 |
-
from relik.retriever.indexers.faiss import FaissDocumentIndex
|
8 |
-
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
|
9 |
-
from rich.pretty import pprint
|
10 |
-
|
11 |
-
from relik.common.log import get_console_logger, get_logger
|
12 |
-
from relik.common.upload import upload
|
13 |
-
from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string
|
14 |
-
from relik.inference.data.objects import EntitySpan, RelikOutput
|
15 |
-
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
16 |
-
from relik.inference.data.window.manager import WindowManager
|
17 |
-
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
|
18 |
-
from relik.reader.relik_reader import RelikReader
|
19 |
-
from relik.retriever.data.utils import batch_generator
|
20 |
-
from relik.retriever.indexers.base import BaseDocumentIndex
|
21 |
-
from relik.retriever.pytorch_modules.model import GoldenRetriever
|
22 |
-
|
23 |
-
logger = get_logger(__name__)
|
24 |
-
console_logger = get_console_logger()
|
25 |
-
|
26 |
-
|
27 |
-
class Relik:
|
28 |
-
"""
|
29 |
-
Relik main class. It is a wrapper around a retriever and a reader.
|
30 |
-
|
31 |
-
Args:
|
32 |
-
retriever (`Optional[GoldenRetriever]`, `optional`):
|
33 |
-
The retriever to use. If `None`, a retriever will be instantiated from the
|
34 |
-
provided `question_encoder`, `passage_encoder` and `document_index`.
|
35 |
-
Defaults to `None`.
|
36 |
-
question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
|
37 |
-
The question encoder to use. If `retriever` is `None`, a retriever will be
|
38 |
-
instantiated from this parameter. Defaults to `None`.
|
39 |
-
passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
|
40 |
-
The passage encoder to use. If `retriever` is `None`, a retriever will be
|
41 |
-
instantiated from this parameter. Defaults to `None`.
|
42 |
-
document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`):
|
43 |
-
The document index to use. If `retriever` is `None`, a retriever will be
|
44 |
-
instantiated from this parameter. Defaults to `None`.
|
45 |
-
reader (`Optional[Union[str, RelikReader]]`, `optional`):
|
46 |
-
The reader to use. If `None`, a reader will be instantiated from the
|
47 |
-
provided `reader`. Defaults to `None`.
|
48 |
-
retriever_device (`str`, `optional`, defaults to `cpu`):
|
49 |
-
The device to use for the retriever.
|
50 |
-
|
51 |
-
"""
|
52 |
-
|
53 |
-
def __init__(
|
54 |
-
self,
|
55 |
-
retriever: GoldenRetriever | None = None,
|
56 |
-
question_encoder: str | GoldenRetrieverModel | None = None,
|
57 |
-
passage_encoder: str | GoldenRetrieverModel | None = None,
|
58 |
-
document_index: str | BaseDocumentIndex | None = None,
|
59 |
-
reader: str | RelikReader | None = None,
|
60 |
-
device: str = "cpu",
|
61 |
-
retriever_device: str | None = None,
|
62 |
-
document_index_device: str | None = None,
|
63 |
-
reader_device: str | None = None,
|
64 |
-
precision: int = 32,
|
65 |
-
retriever_precision: int | None = None,
|
66 |
-
document_index_precision: int | None = None,
|
67 |
-
reader_precision: int | None = None,
|
68 |
-
reader_kwargs: dict | None = None,
|
69 |
-
retriever_kwargs: dict | None = None,
|
70 |
-
candidates_preprocessing_fn: str | Callable | None = None,
|
71 |
-
top_k: int | None = None,
|
72 |
-
window_size: int | None = None,
|
73 |
-
window_stride: int | None = None,
|
74 |
-
**kwargs,
|
75 |
-
) -> None:
|
76 |
-
# retriever
|
77 |
-
retriever_device = retriever_device or device
|
78 |
-
document_index_device = document_index_device or device
|
79 |
-
retriever_precision = retriever_precision or precision
|
80 |
-
document_index_precision = document_index_precision or precision
|
81 |
-
if retriever is None and question_encoder is None:
|
82 |
-
raise ValueError(
|
83 |
-
"Either `retriever` or `question_encoder` must be provided"
|
84 |
-
)
|
85 |
-
if retriever is None:
|
86 |
-
self.retriever_kwargs = dict(
|
87 |
-
question_encoder=question_encoder,
|
88 |
-
passage_encoder=passage_encoder,
|
89 |
-
document_index=document_index,
|
90 |
-
device=retriever_device,
|
91 |
-
precision=retriever_precision,
|
92 |
-
index_device=document_index_device,
|
93 |
-
index_precision=document_index_precision,
|
94 |
-
)
|
95 |
-
# overwrite default_retriever_kwargs with retriever_kwargs
|
96 |
-
self.retriever_kwargs.update(retriever_kwargs or {})
|
97 |
-
retriever = GoldenRetriever(**self.retriever_kwargs)
|
98 |
-
retriever.training = False
|
99 |
-
retriever.eval()
|
100 |
-
self.retriever = retriever
|
101 |
-
|
102 |
-
# reader
|
103 |
-
self.reader_device = reader_device or device
|
104 |
-
self.reader_precision = reader_precision or precision
|
105 |
-
self.reader_kwargs = reader_kwargs
|
106 |
-
if isinstance(reader, str):
|
107 |
-
reader_kwargs = reader_kwargs or {}
|
108 |
-
reader = RelikReaderForSpanExtraction(reader, **reader_kwargs)
|
109 |
-
self.reader = reader
|
110 |
-
|
111 |
-
# windowization stuff
|
112 |
-
self.tokenizer = SpacyTokenizer(language="en")
|
113 |
-
self.window_manager: WindowManager | None = None
|
114 |
-
|
115 |
-
# candidates preprocessing
|
116 |
-
# TODO: maybe move this logic somewhere else
|
117 |
-
candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x)
|
118 |
-
if isinstance(candidates_preprocessing_fn, str):
|
119 |
-
candidates_preprocessing_fn = get_callable_from_string(
|
120 |
-
candidates_preprocessing_fn
|
121 |
-
)
|
122 |
-
self.candidates_preprocessing_fn = candidates_preprocessing_fn
|
123 |
-
|
124 |
-
# inference params
|
125 |
-
self.top_k = top_k
|
126 |
-
self.window_size = window_size
|
127 |
-
self.window_stride = window_stride
|
128 |
-
|
129 |
-
def __call__(
|
130 |
-
self,
|
131 |
-
text: Union[str, list],
|
132 |
-
top_k: Optional[int] = None,
|
133 |
-
window_size: Optional[int] = None,
|
134 |
-
window_stride: Optional[int] = None,
|
135 |
-
retriever_batch_size: Optional[int] = 32,
|
136 |
-
reader_batch_size: Optional[int] = 32,
|
137 |
-
return_also_windows: bool = False,
|
138 |
-
**kwargs,
|
139 |
-
) -> Union[RelikOutput, list[RelikOutput]]:
|
140 |
-
"""
|
141 |
-
Annotate a text with entities.
|
142 |
-
|
143 |
-
Args:
|
144 |
-
text (`str` or `list`):
|
145 |
-
The text to annotate. If a list is provided, each element of the list
|
146 |
-
will be annotated separately.
|
147 |
-
top_k (`int`, `optional`, defaults to `None`):
|
148 |
-
The number of candidates to retrieve for each window.
|
149 |
-
window_size (`int`, `optional`, defaults to `None`):
|
150 |
-
The size of the window. If `None`, the whole text will be annotated.
|
151 |
-
window_stride (`int`, `optional`, defaults to `None`):
|
152 |
-
The stride of the window. If `None`, there will be no overlap between windows.
|
153 |
-
retriever_batch_size (`int`, `optional`, defaults to `None`):
|
154 |
-
The batch size to use for the retriever. The whole input is the batch for the retriever.
|
155 |
-
reader_batch_size (`int`, `optional`, defaults to `None`):
|
156 |
-
The batch size to use for the reader. The whole input is the batch for the reader.
|
157 |
-
return_also_windows (`bool`, `optional`, defaults to `False`):
|
158 |
-
Whether to return the windows in the output.
|
159 |
-
**kwargs:
|
160 |
-
Additional keyword arguments to pass to the retriever and the reader.
|
161 |
-
|
162 |
-
Returns:
|
163 |
-
`RelikOutput` or `list[RelikOutput]`:
|
164 |
-
The annotated text. If a list was provided as input, a list of
|
165 |
-
`RelikOutput` objects will be returned.
|
166 |
-
"""
|
167 |
-
if top_k is None:
|
168 |
-
top_k = self.top_k or 100
|
169 |
-
if window_size is None:
|
170 |
-
window_size = self.window_size
|
171 |
-
if window_stride is None:
|
172 |
-
window_stride = self.window_stride
|
173 |
-
|
174 |
-
if isinstance(text, str):
|
175 |
-
text = [text]
|
176 |
-
|
177 |
-
if window_size is not None:
|
178 |
-
if self.window_manager is None:
|
179 |
-
self.window_manager = WindowManager(self.tokenizer)
|
180 |
-
|
181 |
-
if window_size == "sentence":
|
182 |
-
# todo: implement sentence windowizer
|
183 |
-
raise NotImplementedError("Sentence windowizer not implemented yet")
|
184 |
-
|
185 |
-
# if window_size < window_stride:
|
186 |
-
# raise ValueError(
|
187 |
-
# f"Window size ({window_size}) must be greater than window stride ({window_stride})"
|
188 |
-
# )
|
189 |
-
|
190 |
-
# window generator
|
191 |
-
windows = [
|
192 |
-
window
|
193 |
-
for doc_id, t in enumerate(text)
|
194 |
-
for window in self.window_manager.create_windows(
|
195 |
-
t,
|
196 |
-
window_size=window_size,
|
197 |
-
stride=window_stride,
|
198 |
-
doc_id=doc_id,
|
199 |
-
)
|
200 |
-
]
|
201 |
-
|
202 |
-
# retrieve candidates first
|
203 |
-
windows_candidates = []
|
204 |
-
# TODO: Move batching inside retriever
|
205 |
-
for batch in batch_generator(windows, batch_size=retriever_batch_size):
|
206 |
-
retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k)
|
207 |
-
windows_candidates.extend(
|
208 |
-
[[p.label for p in predictions] for predictions in retriever_out]
|
209 |
-
)
|
210 |
-
|
211 |
-
# add passage to the windows
|
212 |
-
for window, candidates in zip(windows, windows_candidates):
|
213 |
-
window.window_candidates = [
|
214 |
-
self.candidates_preprocessing_fn(c) for c in candidates
|
215 |
-
]
|
216 |
-
|
217 |
-
windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size)
|
218 |
-
windows = self.window_manager.merge_windows(windows)
|
219 |
-
|
220 |
-
# transform predictions into RelikOutput objects
|
221 |
-
output = []
|
222 |
-
for w in windows:
|
223 |
-
sample_output = RelikOutput(
|
224 |
-
text=text[w.doc_id],
|
225 |
-
labels=sorted(
|
226 |
-
[
|
227 |
-
EntitySpan(
|
228 |
-
start=ss, end=se, label=sl, text=text[w.doc_id][ss:se]
|
229 |
-
)
|
230 |
-
for ss, se, sl in w.predicted_window_labels_chars
|
231 |
-
],
|
232 |
-
key=lambda x: x.start,
|
233 |
-
),
|
234 |
-
)
|
235 |
-
output.append(sample_output)
|
236 |
-
|
237 |
-
if return_also_windows:
|
238 |
-
for i, sample_output in enumerate(output):
|
239 |
-
sample_output.windows = [w for w in windows if w.doc_id == i]
|
240 |
-
|
241 |
-
# if only one text was provided, return a single RelikOutput object
|
242 |
-
if len(output) == 1:
|
243 |
-
return output[0]
|
244 |
-
|
245 |
-
return output
|
246 |
-
|
247 |
-
@classmethod
|
248 |
-
def from_pretrained(
|
249 |
-
cls,
|
250 |
-
model_name_or_dir: Union[str, os.PathLike],
|
251 |
-
config_kwargs: Optional[Dict] = None,
|
252 |
-
config_file_name: str = CONFIG_NAME,
|
253 |
-
*args,
|
254 |
-
**kwargs,
|
255 |
-
) -> "Relik":
|
256 |
-
cache_dir = kwargs.pop("cache_dir", None)
|
257 |
-
force_download = kwargs.pop("force_download", False)
|
258 |
-
|
259 |
-
model_dir = from_cache(
|
260 |
-
model_name_or_dir,
|
261 |
-
filenames=[config_file_name],
|
262 |
-
cache_dir=cache_dir,
|
263 |
-
force_download=force_download,
|
264 |
-
)
|
265 |
-
|
266 |
-
config_path = model_dir / config_file_name
|
267 |
-
if not config_path.exists():
|
268 |
-
raise FileNotFoundError(
|
269 |
-
f"Model configuration file not found at {config_path}."
|
270 |
-
)
|
271 |
-
|
272 |
-
# overwrite config with config_kwargs
|
273 |
-
config = OmegaConf.load(config_path)
|
274 |
-
if config_kwargs is not None:
|
275 |
-
# TODO: check merging behavior
|
276 |
-
config = OmegaConf.merge(config, OmegaConf.create(config_kwargs))
|
277 |
-
# do we want to print the config? I like it
|
278 |
-
pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
|
279 |
-
|
280 |
-
# load relik from config
|
281 |
-
relik = hydra.utils.instantiate(config, *args, **kwargs)
|
282 |
-
|
283 |
-
return relik
|
284 |
-
|
285 |
-
def save_pretrained(
|
286 |
-
self,
|
287 |
-
output_dir: Union[str, os.PathLike],
|
288 |
-
config: Optional[Dict[str, Any]] = None,
|
289 |
-
config_file_name: Optional[str] = None,
|
290 |
-
save_weights: bool = False,
|
291 |
-
push_to_hub: bool = False,
|
292 |
-
model_id: Optional[str] = None,
|
293 |
-
organization: Optional[str] = None,
|
294 |
-
repo_name: Optional[str] = None,
|
295 |
-
**kwargs,
|
296 |
-
):
|
297 |
-
"""
|
298 |
-
Save the configuration of Relik to the specified directory as a YAML file.
|
299 |
-
|
300 |
-
Args:
|
301 |
-
output_dir (`str`):
|
302 |
-
The directory to save the configuration file to.
|
303 |
-
config (`Optional[Dict[str, Any]]`, `optional`):
|
304 |
-
The configuration to save. If `None`, the current configuration will be
|
305 |
-
saved. Defaults to `None`.
|
306 |
-
config_file_name (`Optional[str]`, `optional`):
|
307 |
-
The name of the configuration file. Defaults to `config.yaml`.
|
308 |
-
save_weights (`bool`, `optional`):
|
309 |
-
Whether to save the weights of the model. Defaults to `False`.
|
310 |
-
push_to_hub (`bool`, `optional`):
|
311 |
-
Whether to push the saved model to the hub. Defaults to `False`.
|
312 |
-
model_id (`Optional[str]`, `optional`):
|
313 |
-
The id of the model to push to the hub. If `None`, the name of the
|
314 |
-
directory will be used. Defaults to `None`.
|
315 |
-
organization (`Optional[str]`, `optional`):
|
316 |
-
The organization to push the model to. Defaults to `None`.
|
317 |
-
repo_name (`Optional[str]`, `optional`):
|
318 |
-
The name of the repository to push the model to. Defaults to `None`.
|
319 |
-
**kwargs:
|
320 |
-
Additional keyword arguments to pass to `OmegaConf.save`.
|
321 |
-
"""
|
322 |
-
if config is None:
|
323 |
-
# create a default config
|
324 |
-
config = {
|
325 |
-
"_target_": f"{self.__class__.__module__}.{self.__class__.__name__}"
|
326 |
-
}
|
327 |
-
if self.retriever is not None:
|
328 |
-
if self.retriever.question_encoder is not None:
|
329 |
-
config[
|
330 |
-
"question_encoder"
|
331 |
-
] = self.retriever.question_encoder.name_or_path
|
332 |
-
if self.retriever.passage_encoder is not None:
|
333 |
-
config[
|
334 |
-
"passage_encoder"
|
335 |
-
] = self.retriever.passage_encoder.name_or_path
|
336 |
-
if self.retriever.document_index is not None:
|
337 |
-
config["document_index"] = self.retriever.document_index.name_or_dir
|
338 |
-
if self.reader is not None:
|
339 |
-
config["reader"] = self.reader.model_path
|
340 |
-
|
341 |
-
config["retriever_kwargs"] = self.retriever_kwargs
|
342 |
-
config["reader_kwargs"] = self.reader_kwargs
|
343 |
-
# expand the fn as to be able to save it and load it later
|
344 |
-
config[
|
345 |
-
"candidates_preprocessing_fn"
|
346 |
-
] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}"
|
347 |
-
|
348 |
-
# these are model-specific and should be saved
|
349 |
-
config["top_k"] = self.top_k
|
350 |
-
config["window_size"] = self.window_size
|
351 |
-
config["window_stride"] = self.window_stride
|
352 |
-
|
353 |
-
config_file_name = config_file_name or CONFIG_NAME
|
354 |
-
|
355 |
-
# create the output directory
|
356 |
-
output_dir = Path(output_dir)
|
357 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
358 |
-
|
359 |
-
logger.info(f"Saving relik config to {output_dir / config_file_name}")
|
360 |
-
# pretty print the config
|
361 |
-
pprint(config, console=console_logger, expand_all=True)
|
362 |
-
OmegaConf.save(config, output_dir / config_file_name)
|
363 |
-
|
364 |
-
if save_weights:
|
365 |
-
model_id = model_id or output_dir.name
|
366 |
-
retriever_model_id = model_id + "-retriever"
|
367 |
-
# save weights
|
368 |
-
logger.info(f"Saving retriever to {output_dir / retriever_model_id}")
|
369 |
-
self.retriever.save_pretrained(
|
370 |
-
output_dir / retriever_model_id,
|
371 |
-
question_encoder_name=retriever_model_id + "-question-encoder",
|
372 |
-
passage_encoder_name=retriever_model_id + "-passage-encoder",
|
373 |
-
document_index_name=retriever_model_id + "-index",
|
374 |
-
push_to_hub=push_to_hub,
|
375 |
-
organization=organization,
|
376 |
-
repo_name=repo_name,
|
377 |
-
**kwargs,
|
378 |
-
)
|
379 |
-
reader_model_id = model_id + "-reader"
|
380 |
-
logger.info(f"Saving reader to {output_dir / reader_model_id}")
|
381 |
-
self.reader.save_pretrained(
|
382 |
-
output_dir / reader_model_id,
|
383 |
-
push_to_hub=push_to_hub,
|
384 |
-
organization=organization,
|
385 |
-
repo_name=repo_name,
|
386 |
-
**kwargs,
|
387 |
-
)
|
388 |
-
|
389 |
-
if push_to_hub:
|
390 |
-
# push to hub
|
391 |
-
logger.info(f"Pushing to hub")
|
392 |
-
model_id = model_id or output_dir.name
|
393 |
-
upload(output_dir, model_id, organization=organization, repo_name=repo_name)
|
394 |
-
|
395 |
-
|
396 |
-
def main():
|
397 |
-
from pprint import pprint
|
398 |
-
|
399 |
-
document_index = FaissDocumentIndex.from_pretrained(
|
400 |
-
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index",
|
401 |
-
config_kwargs={"_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex", "index_type": "IVFx,Flat"},
|
402 |
-
)
|
403 |
-
|
404 |
-
relik = Relik(
|
405 |
-
question_encoder="/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
|
406 |
-
document_index=document_index,
|
407 |
-
reader="/root/relik-spaces/models/relik-reader-aida-deberta-small",
|
408 |
-
device="cuda",
|
409 |
-
precision=16,
|
410 |
-
top_k=100,
|
411 |
-
window_size=32,
|
412 |
-
window_stride=16,
|
413 |
-
candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
|
414 |
-
)
|
415 |
-
|
416 |
-
input_text = """
|
417 |
-
Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore.
|
418 |
-
The 92-year-old billionaire did not disclose the trust to the government in July 2015.
|
419 |
-
Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty.
|
420 |
-
Ecclestone had been due to go on trial next month.
|
421 |
-
"""
|
422 |
-
|
423 |
-
preds = relik(input_text)
|
424 |
-
pprint(preds)
|
425 |
-
|
426 |
-
|
427 |
-
if __name__ == "__main__":
|
428 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/data/__init__.py
DELETED
File without changes
|
relik/inference/data/objects.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from dataclasses import dataclass
|
4 |
-
from typing import List, NamedTuple, Optional
|
5 |
-
|
6 |
-
from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
|
7 |
-
|
8 |
-
|
9 |
-
@dataclass
|
10 |
-
class Word:
|
11 |
-
"""
|
12 |
-
A word representation that includes text, index in the sentence, POS tag, lemma,
|
13 |
-
dependency relation, and similar information.
|
14 |
-
|
15 |
-
# Parameters
|
16 |
-
text : `str`, optional
|
17 |
-
The text representation.
|
18 |
-
index : `int`, optional
|
19 |
-
The word offset in the sentence.
|
20 |
-
lemma : `str`, optional
|
21 |
-
The lemma of this word.
|
22 |
-
pos : `str`, optional
|
23 |
-
The coarse-grained part of speech of this word.
|
24 |
-
dep : `str`, optional
|
25 |
-
The dependency relation for this word.
|
26 |
-
|
27 |
-
input_id : `int`, optional
|
28 |
-
Integer representation of the word, used to pass it to a model.
|
29 |
-
token_type_id : `int`, optional
|
30 |
-
Token type id used by some transformers.
|
31 |
-
attention_mask: `int`, optional
|
32 |
-
Attention mask used by transformers, indicates to the model which tokens should
|
33 |
-
be attended to, and which should not.
|
34 |
-
"""
|
35 |
-
|
36 |
-
text: str
|
37 |
-
index: int
|
38 |
-
start_char: Optional[int] = None
|
39 |
-
end_char: Optional[int] = None
|
40 |
-
# preprocessing fields
|
41 |
-
lemma: Optional[str] = None
|
42 |
-
pos: Optional[str] = None
|
43 |
-
dep: Optional[str] = None
|
44 |
-
head: Optional[int] = None
|
45 |
-
|
46 |
-
def __str__(self):
|
47 |
-
return self.text
|
48 |
-
|
49 |
-
def __repr__(self):
|
50 |
-
return self.__str__()
|
51 |
-
|
52 |
-
|
53 |
-
class EntitySpan(NamedTuple):
|
54 |
-
start: int
|
55 |
-
end: int
|
56 |
-
label: str
|
57 |
-
text: str
|
58 |
-
|
59 |
-
|
60 |
-
@dataclass
|
61 |
-
class RelikOutput:
|
62 |
-
text: str
|
63 |
-
labels: List[EntitySpan]
|
64 |
-
windows: Optional[List[RelikReaderSample]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/data/tokenizers/__init__.py
DELETED
@@ -1,89 +0,0 @@
|
|
1 |
-
SPACY_LANGUAGE_MAPPER = {
|
2 |
-
"ca": "ca_core_news_sm",
|
3 |
-
"da": "da_core_news_sm",
|
4 |
-
"de": "de_core_news_sm",
|
5 |
-
"el": "el_core_news_sm",
|
6 |
-
"en": "en_core_web_sm",
|
7 |
-
"es": "es_core_news_sm",
|
8 |
-
"fr": "fr_core_news_sm",
|
9 |
-
"it": "it_core_news_sm",
|
10 |
-
"ja": "ja_core_news_sm",
|
11 |
-
"lt": "lt_core_news_sm",
|
12 |
-
"mk": "mk_core_news_sm",
|
13 |
-
"nb": "nb_core_news_sm",
|
14 |
-
"nl": "nl_core_news_sm",
|
15 |
-
"pl": "pl_core_news_sm",
|
16 |
-
"pt": "pt_core_news_sm",
|
17 |
-
"ro": "ro_core_news_sm",
|
18 |
-
"ru": "ru_core_news_sm",
|
19 |
-
"xx": "xx_sent_ud_sm",
|
20 |
-
"zh": "zh_core_web_sm",
|
21 |
-
"ca_core_news_sm": "ca_core_news_sm",
|
22 |
-
"ca_core_news_md": "ca_core_news_md",
|
23 |
-
"ca_core_news_lg": "ca_core_news_lg",
|
24 |
-
"ca_core_news_trf": "ca_core_news_trf",
|
25 |
-
"da_core_news_sm": "da_core_news_sm",
|
26 |
-
"da_core_news_md": "da_core_news_md",
|
27 |
-
"da_core_news_lg": "da_core_news_lg",
|
28 |
-
"da_core_news_trf": "da_core_news_trf",
|
29 |
-
"de_core_news_sm": "de_core_news_sm",
|
30 |
-
"de_core_news_md": "de_core_news_md",
|
31 |
-
"de_core_news_lg": "de_core_news_lg",
|
32 |
-
"de_dep_news_trf": "de_dep_news_trf",
|
33 |
-
"el_core_news_sm": "el_core_news_sm",
|
34 |
-
"el_core_news_md": "el_core_news_md",
|
35 |
-
"el_core_news_lg": "el_core_news_lg",
|
36 |
-
"en_core_web_sm": "en_core_web_sm",
|
37 |
-
"en_core_web_md": "en_core_web_md",
|
38 |
-
"en_core_web_lg": "en_core_web_lg",
|
39 |
-
"en_core_web_trf": "en_core_web_trf",
|
40 |
-
"es_core_news_sm": "es_core_news_sm",
|
41 |
-
"es_core_news_md": "es_core_news_md",
|
42 |
-
"es_core_news_lg": "es_core_news_lg",
|
43 |
-
"es_dep_news_trf": "es_dep_news_trf",
|
44 |
-
"fr_core_news_sm": "fr_core_news_sm",
|
45 |
-
"fr_core_news_md": "fr_core_news_md",
|
46 |
-
"fr_core_news_lg": "fr_core_news_lg",
|
47 |
-
"fr_dep_news_trf": "fr_dep_news_trf",
|
48 |
-
"it_core_news_sm": "it_core_news_sm",
|
49 |
-
"it_core_news_md": "it_core_news_md",
|
50 |
-
"it_core_news_lg": "it_core_news_lg",
|
51 |
-
"ja_core_news_sm": "ja_core_news_sm",
|
52 |
-
"ja_core_news_md": "ja_core_news_md",
|
53 |
-
"ja_core_news_lg": "ja_core_news_lg",
|
54 |
-
"ja_dep_news_trf": "ja_dep_news_trf",
|
55 |
-
"lt_core_news_sm": "lt_core_news_sm",
|
56 |
-
"lt_core_news_md": "lt_core_news_md",
|
57 |
-
"lt_core_news_lg": "lt_core_news_lg",
|
58 |
-
"mk_core_news_sm": "mk_core_news_sm",
|
59 |
-
"mk_core_news_md": "mk_core_news_md",
|
60 |
-
"mk_core_news_lg": "mk_core_news_lg",
|
61 |
-
"nb_core_news_sm": "nb_core_news_sm",
|
62 |
-
"nb_core_news_md": "nb_core_news_md",
|
63 |
-
"nb_core_news_lg": "nb_core_news_lg",
|
64 |
-
"nl_core_news_sm": "nl_core_news_sm",
|
65 |
-
"nl_core_news_md": "nl_core_news_md",
|
66 |
-
"nl_core_news_lg": "nl_core_news_lg",
|
67 |
-
"pl_core_news_sm": "pl_core_news_sm",
|
68 |
-
"pl_core_news_md": "pl_core_news_md",
|
69 |
-
"pl_core_news_lg": "pl_core_news_lg",
|
70 |
-
"pt_core_news_sm": "pt_core_news_sm",
|
71 |
-
"pt_core_news_md": "pt_core_news_md",
|
72 |
-
"pt_core_news_lg": "pt_core_news_lg",
|
73 |
-
"ro_core_news_sm": "ro_core_news_sm",
|
74 |
-
"ro_core_news_md": "ro_core_news_md",
|
75 |
-
"ro_core_news_lg": "ro_core_news_lg",
|
76 |
-
"ru_core_news_sm": "ru_core_news_sm",
|
77 |
-
"ru_core_news_md": "ru_core_news_md",
|
78 |
-
"ru_core_news_lg": "ru_core_news_lg",
|
79 |
-
"xx_ent_wiki_sm": "xx_ent_wiki_sm",
|
80 |
-
"xx_sent_ud_sm": "xx_sent_ud_sm",
|
81 |
-
"zh_core_web_sm": "zh_core_web_sm",
|
82 |
-
"zh_core_web_md": "zh_core_web_md",
|
83 |
-
"zh_core_web_lg": "zh_core_web_lg",
|
84 |
-
"zh_core_web_trf": "zh_core_web_trf",
|
85 |
-
}
|
86 |
-
|
87 |
-
from relik.inference.data.tokenizers.regex_tokenizer import RegexTokenizer
|
88 |
-
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
89 |
-
from relik.inference.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/data/tokenizers/base_tokenizer.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
from typing import List, Union
|
2 |
-
|
3 |
-
from relik.inference.data.objects import Word
|
4 |
-
|
5 |
-
|
6 |
-
class BaseTokenizer:
|
7 |
-
"""
|
8 |
-
A :obj:`Tokenizer` splits strings of text into single words, optionally adds
|
9 |
-
pos tags and perform lemmatization.
|
10 |
-
"""
|
11 |
-
|
12 |
-
def __call__(
|
13 |
-
self,
|
14 |
-
texts: Union[str, List[str], List[List[str]]],
|
15 |
-
is_split_into_words: bool = False,
|
16 |
-
**kwargs
|
17 |
-
) -> List[List[Word]]:
|
18 |
-
"""
|
19 |
-
Tokenize the input into single words.
|
20 |
-
|
21 |
-
Args:
|
22 |
-
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
23 |
-
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
24 |
-
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
25 |
-
If :obj:`True` and the input is a string, the input is split on spaces.
|
26 |
-
|
27 |
-
Returns:
|
28 |
-
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
29 |
-
"""
|
30 |
-
raise NotImplementedError
|
31 |
-
|
32 |
-
def tokenize(self, text: str) -> List[Word]:
|
33 |
-
"""
|
34 |
-
Implements splitting words into tokens.
|
35 |
-
|
36 |
-
Args:
|
37 |
-
text (:obj:`str`):
|
38 |
-
Text to tokenize.
|
39 |
-
|
40 |
-
Returns:
|
41 |
-
:obj:`List[Word]`: The input text tokenized in single words.
|
42 |
-
|
43 |
-
"""
|
44 |
-
raise NotImplementedError
|
45 |
-
|
46 |
-
def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
|
47 |
-
"""
|
48 |
-
Implements batch splitting words into tokens.
|
49 |
-
|
50 |
-
Args:
|
51 |
-
texts (:obj:`List[str]`):
|
52 |
-
Batch of text to tokenize.
|
53 |
-
|
54 |
-
Returns:
|
55 |
-
:obj:`List[List[Word]]`: The input batch tokenized in single words.
|
56 |
-
|
57 |
-
"""
|
58 |
-
return [self.tokenize(text) for text in texts]
|
59 |
-
|
60 |
-
@staticmethod
|
61 |
-
def check_is_batched(
|
62 |
-
texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
|
63 |
-
):
|
64 |
-
"""
|
65 |
-
Check if input is batched or a single sample.
|
66 |
-
|
67 |
-
Args:
|
68 |
-
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
69 |
-
Text to check.
|
70 |
-
is_split_into_words (:obj:`bool`):
|
71 |
-
If :obj:`True` and the input is a string, the input is split on spaces.
|
72 |
-
|
73 |
-
Returns:
|
74 |
-
:obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
|
75 |
-
"""
|
76 |
-
return bool(
|
77 |
-
(not is_split_into_words and isinstance(texts, (list, tuple)))
|
78 |
-
or (
|
79 |
-
is_split_into_words
|
80 |
-
and isinstance(texts, (list, tuple))
|
81 |
-
and texts
|
82 |
-
and isinstance(texts[0], (list, tuple))
|
83 |
-
)
|
84 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/data/tokenizers/regex_tokenizer.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
from typing import List, Union
|
3 |
-
|
4 |
-
from overrides import overrides
|
5 |
-
|
6 |
-
from relik.inference.data.objects import Word
|
7 |
-
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
8 |
-
|
9 |
-
|
10 |
-
class RegexTokenizer(BaseTokenizer):
|
11 |
-
"""
|
12 |
-
A :obj:`Tokenizer` that splits the text based on a simple regex.
|
13 |
-
"""
|
14 |
-
|
15 |
-
def __init__(self):
|
16 |
-
super(RegexTokenizer, self).__init__()
|
17 |
-
# regex for splitting on spaces and punctuation and new lines
|
18 |
-
# self._regex = re.compile(r"\S+|[\[\](),.!?;:\"]|\\n")
|
19 |
-
self._regex = re.compile(
|
20 |
-
r"\w+|\$[\d\.]+|\S+", re.UNICODE | re.MULTILINE | re.DOTALL
|
21 |
-
)
|
22 |
-
|
23 |
-
def __call__(
|
24 |
-
self,
|
25 |
-
texts: Union[str, List[str], List[List[str]]],
|
26 |
-
is_split_into_words: bool = False,
|
27 |
-
**kwargs,
|
28 |
-
) -> List[List[Word]]:
|
29 |
-
"""
|
30 |
-
Tokenize the input into single words by splitting using a simple regex.
|
31 |
-
|
32 |
-
Args:
|
33 |
-
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
34 |
-
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
35 |
-
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
36 |
-
If :obj:`True` and the input is a string, the input is split on spaces.
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
40 |
-
|
41 |
-
Example::
|
42 |
-
|
43 |
-
>>> from relik.retriever.serve.tokenizers.regex_tokenizer import RegexTokenizer
|
44 |
-
|
45 |
-
>>> regex_tokenizer = RegexTokenizer()
|
46 |
-
>>> regex_tokenizer("Mary sold the car to John.")
|
47 |
-
|
48 |
-
"""
|
49 |
-
# check if input is batched or a single sample
|
50 |
-
is_batched = self.check_is_batched(texts, is_split_into_words)
|
51 |
-
|
52 |
-
if is_batched:
|
53 |
-
tokenized = self.tokenize_batch(texts)
|
54 |
-
else:
|
55 |
-
tokenized = self.tokenize(texts)
|
56 |
-
|
57 |
-
return tokenized
|
58 |
-
|
59 |
-
@overrides
|
60 |
-
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
61 |
-
if not isinstance(text, (str, list)):
|
62 |
-
raise ValueError(
|
63 |
-
f"text must be either `str` or `list`, found: `{type(text)}`"
|
64 |
-
)
|
65 |
-
|
66 |
-
if isinstance(text, list):
|
67 |
-
text = " ".join(text)
|
68 |
-
return [
|
69 |
-
Word(t[0], i, start_char=t[1], end_char=t[2])
|
70 |
-
for i, t in enumerate(
|
71 |
-
(m.group(0), m.start(), m.end()) for m in self._regex.finditer(text)
|
72 |
-
)
|
73 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/data/tokenizers/spacy_tokenizer.py
DELETED
@@ -1,228 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from typing import Dict, List, Tuple, Union
|
3 |
-
|
4 |
-
import spacy
|
5 |
-
|
6 |
-
# from ipa.common.utils import load_spacy
|
7 |
-
from overrides import overrides
|
8 |
-
from spacy.cli.download import download as spacy_download
|
9 |
-
from spacy.tokens import Doc
|
10 |
-
|
11 |
-
from relik.common.log import get_logger
|
12 |
-
from relik.inference.data.objects import Word
|
13 |
-
from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
|
14 |
-
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
15 |
-
|
16 |
-
logger = get_logger(level=logging.DEBUG)
|
17 |
-
|
18 |
-
# Spacy and Stanza stuff
|
19 |
-
|
20 |
-
LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
|
21 |
-
|
22 |
-
|
23 |
-
def load_spacy(
|
24 |
-
language: str,
|
25 |
-
pos_tags: bool = False,
|
26 |
-
lemma: bool = False,
|
27 |
-
parse: bool = False,
|
28 |
-
split_on_spaces: bool = False,
|
29 |
-
) -> spacy.Language:
|
30 |
-
"""
|
31 |
-
Download and load spacy model.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
language (:obj:`str`, defaults to :obj:`en`):
|
35 |
-
Language of the text to tokenize.
|
36 |
-
pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
|
37 |
-
If :obj:`True`, performs POS tagging with spacy model.
|
38 |
-
lemma (:obj:`bool`, optional, defaults to :obj:`False`):
|
39 |
-
If :obj:`True`, performs lemmatization with spacy model.
|
40 |
-
parse (:obj:`bool`, optional, defaults to :obj:`False`):
|
41 |
-
If :obj:`True`, performs dependency parsing with spacy model.
|
42 |
-
split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
|
43 |
-
If :obj:`True`, will split by spaces without performing tokenization.
|
44 |
-
|
45 |
-
Returns:
|
46 |
-
:obj:`spacy.Language`: The spacy model loaded.
|
47 |
-
"""
|
48 |
-
exclude = ["vectors", "textcat", "ner"]
|
49 |
-
if not pos_tags:
|
50 |
-
exclude.append("tagger")
|
51 |
-
if not lemma:
|
52 |
-
exclude.append("lemmatizer")
|
53 |
-
if not parse:
|
54 |
-
exclude.append("parser")
|
55 |
-
|
56 |
-
# check if the model is already loaded
|
57 |
-
# if so, there is no need to reload it
|
58 |
-
spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
|
59 |
-
if spacy_params not in LOADED_SPACY_MODELS:
|
60 |
-
try:
|
61 |
-
spacy_tagger = spacy.load(language, exclude=exclude)
|
62 |
-
except OSError:
|
63 |
-
logger.warning(
|
64 |
-
"Spacy model '%s' not found. Downloading and installing.", language
|
65 |
-
)
|
66 |
-
spacy_download(language)
|
67 |
-
spacy_tagger = spacy.load(language, exclude=exclude)
|
68 |
-
|
69 |
-
# if everything is disabled, return only the tokenizer
|
70 |
-
# for faster tokenization
|
71 |
-
# TODO: is it really faster?
|
72 |
-
# if len(exclude) >= 6:
|
73 |
-
# spacy_tagger = spacy_tagger.tokenizer
|
74 |
-
LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
|
75 |
-
|
76 |
-
return LOADED_SPACY_MODELS[spacy_params]
|
77 |
-
|
78 |
-
|
79 |
-
class SpacyTokenizer(BaseTokenizer):
|
80 |
-
"""
|
81 |
-
A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
|
82 |
-
|
83 |
-
Args:
|
84 |
-
language (:obj:`str`, optional, defaults to :obj:`en`):
|
85 |
-
Language of the text to tokenize.
|
86 |
-
return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
|
87 |
-
If :obj:`True`, performs POS tagging with spacy model.
|
88 |
-
return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
|
89 |
-
If :obj:`True`, performs lemmatization with spacy model.
|
90 |
-
return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
|
91 |
-
If :obj:`True`, performs dependency parsing with spacy model.
|
92 |
-
split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
|
93 |
-
If :obj:`True`, will split by spaces without performing tokenization.
|
94 |
-
use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
|
95 |
-
If :obj:`True`, will load the Stanza model on GPU.
|
96 |
-
"""
|
97 |
-
|
98 |
-
def __init__(
|
99 |
-
self,
|
100 |
-
language: str = "en",
|
101 |
-
return_pos_tags: bool = False,
|
102 |
-
return_lemmas: bool = False,
|
103 |
-
return_deps: bool = False,
|
104 |
-
split_on_spaces: bool = False,
|
105 |
-
use_gpu: bool = False,
|
106 |
-
):
|
107 |
-
super(SpacyTokenizer, self).__init__()
|
108 |
-
if language not in SPACY_LANGUAGE_MAPPER:
|
109 |
-
raise ValueError(
|
110 |
-
f"`{language}` language not supported. The supported "
|
111 |
-
f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
|
112 |
-
)
|
113 |
-
if use_gpu:
|
114 |
-
# load the model on GPU
|
115 |
-
# if the GPU is not available or not correctly configured,
|
116 |
-
# it will rise an error
|
117 |
-
spacy.require_gpu()
|
118 |
-
self.spacy = load_spacy(
|
119 |
-
SPACY_LANGUAGE_MAPPER[language],
|
120 |
-
return_pos_tags,
|
121 |
-
return_lemmas,
|
122 |
-
return_deps,
|
123 |
-
split_on_spaces,
|
124 |
-
)
|
125 |
-
self.split_on_spaces = split_on_spaces
|
126 |
-
|
127 |
-
def __call__(
|
128 |
-
self,
|
129 |
-
texts: Union[str, List[str], List[List[str]]],
|
130 |
-
is_split_into_words: bool = False,
|
131 |
-
**kwargs,
|
132 |
-
) -> Union[List[Word], List[List[Word]]]:
|
133 |
-
"""
|
134 |
-
Tokenize the input into single words using SpaCy models.
|
135 |
-
|
136 |
-
Args:
|
137 |
-
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
138 |
-
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
139 |
-
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
140 |
-
If :obj:`True` and the input is a string, the input is split on spaces.
|
141 |
-
|
142 |
-
Returns:
|
143 |
-
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
144 |
-
|
145 |
-
Example::
|
146 |
-
|
147 |
-
>>> from ipa import SpacyTokenizer
|
148 |
-
|
149 |
-
>>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
|
150 |
-
>>> spacy_tokenizer("Mary sold the car to John.")
|
151 |
-
|
152 |
-
"""
|
153 |
-
# check if input is batched or a single sample
|
154 |
-
is_batched = self.check_is_batched(texts, is_split_into_words)
|
155 |
-
if is_batched:
|
156 |
-
tokenized = self.tokenize_batch(texts)
|
157 |
-
else:
|
158 |
-
tokenized = self.tokenize(texts)
|
159 |
-
return tokenized
|
160 |
-
|
161 |
-
@overrides
|
162 |
-
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
163 |
-
if self.split_on_spaces:
|
164 |
-
if isinstance(text, str):
|
165 |
-
text = text.split(" ")
|
166 |
-
spaces = [True] * len(text)
|
167 |
-
text = Doc(self.spacy.vocab, words=text, spaces=spaces)
|
168 |
-
return self._clean_tokens(self.spacy(text))
|
169 |
-
|
170 |
-
@overrides
|
171 |
-
def tokenize_batch(
|
172 |
-
self, texts: Union[List[str], List[List[str]]]
|
173 |
-
) -> List[List[Word]]:
|
174 |
-
if self.split_on_spaces:
|
175 |
-
if isinstance(texts[0], str):
|
176 |
-
texts = [text.split(" ") for text in texts]
|
177 |
-
spaces = [[True] * len(text) for text in texts]
|
178 |
-
texts = [
|
179 |
-
Doc(self.spacy.vocab, words=text, spaces=space)
|
180 |
-
for text, space in zip(texts, spaces)
|
181 |
-
]
|
182 |
-
return [self._clean_tokens(tokens) for tokens in self.spacy.pipe(texts)]
|
183 |
-
|
184 |
-
@staticmethod
|
185 |
-
def _clean_tokens(tokens: Doc) -> List[Word]:
|
186 |
-
"""
|
187 |
-
Converts spaCy tokens to :obj:`Word`.
|
188 |
-
|
189 |
-
Args:
|
190 |
-
tokens (:obj:`spacy.tokens.Doc`):
|
191 |
-
Tokens from SpaCy model.
|
192 |
-
|
193 |
-
Returns:
|
194 |
-
:obj:`List[Word]`: The SpaCy model output converted into :obj:`Word` objects.
|
195 |
-
"""
|
196 |
-
words = [
|
197 |
-
Word(
|
198 |
-
token.text,
|
199 |
-
token.i,
|
200 |
-
token.idx,
|
201 |
-
token.idx + len(token),
|
202 |
-
token.lemma_,
|
203 |
-
token.pos_,
|
204 |
-
token.dep_,
|
205 |
-
token.head.i,
|
206 |
-
)
|
207 |
-
for token in tokens
|
208 |
-
]
|
209 |
-
return words
|
210 |
-
|
211 |
-
|
212 |
-
class WhitespaceSpacyTokenizer:
|
213 |
-
"""Simple white space tokenizer for SpaCy."""
|
214 |
-
|
215 |
-
def __init__(self, vocab):
|
216 |
-
self.vocab = vocab
|
217 |
-
|
218 |
-
def __call__(self, text):
|
219 |
-
if isinstance(text, str):
|
220 |
-
words = text.split(" ")
|
221 |
-
elif isinstance(text, list):
|
222 |
-
words = text
|
223 |
-
else:
|
224 |
-
raise ValueError(
|
225 |
-
f"text must be either `str` or `list`, found: `{type(text)}`"
|
226 |
-
)
|
227 |
-
spaces = [True] * len(words)
|
228 |
-
return Doc(self.vocab, words=words, spaces=spaces)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/data/tokenizers/whitespace_tokenizer.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
from typing import List, Union
|
3 |
-
|
4 |
-
from overrides import overrides
|
5 |
-
|
6 |
-
from relik.inference.data.objects import Word
|
7 |
-
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
8 |
-
|
9 |
-
|
10 |
-
class WhitespaceTokenizer(BaseTokenizer):
|
11 |
-
"""
|
12 |
-
A :obj:`Tokenizer` that splits the text on spaces.
|
13 |
-
"""
|
14 |
-
|
15 |
-
def __init__(self):
|
16 |
-
super(WhitespaceTokenizer, self).__init__()
|
17 |
-
self.whitespace_regex = re.compile(r"\S+")
|
18 |
-
|
19 |
-
def __call__(
|
20 |
-
self,
|
21 |
-
texts: Union[str, List[str], List[List[str]]],
|
22 |
-
is_split_into_words: bool = False,
|
23 |
-
**kwargs,
|
24 |
-
) -> List[List[Word]]:
|
25 |
-
"""
|
26 |
-
Tokenize the input into single words by splitting on spaces.
|
27 |
-
|
28 |
-
Args:
|
29 |
-
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
30 |
-
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
31 |
-
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
32 |
-
If :obj:`True` and the input is a string, the input is split on spaces.
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
36 |
-
|
37 |
-
Example::
|
38 |
-
|
39 |
-
>>> from nlp_preprocessing_wrappers import WhitespaceTokenizer
|
40 |
-
|
41 |
-
>>> whitespace_tokenizer = WhitespaceTokenizer()
|
42 |
-
>>> whitespace_tokenizer("Mary sold the car to John .")
|
43 |
-
|
44 |
-
"""
|
45 |
-
# check if input is batched or a single sample
|
46 |
-
is_batched = self.check_is_batched(texts, is_split_into_words)
|
47 |
-
|
48 |
-
if is_batched:
|
49 |
-
tokenized = self.tokenize_batch(texts)
|
50 |
-
else:
|
51 |
-
tokenized = self.tokenize(texts)
|
52 |
-
|
53 |
-
return tokenized
|
54 |
-
|
55 |
-
@overrides
|
56 |
-
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
57 |
-
if not isinstance(text, (str, list)):
|
58 |
-
raise ValueError(
|
59 |
-
f"text must be either `str` or `list`, found: `{type(text)}`"
|
60 |
-
)
|
61 |
-
|
62 |
-
if isinstance(text, list):
|
63 |
-
text = " ".join(text)
|
64 |
-
return [
|
65 |
-
Word(t[0], i, start_char=t[1], end_char=t[2])
|
66 |
-
for i, t in enumerate(
|
67 |
-
(m.group(0), m.start(), m.end())
|
68 |
-
for m in self.whitespace_regex.finditer(text)
|
69 |
-
)
|
70 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/data/window/__init__.py
DELETED
File without changes
|
relik/inference/data/window/manager.py
DELETED
@@ -1,262 +0,0 @@
|
|
1 |
-
import collections
|
2 |
-
import itertools
|
3 |
-
from dataclasses import dataclass
|
4 |
-
from typing import List, Optional, Set, Tuple
|
5 |
-
|
6 |
-
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
7 |
-
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
8 |
-
|
9 |
-
|
10 |
-
@dataclass
|
11 |
-
class Window:
|
12 |
-
doc_id: int
|
13 |
-
window_id: int
|
14 |
-
text: str
|
15 |
-
tokens: List[str]
|
16 |
-
doc_topic: Optional[str]
|
17 |
-
offset: int
|
18 |
-
token2char_start: dict
|
19 |
-
token2char_end: dict
|
20 |
-
window_candidates: Optional[List[str]] = None
|
21 |
-
|
22 |
-
|
23 |
-
class WindowManager:
|
24 |
-
def __init__(self, tokenizer: BaseTokenizer) -> None:
|
25 |
-
self.tokenizer = tokenizer
|
26 |
-
|
27 |
-
def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]:
|
28 |
-
tokenized_document = self.tokenizer(document)
|
29 |
-
tokens = []
|
30 |
-
tokens_char_mapping = []
|
31 |
-
for token in tokenized_document:
|
32 |
-
tokens.append(token.text)
|
33 |
-
tokens_char_mapping.append((token.start_char, token.end_char))
|
34 |
-
return tokens, tokens_char_mapping
|
35 |
-
|
36 |
-
def create_windows(
|
37 |
-
self,
|
38 |
-
document: str,
|
39 |
-
window_size: int,
|
40 |
-
stride: int,
|
41 |
-
doc_id: int = 0,
|
42 |
-
doc_topic: str = None,
|
43 |
-
) -> List[RelikReaderSample]:
|
44 |
-
document_tokens, tokens_char_mapping = self.tokenize(document)
|
45 |
-
if doc_topic is None:
|
46 |
-
doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
|
47 |
-
document_windows = []
|
48 |
-
if len(document_tokens) <= window_size:
|
49 |
-
text = document
|
50 |
-
# relik_reader_sample = RelikReaderSample()
|
51 |
-
document_windows.append(
|
52 |
-
# Window(
|
53 |
-
RelikReaderSample(
|
54 |
-
doc_id=doc_id,
|
55 |
-
window_id=0,
|
56 |
-
text=text,
|
57 |
-
tokens=document_tokens,
|
58 |
-
doc_topic=doc_topic,
|
59 |
-
offset=0,
|
60 |
-
token2char_start={
|
61 |
-
str(i): tokens_char_mapping[i][0]
|
62 |
-
for i in range(len(document_tokens))
|
63 |
-
},
|
64 |
-
token2char_end={
|
65 |
-
str(i): tokens_char_mapping[i][1]
|
66 |
-
for i in range(len(document_tokens))
|
67 |
-
},
|
68 |
-
)
|
69 |
-
)
|
70 |
-
else:
|
71 |
-
for window_id, i in enumerate(range(0, len(document_tokens), stride)):
|
72 |
-
# if the last stride is smaller than the window size, then we can
|
73 |
-
# include more tokens form the previous window.
|
74 |
-
if i != 0 and i + window_size > len(document_tokens):
|
75 |
-
overflowing_tokens = i + window_size - len(document_tokens)
|
76 |
-
if overflowing_tokens >= stride:
|
77 |
-
break
|
78 |
-
i -= overflowing_tokens
|
79 |
-
|
80 |
-
involved_token_indices = list(
|
81 |
-
range(i, min(i + window_size, len(document_tokens) - 1))
|
82 |
-
)
|
83 |
-
window_tokens = [document_tokens[j] for j in involved_token_indices]
|
84 |
-
window_text_start = tokens_char_mapping[involved_token_indices[0]][0]
|
85 |
-
window_text_end = tokens_char_mapping[involved_token_indices[-1]][1]
|
86 |
-
text = document[window_text_start:window_text_end]
|
87 |
-
document_windows.append(
|
88 |
-
# Window(
|
89 |
-
RelikReaderSample(
|
90 |
-
# dict(
|
91 |
-
doc_id=doc_id,
|
92 |
-
window_id=window_id,
|
93 |
-
text=text,
|
94 |
-
tokens=window_tokens,
|
95 |
-
doc_topic=doc_topic,
|
96 |
-
offset=window_text_start,
|
97 |
-
token2char_start={
|
98 |
-
str(i): tokens_char_mapping[ti][0]
|
99 |
-
for i, ti in enumerate(involved_token_indices)
|
100 |
-
},
|
101 |
-
token2char_end={
|
102 |
-
str(i): tokens_char_mapping[ti][1]
|
103 |
-
for i, ti in enumerate(involved_token_indices)
|
104 |
-
},
|
105 |
-
# )
|
106 |
-
)
|
107 |
-
)
|
108 |
-
return document_windows
|
109 |
-
|
110 |
-
def merge_windows(
|
111 |
-
self, windows: List[RelikReaderSample]
|
112 |
-
) -> List[RelikReaderSample]:
|
113 |
-
windows_by_doc_id = collections.defaultdict(list)
|
114 |
-
for window in windows:
|
115 |
-
windows_by_doc_id[window.doc_id].append(window)
|
116 |
-
|
117 |
-
merged_window_by_doc = {
|
118 |
-
doc_id: self.merge_doc_windows(doc_windows)
|
119 |
-
for doc_id, doc_windows in windows_by_doc_id.items()
|
120 |
-
}
|
121 |
-
|
122 |
-
return list(merged_window_by_doc.values())
|
123 |
-
|
124 |
-
def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
|
125 |
-
if len(windows) == 1:
|
126 |
-
return windows[0]
|
127 |
-
|
128 |
-
if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
|
129 |
-
windows = sorted(windows, key=(lambda x: x.offset))
|
130 |
-
|
131 |
-
window_accumulator = windows[0]
|
132 |
-
|
133 |
-
for next_window in windows[1:]:
|
134 |
-
window_accumulator = self._merge_window_pair(
|
135 |
-
window_accumulator, next_window
|
136 |
-
)
|
137 |
-
|
138 |
-
return window_accumulator
|
139 |
-
|
140 |
-
def _merge_tokens(
|
141 |
-
self, window1: RelikReaderSample, window2: RelikReaderSample
|
142 |
-
) -> Tuple[list, dict, dict]:
|
143 |
-
w1_tokens = window1.tokens[1:-1]
|
144 |
-
w2_tokens = window2.tokens[1:-1]
|
145 |
-
|
146 |
-
# find intersection
|
147 |
-
tokens_intersection = None
|
148 |
-
for k in reversed(range(1, len(w1_tokens))):
|
149 |
-
if w1_tokens[-k:] == w2_tokens[:k]:
|
150 |
-
tokens_intersection = k
|
151 |
-
break
|
152 |
-
assert tokens_intersection is not None, (
|
153 |
-
f"{window1.doc_id} - {window1.sent_id} - {window1.offset}"
|
154 |
-
+ f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n"
|
155 |
-
+ f"w1 tokens: {w1_tokens}\n"
|
156 |
-
+ f"w2 tokens: {w2_tokens}\n"
|
157 |
-
)
|
158 |
-
|
159 |
-
final_tokens = (
|
160 |
-
[window1.tokens[0]] # CLS
|
161 |
-
+ w1_tokens
|
162 |
-
+ w2_tokens[tokens_intersection:]
|
163 |
-
+ [window1.tokens[-1]] # SEP
|
164 |
-
)
|
165 |
-
|
166 |
-
w2_starting_offset = len(w1_tokens) - tokens_intersection
|
167 |
-
|
168 |
-
def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
|
169 |
-
final_t2c = dict()
|
170 |
-
final_t2c.update(t2c1)
|
171 |
-
for t, c in t2c2.items():
|
172 |
-
t = int(t)
|
173 |
-
if t < tokens_intersection:
|
174 |
-
continue
|
175 |
-
final_t2c[str(t + w2_starting_offset)] = c
|
176 |
-
return final_t2c
|
177 |
-
|
178 |
-
return (
|
179 |
-
final_tokens,
|
180 |
-
merge_char_mapping(window1.token2char_start, window2.token2char_start),
|
181 |
-
merge_char_mapping(window1.token2char_end, window2.token2char_end),
|
182 |
-
)
|
183 |
-
|
184 |
-
def _merge_span_annotation(
|
185 |
-
self, span_annotation1: List[list], span_annotation2: List[list]
|
186 |
-
) -> List[list]:
|
187 |
-
uniq_store = set()
|
188 |
-
final_span_annotation_store = []
|
189 |
-
for span_annotation in itertools.chain(span_annotation1, span_annotation2):
|
190 |
-
span_annotation_id = tuple(span_annotation)
|
191 |
-
if span_annotation_id not in uniq_store:
|
192 |
-
uniq_store.add(span_annotation_id)
|
193 |
-
final_span_annotation_store.append(span_annotation)
|
194 |
-
return sorted(final_span_annotation_store, key=lambda x: x[0])
|
195 |
-
|
196 |
-
def _merge_predictions(
|
197 |
-
self,
|
198 |
-
window1: RelikReaderSample,
|
199 |
-
window2: RelikReaderSample,
|
200 |
-
) -> Tuple[Set[Tuple[int, int, str]], dict]:
|
201 |
-
merged_predictions = window1.predicted_window_labels_chars.union(
|
202 |
-
window2.predicted_window_labels_chars
|
203 |
-
)
|
204 |
-
|
205 |
-
span_title_probabilities = dict()
|
206 |
-
# probabilities
|
207 |
-
for span_prediction, predicted_probs in itertools.chain(
|
208 |
-
window1.probs_window_labels_chars.items(),
|
209 |
-
window2.probs_window_labels_chars.items(),
|
210 |
-
):
|
211 |
-
if span_prediction not in span_title_probabilities:
|
212 |
-
span_title_probabilities[span_prediction] = predicted_probs
|
213 |
-
|
214 |
-
return merged_predictions, span_title_probabilities
|
215 |
-
|
216 |
-
def _merge_window_pair(
|
217 |
-
self,
|
218 |
-
window1: RelikReaderSample,
|
219 |
-
window2: RelikReaderSample,
|
220 |
-
) -> RelikReaderSample:
|
221 |
-
merging_output = dict()
|
222 |
-
|
223 |
-
if getattr(window1, "doc_id", None) is not None:
|
224 |
-
assert window1.doc_id == window2.doc_id
|
225 |
-
|
226 |
-
if getattr(window1, "offset", None) is not None:
|
227 |
-
assert (
|
228 |
-
window1.offset < window2.offset
|
229 |
-
), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
|
230 |
-
|
231 |
-
merging_output["doc_id"] = window1.doc_id
|
232 |
-
merging_output["offset"] = window2.offset
|
233 |
-
|
234 |
-
m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
|
235 |
-
window1, window2
|
236 |
-
)
|
237 |
-
|
238 |
-
window_labels = None
|
239 |
-
if getattr(window1, "window_labels", None) is not None:
|
240 |
-
window_labels = self._merge_span_annotation(
|
241 |
-
window1.window_labels, window2.window_labels
|
242 |
-
)
|
243 |
-
(
|
244 |
-
predicted_window_labels_chars,
|
245 |
-
probs_window_labels_chars,
|
246 |
-
) = self._merge_predictions(
|
247 |
-
window1,
|
248 |
-
window2,
|
249 |
-
)
|
250 |
-
|
251 |
-
merging_output.update(
|
252 |
-
dict(
|
253 |
-
tokens=m_tokens,
|
254 |
-
token2char_start=m_token2char_start,
|
255 |
-
token2char_end=m_token2char_end,
|
256 |
-
window_labels=window_labels,
|
257 |
-
predicted_window_labels_chars=predicted_window_labels_chars,
|
258 |
-
probs_window_labels_chars=probs_window_labels_chars,
|
259 |
-
)
|
260 |
-
)
|
261 |
-
|
262 |
-
return RelikReaderSample(**merging_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/gerbil.py
DELETED
@@ -1,254 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
import re
|
5 |
-
import sys
|
6 |
-
from http.server import BaseHTTPRequestHandler, HTTPServer
|
7 |
-
from typing import Iterator, List, Optional, Tuple
|
8 |
-
|
9 |
-
from relik.inference.annotator import Relik
|
10 |
-
from relik.inference.data.objects import RelikOutput
|
11 |
-
|
12 |
-
# sys.path += ['../']
|
13 |
-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
14 |
-
|
15 |
-
|
16 |
-
import logging
|
17 |
-
|
18 |
-
logger = logging.getLogger(__name__)
|
19 |
-
|
20 |
-
|
21 |
-
class GerbilAlbyManager:
|
22 |
-
def __init__(
|
23 |
-
self,
|
24 |
-
annotator: Optional[Relik] = None,
|
25 |
-
response_logger_dir: Optional[str] = None,
|
26 |
-
) -> None:
|
27 |
-
self.annotator = annotator
|
28 |
-
self.response_logger_dir = response_logger_dir
|
29 |
-
self.predictions_counter = 0
|
30 |
-
self.labels_mapping = None
|
31 |
-
|
32 |
-
def annotate(self, document: str):
|
33 |
-
relik_output: RelikOutput = self.annotator(document)
|
34 |
-
annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels]
|
35 |
-
if self.labels_mapping is not None:
|
36 |
-
return [
|
37 |
-
(ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations
|
38 |
-
]
|
39 |
-
return annotations
|
40 |
-
|
41 |
-
def set_mapping_file(self, mapping_file_path: str):
|
42 |
-
with open(mapping_file_path) as f:
|
43 |
-
labels_mapping = json.load(f)
|
44 |
-
self.labels_mapping = {v: k for k, v in labels_mapping.items()}
|
45 |
-
|
46 |
-
def write_response_bundle(
|
47 |
-
self,
|
48 |
-
document: str,
|
49 |
-
new_document: str,
|
50 |
-
annotations: list,
|
51 |
-
mapped_annotations: list,
|
52 |
-
) -> None:
|
53 |
-
if self.response_logger_dir is None:
|
54 |
-
return
|
55 |
-
|
56 |
-
if not os.path.isdir(self.response_logger_dir):
|
57 |
-
os.mkdir(self.response_logger_dir)
|
58 |
-
|
59 |
-
with open(
|
60 |
-
f"{self.response_logger_dir}/{self.predictions_counter}.json", "w"
|
61 |
-
) as f:
|
62 |
-
out_json_obj = dict(
|
63 |
-
document=document,
|
64 |
-
new_document=new_document,
|
65 |
-
annotations=annotations,
|
66 |
-
mapped_annotations=mapped_annotations,
|
67 |
-
)
|
68 |
-
|
69 |
-
out_json_obj["span_annotations"] = [
|
70 |
-
(ss, se, document[ss:se], label) for (ss, se, label) in annotations
|
71 |
-
]
|
72 |
-
|
73 |
-
out_json_obj["span_mapped_annotations"] = [
|
74 |
-
(ss, se, new_document[ss:se], label)
|
75 |
-
for (ss, se, label) in mapped_annotations
|
76 |
-
]
|
77 |
-
|
78 |
-
json.dump(out_json_obj, f, indent=2)
|
79 |
-
|
80 |
-
self.predictions_counter += 1
|
81 |
-
|
82 |
-
|
83 |
-
manager = GerbilAlbyManager()
|
84 |
-
|
85 |
-
|
86 |
-
def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]:
|
87 |
-
pattern_subs = {
|
88 |
-
"-LPR- ": " (",
|
89 |
-
"-RPR-": ")",
|
90 |
-
"\n\n": "\n",
|
91 |
-
"-LRB-": "(",
|
92 |
-
"-RRB-": ")",
|
93 |
-
'","': ",",
|
94 |
-
}
|
95 |
-
|
96 |
-
document_acc = document
|
97 |
-
curr_offset = 0
|
98 |
-
char2offset = []
|
99 |
-
|
100 |
-
matchings = re.finditer("({})".format("|".join(pattern_subs)), document)
|
101 |
-
for span_matching in sorted(matchings, key=lambda x: x.span()[0]):
|
102 |
-
span_start, span_end = span_matching.span()
|
103 |
-
span_start -= curr_offset
|
104 |
-
span_end -= curr_offset
|
105 |
-
|
106 |
-
span_text = document_acc[span_start:span_end]
|
107 |
-
span_sub = pattern_subs[span_text]
|
108 |
-
document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:]
|
109 |
-
|
110 |
-
offset = len(span_text) - len(span_sub)
|
111 |
-
curr_offset += offset
|
112 |
-
|
113 |
-
char2offset.append((span_start + len(span_sub), curr_offset))
|
114 |
-
|
115 |
-
return document_acc, char2offset
|
116 |
-
|
117 |
-
|
118 |
-
def map_back_annotations(
|
119 |
-
annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]]
|
120 |
-
) -> Iterator[Tuple[int, int, str]]:
|
121 |
-
def map_char(char_idx: int) -> int:
|
122 |
-
current_offset = 0
|
123 |
-
for offset_idx, offset_value in char_mapping:
|
124 |
-
if char_idx >= offset_idx:
|
125 |
-
current_offset = offset_value
|
126 |
-
else:
|
127 |
-
break
|
128 |
-
return char_idx + current_offset
|
129 |
-
|
130 |
-
for ss, se, label in annotations:
|
131 |
-
yield map_char(ss), map_char(se), label
|
132 |
-
|
133 |
-
|
134 |
-
def annotate(document: str) -> List[Tuple[int, int, str]]:
|
135 |
-
new_document, mapping = preprocess_document(document)
|
136 |
-
logger.info("Mapping: " + str(mapping))
|
137 |
-
logger.info("Document: " + str(document))
|
138 |
-
annotations = [
|
139 |
-
(cs, ce, label.replace(" ", "_"))
|
140 |
-
for cs, ce, label in manager.annotate(new_document)
|
141 |
-
]
|
142 |
-
logger.info("New document: " + str(new_document))
|
143 |
-
mapped_annotations = (
|
144 |
-
list(map_back_annotations(annotations, mapping))
|
145 |
-
if len(mapping) > 0
|
146 |
-
else annotations
|
147 |
-
)
|
148 |
-
|
149 |
-
logger.info(
|
150 |
-
"Annotations: "
|
151 |
-
+ str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations])
|
152 |
-
)
|
153 |
-
|
154 |
-
manager.write_response_bundle(
|
155 |
-
document, new_document, mapped_annotations, annotations
|
156 |
-
)
|
157 |
-
|
158 |
-
if not all(
|
159 |
-
[
|
160 |
-
new_document[ss:se] == document[mss:mse]
|
161 |
-
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
162 |
-
]
|
163 |
-
):
|
164 |
-
diff_mappings = [
|
165 |
-
(new_document[ss:se], document[mss:mse])
|
166 |
-
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
167 |
-
]
|
168 |
-
return None
|
169 |
-
assert all(
|
170 |
-
[
|
171 |
-
document[mss:mse] == new_document[ss:se]
|
172 |
-
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
173 |
-
]
|
174 |
-
), (mapped_annotations, annotations)
|
175 |
-
|
176 |
-
return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations]
|
177 |
-
|
178 |
-
|
179 |
-
class GetHandler(BaseHTTPRequestHandler):
|
180 |
-
def do_POST(self):
|
181 |
-
content_length = int(self.headers["Content-Length"])
|
182 |
-
post_data = self.rfile.read(content_length)
|
183 |
-
self.send_response(200)
|
184 |
-
self.end_headers()
|
185 |
-
doc_text = read_json(post_data)
|
186 |
-
# try:
|
187 |
-
response = annotate(doc_text)
|
188 |
-
|
189 |
-
self.wfile.write(bytes(json.dumps(response), "utf-8"))
|
190 |
-
return
|
191 |
-
|
192 |
-
|
193 |
-
def read_json(post_data):
|
194 |
-
data = json.loads(post_data.decode("utf-8"))
|
195 |
-
# logger.info("received data:", data)
|
196 |
-
text = data["text"]
|
197 |
-
# spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
|
198 |
-
return text
|
199 |
-
|
200 |
-
|
201 |
-
def parse_args() -> argparse.Namespace:
|
202 |
-
parser = argparse.ArgumentParser()
|
203 |
-
parser.add_argument("--relik-model-name", required=True)
|
204 |
-
parser.add_argument("--responses-log-dir")
|
205 |
-
parser.add_argument("--log-file", default="logs/logging.txt")
|
206 |
-
parser.add_argument("--mapping-file")
|
207 |
-
return parser.parse_args()
|
208 |
-
|
209 |
-
|
210 |
-
def main():
|
211 |
-
args = parse_args()
|
212 |
-
|
213 |
-
# init manager
|
214 |
-
manager.response_logger_dir = args.responses_log_dir
|
215 |
-
# manager.annotator = Relik.from_pretrained(args.relik_model_name)
|
216 |
-
|
217 |
-
print("Debugging, not using you relik model but an hardcoded one.")
|
218 |
-
manager.annotator = Relik(
|
219 |
-
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
|
220 |
-
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
|
221 |
-
reader="relik/reader/models/relik-reader-deberta-base-new-data",
|
222 |
-
window_size=32,
|
223 |
-
window_stride=16,
|
224 |
-
candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()),
|
225 |
-
)
|
226 |
-
|
227 |
-
if args.mapping_file is not None:
|
228 |
-
manager.set_mapping_file(args.mapping_file)
|
229 |
-
|
230 |
-
port = 6654
|
231 |
-
server = HTTPServer(("localhost", port), GetHandler)
|
232 |
-
logger.info(f"Starting server at http://localhost:{port}")
|
233 |
-
|
234 |
-
# Create a file handler and set its level
|
235 |
-
file_handler = logging.FileHandler(args.log_file)
|
236 |
-
file_handler.setLevel(logging.DEBUG)
|
237 |
-
|
238 |
-
# Create a log formatter and set it on the handler
|
239 |
-
formatter = logging.Formatter(
|
240 |
-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
241 |
-
)
|
242 |
-
file_handler.setFormatter(formatter)
|
243 |
-
|
244 |
-
# Add the file handler to the logger
|
245 |
-
logger.addHandler(file_handler)
|
246 |
-
|
247 |
-
try:
|
248 |
-
server.serve_forever()
|
249 |
-
except KeyboardInterrupt:
|
250 |
-
exit(0)
|
251 |
-
|
252 |
-
|
253 |
-
if __name__ == "__main__":
|
254 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/preprocessing.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
def wikipedia_title_and_openings_preprocessing(
|
2 |
-
wikipedia_title_and_openings: str, sepator: str = " <def>"
|
3 |
-
):
|
4 |
-
return wikipedia_title_and_openings.split(sepator, 1)[0]
|
|
|
|
|
|
|
|
|
|
relik/inference/serve/__init__.py
DELETED
File without changes
|
relik/inference/serve/backend/__init__.py
DELETED
File without changes
|
relik/inference/serve/backend/relik.py
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import List, Optional, Union
|
4 |
-
|
5 |
-
from relik.common.utils import is_package_available
|
6 |
-
from relik.inference.annotator import Relik
|
7 |
-
|
8 |
-
if not is_package_available("fastapi"):
|
9 |
-
raise ImportError(
|
10 |
-
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
|
11 |
-
)
|
12 |
-
from fastapi import FastAPI, HTTPException
|
13 |
-
|
14 |
-
if not is_package_available("ray"):
|
15 |
-
raise ImportError(
|
16 |
-
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
|
17 |
-
)
|
18 |
-
from ray import serve
|
19 |
-
|
20 |
-
from relik.common.log import get_logger
|
21 |
-
from relik.inference.serve.backend.utils import (
|
22 |
-
RayParameterManager,
|
23 |
-
ServerParameterManager,
|
24 |
-
)
|
25 |
-
from relik.retriever.data.utils import batch_generator
|
26 |
-
|
27 |
-
logger = get_logger(__name__, level=logging.INFO)
|
28 |
-
|
29 |
-
VERSION = {} # type: ignore
|
30 |
-
with open(
|
31 |
-
Path(__file__).parent.parent.parent.parent / "version.py", "r"
|
32 |
-
) as version_file:
|
33 |
-
exec(version_file.read(), VERSION)
|
34 |
-
|
35 |
-
# Env variables for server
|
36 |
-
SERVER_MANAGER = ServerParameterManager()
|
37 |
-
RAY_MANAGER = RayParameterManager()
|
38 |
-
|
39 |
-
app = FastAPI(
|
40 |
-
title="ReLiK",
|
41 |
-
version=VERSION["VERSION"],
|
42 |
-
description="ReLiK REST API",
|
43 |
-
)
|
44 |
-
|
45 |
-
|
46 |
-
@serve.deployment(
|
47 |
-
ray_actor_options={
|
48 |
-
"num_gpus": RAY_MANAGER.num_gpus
|
49 |
-
if (
|
50 |
-
SERVER_MANAGER.retriver_device == "cuda"
|
51 |
-
or SERVER_MANAGER.reader_device == "cuda"
|
52 |
-
)
|
53 |
-
else 0
|
54 |
-
},
|
55 |
-
autoscaling_config={
|
56 |
-
"min_replicas": RAY_MANAGER.min_replicas,
|
57 |
-
"max_replicas": RAY_MANAGER.max_replicas,
|
58 |
-
},
|
59 |
-
)
|
60 |
-
@serve.ingress(app)
|
61 |
-
class RelikServer:
|
62 |
-
def __init__(
|
63 |
-
self,
|
64 |
-
question_encoder: str,
|
65 |
-
document_index: str,
|
66 |
-
passage_encoder: Optional[str] = None,
|
67 |
-
reader_encoder: Optional[str] = None,
|
68 |
-
top_k: int = 100,
|
69 |
-
retriver_device: str = "cpu",
|
70 |
-
reader_device: str = "cpu",
|
71 |
-
index_device: Optional[str] = None,
|
72 |
-
precision: int = 32,
|
73 |
-
index_precision: Optional[int] = None,
|
74 |
-
use_faiss: bool = False,
|
75 |
-
window_batch_size: int = 32,
|
76 |
-
window_size: int = 32,
|
77 |
-
window_stride: int = 16,
|
78 |
-
split_on_spaces: bool = False,
|
79 |
-
):
|
80 |
-
# parameters
|
81 |
-
self.question_encoder = question_encoder
|
82 |
-
self.passage_encoder = passage_encoder
|
83 |
-
self.reader_encoder = reader_encoder
|
84 |
-
self.document_index = document_index
|
85 |
-
self.top_k = top_k
|
86 |
-
self.retriver_device = retriver_device
|
87 |
-
self.index_device = index_device or retriver_device
|
88 |
-
self.reader_device = reader_device
|
89 |
-
self.precision = precision
|
90 |
-
self.index_precision = index_precision or precision
|
91 |
-
self.use_faiss = use_faiss
|
92 |
-
self.window_batch_size = window_batch_size
|
93 |
-
self.window_size = window_size
|
94 |
-
self.window_stride = window_stride
|
95 |
-
self.split_on_spaces = split_on_spaces
|
96 |
-
|
97 |
-
# log stuff for debugging
|
98 |
-
logger.info("Initializing RelikServer with parameters:")
|
99 |
-
logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
|
100 |
-
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
|
101 |
-
logger.info(f"READER_ENCODER: {self.reader_encoder}")
|
102 |
-
logger.info(f"DOCUMENT_INDEX: {self.document_index}")
|
103 |
-
logger.info(f"TOP_K: {self.top_k}")
|
104 |
-
logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}")
|
105 |
-
logger.info(f"READER_DEVICE: {self.reader_device}")
|
106 |
-
logger.info(f"INDEX_DEVICE: {self.index_device}")
|
107 |
-
logger.info(f"PRECISION: {self.precision}")
|
108 |
-
logger.info(f"INDEX_PRECISION: {self.index_precision}")
|
109 |
-
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
|
110 |
-
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
|
111 |
-
|
112 |
-
self.relik = Relik(
|
113 |
-
question_encoder=self.question_encoder,
|
114 |
-
passage_encoder=self.passage_encoder,
|
115 |
-
document_index=self.document_index,
|
116 |
-
reader=self.reader_encoder,
|
117 |
-
retriever_device=self.retriver_device,
|
118 |
-
document_index_device=self.index_device,
|
119 |
-
reader_device=self.reader_device,
|
120 |
-
retriever_precision=self.precision,
|
121 |
-
document_index_precision=self.index_precision,
|
122 |
-
reader_precision=self.precision,
|
123 |
-
)
|
124 |
-
|
125 |
-
# @serve.batch()
|
126 |
-
async def handle_batch(self, documents: List[str]) -> List:
|
127 |
-
return self.relik(
|
128 |
-
documents,
|
129 |
-
top_k=self.top_k,
|
130 |
-
window_size=self.window_size,
|
131 |
-
window_stride=self.window_stride,
|
132 |
-
batch_size=self.window_batch_size,
|
133 |
-
)
|
134 |
-
|
135 |
-
@app.post("/api/entities")
|
136 |
-
async def entities_endpoint(
|
137 |
-
self,
|
138 |
-
documents: Union[str, List[str]],
|
139 |
-
):
|
140 |
-
try:
|
141 |
-
# normalize input
|
142 |
-
if isinstance(documents, str):
|
143 |
-
documents = [documents]
|
144 |
-
if document_topics is not None:
|
145 |
-
if isinstance(document_topics, str):
|
146 |
-
document_topics = [document_topics]
|
147 |
-
assert len(documents) == len(document_topics)
|
148 |
-
# get predictions for the retriever
|
149 |
-
return await self.handle_batch(documents, document_topics)
|
150 |
-
except Exception as e:
|
151 |
-
# log the entire stack trace
|
152 |
-
logger.exception(e)
|
153 |
-
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
154 |
-
|
155 |
-
@app.post("/api/gerbil")
|
156 |
-
async def gerbil_endpoint(self, documents: Union[str, List[str]]):
|
157 |
-
try:
|
158 |
-
# normalize input
|
159 |
-
if isinstance(documents, str):
|
160 |
-
documents = [documents]
|
161 |
-
|
162 |
-
# output list
|
163 |
-
windows_passages = []
|
164 |
-
# split documents into windows
|
165 |
-
document_windows = [
|
166 |
-
window
|
167 |
-
for doc_id, document in enumerate(documents)
|
168 |
-
for window in self.window_manager(
|
169 |
-
self.tokenizer,
|
170 |
-
document,
|
171 |
-
window_size=self.window_size,
|
172 |
-
stride=self.window_stride,
|
173 |
-
doc_id=doc_id,
|
174 |
-
)
|
175 |
-
]
|
176 |
-
|
177 |
-
# get text and topic from document windows and create new list
|
178 |
-
model_inputs = [
|
179 |
-
(window.text, window.doc_topic) for window in document_windows
|
180 |
-
]
|
181 |
-
|
182 |
-
# batch generator
|
183 |
-
for batch in batch_generator(
|
184 |
-
model_inputs, batch_size=self.window_batch_size
|
185 |
-
):
|
186 |
-
text, text_pair = zip(*batch)
|
187 |
-
batch_predictions = await self.handle_batch_retriever(text, text_pair)
|
188 |
-
windows_passages.extend(
|
189 |
-
[
|
190 |
-
[p.label for p in predictions]
|
191 |
-
for predictions in batch_predictions
|
192 |
-
]
|
193 |
-
)
|
194 |
-
|
195 |
-
# add passage to document windows
|
196 |
-
for window, passages in zip(document_windows, windows_passages):
|
197 |
-
# clean up passages (remove everything after first <def> tag if present)
|
198 |
-
passages = [c.split(" <def>", 1)[0] for c in passages]
|
199 |
-
window.window_candidates = passages
|
200 |
-
|
201 |
-
# return document windows
|
202 |
-
return document_windows
|
203 |
-
|
204 |
-
except Exception as e:
|
205 |
-
# log the entire stack trace
|
206 |
-
logger.exception(e)
|
207 |
-
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
208 |
-
|
209 |
-
|
210 |
-
server = RelikServer.bind(**vars(SERVER_MANAGER))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/serve/backend/retriever.py
DELETED
@@ -1,206 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import List, Optional, Union
|
4 |
-
|
5 |
-
from relik.common.utils import is_package_available
|
6 |
-
|
7 |
-
if not is_package_available("fastapi"):
|
8 |
-
raise ImportError(
|
9 |
-
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
|
10 |
-
)
|
11 |
-
from fastapi import FastAPI, HTTPException
|
12 |
-
|
13 |
-
if not is_package_available("ray"):
|
14 |
-
raise ImportError(
|
15 |
-
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
|
16 |
-
)
|
17 |
-
from ray import serve
|
18 |
-
|
19 |
-
from relik.common.log import get_logger
|
20 |
-
from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer
|
21 |
-
from relik.inference.data.window.manager import WindowManager
|
22 |
-
from relik.inference.serve.backend.utils import (
|
23 |
-
RayParameterManager,
|
24 |
-
ServerParameterManager,
|
25 |
-
)
|
26 |
-
from relik.retriever.data.utils import batch_generator
|
27 |
-
from relik.retriever.pytorch_modules import GoldenRetriever
|
28 |
-
|
29 |
-
logger = get_logger(__name__, level=logging.INFO)
|
30 |
-
|
31 |
-
VERSION = {} # type: ignore
|
32 |
-
with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file:
|
33 |
-
exec(version_file.read(), VERSION)
|
34 |
-
|
35 |
-
# Env variables for server
|
36 |
-
SERVER_MANAGER = ServerParameterManager()
|
37 |
-
RAY_MANAGER = RayParameterManager()
|
38 |
-
|
39 |
-
app = FastAPI(
|
40 |
-
title="Golden Retriever",
|
41 |
-
version=VERSION["VERSION"],
|
42 |
-
description="Golden Retriever REST API",
|
43 |
-
)
|
44 |
-
|
45 |
-
|
46 |
-
@serve.deployment(
|
47 |
-
ray_actor_options={
|
48 |
-
"num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0
|
49 |
-
},
|
50 |
-
autoscaling_config={
|
51 |
-
"min_replicas": RAY_MANAGER.min_replicas,
|
52 |
-
"max_replicas": RAY_MANAGER.max_replicas,
|
53 |
-
},
|
54 |
-
)
|
55 |
-
@serve.ingress(app)
|
56 |
-
class GoldenRetrieverServer:
|
57 |
-
def __init__(
|
58 |
-
self,
|
59 |
-
question_encoder: str,
|
60 |
-
document_index: str,
|
61 |
-
passage_encoder: Optional[str] = None,
|
62 |
-
top_k: int = 100,
|
63 |
-
device: str = "cpu",
|
64 |
-
index_device: Optional[str] = None,
|
65 |
-
precision: int = 32,
|
66 |
-
index_precision: Optional[int] = None,
|
67 |
-
use_faiss: bool = False,
|
68 |
-
window_batch_size: int = 32,
|
69 |
-
window_size: int = 32,
|
70 |
-
window_stride: int = 16,
|
71 |
-
split_on_spaces: bool = False,
|
72 |
-
):
|
73 |
-
# parameters
|
74 |
-
self.question_encoder = question_encoder
|
75 |
-
self.passage_encoder = passage_encoder
|
76 |
-
self.document_index = document_index
|
77 |
-
self.top_k = top_k
|
78 |
-
self.device = device
|
79 |
-
self.index_device = index_device or device
|
80 |
-
self.precision = precision
|
81 |
-
self.index_precision = index_precision or precision
|
82 |
-
self.use_faiss = use_faiss
|
83 |
-
self.window_batch_size = window_batch_size
|
84 |
-
self.window_size = window_size
|
85 |
-
self.window_stride = window_stride
|
86 |
-
self.split_on_spaces = split_on_spaces
|
87 |
-
|
88 |
-
# log stuff for debugging
|
89 |
-
logger.info("Initializing GoldenRetrieverServer with parameters:")
|
90 |
-
logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
|
91 |
-
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
|
92 |
-
logger.info(f"DOCUMENT_INDEX: {self.document_index}")
|
93 |
-
logger.info(f"TOP_K: {self.top_k}")
|
94 |
-
logger.info(f"DEVICE: {self.device}")
|
95 |
-
logger.info(f"INDEX_DEVICE: {self.index_device}")
|
96 |
-
logger.info(f"PRECISION: {self.precision}")
|
97 |
-
logger.info(f"INDEX_PRECISION: {self.index_precision}")
|
98 |
-
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
|
99 |
-
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
|
100 |
-
|
101 |
-
self.retriever = GoldenRetriever(
|
102 |
-
question_encoder=self.question_encoder,
|
103 |
-
passage_encoder=self.passage_encoder,
|
104 |
-
document_index=self.document_index,
|
105 |
-
device=self.device,
|
106 |
-
index_device=self.index_device,
|
107 |
-
index_precision=self.index_precision,
|
108 |
-
)
|
109 |
-
self.retriever.eval()
|
110 |
-
|
111 |
-
if self.split_on_spaces:
|
112 |
-
logger.info("Using WhitespaceTokenizer")
|
113 |
-
self.tokenizer = WhitespaceTokenizer()
|
114 |
-
# logger.info("Using RegexTokenizer")
|
115 |
-
# self.tokenizer = RegexTokenizer()
|
116 |
-
else:
|
117 |
-
logger.info("Using SpacyTokenizer")
|
118 |
-
self.tokenizer = SpacyTokenizer(language="en")
|
119 |
-
|
120 |
-
self.window_manager = WindowManager(tokenizer=self.tokenizer)
|
121 |
-
|
122 |
-
# @serve.batch()
|
123 |
-
async def handle_batch(
|
124 |
-
self, documents: List[str], document_topics: List[str]
|
125 |
-
) -> List:
|
126 |
-
return self.retriever.retrieve(
|
127 |
-
documents, text_pair=document_topics, k=self.top_k, precision=self.precision
|
128 |
-
)
|
129 |
-
|
130 |
-
@app.post("/api/retrieve")
|
131 |
-
async def retrieve_endpoint(
|
132 |
-
self,
|
133 |
-
documents: Union[str, List[str]],
|
134 |
-
document_topics: Optional[Union[str, List[str]]] = None,
|
135 |
-
):
|
136 |
-
try:
|
137 |
-
# normalize input
|
138 |
-
if isinstance(documents, str):
|
139 |
-
documents = [documents]
|
140 |
-
if document_topics is not None:
|
141 |
-
if isinstance(document_topics, str):
|
142 |
-
document_topics = [document_topics]
|
143 |
-
assert len(documents) == len(document_topics)
|
144 |
-
# get predictions
|
145 |
-
return await self.handle_batch(documents, document_topics)
|
146 |
-
except Exception as e:
|
147 |
-
# log the entire stack trace
|
148 |
-
logger.exception(e)
|
149 |
-
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
150 |
-
|
151 |
-
@app.post("/api/gerbil")
|
152 |
-
async def gerbil_endpoint(self, documents: Union[str, List[str]]):
|
153 |
-
try:
|
154 |
-
# normalize input
|
155 |
-
if isinstance(documents, str):
|
156 |
-
documents = [documents]
|
157 |
-
|
158 |
-
# output list
|
159 |
-
windows_passages = []
|
160 |
-
# split documents into windows
|
161 |
-
document_windows = [
|
162 |
-
window
|
163 |
-
for doc_id, document in enumerate(documents)
|
164 |
-
for window in self.window_manager(
|
165 |
-
self.tokenizer,
|
166 |
-
document,
|
167 |
-
window_size=self.window_size,
|
168 |
-
stride=self.window_stride,
|
169 |
-
doc_id=doc_id,
|
170 |
-
)
|
171 |
-
]
|
172 |
-
|
173 |
-
# get text and topic from document windows and create new list
|
174 |
-
model_inputs = [
|
175 |
-
(window.text, window.doc_topic) for window in document_windows
|
176 |
-
]
|
177 |
-
|
178 |
-
# batch generator
|
179 |
-
for batch in batch_generator(
|
180 |
-
model_inputs, batch_size=self.window_batch_size
|
181 |
-
):
|
182 |
-
text, text_pair = zip(*batch)
|
183 |
-
batch_predictions = await self.handle_batch(text, text_pair)
|
184 |
-
windows_passages.extend(
|
185 |
-
[
|
186 |
-
[p.label for p in predictions]
|
187 |
-
for predictions in batch_predictions
|
188 |
-
]
|
189 |
-
)
|
190 |
-
|
191 |
-
# add passage to document windows
|
192 |
-
for window, passages in zip(document_windows, windows_passages):
|
193 |
-
# clean up passages (remove everything after first <def> tag if present)
|
194 |
-
passages = [c.split(" <def>", 1)[0] for c in passages]
|
195 |
-
window.window_candidates = passages
|
196 |
-
|
197 |
-
# return document windows
|
198 |
-
return document_windows
|
199 |
-
|
200 |
-
except Exception as e:
|
201 |
-
# log the entire stack trace
|
202 |
-
logger.exception(e)
|
203 |
-
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
204 |
-
|
205 |
-
|
206 |
-
server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/serve/backend/utils.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from typing import Union
|
4 |
-
|
5 |
-
|
6 |
-
@dataclass
|
7 |
-
class ServerParameterManager:
|
8 |
-
retriver_device: str = os.environ.get("RETRIEVER_DEVICE", "cpu")
|
9 |
-
reader_device: str = os.environ.get("READER_DEVICE", "cpu")
|
10 |
-
index_device: str = os.environ.get("INDEX_DEVICE", retriver_device)
|
11 |
-
precision: Union[str, int] = os.environ.get("PRECISION", "fp32")
|
12 |
-
index_precision: Union[str, int] = os.environ.get("INDEX_PRECISION", precision)
|
13 |
-
question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
|
14 |
-
passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
|
15 |
-
document_index: str = os.environ.get("DOCUMENT_INDEX", None)
|
16 |
-
reader_encoder: str = os.environ.get("READER_ENCODER", None)
|
17 |
-
top_k: int = int(os.environ.get("TOP_K", 100))
|
18 |
-
use_faiss: bool = os.environ.get("USE_FAISS", False)
|
19 |
-
window_batch_size: int = int(os.environ.get("WINDOW_BATCH_SIZE", 32))
|
20 |
-
window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
|
21 |
-
window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
|
22 |
-
split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
|
23 |
-
|
24 |
-
|
25 |
-
class RayParameterManager:
|
26 |
-
def __init__(self) -> None:
|
27 |
-
self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
|
28 |
-
self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
|
29 |
-
self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/serve/frontend/__init__.py
DELETED
File without changes
|
relik/inference/serve/frontend/relik.py
DELETED
@@ -1,231 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import time
|
4 |
-
from pathlib import Path
|
5 |
-
|
6 |
-
import requests
|
7 |
-
import streamlit as st
|
8 |
-
from spacy import displacy
|
9 |
-
from streamlit_extras.badges import badge
|
10 |
-
from streamlit_extras.stylable_container import stylable_container
|
11 |
-
|
12 |
-
RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
|
13 |
-
|
14 |
-
import random
|
15 |
-
|
16 |
-
|
17 |
-
def get_random_color(ents):
|
18 |
-
colors = {}
|
19 |
-
random_colors = generate_pastel_colors(len(ents))
|
20 |
-
for ent in ents:
|
21 |
-
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
|
22 |
-
return colors
|
23 |
-
|
24 |
-
|
25 |
-
def floatrange(start, stop, steps):
|
26 |
-
if int(steps) == 1:
|
27 |
-
return [stop]
|
28 |
-
return [
|
29 |
-
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
|
30 |
-
]
|
31 |
-
|
32 |
-
|
33 |
-
def hsl_to_rgb(h, s, l):
|
34 |
-
def hue_2_rgb(v1, v2, v_h):
|
35 |
-
while v_h < 0.0:
|
36 |
-
v_h += 1.0
|
37 |
-
while v_h > 1.0:
|
38 |
-
v_h -= 1.0
|
39 |
-
if 6 * v_h < 1.0:
|
40 |
-
return v1 + (v2 - v1) * 6.0 * v_h
|
41 |
-
if 2 * v_h < 1.0:
|
42 |
-
return v2
|
43 |
-
if 3 * v_h < 2.0:
|
44 |
-
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
|
45 |
-
return v1
|
46 |
-
|
47 |
-
# if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
|
48 |
-
# if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
|
49 |
-
|
50 |
-
r, b, g = (l * 255,) * 3
|
51 |
-
if s != 0.0:
|
52 |
-
if l < 0.5:
|
53 |
-
var_2 = l * (1.0 + s)
|
54 |
-
else:
|
55 |
-
var_2 = (l + s) - (s * l)
|
56 |
-
var_1 = 2.0 * l - var_2
|
57 |
-
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
|
58 |
-
g = 255 * hue_2_rgb(var_1, var_2, h)
|
59 |
-
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
|
60 |
-
|
61 |
-
return int(round(r)), int(round(g)), int(round(b))
|
62 |
-
|
63 |
-
|
64 |
-
def generate_pastel_colors(n):
|
65 |
-
"""Return different pastel colours.
|
66 |
-
|
67 |
-
Input:
|
68 |
-
n (integer) : The number of colors to return
|
69 |
-
|
70 |
-
Output:
|
71 |
-
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
|
72 |
-
|
73 |
-
Example:
|
74 |
-
>>> print generate_pastel_colors(5)
|
75 |
-
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
|
76 |
-
"""
|
77 |
-
if n == 0:
|
78 |
-
return []
|
79 |
-
|
80 |
-
# To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
|
81 |
-
start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
|
82 |
-
saturation = 1.0
|
83 |
-
lightness = 0.8
|
84 |
-
# We take points around the chromatic circle (hue):
|
85 |
-
# (Note: we generate n+1 colors, then drop the last one ([:-1]) because
|
86 |
-
# it equals the first one (hue 0 = hue 1))
|
87 |
-
return [
|
88 |
-
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
|
89 |
-
for hue in floatrange(start_hue, start_hue + 1, n + 1)
|
90 |
-
][:-1]
|
91 |
-
|
92 |
-
|
93 |
-
def set_sidebar(css):
|
94 |
-
white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
|
95 |
-
with st.sidebar:
|
96 |
-
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
97 |
-
st.image(
|
98 |
-
"http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
|
99 |
-
use_column_width=True,
|
100 |
-
)
|
101 |
-
st.markdown("## ReLiK")
|
102 |
-
st.write(
|
103 |
-
f"""
|
104 |
-
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")}
|
105 |
-
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")}
|
106 |
-
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")}
|
107 |
-
""",
|
108 |
-
unsafe_allow_html=True,
|
109 |
-
)
|
110 |
-
st.markdown("## Sapienza NLP")
|
111 |
-
st.write(
|
112 |
-
f"""
|
113 |
-
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")}
|
114 |
-
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")}
|
115 |
-
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")}
|
116 |
-
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")}
|
117 |
-
""",
|
118 |
-
unsafe_allow_html=True,
|
119 |
-
)
|
120 |
-
|
121 |
-
|
122 |
-
def get_el_annotations(response):
|
123 |
-
# swap labels key with ents
|
124 |
-
response["ents"] = response.pop("labels")
|
125 |
-
label_in_text = set(l["label"] for l in response["ents"])
|
126 |
-
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
|
127 |
-
return response, options
|
128 |
-
|
129 |
-
|
130 |
-
def set_intro(css):
|
131 |
-
# intro
|
132 |
-
st.markdown("# ReLik")
|
133 |
-
st.markdown(
|
134 |
-
"### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
|
135 |
-
)
|
136 |
-
# st.markdown(
|
137 |
-
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
138 |
-
# "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
|
139 |
-
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
140 |
-
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
141 |
-
# )
|
142 |
-
badge(type="github", name="sapienzanlp/relik")
|
143 |
-
badge(type="pypi", name="relik")
|
144 |
-
|
145 |
-
|
146 |
-
def run_client():
|
147 |
-
with open(Path(__file__).parent / "style.css") as f:
|
148 |
-
css = f.read()
|
149 |
-
|
150 |
-
st.set_page_config(
|
151 |
-
page_title="ReLik",
|
152 |
-
page_icon="🦮",
|
153 |
-
layout="wide",
|
154 |
-
)
|
155 |
-
set_sidebar(css)
|
156 |
-
set_intro(css)
|
157 |
-
|
158 |
-
# text input
|
159 |
-
text = st.text_area(
|
160 |
-
"Enter Text Below:",
|
161 |
-
value="Obama went to Rome for a quick vacation.",
|
162 |
-
height=200,
|
163 |
-
max_chars=500,
|
164 |
-
)
|
165 |
-
|
166 |
-
with stylable_container(
|
167 |
-
key="annotate_button",
|
168 |
-
css_styles="""
|
169 |
-
button {
|
170 |
-
background-color: #802433;
|
171 |
-
color: white;
|
172 |
-
border-radius: 25px;
|
173 |
-
}
|
174 |
-
""",
|
175 |
-
):
|
176 |
-
submit = st.button("Annotate")
|
177 |
-
# submit = st.button("Run")
|
178 |
-
|
179 |
-
# ReLik API call
|
180 |
-
if submit:
|
181 |
-
text = text.strip()
|
182 |
-
if text:
|
183 |
-
st.markdown("####")
|
184 |
-
st.markdown("#### Entity Linking")
|
185 |
-
with st.spinner(text="In progress"):
|
186 |
-
response = requests.post(RELIK, json=text)
|
187 |
-
if response.status_code != 200:
|
188 |
-
st.error("Error: {}".format(response.status_code))
|
189 |
-
else:
|
190 |
-
response = response.json()
|
191 |
-
|
192 |
-
# Entity Linking
|
193 |
-
# with stylable_container(
|
194 |
-
# key="container_with_border",
|
195 |
-
# css_styles="""
|
196 |
-
# {
|
197 |
-
# border: 1px solid rgba(49, 51, 63, 0.2);
|
198 |
-
# border-radius: 0.5rem;
|
199 |
-
# padding: 0.5rem;
|
200 |
-
# padding-bottom: 2rem;
|
201 |
-
# }
|
202 |
-
# """,
|
203 |
-
# ):
|
204 |
-
# st.markdown("##")
|
205 |
-
dict_of_ents, options = get_el_annotations(response=response)
|
206 |
-
display = displacy.render(
|
207 |
-
dict_of_ents, manual=True, style="ent", options=options
|
208 |
-
)
|
209 |
-
display = display.replace("\n", " ")
|
210 |
-
# wsd_display = re.sub(
|
211 |
-
# r"(wiki::\d+\w)",
|
212 |
-
# r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
|
213 |
-
# language.upper()
|
214 |
-
# ),
|
215 |
-
# wsd_display,
|
216 |
-
# )
|
217 |
-
with st.container():
|
218 |
-
st.write(display, unsafe_allow_html=True)
|
219 |
-
|
220 |
-
st.markdown("####")
|
221 |
-
st.markdown("#### Relation Extraction")
|
222 |
-
|
223 |
-
with st.container():
|
224 |
-
st.write("Coming :)", unsafe_allow_html=True)
|
225 |
-
|
226 |
-
else:
|
227 |
-
st.error("Please enter some text.")
|
228 |
-
|
229 |
-
|
230 |
-
if __name__ == "__main__":
|
231 |
-
run_client()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/inference/serve/frontend/style.css
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
/* Sidebar */
|
2 |
-
.eczjsme11 {
|
3 |
-
background-color: #802433;
|
4 |
-
}
|
5 |
-
|
6 |
-
.st-emotion-cache-10oheav h2 {
|
7 |
-
color: white;
|
8 |
-
}
|
9 |
-
|
10 |
-
.st-emotion-cache-10oheav li {
|
11 |
-
color: white;
|
12 |
-
}
|
13 |
-
|
14 |
-
/* Main */
|
15 |
-
a:link {
|
16 |
-
text-decoration: none;
|
17 |
-
color: white;
|
18 |
-
}
|
19 |
-
|
20 |
-
a:visited {
|
21 |
-
text-decoration: none;
|
22 |
-
color: white;
|
23 |
-
}
|
24 |
-
|
25 |
-
a:hover {
|
26 |
-
text-decoration: none;
|
27 |
-
color: rgba(255, 255, 255, 0.871);
|
28 |
-
}
|
29 |
-
|
30 |
-
a:active {
|
31 |
-
text-decoration: none;
|
32 |
-
color: white;
|
33 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/__init__.py
DELETED
File without changes
|
relik/reader/conf/config.yaml
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
# Required to make the "experiments" dir the default one for the output of the models
|
2 |
-
hydra:
|
3 |
-
run:
|
4 |
-
dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
5 |
-
|
6 |
-
model_name: relik-reader-deberta-base # used to name the model in wandb and output dir
|
7 |
-
project_name: relik-reader # used to name the project in wandb
|
8 |
-
|
9 |
-
|
10 |
-
defaults:
|
11 |
-
- _self_
|
12 |
-
- training: base
|
13 |
-
- model: base
|
14 |
-
- data: base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/conf/data/base.yaml
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
train_dataset_path: "relik/reader/data/train.jsonl"
|
2 |
-
val_dataset_path: "relik/reader/data/testa.jsonl"
|
3 |
-
|
4 |
-
train_dataset:
|
5 |
-
_target_: "relik.reader.relik_reader_data.RelikDataset"
|
6 |
-
transformer_model: "${model.model.transformer_model}"
|
7 |
-
materialize_samples: False
|
8 |
-
shuffle_candidates: 0.5
|
9 |
-
random_drop_gold_candidates: 0.05
|
10 |
-
noise_param: 0.0
|
11 |
-
for_inference: False
|
12 |
-
tokens_per_batch: 4096
|
13 |
-
special_symbols: null
|
14 |
-
|
15 |
-
val_dataset:
|
16 |
-
_target_: "relik.reader.relik_reader_data.RelikDataset"
|
17 |
-
transformer_model: "${model.model.transformer_model}"
|
18 |
-
materialize_samples: False
|
19 |
-
shuffle_candidates: False
|
20 |
-
for_inference: True
|
21 |
-
special_symbols: null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/conf/data/re.yaml
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl"
|
2 |
-
val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl"
|
3 |
-
test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl"
|
4 |
-
|
5 |
-
relations_definitions:
|
6 |
-
/people/person/nationality: "nationality"
|
7 |
-
/sports/sports_team/location: "sports team location"
|
8 |
-
/location/country/administrative_divisions: "administrative divisions"
|
9 |
-
/business/company/major_shareholders: "shareholders"
|
10 |
-
/people/ethnicity/people: "ethnicity"
|
11 |
-
/people/ethnicity/geographic_distribution: "geographic distributi6on"
|
12 |
-
/business/company_shareholder/major_shareholder_of: "major shareholder"
|
13 |
-
/location/location/contains: "location"
|
14 |
-
/business/company/founders: "founders"
|
15 |
-
/business/person/company: "company"
|
16 |
-
/business/company/advisors: "advisor"
|
17 |
-
/people/deceased_person/place_of_death: "place of death"
|
18 |
-
/business/company/industry: "industry"
|
19 |
-
/people/person/ethnicity: "ethnic background"
|
20 |
-
/people/person/place_of_birth: "place of birth"
|
21 |
-
/location/administrative_division/country: "country of an administration division"
|
22 |
-
/people/person/place_lived: "place lived"
|
23 |
-
/sports/sports_team_location/teams: "sports team"
|
24 |
-
/people/person/children: "child"
|
25 |
-
/people/person/religion: "religion"
|
26 |
-
/location/neighborhood/neighborhood_of: "neighborhood"
|
27 |
-
/location/country/capital: "capital"
|
28 |
-
/business/company/place_founded: "company founded location"
|
29 |
-
/people/person/profession: "occupation"
|
30 |
-
|
31 |
-
train_dataset:
|
32 |
-
_target_: "relik.reader.relik_reader_re_data.RelikREDataset"
|
33 |
-
transformer_model: "${model.model.transformer_model}"
|
34 |
-
materialize_samples: False
|
35 |
-
shuffle_candidates: False
|
36 |
-
flip_candidates: 1.0
|
37 |
-
noise_param: 0.0
|
38 |
-
for_inference: False
|
39 |
-
tokens_per_batch: 4096
|
40 |
-
min_length: -1
|
41 |
-
special_symbols: null
|
42 |
-
relations_definitions: ${data.relations_definitions}
|
43 |
-
sorting_fields:
|
44 |
-
- "predictable_candidates"
|
45 |
-
val_dataset:
|
46 |
-
_target_: "relik.reader.relik_reader_re_data.RelikREDataset"
|
47 |
-
transformer_model: "${model.model.transformer_model}"
|
48 |
-
materialize_samples: False
|
49 |
-
shuffle_candidates: False
|
50 |
-
flip_candidates: False
|
51 |
-
for_inference: True
|
52 |
-
min_length: -1
|
53 |
-
special_symbols: null
|
54 |
-
relations_definitions: ${data.relations_definitions}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/conf/training/base.yaml
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
seed: 94
|
2 |
-
|
3 |
-
trainer:
|
4 |
-
_target_: lightning.Trainer
|
5 |
-
devices:
|
6 |
-
- 0
|
7 |
-
precision: "16-mixed"
|
8 |
-
max_steps: 50000
|
9 |
-
val_check_interval: 1.0
|
10 |
-
num_sanity_val_steps: 0
|
11 |
-
limit_val_batches: 1
|
12 |
-
gradient_clip_val: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/conf/training/re.yaml
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
seed: 15
|
2 |
-
|
3 |
-
trainer:
|
4 |
-
_target_: lightning.Trainer
|
5 |
-
devices:
|
6 |
-
- 0
|
7 |
-
precision: "16-mixed"
|
8 |
-
max_steps: 100000
|
9 |
-
val_check_interval: 1.0
|
10 |
-
num_sanity_val_steps: 0
|
11 |
-
limit_val_batches: 1
|
12 |
-
gradient_clip_val: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/data/__init__.py
DELETED
File without changes
|
relik/reader/data/patches.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
from typing import List
|
2 |
-
|
3 |
-
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
4 |
-
from relik.reader.utils.special_symbols import NME_SYMBOL
|
5 |
-
|
6 |
-
|
7 |
-
def merge_patches_predictions(sample) -> None:
|
8 |
-
sample._d["predicted_window_labels"] = dict()
|
9 |
-
predicted_window_labels = sample._d["predicted_window_labels"]
|
10 |
-
|
11 |
-
sample._d["span_title_probabilities"] = dict()
|
12 |
-
span_title_probabilities = sample._d["span_title_probabilities"]
|
13 |
-
|
14 |
-
span2title = dict()
|
15 |
-
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
|
16 |
-
# selecting span predictions
|
17 |
-
for predicted_title, predicted_spans in patch_info[
|
18 |
-
"predicted_window_labels"
|
19 |
-
].items():
|
20 |
-
for pred_span in predicted_spans:
|
21 |
-
pred_span = tuple(pred_span)
|
22 |
-
curr_title = span2title.get(pred_span)
|
23 |
-
if curr_title is None or curr_title == NME_SYMBOL:
|
24 |
-
span2title[pred_span] = predicted_title
|
25 |
-
# else:
|
26 |
-
# print("Merging at patch level")
|
27 |
-
|
28 |
-
# selecting span predictions probability
|
29 |
-
for predicted_span, titles_probabilities in patch_info[
|
30 |
-
"span_title_probabilities"
|
31 |
-
].items():
|
32 |
-
if predicted_span not in span_title_probabilities:
|
33 |
-
span_title_probabilities[predicted_span] = titles_probabilities
|
34 |
-
|
35 |
-
for span, title in span2title.items():
|
36 |
-
if title not in predicted_window_labels:
|
37 |
-
predicted_window_labels[title] = list()
|
38 |
-
predicted_window_labels[title].append(span)
|
39 |
-
|
40 |
-
|
41 |
-
def remove_duplicate_samples(
|
42 |
-
samples: List[RelikReaderSample],
|
43 |
-
) -> List[RelikReaderSample]:
|
44 |
-
seen_sample = set()
|
45 |
-
samples_store = []
|
46 |
-
for sample in samples:
|
47 |
-
sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
|
48 |
-
if sample_id not in seen_sample:
|
49 |
-
seen_sample.add(sample_id)
|
50 |
-
samples_store.append(sample)
|
51 |
-
return samples_store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/data/relik_reader_data.py
DELETED
@@ -1,965 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from typing import (
|
3 |
-
Any,
|
4 |
-
Callable,
|
5 |
-
Dict,
|
6 |
-
Generator,
|
7 |
-
Iterable,
|
8 |
-
Iterator,
|
9 |
-
List,
|
10 |
-
NamedTuple,
|
11 |
-
Optional,
|
12 |
-
Tuple,
|
13 |
-
Union,
|
14 |
-
)
|
15 |
-
|
16 |
-
import numpy as np
|
17 |
-
import torch
|
18 |
-
from torch.utils.data import IterableDataset
|
19 |
-
from tqdm import tqdm
|
20 |
-
from transformers import AutoTokenizer, PreTrainedTokenizer
|
21 |
-
|
22 |
-
from relik.reader.data.relik_reader_data_utils import (
|
23 |
-
add_noise_to_value,
|
24 |
-
batchify,
|
25 |
-
chunks,
|
26 |
-
flatten,
|
27 |
-
)
|
28 |
-
from relik.reader.data.relik_reader_sample import (
|
29 |
-
RelikReaderSample,
|
30 |
-
load_relik_reader_samples,
|
31 |
-
)
|
32 |
-
from relik.reader.utils.special_symbols import NME_SYMBOL
|
33 |
-
|
34 |
-
logger = logging.getLogger(__name__)
|
35 |
-
|
36 |
-
|
37 |
-
def preprocess_dataset(
|
38 |
-
input_dataset: Iterable[dict],
|
39 |
-
transformer_model: str,
|
40 |
-
add_topic: bool,
|
41 |
-
) -> Iterable[dict]:
|
42 |
-
tokenizer = AutoTokenizer.from_pretrained(transformer_model)
|
43 |
-
for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"):
|
44 |
-
if len(dataset_elem["tokens"]) == 0:
|
45 |
-
print(
|
46 |
-
f"Dataset element with doc id: {dataset_elem['doc_id']}",
|
47 |
-
f"and offset {dataset_elem['offset']} does not contain any token",
|
48 |
-
"Skipping it",
|
49 |
-
)
|
50 |
-
continue
|
51 |
-
|
52 |
-
new_dataset_elem = dict(
|
53 |
-
doc_id=dataset_elem["doc_id"],
|
54 |
-
offset=dataset_elem["offset"],
|
55 |
-
)
|
56 |
-
|
57 |
-
tokenization_out = tokenizer(
|
58 |
-
dataset_elem["tokens"],
|
59 |
-
return_offsets_mapping=True,
|
60 |
-
add_special_tokens=False,
|
61 |
-
)
|
62 |
-
|
63 |
-
window_tokens = tokenization_out.input_ids
|
64 |
-
window_tokens = flatten(window_tokens)
|
65 |
-
|
66 |
-
offsets_mapping = [
|
67 |
-
[
|
68 |
-
(
|
69 |
-
ss + dataset_elem["token2char_start"][str(i)],
|
70 |
-
se + dataset_elem["token2char_start"][str(i)],
|
71 |
-
)
|
72 |
-
for ss, se in tokenization_out.offset_mapping[i]
|
73 |
-
]
|
74 |
-
for i in range(len(dataset_elem["tokens"]))
|
75 |
-
]
|
76 |
-
|
77 |
-
offsets_mapping = flatten(offsets_mapping)
|
78 |
-
|
79 |
-
assert len(offsets_mapping) == len(window_tokens)
|
80 |
-
|
81 |
-
window_tokens = (
|
82 |
-
[tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
|
83 |
-
)
|
84 |
-
|
85 |
-
topic_offset = 0
|
86 |
-
if add_topic:
|
87 |
-
topic_tokens = tokenizer(
|
88 |
-
dataset_elem["doc_topic"], add_special_tokens=False
|
89 |
-
).input_ids
|
90 |
-
topic_offset = len(topic_tokens)
|
91 |
-
new_dataset_elem["topic_tokens"] = topic_offset
|
92 |
-
window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
|
93 |
-
|
94 |
-
new_dataset_elem.update(
|
95 |
-
dict(
|
96 |
-
tokens=window_tokens,
|
97 |
-
token2char_start={
|
98 |
-
str(i): s
|
99 |
-
for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
|
100 |
-
},
|
101 |
-
token2char_end={
|
102 |
-
str(i): e
|
103 |
-
for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
|
104 |
-
},
|
105 |
-
window_candidates=dataset_elem["window_candidates"],
|
106 |
-
window_candidates_scores=dataset_elem.get("window_candidates_scores"),
|
107 |
-
)
|
108 |
-
)
|
109 |
-
|
110 |
-
if "window_labels" in dataset_elem:
|
111 |
-
window_labels = [
|
112 |
-
(s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"]
|
113 |
-
]
|
114 |
-
|
115 |
-
new_dataset_elem["window_labels"] = window_labels
|
116 |
-
|
117 |
-
if not all(
|
118 |
-
[
|
119 |
-
s in new_dataset_elem["token2char_start"].values()
|
120 |
-
for s, _, _ in new_dataset_elem["window_labels"]
|
121 |
-
]
|
122 |
-
):
|
123 |
-
print(
|
124 |
-
"Mismatching token start char mapping with labels",
|
125 |
-
new_dataset_elem["token2char_start"],
|
126 |
-
new_dataset_elem["window_labels"],
|
127 |
-
dataset_elem["tokens"],
|
128 |
-
)
|
129 |
-
continue
|
130 |
-
|
131 |
-
if not all(
|
132 |
-
[
|
133 |
-
e in new_dataset_elem["token2char_end"].values()
|
134 |
-
for _, e, _ in new_dataset_elem["window_labels"]
|
135 |
-
]
|
136 |
-
):
|
137 |
-
print(
|
138 |
-
"Mismatching token end char mapping with labels",
|
139 |
-
new_dataset_elem["token2char_end"],
|
140 |
-
new_dataset_elem["window_labels"],
|
141 |
-
dataset_elem["tokens"],
|
142 |
-
)
|
143 |
-
continue
|
144 |
-
|
145 |
-
yield new_dataset_elem
|
146 |
-
|
147 |
-
|
148 |
-
def preprocess_sample(
|
149 |
-
relik_sample: RelikReaderSample,
|
150 |
-
tokenizer,
|
151 |
-
lowercase_policy: float,
|
152 |
-
add_topic: bool = False,
|
153 |
-
) -> None:
|
154 |
-
if len(relik_sample.tokens) == 0:
|
155 |
-
return
|
156 |
-
|
157 |
-
if lowercase_policy > 0:
|
158 |
-
lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy
|
159 |
-
relik_sample.tokens = [
|
160 |
-
t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens)
|
161 |
-
]
|
162 |
-
|
163 |
-
tokenization_out = tokenizer(
|
164 |
-
relik_sample.tokens,
|
165 |
-
return_offsets_mapping=True,
|
166 |
-
add_special_tokens=False,
|
167 |
-
)
|
168 |
-
|
169 |
-
window_tokens = tokenization_out.input_ids
|
170 |
-
window_tokens = flatten(window_tokens)
|
171 |
-
|
172 |
-
offsets_mapping = [
|
173 |
-
[
|
174 |
-
(
|
175 |
-
ss + relik_sample.token2char_start[str(i)],
|
176 |
-
se + relik_sample.token2char_start[str(i)],
|
177 |
-
)
|
178 |
-
for ss, se in tokenization_out.offset_mapping[i]
|
179 |
-
]
|
180 |
-
for i in range(len(relik_sample.tokens))
|
181 |
-
]
|
182 |
-
|
183 |
-
offsets_mapping = flatten(offsets_mapping)
|
184 |
-
|
185 |
-
assert len(offsets_mapping) == len(window_tokens)
|
186 |
-
|
187 |
-
window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
|
188 |
-
|
189 |
-
topic_offset = 0
|
190 |
-
if add_topic:
|
191 |
-
topic_tokens = tokenizer(
|
192 |
-
relik_sample.doc_topic, add_special_tokens=False
|
193 |
-
).input_ids
|
194 |
-
topic_offset = len(topic_tokens)
|
195 |
-
relik_sample.topic_tokens = topic_offset
|
196 |
-
window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
|
197 |
-
|
198 |
-
relik_sample._d.update(
|
199 |
-
dict(
|
200 |
-
tokens=window_tokens,
|
201 |
-
token2char_start={
|
202 |
-
str(i): s
|
203 |
-
for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
|
204 |
-
},
|
205 |
-
token2char_end={
|
206 |
-
str(i): e
|
207 |
-
for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
|
208 |
-
},
|
209 |
-
)
|
210 |
-
)
|
211 |
-
|
212 |
-
if "window_labels" in relik_sample._d:
|
213 |
-
relik_sample.window_labels = [
|
214 |
-
(s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels
|
215 |
-
]
|
216 |
-
|
217 |
-
|
218 |
-
class TokenizationOutput(NamedTuple):
|
219 |
-
input_ids: torch.Tensor
|
220 |
-
attention_mask: torch.Tensor
|
221 |
-
token_type_ids: torch.Tensor
|
222 |
-
prediction_mask: torch.Tensor
|
223 |
-
special_symbols_mask: torch.Tensor
|
224 |
-
|
225 |
-
|
226 |
-
class RelikDataset(IterableDataset):
|
227 |
-
def __init__(
|
228 |
-
self,
|
229 |
-
dataset_path: Optional[str],
|
230 |
-
materialize_samples: bool,
|
231 |
-
transformer_model: Union[str, PreTrainedTokenizer],
|
232 |
-
special_symbols: List[str],
|
233 |
-
shuffle_candidates: Optional[Union[bool, float]] = False,
|
234 |
-
for_inference: bool = False,
|
235 |
-
noise_param: float = 0.1,
|
236 |
-
sorting_fields: Optional[str] = None,
|
237 |
-
tokens_per_batch: int = 2048,
|
238 |
-
batch_size: int = None,
|
239 |
-
max_batch_size: int = 128,
|
240 |
-
section_size: int = 50_000,
|
241 |
-
prebatch: bool = True,
|
242 |
-
random_drop_gold_candidates: float = 0.0,
|
243 |
-
use_nme: bool = True,
|
244 |
-
max_subwords_per_candidate: bool = 22,
|
245 |
-
mask_by_instances: bool = False,
|
246 |
-
min_length: int = 5,
|
247 |
-
max_length: int = 2048,
|
248 |
-
model_max_length: int = 1000,
|
249 |
-
split_on_cand_overload: bool = True,
|
250 |
-
skip_empty_training_samples: bool = False,
|
251 |
-
drop_last: bool = False,
|
252 |
-
samples: Optional[Iterator[RelikReaderSample]] = None,
|
253 |
-
lowercase_policy: float = 0.0,
|
254 |
-
**kwargs,
|
255 |
-
):
|
256 |
-
super().__init__(**kwargs)
|
257 |
-
self.dataset_path = dataset_path
|
258 |
-
self.materialize_samples = materialize_samples
|
259 |
-
self.samples: Optional[List[RelikReaderSample]] = None
|
260 |
-
if self.materialize_samples:
|
261 |
-
self.samples = list()
|
262 |
-
|
263 |
-
if isinstance(transformer_model, str):
|
264 |
-
self.tokenizer = self._build_tokenizer(transformer_model, special_symbols)
|
265 |
-
else:
|
266 |
-
self.tokenizer = transformer_model
|
267 |
-
self.special_symbols = special_symbols
|
268 |
-
self.shuffle_candidates = shuffle_candidates
|
269 |
-
self.for_inference = for_inference
|
270 |
-
self.noise_param = noise_param
|
271 |
-
self.batching_fields = ["input_ids"]
|
272 |
-
self.sorting_fields = (
|
273 |
-
sorting_fields if sorting_fields is not None else self.batching_fields
|
274 |
-
)
|
275 |
-
|
276 |
-
self.tokens_per_batch = tokens_per_batch
|
277 |
-
self.batch_size = batch_size
|
278 |
-
self.max_batch_size = max_batch_size
|
279 |
-
self.section_size = section_size
|
280 |
-
self.prebatch = prebatch
|
281 |
-
|
282 |
-
self.random_drop_gold_candidates = random_drop_gold_candidates
|
283 |
-
self.use_nme = use_nme
|
284 |
-
self.max_subwords_per_candidate = max_subwords_per_candidate
|
285 |
-
self.mask_by_instances = mask_by_instances
|
286 |
-
self.min_length = min_length
|
287 |
-
self.max_length = max_length
|
288 |
-
self.model_max_length = (
|
289 |
-
model_max_length
|
290 |
-
if model_max_length < self.tokenizer.model_max_length
|
291 |
-
else self.tokenizer.model_max_length
|
292 |
-
)
|
293 |
-
|
294 |
-
# retrocompatibility workaround
|
295 |
-
self.transformer_model = (
|
296 |
-
transformer_model
|
297 |
-
if isinstance(transformer_model, str)
|
298 |
-
else transformer_model.name_or_path
|
299 |
-
)
|
300 |
-
self.split_on_cand_overload = split_on_cand_overload
|
301 |
-
self.skip_empty_training_samples = skip_empty_training_samples
|
302 |
-
self.drop_last = drop_last
|
303 |
-
self.lowercase_policy = lowercase_policy
|
304 |
-
self.samples = samples
|
305 |
-
|
306 |
-
def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]):
|
307 |
-
return AutoTokenizer.from_pretrained(
|
308 |
-
transformer_model,
|
309 |
-
additional_special_tokens=[ss for ss in special_symbols],
|
310 |
-
add_prefix_space=True,
|
311 |
-
)
|
312 |
-
|
313 |
-
@property
|
314 |
-
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
|
315 |
-
fields_batchers = {
|
316 |
-
"input_ids": lambda x: batchify(
|
317 |
-
x, padding_value=self.tokenizer.pad_token_id
|
318 |
-
),
|
319 |
-
"attention_mask": lambda x: batchify(x, padding_value=0),
|
320 |
-
"token_type_ids": lambda x: batchify(x, padding_value=0),
|
321 |
-
"prediction_mask": lambda x: batchify(x, padding_value=1),
|
322 |
-
"global_attention": lambda x: batchify(x, padding_value=0),
|
323 |
-
"token2word": None,
|
324 |
-
"sample": None,
|
325 |
-
"special_symbols_mask": lambda x: batchify(x, padding_value=False),
|
326 |
-
"start_labels": lambda x: batchify(x, padding_value=-100),
|
327 |
-
"end_labels": lambda x: batchify(x, padding_value=-100),
|
328 |
-
"predictable_candidates_symbols": None,
|
329 |
-
"predictable_candidates": None,
|
330 |
-
"patch_offset": None,
|
331 |
-
"optimus_labels": None,
|
332 |
-
}
|
333 |
-
|
334 |
-
if "roberta" in self.transformer_model:
|
335 |
-
del fields_batchers["token_type_ids"]
|
336 |
-
|
337 |
-
return fields_batchers
|
338 |
-
|
339 |
-
def _build_input_ids(
|
340 |
-
self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]]
|
341 |
-
) -> List[int]:
|
342 |
-
return (
|
343 |
-
[self.tokenizer.cls_token_id]
|
344 |
-
+ sentence_input_ids
|
345 |
-
+ [self.tokenizer.sep_token_id]
|
346 |
-
+ flatten(candidates_input_ids)
|
347 |
-
+ [self.tokenizer.sep_token_id]
|
348 |
-
)
|
349 |
-
|
350 |
-
def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
351 |
-
special_symbols_mask = input_ids >= (
|
352 |
-
len(self.tokenizer) - len(self.special_symbols)
|
353 |
-
)
|
354 |
-
special_symbols_mask[0] = True
|
355 |
-
return special_symbols_mask
|
356 |
-
|
357 |
-
def _build_tokenizer_essentials(
|
358 |
-
self, input_ids, original_sequence, sample
|
359 |
-
) -> TokenizationOutput:
|
360 |
-
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
361 |
-
attention_mask = torch.ones_like(input_ids)
|
362 |
-
|
363 |
-
total_sequence_len = len(input_ids)
|
364 |
-
predictable_sentence_len = len(original_sequence)
|
365 |
-
|
366 |
-
# token type ids
|
367 |
-
token_type_ids = torch.cat(
|
368 |
-
[
|
369 |
-
input_ids.new_zeros(
|
370 |
-
predictable_sentence_len + 2
|
371 |
-
), # original sentence bpes + CLS and SEP
|
372 |
-
input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2),
|
373 |
-
]
|
374 |
-
)
|
375 |
-
|
376 |
-
# prediction mask -> boolean on tokens that are predictable
|
377 |
-
|
378 |
-
prediction_mask = torch.tensor(
|
379 |
-
[1]
|
380 |
-
+ ([0] * predictable_sentence_len)
|
381 |
-
+ ([1] * (total_sequence_len - predictable_sentence_len - 1))
|
382 |
-
)
|
383 |
-
|
384 |
-
# add topic tokens to the prediction mask so that they cannot be predicted
|
385 |
-
# or optimized during training
|
386 |
-
topic_tokens = getattr(sample, "topic_tokens", None)
|
387 |
-
if topic_tokens is not None:
|
388 |
-
prediction_mask[1 : 1 + topic_tokens] = 1
|
389 |
-
|
390 |
-
# If mask by instances is active the prediction mask is applied to everything
|
391 |
-
# that is not indicated as an instance in the training set.
|
392 |
-
if self.mask_by_instances:
|
393 |
-
char_start2token = {
|
394 |
-
cs: int(tok) for tok, cs in sample.token2char_start.items()
|
395 |
-
}
|
396 |
-
char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()}
|
397 |
-
instances_mask = torch.ones_like(prediction_mask)
|
398 |
-
for _, span_info in sample.instance_id2span_data.items():
|
399 |
-
span_info = span_info[0]
|
400 |
-
token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS
|
401 |
-
token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS
|
402 |
-
instances_mask[token_start : token_end + 1] = 0
|
403 |
-
|
404 |
-
prediction_mask += instances_mask
|
405 |
-
prediction_mask[prediction_mask > 1] = 1
|
406 |
-
|
407 |
-
assert len(prediction_mask) == len(input_ids)
|
408 |
-
|
409 |
-
# special symbols mask
|
410 |
-
special_symbols_mask = self._get_special_symbols_mask(input_ids)
|
411 |
-
|
412 |
-
return TokenizationOutput(
|
413 |
-
input_ids,
|
414 |
-
attention_mask,
|
415 |
-
token_type_ids,
|
416 |
-
prediction_mask,
|
417 |
-
special_symbols_mask,
|
418 |
-
)
|
419 |
-
|
420 |
-
def _build_labels(
|
421 |
-
self,
|
422 |
-
sample,
|
423 |
-
tokenization_output: TokenizationOutput,
|
424 |
-
predictable_candidates: List[str],
|
425 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
426 |
-
start_labels = [0] * len(tokenization_output.input_ids)
|
427 |
-
end_labels = [0] * len(tokenization_output.input_ids)
|
428 |
-
|
429 |
-
char_start2token = {v: int(k) for k, v in sample.token2char_start.items()}
|
430 |
-
char_end2token = {v: int(k) for k, v in sample.token2char_end.items()}
|
431 |
-
for cs, ce, gold_candidate_title in sample.window_labels:
|
432 |
-
if gold_candidate_title not in predictable_candidates:
|
433 |
-
if self.use_nme:
|
434 |
-
gold_candidate_title = NME_SYMBOL
|
435 |
-
else:
|
436 |
-
continue
|
437 |
-
# +1 is to account for the CLS token
|
438 |
-
start_bpe = char_start2token[cs] + 1
|
439 |
-
end_bpe = char_end2token[ce] + 1
|
440 |
-
class_index = predictable_candidates.index(gold_candidate_title)
|
441 |
-
if (
|
442 |
-
start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0
|
443 |
-
): # prevent from having entities that ends with the same label
|
444 |
-
start_labels[start_bpe] = class_index + 1 # +1 for the NONE class
|
445 |
-
end_labels[end_bpe] = class_index + 1 # +1 for the NONE class
|
446 |
-
else:
|
447 |
-
print(
|
448 |
-
"Found entity with the same last subword, it will not be included."
|
449 |
-
)
|
450 |
-
print(
|
451 |
-
cs,
|
452 |
-
ce,
|
453 |
-
gold_candidate_title,
|
454 |
-
start_labels,
|
455 |
-
end_labels,
|
456 |
-
sample.doc_id,
|
457 |
-
)
|
458 |
-
|
459 |
-
ignored_labels_indices = tokenization_output.prediction_mask == 1
|
460 |
-
|
461 |
-
start_labels = torch.tensor(start_labels, dtype=torch.long)
|
462 |
-
start_labels[ignored_labels_indices] = -100
|
463 |
-
|
464 |
-
end_labels = torch.tensor(end_labels, dtype=torch.long)
|
465 |
-
end_labels[ignored_labels_indices] = -100
|
466 |
-
|
467 |
-
return start_labels, end_labels
|
468 |
-
|
469 |
-
def produce_sample_bag(
|
470 |
-
self, sample, predictable_candidates: List[str], candidates_starting_offset: int
|
471 |
-
) -> Optional[Tuple[dict, list, int]]:
|
472 |
-
# input sentence tokenization
|
473 |
-
input_subwords = sample.tokens[1:-1] # removing special tokens
|
474 |
-
candidates_symbols = self.special_symbols[candidates_starting_offset:]
|
475 |
-
|
476 |
-
predictable_candidates = list(predictable_candidates)
|
477 |
-
original_predictable_candidates = list(predictable_candidates)
|
478 |
-
|
479 |
-
# add NME as a possible candidate
|
480 |
-
if self.use_nme:
|
481 |
-
predictable_candidates.insert(0, NME_SYMBOL)
|
482 |
-
|
483 |
-
# candidates encoding
|
484 |
-
candidates_symbols = candidates_symbols[: len(predictable_candidates)]
|
485 |
-
candidates_encoding_result = self.tokenizer.batch_encode_plus(
|
486 |
-
[
|
487 |
-
"{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL
|
488 |
-
for cs, ct in zip(candidates_symbols, predictable_candidates)
|
489 |
-
],
|
490 |
-
add_special_tokens=False,
|
491 |
-
).input_ids
|
492 |
-
|
493 |
-
if (
|
494 |
-
self.max_subwords_per_candidate is not None
|
495 |
-
and self.max_subwords_per_candidate > 0
|
496 |
-
):
|
497 |
-
candidates_encoding_result = [
|
498 |
-
cer[: self.max_subwords_per_candidate]
|
499 |
-
for cer in candidates_encoding_result
|
500 |
-
]
|
501 |
-
|
502 |
-
# drop candidates if the number of input tokens is too long for the model
|
503 |
-
if (
|
504 |
-
sum(map(len, candidates_encoding_result))
|
505 |
-
+ len(input_subwords)
|
506 |
-
+ 20 # + 20 special tokens
|
507 |
-
> self.model_max_length
|
508 |
-
):
|
509 |
-
acceptable_tokens_from_candidates = (
|
510 |
-
self.model_max_length - 20 - len(input_subwords)
|
511 |
-
)
|
512 |
-
i = 0
|
513 |
-
cum_len = 0
|
514 |
-
while (
|
515 |
-
cum_len + len(candidates_encoding_result[i])
|
516 |
-
< acceptable_tokens_from_candidates
|
517 |
-
):
|
518 |
-
cum_len += len(candidates_encoding_result[i])
|
519 |
-
i += 1
|
520 |
-
|
521 |
-
candidates_encoding_result = candidates_encoding_result[:i]
|
522 |
-
candidates_symbols = candidates_symbols[:i]
|
523 |
-
predictable_candidates = predictable_candidates[:i]
|
524 |
-
|
525 |
-
# final input_ids build
|
526 |
-
input_ids = self._build_input_ids(
|
527 |
-
sentence_input_ids=input_subwords,
|
528 |
-
candidates_input_ids=candidates_encoding_result,
|
529 |
-
)
|
530 |
-
|
531 |
-
# complete input building (e.g. attention / prediction mask)
|
532 |
-
tokenization_output = self._build_tokenizer_essentials(
|
533 |
-
input_ids, input_subwords, sample
|
534 |
-
)
|
535 |
-
|
536 |
-
output_dict = {
|
537 |
-
"input_ids": tokenization_output.input_ids,
|
538 |
-
"attention_mask": tokenization_output.attention_mask,
|
539 |
-
"token_type_ids": tokenization_output.token_type_ids,
|
540 |
-
"prediction_mask": tokenization_output.prediction_mask,
|
541 |
-
"special_symbols_mask": tokenization_output.special_symbols_mask,
|
542 |
-
"sample": sample,
|
543 |
-
"predictable_candidates_symbols": candidates_symbols,
|
544 |
-
"predictable_candidates": predictable_candidates,
|
545 |
-
}
|
546 |
-
|
547 |
-
# labels creation
|
548 |
-
if sample.window_labels is not None:
|
549 |
-
start_labels, end_labels = self._build_labels(
|
550 |
-
sample,
|
551 |
-
tokenization_output,
|
552 |
-
predictable_candidates,
|
553 |
-
)
|
554 |
-
output_dict.update(start_labels=start_labels, end_labels=end_labels)
|
555 |
-
|
556 |
-
if (
|
557 |
-
"roberta" in self.transformer_model
|
558 |
-
or "longformer" in self.transformer_model
|
559 |
-
):
|
560 |
-
del output_dict["token_type_ids"]
|
561 |
-
|
562 |
-
predictable_candidates_set = set(predictable_candidates)
|
563 |
-
remaining_candidates = [
|
564 |
-
candidate
|
565 |
-
for candidate in original_predictable_candidates
|
566 |
-
if candidate not in predictable_candidates_set
|
567 |
-
]
|
568 |
-
total_used_candidates = (
|
569 |
-
candidates_starting_offset
|
570 |
-
+ len(predictable_candidates)
|
571 |
-
- (1 if self.use_nme else 0)
|
572 |
-
)
|
573 |
-
|
574 |
-
if self.use_nme:
|
575 |
-
assert predictable_candidates[0] == NME_SYMBOL
|
576 |
-
|
577 |
-
return output_dict, remaining_candidates, total_used_candidates
|
578 |
-
|
579 |
-
def __iter__(self):
|
580 |
-
dataset_iterator = self.dataset_iterator_func()
|
581 |
-
|
582 |
-
current_dataset_elements = []
|
583 |
-
|
584 |
-
i = None
|
585 |
-
for i, dataset_elem in enumerate(dataset_iterator, start=1):
|
586 |
-
if (
|
587 |
-
self.section_size is not None
|
588 |
-
and len(current_dataset_elements) == self.section_size
|
589 |
-
):
|
590 |
-
for batch in self.materialize_batches(current_dataset_elements):
|
591 |
-
yield batch
|
592 |
-
current_dataset_elements = []
|
593 |
-
|
594 |
-
current_dataset_elements.append(dataset_elem)
|
595 |
-
|
596 |
-
if i % 50_000 == 0:
|
597 |
-
logger.info(f"Processed: {i} number of elements")
|
598 |
-
|
599 |
-
if len(current_dataset_elements) != 0:
|
600 |
-
for batch in self.materialize_batches(current_dataset_elements):
|
601 |
-
yield batch
|
602 |
-
|
603 |
-
if i is not None:
|
604 |
-
logger.info(f"Dataset finished: {i} number of elements processed")
|
605 |
-
else:
|
606 |
-
logger.warning("Dataset empty")
|
607 |
-
|
608 |
-
def dataset_iterator_func(self):
|
609 |
-
skipped_instances = 0
|
610 |
-
data_samples = (
|
611 |
-
load_relik_reader_samples(self.dataset_path)
|
612 |
-
if self.samples is None
|
613 |
-
else self.samples
|
614 |
-
)
|
615 |
-
for sample in data_samples:
|
616 |
-
preprocess_sample(
|
617 |
-
sample, self.tokenizer, lowercase_policy=self.lowercase_policy
|
618 |
-
)
|
619 |
-
current_patch = 0
|
620 |
-
sample_bag, used_candidates = None, None
|
621 |
-
remaining_candidates = list(sample.window_candidates)
|
622 |
-
|
623 |
-
if not self.for_inference:
|
624 |
-
# randomly drop gold candidates at training time
|
625 |
-
if (
|
626 |
-
self.random_drop_gold_candidates > 0.0
|
627 |
-
and np.random.uniform() < self.random_drop_gold_candidates
|
628 |
-
and len(set(ct for _, _, ct in sample.window_labels)) > 1
|
629 |
-
):
|
630 |
-
# selecting candidates to drop
|
631 |
-
np.random.shuffle(sample.window_labels)
|
632 |
-
n_dropped_candidates = np.random.randint(
|
633 |
-
0, len(sample.window_labels) - 1
|
634 |
-
)
|
635 |
-
dropped_candidates = [
|
636 |
-
label_elem[-1]
|
637 |
-
for label_elem in sample.window_labels[:n_dropped_candidates]
|
638 |
-
]
|
639 |
-
dropped_candidates = set(dropped_candidates)
|
640 |
-
|
641 |
-
# saving NMEs because they should not be dropped
|
642 |
-
if NME_SYMBOL in dropped_candidates:
|
643 |
-
dropped_candidates.remove(NME_SYMBOL)
|
644 |
-
|
645 |
-
# sample update
|
646 |
-
sample.window_labels = [
|
647 |
-
(s, e, _l)
|
648 |
-
if _l not in dropped_candidates
|
649 |
-
else (s, e, NME_SYMBOL)
|
650 |
-
for s, e, _l in sample.window_labels
|
651 |
-
]
|
652 |
-
remaining_candidates = [
|
653 |
-
wc
|
654 |
-
for wc in remaining_candidates
|
655 |
-
if wc not in dropped_candidates
|
656 |
-
]
|
657 |
-
|
658 |
-
# shuffle candidates
|
659 |
-
if (
|
660 |
-
isinstance(self.shuffle_candidates, bool)
|
661 |
-
and self.shuffle_candidates
|
662 |
-
) or (
|
663 |
-
isinstance(self.shuffle_candidates, float)
|
664 |
-
and np.random.uniform() < self.shuffle_candidates
|
665 |
-
):
|
666 |
-
np.random.shuffle(remaining_candidates)
|
667 |
-
|
668 |
-
while len(remaining_candidates) != 0:
|
669 |
-
sample_bag = self.produce_sample_bag(
|
670 |
-
sample,
|
671 |
-
predictable_candidates=remaining_candidates,
|
672 |
-
candidates_starting_offset=used_candidates
|
673 |
-
if used_candidates is not None
|
674 |
-
else 0,
|
675 |
-
)
|
676 |
-
if sample_bag is not None:
|
677 |
-
sample_bag, remaining_candidates, used_candidates = sample_bag
|
678 |
-
if (
|
679 |
-
self.for_inference
|
680 |
-
or not self.skip_empty_training_samples
|
681 |
-
or (
|
682 |
-
(
|
683 |
-
sample_bag.get("start_labels") is not None
|
684 |
-
and torch.any(sample_bag["start_labels"] > 1).item()
|
685 |
-
)
|
686 |
-
or (
|
687 |
-
sample_bag.get("optimus_labels") is not None
|
688 |
-
and len(sample_bag["optimus_labels"]) > 0
|
689 |
-
)
|
690 |
-
)
|
691 |
-
):
|
692 |
-
sample_bag["patch_offset"] = current_patch
|
693 |
-
current_patch += 1
|
694 |
-
yield sample_bag
|
695 |
-
else:
|
696 |
-
skipped_instances += 1
|
697 |
-
if skipped_instances % 1000 == 0 and skipped_instances != 0:
|
698 |
-
logger.info(
|
699 |
-
f"Skipped {skipped_instances} instances since they did not have any gold labels..."
|
700 |
-
)
|
701 |
-
|
702 |
-
# Just use the first fitting candidates if split on
|
703 |
-
# cand is not True
|
704 |
-
if not self.split_on_cand_overload:
|
705 |
-
break
|
706 |
-
|
707 |
-
def preshuffle_elements(self, dataset_elements: List):
|
708 |
-
# This shuffling is done so that when using the sorting function,
|
709 |
-
# if it is deterministic given a collection and its order, we will
|
710 |
-
# make the whole operation not deterministic anymore.
|
711 |
-
# Basically, the aim is not to build every time the same batches.
|
712 |
-
if not self.for_inference:
|
713 |
-
dataset_elements = np.random.permutation(dataset_elements)
|
714 |
-
|
715 |
-
sorting_fn = (
|
716 |
-
lambda elem: add_noise_to_value(
|
717 |
-
sum(len(elem[k]) for k in self.sorting_fields),
|
718 |
-
noise_param=self.noise_param,
|
719 |
-
)
|
720 |
-
if not self.for_inference
|
721 |
-
else sum(len(elem[k]) for k in self.sorting_fields)
|
722 |
-
)
|
723 |
-
|
724 |
-
dataset_elements = sorted(dataset_elements, key=sorting_fn)
|
725 |
-
|
726 |
-
if self.for_inference:
|
727 |
-
return dataset_elements
|
728 |
-
|
729 |
-
ds = list(chunks(dataset_elements, 64))
|
730 |
-
np.random.shuffle(ds)
|
731 |
-
return flatten(ds)
|
732 |
-
|
733 |
-
def materialize_batches(
|
734 |
-
self, dataset_elements: List[Dict[str, Any]]
|
735 |
-
) -> Generator[Dict[str, Any], None, None]:
|
736 |
-
if self.prebatch:
|
737 |
-
dataset_elements = self.preshuffle_elements(dataset_elements)
|
738 |
-
|
739 |
-
current_batch = []
|
740 |
-
|
741 |
-
# function that creates a batch from the 'current_batch' list
|
742 |
-
def output_batch() -> Dict[str, Any]:
|
743 |
-
assert (
|
744 |
-
len(
|
745 |
-
set([len(elem["predictable_candidates"]) for elem in current_batch])
|
746 |
-
)
|
747 |
-
== 1
|
748 |
-
), " ".join(
|
749 |
-
map(
|
750 |
-
str, [len(elem["predictable_candidates"]) for elem in current_batch]
|
751 |
-
)
|
752 |
-
)
|
753 |
-
|
754 |
-
batch_dict = dict()
|
755 |
-
|
756 |
-
de_values_by_field = {
|
757 |
-
fn: [de[fn] for de in current_batch if fn in de]
|
758 |
-
for fn in self.fields_batcher
|
759 |
-
}
|
760 |
-
|
761 |
-
# in case you provide fields batchers but in the batch
|
762 |
-
# there are no elements for that field
|
763 |
-
de_values_by_field = {
|
764 |
-
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
|
765 |
-
}
|
766 |
-
|
767 |
-
assert len(set([len(v) for v in de_values_by_field.values()]))
|
768 |
-
|
769 |
-
# todo: maybe we should report the user about possible
|
770 |
-
# fields filtering due to "None" instances
|
771 |
-
de_values_by_field = {
|
772 |
-
fn: fvs
|
773 |
-
for fn, fvs in de_values_by_field.items()
|
774 |
-
if all([fv is not None for fv in fvs])
|
775 |
-
}
|
776 |
-
|
777 |
-
for field_name, field_values in de_values_by_field.items():
|
778 |
-
field_batch = (
|
779 |
-
self.fields_batcher[field_name](field_values)
|
780 |
-
if self.fields_batcher[field_name] is not None
|
781 |
-
else field_values
|
782 |
-
)
|
783 |
-
|
784 |
-
batch_dict[field_name] = field_batch
|
785 |
-
|
786 |
-
return batch_dict
|
787 |
-
|
788 |
-
max_len_discards, min_len_discards = 0, 0
|
789 |
-
|
790 |
-
should_token_batch = self.batch_size is None
|
791 |
-
|
792 |
-
curr_pred_elements = -1
|
793 |
-
for de in dataset_elements:
|
794 |
-
if (
|
795 |
-
should_token_batch
|
796 |
-
and self.max_batch_size != -1
|
797 |
-
and len(current_batch) == self.max_batch_size
|
798 |
-
) or (not should_token_batch and len(current_batch) == self.batch_size):
|
799 |
-
yield output_batch()
|
800 |
-
current_batch = []
|
801 |
-
curr_pred_elements = -1
|
802 |
-
|
803 |
-
too_long_fields = [
|
804 |
-
k
|
805 |
-
for k in de
|
806 |
-
if self.max_length != -1
|
807 |
-
and torch.is_tensor(de[k])
|
808 |
-
and len(de[k]) > self.max_length
|
809 |
-
]
|
810 |
-
if len(too_long_fields) > 0:
|
811 |
-
max_len_discards += 1
|
812 |
-
continue
|
813 |
-
|
814 |
-
too_short_fields = [
|
815 |
-
k
|
816 |
-
for k in de
|
817 |
-
if self.min_length != -1
|
818 |
-
and torch.is_tensor(de[k])
|
819 |
-
and len(de[k]) < self.min_length
|
820 |
-
]
|
821 |
-
if len(too_short_fields) > 0:
|
822 |
-
min_len_discards += 1
|
823 |
-
continue
|
824 |
-
|
825 |
-
if should_token_batch:
|
826 |
-
de_len = sum(len(de[k]) for k in self.batching_fields)
|
827 |
-
|
828 |
-
future_max_len = max(
|
829 |
-
de_len,
|
830 |
-
max(
|
831 |
-
[
|
832 |
-
sum(len(bde[k]) for k in self.batching_fields)
|
833 |
-
for bde in current_batch
|
834 |
-
],
|
835 |
-
default=0,
|
836 |
-
),
|
837 |
-
)
|
838 |
-
|
839 |
-
future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
|
840 |
-
|
841 |
-
num_predictable_candidates = len(de["predictable_candidates"])
|
842 |
-
|
843 |
-
if len(current_batch) > 0 and (
|
844 |
-
future_tokens_per_batch >= self.tokens_per_batch
|
845 |
-
or (
|
846 |
-
num_predictable_candidates != curr_pred_elements
|
847 |
-
and curr_pred_elements != -1
|
848 |
-
)
|
849 |
-
):
|
850 |
-
yield output_batch()
|
851 |
-
current_batch = []
|
852 |
-
|
853 |
-
current_batch.append(de)
|
854 |
-
curr_pred_elements = len(de["predictable_candidates"])
|
855 |
-
|
856 |
-
if len(current_batch) != 0 and not self.drop_last:
|
857 |
-
yield output_batch()
|
858 |
-
|
859 |
-
if max_len_discards > 0:
|
860 |
-
if self.for_inference:
|
861 |
-
logger.warning(
|
862 |
-
f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were "
|
863 |
-
f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
|
864 |
-
f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the "
|
865 |
-
f"sample length exceeds the maximum length supported by the current model."
|
866 |
-
)
|
867 |
-
else:
|
868 |
-
logger.warning(
|
869 |
-
f"During iteration, {max_len_discards} elements were "
|
870 |
-
f"discarded since longer than max length {self.max_length}"
|
871 |
-
)
|
872 |
-
|
873 |
-
if min_len_discards > 0:
|
874 |
-
if self.for_inference:
|
875 |
-
logger.warning(
|
876 |
-
f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were "
|
877 |
-
f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
|
878 |
-
f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the "
|
879 |
-
f"sample length is shorter than the minimum length supported by the current model."
|
880 |
-
)
|
881 |
-
else:
|
882 |
-
logger.warning(
|
883 |
-
f"During iteration, {min_len_discards} elements were "
|
884 |
-
f"discarded since shorter than min length {self.min_length}"
|
885 |
-
)
|
886 |
-
|
887 |
-
@staticmethod
|
888 |
-
def convert_tokens_to_char_annotations(
|
889 |
-
sample: RelikReaderSample,
|
890 |
-
remove_nmes: bool = True,
|
891 |
-
) -> RelikReaderSample:
|
892 |
-
"""
|
893 |
-
Converts the token annotations to char annotations.
|
894 |
-
|
895 |
-
Args:
|
896 |
-
sample (:obj:`RelikReaderSample`):
|
897 |
-
The sample to convert.
|
898 |
-
remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
899 |
-
Whether to remove the NMEs from the annotations.
|
900 |
-
Returns:
|
901 |
-
:obj:`RelikReaderSample`: The converted sample.
|
902 |
-
"""
|
903 |
-
char_annotations = set()
|
904 |
-
for (
|
905 |
-
predicted_entity,
|
906 |
-
predicted_spans,
|
907 |
-
) in sample.predicted_window_labels.items():
|
908 |
-
if predicted_entity == NME_SYMBOL and remove_nmes:
|
909 |
-
continue
|
910 |
-
|
911 |
-
for span_start, span_end in predicted_spans:
|
912 |
-
span_start = sample.token2char_start[str(span_start)]
|
913 |
-
span_end = sample.token2char_end[str(span_end)]
|
914 |
-
|
915 |
-
char_annotations.add((span_start, span_end, predicted_entity))
|
916 |
-
|
917 |
-
char_probs_annotations = dict()
|
918 |
-
for (
|
919 |
-
span_start,
|
920 |
-
span_end,
|
921 |
-
), candidates_probs in sample.span_title_probabilities.items():
|
922 |
-
span_start = sample.token2char_start[str(span_start)]
|
923 |
-
span_end = sample.token2char_end[str(span_end)]
|
924 |
-
char_probs_annotations[(span_start, span_end)] = {
|
925 |
-
title for title, _ in candidates_probs
|
926 |
-
}
|
927 |
-
|
928 |
-
sample.predicted_window_labels_chars = char_annotations
|
929 |
-
sample.probs_window_labels_chars = char_probs_annotations
|
930 |
-
|
931 |
-
return sample
|
932 |
-
|
933 |
-
@staticmethod
|
934 |
-
def merge_patches_predictions(sample) -> None:
|
935 |
-
sample._d["predicted_window_labels"] = dict()
|
936 |
-
predicted_window_labels = sample._d["predicted_window_labels"]
|
937 |
-
|
938 |
-
sample._d["span_title_probabilities"] = dict()
|
939 |
-
span_title_probabilities = sample._d["span_title_probabilities"]
|
940 |
-
|
941 |
-
span2title = dict()
|
942 |
-
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
|
943 |
-
# selecting span predictions
|
944 |
-
for predicted_title, predicted_spans in patch_info[
|
945 |
-
"predicted_window_labels"
|
946 |
-
].items():
|
947 |
-
for pred_span in predicted_spans:
|
948 |
-
pred_span = tuple(pred_span)
|
949 |
-
curr_title = span2title.get(pred_span)
|
950 |
-
if curr_title is None or curr_title == NME_SYMBOL:
|
951 |
-
span2title[pred_span] = predicted_title
|
952 |
-
# else:
|
953 |
-
# print("Merging at patch level")
|
954 |
-
|
955 |
-
# selecting span predictions probability
|
956 |
-
for predicted_span, titles_probabilities in patch_info[
|
957 |
-
"span_title_probabilities"
|
958 |
-
].items():
|
959 |
-
if predicted_span not in span_title_probabilities:
|
960 |
-
span_title_probabilities[predicted_span] = titles_probabilities
|
961 |
-
|
962 |
-
for span, title in span2title.items():
|
963 |
-
if title not in predicted_window_labels:
|
964 |
-
predicted_window_labels[title] = list()
|
965 |
-
predicted_window_labels[title].append(span)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/data/relik_reader_data_utils.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
from typing import List
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
|
6 |
-
|
7 |
-
def flatten(lsts: List[list]) -> list:
|
8 |
-
acc_lst = list()
|
9 |
-
for lst in lsts:
|
10 |
-
acc_lst.extend(lst)
|
11 |
-
return acc_lst
|
12 |
-
|
13 |
-
|
14 |
-
def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor:
|
15 |
-
return torch.nn.utils.rnn.pad_sequence(
|
16 |
-
tensors, batch_first=True, padding_value=padding_value
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
|
21 |
-
x = max([t.shape[0] for t in tensors])
|
22 |
-
y = max([t.shape[1] for t in tensors])
|
23 |
-
out_matrix = torch.zeros((len(tensors), x, y))
|
24 |
-
out_matrix += padding_value
|
25 |
-
for i, tensor in enumerate(tensors):
|
26 |
-
out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor
|
27 |
-
return out_matrix
|
28 |
-
|
29 |
-
|
30 |
-
def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
|
31 |
-
x = max([t.shape[0] for t in tensors])
|
32 |
-
y = max([t.shape[1] for t in tensors])
|
33 |
-
rest = tensors[0].shape[2]
|
34 |
-
out_matrix = torch.zeros((len(tensors), x, y, rest))
|
35 |
-
out_matrix += padding_value
|
36 |
-
for i, tensor in enumerate(tensors):
|
37 |
-
out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor
|
38 |
-
return out_matrix
|
39 |
-
|
40 |
-
|
41 |
-
def chunks(lst: list, chunk_size: int) -> List[list]:
|
42 |
-
chunks_acc = list()
|
43 |
-
for i in range(0, len(lst), chunk_size):
|
44 |
-
chunks_acc.append(lst[i : i + chunk_size])
|
45 |
-
return chunks_acc
|
46 |
-
|
47 |
-
|
48 |
-
def add_noise_to_value(value: int, noise_param: float):
|
49 |
-
noise_value = value * noise_param
|
50 |
-
noise = np.random.uniform(-noise_value, noise_value)
|
51 |
-
return max(1, value + noise)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/data/relik_reader_sample.py
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
from typing import Iterable
|
3 |
-
|
4 |
-
|
5 |
-
class RelikReaderSample:
|
6 |
-
def __init__(self, **kwargs):
|
7 |
-
super().__setattr__("_d", {})
|
8 |
-
self._d = kwargs
|
9 |
-
|
10 |
-
def __getattribute__(self, item):
|
11 |
-
return super(RelikReaderSample, self).__getattribute__(item)
|
12 |
-
|
13 |
-
def __getattr__(self, item):
|
14 |
-
if item.startswith("__") and item.endswith("__"):
|
15 |
-
# this is likely some python library-specific variable (such as __deepcopy__ for copy)
|
16 |
-
# better follow standard behavior here
|
17 |
-
raise AttributeError(item)
|
18 |
-
elif item in self._d:
|
19 |
-
return self._d[item]
|
20 |
-
else:
|
21 |
-
return None
|
22 |
-
|
23 |
-
def __setattr__(self, key, value):
|
24 |
-
if key in self._d:
|
25 |
-
self._d[key] = value
|
26 |
-
else:
|
27 |
-
super().__setattr__(key, value)
|
28 |
-
|
29 |
-
def to_jsons(self) -> str:
|
30 |
-
if "predicted_window_labels" in self._d:
|
31 |
-
new_obj = {
|
32 |
-
k: v
|
33 |
-
for k, v in self._d.items()
|
34 |
-
if k != "predicted_window_labels" and k != "span_title_probabilities"
|
35 |
-
}
|
36 |
-
new_obj["predicted_window_labels"] = [
|
37 |
-
[ss, se, pred_title]
|
38 |
-
for (ss, se), pred_title in self.predicted_window_labels_chars
|
39 |
-
]
|
40 |
-
else:
|
41 |
-
return json.dumps(self._d)
|
42 |
-
|
43 |
-
|
44 |
-
def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]:
|
45 |
-
with open(path) as f:
|
46 |
-
for line in f:
|
47 |
-
jsonl_line = json.loads(line.strip())
|
48 |
-
relik_reader_sample = RelikReaderSample(**jsonl_line)
|
49 |
-
yield relik_reader_sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/lightning_modules/__init__.py
DELETED
File without changes
|
relik/reader/lightning_modules/relik_reader_pl_module.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
from typing import Any, Optional
|
2 |
-
|
3 |
-
import lightning
|
4 |
-
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
|
5 |
-
|
6 |
-
from relik.reader.relik_reader_core import RelikReaderCoreModel
|
7 |
-
|
8 |
-
|
9 |
-
class RelikReaderPLModule(lightning.LightningModule):
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
cfg: dict,
|
13 |
-
transformer_model: str,
|
14 |
-
additional_special_symbols: int,
|
15 |
-
num_layers: Optional[int] = None,
|
16 |
-
activation: str = "gelu",
|
17 |
-
linears_hidden_size: Optional[int] = 512,
|
18 |
-
use_last_k_layers: int = 1,
|
19 |
-
training: bool = False,
|
20 |
-
*args: Any,
|
21 |
-
**kwargs: Any
|
22 |
-
):
|
23 |
-
super().__init__(*args, **kwargs)
|
24 |
-
self.save_hyperparameters()
|
25 |
-
self.relik_reader_core_model = RelikReaderCoreModel(
|
26 |
-
transformer_model,
|
27 |
-
additional_special_symbols,
|
28 |
-
num_layers,
|
29 |
-
activation,
|
30 |
-
linears_hidden_size,
|
31 |
-
use_last_k_layers,
|
32 |
-
training=training,
|
33 |
-
)
|
34 |
-
self.optimizer_factory = None
|
35 |
-
|
36 |
-
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
37 |
-
relik_output = self.relik_reader_core_model(**batch)
|
38 |
-
self.log("train-loss", relik_output["loss"])
|
39 |
-
return relik_output["loss"]
|
40 |
-
|
41 |
-
def validation_step(
|
42 |
-
self, batch: dict, *args: Any, **kwargs: Any
|
43 |
-
) -> Optional[STEP_OUTPUT]:
|
44 |
-
return
|
45 |
-
|
46 |
-
def set_optimizer_factory(self, optimizer_factory) -> None:
|
47 |
-
self.optimizer_factory = optimizer_factory
|
48 |
-
|
49 |
-
def configure_optimizers(self) -> OptimizerLRScheduler:
|
50 |
-
return self.optimizer_factory(self.relik_reader_core_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/lightning_modules/relik_reader_re_pl_module.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
from typing import Any, Optional
|
2 |
-
|
3 |
-
import lightning
|
4 |
-
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
|
5 |
-
|
6 |
-
from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
|
7 |
-
|
8 |
-
|
9 |
-
class RelikReaderREPLModule(lightning.LightningModule):
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
cfg: dict,
|
13 |
-
transformer_model: str,
|
14 |
-
additional_special_symbols: int,
|
15 |
-
num_layers: Optional[int] = None,
|
16 |
-
activation: str = "gelu",
|
17 |
-
linears_hidden_size: Optional[int] = 512,
|
18 |
-
use_last_k_layers: int = 1,
|
19 |
-
training: bool = False,
|
20 |
-
*args: Any,
|
21 |
-
**kwargs: Any
|
22 |
-
):
|
23 |
-
super().__init__(*args, **kwargs)
|
24 |
-
self.save_hyperparameters()
|
25 |
-
|
26 |
-
self.relik_reader_re_model = RelikReaderForTripletExtraction(
|
27 |
-
transformer_model,
|
28 |
-
additional_special_symbols,
|
29 |
-
num_layers,
|
30 |
-
activation,
|
31 |
-
linears_hidden_size,
|
32 |
-
use_last_k_layers,
|
33 |
-
training=training,
|
34 |
-
)
|
35 |
-
self.optimizer_factory = None
|
36 |
-
|
37 |
-
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
38 |
-
relik_output = self.relik_reader_re_model(**batch)
|
39 |
-
self.log("train-loss", relik_output["loss"])
|
40 |
-
self.log("train-start_loss", relik_output["ned_start_loss"])
|
41 |
-
self.log("train-end_loss", relik_output["ned_end_loss"])
|
42 |
-
self.log("train-relation_loss", relik_output["re_loss"])
|
43 |
-
return relik_output["loss"]
|
44 |
-
|
45 |
-
def validation_step(
|
46 |
-
self, batch: dict, *args: Any, **kwargs: Any
|
47 |
-
) -> Optional[STEP_OUTPUT]:
|
48 |
-
return
|
49 |
-
|
50 |
-
def set_optimizer_factory(self, optimizer_factory) -> None:
|
51 |
-
self.optimizer_factory = optimizer_factory
|
52 |
-
|
53 |
-
def configure_optimizers(self) -> OptimizerLRScheduler:
|
54 |
-
return self.optimizer_factory(self.relik_reader_re_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/pytorch_modules/__init__.py
DELETED
File without changes
|
relik/reader/pytorch_modules/base.py
DELETED
@@ -1,248 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Any, Dict, List
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import transformers as tr
|
8 |
-
from torch.utils.data import IterableDataset
|
9 |
-
from transformers import AutoConfig
|
10 |
-
|
11 |
-
from relik.common.log import get_console_logger, get_logger
|
12 |
-
from relik.common.utils import get_callable_from_string
|
13 |
-
from relik.reader.pytorch_modules.hf.modeling_relik import (
|
14 |
-
RelikReaderConfig,
|
15 |
-
RelikReaderSample,
|
16 |
-
)
|
17 |
-
|
18 |
-
console_logger = get_console_logger()
|
19 |
-
logger = get_logger(__name__, level=logging.INFO)
|
20 |
-
|
21 |
-
|
22 |
-
class RelikReaderBase(torch.nn.Module):
|
23 |
-
default_reader_class: str | None = None
|
24 |
-
default_data_class: str | None = None
|
25 |
-
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
transformer_model: str | tr.PreTrainedModel | None = None,
|
29 |
-
additional_special_symbols: int = 0,
|
30 |
-
num_layers: int | None = None,
|
31 |
-
activation: str = "gelu",
|
32 |
-
linears_hidden_size: int | None = 512,
|
33 |
-
use_last_k_layers: int = 1,
|
34 |
-
training: bool = False,
|
35 |
-
device: str | torch.device | None = None,
|
36 |
-
precision: int = 32,
|
37 |
-
tokenizer: str | tr.PreTrainedTokenizer | None = None,
|
38 |
-
dataset: IterableDataset | str | None = None,
|
39 |
-
default_reader_class: tr.PreTrainedModel | str | None = None,
|
40 |
-
**kwargs,
|
41 |
-
) -> None:
|
42 |
-
super().__init__()
|
43 |
-
|
44 |
-
self.default_reader_class = default_reader_class or self.default_reader_class
|
45 |
-
|
46 |
-
if self.default_reader_class is None:
|
47 |
-
raise ValueError("You must specify a default reader class.")
|
48 |
-
|
49 |
-
# get the callable for the default reader class
|
50 |
-
self.default_reader_class: tr.PreTrainedModel = get_callable_from_string(
|
51 |
-
self.default_reader_class
|
52 |
-
)
|
53 |
-
|
54 |
-
if isinstance(transformer_model, str):
|
55 |
-
config = AutoConfig.from_pretrained(
|
56 |
-
transformer_model, trust_remote_code=True
|
57 |
-
)
|
58 |
-
if "relik-reader" in config.model_type:
|
59 |
-
transformer_model = self.default_reader_class.from_pretrained(
|
60 |
-
transformer_model, **kwargs
|
61 |
-
)
|
62 |
-
else:
|
63 |
-
reader_config = RelikReaderConfig(
|
64 |
-
transformer_model=transformer_model,
|
65 |
-
additional_special_symbols=additional_special_symbols,
|
66 |
-
num_layers=num_layers,
|
67 |
-
activation=activation,
|
68 |
-
linears_hidden_size=linears_hidden_size,
|
69 |
-
use_last_k_layers=use_last_k_layers,
|
70 |
-
training=training,
|
71 |
-
)
|
72 |
-
transformer_model = self.default_reader_class(reader_config)
|
73 |
-
|
74 |
-
self.relik_reader_model = transformer_model
|
75 |
-
self.relik_reader_model_config = self.relik_reader_model.config
|
76 |
-
|
77 |
-
# get the tokenizer
|
78 |
-
self._tokenizer = tokenizer
|
79 |
-
|
80 |
-
# and instantiate the dataset class
|
81 |
-
self.dataset: IterableDataset | None = dataset
|
82 |
-
|
83 |
-
# move the model to the device
|
84 |
-
self.to(device or torch.device("cpu"))
|
85 |
-
|
86 |
-
# set the precision
|
87 |
-
self.precision = precision
|
88 |
-
|
89 |
-
def forward(self, **kwargs) -> Dict[str, Any]:
|
90 |
-
return self.relik_reader_model(**kwargs)
|
91 |
-
|
92 |
-
def _read(self, *args, **kwargs) -> Any:
|
93 |
-
raise NotImplementedError
|
94 |
-
|
95 |
-
@torch.no_grad()
|
96 |
-
@torch.inference_mode()
|
97 |
-
def read(
|
98 |
-
self,
|
99 |
-
text: List[str] | List[List[str]] | None = None,
|
100 |
-
samples: List[RelikReaderSample] | None = None,
|
101 |
-
input_ids: torch.Tensor | None = None,
|
102 |
-
attention_mask: torch.Tensor | None = None,
|
103 |
-
token_type_ids: torch.Tensor | None = None,
|
104 |
-
prediction_mask: torch.Tensor | None = None,
|
105 |
-
special_symbols_mask: torch.Tensor | None = None,
|
106 |
-
candidates: List[List[str]] | None = None,
|
107 |
-
max_length: int = 1000,
|
108 |
-
max_batch_size: int = 128,
|
109 |
-
token_batch_size: int = 2048,
|
110 |
-
precision: int | str | None = None,
|
111 |
-
progress_bar: bool = False,
|
112 |
-
*args,
|
113 |
-
**kwargs,
|
114 |
-
) -> List[RelikReaderSample] | List[List[RelikReaderSample]]:
|
115 |
-
"""
|
116 |
-
Reads the given text.
|
117 |
-
|
118 |
-
Args:
|
119 |
-
text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`):
|
120 |
-
The text to read in tokens. If a list of list of tokens is provided, each
|
121 |
-
inner list is considered a sentence.
|
122 |
-
samples (:obj:`List[RelikReaderSample]`, `optional`):
|
123 |
-
The samples to read. If provided, `text` and `candidates` are ignored.
|
124 |
-
input_ids (:obj:`torch.Tensor`, `optional`):
|
125 |
-
The input ids of the text.
|
126 |
-
attention_mask (:obj:`torch.Tensor`, `optional`):
|
127 |
-
The attention mask of the text.
|
128 |
-
token_type_ids (:obj:`torch.Tensor`, `optional`):
|
129 |
-
The token type ids of the text.
|
130 |
-
prediction_mask (:obj:`torch.Tensor`, `optional`):
|
131 |
-
The prediction mask of the text.
|
132 |
-
special_symbols_mask (:obj:`torch.Tensor`, `optional`):
|
133 |
-
The special symbols mask of the text.
|
134 |
-
candidates (:obj:`List[List[str]]`, `optional`):
|
135 |
-
The candidates of the text.
|
136 |
-
max_length (:obj:`int`, `optional`, defaults to 1024):
|
137 |
-
The maximum length of the text.
|
138 |
-
max_batch_size (:obj:`int`, `optional`, defaults to 128):
|
139 |
-
The maximum batch size.
|
140 |
-
token_batch_size (:obj:`int`, `optional`):
|
141 |
-
The maximum number of tokens per batch.
|
142 |
-
precision (:obj:`int` or :obj:`str`, `optional`):
|
143 |
-
The precision to use. If not provided, the default is 32 bit.
|
144 |
-
progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
145 |
-
Whether to show a progress bar.
|
146 |
-
|
147 |
-
Returns:
|
148 |
-
The predicted labels for each sample.
|
149 |
-
"""
|
150 |
-
if text is None and input_ids is None and samples is None:
|
151 |
-
raise ValueError(
|
152 |
-
"Either `text` or `input_ids` or `samples` must be provided."
|
153 |
-
)
|
154 |
-
if (input_ids is None and samples is None) and (
|
155 |
-
text is None or candidates is None
|
156 |
-
):
|
157 |
-
raise ValueError(
|
158 |
-
"`text` and `candidates` must be provided to return the predictions when "
|
159 |
-
"`input_ids` and `samples` is not provided."
|
160 |
-
)
|
161 |
-
if text is not None and samples is None:
|
162 |
-
if len(text) != len(candidates):
|
163 |
-
raise ValueError("`text` and `candidates` must have the same length.")
|
164 |
-
if isinstance(text[0], str): # change to list of text
|
165 |
-
text = [text]
|
166 |
-
candidates = [candidates]
|
167 |
-
|
168 |
-
samples = [
|
169 |
-
RelikReaderSample(tokens=t, candidates=c)
|
170 |
-
for t, c in zip(text, candidates)
|
171 |
-
]
|
172 |
-
|
173 |
-
return self._read(
|
174 |
-
samples,
|
175 |
-
input_ids,
|
176 |
-
attention_mask,
|
177 |
-
token_type_ids,
|
178 |
-
prediction_mask,
|
179 |
-
special_symbols_mask,
|
180 |
-
max_length,
|
181 |
-
max_batch_size,
|
182 |
-
token_batch_size,
|
183 |
-
precision or self.precision,
|
184 |
-
progress_bar,
|
185 |
-
*args,
|
186 |
-
**kwargs,
|
187 |
-
)
|
188 |
-
|
189 |
-
@property
|
190 |
-
def device(self) -> torch.device:
|
191 |
-
"""
|
192 |
-
The device of the model.
|
193 |
-
"""
|
194 |
-
return next(self.parameters()).device
|
195 |
-
|
196 |
-
@property
|
197 |
-
def tokenizer(self) -> tr.PreTrainedTokenizer:
|
198 |
-
"""
|
199 |
-
The tokenizer.
|
200 |
-
"""
|
201 |
-
if self._tokenizer:
|
202 |
-
return self._tokenizer
|
203 |
-
|
204 |
-
self._tokenizer = tr.AutoTokenizer.from_pretrained(
|
205 |
-
self.relik_reader_model.config.name_or_path
|
206 |
-
)
|
207 |
-
return self._tokenizer
|
208 |
-
|
209 |
-
def save_pretrained(
|
210 |
-
self,
|
211 |
-
output_dir: str | os.PathLike,
|
212 |
-
model_name: str | None = None,
|
213 |
-
push_to_hub: bool = False,
|
214 |
-
**kwargs,
|
215 |
-
) -> None:
|
216 |
-
"""
|
217 |
-
Saves the model to the given path.
|
218 |
-
|
219 |
-
Args:
|
220 |
-
output_dir (`str` or :obj:`os.PathLike`):
|
221 |
-
The path to save the model to.
|
222 |
-
model_name (`str`, `optional`):
|
223 |
-
The name of the model. If not provided, the model will be saved as
|
224 |
-
`default_reader_class.__name__`.
|
225 |
-
push_to_hub (`bool`, `optional`, defaults to `False`):
|
226 |
-
Whether to push the model to the HuggingFace Hub.
|
227 |
-
**kwargs:
|
228 |
-
Additional keyword arguments to pass to the `save_pretrained` method
|
229 |
-
"""
|
230 |
-
# create the output directory
|
231 |
-
output_dir = Path(output_dir)
|
232 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
233 |
-
|
234 |
-
model_name = model_name or self.default_reader_class.__name__
|
235 |
-
|
236 |
-
logger.info(f"Saving reader to {output_dir / model_name}")
|
237 |
-
|
238 |
-
# save the model
|
239 |
-
self.relik_reader_model.register_for_auto_class()
|
240 |
-
self.relik_reader_model.save_pretrained(
|
241 |
-
output_dir / model_name, push_to_hub=push_to_hub, **kwargs
|
242 |
-
)
|
243 |
-
|
244 |
-
if self.tokenizer:
|
245 |
-
logger.info("Saving also the tokenizer")
|
246 |
-
self.tokenizer.save_pretrained(
|
247 |
-
output_dir / model_name, push_to_hub=push_to_hub, **kwargs
|
248 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/pytorch_modules/hf/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .configuration_relik import RelikReaderConfig
|
2 |
-
from .modeling_relik import RelikReaderREModel
|
|
|
|
|
|
relik/reader/pytorch_modules/hf/configuration_relik.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
|
3 |
-
from transformers import AutoConfig
|
4 |
-
from transformers.configuration_utils import PretrainedConfig
|
5 |
-
|
6 |
-
|
7 |
-
class RelikReaderConfig(PretrainedConfig):
|
8 |
-
model_type = "relik-reader"
|
9 |
-
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
transformer_model: str = "microsoft/deberta-v3-base",
|
13 |
-
additional_special_symbols: int = 101,
|
14 |
-
num_layers: Optional[int] = None,
|
15 |
-
activation: str = "gelu",
|
16 |
-
linears_hidden_size: Optional[int] = 512,
|
17 |
-
use_last_k_layers: int = 1,
|
18 |
-
training: bool = False,
|
19 |
-
default_reader_class: Optional[str] = None,
|
20 |
-
**kwargs
|
21 |
-
) -> None:
|
22 |
-
self.transformer_model = transformer_model
|
23 |
-
self.additional_special_symbols = additional_special_symbols
|
24 |
-
self.num_layers = num_layers
|
25 |
-
self.activation = activation
|
26 |
-
self.linears_hidden_size = linears_hidden_size
|
27 |
-
self.use_last_k_layers = use_last_k_layers
|
28 |
-
self.training = training
|
29 |
-
self.default_reader_class = default_reader_class
|
30 |
-
super().__init__(**kwargs)
|
31 |
-
|
32 |
-
|
33 |
-
AutoConfig.register("relik-reader", RelikReaderConfig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/pytorch_modules/hf/modeling_relik.py
DELETED
@@ -1,981 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, Optional
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from transformers import AutoModel, PreTrainedModel
|
5 |
-
from transformers.activations import ClippedGELUActivation, GELUActivation
|
6 |
-
from transformers.configuration_utils import PretrainedConfig
|
7 |
-
from transformers.modeling_utils import PoolerEndLogits
|
8 |
-
|
9 |
-
from .configuration_relik import RelikReaderConfig
|
10 |
-
|
11 |
-
|
12 |
-
class RelikReaderSample:
|
13 |
-
def __init__(self, **kwargs):
|
14 |
-
super().__setattr__("_d", {})
|
15 |
-
self._d = kwargs
|
16 |
-
|
17 |
-
def __getattribute__(self, item):
|
18 |
-
return super(RelikReaderSample, self).__getattribute__(item)
|
19 |
-
|
20 |
-
def __getattr__(self, item):
|
21 |
-
if item.startswith("__") and item.endswith("__"):
|
22 |
-
# this is likely some python library-specific variable (such as __deepcopy__ for copy)
|
23 |
-
# better follow standard behavior here
|
24 |
-
raise AttributeError(item)
|
25 |
-
elif item in self._d:
|
26 |
-
return self._d[item]
|
27 |
-
else:
|
28 |
-
return None
|
29 |
-
|
30 |
-
def __setattr__(self, key, value):
|
31 |
-
if key in self._d:
|
32 |
-
self._d[key] = value
|
33 |
-
else:
|
34 |
-
super().__setattr__(key, value)
|
35 |
-
|
36 |
-
|
37 |
-
activation2functions = {
|
38 |
-
"relu": torch.nn.ReLU(),
|
39 |
-
"gelu": GELUActivation(),
|
40 |
-
"gelu_10": ClippedGELUActivation(-10, 10),
|
41 |
-
}
|
42 |
-
|
43 |
-
|
44 |
-
class PoolerEndLogitsBi(PoolerEndLogits):
|
45 |
-
def __init__(self, config: PretrainedConfig):
|
46 |
-
super().__init__(config)
|
47 |
-
self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
|
48 |
-
|
49 |
-
def forward(
|
50 |
-
self,
|
51 |
-
hidden_states: torch.FloatTensor,
|
52 |
-
start_states: Optional[torch.FloatTensor] = None,
|
53 |
-
start_positions: Optional[torch.LongTensor] = None,
|
54 |
-
p_mask: Optional[torch.FloatTensor] = None,
|
55 |
-
) -> torch.FloatTensor:
|
56 |
-
if p_mask is not None:
|
57 |
-
p_mask = p_mask.unsqueeze(-1)
|
58 |
-
logits = super().forward(
|
59 |
-
hidden_states,
|
60 |
-
start_states,
|
61 |
-
start_positions,
|
62 |
-
p_mask,
|
63 |
-
)
|
64 |
-
return logits
|
65 |
-
|
66 |
-
|
67 |
-
class RelikReaderSpanModel(PreTrainedModel):
|
68 |
-
config_class = RelikReaderConfig
|
69 |
-
|
70 |
-
def __init__(self, config: RelikReaderConfig, *args, **kwargs):
|
71 |
-
super().__init__(config)
|
72 |
-
# Transformer model declaration
|
73 |
-
self.config = config
|
74 |
-
self.transformer_model = (
|
75 |
-
AutoModel.from_pretrained(self.config.transformer_model)
|
76 |
-
if self.config.num_layers is None
|
77 |
-
else AutoModel.from_pretrained(
|
78 |
-
self.config.transformer_model, num_hidden_layers=self.config.num_layers
|
79 |
-
)
|
80 |
-
)
|
81 |
-
self.transformer_model.resize_token_embeddings(
|
82 |
-
self.transformer_model.config.vocab_size
|
83 |
-
+ self.config.additional_special_symbols
|
84 |
-
)
|
85 |
-
|
86 |
-
self.activation = self.config.activation
|
87 |
-
self.linears_hidden_size = self.config.linears_hidden_size
|
88 |
-
self.use_last_k_layers = self.config.use_last_k_layers
|
89 |
-
|
90 |
-
# named entity detection layers
|
91 |
-
self.ned_start_classifier = self._get_projection_layer(
|
92 |
-
self.activation, last_hidden=2, layer_norm=False
|
93 |
-
)
|
94 |
-
self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
|
95 |
-
|
96 |
-
# END entity disambiguation layer
|
97 |
-
self.ed_start_projector = self._get_projection_layer(self.activation)
|
98 |
-
self.ed_end_projector = self._get_projection_layer(self.activation)
|
99 |
-
|
100 |
-
self.training = self.config.training
|
101 |
-
|
102 |
-
# criterion
|
103 |
-
self.criterion = torch.nn.CrossEntropyLoss()
|
104 |
-
|
105 |
-
def _get_projection_layer(
|
106 |
-
self,
|
107 |
-
activation: str,
|
108 |
-
last_hidden: Optional[int] = None,
|
109 |
-
input_hidden=None,
|
110 |
-
layer_norm: bool = True,
|
111 |
-
) -> torch.nn.Sequential:
|
112 |
-
head_components = [
|
113 |
-
torch.nn.Dropout(0.1),
|
114 |
-
torch.nn.Linear(
|
115 |
-
self.transformer_model.config.hidden_size * self.use_last_k_layers
|
116 |
-
if input_hidden is None
|
117 |
-
else input_hidden,
|
118 |
-
self.linears_hidden_size,
|
119 |
-
),
|
120 |
-
activation2functions[activation],
|
121 |
-
torch.nn.Dropout(0.1),
|
122 |
-
torch.nn.Linear(
|
123 |
-
self.linears_hidden_size,
|
124 |
-
self.linears_hidden_size if last_hidden is None else last_hidden,
|
125 |
-
),
|
126 |
-
]
|
127 |
-
|
128 |
-
if layer_norm:
|
129 |
-
head_components.append(
|
130 |
-
torch.nn.LayerNorm(
|
131 |
-
self.linears_hidden_size if last_hidden is None else last_hidden,
|
132 |
-
self.transformer_model.config.layer_norm_eps,
|
133 |
-
)
|
134 |
-
)
|
135 |
-
|
136 |
-
return torch.nn.Sequential(*head_components)
|
137 |
-
|
138 |
-
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
139 |
-
mask = mask.unsqueeze(-1)
|
140 |
-
if next(self.parameters()).dtype == torch.float16:
|
141 |
-
logits = logits * (1 - mask) - 65500 * mask
|
142 |
-
else:
|
143 |
-
logits = logits * (1 - mask) - 1e30 * mask
|
144 |
-
return logits
|
145 |
-
|
146 |
-
def _get_model_features(
|
147 |
-
self,
|
148 |
-
input_ids: torch.Tensor,
|
149 |
-
attention_mask: torch.Tensor,
|
150 |
-
token_type_ids: Optional[torch.Tensor],
|
151 |
-
):
|
152 |
-
model_input = {
|
153 |
-
"input_ids": input_ids,
|
154 |
-
"attention_mask": attention_mask,
|
155 |
-
"output_hidden_states": self.use_last_k_layers > 1,
|
156 |
-
}
|
157 |
-
|
158 |
-
if token_type_ids is not None:
|
159 |
-
model_input["token_type_ids"] = token_type_ids
|
160 |
-
|
161 |
-
model_output = self.transformer_model(**model_input)
|
162 |
-
|
163 |
-
if self.use_last_k_layers > 1:
|
164 |
-
model_features = torch.cat(
|
165 |
-
model_output[1][-self.use_last_k_layers :], dim=-1
|
166 |
-
)
|
167 |
-
else:
|
168 |
-
model_features = model_output[0]
|
169 |
-
|
170 |
-
return model_features
|
171 |
-
|
172 |
-
def compute_ned_end_logits(
|
173 |
-
self,
|
174 |
-
start_predictions,
|
175 |
-
start_labels,
|
176 |
-
model_features,
|
177 |
-
prediction_mask,
|
178 |
-
batch_size,
|
179 |
-
) -> Optional[torch.Tensor]:
|
180 |
-
# todo: maybe when constraining on the spans,
|
181 |
-
# we should not use a prediction_mask for the end tokens.
|
182 |
-
# at least we should not during training imo
|
183 |
-
start_positions = start_labels if self.training else start_predictions
|
184 |
-
start_positions_indices = (
|
185 |
-
torch.arange(start_positions.size(1), device=start_positions.device)
|
186 |
-
.unsqueeze(0)
|
187 |
-
.expand(batch_size, -1)[start_positions > 0]
|
188 |
-
).to(start_positions.device)
|
189 |
-
|
190 |
-
if len(start_positions_indices) > 0:
|
191 |
-
expanded_features = torch.cat(
|
192 |
-
[
|
193 |
-
model_features[i].unsqueeze(0).expand(x, -1, -1)
|
194 |
-
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
|
195 |
-
if x > 0
|
196 |
-
],
|
197 |
-
dim=0,
|
198 |
-
).to(start_positions_indices.device)
|
199 |
-
|
200 |
-
expanded_prediction_mask = torch.cat(
|
201 |
-
[
|
202 |
-
prediction_mask[i].unsqueeze(0).expand(x, -1)
|
203 |
-
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
|
204 |
-
if x > 0
|
205 |
-
],
|
206 |
-
dim=0,
|
207 |
-
).to(expanded_features.device)
|
208 |
-
|
209 |
-
end_logits = self.ned_end_classifier(
|
210 |
-
hidden_states=expanded_features,
|
211 |
-
start_positions=start_positions_indices,
|
212 |
-
p_mask=expanded_prediction_mask,
|
213 |
-
)
|
214 |
-
|
215 |
-
return end_logits
|
216 |
-
|
217 |
-
return None
|
218 |
-
|
219 |
-
def compute_classification_logits(
|
220 |
-
self,
|
221 |
-
model_features,
|
222 |
-
special_symbols_mask,
|
223 |
-
prediction_mask,
|
224 |
-
batch_size,
|
225 |
-
start_positions=None,
|
226 |
-
end_positions=None,
|
227 |
-
) -> torch.Tensor:
|
228 |
-
if start_positions is None or end_positions is None:
|
229 |
-
start_positions = torch.zeros_like(prediction_mask)
|
230 |
-
end_positions = torch.zeros_like(prediction_mask)
|
231 |
-
|
232 |
-
model_start_features = self.ed_start_projector(model_features)
|
233 |
-
model_end_features = self.ed_end_projector(model_features)
|
234 |
-
model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
|
235 |
-
|
236 |
-
model_ed_features = torch.cat(
|
237 |
-
[model_start_features, model_end_features], dim=-1
|
238 |
-
)
|
239 |
-
|
240 |
-
# computing ed features
|
241 |
-
classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
|
242 |
-
special_symbols_representation = model_ed_features[special_symbols_mask].view(
|
243 |
-
batch_size, classes_representations, -1
|
244 |
-
)
|
245 |
-
|
246 |
-
logits = torch.bmm(
|
247 |
-
model_ed_features,
|
248 |
-
torch.permute(special_symbols_representation, (0, 2, 1)),
|
249 |
-
)
|
250 |
-
|
251 |
-
logits = self._mask_logits(logits, prediction_mask)
|
252 |
-
|
253 |
-
return logits
|
254 |
-
|
255 |
-
def forward(
|
256 |
-
self,
|
257 |
-
input_ids: torch.Tensor,
|
258 |
-
attention_mask: torch.Tensor,
|
259 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
260 |
-
prediction_mask: Optional[torch.Tensor] = None,
|
261 |
-
special_symbols_mask: Optional[torch.Tensor] = None,
|
262 |
-
start_labels: Optional[torch.Tensor] = None,
|
263 |
-
end_labels: Optional[torch.Tensor] = None,
|
264 |
-
use_predefined_spans: bool = False,
|
265 |
-
*args,
|
266 |
-
**kwargs,
|
267 |
-
) -> Dict[str, Any]:
|
268 |
-
batch_size, seq_len = input_ids.shape
|
269 |
-
|
270 |
-
model_features = self._get_model_features(
|
271 |
-
input_ids, attention_mask, token_type_ids
|
272 |
-
)
|
273 |
-
|
274 |
-
ned_start_labels = None
|
275 |
-
|
276 |
-
# named entity detection if required
|
277 |
-
if use_predefined_spans: # no need to compute spans
|
278 |
-
ned_start_logits, ned_start_probabilities, ned_start_predictions = (
|
279 |
-
None,
|
280 |
-
None,
|
281 |
-
torch.clone(start_labels)
|
282 |
-
if start_labels is not None
|
283 |
-
else torch.zeros_like(input_ids),
|
284 |
-
)
|
285 |
-
ned_end_logits, ned_end_probabilities, ned_end_predictions = (
|
286 |
-
None,
|
287 |
-
None,
|
288 |
-
torch.clone(end_labels)
|
289 |
-
if end_labels is not None
|
290 |
-
else torch.zeros_like(input_ids),
|
291 |
-
)
|
292 |
-
|
293 |
-
ned_start_predictions[ned_start_predictions > 0] = 1
|
294 |
-
ned_end_predictions[ned_end_predictions > 0] = 1
|
295 |
-
|
296 |
-
else: # compute spans
|
297 |
-
# start boundary prediction
|
298 |
-
ned_start_logits = self.ned_start_classifier(model_features)
|
299 |
-
ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
|
300 |
-
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
|
301 |
-
ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
|
302 |
-
|
303 |
-
# end boundary prediction
|
304 |
-
ned_start_labels = (
|
305 |
-
torch.zeros_like(start_labels) if start_labels is not None else None
|
306 |
-
)
|
307 |
-
|
308 |
-
if ned_start_labels is not None:
|
309 |
-
ned_start_labels[start_labels == -100] = -100
|
310 |
-
ned_start_labels[start_labels > 0] = 1
|
311 |
-
|
312 |
-
ned_end_logits = self.compute_ned_end_logits(
|
313 |
-
ned_start_predictions,
|
314 |
-
ned_start_labels,
|
315 |
-
model_features,
|
316 |
-
prediction_mask,
|
317 |
-
batch_size,
|
318 |
-
)
|
319 |
-
|
320 |
-
if ned_end_logits is not None:
|
321 |
-
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
|
322 |
-
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
|
323 |
-
else:
|
324 |
-
ned_end_logits, ned_end_probabilities = None, None
|
325 |
-
ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
|
326 |
-
|
327 |
-
# flattening end predictions
|
328 |
-
# (flattening can happen only if the
|
329 |
-
# end boundaries were not predicted using the gold labels)
|
330 |
-
if not self.training:
|
331 |
-
flattened_end_predictions = torch.clone(ned_start_predictions)
|
332 |
-
flattened_end_predictions[flattened_end_predictions > 0] = 0
|
333 |
-
|
334 |
-
batch_start_predictions = list()
|
335 |
-
for elem_idx in range(batch_size):
|
336 |
-
batch_start_predictions.append(
|
337 |
-
torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
|
338 |
-
)
|
339 |
-
|
340 |
-
# check that the total number of start predictions
|
341 |
-
# is equal to the end predictions
|
342 |
-
total_start_predictions = sum(map(len, batch_start_predictions))
|
343 |
-
total_end_predictions = len(ned_end_predictions)
|
344 |
-
assert (
|
345 |
-
total_start_predictions == 0
|
346 |
-
or total_start_predictions == total_end_predictions
|
347 |
-
), (
|
348 |
-
f"Total number of start predictions = {total_start_predictions}. "
|
349 |
-
f"Total number of end predictions = {total_end_predictions}"
|
350 |
-
)
|
351 |
-
|
352 |
-
curr_end_pred_num = 0
|
353 |
-
for elem_idx, bsp in enumerate(batch_start_predictions):
|
354 |
-
for sp in bsp:
|
355 |
-
ep = ned_end_predictions[curr_end_pred_num].item()
|
356 |
-
if ep < sp:
|
357 |
-
ep = sp
|
358 |
-
|
359 |
-
# if we already set this span throw it (no overlap)
|
360 |
-
if flattened_end_predictions[elem_idx, ep] == 1:
|
361 |
-
ned_start_predictions[elem_idx, sp] = 0
|
362 |
-
else:
|
363 |
-
flattened_end_predictions[elem_idx, ep] = 1
|
364 |
-
|
365 |
-
curr_end_pred_num += 1
|
366 |
-
|
367 |
-
ned_end_predictions = flattened_end_predictions
|
368 |
-
|
369 |
-
start_position, end_position = (
|
370 |
-
(start_labels, end_labels)
|
371 |
-
if self.training
|
372 |
-
else (ned_start_predictions, ned_end_predictions)
|
373 |
-
)
|
374 |
-
|
375 |
-
# Entity disambiguation
|
376 |
-
ed_logits = self.compute_classification_logits(
|
377 |
-
model_features,
|
378 |
-
special_symbols_mask,
|
379 |
-
prediction_mask,
|
380 |
-
batch_size,
|
381 |
-
start_position,
|
382 |
-
end_position,
|
383 |
-
)
|
384 |
-
ed_probabilities = torch.softmax(ed_logits, dim=-1)
|
385 |
-
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
|
386 |
-
|
387 |
-
# output build
|
388 |
-
output_dict = dict(
|
389 |
-
batch_size=batch_size,
|
390 |
-
ned_start_logits=ned_start_logits,
|
391 |
-
ned_start_probabilities=ned_start_probabilities,
|
392 |
-
ned_start_predictions=ned_start_predictions,
|
393 |
-
ned_end_logits=ned_end_logits,
|
394 |
-
ned_end_probabilities=ned_end_probabilities,
|
395 |
-
ned_end_predictions=ned_end_predictions,
|
396 |
-
ed_logits=ed_logits,
|
397 |
-
ed_probabilities=ed_probabilities,
|
398 |
-
ed_predictions=ed_predictions,
|
399 |
-
)
|
400 |
-
|
401 |
-
# compute loss if labels
|
402 |
-
if start_labels is not None and end_labels is not None and self.training:
|
403 |
-
# named entity detection loss
|
404 |
-
|
405 |
-
# start
|
406 |
-
if ned_start_logits is not None:
|
407 |
-
ned_start_loss = self.criterion(
|
408 |
-
ned_start_logits.view(-1, ned_start_logits.shape[-1]),
|
409 |
-
ned_start_labels.view(-1),
|
410 |
-
)
|
411 |
-
else:
|
412 |
-
ned_start_loss = 0
|
413 |
-
|
414 |
-
# end
|
415 |
-
if ned_end_logits is not None:
|
416 |
-
ned_end_labels = torch.zeros_like(end_labels)
|
417 |
-
ned_end_labels[end_labels == -100] = -100
|
418 |
-
ned_end_labels[end_labels > 0] = 1
|
419 |
-
|
420 |
-
ned_end_loss = self.criterion(
|
421 |
-
ned_end_logits,
|
422 |
-
(
|
423 |
-
torch.arange(
|
424 |
-
ned_end_labels.size(1), device=ned_end_labels.device
|
425 |
-
)
|
426 |
-
.unsqueeze(0)
|
427 |
-
.expand(batch_size, -1)[ned_end_labels > 0]
|
428 |
-
).to(ned_end_labels.device),
|
429 |
-
)
|
430 |
-
|
431 |
-
else:
|
432 |
-
ned_end_loss = 0
|
433 |
-
|
434 |
-
# entity disambiguation loss
|
435 |
-
start_labels[ned_start_labels != 1] = -100
|
436 |
-
ed_labels = torch.clone(start_labels)
|
437 |
-
ed_labels[end_labels > 0] = end_labels[end_labels > 0]
|
438 |
-
ed_loss = self.criterion(
|
439 |
-
ed_logits.view(-1, ed_logits.shape[-1]),
|
440 |
-
ed_labels.view(-1),
|
441 |
-
)
|
442 |
-
|
443 |
-
output_dict["ned_start_loss"] = ned_start_loss
|
444 |
-
output_dict["ned_end_loss"] = ned_end_loss
|
445 |
-
output_dict["ed_loss"] = ed_loss
|
446 |
-
|
447 |
-
output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
|
448 |
-
|
449 |
-
return output_dict
|
450 |
-
|
451 |
-
|
452 |
-
class RelikReaderREModel(PreTrainedModel):
|
453 |
-
config_class = RelikReaderConfig
|
454 |
-
|
455 |
-
def __init__(self, config, *args, **kwargs):
|
456 |
-
super().__init__(config)
|
457 |
-
# Transformer model declaration
|
458 |
-
# self.transformer_model_name = transformer_model
|
459 |
-
self.config = config
|
460 |
-
self.transformer_model = (
|
461 |
-
AutoModel.from_pretrained(config.transformer_model)
|
462 |
-
if config.num_layers is None
|
463 |
-
else AutoModel.from_pretrained(
|
464 |
-
config.transformer_model, num_hidden_layers=config.num_layers
|
465 |
-
)
|
466 |
-
)
|
467 |
-
self.transformer_model.resize_token_embeddings(
|
468 |
-
self.transformer_model.config.vocab_size + config.additional_special_symbols
|
469 |
-
)
|
470 |
-
|
471 |
-
# named entity detection layers
|
472 |
-
self.ned_start_classifier = self._get_projection_layer(
|
473 |
-
config.activation, last_hidden=2, layer_norm=False
|
474 |
-
)
|
475 |
-
|
476 |
-
self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
|
477 |
-
|
478 |
-
self.entity_type_loss = (
|
479 |
-
config.entity_type_loss if hasattr(config, "entity_type_loss") else False
|
480 |
-
)
|
481 |
-
self.relation_disambiguation_loss = (
|
482 |
-
config.relation_disambiguation_loss
|
483 |
-
if hasattr(config, "relation_disambiguation_loss")
|
484 |
-
else False
|
485 |
-
)
|
486 |
-
|
487 |
-
input_hidden_ents = 2 * self.transformer_model.config.hidden_size
|
488 |
-
|
489 |
-
self.re_subject_projector = self._get_projection_layer(
|
490 |
-
config.activation, input_hidden=input_hidden_ents
|
491 |
-
)
|
492 |
-
self.re_object_projector = self._get_projection_layer(
|
493 |
-
config.activation, input_hidden=input_hidden_ents
|
494 |
-
)
|
495 |
-
self.re_relation_projector = self._get_projection_layer(config.activation)
|
496 |
-
|
497 |
-
if self.entity_type_loss or self.relation_disambiguation_loss:
|
498 |
-
self.re_entities_projector = self._get_projection_layer(
|
499 |
-
config.activation,
|
500 |
-
input_hidden=2 * self.transformer_model.config.hidden_size,
|
501 |
-
)
|
502 |
-
self.re_definition_projector = self._get_projection_layer(
|
503 |
-
config.activation,
|
504 |
-
)
|
505 |
-
|
506 |
-
self.re_classifier = self._get_projection_layer(
|
507 |
-
config.activation,
|
508 |
-
input_hidden=config.linears_hidden_size,
|
509 |
-
last_hidden=2,
|
510 |
-
layer_norm=False,
|
511 |
-
)
|
512 |
-
|
513 |
-
if self.entity_type_loss or self.relation_disambiguation_loss:
|
514 |
-
self.re_ed_classifier = self._get_projection_layer(
|
515 |
-
config.activation,
|
516 |
-
input_hidden=config.linears_hidden_size,
|
517 |
-
last_hidden=2,
|
518 |
-
layer_norm=False,
|
519 |
-
)
|
520 |
-
|
521 |
-
self.training = config.training
|
522 |
-
|
523 |
-
# criterion
|
524 |
-
self.criterion = torch.nn.CrossEntropyLoss()
|
525 |
-
|
526 |
-
def _get_projection_layer(
|
527 |
-
self,
|
528 |
-
activation: str,
|
529 |
-
last_hidden: Optional[int] = None,
|
530 |
-
input_hidden=None,
|
531 |
-
layer_norm: bool = True,
|
532 |
-
) -> torch.nn.Sequential:
|
533 |
-
head_components = [
|
534 |
-
torch.nn.Dropout(0.1),
|
535 |
-
torch.nn.Linear(
|
536 |
-
self.transformer_model.config.hidden_size
|
537 |
-
* self.config.use_last_k_layers
|
538 |
-
if input_hidden is None
|
539 |
-
else input_hidden,
|
540 |
-
self.config.linears_hidden_size,
|
541 |
-
),
|
542 |
-
activation2functions[activation],
|
543 |
-
torch.nn.Dropout(0.1),
|
544 |
-
torch.nn.Linear(
|
545 |
-
self.config.linears_hidden_size,
|
546 |
-
self.config.linears_hidden_size if last_hidden is None else last_hidden,
|
547 |
-
),
|
548 |
-
]
|
549 |
-
|
550 |
-
if layer_norm:
|
551 |
-
head_components.append(
|
552 |
-
torch.nn.LayerNorm(
|
553 |
-
self.config.linears_hidden_size
|
554 |
-
if last_hidden is None
|
555 |
-
else last_hidden,
|
556 |
-
self.transformer_model.config.layer_norm_eps,
|
557 |
-
)
|
558 |
-
)
|
559 |
-
|
560 |
-
return torch.nn.Sequential(*head_components)
|
561 |
-
|
562 |
-
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
563 |
-
mask = mask.unsqueeze(-1)
|
564 |
-
if next(self.parameters()).dtype == torch.float16:
|
565 |
-
logits = logits * (1 - mask) - 65500 * mask
|
566 |
-
else:
|
567 |
-
logits = logits * (1 - mask) - 1e30 * mask
|
568 |
-
return logits
|
569 |
-
|
570 |
-
def _get_model_features(
|
571 |
-
self,
|
572 |
-
input_ids: torch.Tensor,
|
573 |
-
attention_mask: torch.Tensor,
|
574 |
-
token_type_ids: Optional[torch.Tensor],
|
575 |
-
):
|
576 |
-
model_input = {
|
577 |
-
"input_ids": input_ids,
|
578 |
-
"attention_mask": attention_mask,
|
579 |
-
"output_hidden_states": self.config.use_last_k_layers > 1,
|
580 |
-
}
|
581 |
-
|
582 |
-
if token_type_ids is not None:
|
583 |
-
model_input["token_type_ids"] = token_type_ids
|
584 |
-
|
585 |
-
model_output = self.transformer_model(**model_input)
|
586 |
-
|
587 |
-
if self.config.use_last_k_layers > 1:
|
588 |
-
model_features = torch.cat(
|
589 |
-
model_output[1][-self.config.use_last_k_layers :], dim=-1
|
590 |
-
)
|
591 |
-
else:
|
592 |
-
model_features = model_output[0]
|
593 |
-
|
594 |
-
return model_features
|
595 |
-
|
596 |
-
def compute_ned_end_logits(
|
597 |
-
self,
|
598 |
-
start_predictions,
|
599 |
-
start_labels,
|
600 |
-
model_features,
|
601 |
-
prediction_mask,
|
602 |
-
batch_size,
|
603 |
-
) -> Optional[torch.Tensor]:
|
604 |
-
# todo: maybe when constraining on the spans,
|
605 |
-
# we should not use a prediction_mask for the end tokens.
|
606 |
-
# at least we should not during training imo
|
607 |
-
start_positions = start_labels if self.training else start_predictions
|
608 |
-
start_positions_indices = (
|
609 |
-
torch.arange(start_positions.size(1), device=start_positions.device)
|
610 |
-
.unsqueeze(0)
|
611 |
-
.expand(batch_size, -1)[start_positions > 0]
|
612 |
-
).to(start_positions.device)
|
613 |
-
|
614 |
-
if len(start_positions_indices) > 0:
|
615 |
-
expanded_features = torch.cat(
|
616 |
-
[
|
617 |
-
model_features[i].unsqueeze(0).expand(x, -1, -1)
|
618 |
-
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
|
619 |
-
if x > 0
|
620 |
-
],
|
621 |
-
dim=0,
|
622 |
-
).to(start_positions_indices.device)
|
623 |
-
|
624 |
-
expanded_prediction_mask = torch.cat(
|
625 |
-
[
|
626 |
-
prediction_mask[i].unsqueeze(0).expand(x, -1)
|
627 |
-
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
|
628 |
-
if x > 0
|
629 |
-
],
|
630 |
-
dim=0,
|
631 |
-
).to(expanded_features.device)
|
632 |
-
|
633 |
-
# mask all tokens before start_positions_indices ie, mask all tokens with
|
634 |
-
# indices < start_positions_indices with 1, ie. [range(x) for x in start_positions_indices]
|
635 |
-
expanded_prediction_mask = torch.stack(
|
636 |
-
[
|
637 |
-
torch.cat(
|
638 |
-
[
|
639 |
-
torch.ones(x, device=expanded_features.device),
|
640 |
-
expanded_prediction_mask[i, x:],
|
641 |
-
]
|
642 |
-
)
|
643 |
-
for i, x in enumerate(start_positions_indices)
|
644 |
-
if x > 0
|
645 |
-
],
|
646 |
-
dim=0,
|
647 |
-
).to(expanded_features.device)
|
648 |
-
|
649 |
-
end_logits = self.ned_end_classifier(
|
650 |
-
hidden_states=expanded_features,
|
651 |
-
start_positions=start_positions_indices,
|
652 |
-
p_mask=expanded_prediction_mask,
|
653 |
-
)
|
654 |
-
|
655 |
-
return end_logits
|
656 |
-
|
657 |
-
return None
|
658 |
-
|
659 |
-
def compute_relation_logits(
|
660 |
-
self,
|
661 |
-
model_entity_features,
|
662 |
-
special_symbols_features,
|
663 |
-
) -> torch.Tensor:
|
664 |
-
model_subject_features = self.re_subject_projector(model_entity_features)
|
665 |
-
model_object_features = self.re_object_projector(model_entity_features)
|
666 |
-
special_symbols_start_representation = self.re_relation_projector(
|
667 |
-
special_symbols_features
|
668 |
-
)
|
669 |
-
re_logits = torch.einsum(
|
670 |
-
"bse,bde,bfe->bsdfe",
|
671 |
-
model_subject_features,
|
672 |
-
model_object_features,
|
673 |
-
special_symbols_start_representation,
|
674 |
-
)
|
675 |
-
re_logits = self.re_classifier(re_logits)
|
676 |
-
|
677 |
-
return re_logits
|
678 |
-
|
679 |
-
def compute_entity_logits(
|
680 |
-
self,
|
681 |
-
model_entity_features,
|
682 |
-
special_symbols_features,
|
683 |
-
) -> torch.Tensor:
|
684 |
-
model_ed_features = self.re_entities_projector(model_entity_features)
|
685 |
-
special_symbols_ed_representation = self.re_definition_projector(
|
686 |
-
special_symbols_features
|
687 |
-
)
|
688 |
-
logits = torch.einsum(
|
689 |
-
"bce,bde->bcde",
|
690 |
-
model_ed_features,
|
691 |
-
special_symbols_ed_representation,
|
692 |
-
)
|
693 |
-
logits = self.re_ed_classifier(logits)
|
694 |
-
start_logits = self._mask_logits(
|
695 |
-
logits,
|
696 |
-
(model_entity_features == -100)
|
697 |
-
.all(2)
|
698 |
-
.long()
|
699 |
-
.unsqueeze(2)
|
700 |
-
.repeat(1, 1, torch.sum(model_entity_features, dim=1)[0].item()),
|
701 |
-
)
|
702 |
-
|
703 |
-
return logits
|
704 |
-
|
705 |
-
def compute_loss(self, logits, labels, mask=None):
|
706 |
-
logits = logits.view(-1, logits.shape[-1])
|
707 |
-
labels = labels.view(-1).long()
|
708 |
-
if mask is not None:
|
709 |
-
return self.criterion(logits[mask], labels[mask])
|
710 |
-
return self.criterion(logits, labels)
|
711 |
-
|
712 |
-
def compute_ned_end_loss(self, ned_end_logits, end_labels):
|
713 |
-
if ned_end_logits is None:
|
714 |
-
return 0
|
715 |
-
ned_end_labels = torch.zeros_like(end_labels)
|
716 |
-
ned_end_labels[end_labels == -100] = -100
|
717 |
-
ned_end_labels[end_labels > 0] = 1
|
718 |
-
return self.compute_loss(ned_end_logits, ned_end_labels)
|
719 |
-
|
720 |
-
def compute_ned_type_loss(
|
721 |
-
self,
|
722 |
-
disambiguation_labels,
|
723 |
-
re_ned_entities_logits,
|
724 |
-
ned_type_logits,
|
725 |
-
re_entities_logits,
|
726 |
-
entity_types,
|
727 |
-
):
|
728 |
-
if self.entity_type_loss and self.relation_disambiguation_loss:
|
729 |
-
return self.compute_loss(disambiguation_labels, re_ned_entities_logits)
|
730 |
-
if self.entity_type_loss:
|
731 |
-
return self.compute_loss(
|
732 |
-
disambiguation_labels[:, :, :entity_types], ned_type_logits
|
733 |
-
)
|
734 |
-
if self.relation_disambiguation_loss:
|
735 |
-
return self.compute_loss(disambiguation_labels, re_entities_logits)
|
736 |
-
return 0
|
737 |
-
|
738 |
-
def compute_relation_loss(self, relation_labels, re_logits):
|
739 |
-
return self.compute_loss(
|
740 |
-
re_logits, relation_labels, relation_labels.view(-1) != -100
|
741 |
-
)
|
742 |
-
|
743 |
-
def forward(
|
744 |
-
self,
|
745 |
-
input_ids: torch.Tensor,
|
746 |
-
attention_mask: torch.Tensor,
|
747 |
-
token_type_ids: torch.Tensor,
|
748 |
-
prediction_mask: Optional[torch.Tensor] = None,
|
749 |
-
special_symbols_mask: Optional[torch.Tensor] = None,
|
750 |
-
special_symbols_mask_entities: Optional[torch.Tensor] = None,
|
751 |
-
start_labels: Optional[torch.Tensor] = None,
|
752 |
-
end_labels: Optional[torch.Tensor] = None,
|
753 |
-
disambiguation_labels: Optional[torch.Tensor] = None,
|
754 |
-
relation_labels: Optional[torch.Tensor] = None,
|
755 |
-
is_validation: bool = False,
|
756 |
-
is_prediction: bool = False,
|
757 |
-
*args,
|
758 |
-
**kwargs,
|
759 |
-
) -> Dict[str, Any]:
|
760 |
-
batch_size = input_ids.shape[0]
|
761 |
-
|
762 |
-
model_features = self._get_model_features(
|
763 |
-
input_ids, attention_mask, token_type_ids
|
764 |
-
)
|
765 |
-
|
766 |
-
# named entity detection
|
767 |
-
if is_prediction and start_labels is not None:
|
768 |
-
ned_start_logits, ned_start_probabilities, ned_start_predictions = (
|
769 |
-
None,
|
770 |
-
None,
|
771 |
-
torch.zeros_like(start_labels),
|
772 |
-
)
|
773 |
-
ned_end_logits, ned_end_probabilities, ned_end_predictions = (
|
774 |
-
None,
|
775 |
-
None,
|
776 |
-
torch.zeros_like(end_labels),
|
777 |
-
)
|
778 |
-
|
779 |
-
ned_start_predictions[start_labels > 0] = 1
|
780 |
-
ned_end_predictions[end_labels > 0] = 1
|
781 |
-
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
|
782 |
-
else:
|
783 |
-
# start boundary prediction
|
784 |
-
ned_start_logits = self.ned_start_classifier(model_features)
|
785 |
-
ned_start_logits = self._mask_logits(
|
786 |
-
ned_start_logits, prediction_mask
|
787 |
-
) # why?
|
788 |
-
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
|
789 |
-
ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
|
790 |
-
|
791 |
-
# end boundary prediction
|
792 |
-
ned_start_labels = (
|
793 |
-
torch.zeros_like(start_labels) if start_labels is not None else None
|
794 |
-
)
|
795 |
-
|
796 |
-
# start_labels contain entity id at their position, we just need 1 for start of entity
|
797 |
-
if ned_start_labels is not None:
|
798 |
-
ned_start_labels[start_labels > 0] = 1
|
799 |
-
|
800 |
-
# compute end logits only if there are any start predictions.
|
801 |
-
# For each start prediction, n end predictions are made
|
802 |
-
ned_end_logits = self.compute_ned_end_logits(
|
803 |
-
ned_start_predictions,
|
804 |
-
ned_start_labels,
|
805 |
-
model_features,
|
806 |
-
prediction_mask,
|
807 |
-
batch_size,
|
808 |
-
)
|
809 |
-
# For each start prediction, n end predictions are made based on
|
810 |
-
# binary classification ie. argmax at each position.
|
811 |
-
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
|
812 |
-
ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
|
813 |
-
if is_prediction or is_validation:
|
814 |
-
end_preds_count = ned_end_predictions.sum(1)
|
815 |
-
# If there are no end predictions for a start prediction, remove the start prediction
|
816 |
-
ned_start_predictions[ned_start_predictions == 1] = (
|
817 |
-
end_preds_count != 0
|
818 |
-
).long()
|
819 |
-
ned_end_predictions = ned_end_predictions[end_preds_count != 0]
|
820 |
-
|
821 |
-
if end_labels is not None:
|
822 |
-
end_labels = end_labels[~(end_labels == -100).all(2)]
|
823 |
-
|
824 |
-
start_position, end_position = (
|
825 |
-
(start_labels, end_labels)
|
826 |
-
if (not is_prediction and not is_validation)
|
827 |
-
else (ned_start_predictions, ned_end_predictions)
|
828 |
-
)
|
829 |
-
|
830 |
-
start_counts = (start_position > 0).sum(1)
|
831 |
-
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
|
832 |
-
|
833 |
-
# We can only predict relations if we have start and end predictions
|
834 |
-
if (end_position > 0).sum() > 0:
|
835 |
-
ends_count = (end_position > 0).sum(1)
|
836 |
-
model_subject_features = torch.cat(
|
837 |
-
[
|
838 |
-
torch.repeat_interleave(
|
839 |
-
model_features[start_position > 0], ends_count, dim=0
|
840 |
-
), # start position features
|
841 |
-
torch.repeat_interleave(model_features, start_counts, dim=0)[
|
842 |
-
end_position > 0
|
843 |
-
], # end position features
|
844 |
-
],
|
845 |
-
dim=-1,
|
846 |
-
)
|
847 |
-
ents_count = torch.nn.utils.rnn.pad_sequence(
|
848 |
-
torch.split(ends_count, start_counts.tolist()),
|
849 |
-
batch_first=True,
|
850 |
-
padding_value=0,
|
851 |
-
).sum(1)
|
852 |
-
model_subject_features = torch.nn.utils.rnn.pad_sequence(
|
853 |
-
torch.split(model_subject_features, ents_count.tolist()),
|
854 |
-
batch_first=True,
|
855 |
-
padding_value=-100,
|
856 |
-
)
|
857 |
-
|
858 |
-
if is_validation or is_prediction:
|
859 |
-
model_subject_features = model_subject_features[:, :30, :]
|
860 |
-
|
861 |
-
# entity disambiguation. Here relation_disambiguation_loss would only be useful to
|
862 |
-
# reduce the number of candidate relations for the next step, but currently unused.
|
863 |
-
if self.entity_type_loss or self.relation_disambiguation_loss:
|
864 |
-
(re_ned_entities_logits) = self.compute_entity_logits(
|
865 |
-
model_subject_features,
|
866 |
-
model_features[
|
867 |
-
special_symbols_mask | special_symbols_mask_entities
|
868 |
-
].view(batch_size, -1, model_features.shape[-1]),
|
869 |
-
)
|
870 |
-
entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
|
871 |
-
ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
|
872 |
-
re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
|
873 |
-
|
874 |
-
if self.entity_type_loss:
|
875 |
-
ned_type_probabilities = torch.softmax(ned_type_logits, dim=-1)
|
876 |
-
ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
|
877 |
-
ned_type_predictions = ned_type_predictions.argmax(dim=-1)
|
878 |
-
|
879 |
-
re_entities_probabilities = torch.softmax(re_entities_logits, dim=-1)
|
880 |
-
re_entities_predictions = re_entities_probabilities.argmax(dim=-1)
|
881 |
-
else:
|
882 |
-
(
|
883 |
-
ned_type_logits,
|
884 |
-
ned_type_probabilities,
|
885 |
-
re_entities_logits,
|
886 |
-
re_entities_probabilities,
|
887 |
-
) = (None, None, None, None)
|
888 |
-
ned_type_predictions, re_entities_predictions = (
|
889 |
-
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
|
890 |
-
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
|
891 |
-
)
|
892 |
-
|
893 |
-
# Compute relation logits
|
894 |
-
re_logits = self.compute_relation_logits(
|
895 |
-
model_subject_features,
|
896 |
-
model_features[special_symbols_mask].view(
|
897 |
-
batch_size, -1, model_features.shape[-1]
|
898 |
-
),
|
899 |
-
)
|
900 |
-
|
901 |
-
re_probabilities = torch.softmax(re_logits, dim=-1)
|
902 |
-
# we set a thresshold instead of argmax in cause it needs to be tweaked
|
903 |
-
re_predictions = re_probabilities[:, :, :, :, 1] > 0.5
|
904 |
-
# re_predictions = re_probabilities.argmax(dim=-1)
|
905 |
-
re_probabilities = re_probabilities[:, :, :, :, 1]
|
906 |
-
|
907 |
-
else:
|
908 |
-
(
|
909 |
-
ned_type_logits,
|
910 |
-
ned_type_probabilities,
|
911 |
-
re_entities_logits,
|
912 |
-
re_entities_probabilities,
|
913 |
-
) = (None, None, None, None)
|
914 |
-
ned_type_predictions, re_entities_predictions = (
|
915 |
-
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
|
916 |
-
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
|
917 |
-
)
|
918 |
-
re_logits, re_probabilities, re_predictions = (
|
919 |
-
torch.zeros(
|
920 |
-
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
921 |
-
).to(input_ids.device),
|
922 |
-
torch.zeros(
|
923 |
-
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
924 |
-
).to(input_ids.device),
|
925 |
-
torch.zeros(
|
926 |
-
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
927 |
-
).to(input_ids.device),
|
928 |
-
)
|
929 |
-
|
930 |
-
# output build
|
931 |
-
output_dict = dict(
|
932 |
-
batch_size=batch_size,
|
933 |
-
ned_start_logits=ned_start_logits,
|
934 |
-
ned_start_probabilities=ned_start_probabilities,
|
935 |
-
ned_start_predictions=ned_start_predictions,
|
936 |
-
ned_end_logits=ned_end_logits,
|
937 |
-
ned_end_probabilities=ned_end_probabilities,
|
938 |
-
ned_end_predictions=ned_end_predictions,
|
939 |
-
ned_type_logits=ned_type_logits,
|
940 |
-
ned_type_probabilities=ned_type_probabilities,
|
941 |
-
ned_type_predictions=ned_type_predictions,
|
942 |
-
re_entities_logits=re_entities_logits,
|
943 |
-
re_entities_probabilities=re_entities_probabilities,
|
944 |
-
re_entities_predictions=re_entities_predictions,
|
945 |
-
re_logits=re_logits,
|
946 |
-
re_probabilities=re_probabilities,
|
947 |
-
re_predictions=re_predictions,
|
948 |
-
)
|
949 |
-
|
950 |
-
if (
|
951 |
-
start_labels is not None
|
952 |
-
and end_labels is not None
|
953 |
-
and relation_labels is not None
|
954 |
-
):
|
955 |
-
ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
|
956 |
-
ned_end_loss = self.compute_ned_end_loss(ned_end_logits, end_labels)
|
957 |
-
if self.entity_type_loss or self.relation_disambiguation_loss:
|
958 |
-
ned_type_loss = self.compute_ned_type_loss(
|
959 |
-
disambiguation_labels,
|
960 |
-
re_ned_entities_logits,
|
961 |
-
ned_type_logits,
|
962 |
-
re_entities_logits,
|
963 |
-
entity_types,
|
964 |
-
)
|
965 |
-
relation_loss = self.compute_relation_loss(relation_labels, re_logits)
|
966 |
-
# compute loss. We can skip the relation loss if we are in the first epochs (optional)
|
967 |
-
if self.entity_type_loss or self.relation_disambiguation_loss:
|
968 |
-
output_dict["loss"] = (
|
969 |
-
ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
|
970 |
-
) / 4
|
971 |
-
output_dict["ned_type_loss"] = ned_type_loss
|
972 |
-
else:
|
973 |
-
output_dict["loss"] = (
|
974 |
-
ned_start_loss + ned_end_loss + relation_loss
|
975 |
-
) / 3
|
976 |
-
|
977 |
-
output_dict["ned_start_loss"] = ned_start_loss
|
978 |
-
output_dict["ned_end_loss"] = ned_end_loss
|
979 |
-
output_dict["re_loss"] = relation_loss
|
980 |
-
|
981 |
-
return output_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/pytorch_modules/optim/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
from relik.reader.pytorch_modules.optim.adamw_with_warmup import (
|
2 |
-
AdamWWithWarmupOptimizer,
|
3 |
-
)
|
4 |
-
from relik.reader.pytorch_modules.optim.layer_wise_lr_decay import (
|
5 |
-
LayerWiseLRDecayOptimizer,
|
6 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/pytorch_modules/optim/adamw_with_warmup.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
from typing import List
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import transformers
|
5 |
-
from torch.optim import AdamW
|
6 |
-
|
7 |
-
|
8 |
-
class AdamWWithWarmupOptimizer:
|
9 |
-
def __init__(
|
10 |
-
self,
|
11 |
-
lr: float,
|
12 |
-
warmup_steps: int,
|
13 |
-
total_steps: int,
|
14 |
-
weight_decay: float,
|
15 |
-
no_decay_params: List[str],
|
16 |
-
):
|
17 |
-
self.lr = lr
|
18 |
-
self.warmup_steps = warmup_steps
|
19 |
-
self.total_steps = total_steps
|
20 |
-
self.weight_decay = weight_decay
|
21 |
-
self.no_decay_params = no_decay_params
|
22 |
-
|
23 |
-
def group_params(self, module: torch.nn.Module) -> list:
|
24 |
-
if self.no_decay_params is not None:
|
25 |
-
optimizer_grouped_parameters = [
|
26 |
-
{
|
27 |
-
"params": [
|
28 |
-
p
|
29 |
-
for n, p in module.named_parameters()
|
30 |
-
if not any(nd in n for nd in self.no_decay_params)
|
31 |
-
],
|
32 |
-
"weight_decay": self.weight_decay,
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"params": [
|
36 |
-
p
|
37 |
-
for n, p in module.named_parameters()
|
38 |
-
if any(nd in n for nd in self.no_decay_params)
|
39 |
-
],
|
40 |
-
"weight_decay": 0.0,
|
41 |
-
},
|
42 |
-
]
|
43 |
-
|
44 |
-
else:
|
45 |
-
optimizer_grouped_parameters = [
|
46 |
-
{"params": module.parameters(), "weight_decay": self.weight_decay}
|
47 |
-
]
|
48 |
-
|
49 |
-
return optimizer_grouped_parameters
|
50 |
-
|
51 |
-
def __call__(self, module: torch.nn.Module):
|
52 |
-
optimizer_grouped_parameters = self.group_params(module)
|
53 |
-
optimizer = AdamW(
|
54 |
-
optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay
|
55 |
-
)
|
56 |
-
scheduler = transformers.get_linear_schedule_with_warmup(
|
57 |
-
optimizer, self.warmup_steps, self.total_steps
|
58 |
-
)
|
59 |
-
return {
|
60 |
-
"optimizer": optimizer,
|
61 |
-
"lr_scheduler": {
|
62 |
-
"scheduler": scheduler,
|
63 |
-
"interval": "step",
|
64 |
-
"frequency": 1,
|
65 |
-
},
|
66 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
import collections
|
2 |
-
from typing import List
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import transformers
|
6 |
-
from torch.optim import AdamW
|
7 |
-
|
8 |
-
|
9 |
-
class LayerWiseLRDecayOptimizer:
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
lr: float,
|
13 |
-
warmup_steps: int,
|
14 |
-
total_steps: int,
|
15 |
-
weight_decay: float,
|
16 |
-
lr_decay: float,
|
17 |
-
no_decay_params: List[str],
|
18 |
-
total_reset: int,
|
19 |
-
):
|
20 |
-
self.lr = lr
|
21 |
-
self.warmup_steps = warmup_steps
|
22 |
-
self.total_steps = total_steps
|
23 |
-
self.weight_decay = weight_decay
|
24 |
-
self.lr_decay = lr_decay
|
25 |
-
self.no_decay_params = no_decay_params
|
26 |
-
self.total_reset = total_reset
|
27 |
-
|
28 |
-
def group_layers(self, module) -> dict:
|
29 |
-
grouped_layers = collections.defaultdict(list)
|
30 |
-
module_named_parameters = list(module.named_parameters())
|
31 |
-
for ln, lp in module_named_parameters:
|
32 |
-
if "embeddings" in ln:
|
33 |
-
grouped_layers["embeddings"].append((ln, lp))
|
34 |
-
elif "encoder.layer" in ln:
|
35 |
-
layer_num = ln.split("transformer_model.encoder.layer.")[-1]
|
36 |
-
layer_num = layer_num[0 : layer_num.index(".")]
|
37 |
-
grouped_layers[layer_num].append((ln, lp))
|
38 |
-
else:
|
39 |
-
grouped_layers["head"].append((ln, lp))
|
40 |
-
|
41 |
-
depth = len(grouped_layers) - 1
|
42 |
-
final_dict = dict()
|
43 |
-
for key, value in grouped_layers.items():
|
44 |
-
if key == "head":
|
45 |
-
final_dict[0] = value
|
46 |
-
elif key == "embeddings":
|
47 |
-
final_dict[depth] = value
|
48 |
-
else:
|
49 |
-
# -1 because layer number starts from zero
|
50 |
-
final_dict[depth - int(key) - 1] = value
|
51 |
-
|
52 |
-
assert len(module_named_parameters) == sum(
|
53 |
-
len(v) for _, v in final_dict.items()
|
54 |
-
)
|
55 |
-
|
56 |
-
return final_dict
|
57 |
-
|
58 |
-
def group_params(self, module) -> list:
|
59 |
-
optimizer_grouped_params = []
|
60 |
-
for inverse_depth, layer in self.group_layers(module).items():
|
61 |
-
layer_lr = self.lr * (self.lr_decay**inverse_depth)
|
62 |
-
layer_wd_params = {
|
63 |
-
"params": [
|
64 |
-
lp
|
65 |
-
for ln, lp in layer
|
66 |
-
if not any(nd in ln for nd in self.no_decay_params)
|
67 |
-
],
|
68 |
-
"weight_decay": self.weight_decay,
|
69 |
-
"lr": layer_lr,
|
70 |
-
}
|
71 |
-
layer_no_wd_params = {
|
72 |
-
"params": [
|
73 |
-
lp
|
74 |
-
for ln, lp in layer
|
75 |
-
if any(nd in ln for nd in self.no_decay_params)
|
76 |
-
],
|
77 |
-
"weight_decay": 0,
|
78 |
-
"lr": layer_lr,
|
79 |
-
}
|
80 |
-
|
81 |
-
if len(layer_wd_params) != 0:
|
82 |
-
optimizer_grouped_params.append(layer_wd_params)
|
83 |
-
if len(layer_no_wd_params) != 0:
|
84 |
-
optimizer_grouped_params.append(layer_no_wd_params)
|
85 |
-
|
86 |
-
return optimizer_grouped_params
|
87 |
-
|
88 |
-
def __call__(self, module: torch.nn.Module):
|
89 |
-
optimizer_grouped_parameters = self.group_params(module)
|
90 |
-
optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
|
91 |
-
scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
|
92 |
-
optimizer,
|
93 |
-
self.warmup_steps,
|
94 |
-
self.total_steps,
|
95 |
-
num_cycles=self.total_reset,
|
96 |
-
)
|
97 |
-
return {
|
98 |
-
"optimizer": optimizer,
|
99 |
-
"lr_scheduler": {
|
100 |
-
"scheduler": scheduler,
|
101 |
-
"interval": "step",
|
102 |
-
"frequency": 1,
|
103 |
-
},
|
104 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/pytorch_modules/span.py
DELETED
@@ -1,367 +0,0 @@
|
|
1 |
-
import collections
|
2 |
-
import contextlib
|
3 |
-
import logging
|
4 |
-
from typing import Any, Dict, Iterator, List
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import transformers as tr
|
8 |
-
from lightning_fabric.utilities import move_data_to_device
|
9 |
-
from torch.utils.data import DataLoader, IterableDataset
|
10 |
-
from tqdm import tqdm
|
11 |
-
|
12 |
-
from relik.common.log import get_console_logger, get_logger
|
13 |
-
from relik.common.utils import get_callable_from_string
|
14 |
-
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
15 |
-
from relik.reader.pytorch_modules.base import RelikReaderBase
|
16 |
-
from relik.reader.utils.special_symbols import get_special_symbols
|
17 |
-
from relik.retriever.pytorch_modules import PRECISION_MAP
|
18 |
-
|
19 |
-
console_logger = get_console_logger()
|
20 |
-
logger = get_logger(__name__, level=logging.INFO)
|
21 |
-
|
22 |
-
|
23 |
-
class RelikReaderForSpanExtraction(RelikReaderBase):
|
24 |
-
"""
|
25 |
-
A class for the RelikReader model for span extraction.
|
26 |
-
|
27 |
-
Args:
|
28 |
-
transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
|
29 |
-
The transformer model to use. If `None`, the default model is used.
|
30 |
-
additional_special_symbols (:obj:`int`, `optional`, defaults to 0):
|
31 |
-
The number of additional special symbols to add to the tokenizer.
|
32 |
-
num_layers (:obj:`int`, `optional`):
|
33 |
-
The number of layers to use. If `None`, all layers are used.
|
34 |
-
activation (:obj:`str`, `optional`, defaults to "gelu"):
|
35 |
-
The activation function to use.
|
36 |
-
linears_hidden_size (:obj:`int`, `optional`, defaults to 512):
|
37 |
-
The hidden size of the linears.
|
38 |
-
use_last_k_layers (:obj:`int`, `optional`, defaults to 1):
|
39 |
-
The number of last layers to use.
|
40 |
-
training (:obj:`bool`, `optional`, defaults to False):
|
41 |
-
Whether the model is in training mode.
|
42 |
-
device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`):
|
43 |
-
The device to use. If `None`, the default device is used.
|
44 |
-
tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`):
|
45 |
-
The tokenizer to use. If `None`, the default tokenizer is used.
|
46 |
-
dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`):
|
47 |
-
The dataset to use. If `None`, the default dataset is used.
|
48 |
-
dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`):
|
49 |
-
The keyword arguments to pass to the dataset class.
|
50 |
-
default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
|
51 |
-
The default reader class to use. If `None`, the default reader class is used.
|
52 |
-
**kwargs:
|
53 |
-
Keyword arguments.
|
54 |
-
"""
|
55 |
-
|
56 |
-
default_reader_class: str = (
|
57 |
-
"relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel"
|
58 |
-
)
|
59 |
-
default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset"
|
60 |
-
|
61 |
-
def __init__(
|
62 |
-
self,
|
63 |
-
transformer_model: str | tr.PreTrainedModel | None = None,
|
64 |
-
additional_special_symbols: int = 0,
|
65 |
-
num_layers: int | None = None,
|
66 |
-
activation: str = "gelu",
|
67 |
-
linears_hidden_size: int | None = 512,
|
68 |
-
use_last_k_layers: int = 1,
|
69 |
-
training: bool = False,
|
70 |
-
device: str | torch.device | None = None,
|
71 |
-
tokenizer: str | tr.PreTrainedTokenizer | None = None,
|
72 |
-
dataset: IterableDataset | str | None = None,
|
73 |
-
dataset_kwargs: Dict[str, Any] | None = None,
|
74 |
-
default_reader_class: tr.PreTrainedModel | str | None = None,
|
75 |
-
**kwargs,
|
76 |
-
):
|
77 |
-
super().__init__(
|
78 |
-
transformer_model=transformer_model,
|
79 |
-
additional_special_symbols=additional_special_symbols,
|
80 |
-
num_layers=num_layers,
|
81 |
-
activation=activation,
|
82 |
-
linears_hidden_size=linears_hidden_size,
|
83 |
-
use_last_k_layers=use_last_k_layers,
|
84 |
-
training=training,
|
85 |
-
device=device,
|
86 |
-
tokenizer=tokenizer,
|
87 |
-
dataset=dataset,
|
88 |
-
default_reader_class=default_reader_class,
|
89 |
-
**kwargs,
|
90 |
-
)
|
91 |
-
# and instantiate the dataset class
|
92 |
-
self.dataset = dataset
|
93 |
-
if self.dataset is None:
|
94 |
-
default_data_kwargs = dict(
|
95 |
-
dataset_path=None,
|
96 |
-
materialize_samples=False,
|
97 |
-
transformer_model=self.tokenizer,
|
98 |
-
special_symbols=get_special_symbols(
|
99 |
-
self.relik_reader_model.config.additional_special_symbols
|
100 |
-
),
|
101 |
-
for_inference=True,
|
102 |
-
)
|
103 |
-
# merge the default data kwargs with the ones passed to the model
|
104 |
-
default_data_kwargs.update(dataset_kwargs or {})
|
105 |
-
self.dataset = get_callable_from_string(self.default_data_class)(
|
106 |
-
**default_data_kwargs
|
107 |
-
)
|
108 |
-
|
109 |
-
@torch.no_grad()
|
110 |
-
@torch.inference_mode()
|
111 |
-
def _read(
|
112 |
-
self,
|
113 |
-
samples: List[RelikReaderSample] | None = None,
|
114 |
-
input_ids: torch.Tensor | None = None,
|
115 |
-
attention_mask: torch.Tensor | None = None,
|
116 |
-
token_type_ids: torch.Tensor | None = None,
|
117 |
-
prediction_mask: torch.Tensor | None = None,
|
118 |
-
special_symbols_mask: torch.Tensor | None = None,
|
119 |
-
max_length: int = 1000,
|
120 |
-
max_batch_size: int = 128,
|
121 |
-
token_batch_size: int = 2048,
|
122 |
-
precision: str = 32,
|
123 |
-
annotation_type: str = "char",
|
124 |
-
progress_bar: bool = False,
|
125 |
-
*args: object,
|
126 |
-
**kwargs: object,
|
127 |
-
) -> List[RelikReaderSample] | List[List[RelikReaderSample]]:
|
128 |
-
"""
|
129 |
-
A wrapper around the forward method that returns the predicted labels for each sample.
|
130 |
-
|
131 |
-
Args:
|
132 |
-
samples (:obj:`List[RelikReaderSample]`, `optional`):
|
133 |
-
The samples to read. If provided, `text` and `candidates` are ignored.
|
134 |
-
input_ids (:obj:`torch.Tensor`, `optional`):
|
135 |
-
The input ids of the text. If `samples` is provided, this is ignored.
|
136 |
-
attention_mask (:obj:`torch.Tensor`, `optional`):
|
137 |
-
The attention mask of the text. If `samples` is provided, this is ignored.
|
138 |
-
token_type_ids (:obj:`torch.Tensor`, `optional`):
|
139 |
-
The token type ids of the text. If `samples` is provided, this is ignored.
|
140 |
-
prediction_mask (:obj:`torch.Tensor`, `optional`):
|
141 |
-
The prediction mask of the text. If `samples` is provided, this is ignored.
|
142 |
-
special_symbols_mask (:obj:`torch.Tensor`, `optional`):
|
143 |
-
The special symbols mask of the text. If `samples` is provided, this is ignored.
|
144 |
-
max_length (:obj:`int`, `optional`, defaults to 1000):
|
145 |
-
The maximum length of the text.
|
146 |
-
max_batch_size (:obj:`int`, `optional`, defaults to 128):
|
147 |
-
The maximum batch size.
|
148 |
-
token_batch_size (:obj:`int`, `optional`):
|
149 |
-
The token batch size.
|
150 |
-
progress_bar (:obj:`bool`, `optional`, defaults to False):
|
151 |
-
Whether to show a progress bar.
|
152 |
-
precision (:obj:`str`, `optional`, defaults to 32):
|
153 |
-
The precision to use for the model.
|
154 |
-
annotation_type (:obj:`str`, `optional`, defaults to "char"):
|
155 |
-
The annotation type to use. It can be either "char", "token" or "word".
|
156 |
-
*args:
|
157 |
-
Positional arguments.
|
158 |
-
**kwargs:
|
159 |
-
Keyword arguments.
|
160 |
-
|
161 |
-
Returns:
|
162 |
-
:obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`:
|
163 |
-
The predicted labels for each sample.
|
164 |
-
"""
|
165 |
-
|
166 |
-
precision = precision or self.precision
|
167 |
-
if samples is not None:
|
168 |
-
|
169 |
-
def _read_iterator():
|
170 |
-
def samples_it():
|
171 |
-
for i, sample in enumerate(samples):
|
172 |
-
assert sample._mixin_prediction_position is None
|
173 |
-
sample._mixin_prediction_position = i
|
174 |
-
yield sample
|
175 |
-
|
176 |
-
next_prediction_position = 0
|
177 |
-
position2predicted_sample = {}
|
178 |
-
|
179 |
-
# instantiate dataset
|
180 |
-
if self.dataset is None:
|
181 |
-
raise ValueError(
|
182 |
-
"You need to pass a dataset to the model in order to predict"
|
183 |
-
)
|
184 |
-
self.dataset.samples = samples_it()
|
185 |
-
self.dataset.model_max_length = max_length
|
186 |
-
self.dataset.tokens_per_batch = token_batch_size
|
187 |
-
self.dataset.max_batch_size = max_batch_size
|
188 |
-
|
189 |
-
# instantiate dataloader
|
190 |
-
iterator = DataLoader(
|
191 |
-
self.dataset, batch_size=None, num_workers=0, shuffle=False
|
192 |
-
)
|
193 |
-
if progress_bar:
|
194 |
-
iterator = tqdm(iterator, desc="Predicting with RelikReader")
|
195 |
-
|
196 |
-
# fucking autocast only wants pure strings like 'cpu' or 'cuda'
|
197 |
-
# we need to convert the model device to that
|
198 |
-
device_type_for_autocast = str(self.device).split(":")[0]
|
199 |
-
# autocast doesn't work with CPU and stuff different from bfloat16
|
200 |
-
autocast_mngr = (
|
201 |
-
contextlib.nullcontext()
|
202 |
-
if device_type_for_autocast == "cpu"
|
203 |
-
else (
|
204 |
-
torch.autocast(
|
205 |
-
device_type=device_type_for_autocast,
|
206 |
-
dtype=PRECISION_MAP[precision],
|
207 |
-
)
|
208 |
-
)
|
209 |
-
)
|
210 |
-
|
211 |
-
with autocast_mngr:
|
212 |
-
for batch in iterator:
|
213 |
-
batch = move_data_to_device(batch, self.device)
|
214 |
-
batch_out = self._batch_predict(**batch)
|
215 |
-
|
216 |
-
for sample in batch_out:
|
217 |
-
if (
|
218 |
-
sample._mixin_prediction_position
|
219 |
-
>= next_prediction_position
|
220 |
-
):
|
221 |
-
position2predicted_sample[
|
222 |
-
sample._mixin_prediction_position
|
223 |
-
] = sample
|
224 |
-
|
225 |
-
# yield
|
226 |
-
while next_prediction_position in position2predicted_sample:
|
227 |
-
yield position2predicted_sample[next_prediction_position]
|
228 |
-
del position2predicted_sample[next_prediction_position]
|
229 |
-
next_prediction_position += 1
|
230 |
-
|
231 |
-
outputs = list(_read_iterator())
|
232 |
-
for sample in outputs:
|
233 |
-
self.dataset.merge_patches_predictions(sample)
|
234 |
-
self.dataset.convert_tokens_to_char_annotations(sample)
|
235 |
-
|
236 |
-
else:
|
237 |
-
outputs = list(
|
238 |
-
self._batch_predict(
|
239 |
-
input_ids,
|
240 |
-
attention_mask,
|
241 |
-
token_type_ids,
|
242 |
-
prediction_mask,
|
243 |
-
special_symbols_mask,
|
244 |
-
*args,
|
245 |
-
**kwargs,
|
246 |
-
)
|
247 |
-
)
|
248 |
-
return outputs
|
249 |
-
|
250 |
-
def _batch_predict(
|
251 |
-
self,
|
252 |
-
input_ids: torch.Tensor,
|
253 |
-
attention_mask: torch.Tensor,
|
254 |
-
token_type_ids: torch.Tensor | None = None,
|
255 |
-
prediction_mask: torch.Tensor | None = None,
|
256 |
-
special_symbols_mask: torch.Tensor | None = None,
|
257 |
-
sample: List[RelikReaderSample] | None = None,
|
258 |
-
top_k: int = 5, # the amount of top-k most probable entities to predict
|
259 |
-
*args,
|
260 |
-
**kwargs,
|
261 |
-
) -> Iterator[RelikReaderSample]:
|
262 |
-
"""
|
263 |
-
A wrapper around the forward method that returns the predicted labels for each sample.
|
264 |
-
It also adds the predicted labels to the samples.
|
265 |
-
|
266 |
-
Args:
|
267 |
-
input_ids (:obj:`torch.Tensor`):
|
268 |
-
The input ids of the text.
|
269 |
-
attention_mask (:obj:`torch.Tensor`):
|
270 |
-
The attention mask of the text.
|
271 |
-
token_type_ids (:obj:`torch.Tensor`, `optional`):
|
272 |
-
The token type ids of the text.
|
273 |
-
prediction_mask (:obj:`torch.Tensor`, `optional`):
|
274 |
-
The prediction mask of the text.
|
275 |
-
special_symbols_mask (:obj:`torch.Tensor`, `optional`):
|
276 |
-
The special symbols mask of the text.
|
277 |
-
sample (:obj:`List[RelikReaderSample]`, `optional`):
|
278 |
-
The samples to read. If provided, `text` and `candidates` are ignored.
|
279 |
-
top_k (:obj:`int`, `optional`, defaults to 5):
|
280 |
-
The amount of top-k most probable entities to predict.
|
281 |
-
*args:
|
282 |
-
Positional arguments.
|
283 |
-
**kwargs:
|
284 |
-
Keyword arguments.
|
285 |
-
|
286 |
-
Returns:
|
287 |
-
The predicted labels for each sample.
|
288 |
-
"""
|
289 |
-
forward_output = self.forward(
|
290 |
-
input_ids=input_ids,
|
291 |
-
attention_mask=attention_mask,
|
292 |
-
token_type_ids=token_type_ids,
|
293 |
-
prediction_mask=prediction_mask,
|
294 |
-
special_symbols_mask=special_symbols_mask,
|
295 |
-
)
|
296 |
-
|
297 |
-
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
|
298 |
-
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
|
299 |
-
ed_predictions = forward_output["ed_predictions"].cpu().numpy()
|
300 |
-
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
|
301 |
-
|
302 |
-
batch_predictable_candidates = kwargs["predictable_candidates"]
|
303 |
-
patch_offset = kwargs["patch_offset"]
|
304 |
-
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
|
305 |
-
sample,
|
306 |
-
ned_start_predictions,
|
307 |
-
ned_end_predictions,
|
308 |
-
ed_predictions,
|
309 |
-
ed_probabilities,
|
310 |
-
batch_predictable_candidates,
|
311 |
-
patch_offset,
|
312 |
-
):
|
313 |
-
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
|
314 |
-
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
|
315 |
-
|
316 |
-
final_class2predicted_spans = collections.defaultdict(list)
|
317 |
-
spans2predicted_probabilities = dict()
|
318 |
-
for start_token_index, end_token_index in zip(
|
319 |
-
ne_start_indices, ne_end_indices
|
320 |
-
):
|
321 |
-
# predicted candidate
|
322 |
-
token_class = edp[start_token_index + 1] - 1
|
323 |
-
predicted_candidate_title = pred_cands[token_class]
|
324 |
-
final_class2predicted_spans[predicted_candidate_title].append(
|
325 |
-
[start_token_index, end_token_index]
|
326 |
-
)
|
327 |
-
|
328 |
-
# candidates probabilities
|
329 |
-
classes_probabilities = edpr[start_token_index + 1]
|
330 |
-
classes_probabilities_best_indices = classes_probabilities.argsort()[
|
331 |
-
::-1
|
332 |
-
]
|
333 |
-
titles_2_probs = []
|
334 |
-
top_k = (
|
335 |
-
min(
|
336 |
-
top_k,
|
337 |
-
len(classes_probabilities_best_indices),
|
338 |
-
)
|
339 |
-
if top_k != -1
|
340 |
-
else len(classes_probabilities_best_indices)
|
341 |
-
)
|
342 |
-
for i in range(top_k):
|
343 |
-
titles_2_probs.append(
|
344 |
-
(
|
345 |
-
pred_cands[classes_probabilities_best_indices[i] - 1],
|
346 |
-
classes_probabilities[
|
347 |
-
classes_probabilities_best_indices[i]
|
348 |
-
].item(),
|
349 |
-
)
|
350 |
-
)
|
351 |
-
spans2predicted_probabilities[
|
352 |
-
(start_token_index, end_token_index)
|
353 |
-
] = titles_2_probs
|
354 |
-
|
355 |
-
if "patches" not in ts._d:
|
356 |
-
ts._d["patches"] = dict()
|
357 |
-
|
358 |
-
ts._d["patches"][po] = dict()
|
359 |
-
sample_patch = ts._d["patches"][po]
|
360 |
-
|
361 |
-
sample_patch["predicted_window_labels"] = final_class2predicted_spans
|
362 |
-
sample_patch["span_title_probabilities"] = spans2predicted_probabilities
|
363 |
-
|
364 |
-
# additional info
|
365 |
-
sample_patch["predictable_candidates"] = pred_cands
|
366 |
-
|
367 |
-
yield ts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relik/reader/relik_reader.py
DELETED
@@ -1,629 +0,0 @@
|
|
1 |
-
import collections
|
2 |
-
import logging
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Any, Callable, Dict, Iterator, List, Union
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import transformers as tr
|
8 |
-
from tqdm import tqdm
|
9 |
-
from transformers import AutoConfig
|
10 |
-
|
11 |
-
from relik.common.log import get_console_logger, get_logger
|
12 |
-
from relik.reader.data.relik_reader_data_utils import batchify, flatten
|
13 |
-
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
14 |
-
from relik.reader.pytorch_modules.hf.modeling_relik import (
|
15 |
-
RelikReaderConfig,
|
16 |
-
RelikReaderSpanModel,
|
17 |
-
)
|
18 |
-
from relik.reader.relik_reader_predictor import RelikReaderPredictor
|
19 |
-
from relik.reader.utils.save_load_utilities import load_model_and_conf
|
20 |
-
from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols
|
21 |
-
|
22 |
-
console_logger = get_console_logger()
|
23 |
-
logger = get_logger(__name__, level=logging.INFO)
|
24 |
-
|
25 |
-
|
26 |
-
class RelikReaderForSpanExtraction(torch.nn.Module):
|
27 |
-
def __init__(
|
28 |
-
self,
|
29 |
-
transformer_model: str | tr.PreTrainedModel | None = None,
|
30 |
-
additional_special_symbols: int = 0,
|
31 |
-
num_layers: int | None = None,
|
32 |
-
activation: str = "gelu",
|
33 |
-
linears_hidden_size: int | None = 512,
|
34 |
-
use_last_k_layers: int = 1,
|
35 |
-
training: bool = False,
|
36 |
-
device: str | torch.device | None = None,
|
37 |
-
tokenizer: str | tr.PreTrainedTokenizer | None = None,
|
38 |
-
**kwargs,
|
39 |
-
) -> None:
|
40 |
-
super().__init__()
|
41 |
-
|
42 |
-
if isinstance(transformer_model, str):
|
43 |
-
config = AutoConfig.from_pretrained(
|
44 |
-
transformer_model, trust_remote_code=True
|
45 |
-
)
|
46 |
-
if "relik-reader" in config.model_type:
|
47 |
-
transformer_model = RelikReaderSpanModel.from_pretrained(
|
48 |
-
transformer_model, **kwargs
|
49 |
-
)
|
50 |
-
else:
|
51 |
-
reader_config = RelikReaderConfig(
|
52 |
-
transformer_model=transformer_model,
|
53 |
-
additional_special_symbols=additional_special_symbols,
|
54 |
-
num_layers=num_layers,
|
55 |
-
activation=activation,
|
56 |
-
linears_hidden_size=linears_hidden_size,
|
57 |
-
use_last_k_layers=use_last_k_layers,
|
58 |
-
training=training,
|
59 |
-
)
|
60 |
-
transformer_model = RelikReaderSpanModel(reader_config)
|
61 |
-
|
62 |
-
self.relik_reader_model = transformer_model
|
63 |
-
|
64 |
-
self._tokenizer = tokenizer
|
65 |
-
|
66 |
-
# move the model to the device
|
67 |
-
self.to(device or torch.device("cpu"))
|
68 |
-
|
69 |
-
def forward(
|
70 |
-
self,
|
71 |
-
input_ids: torch.Tensor,
|
72 |
-
attention_mask: torch.Tensor,
|
73 |
-
token_type_ids: torch.Tensor,
|
74 |
-
prediction_mask: torch.Tensor | None = None,
|
75 |
-
special_symbols_mask: torch.Tensor | None = None,
|
76 |
-
special_symbols_mask_entities: torch.Tensor | None = None,
|
77 |
-
start_labels: torch.Tensor | None = None,
|
78 |
-
end_labels: torch.Tensor | None = None,
|
79 |
-
disambiguation_labels: torch.Tensor | None = None,
|
80 |
-
relation_labels: torch.Tensor | None = None,
|
81 |
-
is_validation: bool = False,
|
82 |
-
is_prediction: bool = False,
|
83 |
-
*args,
|
84 |
-
**kwargs,
|
85 |
-
) -> Dict[str, Any]:
|
86 |
-
return self.relik_reader_model(
|
87 |
-
input_ids,
|
88 |
-
attention_mask,
|
89 |
-
token_type_ids,
|
90 |
-
prediction_mask,
|
91 |
-
special_symbols_mask,
|
92 |
-
special_symbols_mask_entities,
|
93 |
-
start_labels,
|
94 |
-
end_labels,
|
95 |
-
disambiguation_labels,
|
96 |
-
relation_labels,
|
97 |
-
is_validation,
|
98 |
-
is_prediction,
|
99 |
-
*args,
|
100 |
-
**kwargs,
|
101 |
-
)
|
102 |
-
|
103 |
-
def batch_predict(
|
104 |
-
self,
|
105 |
-
input_ids: torch.Tensor,
|
106 |
-
attention_mask: torch.Tensor,
|
107 |
-
token_type_ids: torch.Tensor | None = None,
|
108 |
-
prediction_mask: torch.Tensor | None = None,
|
109 |
-
special_symbols_mask: torch.Tensor | None = None,
|
110 |
-
sample: List[RelikReaderSample] | None = None,
|
111 |
-
top_k: int = 5, # the amount of top-k most probable entities to predict
|
112 |
-
*args,
|
113 |
-
**kwargs,
|
114 |
-
) -> Iterator[RelikReaderSample]:
|
115 |
-
"""
|
116 |
-
|
117 |
-
|
118 |
-
Args:
|
119 |
-
input_ids:
|
120 |
-
attention_mask:
|
121 |
-
token_type_ids:
|
122 |
-
prediction_mask:
|
123 |
-
special_symbols_mask:
|
124 |
-
sample:
|
125 |
-
top_k:
|
126 |
-
*args:
|
127 |
-
**kwargs:
|
128 |
-
|
129 |
-
Returns:
|
130 |
-
|
131 |
-
"""
|
132 |
-
forward_output = self.forward(
|
133 |
-
input_ids,
|
134 |
-
attention_mask,
|
135 |
-
token_type_ids,
|
136 |
-
prediction_mask,
|
137 |
-
special_symbols_mask,
|
138 |
-
)
|
139 |
-
|
140 |
-
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
|
141 |
-
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
|
142 |
-
ed_predictions = forward_output["ed_predictions"].cpu().numpy()
|
143 |
-
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
|
144 |
-
|
145 |
-
batch_predictable_candidates = kwargs["predictable_candidates"]
|
146 |
-
patch_offset = kwargs["patch_offset"]
|
147 |
-
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
|
148 |
-
sample,
|
149 |
-
ned_start_predictions,
|
150 |
-
ned_end_predictions,
|
151 |
-
ed_predictions,
|
152 |
-
ed_probabilities,
|
153 |
-
batch_predictable_candidates,
|
154 |
-
patch_offset,
|
155 |
-
):
|
156 |
-
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
|
157 |
-
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
|
158 |
-
|
159 |
-
final_class2predicted_spans = collections.defaultdict(list)
|
160 |
-
spans2predicted_probabilities = dict()
|
161 |
-
for start_token_index, end_token_index in zip(
|
162 |
-
ne_start_indices, ne_end_indices
|
163 |
-
):
|
164 |
-
# predicted candidate
|
165 |
-
token_class = edp[start_token_index + 1] - 1
|
166 |
-
predicted_candidate_title = pred_cands[token_class]
|
167 |
-
final_class2predicted_spans[predicted_candidate_title].append(
|
168 |
-
[start_token_index, end_token_index]
|
169 |
-
)
|
170 |
-
|
171 |
-
# candidates probabilities
|
172 |
-
classes_probabilities = edpr[start_token_index + 1]
|
173 |
-
classes_probabilities_best_indices = classes_probabilities.argsort()[
|
174 |
-
::-1
|
175 |
-
]
|
176 |
-
titles_2_probs = []
|
177 |
-
top_k = (
|
178 |
-
min(
|
179 |
-
top_k,
|
180 |
-
len(classes_probabilities_best_indices),
|
181 |
-
)
|
182 |
-
if top_k != -1
|
183 |
-
else len(classes_probabilities_best_indices)
|
184 |
-
)
|
185 |
-
for i in range(top_k):
|
186 |
-
titles_2_probs.append(
|
187 |
-
(
|
188 |
-
pred_cands[classes_probabilities_best_indices[i] - 1],
|
189 |
-
classes_probabilities[
|
190 |
-
classes_probabilities_best_indices[i]
|
191 |
-
].item(),
|
192 |
-
)
|
193 |
-
)
|
194 |
-
spans2predicted_probabilities[
|
195 |
-
(start_token_index, end_token_index)
|
196 |
-
] = titles_2_probs
|
197 |
-
|
198 |
-
if "patches" not in ts._d:
|
199 |
-
ts._d["patches"] = dict()
|
200 |
-
|
201 |
-
ts._d["patches"][po] = dict()
|
202 |
-
sample_patch = ts._d["patches"][po]
|
203 |
-
|
204 |
-
sample_patch["predicted_window_labels"] = final_class2predicted_spans
|
205 |
-
sample_patch["span_title_probabilities"] = spans2predicted_probabilities
|
206 |
-
|
207 |
-
# additional info
|
208 |
-
sample_patch["predictable_candidates"] = pred_cands
|
209 |
-
|
210 |
-
yield ts
|
211 |
-
|
212 |
-
def _build_input(self, text: List[str], candidates: List[List[str]]) -> list[str]:
|
213 |
-
candidates_symbols = get_special_symbols(len(candidates))
|
214 |
-
candidates = [
|
215 |
-
[cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL]
|
216 |
-
for cs, ct in zip(candidates_symbols, candidates)
|
217 |
-
]
|
218 |
-
return (
|
219 |
-
[self.tokenizer.cls_token]
|
220 |
-
+ text
|
221 |
-
+ [self.tokenizer.sep_token]
|
222 |
-
+ flatten(candidates)
|
223 |
-
+ [self.tokenizer.sep_token]
|
224 |
-
)
|
225 |
-
|
226 |
-
@staticmethod
|
227 |
-
def _compute_offsets(offsets_mapping):
|
228 |
-
offsets_mapping = offsets_mapping.numpy()
|
229 |
-
token2word = []
|
230 |
-
word2token = {}
|
231 |
-
count = 0
|
232 |
-
for i, offset in enumerate(offsets_mapping):
|
233 |
-
if offset[0] == 0:
|
234 |
-
token2word.append(i - count)
|
235 |
-
word2token[i - count] = [i]
|
236 |
-
else:
|
237 |
-
token2word.append(token2word[-1])
|
238 |
-
word2token[token2word[-1]].append(i)
|
239 |
-
count += 1
|
240 |
-
return token2word, word2token
|
241 |
-
|
242 |
-
@staticmethod
|
243 |
-
def _convert_tokens_to_word_annotations(sample: RelikReaderSample):
|
244 |
-
triplets = []
|
245 |
-
entities = []
|
246 |
-
for entity in sample.predicted_entities:
|
247 |
-
if sample.entity_candidates:
|
248 |
-
entities.append(
|
249 |
-
(
|
250 |
-
sample.token2word[entity[0] - 1],
|
251 |
-
sample.token2word[entity[1] - 1] + 1,
|
252 |
-
sample.entity_candidates[entity[2]],
|
253 |
-
)
|
254 |
-
)
|
255 |
-
else:
|
256 |
-
entities.append(
|
257 |
-
(
|
258 |
-
sample.token2word[entity[0] - 1],
|
259 |
-
sample.token2word[entity[1] - 1] + 1,
|
260 |
-
-1,
|
261 |
-
)
|
262 |
-
)
|
263 |
-
for predicted_triplet, predicted_triplet_probabilities in zip(
|
264 |
-
sample.predicted_relations, sample.predicted_relations_probabilities
|
265 |
-
):
|
266 |
-
subject, object_, relation = predicted_triplet
|
267 |
-
subject = entities[subject]
|
268 |
-
object_ = entities[object_]
|
269 |
-
relation = sample.candidates[relation]
|
270 |
-
triplets.append(
|
271 |
-
{
|
272 |
-
"subject": {
|
273 |
-
"start": subject[0],
|
274 |
-
"end": subject[1],
|
275 |
-
"type": subject[2],
|
276 |
-
"name": " ".join(sample.tokens[subject[0] : subject[1]]),
|
277 |
-
},
|
278 |
-
"relation": {
|
279 |
-
"name": relation,
|
280 |
-
"probability": float(predicted_triplet_probabilities.round(2)),
|
281 |
-
},
|
282 |
-
"object": {
|
283 |
-
"start": object_[0],
|
284 |
-
"end": object_[1],
|
285 |
-
"type": object_[2],
|
286 |
-
"name": " ".join(sample.tokens[object_[0] : object_[1]]),
|
287 |
-
},
|
288 |
-
}
|
289 |
-
)
|
290 |
-
sample.predicted_entities = entities
|
291 |
-
sample.predicted_relations = triplets
|
292 |
-
sample.predicted_relations_probabilities = None
|
293 |
-
|
294 |
-
@torch.no_grad()
|
295 |
-
@torch.inference_mode()
|
296 |
-
def read(
|
297 |
-
self,
|
298 |
-
text: List[str] | List[List[str]] | None = None,
|
299 |
-
samples: List[RelikReaderSample] | None = None,
|
300 |
-
input_ids: torch.Tensor | None = None,
|
301 |
-
attention_mask: torch.Tensor | None = None,
|
302 |
-
token_type_ids: torch.Tensor | None = None,
|
303 |
-
prediction_mask: torch.Tensor | None = None,
|
304 |
-
special_symbols_mask: torch.Tensor | None = None,
|
305 |
-
special_symbols_mask_entities: torch.Tensor | None = None,
|
306 |
-
candidates: List[List[str]] | None = None,
|
307 |
-
max_length: int | None = 1024,
|
308 |
-
max_batch_size: int | None = 64,
|
309 |
-
token_batch_size: int | None = None,
|
310 |
-
progress_bar: bool = False,
|
311 |
-
*args,
|
312 |
-
**kwargs,
|
313 |
-
) -> List[List[RelikReaderSample]]:
|
314 |
-
"""
|
315 |
-
Reads the given text.
|
316 |
-
Args:
|
317 |
-
text: The text to read in tokens.
|
318 |
-
samples:
|
319 |
-
input_ids: The input ids of the text.
|
320 |
-
attention_mask: The attention mask of the text.
|
321 |
-
token_type_ids: The token type ids of the text.
|
322 |
-
prediction_mask: The prediction mask of the text.
|
323 |
-
special_symbols_mask: The special symbols mask of the text.
|
324 |
-
special_symbols_mask_entities: The special symbols mask entities of the text.
|
325 |
-
candidates: The candidates of the text.
|
326 |
-
max_length: The maximum length of the text.
|
327 |
-
max_batch_size: The maximum batch size.
|
328 |
-
token_batch_size: The maximum number of tokens per batch.
|
329 |
-
progress_bar:
|
330 |
-
Returns:
|
331 |
-
The predicted labels for each sample.
|
332 |
-
"""
|
333 |
-
if text is None and input_ids is None and samples is None:
|
334 |
-
raise ValueError(
|
335 |
-
"Either `text` or `input_ids` or `samples` must be provided."
|
336 |
-
)
|
337 |
-
if (input_ids is None and samples is None) and (
|
338 |
-
text is None or candidates is None
|
339 |
-
):
|
340 |
-
raise ValueError(
|
341 |
-
"`text` and `candidates` must be provided to return the predictions when "
|
342 |
-
"`input_ids` and `samples` is not provided."
|
343 |
-
)
|
344 |
-
if text is not None and samples is None:
|
345 |
-
if len(text) != len(candidates):
|
346 |
-
raise ValueError("`text` and `candidates` must have the same length.")
|
347 |
-
if isinstance(text[0], str): # change to list of text
|
348 |
-
text = [text]
|
349 |
-
candidates = [candidates]
|
350 |
-
|
351 |
-
samples = [
|
352 |
-
RelikReaderSample(tokens=t, candidates=c)
|
353 |
-
for t, c in zip(text, candidates)
|
354 |
-
]
|
355 |
-
|
356 |
-
if samples is not None:
|
357 |
-
# function that creates a batch from the 'current_batch' list
|
358 |
-
def output_batch() -> Dict[str, Any]:
|
359 |
-
assert (
|
360 |
-
len(
|
361 |
-
set(
|
362 |
-
[
|
363 |
-
len(elem["predictable_candidates"])
|
364 |
-
for elem in current_batch
|
365 |
-
]
|
366 |
-
)
|
367 |
-
)
|
368 |
-
== 1
|
369 |
-
), " ".join(
|
370 |
-
map(
|
371 |
-
str,
|
372 |
-
[len(elem["predictable_candidates"]) for elem in current_batch],
|
373 |
-
)
|
374 |
-
)
|
375 |
-
|
376 |
-
batch_dict = dict()
|
377 |
-
|
378 |
-
de_values_by_field = {
|
379 |
-
fn: [de[fn] for de in current_batch if fn in de]
|
380 |
-
for fn in self.fields_batcher
|
381 |
-
}
|
382 |
-
|
383 |
-
# in case you provide fields batchers but in the batch
|
384 |
-
# there are no elements for that field
|
385 |
-
de_values_by_field = {
|
386 |
-
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
|
387 |
-
}
|
388 |
-
|
389 |
-
assert len(set([len(v) for v in de_values_by_field.values()]))
|
390 |
-
|
391 |
-
# todo: maybe we should report the user about possible
|
392 |
-
# fields filtering due to "None" instances
|
393 |
-
de_values_by_field = {
|
394 |
-
fn: fvs
|
395 |
-
for fn, fvs in de_values_by_field.items()
|
396 |
-
if all([fv is not None for fv in fvs])
|
397 |
-
}
|
398 |
-
|
399 |
-
for field_name, field_values in de_values_by_field.items():
|
400 |
-
field_batch = (
|
401 |
-
self.fields_batcher[field_name]([fv[0] for fv in field_values])
|
402 |
-
if self.fields_batcher[field_name] is not None
|
403 |
-
else field_values
|
404 |
-
)
|
405 |
-
|
406 |
-
batch_dict[field_name] = field_batch
|
407 |
-
|
408 |
-
batch_dict = {
|
409 |
-
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
410 |
-
for k, v in batch_dict.items()
|
411 |
-
}
|
412 |
-
return batch_dict
|
413 |
-
|
414 |
-
current_batch = []
|
415 |
-
predictions = []
|
416 |
-
current_cand_len = -1
|
417 |
-
|
418 |
-
for sample in tqdm(samples, disable=not progress_bar):
|
419 |
-
sample.candidates = [NME_SYMBOL] + sample.candidates
|
420 |
-
inputs_text = self._build_input(sample.tokens, sample.candidates)
|
421 |
-
model_inputs = self.tokenizer(
|
422 |
-
inputs_text,
|
423 |
-
is_split_into_words=True,
|
424 |
-
add_special_tokens=False,
|
425 |
-
padding=False,
|
426 |
-
truncation=True,
|
427 |
-
max_length=max_length or self.tokenizer.model_max_length,
|
428 |
-
return_offsets_mapping=True,
|
429 |
-
return_tensors="pt",
|
430 |
-
)
|
431 |
-
model_inputs["special_symbols_mask"] = (
|
432 |
-
model_inputs["input_ids"] > self.tokenizer.vocab_size
|
433 |
-
)
|
434 |
-
# prediction mask is 0 until the first special symbol
|
435 |
-
model_inputs["token_type_ids"] = (
|
436 |
-
torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0
|
437 |
-
).long()
|
438 |
-
# shift prediction_mask to the left
|
439 |
-
model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll(
|
440 |
-
shifts=-1, dims=1
|
441 |
-
)
|
442 |
-
model_inputs["prediction_mask"][:, -1] = 1
|
443 |
-
model_inputs["prediction_mask"][:, 0] = 1
|
444 |
-
|
445 |
-
assert (
|
446 |
-
len(model_inputs["special_symbols_mask"])
|
447 |
-
== len(model_inputs["prediction_mask"])
|
448 |
-
== len(model_inputs["input_ids"])
|
449 |
-
)
|
450 |
-
|
451 |
-
model_inputs["sample"] = sample
|
452 |
-
|
453 |
-
# compute cand_len using special_symbols_mask
|
454 |
-
model_inputs["predictable_candidates"] = sample.candidates[
|
455 |
-
: model_inputs["special_symbols_mask"].sum().item()
|
456 |
-
]
|
457 |
-
# cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]])
|
458 |
-
offsets = model_inputs.pop("offset_mapping")
|
459 |
-
offsets = offsets[model_inputs["prediction_mask"] == 0]
|
460 |
-
sample.token2word, sample.word2token = self._compute_offsets(offsets)
|
461 |
-
future_max_len = max(
|
462 |
-
len(model_inputs["input_ids"]),
|
463 |
-
max([len(b["input_ids"]) for b in current_batch], default=0),
|
464 |
-
)
|
465 |
-
future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
|
466 |
-
|
467 |
-
if len(current_batch) > 0 and (
|
468 |
-
(
|
469 |
-
len(model_inputs["predictable_candidates"]) != current_cand_len
|
470 |
-
and current_cand_len != -1
|
471 |
-
)
|
472 |
-
or (
|
473 |
-
isinstance(token_batch_size, int)
|
474 |
-
and future_tokens_per_batch >= token_batch_size
|
475 |
-
)
|
476 |
-
or len(current_batch) == max_batch_size
|
477 |
-
):
|
478 |
-
batch_inputs = output_batch()
|
479 |
-
current_batch = []
|
480 |
-
predictions.extend(list(self.batch_predict(**batch_inputs)))
|
481 |
-
current_cand_len = len(model_inputs["predictable_candidates"])
|
482 |
-
current_batch.append(model_inputs)
|
483 |
-
|
484 |
-
if current_batch:
|
485 |
-
batch_inputs = output_batch()
|
486 |
-
predictions.extend(list(self.batch_predict(**batch_inputs)))
|
487 |
-
else:
|
488 |
-
predictions = list(
|
489 |
-
self.batch_predict(
|
490 |
-
input_ids,
|
491 |
-
attention_mask,
|
492 |
-
token_type_ids,
|
493 |
-
prediction_mask,
|
494 |
-
special_symbols_mask,
|
495 |
-
special_symbols_mask_entities,
|
496 |
-
*args,
|
497 |
-
**kwargs,
|
498 |
-
)
|
499 |
-
)
|
500 |
-
return predictions
|
501 |
-
|
502 |
-
@property
|
503 |
-
def device(self) -> torch.device:
|
504 |
-
"""
|
505 |
-
The device of the model.
|
506 |
-
"""
|
507 |
-
return next(self.parameters()).device
|
508 |
-
|
509 |
-
@property
|
510 |
-
def tokenizer(self) -> tr.PreTrainedTokenizer:
|
511 |
-
"""
|
512 |
-
The tokenizer.
|
513 |
-
"""
|
514 |
-
if self._tokenizer:
|
515 |
-
return self._tokenizer
|
516 |
-
|
517 |
-
self._tokenizer = tr.AutoTokenizer.from_pretrained(
|
518 |
-
self.relik_reader_model.config.name_or_path
|
519 |
-
)
|
520 |
-
return self._tokenizer
|
521 |
-
|
522 |
-
@property
|
523 |
-
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
|
524 |
-
fields_batchers = {
|
525 |
-
"input_ids": lambda x: batchify(
|
526 |
-
x, padding_value=self.tokenizer.pad_token_id
|
527 |
-
),
|
528 |
-
"attention_mask": lambda x: batchify(x, padding_value=0),
|
529 |
-
"token_type_ids": lambda x: batchify(x, padding_value=0),
|
530 |
-
"prediction_mask": lambda x: batchify(x, padding_value=1),
|
531 |
-
"global_attention": lambda x: batchify(x, padding_value=0),
|
532 |
-
"token2word": None,
|
533 |
-
"sample": None,
|
534 |
-
"special_symbols_mask": lambda x: batchify(x, padding_value=False),
|
535 |
-
"special_symbols_mask_entities": lambda x: batchify(x, padding_value=False),
|
536 |
-
}
|
537 |
-
if "roberta" in self.relik_reader_model.config.model_type:
|
538 |
-
del fields_batchers["token_type_ids"]
|
539 |
-
|
540 |
-
return fields_batchers
|
541 |
-
|
542 |
-
def save_pretrained(
|
543 |
-
self,
|
544 |
-
output_dir: str,
|
545 |
-
model_name: str | None = None,
|
546 |
-
push_to_hub: bool = False,
|
547 |
-
**kwargs,
|
548 |
-
) -> None:
|
549 |
-
"""
|
550 |
-
Saves the model to the given path.
|
551 |
-
Args:
|
552 |
-
output_dir: The path to save the model to.
|
553 |
-
model_name: The name of the model.
|
554 |
-
push_to_hub: Whether to push the model to the hub.
|
555 |
-
"""
|
556 |
-
# create the output directory
|
557 |
-
output_dir = Path(output_dir)
|
558 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
559 |
-
|
560 |
-
model_name = model_name or "relik-reader-for-span-extraction"
|
561 |
-
|
562 |
-
logger.info(f"Saving reader to {output_dir / model_name}")
|
563 |
-
|
564 |
-
# save the model
|
565 |
-
self.relik_reader_model.register_for_auto_class()
|
566 |
-
self.relik_reader_model.save_pretrained(
|
567 |
-
output_dir / model_name, push_to_hub=push_to_hub, **kwargs
|
568 |
-
)
|
569 |
-
|
570 |
-
logger.info("Saving reader to disk done.")
|
571 |
-
|
572 |
-
if self.tokenizer:
|
573 |
-
self.tokenizer.save_pretrained(
|
574 |
-
output_dir / model_name, push_to_hub=push_to_hub, **kwargs
|
575 |
-
)
|
576 |
-
logger.info("Saving tokenizer to disk done.")
|
577 |
-
|
578 |
-
|
579 |
-
class RelikReader:
|
580 |
-
def __init__(self, model_path: str, predict_nmes: bool = False):
|
581 |
-
model, model_conf = load_model_and_conf(model_path)
|
582 |
-
model.training = False
|
583 |
-
model.eval()
|
584 |
-
|
585 |
-
val_dataset_conf = model_conf.data.val_dataset
|
586 |
-
val_dataset_conf.special_symbols = get_special_symbols(
|
587 |
-
model_conf.model.entities_per_forward
|
588 |
-
)
|
589 |
-
val_dataset_conf.transformer_model = model_conf.model.model.transformer_model
|
590 |
-
|
591 |
-
self.predictor = RelikReaderPredictor(
|
592 |
-
model,
|
593 |
-
dataset_conf=model_conf.data.val_dataset,
|
594 |
-
predict_nmes=predict_nmes,
|
595 |
-
)
|
596 |
-
self.model_path = model_path
|
597 |
-
|
598 |
-
def link_entities(
|
599 |
-
self,
|
600 |
-
dataset_path_or_samples: str | Iterator[RelikReaderSample],
|
601 |
-
token_batch_size: int = 2048,
|
602 |
-
progress_bar: bool = False,
|
603 |
-
) -> List[RelikReaderSample]:
|
604 |
-
data_input = (
|
605 |
-
(dataset_path_or_samples, None)
|
606 |
-
if isinstance(dataset_path_or_samples, str)
|
607 |
-
else (None, dataset_path_or_samples)
|
608 |
-
)
|
609 |
-
return self.predictor.predict(
|
610 |
-
*data_input,
|
611 |
-
dataset_conf=None,
|
612 |
-
token_batch_size=token_batch_size,
|
613 |
-
progress_bar=progress_bar,
|
614 |
-
)
|
615 |
-
|
616 |
-
# def save_pretrained(self, path: Union[str, Path]):
|
617 |
-
# self.predictor.save(path)
|
618 |
-
|
619 |
-
|
620 |
-
def main():
|
621 |
-
rr = RelikReader("riccorl/relik-reader-aida-deberta-small-old", predict_nmes=True)
|
622 |
-
predictions = rr.link_entities(
|
623 |
-
"/Users/ric/Documents/PhD/Projects/relik/data/reader/aida/testa.jsonl"
|
624 |
-
)
|
625 |
-
print(predictions)
|
626 |
-
|
627 |
-
|
628 |
-
if __name__ == "__main__":
|
629 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|