CarlosMalaga commited on
Commit
3376207
1 Parent(s): e883357

Delete relik

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. relik/__init__.py +0 -1
  2. relik/common/__init__.py +0 -0
  3. relik/common/log.py +0 -97
  4. relik/common/upload.py +0 -128
  5. relik/common/utils.py +0 -609
  6. relik/inference/__init__.py +0 -0
  7. relik/inference/annotator.py +0 -428
  8. relik/inference/data/__init__.py +0 -0
  9. relik/inference/data/objects.py +0 -64
  10. relik/inference/data/tokenizers/__init__.py +0 -89
  11. relik/inference/data/tokenizers/base_tokenizer.py +0 -84
  12. relik/inference/data/tokenizers/regex_tokenizer.py +0 -73
  13. relik/inference/data/tokenizers/spacy_tokenizer.py +0 -228
  14. relik/inference/data/tokenizers/whitespace_tokenizer.py +0 -70
  15. relik/inference/data/window/__init__.py +0 -0
  16. relik/inference/data/window/manager.py +0 -262
  17. relik/inference/gerbil.py +0 -254
  18. relik/inference/preprocessing.py +0 -4
  19. relik/inference/serve/__init__.py +0 -0
  20. relik/inference/serve/backend/__init__.py +0 -0
  21. relik/inference/serve/backend/relik.py +0 -210
  22. relik/inference/serve/backend/retriever.py +0 -206
  23. relik/inference/serve/backend/utils.py +0 -29
  24. relik/inference/serve/frontend/__init__.py +0 -0
  25. relik/inference/serve/frontend/relik.py +0 -231
  26. relik/inference/serve/frontend/style.css +0 -33
  27. relik/reader/__init__.py +0 -0
  28. relik/reader/conf/config.yaml +0 -14
  29. relik/reader/conf/data/base.yaml +0 -21
  30. relik/reader/conf/data/re.yaml +0 -54
  31. relik/reader/conf/training/base.yaml +0 -12
  32. relik/reader/conf/training/re.yaml +0 -12
  33. relik/reader/data/__init__.py +0 -0
  34. relik/reader/data/patches.py +0 -51
  35. relik/reader/data/relik_reader_data.py +0 -965
  36. relik/reader/data/relik_reader_data_utils.py +0 -51
  37. relik/reader/data/relik_reader_sample.py +0 -49
  38. relik/reader/lightning_modules/__init__.py +0 -0
  39. relik/reader/lightning_modules/relik_reader_pl_module.py +0 -50
  40. relik/reader/lightning_modules/relik_reader_re_pl_module.py +0 -54
  41. relik/reader/pytorch_modules/__init__.py +0 -0
  42. relik/reader/pytorch_modules/base.py +0 -248
  43. relik/reader/pytorch_modules/hf/__init__.py +0 -2
  44. relik/reader/pytorch_modules/hf/configuration_relik.py +0 -33
  45. relik/reader/pytorch_modules/hf/modeling_relik.py +0 -981
  46. relik/reader/pytorch_modules/optim/__init__.py +0 -6
  47. relik/reader/pytorch_modules/optim/adamw_with_warmup.py +0 -66
  48. relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py +0 -104
  49. relik/reader/pytorch_modules/span.py +0 -367
  50. relik/reader/relik_reader.py +0 -629
relik/__init__.py DELETED
@@ -1 +0,0 @@
1
- from relik.retriever.pytorch_modules.model import GoldenRetriever
 
 
relik/common/__init__.py DELETED
File without changes
relik/common/log.py DELETED
@@ -1,97 +0,0 @@
1
- import logging
2
- import sys
3
- import threading
4
- from typing import Optional
5
-
6
- from rich import get_console
7
-
8
- _lock = threading.Lock()
9
- _default_handler: Optional[logging.Handler] = None
10
-
11
- _default_log_level = logging.WARNING
12
-
13
- # fancy logger
14
- _console = get_console()
15
-
16
-
17
- def _get_library_name() -> str:
18
- return __name__.split(".")[0]
19
-
20
-
21
- def _get_library_root_logger() -> logging.Logger:
22
- return logging.getLogger(_get_library_name())
23
-
24
-
25
- def _configure_library_root_logger() -> None:
26
- global _default_handler
27
-
28
- with _lock:
29
- if _default_handler:
30
- # This library has already configured the library root logger.
31
- return
32
- _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
33
- _default_handler.flush = sys.stderr.flush
34
-
35
- # Apply our default configuration to the library root logger.
36
- library_root_logger = _get_library_root_logger()
37
- library_root_logger.addHandler(_default_handler)
38
- library_root_logger.setLevel(_default_log_level)
39
- library_root_logger.propagate = False
40
-
41
-
42
- def _reset_library_root_logger() -> None:
43
- global _default_handler
44
-
45
- with _lock:
46
- if not _default_handler:
47
- return
48
-
49
- library_root_logger = _get_library_root_logger()
50
- library_root_logger.removeHandler(_default_handler)
51
- library_root_logger.setLevel(logging.NOTSET)
52
- _default_handler = None
53
-
54
-
55
- def set_log_level(level: int, logger: logging.Logger = None) -> None:
56
- """
57
- Set the log level.
58
- Args:
59
- level (:obj:`int`):
60
- Logging level.
61
- logger (:obj:`logging.Logger`):
62
- Logger to set the log level.
63
- """
64
- if not logger:
65
- _configure_library_root_logger()
66
- logger = _get_library_root_logger()
67
- logger.setLevel(level)
68
-
69
-
70
- def get_logger(
71
- name: Optional[str] = None,
72
- level: Optional[int] = None,
73
- formatter: Optional[str] = None,
74
- ) -> logging.Logger:
75
- """
76
- Return a logger with the specified name.
77
- """
78
-
79
- if name is None:
80
- name = _get_library_name()
81
-
82
- _configure_library_root_logger()
83
-
84
- if level is not None:
85
- set_log_level(level)
86
-
87
- if formatter is None:
88
- formatter = logging.Formatter(
89
- "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
90
- )
91
- _default_handler.setFormatter(formatter)
92
-
93
- return logging.getLogger(name)
94
-
95
-
96
- def get_console_logger():
97
- return _console
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/common/upload.py DELETED
@@ -1,128 +0,0 @@
1
- import argparse
2
- import json
3
- import logging
4
- import os
5
- import tempfile
6
- import zipfile
7
- from datetime import datetime
8
- from pathlib import Path
9
- from typing import Optional, Union
10
-
11
- import huggingface_hub
12
-
13
- from relik.common.log import get_logger
14
- from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
15
-
16
- logger = get_logger(level=logging.DEBUG)
17
-
18
-
19
- def create_info_file(tmpdir: Path):
20
- logger.debug("Computing md5 of model.zip")
21
- md5 = get_md5(tmpdir / "model.zip")
22
- date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
23
-
24
- logger.debug("Dumping info.json file")
25
- with (tmpdir / "info.json").open("w") as f:
26
- json.dump(dict(md5=md5, upload_date=date), f, indent=2)
27
-
28
-
29
- def zip_run(
30
- dir_path: Union[str, os.PathLike],
31
- tmpdir: Union[str, os.PathLike],
32
- zip_name: str = "model.zip",
33
- ) -> Path:
34
- logger.debug(f"zipping {dir_path} to {tmpdir}")
35
- # creates a zip version of the provided dir_path
36
- run_dir = Path(dir_path)
37
- zip_path = tmpdir / zip_name
38
-
39
- with zipfile.ZipFile(zip_path, "w") as zip_file:
40
- # fully zip the run directory maintaining its structure
41
- for file in run_dir.rglob("*.*"):
42
- if file.is_dir():
43
- continue
44
-
45
- zip_file.write(file, arcname=file.relative_to(run_dir))
46
-
47
- return zip_path
48
-
49
-
50
- def upload(
51
- model_dir: Union[str, os.PathLike],
52
- model_name: str,
53
- organization: Optional[str] = None,
54
- repo_name: Optional[str] = None,
55
- commit: Optional[str] = None,
56
- archive: bool = False,
57
- ):
58
- token = huggingface_hub.HfFolder.get_token()
59
- if token is None:
60
- print(
61
- "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
62
- )
63
- return
64
-
65
- repo_id = repo_name or model_name
66
- if organization is not None:
67
- repo_id = f"{organization}/{repo_id}"
68
- with tempfile.TemporaryDirectory() as tmpdir:
69
- api = huggingface_hub.HfApi()
70
- repo_url = api.create_repo(
71
- token=token,
72
- repo_id=repo_id,
73
- exist_ok=True,
74
- )
75
- repo = huggingface_hub.Repository(
76
- str(tmpdir), clone_from=repo_url, use_auth_token=token
77
- )
78
-
79
- tmp_path = Path(tmpdir)
80
- if archive:
81
- # otherwise we zip the model_dir
82
- logger.debug(f"Zipping {model_dir} to {tmp_path}")
83
- zip_run(model_dir, tmp_path)
84
- create_info_file(tmp_path)
85
- else:
86
- # if the user wants to upload a transformers model, we don't need to zip it
87
- # we just need to copy the files to the tmpdir
88
- logger.debug(f"Copying {model_dir} to {tmpdir}")
89
- os.system(f"cp -r {model_dir}/* {tmpdir}")
90
-
91
- # this method automatically puts large files (>10MB) into git lfs
92
- repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
93
-
94
-
95
- def parse_args() -> argparse.Namespace:
96
- parser = argparse.ArgumentParser()
97
- parser.add_argument(
98
- "model_dir", help="The directory of the model you want to upload"
99
- )
100
- parser.add_argument("model_name", help="The model you want to upload")
101
- parser.add_argument(
102
- "--organization",
103
- help="the name of the organization where you want to upload the model",
104
- )
105
- parser.add_argument(
106
- "--repo_name",
107
- help="Optional name to use when uploading to the HuggingFace repository",
108
- )
109
- parser.add_argument(
110
- "--commit", help="Commit message to use when pushing to the HuggingFace Hub"
111
- )
112
- parser.add_argument(
113
- "--archive",
114
- action="store_true",
115
- help="""
116
- Whether to compress the model directory before uploading it.
117
- If True, the model directory will be zipped and the zip file will be uploaded.
118
- If False, the model directory will be uploaded as is.""",
119
- )
120
- return parser.parse_args()
121
-
122
-
123
- def main():
124
- upload(**vars(parse_args()))
125
-
126
-
127
- if __name__ == "__main__":
128
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/common/utils.py DELETED
@@ -1,609 +0,0 @@
1
- import importlib.util
2
- import json
3
- import logging
4
- import os
5
- import shutil
6
- import tarfile
7
- import tempfile
8
- from functools import partial
9
- from hashlib import sha256
10
- from pathlib import Path
11
- from typing import Any, BinaryIO, Dict, List, Optional, Union
12
- from urllib.parse import urlparse
13
- from zipfile import ZipFile, is_zipfile
14
-
15
- import huggingface_hub
16
- import requests
17
- import tqdm
18
- from filelock import FileLock
19
- from transformers.utils.hub import cached_file as hf_cached_file
20
-
21
- from relik.common.log import get_logger
22
-
23
- # name constants
24
- WEIGHTS_NAME = "weights.pt"
25
- ONNX_WEIGHTS_NAME = "weights.onnx"
26
- CONFIG_NAME = "config.yaml"
27
- LABELS_NAME = "labels.json"
28
-
29
- # SAPIENZANLP_USER_NAME = "sapienzanlp"
30
- SAPIENZANLP_USER_NAME = "riccorl"
31
- SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}"
32
- SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = (
33
- f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip"
34
- )
35
- # path constants
36
- SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp")
37
- SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S"
38
-
39
-
40
- logger = get_logger(__name__)
41
-
42
-
43
- def sapienzanlp_model_urls(model_id: str) -> str:
44
- """
45
- Returns the URL for a possible SapienzaNLP valid model.
46
-
47
- Args:
48
- model_id (:obj:`str`):
49
- A SapienzaNLP model id.
50
-
51
- Returns:
52
- :obj:`str`: The url for the model id.
53
- """
54
- # check if there is already the namespace of the user
55
- if "/" in model_id:
56
- return model_id
57
- return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id)
58
-
59
-
60
- def is_package_available(package_name: str) -> bool:
61
- """
62
- Check if a package is available.
63
-
64
- Args:
65
- package_name (`str`): The name of the package to check.
66
- """
67
- return importlib.util.find_spec(package_name) is not None
68
-
69
-
70
- def load_json(path: Union[str, Path]) -> Any:
71
- """
72
- Load a json file provided in input.
73
-
74
- Args:
75
- path (`Union[str, Path]`): The path to the json file to load.
76
-
77
- Returns:
78
- `Any`: The loaded json file.
79
- """
80
- with open(path, encoding="utf8") as f:
81
- return json.load(f)
82
-
83
-
84
- def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None):
85
- """
86
- Dump input to json file.
87
-
88
- Args:
89
- document (`Any`): The document to dump.
90
- path (`Union[str, Path]`): The path to dump the document to.
91
- indent (`Optional[int]`): The indent to use for the json file.
92
-
93
- """
94
- with open(path, "w", encoding="utf8") as outfile:
95
- json.dump(document, outfile, indent=indent)
96
-
97
-
98
- def get_md5(path: Path):
99
- """
100
- Get the MD5 value of a path.
101
- """
102
- import hashlib
103
-
104
- with path.open("rb") as fin:
105
- data = fin.read()
106
- return hashlib.md5(data).hexdigest()
107
-
108
-
109
- def file_exists(path: Union[str, os.PathLike]) -> bool:
110
- """
111
- Check if the file at :obj:`path` exists.
112
-
113
- Args:
114
- path (:obj:`str`, :obj:`os.PathLike`):
115
- Path to check.
116
-
117
- Returns:
118
- :obj:`bool`: :obj:`True` if the file exists.
119
- """
120
- return Path(path).exists()
121
-
122
-
123
- def dir_exists(path: Union[str, os.PathLike]) -> bool:
124
- """
125
- Check if the directory at :obj:`path` exists.
126
-
127
- Args:
128
- path (:obj:`str`, :obj:`os.PathLike`):
129
- Path to check.
130
-
131
- Returns:
132
- :obj:`bool`: :obj:`True` if the directory exists.
133
- """
134
- return Path(path).is_dir()
135
-
136
-
137
- def is_remote_url(url_or_filename: Union[str, Path]):
138
- """
139
- Returns :obj:`True` if the input path is an url.
140
-
141
- Args:
142
- url_or_filename (:obj:`str`, :obj:`Path`):
143
- path to check.
144
-
145
- Returns:
146
- :obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise.
147
-
148
- """
149
- if isinstance(url_or_filename, Path):
150
- url_or_filename = str(url_or_filename)
151
- parsed = urlparse(url_or_filename)
152
- return parsed.scheme in ("http", "https")
153
-
154
-
155
- def url_to_filename(resource: str, etag: str = None) -> str:
156
- """
157
- Convert a `resource` into a hashed filename in a repeatable way.
158
- If `etag` is specified, append its hash to the resources's, delimited
159
- by a period.
160
- """
161
- resource_bytes = resource.encode("utf-8")
162
- resource_hash = sha256(resource_bytes)
163
- filename = resource_hash.hexdigest()
164
-
165
- if etag:
166
- etag_bytes = etag.encode("utf-8")
167
- etag_hash = sha256(etag_bytes)
168
- filename += "." + etag_hash.hexdigest()
169
-
170
- return filename
171
-
172
-
173
- def download_resource(
174
- url: str,
175
- temp_file: BinaryIO,
176
- headers=None,
177
- ):
178
- """
179
- Download remote file.
180
- """
181
-
182
- if headers is None:
183
- headers = {}
184
-
185
- r = requests.get(url, stream=True, headers=headers)
186
- r.raise_for_status()
187
- content_length = r.headers.get("Content-Length")
188
- total = int(content_length) if content_length is not None else None
189
- progress = tqdm(
190
- unit="B",
191
- unit_scale=True,
192
- total=total,
193
- desc="Downloading",
194
- disable=logger.level in [logging.NOTSET],
195
- )
196
- for chunk in r.iter_content(chunk_size=1024):
197
- if chunk: # filter out keep-alive new chunks
198
- progress.update(len(chunk))
199
- temp_file.write(chunk)
200
- progress.close()
201
-
202
-
203
- def download_and_cache(
204
- url: Union[str, Path],
205
- cache_dir: Union[str, Path] = None,
206
- force_download: bool = False,
207
- ):
208
- if cache_dir is None:
209
- cache_dir = SAPIENZANLP_CACHE_DIR
210
- if isinstance(url, Path):
211
- url = str(url)
212
-
213
- # check if cache dir exists
214
- Path(cache_dir).mkdir(parents=True, exist_ok=True)
215
-
216
- # check if file is private
217
- headers = {}
218
- try:
219
- r = requests.head(url, allow_redirects=False, timeout=10)
220
- r.raise_for_status()
221
- except requests.exceptions.HTTPError:
222
- if r.status_code == 401:
223
- hf_token = huggingface_hub.HfFolder.get_token()
224
- if hf_token is None:
225
- raise ValueError(
226
- "You need to login to HuggingFace to download this model "
227
- "(use the `huggingface-cli login` command)"
228
- )
229
- headers["Authorization"] = f"Bearer {hf_token}"
230
-
231
- etag = None
232
- try:
233
- r = requests.head(url, allow_redirects=True, timeout=10, headers=headers)
234
- r.raise_for_status()
235
- etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
236
- # We favor a custom header indicating the etag of the linked resource, and
237
- # we fallback to the regular etag header.
238
- # If we don't have any of those, raise an error.
239
- if etag is None:
240
- raise OSError(
241
- "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
242
- )
243
- # In case of a redirect,
244
- # save an extra redirect on the request.get call,
245
- # and ensure we download the exact atomic version even if it changed
246
- # between the HEAD and the GET (unlikely, but hey).
247
- if 300 <= r.status_code <= 399:
248
- url = r.headers["Location"]
249
- except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
250
- # Actually raise for those subclasses of ConnectionError
251
- raise
252
- except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
253
- # Otherwise, our Internet connection is down.
254
- # etag is None
255
- pass
256
-
257
- # get filename from the url
258
- filename = url_to_filename(url, etag)
259
- # get cache path to put the file
260
- cache_path = cache_dir / filename
261
-
262
- # the file is already here, return it
263
- if file_exists(cache_path) and not force_download:
264
- logger.info(
265
- f"{url} found in cache, set `force_download=True` to force the download"
266
- )
267
- return cache_path
268
-
269
- cache_path = str(cache_path)
270
- # Prevent parallel downloads of the same file with a lock.
271
- lock_path = cache_path + ".lock"
272
- with FileLock(lock_path):
273
- # If the download just completed while the lock was activated.
274
- if file_exists(cache_path) and not force_download:
275
- # Even if returning early like here, the lock will be released.
276
- return cache_path
277
-
278
- temp_file_manager = partial(
279
- tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
280
- )
281
-
282
- # Download to temporary file, then copy to cache dir once finished.
283
- # Otherwise, you get corrupt cache entries if the download gets interrupted.
284
- with temp_file_manager() as temp_file:
285
- logger.info(
286
- f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}"
287
- )
288
- download_resource(url, temp_file, headers)
289
-
290
- logger.info(f"storing {url} in cache at {cache_path}")
291
- os.replace(temp_file.name, cache_path)
292
-
293
- # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
294
- umask = os.umask(0o666)
295
- os.umask(umask)
296
- os.chmod(cache_path, 0o666 & ~umask)
297
-
298
- logger.info(f"creating metadata file for {cache_path}")
299
- meta = {"url": url} # , "etag": etag}
300
- meta_path = cache_path + ".json"
301
- with open(meta_path, "w") as meta_file:
302
- json.dump(meta, meta_file)
303
-
304
- return cache_path
305
-
306
-
307
- def download_from_hf(
308
- path_or_repo_id: Union[str, Path],
309
- filenames: Optional[List[str]],
310
- cache_dir: Union[str, Path] = None,
311
- force_download: bool = False,
312
- resume_download: bool = False,
313
- proxies: Optional[Dict[str, str]] = None,
314
- use_auth_token: Optional[Union[bool, str]] = None,
315
- revision: Optional[str] = None,
316
- local_files_only: bool = False,
317
- subfolder: str = "",
318
- ):
319
- if isinstance(path_or_repo_id, Path):
320
- path_or_repo_id = str(path_or_repo_id)
321
-
322
- downloaded_paths = []
323
- for filename in filenames:
324
- downloaded_path = hf_cached_file(
325
- path_or_repo_id,
326
- filename,
327
- cache_dir=cache_dir,
328
- force_download=force_download,
329
- proxies=proxies,
330
- resume_download=resume_download,
331
- use_auth_token=use_auth_token,
332
- revision=revision,
333
- local_files_only=local_files_only,
334
- subfolder=subfolder,
335
- )
336
- downloaded_paths.append(downloaded_path)
337
-
338
- # we want the folder where the files are downloaded
339
- # the best guess is the parent folder of the first file
340
- probably_the_folder = Path(downloaded_paths[0]).parent
341
- return probably_the_folder
342
-
343
-
344
- def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str:
345
- """
346
- Resolve a model name or directory to a model archive name or directory.
347
-
348
- Args:
349
- model_name_or_dir (:obj:`str` or :obj:`os.PathLike`):
350
- A model name or directory.
351
-
352
- Returns:
353
- :obj:`str`: The model archive name or directory.
354
- """
355
- if is_remote_url(model_name_or_dir):
356
- # if model_name_or_dir is a URL
357
- # download it and try to load
358
- model_archive = model_name_or_dir
359
- elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file():
360
- # if model_name_or_dir is a local directory or
361
- # an archive file try to load it
362
- model_archive = model_name_or_dir
363
- else:
364
- # probably model_name_or_dir is a sapienzanlp model id
365
- # guess the url and try to download
366
- model_name_or_dir_ = model_name_or_dir
367
- # raise ValueError(f"Providing a model id is not supported yet.")
368
- model_archive = sapienzanlp_model_urls(model_name_or_dir_)
369
-
370
- return model_archive
371
-
372
-
373
- def from_cache(
374
- url_or_filename: Union[str, Path],
375
- cache_dir: Union[str, Path] = None,
376
- force_download: bool = False,
377
- resume_download: bool = False,
378
- proxies: Optional[Dict[str, str]] = None,
379
- use_auth_token: Optional[Union[bool, str]] = None,
380
- revision: Optional[str] = None,
381
- local_files_only: bool = False,
382
- subfolder: str = "",
383
- filenames: Optional[List[str]] = None,
384
- ) -> Path:
385
- """
386
- Given something that could be either a local path or a URL (or a SapienzaNLP model id),
387
- determine which one and return a path to the corresponding file.
388
-
389
- Args:
390
- url_or_filename (:obj:`str` or :obj:`Path`):
391
- A path to a local file or a URL (or a SapienzaNLP model id).
392
- cache_dir (:obj:`str` or :obj:`Path`, `optional`):
393
- Path to a directory in which a downloaded file will be cached.
394
- force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
395
- Whether or not to re-download the file even if it already exists.
396
- resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
397
- Whether or not to delete incompletely received files. Attempts to resume the download if such a file
398
- exists.
399
- proxies (:obj:`Dict[str, str]`, `optional`):
400
- A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
401
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
402
- use_auth_token (:obj:`Union[bool, str]`, `optional`):
403
- Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from
404
- :obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token.
405
- revision (:obj:`str`, `optional`):
406
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
407
- git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
408
- identifier allowed by git.
409
- local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
410
- Whether or not to raise an error if the file to be downloaded is local.
411
- subfolder (:obj:`str`, `optional`):
412
- In case the relevant file is in a subfolder of the URL, specify it here.
413
- filenames (:obj:`List[str]`, `optional`):
414
- List of filenames to look for in the directory structure.
415
-
416
- Returns:
417
- :obj:`Path`: Path to the cached file.
418
- """
419
-
420
- url_or_filename = model_name_or_path_resolver(url_or_filename)
421
-
422
- if cache_dir is None:
423
- cache_dir = SAPIENZANLP_CACHE_DIR
424
-
425
- if file_exists(url_or_filename):
426
- logger.info(f"{url_or_filename} is a local path or file")
427
- output_path = url_or_filename
428
- elif is_remote_url(url_or_filename):
429
- # URL, so get it from the cache (downloading if necessary)
430
- output_path = download_and_cache(
431
- url_or_filename,
432
- cache_dir=cache_dir,
433
- force_download=force_download,
434
- )
435
- else:
436
- if filenames is None:
437
- filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME]
438
- output_path = download_from_hf(
439
- url_or_filename,
440
- filenames,
441
- cache_dir,
442
- force_download,
443
- resume_download,
444
- proxies,
445
- use_auth_token,
446
- revision,
447
- local_files_only,
448
- subfolder,
449
- )
450
-
451
- # if is_hf_hub_url(url_or_filename):
452
- # HuggingFace Hub
453
- # output_path = hf_hub_download_url(url_or_filename)
454
- # elif is_remote_url(url_or_filename):
455
- # # URL, so get it from the cache (downloading if necessary)
456
- # output_path = download_and_cache(
457
- # url_or_filename,
458
- # cache_dir=cache_dir,
459
- # force_download=force_download,
460
- # )
461
- # elif file_exists(url_or_filename):
462
- # logger.info(f"{url_or_filename} is a local path or file")
463
- # # File, and it exists.
464
- # output_path = url_or_filename
465
- # elif urlparse(url_or_filename).scheme == "":
466
- # # File, but it doesn't exist.
467
- # raise EnvironmentError(f"file {url_or_filename} not found")
468
- # else:
469
- # # Something unknown
470
- # raise ValueError(
471
- # f"unable to parse {url_or_filename} as a URL or as a local path"
472
- # )
473
-
474
- if dir_exists(output_path) or (
475
- not is_zipfile(output_path) and not tarfile.is_tarfile(output_path)
476
- ):
477
- return Path(output_path)
478
-
479
- # Path where we extract compressed archives
480
- # for now it will extract it in the same folder
481
- # maybe implement extraction in the sapienzanlp folder
482
- # when using local archive path?
483
- logger.info("Extracting compressed archive")
484
- output_dir, output_file = os.path.split(output_path)
485
- output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
486
- output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
487
-
488
- # already extracted, do not extract
489
- if (
490
- os.path.isdir(output_path_extracted)
491
- and os.listdir(output_path_extracted)
492
- and not force_download
493
- ):
494
- return Path(output_path_extracted)
495
-
496
- # Prevent parallel extractions
497
- lock_path = output_path + ".lock"
498
- with FileLock(lock_path):
499
- shutil.rmtree(output_path_extracted, ignore_errors=True)
500
- os.makedirs(output_path_extracted)
501
- if is_zipfile(output_path):
502
- with ZipFile(output_path, "r") as zip_file:
503
- zip_file.extractall(output_path_extracted)
504
- zip_file.close()
505
- elif tarfile.is_tarfile(output_path):
506
- tar_file = tarfile.open(output_path)
507
- tar_file.extractall(output_path_extracted)
508
- tar_file.close()
509
- else:
510
- raise EnvironmentError(
511
- f"Archive format of {output_path} could not be identified"
512
- )
513
-
514
- # remove lock file, is it safe?
515
- os.remove(lock_path)
516
-
517
- return Path(output_path_extracted)
518
-
519
-
520
- def is_str_a_path(maybe_path: str) -> bool:
521
- """
522
- Check if a string is a path.
523
-
524
- Args:
525
- maybe_path (`str`): The string to check.
526
-
527
- Returns:
528
- `bool`: `True` if the string is a path, `False` otherwise.
529
- """
530
- # first check if it is a path
531
- if Path(maybe_path).exists():
532
- return True
533
- # check if it is a relative path
534
- if Path(os.path.join(os.getcwd(), maybe_path)).exists():
535
- return True
536
- # otherwise it is not a path
537
- return False
538
-
539
-
540
- def relative_to_absolute_path(path: str) -> os.PathLike:
541
- """
542
- Convert a relative path to an absolute path.
543
-
544
- Args:
545
- path (`str`): The relative path to convert.
546
-
547
- Returns:
548
- `os.PathLike`: The absolute path.
549
- """
550
- if not is_str_a_path(path):
551
- raise ValueError(f"{path} is not a path")
552
- if Path(path).exists():
553
- return Path(path).absolute()
554
- if Path(os.path.join(os.getcwd(), path)).exists():
555
- return Path(os.path.join(os.getcwd(), path)).absolute()
556
- raise ValueError(f"{path} is not a path")
557
-
558
-
559
- def to_config(object_to_save: Any) -> Dict[str, Any]:
560
- """
561
- Convert an object to a dictionary.
562
-
563
- Returns:
564
- `Dict[str, Any]`: The dictionary representation of the object.
565
- """
566
-
567
- def obj_to_dict(obj):
568
- match obj:
569
- case dict():
570
- data = {}
571
- for k, v in obj.items():
572
- data[k] = obj_to_dict(v)
573
- return data
574
-
575
- case list() | tuple():
576
- return [obj_to_dict(x) for x in obj]
577
-
578
- case object(__dict__=_):
579
- data = {
580
- "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
581
- }
582
- for k, v in obj.__dict__.items():
583
- if not k.startswith("_"):
584
- data[k] = obj_to_dict(v)
585
- return data
586
-
587
- case _:
588
- return obj
589
-
590
- return obj_to_dict(object_to_save)
591
-
592
-
593
- def get_callable_from_string(callable_fn: str) -> Any:
594
- """
595
- Get a callable from a string.
596
-
597
- Args:
598
- callable_fn (`str`):
599
- The string representation of the callable.
600
-
601
- Returns:
602
- `Any`: The callable.
603
- """
604
- # separate the function name from the module name
605
- module_name, function_name = callable_fn.rsplit(".", 1)
606
- # import the module
607
- module = importlib.import_module(module_name)
608
- # get the function
609
- return getattr(module, function_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/__init__.py DELETED
File without changes
relik/inference/annotator.py DELETED
@@ -1,428 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- from typing import Any, Callable, Dict, Optional, Union
4
-
5
- import hydra
6
- from omegaconf import OmegaConf
7
- from relik.retriever.indexers.faiss import FaissDocumentIndex
8
- from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
9
- from rich.pretty import pprint
10
-
11
- from relik.common.log import get_console_logger, get_logger
12
- from relik.common.upload import upload
13
- from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string
14
- from relik.inference.data.objects import EntitySpan, RelikOutput
15
- from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
16
- from relik.inference.data.window.manager import WindowManager
17
- from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
18
- from relik.reader.relik_reader import RelikReader
19
- from relik.retriever.data.utils import batch_generator
20
- from relik.retriever.indexers.base import BaseDocumentIndex
21
- from relik.retriever.pytorch_modules.model import GoldenRetriever
22
-
23
- logger = get_logger(__name__)
24
- console_logger = get_console_logger()
25
-
26
-
27
- class Relik:
28
- """
29
- Relik main class. It is a wrapper around a retriever and a reader.
30
-
31
- Args:
32
- retriever (`Optional[GoldenRetriever]`, `optional`):
33
- The retriever to use. If `None`, a retriever will be instantiated from the
34
- provided `question_encoder`, `passage_encoder` and `document_index`.
35
- Defaults to `None`.
36
- question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
37
- The question encoder to use. If `retriever` is `None`, a retriever will be
38
- instantiated from this parameter. Defaults to `None`.
39
- passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
40
- The passage encoder to use. If `retriever` is `None`, a retriever will be
41
- instantiated from this parameter. Defaults to `None`.
42
- document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`):
43
- The document index to use. If `retriever` is `None`, a retriever will be
44
- instantiated from this parameter. Defaults to `None`.
45
- reader (`Optional[Union[str, RelikReader]]`, `optional`):
46
- The reader to use. If `None`, a reader will be instantiated from the
47
- provided `reader`. Defaults to `None`.
48
- retriever_device (`str`, `optional`, defaults to `cpu`):
49
- The device to use for the retriever.
50
-
51
- """
52
-
53
- def __init__(
54
- self,
55
- retriever: GoldenRetriever | None = None,
56
- question_encoder: str | GoldenRetrieverModel | None = None,
57
- passage_encoder: str | GoldenRetrieverModel | None = None,
58
- document_index: str | BaseDocumentIndex | None = None,
59
- reader: str | RelikReader | None = None,
60
- device: str = "cpu",
61
- retriever_device: str | None = None,
62
- document_index_device: str | None = None,
63
- reader_device: str | None = None,
64
- precision: int = 32,
65
- retriever_precision: int | None = None,
66
- document_index_precision: int | None = None,
67
- reader_precision: int | None = None,
68
- reader_kwargs: dict | None = None,
69
- retriever_kwargs: dict | None = None,
70
- candidates_preprocessing_fn: str | Callable | None = None,
71
- top_k: int | None = None,
72
- window_size: int | None = None,
73
- window_stride: int | None = None,
74
- **kwargs,
75
- ) -> None:
76
- # retriever
77
- retriever_device = retriever_device or device
78
- document_index_device = document_index_device or device
79
- retriever_precision = retriever_precision or precision
80
- document_index_precision = document_index_precision or precision
81
- if retriever is None and question_encoder is None:
82
- raise ValueError(
83
- "Either `retriever` or `question_encoder` must be provided"
84
- )
85
- if retriever is None:
86
- self.retriever_kwargs = dict(
87
- question_encoder=question_encoder,
88
- passage_encoder=passage_encoder,
89
- document_index=document_index,
90
- device=retriever_device,
91
- precision=retriever_precision,
92
- index_device=document_index_device,
93
- index_precision=document_index_precision,
94
- )
95
- # overwrite default_retriever_kwargs with retriever_kwargs
96
- self.retriever_kwargs.update(retriever_kwargs or {})
97
- retriever = GoldenRetriever(**self.retriever_kwargs)
98
- retriever.training = False
99
- retriever.eval()
100
- self.retriever = retriever
101
-
102
- # reader
103
- self.reader_device = reader_device or device
104
- self.reader_precision = reader_precision or precision
105
- self.reader_kwargs = reader_kwargs
106
- if isinstance(reader, str):
107
- reader_kwargs = reader_kwargs or {}
108
- reader = RelikReaderForSpanExtraction(reader, **reader_kwargs)
109
- self.reader = reader
110
-
111
- # windowization stuff
112
- self.tokenizer = SpacyTokenizer(language="en")
113
- self.window_manager: WindowManager | None = None
114
-
115
- # candidates preprocessing
116
- # TODO: maybe move this logic somewhere else
117
- candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x)
118
- if isinstance(candidates_preprocessing_fn, str):
119
- candidates_preprocessing_fn = get_callable_from_string(
120
- candidates_preprocessing_fn
121
- )
122
- self.candidates_preprocessing_fn = candidates_preprocessing_fn
123
-
124
- # inference params
125
- self.top_k = top_k
126
- self.window_size = window_size
127
- self.window_stride = window_stride
128
-
129
- def __call__(
130
- self,
131
- text: Union[str, list],
132
- top_k: Optional[int] = None,
133
- window_size: Optional[int] = None,
134
- window_stride: Optional[int] = None,
135
- retriever_batch_size: Optional[int] = 32,
136
- reader_batch_size: Optional[int] = 32,
137
- return_also_windows: bool = False,
138
- **kwargs,
139
- ) -> Union[RelikOutput, list[RelikOutput]]:
140
- """
141
- Annotate a text with entities.
142
-
143
- Args:
144
- text (`str` or `list`):
145
- The text to annotate. If a list is provided, each element of the list
146
- will be annotated separately.
147
- top_k (`int`, `optional`, defaults to `None`):
148
- The number of candidates to retrieve for each window.
149
- window_size (`int`, `optional`, defaults to `None`):
150
- The size of the window. If `None`, the whole text will be annotated.
151
- window_stride (`int`, `optional`, defaults to `None`):
152
- The stride of the window. If `None`, there will be no overlap between windows.
153
- retriever_batch_size (`int`, `optional`, defaults to `None`):
154
- The batch size to use for the retriever. The whole input is the batch for the retriever.
155
- reader_batch_size (`int`, `optional`, defaults to `None`):
156
- The batch size to use for the reader. The whole input is the batch for the reader.
157
- return_also_windows (`bool`, `optional`, defaults to `False`):
158
- Whether to return the windows in the output.
159
- **kwargs:
160
- Additional keyword arguments to pass to the retriever and the reader.
161
-
162
- Returns:
163
- `RelikOutput` or `list[RelikOutput]`:
164
- The annotated text. If a list was provided as input, a list of
165
- `RelikOutput` objects will be returned.
166
- """
167
- if top_k is None:
168
- top_k = self.top_k or 100
169
- if window_size is None:
170
- window_size = self.window_size
171
- if window_stride is None:
172
- window_stride = self.window_stride
173
-
174
- if isinstance(text, str):
175
- text = [text]
176
-
177
- if window_size is not None:
178
- if self.window_manager is None:
179
- self.window_manager = WindowManager(self.tokenizer)
180
-
181
- if window_size == "sentence":
182
- # todo: implement sentence windowizer
183
- raise NotImplementedError("Sentence windowizer not implemented yet")
184
-
185
- # if window_size < window_stride:
186
- # raise ValueError(
187
- # f"Window size ({window_size}) must be greater than window stride ({window_stride})"
188
- # )
189
-
190
- # window generator
191
- windows = [
192
- window
193
- for doc_id, t in enumerate(text)
194
- for window in self.window_manager.create_windows(
195
- t,
196
- window_size=window_size,
197
- stride=window_stride,
198
- doc_id=doc_id,
199
- )
200
- ]
201
-
202
- # retrieve candidates first
203
- windows_candidates = []
204
- # TODO: Move batching inside retriever
205
- for batch in batch_generator(windows, batch_size=retriever_batch_size):
206
- retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k)
207
- windows_candidates.extend(
208
- [[p.label for p in predictions] for predictions in retriever_out]
209
- )
210
-
211
- # add passage to the windows
212
- for window, candidates in zip(windows, windows_candidates):
213
- window.window_candidates = [
214
- self.candidates_preprocessing_fn(c) for c in candidates
215
- ]
216
-
217
- windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size)
218
- windows = self.window_manager.merge_windows(windows)
219
-
220
- # transform predictions into RelikOutput objects
221
- output = []
222
- for w in windows:
223
- sample_output = RelikOutput(
224
- text=text[w.doc_id],
225
- labels=sorted(
226
- [
227
- EntitySpan(
228
- start=ss, end=se, label=sl, text=text[w.doc_id][ss:se]
229
- )
230
- for ss, se, sl in w.predicted_window_labels_chars
231
- ],
232
- key=lambda x: x.start,
233
- ),
234
- )
235
- output.append(sample_output)
236
-
237
- if return_also_windows:
238
- for i, sample_output in enumerate(output):
239
- sample_output.windows = [w for w in windows if w.doc_id == i]
240
-
241
- # if only one text was provided, return a single RelikOutput object
242
- if len(output) == 1:
243
- return output[0]
244
-
245
- return output
246
-
247
- @classmethod
248
- def from_pretrained(
249
- cls,
250
- model_name_or_dir: Union[str, os.PathLike],
251
- config_kwargs: Optional[Dict] = None,
252
- config_file_name: str = CONFIG_NAME,
253
- *args,
254
- **kwargs,
255
- ) -> "Relik":
256
- cache_dir = kwargs.pop("cache_dir", None)
257
- force_download = kwargs.pop("force_download", False)
258
-
259
- model_dir = from_cache(
260
- model_name_or_dir,
261
- filenames=[config_file_name],
262
- cache_dir=cache_dir,
263
- force_download=force_download,
264
- )
265
-
266
- config_path = model_dir / config_file_name
267
- if not config_path.exists():
268
- raise FileNotFoundError(
269
- f"Model configuration file not found at {config_path}."
270
- )
271
-
272
- # overwrite config with config_kwargs
273
- config = OmegaConf.load(config_path)
274
- if config_kwargs is not None:
275
- # TODO: check merging behavior
276
- config = OmegaConf.merge(config, OmegaConf.create(config_kwargs))
277
- # do we want to print the config? I like it
278
- pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
279
-
280
- # load relik from config
281
- relik = hydra.utils.instantiate(config, *args, **kwargs)
282
-
283
- return relik
284
-
285
- def save_pretrained(
286
- self,
287
- output_dir: Union[str, os.PathLike],
288
- config: Optional[Dict[str, Any]] = None,
289
- config_file_name: Optional[str] = None,
290
- save_weights: bool = False,
291
- push_to_hub: bool = False,
292
- model_id: Optional[str] = None,
293
- organization: Optional[str] = None,
294
- repo_name: Optional[str] = None,
295
- **kwargs,
296
- ):
297
- """
298
- Save the configuration of Relik to the specified directory as a YAML file.
299
-
300
- Args:
301
- output_dir (`str`):
302
- The directory to save the configuration file to.
303
- config (`Optional[Dict[str, Any]]`, `optional`):
304
- The configuration to save. If `None`, the current configuration will be
305
- saved. Defaults to `None`.
306
- config_file_name (`Optional[str]`, `optional`):
307
- The name of the configuration file. Defaults to `config.yaml`.
308
- save_weights (`bool`, `optional`):
309
- Whether to save the weights of the model. Defaults to `False`.
310
- push_to_hub (`bool`, `optional`):
311
- Whether to push the saved model to the hub. Defaults to `False`.
312
- model_id (`Optional[str]`, `optional`):
313
- The id of the model to push to the hub. If `None`, the name of the
314
- directory will be used. Defaults to `None`.
315
- organization (`Optional[str]`, `optional`):
316
- The organization to push the model to. Defaults to `None`.
317
- repo_name (`Optional[str]`, `optional`):
318
- The name of the repository to push the model to. Defaults to `None`.
319
- **kwargs:
320
- Additional keyword arguments to pass to `OmegaConf.save`.
321
- """
322
- if config is None:
323
- # create a default config
324
- config = {
325
- "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}"
326
- }
327
- if self.retriever is not None:
328
- if self.retriever.question_encoder is not None:
329
- config[
330
- "question_encoder"
331
- ] = self.retriever.question_encoder.name_or_path
332
- if self.retriever.passage_encoder is not None:
333
- config[
334
- "passage_encoder"
335
- ] = self.retriever.passage_encoder.name_or_path
336
- if self.retriever.document_index is not None:
337
- config["document_index"] = self.retriever.document_index.name_or_dir
338
- if self.reader is not None:
339
- config["reader"] = self.reader.model_path
340
-
341
- config["retriever_kwargs"] = self.retriever_kwargs
342
- config["reader_kwargs"] = self.reader_kwargs
343
- # expand the fn as to be able to save it and load it later
344
- config[
345
- "candidates_preprocessing_fn"
346
- ] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}"
347
-
348
- # these are model-specific and should be saved
349
- config["top_k"] = self.top_k
350
- config["window_size"] = self.window_size
351
- config["window_stride"] = self.window_stride
352
-
353
- config_file_name = config_file_name or CONFIG_NAME
354
-
355
- # create the output directory
356
- output_dir = Path(output_dir)
357
- output_dir.mkdir(parents=True, exist_ok=True)
358
-
359
- logger.info(f"Saving relik config to {output_dir / config_file_name}")
360
- # pretty print the config
361
- pprint(config, console=console_logger, expand_all=True)
362
- OmegaConf.save(config, output_dir / config_file_name)
363
-
364
- if save_weights:
365
- model_id = model_id or output_dir.name
366
- retriever_model_id = model_id + "-retriever"
367
- # save weights
368
- logger.info(f"Saving retriever to {output_dir / retriever_model_id}")
369
- self.retriever.save_pretrained(
370
- output_dir / retriever_model_id,
371
- question_encoder_name=retriever_model_id + "-question-encoder",
372
- passage_encoder_name=retriever_model_id + "-passage-encoder",
373
- document_index_name=retriever_model_id + "-index",
374
- push_to_hub=push_to_hub,
375
- organization=organization,
376
- repo_name=repo_name,
377
- **kwargs,
378
- )
379
- reader_model_id = model_id + "-reader"
380
- logger.info(f"Saving reader to {output_dir / reader_model_id}")
381
- self.reader.save_pretrained(
382
- output_dir / reader_model_id,
383
- push_to_hub=push_to_hub,
384
- organization=organization,
385
- repo_name=repo_name,
386
- **kwargs,
387
- )
388
-
389
- if push_to_hub:
390
- # push to hub
391
- logger.info(f"Pushing to hub")
392
- model_id = model_id or output_dir.name
393
- upload(output_dir, model_id, organization=organization, repo_name=repo_name)
394
-
395
-
396
- def main():
397
- from pprint import pprint
398
-
399
- document_index = FaissDocumentIndex.from_pretrained(
400
- "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index",
401
- config_kwargs={"_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex", "index_type": "IVFx,Flat"},
402
- )
403
-
404
- relik = Relik(
405
- question_encoder="/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
406
- document_index=document_index,
407
- reader="/root/relik-spaces/models/relik-reader-aida-deberta-small",
408
- device="cuda",
409
- precision=16,
410
- top_k=100,
411
- window_size=32,
412
- window_stride=16,
413
- candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
414
- )
415
-
416
- input_text = """
417
- Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore.
418
- The 92-year-old billionaire did not disclose the trust to the government in July 2015.
419
- Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty.
420
- Ecclestone had been due to go on trial next month.
421
- """
422
-
423
- preds = relik(input_text)
424
- pprint(preds)
425
-
426
-
427
- if __name__ == "__main__":
428
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/data/__init__.py DELETED
File without changes
relik/inference/data/objects.py DELETED
@@ -1,64 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import List, NamedTuple, Optional
5
-
6
- from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
7
-
8
-
9
- @dataclass
10
- class Word:
11
- """
12
- A word representation that includes text, index in the sentence, POS tag, lemma,
13
- dependency relation, and similar information.
14
-
15
- # Parameters
16
- text : `str`, optional
17
- The text representation.
18
- index : `int`, optional
19
- The word offset in the sentence.
20
- lemma : `str`, optional
21
- The lemma of this word.
22
- pos : `str`, optional
23
- The coarse-grained part of speech of this word.
24
- dep : `str`, optional
25
- The dependency relation for this word.
26
-
27
- input_id : `int`, optional
28
- Integer representation of the word, used to pass it to a model.
29
- token_type_id : `int`, optional
30
- Token type id used by some transformers.
31
- attention_mask: `int`, optional
32
- Attention mask used by transformers, indicates to the model which tokens should
33
- be attended to, and which should not.
34
- """
35
-
36
- text: str
37
- index: int
38
- start_char: Optional[int] = None
39
- end_char: Optional[int] = None
40
- # preprocessing fields
41
- lemma: Optional[str] = None
42
- pos: Optional[str] = None
43
- dep: Optional[str] = None
44
- head: Optional[int] = None
45
-
46
- def __str__(self):
47
- return self.text
48
-
49
- def __repr__(self):
50
- return self.__str__()
51
-
52
-
53
- class EntitySpan(NamedTuple):
54
- start: int
55
- end: int
56
- label: str
57
- text: str
58
-
59
-
60
- @dataclass
61
- class RelikOutput:
62
- text: str
63
- labels: List[EntitySpan]
64
- windows: Optional[List[RelikReaderSample]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/data/tokenizers/__init__.py DELETED
@@ -1,89 +0,0 @@
1
- SPACY_LANGUAGE_MAPPER = {
2
- "ca": "ca_core_news_sm",
3
- "da": "da_core_news_sm",
4
- "de": "de_core_news_sm",
5
- "el": "el_core_news_sm",
6
- "en": "en_core_web_sm",
7
- "es": "es_core_news_sm",
8
- "fr": "fr_core_news_sm",
9
- "it": "it_core_news_sm",
10
- "ja": "ja_core_news_sm",
11
- "lt": "lt_core_news_sm",
12
- "mk": "mk_core_news_sm",
13
- "nb": "nb_core_news_sm",
14
- "nl": "nl_core_news_sm",
15
- "pl": "pl_core_news_sm",
16
- "pt": "pt_core_news_sm",
17
- "ro": "ro_core_news_sm",
18
- "ru": "ru_core_news_sm",
19
- "xx": "xx_sent_ud_sm",
20
- "zh": "zh_core_web_sm",
21
- "ca_core_news_sm": "ca_core_news_sm",
22
- "ca_core_news_md": "ca_core_news_md",
23
- "ca_core_news_lg": "ca_core_news_lg",
24
- "ca_core_news_trf": "ca_core_news_trf",
25
- "da_core_news_sm": "da_core_news_sm",
26
- "da_core_news_md": "da_core_news_md",
27
- "da_core_news_lg": "da_core_news_lg",
28
- "da_core_news_trf": "da_core_news_trf",
29
- "de_core_news_sm": "de_core_news_sm",
30
- "de_core_news_md": "de_core_news_md",
31
- "de_core_news_lg": "de_core_news_lg",
32
- "de_dep_news_trf": "de_dep_news_trf",
33
- "el_core_news_sm": "el_core_news_sm",
34
- "el_core_news_md": "el_core_news_md",
35
- "el_core_news_lg": "el_core_news_lg",
36
- "en_core_web_sm": "en_core_web_sm",
37
- "en_core_web_md": "en_core_web_md",
38
- "en_core_web_lg": "en_core_web_lg",
39
- "en_core_web_trf": "en_core_web_trf",
40
- "es_core_news_sm": "es_core_news_sm",
41
- "es_core_news_md": "es_core_news_md",
42
- "es_core_news_lg": "es_core_news_lg",
43
- "es_dep_news_trf": "es_dep_news_trf",
44
- "fr_core_news_sm": "fr_core_news_sm",
45
- "fr_core_news_md": "fr_core_news_md",
46
- "fr_core_news_lg": "fr_core_news_lg",
47
- "fr_dep_news_trf": "fr_dep_news_trf",
48
- "it_core_news_sm": "it_core_news_sm",
49
- "it_core_news_md": "it_core_news_md",
50
- "it_core_news_lg": "it_core_news_lg",
51
- "ja_core_news_sm": "ja_core_news_sm",
52
- "ja_core_news_md": "ja_core_news_md",
53
- "ja_core_news_lg": "ja_core_news_lg",
54
- "ja_dep_news_trf": "ja_dep_news_trf",
55
- "lt_core_news_sm": "lt_core_news_sm",
56
- "lt_core_news_md": "lt_core_news_md",
57
- "lt_core_news_lg": "lt_core_news_lg",
58
- "mk_core_news_sm": "mk_core_news_sm",
59
- "mk_core_news_md": "mk_core_news_md",
60
- "mk_core_news_lg": "mk_core_news_lg",
61
- "nb_core_news_sm": "nb_core_news_sm",
62
- "nb_core_news_md": "nb_core_news_md",
63
- "nb_core_news_lg": "nb_core_news_lg",
64
- "nl_core_news_sm": "nl_core_news_sm",
65
- "nl_core_news_md": "nl_core_news_md",
66
- "nl_core_news_lg": "nl_core_news_lg",
67
- "pl_core_news_sm": "pl_core_news_sm",
68
- "pl_core_news_md": "pl_core_news_md",
69
- "pl_core_news_lg": "pl_core_news_lg",
70
- "pt_core_news_sm": "pt_core_news_sm",
71
- "pt_core_news_md": "pt_core_news_md",
72
- "pt_core_news_lg": "pt_core_news_lg",
73
- "ro_core_news_sm": "ro_core_news_sm",
74
- "ro_core_news_md": "ro_core_news_md",
75
- "ro_core_news_lg": "ro_core_news_lg",
76
- "ru_core_news_sm": "ru_core_news_sm",
77
- "ru_core_news_md": "ru_core_news_md",
78
- "ru_core_news_lg": "ru_core_news_lg",
79
- "xx_ent_wiki_sm": "xx_ent_wiki_sm",
80
- "xx_sent_ud_sm": "xx_sent_ud_sm",
81
- "zh_core_web_sm": "zh_core_web_sm",
82
- "zh_core_web_md": "zh_core_web_md",
83
- "zh_core_web_lg": "zh_core_web_lg",
84
- "zh_core_web_trf": "zh_core_web_trf",
85
- }
86
-
87
- from relik.inference.data.tokenizers.regex_tokenizer import RegexTokenizer
88
- from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
89
- from relik.inference.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/data/tokenizers/base_tokenizer.py DELETED
@@ -1,84 +0,0 @@
1
- from typing import List, Union
2
-
3
- from relik.inference.data.objects import Word
4
-
5
-
6
- class BaseTokenizer:
7
- """
8
- A :obj:`Tokenizer` splits strings of text into single words, optionally adds
9
- pos tags and perform lemmatization.
10
- """
11
-
12
- def __call__(
13
- self,
14
- texts: Union[str, List[str], List[List[str]]],
15
- is_split_into_words: bool = False,
16
- **kwargs
17
- ) -> List[List[Word]]:
18
- """
19
- Tokenize the input into single words.
20
-
21
- Args:
22
- texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
23
- Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
24
- is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
25
- If :obj:`True` and the input is a string, the input is split on spaces.
26
-
27
- Returns:
28
- :obj:`List[List[Word]]`: The input text tokenized in single words.
29
- """
30
- raise NotImplementedError
31
-
32
- def tokenize(self, text: str) -> List[Word]:
33
- """
34
- Implements splitting words into tokens.
35
-
36
- Args:
37
- text (:obj:`str`):
38
- Text to tokenize.
39
-
40
- Returns:
41
- :obj:`List[Word]`: The input text tokenized in single words.
42
-
43
- """
44
- raise NotImplementedError
45
-
46
- def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
47
- """
48
- Implements batch splitting words into tokens.
49
-
50
- Args:
51
- texts (:obj:`List[str]`):
52
- Batch of text to tokenize.
53
-
54
- Returns:
55
- :obj:`List[List[Word]]`: The input batch tokenized in single words.
56
-
57
- """
58
- return [self.tokenize(text) for text in texts]
59
-
60
- @staticmethod
61
- def check_is_batched(
62
- texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
63
- ):
64
- """
65
- Check if input is batched or a single sample.
66
-
67
- Args:
68
- texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
69
- Text to check.
70
- is_split_into_words (:obj:`bool`):
71
- If :obj:`True` and the input is a string, the input is split on spaces.
72
-
73
- Returns:
74
- :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
75
- """
76
- return bool(
77
- (not is_split_into_words and isinstance(texts, (list, tuple)))
78
- or (
79
- is_split_into_words
80
- and isinstance(texts, (list, tuple))
81
- and texts
82
- and isinstance(texts[0], (list, tuple))
83
- )
84
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/data/tokenizers/regex_tokenizer.py DELETED
@@ -1,73 +0,0 @@
1
- import re
2
- from typing import List, Union
3
-
4
- from overrides import overrides
5
-
6
- from relik.inference.data.objects import Word
7
- from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
8
-
9
-
10
- class RegexTokenizer(BaseTokenizer):
11
- """
12
- A :obj:`Tokenizer` that splits the text based on a simple regex.
13
- """
14
-
15
- def __init__(self):
16
- super(RegexTokenizer, self).__init__()
17
- # regex for splitting on spaces and punctuation and new lines
18
- # self._regex = re.compile(r"\S+|[\[\](),.!?;:\"]|\\n")
19
- self._regex = re.compile(
20
- r"\w+|\$[\d\.]+|\S+", re.UNICODE | re.MULTILINE | re.DOTALL
21
- )
22
-
23
- def __call__(
24
- self,
25
- texts: Union[str, List[str], List[List[str]]],
26
- is_split_into_words: bool = False,
27
- **kwargs,
28
- ) -> List[List[Word]]:
29
- """
30
- Tokenize the input into single words by splitting using a simple regex.
31
-
32
- Args:
33
- texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
34
- Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
35
- is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
36
- If :obj:`True` and the input is a string, the input is split on spaces.
37
-
38
- Returns:
39
- :obj:`List[List[Word]]`: The input text tokenized in single words.
40
-
41
- Example::
42
-
43
- >>> from relik.retriever.serve.tokenizers.regex_tokenizer import RegexTokenizer
44
-
45
- >>> regex_tokenizer = RegexTokenizer()
46
- >>> regex_tokenizer("Mary sold the car to John.")
47
-
48
- """
49
- # check if input is batched or a single sample
50
- is_batched = self.check_is_batched(texts, is_split_into_words)
51
-
52
- if is_batched:
53
- tokenized = self.tokenize_batch(texts)
54
- else:
55
- tokenized = self.tokenize(texts)
56
-
57
- return tokenized
58
-
59
- @overrides
60
- def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
61
- if not isinstance(text, (str, list)):
62
- raise ValueError(
63
- f"text must be either `str` or `list`, found: `{type(text)}`"
64
- )
65
-
66
- if isinstance(text, list):
67
- text = " ".join(text)
68
- return [
69
- Word(t[0], i, start_char=t[1], end_char=t[2])
70
- for i, t in enumerate(
71
- (m.group(0), m.start(), m.end()) for m in self._regex.finditer(text)
72
- )
73
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/data/tokenizers/spacy_tokenizer.py DELETED
@@ -1,228 +0,0 @@
1
- import logging
2
- from typing import Dict, List, Tuple, Union
3
-
4
- import spacy
5
-
6
- # from ipa.common.utils import load_spacy
7
- from overrides import overrides
8
- from spacy.cli.download import download as spacy_download
9
- from spacy.tokens import Doc
10
-
11
- from relik.common.log import get_logger
12
- from relik.inference.data.objects import Word
13
- from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
14
- from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
15
-
16
- logger = get_logger(level=logging.DEBUG)
17
-
18
- # Spacy and Stanza stuff
19
-
20
- LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
21
-
22
-
23
- def load_spacy(
24
- language: str,
25
- pos_tags: bool = False,
26
- lemma: bool = False,
27
- parse: bool = False,
28
- split_on_spaces: bool = False,
29
- ) -> spacy.Language:
30
- """
31
- Download and load spacy model.
32
-
33
- Args:
34
- language (:obj:`str`, defaults to :obj:`en`):
35
- Language of the text to tokenize.
36
- pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
37
- If :obj:`True`, performs POS tagging with spacy model.
38
- lemma (:obj:`bool`, optional, defaults to :obj:`False`):
39
- If :obj:`True`, performs lemmatization with spacy model.
40
- parse (:obj:`bool`, optional, defaults to :obj:`False`):
41
- If :obj:`True`, performs dependency parsing with spacy model.
42
- split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
43
- If :obj:`True`, will split by spaces without performing tokenization.
44
-
45
- Returns:
46
- :obj:`spacy.Language`: The spacy model loaded.
47
- """
48
- exclude = ["vectors", "textcat", "ner"]
49
- if not pos_tags:
50
- exclude.append("tagger")
51
- if not lemma:
52
- exclude.append("lemmatizer")
53
- if not parse:
54
- exclude.append("parser")
55
-
56
- # check if the model is already loaded
57
- # if so, there is no need to reload it
58
- spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
59
- if spacy_params not in LOADED_SPACY_MODELS:
60
- try:
61
- spacy_tagger = spacy.load(language, exclude=exclude)
62
- except OSError:
63
- logger.warning(
64
- "Spacy model '%s' not found. Downloading and installing.", language
65
- )
66
- spacy_download(language)
67
- spacy_tagger = spacy.load(language, exclude=exclude)
68
-
69
- # if everything is disabled, return only the tokenizer
70
- # for faster tokenization
71
- # TODO: is it really faster?
72
- # if len(exclude) >= 6:
73
- # spacy_tagger = spacy_tagger.tokenizer
74
- LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
75
-
76
- return LOADED_SPACY_MODELS[spacy_params]
77
-
78
-
79
- class SpacyTokenizer(BaseTokenizer):
80
- """
81
- A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
82
-
83
- Args:
84
- language (:obj:`str`, optional, defaults to :obj:`en`):
85
- Language of the text to tokenize.
86
- return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
87
- If :obj:`True`, performs POS tagging with spacy model.
88
- return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
89
- If :obj:`True`, performs lemmatization with spacy model.
90
- return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
91
- If :obj:`True`, performs dependency parsing with spacy model.
92
- split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
93
- If :obj:`True`, will split by spaces without performing tokenization.
94
- use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
95
- If :obj:`True`, will load the Stanza model on GPU.
96
- """
97
-
98
- def __init__(
99
- self,
100
- language: str = "en",
101
- return_pos_tags: bool = False,
102
- return_lemmas: bool = False,
103
- return_deps: bool = False,
104
- split_on_spaces: bool = False,
105
- use_gpu: bool = False,
106
- ):
107
- super(SpacyTokenizer, self).__init__()
108
- if language not in SPACY_LANGUAGE_MAPPER:
109
- raise ValueError(
110
- f"`{language}` language not supported. The supported "
111
- f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
112
- )
113
- if use_gpu:
114
- # load the model on GPU
115
- # if the GPU is not available or not correctly configured,
116
- # it will rise an error
117
- spacy.require_gpu()
118
- self.spacy = load_spacy(
119
- SPACY_LANGUAGE_MAPPER[language],
120
- return_pos_tags,
121
- return_lemmas,
122
- return_deps,
123
- split_on_spaces,
124
- )
125
- self.split_on_spaces = split_on_spaces
126
-
127
- def __call__(
128
- self,
129
- texts: Union[str, List[str], List[List[str]]],
130
- is_split_into_words: bool = False,
131
- **kwargs,
132
- ) -> Union[List[Word], List[List[Word]]]:
133
- """
134
- Tokenize the input into single words using SpaCy models.
135
-
136
- Args:
137
- texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
138
- Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
139
- is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
140
- If :obj:`True` and the input is a string, the input is split on spaces.
141
-
142
- Returns:
143
- :obj:`List[List[Word]]`: The input text tokenized in single words.
144
-
145
- Example::
146
-
147
- >>> from ipa import SpacyTokenizer
148
-
149
- >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
150
- >>> spacy_tokenizer("Mary sold the car to John.")
151
-
152
- """
153
- # check if input is batched or a single sample
154
- is_batched = self.check_is_batched(texts, is_split_into_words)
155
- if is_batched:
156
- tokenized = self.tokenize_batch(texts)
157
- else:
158
- tokenized = self.tokenize(texts)
159
- return tokenized
160
-
161
- @overrides
162
- def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
163
- if self.split_on_spaces:
164
- if isinstance(text, str):
165
- text = text.split(" ")
166
- spaces = [True] * len(text)
167
- text = Doc(self.spacy.vocab, words=text, spaces=spaces)
168
- return self._clean_tokens(self.spacy(text))
169
-
170
- @overrides
171
- def tokenize_batch(
172
- self, texts: Union[List[str], List[List[str]]]
173
- ) -> List[List[Word]]:
174
- if self.split_on_spaces:
175
- if isinstance(texts[0], str):
176
- texts = [text.split(" ") for text in texts]
177
- spaces = [[True] * len(text) for text in texts]
178
- texts = [
179
- Doc(self.spacy.vocab, words=text, spaces=space)
180
- for text, space in zip(texts, spaces)
181
- ]
182
- return [self._clean_tokens(tokens) for tokens in self.spacy.pipe(texts)]
183
-
184
- @staticmethod
185
- def _clean_tokens(tokens: Doc) -> List[Word]:
186
- """
187
- Converts spaCy tokens to :obj:`Word`.
188
-
189
- Args:
190
- tokens (:obj:`spacy.tokens.Doc`):
191
- Tokens from SpaCy model.
192
-
193
- Returns:
194
- :obj:`List[Word]`: The SpaCy model output converted into :obj:`Word` objects.
195
- """
196
- words = [
197
- Word(
198
- token.text,
199
- token.i,
200
- token.idx,
201
- token.idx + len(token),
202
- token.lemma_,
203
- token.pos_,
204
- token.dep_,
205
- token.head.i,
206
- )
207
- for token in tokens
208
- ]
209
- return words
210
-
211
-
212
- class WhitespaceSpacyTokenizer:
213
- """Simple white space tokenizer for SpaCy."""
214
-
215
- def __init__(self, vocab):
216
- self.vocab = vocab
217
-
218
- def __call__(self, text):
219
- if isinstance(text, str):
220
- words = text.split(" ")
221
- elif isinstance(text, list):
222
- words = text
223
- else:
224
- raise ValueError(
225
- f"text must be either `str` or `list`, found: `{type(text)}`"
226
- )
227
- spaces = [True] * len(words)
228
- return Doc(self.vocab, words=words, spaces=spaces)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/data/tokenizers/whitespace_tokenizer.py DELETED
@@ -1,70 +0,0 @@
1
- import re
2
- from typing import List, Union
3
-
4
- from overrides import overrides
5
-
6
- from relik.inference.data.objects import Word
7
- from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
8
-
9
-
10
- class WhitespaceTokenizer(BaseTokenizer):
11
- """
12
- A :obj:`Tokenizer` that splits the text on spaces.
13
- """
14
-
15
- def __init__(self):
16
- super(WhitespaceTokenizer, self).__init__()
17
- self.whitespace_regex = re.compile(r"\S+")
18
-
19
- def __call__(
20
- self,
21
- texts: Union[str, List[str], List[List[str]]],
22
- is_split_into_words: bool = False,
23
- **kwargs,
24
- ) -> List[List[Word]]:
25
- """
26
- Tokenize the input into single words by splitting on spaces.
27
-
28
- Args:
29
- texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
30
- Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
31
- is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
32
- If :obj:`True` and the input is a string, the input is split on spaces.
33
-
34
- Returns:
35
- :obj:`List[List[Word]]`: The input text tokenized in single words.
36
-
37
- Example::
38
-
39
- >>> from nlp_preprocessing_wrappers import WhitespaceTokenizer
40
-
41
- >>> whitespace_tokenizer = WhitespaceTokenizer()
42
- >>> whitespace_tokenizer("Mary sold the car to John .")
43
-
44
- """
45
- # check if input is batched or a single sample
46
- is_batched = self.check_is_batched(texts, is_split_into_words)
47
-
48
- if is_batched:
49
- tokenized = self.tokenize_batch(texts)
50
- else:
51
- tokenized = self.tokenize(texts)
52
-
53
- return tokenized
54
-
55
- @overrides
56
- def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
57
- if not isinstance(text, (str, list)):
58
- raise ValueError(
59
- f"text must be either `str` or `list`, found: `{type(text)}`"
60
- )
61
-
62
- if isinstance(text, list):
63
- text = " ".join(text)
64
- return [
65
- Word(t[0], i, start_char=t[1], end_char=t[2])
66
- for i, t in enumerate(
67
- (m.group(0), m.start(), m.end())
68
- for m in self.whitespace_regex.finditer(text)
69
- )
70
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/data/window/__init__.py DELETED
File without changes
relik/inference/data/window/manager.py DELETED
@@ -1,262 +0,0 @@
1
- import collections
2
- import itertools
3
- from dataclasses import dataclass
4
- from typing import List, Optional, Set, Tuple
5
-
6
- from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
7
- from relik.reader.data.relik_reader_sample import RelikReaderSample
8
-
9
-
10
- @dataclass
11
- class Window:
12
- doc_id: int
13
- window_id: int
14
- text: str
15
- tokens: List[str]
16
- doc_topic: Optional[str]
17
- offset: int
18
- token2char_start: dict
19
- token2char_end: dict
20
- window_candidates: Optional[List[str]] = None
21
-
22
-
23
- class WindowManager:
24
- def __init__(self, tokenizer: BaseTokenizer) -> None:
25
- self.tokenizer = tokenizer
26
-
27
- def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]:
28
- tokenized_document = self.tokenizer(document)
29
- tokens = []
30
- tokens_char_mapping = []
31
- for token in tokenized_document:
32
- tokens.append(token.text)
33
- tokens_char_mapping.append((token.start_char, token.end_char))
34
- return tokens, tokens_char_mapping
35
-
36
- def create_windows(
37
- self,
38
- document: str,
39
- window_size: int,
40
- stride: int,
41
- doc_id: int = 0,
42
- doc_topic: str = None,
43
- ) -> List[RelikReaderSample]:
44
- document_tokens, tokens_char_mapping = self.tokenize(document)
45
- if doc_topic is None:
46
- doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
47
- document_windows = []
48
- if len(document_tokens) <= window_size:
49
- text = document
50
- # relik_reader_sample = RelikReaderSample()
51
- document_windows.append(
52
- # Window(
53
- RelikReaderSample(
54
- doc_id=doc_id,
55
- window_id=0,
56
- text=text,
57
- tokens=document_tokens,
58
- doc_topic=doc_topic,
59
- offset=0,
60
- token2char_start={
61
- str(i): tokens_char_mapping[i][0]
62
- for i in range(len(document_tokens))
63
- },
64
- token2char_end={
65
- str(i): tokens_char_mapping[i][1]
66
- for i in range(len(document_tokens))
67
- },
68
- )
69
- )
70
- else:
71
- for window_id, i in enumerate(range(0, len(document_tokens), stride)):
72
- # if the last stride is smaller than the window size, then we can
73
- # include more tokens form the previous window.
74
- if i != 0 and i + window_size > len(document_tokens):
75
- overflowing_tokens = i + window_size - len(document_tokens)
76
- if overflowing_tokens >= stride:
77
- break
78
- i -= overflowing_tokens
79
-
80
- involved_token_indices = list(
81
- range(i, min(i + window_size, len(document_tokens) - 1))
82
- )
83
- window_tokens = [document_tokens[j] for j in involved_token_indices]
84
- window_text_start = tokens_char_mapping[involved_token_indices[0]][0]
85
- window_text_end = tokens_char_mapping[involved_token_indices[-1]][1]
86
- text = document[window_text_start:window_text_end]
87
- document_windows.append(
88
- # Window(
89
- RelikReaderSample(
90
- # dict(
91
- doc_id=doc_id,
92
- window_id=window_id,
93
- text=text,
94
- tokens=window_tokens,
95
- doc_topic=doc_topic,
96
- offset=window_text_start,
97
- token2char_start={
98
- str(i): tokens_char_mapping[ti][0]
99
- for i, ti in enumerate(involved_token_indices)
100
- },
101
- token2char_end={
102
- str(i): tokens_char_mapping[ti][1]
103
- for i, ti in enumerate(involved_token_indices)
104
- },
105
- # )
106
- )
107
- )
108
- return document_windows
109
-
110
- def merge_windows(
111
- self, windows: List[RelikReaderSample]
112
- ) -> List[RelikReaderSample]:
113
- windows_by_doc_id = collections.defaultdict(list)
114
- for window in windows:
115
- windows_by_doc_id[window.doc_id].append(window)
116
-
117
- merged_window_by_doc = {
118
- doc_id: self.merge_doc_windows(doc_windows)
119
- for doc_id, doc_windows in windows_by_doc_id.items()
120
- }
121
-
122
- return list(merged_window_by_doc.values())
123
-
124
- def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
125
- if len(windows) == 1:
126
- return windows[0]
127
-
128
- if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
129
- windows = sorted(windows, key=(lambda x: x.offset))
130
-
131
- window_accumulator = windows[0]
132
-
133
- for next_window in windows[1:]:
134
- window_accumulator = self._merge_window_pair(
135
- window_accumulator, next_window
136
- )
137
-
138
- return window_accumulator
139
-
140
- def _merge_tokens(
141
- self, window1: RelikReaderSample, window2: RelikReaderSample
142
- ) -> Tuple[list, dict, dict]:
143
- w1_tokens = window1.tokens[1:-1]
144
- w2_tokens = window2.tokens[1:-1]
145
-
146
- # find intersection
147
- tokens_intersection = None
148
- for k in reversed(range(1, len(w1_tokens))):
149
- if w1_tokens[-k:] == w2_tokens[:k]:
150
- tokens_intersection = k
151
- break
152
- assert tokens_intersection is not None, (
153
- f"{window1.doc_id} - {window1.sent_id} - {window1.offset}"
154
- + f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n"
155
- + f"w1 tokens: {w1_tokens}\n"
156
- + f"w2 tokens: {w2_tokens}\n"
157
- )
158
-
159
- final_tokens = (
160
- [window1.tokens[0]] # CLS
161
- + w1_tokens
162
- + w2_tokens[tokens_intersection:]
163
- + [window1.tokens[-1]] # SEP
164
- )
165
-
166
- w2_starting_offset = len(w1_tokens) - tokens_intersection
167
-
168
- def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
169
- final_t2c = dict()
170
- final_t2c.update(t2c1)
171
- for t, c in t2c2.items():
172
- t = int(t)
173
- if t < tokens_intersection:
174
- continue
175
- final_t2c[str(t + w2_starting_offset)] = c
176
- return final_t2c
177
-
178
- return (
179
- final_tokens,
180
- merge_char_mapping(window1.token2char_start, window2.token2char_start),
181
- merge_char_mapping(window1.token2char_end, window2.token2char_end),
182
- )
183
-
184
- def _merge_span_annotation(
185
- self, span_annotation1: List[list], span_annotation2: List[list]
186
- ) -> List[list]:
187
- uniq_store = set()
188
- final_span_annotation_store = []
189
- for span_annotation in itertools.chain(span_annotation1, span_annotation2):
190
- span_annotation_id = tuple(span_annotation)
191
- if span_annotation_id not in uniq_store:
192
- uniq_store.add(span_annotation_id)
193
- final_span_annotation_store.append(span_annotation)
194
- return sorted(final_span_annotation_store, key=lambda x: x[0])
195
-
196
- def _merge_predictions(
197
- self,
198
- window1: RelikReaderSample,
199
- window2: RelikReaderSample,
200
- ) -> Tuple[Set[Tuple[int, int, str]], dict]:
201
- merged_predictions = window1.predicted_window_labels_chars.union(
202
- window2.predicted_window_labels_chars
203
- )
204
-
205
- span_title_probabilities = dict()
206
- # probabilities
207
- for span_prediction, predicted_probs in itertools.chain(
208
- window1.probs_window_labels_chars.items(),
209
- window2.probs_window_labels_chars.items(),
210
- ):
211
- if span_prediction not in span_title_probabilities:
212
- span_title_probabilities[span_prediction] = predicted_probs
213
-
214
- return merged_predictions, span_title_probabilities
215
-
216
- def _merge_window_pair(
217
- self,
218
- window1: RelikReaderSample,
219
- window2: RelikReaderSample,
220
- ) -> RelikReaderSample:
221
- merging_output = dict()
222
-
223
- if getattr(window1, "doc_id", None) is not None:
224
- assert window1.doc_id == window2.doc_id
225
-
226
- if getattr(window1, "offset", None) is not None:
227
- assert (
228
- window1.offset < window2.offset
229
- ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
230
-
231
- merging_output["doc_id"] = window1.doc_id
232
- merging_output["offset"] = window2.offset
233
-
234
- m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
235
- window1, window2
236
- )
237
-
238
- window_labels = None
239
- if getattr(window1, "window_labels", None) is not None:
240
- window_labels = self._merge_span_annotation(
241
- window1.window_labels, window2.window_labels
242
- )
243
- (
244
- predicted_window_labels_chars,
245
- probs_window_labels_chars,
246
- ) = self._merge_predictions(
247
- window1,
248
- window2,
249
- )
250
-
251
- merging_output.update(
252
- dict(
253
- tokens=m_tokens,
254
- token2char_start=m_token2char_start,
255
- token2char_end=m_token2char_end,
256
- window_labels=window_labels,
257
- predicted_window_labels_chars=predicted_window_labels_chars,
258
- probs_window_labels_chars=probs_window_labels_chars,
259
- )
260
- )
261
-
262
- return RelikReaderSample(**merging_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/gerbil.py DELETED
@@ -1,254 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
- import re
5
- import sys
6
- from http.server import BaseHTTPRequestHandler, HTTPServer
7
- from typing import Iterator, List, Optional, Tuple
8
-
9
- from relik.inference.annotator import Relik
10
- from relik.inference.data.objects import RelikOutput
11
-
12
- # sys.path += ['../']
13
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
14
-
15
-
16
- import logging
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
-
21
- class GerbilAlbyManager:
22
- def __init__(
23
- self,
24
- annotator: Optional[Relik] = None,
25
- response_logger_dir: Optional[str] = None,
26
- ) -> None:
27
- self.annotator = annotator
28
- self.response_logger_dir = response_logger_dir
29
- self.predictions_counter = 0
30
- self.labels_mapping = None
31
-
32
- def annotate(self, document: str):
33
- relik_output: RelikOutput = self.annotator(document)
34
- annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels]
35
- if self.labels_mapping is not None:
36
- return [
37
- (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations
38
- ]
39
- return annotations
40
-
41
- def set_mapping_file(self, mapping_file_path: str):
42
- with open(mapping_file_path) as f:
43
- labels_mapping = json.load(f)
44
- self.labels_mapping = {v: k for k, v in labels_mapping.items()}
45
-
46
- def write_response_bundle(
47
- self,
48
- document: str,
49
- new_document: str,
50
- annotations: list,
51
- mapped_annotations: list,
52
- ) -> None:
53
- if self.response_logger_dir is None:
54
- return
55
-
56
- if not os.path.isdir(self.response_logger_dir):
57
- os.mkdir(self.response_logger_dir)
58
-
59
- with open(
60
- f"{self.response_logger_dir}/{self.predictions_counter}.json", "w"
61
- ) as f:
62
- out_json_obj = dict(
63
- document=document,
64
- new_document=new_document,
65
- annotations=annotations,
66
- mapped_annotations=mapped_annotations,
67
- )
68
-
69
- out_json_obj["span_annotations"] = [
70
- (ss, se, document[ss:se], label) for (ss, se, label) in annotations
71
- ]
72
-
73
- out_json_obj["span_mapped_annotations"] = [
74
- (ss, se, new_document[ss:se], label)
75
- for (ss, se, label) in mapped_annotations
76
- ]
77
-
78
- json.dump(out_json_obj, f, indent=2)
79
-
80
- self.predictions_counter += 1
81
-
82
-
83
- manager = GerbilAlbyManager()
84
-
85
-
86
- def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]:
87
- pattern_subs = {
88
- "-LPR- ": " (",
89
- "-RPR-": ")",
90
- "\n\n": "\n",
91
- "-LRB-": "(",
92
- "-RRB-": ")",
93
- '","': ",",
94
- }
95
-
96
- document_acc = document
97
- curr_offset = 0
98
- char2offset = []
99
-
100
- matchings = re.finditer("({})".format("|".join(pattern_subs)), document)
101
- for span_matching in sorted(matchings, key=lambda x: x.span()[0]):
102
- span_start, span_end = span_matching.span()
103
- span_start -= curr_offset
104
- span_end -= curr_offset
105
-
106
- span_text = document_acc[span_start:span_end]
107
- span_sub = pattern_subs[span_text]
108
- document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:]
109
-
110
- offset = len(span_text) - len(span_sub)
111
- curr_offset += offset
112
-
113
- char2offset.append((span_start + len(span_sub), curr_offset))
114
-
115
- return document_acc, char2offset
116
-
117
-
118
- def map_back_annotations(
119
- annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]]
120
- ) -> Iterator[Tuple[int, int, str]]:
121
- def map_char(char_idx: int) -> int:
122
- current_offset = 0
123
- for offset_idx, offset_value in char_mapping:
124
- if char_idx >= offset_idx:
125
- current_offset = offset_value
126
- else:
127
- break
128
- return char_idx + current_offset
129
-
130
- for ss, se, label in annotations:
131
- yield map_char(ss), map_char(se), label
132
-
133
-
134
- def annotate(document: str) -> List[Tuple[int, int, str]]:
135
- new_document, mapping = preprocess_document(document)
136
- logger.info("Mapping: " + str(mapping))
137
- logger.info("Document: " + str(document))
138
- annotations = [
139
- (cs, ce, label.replace(" ", "_"))
140
- for cs, ce, label in manager.annotate(new_document)
141
- ]
142
- logger.info("New document: " + str(new_document))
143
- mapped_annotations = (
144
- list(map_back_annotations(annotations, mapping))
145
- if len(mapping) > 0
146
- else annotations
147
- )
148
-
149
- logger.info(
150
- "Annotations: "
151
- + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations])
152
- )
153
-
154
- manager.write_response_bundle(
155
- document, new_document, mapped_annotations, annotations
156
- )
157
-
158
- if not all(
159
- [
160
- new_document[ss:se] == document[mss:mse]
161
- for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
162
- ]
163
- ):
164
- diff_mappings = [
165
- (new_document[ss:se], document[mss:mse])
166
- for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
167
- ]
168
- return None
169
- assert all(
170
- [
171
- document[mss:mse] == new_document[ss:se]
172
- for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
173
- ]
174
- ), (mapped_annotations, annotations)
175
-
176
- return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations]
177
-
178
-
179
- class GetHandler(BaseHTTPRequestHandler):
180
- def do_POST(self):
181
- content_length = int(self.headers["Content-Length"])
182
- post_data = self.rfile.read(content_length)
183
- self.send_response(200)
184
- self.end_headers()
185
- doc_text = read_json(post_data)
186
- # try:
187
- response = annotate(doc_text)
188
-
189
- self.wfile.write(bytes(json.dumps(response), "utf-8"))
190
- return
191
-
192
-
193
- def read_json(post_data):
194
- data = json.loads(post_data.decode("utf-8"))
195
- # logger.info("received data:", data)
196
- text = data["text"]
197
- # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
198
- return text
199
-
200
-
201
- def parse_args() -> argparse.Namespace:
202
- parser = argparse.ArgumentParser()
203
- parser.add_argument("--relik-model-name", required=True)
204
- parser.add_argument("--responses-log-dir")
205
- parser.add_argument("--log-file", default="logs/logging.txt")
206
- parser.add_argument("--mapping-file")
207
- return parser.parse_args()
208
-
209
-
210
- def main():
211
- args = parse_args()
212
-
213
- # init manager
214
- manager.response_logger_dir = args.responses_log_dir
215
- # manager.annotator = Relik.from_pretrained(args.relik_model_name)
216
-
217
- print("Debugging, not using you relik model but an hardcoded one.")
218
- manager.annotator = Relik(
219
- question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
220
- document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
221
- reader="relik/reader/models/relik-reader-deberta-base-new-data",
222
- window_size=32,
223
- window_stride=16,
224
- candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()),
225
- )
226
-
227
- if args.mapping_file is not None:
228
- manager.set_mapping_file(args.mapping_file)
229
-
230
- port = 6654
231
- server = HTTPServer(("localhost", port), GetHandler)
232
- logger.info(f"Starting server at http://localhost:{port}")
233
-
234
- # Create a file handler and set its level
235
- file_handler = logging.FileHandler(args.log_file)
236
- file_handler.setLevel(logging.DEBUG)
237
-
238
- # Create a log formatter and set it on the handler
239
- formatter = logging.Formatter(
240
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
241
- )
242
- file_handler.setFormatter(formatter)
243
-
244
- # Add the file handler to the logger
245
- logger.addHandler(file_handler)
246
-
247
- try:
248
- server.serve_forever()
249
- except KeyboardInterrupt:
250
- exit(0)
251
-
252
-
253
- if __name__ == "__main__":
254
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/preprocessing.py DELETED
@@ -1,4 +0,0 @@
1
- def wikipedia_title_and_openings_preprocessing(
2
- wikipedia_title_and_openings: str, sepator: str = " <def>"
3
- ):
4
- return wikipedia_title_and_openings.split(sepator, 1)[0]
 
 
 
 
 
relik/inference/serve/__init__.py DELETED
File without changes
relik/inference/serve/backend/__init__.py DELETED
File without changes
relik/inference/serve/backend/relik.py DELETED
@@ -1,210 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import List, Optional, Union
4
-
5
- from relik.common.utils import is_package_available
6
- from relik.inference.annotator import Relik
7
-
8
- if not is_package_available("fastapi"):
9
- raise ImportError(
10
- "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
11
- )
12
- from fastapi import FastAPI, HTTPException
13
-
14
- if not is_package_available("ray"):
15
- raise ImportError(
16
- "Ray is not installed. Please install Ray with `pip install relik[serve]`."
17
- )
18
- from ray import serve
19
-
20
- from relik.common.log import get_logger
21
- from relik.inference.serve.backend.utils import (
22
- RayParameterManager,
23
- ServerParameterManager,
24
- )
25
- from relik.retriever.data.utils import batch_generator
26
-
27
- logger = get_logger(__name__, level=logging.INFO)
28
-
29
- VERSION = {} # type: ignore
30
- with open(
31
- Path(__file__).parent.parent.parent.parent / "version.py", "r"
32
- ) as version_file:
33
- exec(version_file.read(), VERSION)
34
-
35
- # Env variables for server
36
- SERVER_MANAGER = ServerParameterManager()
37
- RAY_MANAGER = RayParameterManager()
38
-
39
- app = FastAPI(
40
- title="ReLiK",
41
- version=VERSION["VERSION"],
42
- description="ReLiK REST API",
43
- )
44
-
45
-
46
- @serve.deployment(
47
- ray_actor_options={
48
- "num_gpus": RAY_MANAGER.num_gpus
49
- if (
50
- SERVER_MANAGER.retriver_device == "cuda"
51
- or SERVER_MANAGER.reader_device == "cuda"
52
- )
53
- else 0
54
- },
55
- autoscaling_config={
56
- "min_replicas": RAY_MANAGER.min_replicas,
57
- "max_replicas": RAY_MANAGER.max_replicas,
58
- },
59
- )
60
- @serve.ingress(app)
61
- class RelikServer:
62
- def __init__(
63
- self,
64
- question_encoder: str,
65
- document_index: str,
66
- passage_encoder: Optional[str] = None,
67
- reader_encoder: Optional[str] = None,
68
- top_k: int = 100,
69
- retriver_device: str = "cpu",
70
- reader_device: str = "cpu",
71
- index_device: Optional[str] = None,
72
- precision: int = 32,
73
- index_precision: Optional[int] = None,
74
- use_faiss: bool = False,
75
- window_batch_size: int = 32,
76
- window_size: int = 32,
77
- window_stride: int = 16,
78
- split_on_spaces: bool = False,
79
- ):
80
- # parameters
81
- self.question_encoder = question_encoder
82
- self.passage_encoder = passage_encoder
83
- self.reader_encoder = reader_encoder
84
- self.document_index = document_index
85
- self.top_k = top_k
86
- self.retriver_device = retriver_device
87
- self.index_device = index_device or retriver_device
88
- self.reader_device = reader_device
89
- self.precision = precision
90
- self.index_precision = index_precision or precision
91
- self.use_faiss = use_faiss
92
- self.window_batch_size = window_batch_size
93
- self.window_size = window_size
94
- self.window_stride = window_stride
95
- self.split_on_spaces = split_on_spaces
96
-
97
- # log stuff for debugging
98
- logger.info("Initializing RelikServer with parameters:")
99
- logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
100
- logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
101
- logger.info(f"READER_ENCODER: {self.reader_encoder}")
102
- logger.info(f"DOCUMENT_INDEX: {self.document_index}")
103
- logger.info(f"TOP_K: {self.top_k}")
104
- logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}")
105
- logger.info(f"READER_DEVICE: {self.reader_device}")
106
- logger.info(f"INDEX_DEVICE: {self.index_device}")
107
- logger.info(f"PRECISION: {self.precision}")
108
- logger.info(f"INDEX_PRECISION: {self.index_precision}")
109
- logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
110
- logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
111
-
112
- self.relik = Relik(
113
- question_encoder=self.question_encoder,
114
- passage_encoder=self.passage_encoder,
115
- document_index=self.document_index,
116
- reader=self.reader_encoder,
117
- retriever_device=self.retriver_device,
118
- document_index_device=self.index_device,
119
- reader_device=self.reader_device,
120
- retriever_precision=self.precision,
121
- document_index_precision=self.index_precision,
122
- reader_precision=self.precision,
123
- )
124
-
125
- # @serve.batch()
126
- async def handle_batch(self, documents: List[str]) -> List:
127
- return self.relik(
128
- documents,
129
- top_k=self.top_k,
130
- window_size=self.window_size,
131
- window_stride=self.window_stride,
132
- batch_size=self.window_batch_size,
133
- )
134
-
135
- @app.post("/api/entities")
136
- async def entities_endpoint(
137
- self,
138
- documents: Union[str, List[str]],
139
- ):
140
- try:
141
- # normalize input
142
- if isinstance(documents, str):
143
- documents = [documents]
144
- if document_topics is not None:
145
- if isinstance(document_topics, str):
146
- document_topics = [document_topics]
147
- assert len(documents) == len(document_topics)
148
- # get predictions for the retriever
149
- return await self.handle_batch(documents, document_topics)
150
- except Exception as e:
151
- # log the entire stack trace
152
- logger.exception(e)
153
- raise HTTPException(status_code=500, detail=f"Server Error: {e}")
154
-
155
- @app.post("/api/gerbil")
156
- async def gerbil_endpoint(self, documents: Union[str, List[str]]):
157
- try:
158
- # normalize input
159
- if isinstance(documents, str):
160
- documents = [documents]
161
-
162
- # output list
163
- windows_passages = []
164
- # split documents into windows
165
- document_windows = [
166
- window
167
- for doc_id, document in enumerate(documents)
168
- for window in self.window_manager(
169
- self.tokenizer,
170
- document,
171
- window_size=self.window_size,
172
- stride=self.window_stride,
173
- doc_id=doc_id,
174
- )
175
- ]
176
-
177
- # get text and topic from document windows and create new list
178
- model_inputs = [
179
- (window.text, window.doc_topic) for window in document_windows
180
- ]
181
-
182
- # batch generator
183
- for batch in batch_generator(
184
- model_inputs, batch_size=self.window_batch_size
185
- ):
186
- text, text_pair = zip(*batch)
187
- batch_predictions = await self.handle_batch_retriever(text, text_pair)
188
- windows_passages.extend(
189
- [
190
- [p.label for p in predictions]
191
- for predictions in batch_predictions
192
- ]
193
- )
194
-
195
- # add passage to document windows
196
- for window, passages in zip(document_windows, windows_passages):
197
- # clean up passages (remove everything after first <def> tag if present)
198
- passages = [c.split(" <def>", 1)[0] for c in passages]
199
- window.window_candidates = passages
200
-
201
- # return document windows
202
- return document_windows
203
-
204
- except Exception as e:
205
- # log the entire stack trace
206
- logger.exception(e)
207
- raise HTTPException(status_code=500, detail=f"Server Error: {e}")
208
-
209
-
210
- server = RelikServer.bind(**vars(SERVER_MANAGER))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/serve/backend/retriever.py DELETED
@@ -1,206 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import List, Optional, Union
4
-
5
- from relik.common.utils import is_package_available
6
-
7
- if not is_package_available("fastapi"):
8
- raise ImportError(
9
- "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
10
- )
11
- from fastapi import FastAPI, HTTPException
12
-
13
- if not is_package_available("ray"):
14
- raise ImportError(
15
- "Ray is not installed. Please install Ray with `pip install relik[serve]`."
16
- )
17
- from ray import serve
18
-
19
- from relik.common.log import get_logger
20
- from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer
21
- from relik.inference.data.window.manager import WindowManager
22
- from relik.inference.serve.backend.utils import (
23
- RayParameterManager,
24
- ServerParameterManager,
25
- )
26
- from relik.retriever.data.utils import batch_generator
27
- from relik.retriever.pytorch_modules import GoldenRetriever
28
-
29
- logger = get_logger(__name__, level=logging.INFO)
30
-
31
- VERSION = {} # type: ignore
32
- with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file:
33
- exec(version_file.read(), VERSION)
34
-
35
- # Env variables for server
36
- SERVER_MANAGER = ServerParameterManager()
37
- RAY_MANAGER = RayParameterManager()
38
-
39
- app = FastAPI(
40
- title="Golden Retriever",
41
- version=VERSION["VERSION"],
42
- description="Golden Retriever REST API",
43
- )
44
-
45
-
46
- @serve.deployment(
47
- ray_actor_options={
48
- "num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0
49
- },
50
- autoscaling_config={
51
- "min_replicas": RAY_MANAGER.min_replicas,
52
- "max_replicas": RAY_MANAGER.max_replicas,
53
- },
54
- )
55
- @serve.ingress(app)
56
- class GoldenRetrieverServer:
57
- def __init__(
58
- self,
59
- question_encoder: str,
60
- document_index: str,
61
- passage_encoder: Optional[str] = None,
62
- top_k: int = 100,
63
- device: str = "cpu",
64
- index_device: Optional[str] = None,
65
- precision: int = 32,
66
- index_precision: Optional[int] = None,
67
- use_faiss: bool = False,
68
- window_batch_size: int = 32,
69
- window_size: int = 32,
70
- window_stride: int = 16,
71
- split_on_spaces: bool = False,
72
- ):
73
- # parameters
74
- self.question_encoder = question_encoder
75
- self.passage_encoder = passage_encoder
76
- self.document_index = document_index
77
- self.top_k = top_k
78
- self.device = device
79
- self.index_device = index_device or device
80
- self.precision = precision
81
- self.index_precision = index_precision or precision
82
- self.use_faiss = use_faiss
83
- self.window_batch_size = window_batch_size
84
- self.window_size = window_size
85
- self.window_stride = window_stride
86
- self.split_on_spaces = split_on_spaces
87
-
88
- # log stuff for debugging
89
- logger.info("Initializing GoldenRetrieverServer with parameters:")
90
- logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
91
- logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
92
- logger.info(f"DOCUMENT_INDEX: {self.document_index}")
93
- logger.info(f"TOP_K: {self.top_k}")
94
- logger.info(f"DEVICE: {self.device}")
95
- logger.info(f"INDEX_DEVICE: {self.index_device}")
96
- logger.info(f"PRECISION: {self.precision}")
97
- logger.info(f"INDEX_PRECISION: {self.index_precision}")
98
- logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
99
- logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
100
-
101
- self.retriever = GoldenRetriever(
102
- question_encoder=self.question_encoder,
103
- passage_encoder=self.passage_encoder,
104
- document_index=self.document_index,
105
- device=self.device,
106
- index_device=self.index_device,
107
- index_precision=self.index_precision,
108
- )
109
- self.retriever.eval()
110
-
111
- if self.split_on_spaces:
112
- logger.info("Using WhitespaceTokenizer")
113
- self.tokenizer = WhitespaceTokenizer()
114
- # logger.info("Using RegexTokenizer")
115
- # self.tokenizer = RegexTokenizer()
116
- else:
117
- logger.info("Using SpacyTokenizer")
118
- self.tokenizer = SpacyTokenizer(language="en")
119
-
120
- self.window_manager = WindowManager(tokenizer=self.tokenizer)
121
-
122
- # @serve.batch()
123
- async def handle_batch(
124
- self, documents: List[str], document_topics: List[str]
125
- ) -> List:
126
- return self.retriever.retrieve(
127
- documents, text_pair=document_topics, k=self.top_k, precision=self.precision
128
- )
129
-
130
- @app.post("/api/retrieve")
131
- async def retrieve_endpoint(
132
- self,
133
- documents: Union[str, List[str]],
134
- document_topics: Optional[Union[str, List[str]]] = None,
135
- ):
136
- try:
137
- # normalize input
138
- if isinstance(documents, str):
139
- documents = [documents]
140
- if document_topics is not None:
141
- if isinstance(document_topics, str):
142
- document_topics = [document_topics]
143
- assert len(documents) == len(document_topics)
144
- # get predictions
145
- return await self.handle_batch(documents, document_topics)
146
- except Exception as e:
147
- # log the entire stack trace
148
- logger.exception(e)
149
- raise HTTPException(status_code=500, detail=f"Server Error: {e}")
150
-
151
- @app.post("/api/gerbil")
152
- async def gerbil_endpoint(self, documents: Union[str, List[str]]):
153
- try:
154
- # normalize input
155
- if isinstance(documents, str):
156
- documents = [documents]
157
-
158
- # output list
159
- windows_passages = []
160
- # split documents into windows
161
- document_windows = [
162
- window
163
- for doc_id, document in enumerate(documents)
164
- for window in self.window_manager(
165
- self.tokenizer,
166
- document,
167
- window_size=self.window_size,
168
- stride=self.window_stride,
169
- doc_id=doc_id,
170
- )
171
- ]
172
-
173
- # get text and topic from document windows and create new list
174
- model_inputs = [
175
- (window.text, window.doc_topic) for window in document_windows
176
- ]
177
-
178
- # batch generator
179
- for batch in batch_generator(
180
- model_inputs, batch_size=self.window_batch_size
181
- ):
182
- text, text_pair = zip(*batch)
183
- batch_predictions = await self.handle_batch(text, text_pair)
184
- windows_passages.extend(
185
- [
186
- [p.label for p in predictions]
187
- for predictions in batch_predictions
188
- ]
189
- )
190
-
191
- # add passage to document windows
192
- for window, passages in zip(document_windows, windows_passages):
193
- # clean up passages (remove everything after first <def> tag if present)
194
- passages = [c.split(" <def>", 1)[0] for c in passages]
195
- window.window_candidates = passages
196
-
197
- # return document windows
198
- return document_windows
199
-
200
- except Exception as e:
201
- # log the entire stack trace
202
- logger.exception(e)
203
- raise HTTPException(status_code=500, detail=f"Server Error: {e}")
204
-
205
-
206
- server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/serve/backend/utils.py DELETED
@@ -1,29 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
- from typing import Union
4
-
5
-
6
- @dataclass
7
- class ServerParameterManager:
8
- retriver_device: str = os.environ.get("RETRIEVER_DEVICE", "cpu")
9
- reader_device: str = os.environ.get("READER_DEVICE", "cpu")
10
- index_device: str = os.environ.get("INDEX_DEVICE", retriver_device)
11
- precision: Union[str, int] = os.environ.get("PRECISION", "fp32")
12
- index_precision: Union[str, int] = os.environ.get("INDEX_PRECISION", precision)
13
- question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
14
- passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
15
- document_index: str = os.environ.get("DOCUMENT_INDEX", None)
16
- reader_encoder: str = os.environ.get("READER_ENCODER", None)
17
- top_k: int = int(os.environ.get("TOP_K", 100))
18
- use_faiss: bool = os.environ.get("USE_FAISS", False)
19
- window_batch_size: int = int(os.environ.get("WINDOW_BATCH_SIZE", 32))
20
- window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
21
- window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
22
- split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
23
-
24
-
25
- class RayParameterManager:
26
- def __init__(self) -> None:
27
- self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
28
- self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
29
- self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/serve/frontend/__init__.py DELETED
File without changes
relik/inference/serve/frontend/relik.py DELETED
@@ -1,231 +0,0 @@
1
- import os
2
- import re
3
- import time
4
- from pathlib import Path
5
-
6
- import requests
7
- import streamlit as st
8
- from spacy import displacy
9
- from streamlit_extras.badges import badge
10
- from streamlit_extras.stylable_container import stylable_container
11
-
12
- RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
13
-
14
- import random
15
-
16
-
17
- def get_random_color(ents):
18
- colors = {}
19
- random_colors = generate_pastel_colors(len(ents))
20
- for ent in ents:
21
- colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
22
- return colors
23
-
24
-
25
- def floatrange(start, stop, steps):
26
- if int(steps) == 1:
27
- return [stop]
28
- return [
29
- start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
30
- ]
31
-
32
-
33
- def hsl_to_rgb(h, s, l):
34
- def hue_2_rgb(v1, v2, v_h):
35
- while v_h < 0.0:
36
- v_h += 1.0
37
- while v_h > 1.0:
38
- v_h -= 1.0
39
- if 6 * v_h < 1.0:
40
- return v1 + (v2 - v1) * 6.0 * v_h
41
- if 2 * v_h < 1.0:
42
- return v2
43
- if 3 * v_h < 2.0:
44
- return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
45
- return v1
46
-
47
- # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
48
- # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
49
-
50
- r, b, g = (l * 255,) * 3
51
- if s != 0.0:
52
- if l < 0.5:
53
- var_2 = l * (1.0 + s)
54
- else:
55
- var_2 = (l + s) - (s * l)
56
- var_1 = 2.0 * l - var_2
57
- r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
58
- g = 255 * hue_2_rgb(var_1, var_2, h)
59
- b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
60
-
61
- return int(round(r)), int(round(g)), int(round(b))
62
-
63
-
64
- def generate_pastel_colors(n):
65
- """Return different pastel colours.
66
-
67
- Input:
68
- n (integer) : The number of colors to return
69
-
70
- Output:
71
- A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
72
-
73
- Example:
74
- >>> print generate_pastel_colors(5)
75
- ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
76
- """
77
- if n == 0:
78
- return []
79
-
80
- # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
81
- start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
82
- saturation = 1.0
83
- lightness = 0.8
84
- # We take points around the chromatic circle (hue):
85
- # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
86
- # it equals the first one (hue 0 = hue 1))
87
- return [
88
- "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
89
- for hue in floatrange(start_hue, start_hue + 1, n + 1)
90
- ][:-1]
91
-
92
-
93
- def set_sidebar(css):
94
- white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
95
- with st.sidebar:
96
- st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
97
- st.image(
98
- "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
99
- use_column_width=True,
100
- )
101
- st.markdown("## ReLiK")
102
- st.write(
103
- f"""
104
- - {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i>&nbsp; Paper")}
105
- - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
106
- - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i>&nbsp; Docker Hub")}
107
- """,
108
- unsafe_allow_html=True,
109
- )
110
- st.markdown("## Sapienza NLP")
111
- st.write(
112
- f"""
113
- - {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i>&nbsp; Webpage")}
114
- - {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
115
- - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i>&nbsp; Twitter")}
116
- - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i>&nbsp; LinkedIn")}
117
- """,
118
- unsafe_allow_html=True,
119
- )
120
-
121
-
122
- def get_el_annotations(response):
123
- # swap labels key with ents
124
- response["ents"] = response.pop("labels")
125
- label_in_text = set(l["label"] for l in response["ents"])
126
- options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
127
- return response, options
128
-
129
-
130
- def set_intro(css):
131
- # intro
132
- st.markdown("# ReLik")
133
- st.markdown(
134
- "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
135
- )
136
- # st.markdown(
137
- # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
138
- # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
139
- # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
140
- # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
141
- # )
142
- badge(type="github", name="sapienzanlp/relik")
143
- badge(type="pypi", name="relik")
144
-
145
-
146
- def run_client():
147
- with open(Path(__file__).parent / "style.css") as f:
148
- css = f.read()
149
-
150
- st.set_page_config(
151
- page_title="ReLik",
152
- page_icon="🦮",
153
- layout="wide",
154
- )
155
- set_sidebar(css)
156
- set_intro(css)
157
-
158
- # text input
159
- text = st.text_area(
160
- "Enter Text Below:",
161
- value="Obama went to Rome for a quick vacation.",
162
- height=200,
163
- max_chars=500,
164
- )
165
-
166
- with stylable_container(
167
- key="annotate_button",
168
- css_styles="""
169
- button {
170
- background-color: #802433;
171
- color: white;
172
- border-radius: 25px;
173
- }
174
- """,
175
- ):
176
- submit = st.button("Annotate")
177
- # submit = st.button("Run")
178
-
179
- # ReLik API call
180
- if submit:
181
- text = text.strip()
182
- if text:
183
- st.markdown("####")
184
- st.markdown("#### Entity Linking")
185
- with st.spinner(text="In progress"):
186
- response = requests.post(RELIK, json=text)
187
- if response.status_code != 200:
188
- st.error("Error: {}".format(response.status_code))
189
- else:
190
- response = response.json()
191
-
192
- # Entity Linking
193
- # with stylable_container(
194
- # key="container_with_border",
195
- # css_styles="""
196
- # {
197
- # border: 1px solid rgba(49, 51, 63, 0.2);
198
- # border-radius: 0.5rem;
199
- # padding: 0.5rem;
200
- # padding-bottom: 2rem;
201
- # }
202
- # """,
203
- # ):
204
- # st.markdown("##")
205
- dict_of_ents, options = get_el_annotations(response=response)
206
- display = displacy.render(
207
- dict_of_ents, manual=True, style="ent", options=options
208
- )
209
- display = display.replace("\n", " ")
210
- # wsd_display = re.sub(
211
- # r"(wiki::\d+\w)",
212
- # r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
213
- # language.upper()
214
- # ),
215
- # wsd_display,
216
- # )
217
- with st.container():
218
- st.write(display, unsafe_allow_html=True)
219
-
220
- st.markdown("####")
221
- st.markdown("#### Relation Extraction")
222
-
223
- with st.container():
224
- st.write("Coming :)", unsafe_allow_html=True)
225
-
226
- else:
227
- st.error("Please enter some text.")
228
-
229
-
230
- if __name__ == "__main__":
231
- run_client()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/inference/serve/frontend/style.css DELETED
@@ -1,33 +0,0 @@
1
- /* Sidebar */
2
- .eczjsme11 {
3
- background-color: #802433;
4
- }
5
-
6
- .st-emotion-cache-10oheav h2 {
7
- color: white;
8
- }
9
-
10
- .st-emotion-cache-10oheav li {
11
- color: white;
12
- }
13
-
14
- /* Main */
15
- a:link {
16
- text-decoration: none;
17
- color: white;
18
- }
19
-
20
- a:visited {
21
- text-decoration: none;
22
- color: white;
23
- }
24
-
25
- a:hover {
26
- text-decoration: none;
27
- color: rgba(255, 255, 255, 0.871);
28
- }
29
-
30
- a:active {
31
- text-decoration: none;
32
- color: white;
33
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/__init__.py DELETED
File without changes
relik/reader/conf/config.yaml DELETED
@@ -1,14 +0,0 @@
1
- # Required to make the "experiments" dir the default one for the output of the models
2
- hydra:
3
- run:
4
- dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5
-
6
- model_name: relik-reader-deberta-base # used to name the model in wandb and output dir
7
- project_name: relik-reader # used to name the project in wandb
8
-
9
-
10
- defaults:
11
- - _self_
12
- - training: base
13
- - model: base
14
- - data: base
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/conf/data/base.yaml DELETED
@@ -1,21 +0,0 @@
1
- train_dataset_path: "relik/reader/data/train.jsonl"
2
- val_dataset_path: "relik/reader/data/testa.jsonl"
3
-
4
- train_dataset:
5
- _target_: "relik.reader.relik_reader_data.RelikDataset"
6
- transformer_model: "${model.model.transformer_model}"
7
- materialize_samples: False
8
- shuffle_candidates: 0.5
9
- random_drop_gold_candidates: 0.05
10
- noise_param: 0.0
11
- for_inference: False
12
- tokens_per_batch: 4096
13
- special_symbols: null
14
-
15
- val_dataset:
16
- _target_: "relik.reader.relik_reader_data.RelikDataset"
17
- transformer_model: "${model.model.transformer_model}"
18
- materialize_samples: False
19
- shuffle_candidates: False
20
- for_inference: True
21
- special_symbols: null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/conf/data/re.yaml DELETED
@@ -1,54 +0,0 @@
1
- train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl"
2
- val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl"
3
- test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl"
4
-
5
- relations_definitions:
6
- /people/person/nationality: "nationality"
7
- /sports/sports_team/location: "sports team location"
8
- /location/country/administrative_divisions: "administrative divisions"
9
- /business/company/major_shareholders: "shareholders"
10
- /people/ethnicity/people: "ethnicity"
11
- /people/ethnicity/geographic_distribution: "geographic distributi6on"
12
- /business/company_shareholder/major_shareholder_of: "major shareholder"
13
- /location/location/contains: "location"
14
- /business/company/founders: "founders"
15
- /business/person/company: "company"
16
- /business/company/advisors: "advisor"
17
- /people/deceased_person/place_of_death: "place of death"
18
- /business/company/industry: "industry"
19
- /people/person/ethnicity: "ethnic background"
20
- /people/person/place_of_birth: "place of birth"
21
- /location/administrative_division/country: "country of an administration division"
22
- /people/person/place_lived: "place lived"
23
- /sports/sports_team_location/teams: "sports team"
24
- /people/person/children: "child"
25
- /people/person/religion: "religion"
26
- /location/neighborhood/neighborhood_of: "neighborhood"
27
- /location/country/capital: "capital"
28
- /business/company/place_founded: "company founded location"
29
- /people/person/profession: "occupation"
30
-
31
- train_dataset:
32
- _target_: "relik.reader.relik_reader_re_data.RelikREDataset"
33
- transformer_model: "${model.model.transformer_model}"
34
- materialize_samples: False
35
- shuffle_candidates: False
36
- flip_candidates: 1.0
37
- noise_param: 0.0
38
- for_inference: False
39
- tokens_per_batch: 4096
40
- min_length: -1
41
- special_symbols: null
42
- relations_definitions: ${data.relations_definitions}
43
- sorting_fields:
44
- - "predictable_candidates"
45
- val_dataset:
46
- _target_: "relik.reader.relik_reader_re_data.RelikREDataset"
47
- transformer_model: "${model.model.transformer_model}"
48
- materialize_samples: False
49
- shuffle_candidates: False
50
- flip_candidates: False
51
- for_inference: True
52
- min_length: -1
53
- special_symbols: null
54
- relations_definitions: ${data.relations_definitions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/conf/training/base.yaml DELETED
@@ -1,12 +0,0 @@
1
- seed: 94
2
-
3
- trainer:
4
- _target_: lightning.Trainer
5
- devices:
6
- - 0
7
- precision: "16-mixed"
8
- max_steps: 50000
9
- val_check_interval: 1.0
10
- num_sanity_val_steps: 0
11
- limit_val_batches: 1
12
- gradient_clip_val: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/conf/training/re.yaml DELETED
@@ -1,12 +0,0 @@
1
- seed: 15
2
-
3
- trainer:
4
- _target_: lightning.Trainer
5
- devices:
6
- - 0
7
- precision: "16-mixed"
8
- max_steps: 100000
9
- val_check_interval: 1.0
10
- num_sanity_val_steps: 0
11
- limit_val_batches: 1
12
- gradient_clip_val: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/data/__init__.py DELETED
File without changes
relik/reader/data/patches.py DELETED
@@ -1,51 +0,0 @@
1
- from typing import List
2
-
3
- from relik.reader.data.relik_reader_sample import RelikReaderSample
4
- from relik.reader.utils.special_symbols import NME_SYMBOL
5
-
6
-
7
- def merge_patches_predictions(sample) -> None:
8
- sample._d["predicted_window_labels"] = dict()
9
- predicted_window_labels = sample._d["predicted_window_labels"]
10
-
11
- sample._d["span_title_probabilities"] = dict()
12
- span_title_probabilities = sample._d["span_title_probabilities"]
13
-
14
- span2title = dict()
15
- for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
16
- # selecting span predictions
17
- for predicted_title, predicted_spans in patch_info[
18
- "predicted_window_labels"
19
- ].items():
20
- for pred_span in predicted_spans:
21
- pred_span = tuple(pred_span)
22
- curr_title = span2title.get(pred_span)
23
- if curr_title is None or curr_title == NME_SYMBOL:
24
- span2title[pred_span] = predicted_title
25
- # else:
26
- # print("Merging at patch level")
27
-
28
- # selecting span predictions probability
29
- for predicted_span, titles_probabilities in patch_info[
30
- "span_title_probabilities"
31
- ].items():
32
- if predicted_span not in span_title_probabilities:
33
- span_title_probabilities[predicted_span] = titles_probabilities
34
-
35
- for span, title in span2title.items():
36
- if title not in predicted_window_labels:
37
- predicted_window_labels[title] = list()
38
- predicted_window_labels[title].append(span)
39
-
40
-
41
- def remove_duplicate_samples(
42
- samples: List[RelikReaderSample],
43
- ) -> List[RelikReaderSample]:
44
- seen_sample = set()
45
- samples_store = []
46
- for sample in samples:
47
- sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
48
- if sample_id not in seen_sample:
49
- seen_sample.add(sample_id)
50
- samples_store.append(sample)
51
- return samples_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/data/relik_reader_data.py DELETED
@@ -1,965 +0,0 @@
1
- import logging
2
- from typing import (
3
- Any,
4
- Callable,
5
- Dict,
6
- Generator,
7
- Iterable,
8
- Iterator,
9
- List,
10
- NamedTuple,
11
- Optional,
12
- Tuple,
13
- Union,
14
- )
15
-
16
- import numpy as np
17
- import torch
18
- from torch.utils.data import IterableDataset
19
- from tqdm import tqdm
20
- from transformers import AutoTokenizer, PreTrainedTokenizer
21
-
22
- from relik.reader.data.relik_reader_data_utils import (
23
- add_noise_to_value,
24
- batchify,
25
- chunks,
26
- flatten,
27
- )
28
- from relik.reader.data.relik_reader_sample import (
29
- RelikReaderSample,
30
- load_relik_reader_samples,
31
- )
32
- from relik.reader.utils.special_symbols import NME_SYMBOL
33
-
34
- logger = logging.getLogger(__name__)
35
-
36
-
37
- def preprocess_dataset(
38
- input_dataset: Iterable[dict],
39
- transformer_model: str,
40
- add_topic: bool,
41
- ) -> Iterable[dict]:
42
- tokenizer = AutoTokenizer.from_pretrained(transformer_model)
43
- for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"):
44
- if len(dataset_elem["tokens"]) == 0:
45
- print(
46
- f"Dataset element with doc id: {dataset_elem['doc_id']}",
47
- f"and offset {dataset_elem['offset']} does not contain any token",
48
- "Skipping it",
49
- )
50
- continue
51
-
52
- new_dataset_elem = dict(
53
- doc_id=dataset_elem["doc_id"],
54
- offset=dataset_elem["offset"],
55
- )
56
-
57
- tokenization_out = tokenizer(
58
- dataset_elem["tokens"],
59
- return_offsets_mapping=True,
60
- add_special_tokens=False,
61
- )
62
-
63
- window_tokens = tokenization_out.input_ids
64
- window_tokens = flatten(window_tokens)
65
-
66
- offsets_mapping = [
67
- [
68
- (
69
- ss + dataset_elem["token2char_start"][str(i)],
70
- se + dataset_elem["token2char_start"][str(i)],
71
- )
72
- for ss, se in tokenization_out.offset_mapping[i]
73
- ]
74
- for i in range(len(dataset_elem["tokens"]))
75
- ]
76
-
77
- offsets_mapping = flatten(offsets_mapping)
78
-
79
- assert len(offsets_mapping) == len(window_tokens)
80
-
81
- window_tokens = (
82
- [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
83
- )
84
-
85
- topic_offset = 0
86
- if add_topic:
87
- topic_tokens = tokenizer(
88
- dataset_elem["doc_topic"], add_special_tokens=False
89
- ).input_ids
90
- topic_offset = len(topic_tokens)
91
- new_dataset_elem["topic_tokens"] = topic_offset
92
- window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
93
-
94
- new_dataset_elem.update(
95
- dict(
96
- tokens=window_tokens,
97
- token2char_start={
98
- str(i): s
99
- for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
100
- },
101
- token2char_end={
102
- str(i): e
103
- for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
104
- },
105
- window_candidates=dataset_elem["window_candidates"],
106
- window_candidates_scores=dataset_elem.get("window_candidates_scores"),
107
- )
108
- )
109
-
110
- if "window_labels" in dataset_elem:
111
- window_labels = [
112
- (s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"]
113
- ]
114
-
115
- new_dataset_elem["window_labels"] = window_labels
116
-
117
- if not all(
118
- [
119
- s in new_dataset_elem["token2char_start"].values()
120
- for s, _, _ in new_dataset_elem["window_labels"]
121
- ]
122
- ):
123
- print(
124
- "Mismatching token start char mapping with labels",
125
- new_dataset_elem["token2char_start"],
126
- new_dataset_elem["window_labels"],
127
- dataset_elem["tokens"],
128
- )
129
- continue
130
-
131
- if not all(
132
- [
133
- e in new_dataset_elem["token2char_end"].values()
134
- for _, e, _ in new_dataset_elem["window_labels"]
135
- ]
136
- ):
137
- print(
138
- "Mismatching token end char mapping with labels",
139
- new_dataset_elem["token2char_end"],
140
- new_dataset_elem["window_labels"],
141
- dataset_elem["tokens"],
142
- )
143
- continue
144
-
145
- yield new_dataset_elem
146
-
147
-
148
- def preprocess_sample(
149
- relik_sample: RelikReaderSample,
150
- tokenizer,
151
- lowercase_policy: float,
152
- add_topic: bool = False,
153
- ) -> None:
154
- if len(relik_sample.tokens) == 0:
155
- return
156
-
157
- if lowercase_policy > 0:
158
- lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy
159
- relik_sample.tokens = [
160
- t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens)
161
- ]
162
-
163
- tokenization_out = tokenizer(
164
- relik_sample.tokens,
165
- return_offsets_mapping=True,
166
- add_special_tokens=False,
167
- )
168
-
169
- window_tokens = tokenization_out.input_ids
170
- window_tokens = flatten(window_tokens)
171
-
172
- offsets_mapping = [
173
- [
174
- (
175
- ss + relik_sample.token2char_start[str(i)],
176
- se + relik_sample.token2char_start[str(i)],
177
- )
178
- for ss, se in tokenization_out.offset_mapping[i]
179
- ]
180
- for i in range(len(relik_sample.tokens))
181
- ]
182
-
183
- offsets_mapping = flatten(offsets_mapping)
184
-
185
- assert len(offsets_mapping) == len(window_tokens)
186
-
187
- window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
188
-
189
- topic_offset = 0
190
- if add_topic:
191
- topic_tokens = tokenizer(
192
- relik_sample.doc_topic, add_special_tokens=False
193
- ).input_ids
194
- topic_offset = len(topic_tokens)
195
- relik_sample.topic_tokens = topic_offset
196
- window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
197
-
198
- relik_sample._d.update(
199
- dict(
200
- tokens=window_tokens,
201
- token2char_start={
202
- str(i): s
203
- for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
204
- },
205
- token2char_end={
206
- str(i): e
207
- for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
208
- },
209
- )
210
- )
211
-
212
- if "window_labels" in relik_sample._d:
213
- relik_sample.window_labels = [
214
- (s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels
215
- ]
216
-
217
-
218
- class TokenizationOutput(NamedTuple):
219
- input_ids: torch.Tensor
220
- attention_mask: torch.Tensor
221
- token_type_ids: torch.Tensor
222
- prediction_mask: torch.Tensor
223
- special_symbols_mask: torch.Tensor
224
-
225
-
226
- class RelikDataset(IterableDataset):
227
- def __init__(
228
- self,
229
- dataset_path: Optional[str],
230
- materialize_samples: bool,
231
- transformer_model: Union[str, PreTrainedTokenizer],
232
- special_symbols: List[str],
233
- shuffle_candidates: Optional[Union[bool, float]] = False,
234
- for_inference: bool = False,
235
- noise_param: float = 0.1,
236
- sorting_fields: Optional[str] = None,
237
- tokens_per_batch: int = 2048,
238
- batch_size: int = None,
239
- max_batch_size: int = 128,
240
- section_size: int = 50_000,
241
- prebatch: bool = True,
242
- random_drop_gold_candidates: float = 0.0,
243
- use_nme: bool = True,
244
- max_subwords_per_candidate: bool = 22,
245
- mask_by_instances: bool = False,
246
- min_length: int = 5,
247
- max_length: int = 2048,
248
- model_max_length: int = 1000,
249
- split_on_cand_overload: bool = True,
250
- skip_empty_training_samples: bool = False,
251
- drop_last: bool = False,
252
- samples: Optional[Iterator[RelikReaderSample]] = None,
253
- lowercase_policy: float = 0.0,
254
- **kwargs,
255
- ):
256
- super().__init__(**kwargs)
257
- self.dataset_path = dataset_path
258
- self.materialize_samples = materialize_samples
259
- self.samples: Optional[List[RelikReaderSample]] = None
260
- if self.materialize_samples:
261
- self.samples = list()
262
-
263
- if isinstance(transformer_model, str):
264
- self.tokenizer = self._build_tokenizer(transformer_model, special_symbols)
265
- else:
266
- self.tokenizer = transformer_model
267
- self.special_symbols = special_symbols
268
- self.shuffle_candidates = shuffle_candidates
269
- self.for_inference = for_inference
270
- self.noise_param = noise_param
271
- self.batching_fields = ["input_ids"]
272
- self.sorting_fields = (
273
- sorting_fields if sorting_fields is not None else self.batching_fields
274
- )
275
-
276
- self.tokens_per_batch = tokens_per_batch
277
- self.batch_size = batch_size
278
- self.max_batch_size = max_batch_size
279
- self.section_size = section_size
280
- self.prebatch = prebatch
281
-
282
- self.random_drop_gold_candidates = random_drop_gold_candidates
283
- self.use_nme = use_nme
284
- self.max_subwords_per_candidate = max_subwords_per_candidate
285
- self.mask_by_instances = mask_by_instances
286
- self.min_length = min_length
287
- self.max_length = max_length
288
- self.model_max_length = (
289
- model_max_length
290
- if model_max_length < self.tokenizer.model_max_length
291
- else self.tokenizer.model_max_length
292
- )
293
-
294
- # retrocompatibility workaround
295
- self.transformer_model = (
296
- transformer_model
297
- if isinstance(transformer_model, str)
298
- else transformer_model.name_or_path
299
- )
300
- self.split_on_cand_overload = split_on_cand_overload
301
- self.skip_empty_training_samples = skip_empty_training_samples
302
- self.drop_last = drop_last
303
- self.lowercase_policy = lowercase_policy
304
- self.samples = samples
305
-
306
- def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]):
307
- return AutoTokenizer.from_pretrained(
308
- transformer_model,
309
- additional_special_tokens=[ss for ss in special_symbols],
310
- add_prefix_space=True,
311
- )
312
-
313
- @property
314
- def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
315
- fields_batchers = {
316
- "input_ids": lambda x: batchify(
317
- x, padding_value=self.tokenizer.pad_token_id
318
- ),
319
- "attention_mask": lambda x: batchify(x, padding_value=0),
320
- "token_type_ids": lambda x: batchify(x, padding_value=0),
321
- "prediction_mask": lambda x: batchify(x, padding_value=1),
322
- "global_attention": lambda x: batchify(x, padding_value=0),
323
- "token2word": None,
324
- "sample": None,
325
- "special_symbols_mask": lambda x: batchify(x, padding_value=False),
326
- "start_labels": lambda x: batchify(x, padding_value=-100),
327
- "end_labels": lambda x: batchify(x, padding_value=-100),
328
- "predictable_candidates_symbols": None,
329
- "predictable_candidates": None,
330
- "patch_offset": None,
331
- "optimus_labels": None,
332
- }
333
-
334
- if "roberta" in self.transformer_model:
335
- del fields_batchers["token_type_ids"]
336
-
337
- return fields_batchers
338
-
339
- def _build_input_ids(
340
- self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]]
341
- ) -> List[int]:
342
- return (
343
- [self.tokenizer.cls_token_id]
344
- + sentence_input_ids
345
- + [self.tokenizer.sep_token_id]
346
- + flatten(candidates_input_ids)
347
- + [self.tokenizer.sep_token_id]
348
- )
349
-
350
- def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
351
- special_symbols_mask = input_ids >= (
352
- len(self.tokenizer) - len(self.special_symbols)
353
- )
354
- special_symbols_mask[0] = True
355
- return special_symbols_mask
356
-
357
- def _build_tokenizer_essentials(
358
- self, input_ids, original_sequence, sample
359
- ) -> TokenizationOutput:
360
- input_ids = torch.tensor(input_ids, dtype=torch.long)
361
- attention_mask = torch.ones_like(input_ids)
362
-
363
- total_sequence_len = len(input_ids)
364
- predictable_sentence_len = len(original_sequence)
365
-
366
- # token type ids
367
- token_type_ids = torch.cat(
368
- [
369
- input_ids.new_zeros(
370
- predictable_sentence_len + 2
371
- ), # original sentence bpes + CLS and SEP
372
- input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2),
373
- ]
374
- )
375
-
376
- # prediction mask -> boolean on tokens that are predictable
377
-
378
- prediction_mask = torch.tensor(
379
- [1]
380
- + ([0] * predictable_sentence_len)
381
- + ([1] * (total_sequence_len - predictable_sentence_len - 1))
382
- )
383
-
384
- # add topic tokens to the prediction mask so that they cannot be predicted
385
- # or optimized during training
386
- topic_tokens = getattr(sample, "topic_tokens", None)
387
- if topic_tokens is not None:
388
- prediction_mask[1 : 1 + topic_tokens] = 1
389
-
390
- # If mask by instances is active the prediction mask is applied to everything
391
- # that is not indicated as an instance in the training set.
392
- if self.mask_by_instances:
393
- char_start2token = {
394
- cs: int(tok) for tok, cs in sample.token2char_start.items()
395
- }
396
- char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()}
397
- instances_mask = torch.ones_like(prediction_mask)
398
- for _, span_info in sample.instance_id2span_data.items():
399
- span_info = span_info[0]
400
- token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS
401
- token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS
402
- instances_mask[token_start : token_end + 1] = 0
403
-
404
- prediction_mask += instances_mask
405
- prediction_mask[prediction_mask > 1] = 1
406
-
407
- assert len(prediction_mask) == len(input_ids)
408
-
409
- # special symbols mask
410
- special_symbols_mask = self._get_special_symbols_mask(input_ids)
411
-
412
- return TokenizationOutput(
413
- input_ids,
414
- attention_mask,
415
- token_type_ids,
416
- prediction_mask,
417
- special_symbols_mask,
418
- )
419
-
420
- def _build_labels(
421
- self,
422
- sample,
423
- tokenization_output: TokenizationOutput,
424
- predictable_candidates: List[str],
425
- ) -> Tuple[torch.Tensor, torch.Tensor]:
426
- start_labels = [0] * len(tokenization_output.input_ids)
427
- end_labels = [0] * len(tokenization_output.input_ids)
428
-
429
- char_start2token = {v: int(k) for k, v in sample.token2char_start.items()}
430
- char_end2token = {v: int(k) for k, v in sample.token2char_end.items()}
431
- for cs, ce, gold_candidate_title in sample.window_labels:
432
- if gold_candidate_title not in predictable_candidates:
433
- if self.use_nme:
434
- gold_candidate_title = NME_SYMBOL
435
- else:
436
- continue
437
- # +1 is to account for the CLS token
438
- start_bpe = char_start2token[cs] + 1
439
- end_bpe = char_end2token[ce] + 1
440
- class_index = predictable_candidates.index(gold_candidate_title)
441
- if (
442
- start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0
443
- ): # prevent from having entities that ends with the same label
444
- start_labels[start_bpe] = class_index + 1 # +1 for the NONE class
445
- end_labels[end_bpe] = class_index + 1 # +1 for the NONE class
446
- else:
447
- print(
448
- "Found entity with the same last subword, it will not be included."
449
- )
450
- print(
451
- cs,
452
- ce,
453
- gold_candidate_title,
454
- start_labels,
455
- end_labels,
456
- sample.doc_id,
457
- )
458
-
459
- ignored_labels_indices = tokenization_output.prediction_mask == 1
460
-
461
- start_labels = torch.tensor(start_labels, dtype=torch.long)
462
- start_labels[ignored_labels_indices] = -100
463
-
464
- end_labels = torch.tensor(end_labels, dtype=torch.long)
465
- end_labels[ignored_labels_indices] = -100
466
-
467
- return start_labels, end_labels
468
-
469
- def produce_sample_bag(
470
- self, sample, predictable_candidates: List[str], candidates_starting_offset: int
471
- ) -> Optional[Tuple[dict, list, int]]:
472
- # input sentence tokenization
473
- input_subwords = sample.tokens[1:-1] # removing special tokens
474
- candidates_symbols = self.special_symbols[candidates_starting_offset:]
475
-
476
- predictable_candidates = list(predictable_candidates)
477
- original_predictable_candidates = list(predictable_candidates)
478
-
479
- # add NME as a possible candidate
480
- if self.use_nme:
481
- predictable_candidates.insert(0, NME_SYMBOL)
482
-
483
- # candidates encoding
484
- candidates_symbols = candidates_symbols[: len(predictable_candidates)]
485
- candidates_encoding_result = self.tokenizer.batch_encode_plus(
486
- [
487
- "{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL
488
- for cs, ct in zip(candidates_symbols, predictable_candidates)
489
- ],
490
- add_special_tokens=False,
491
- ).input_ids
492
-
493
- if (
494
- self.max_subwords_per_candidate is not None
495
- and self.max_subwords_per_candidate > 0
496
- ):
497
- candidates_encoding_result = [
498
- cer[: self.max_subwords_per_candidate]
499
- for cer in candidates_encoding_result
500
- ]
501
-
502
- # drop candidates if the number of input tokens is too long for the model
503
- if (
504
- sum(map(len, candidates_encoding_result))
505
- + len(input_subwords)
506
- + 20 # + 20 special tokens
507
- > self.model_max_length
508
- ):
509
- acceptable_tokens_from_candidates = (
510
- self.model_max_length - 20 - len(input_subwords)
511
- )
512
- i = 0
513
- cum_len = 0
514
- while (
515
- cum_len + len(candidates_encoding_result[i])
516
- < acceptable_tokens_from_candidates
517
- ):
518
- cum_len += len(candidates_encoding_result[i])
519
- i += 1
520
-
521
- candidates_encoding_result = candidates_encoding_result[:i]
522
- candidates_symbols = candidates_symbols[:i]
523
- predictable_candidates = predictable_candidates[:i]
524
-
525
- # final input_ids build
526
- input_ids = self._build_input_ids(
527
- sentence_input_ids=input_subwords,
528
- candidates_input_ids=candidates_encoding_result,
529
- )
530
-
531
- # complete input building (e.g. attention / prediction mask)
532
- tokenization_output = self._build_tokenizer_essentials(
533
- input_ids, input_subwords, sample
534
- )
535
-
536
- output_dict = {
537
- "input_ids": tokenization_output.input_ids,
538
- "attention_mask": tokenization_output.attention_mask,
539
- "token_type_ids": tokenization_output.token_type_ids,
540
- "prediction_mask": tokenization_output.prediction_mask,
541
- "special_symbols_mask": tokenization_output.special_symbols_mask,
542
- "sample": sample,
543
- "predictable_candidates_symbols": candidates_symbols,
544
- "predictable_candidates": predictable_candidates,
545
- }
546
-
547
- # labels creation
548
- if sample.window_labels is not None:
549
- start_labels, end_labels = self._build_labels(
550
- sample,
551
- tokenization_output,
552
- predictable_candidates,
553
- )
554
- output_dict.update(start_labels=start_labels, end_labels=end_labels)
555
-
556
- if (
557
- "roberta" in self.transformer_model
558
- or "longformer" in self.transformer_model
559
- ):
560
- del output_dict["token_type_ids"]
561
-
562
- predictable_candidates_set = set(predictable_candidates)
563
- remaining_candidates = [
564
- candidate
565
- for candidate in original_predictable_candidates
566
- if candidate not in predictable_candidates_set
567
- ]
568
- total_used_candidates = (
569
- candidates_starting_offset
570
- + len(predictable_candidates)
571
- - (1 if self.use_nme else 0)
572
- )
573
-
574
- if self.use_nme:
575
- assert predictable_candidates[0] == NME_SYMBOL
576
-
577
- return output_dict, remaining_candidates, total_used_candidates
578
-
579
- def __iter__(self):
580
- dataset_iterator = self.dataset_iterator_func()
581
-
582
- current_dataset_elements = []
583
-
584
- i = None
585
- for i, dataset_elem in enumerate(dataset_iterator, start=1):
586
- if (
587
- self.section_size is not None
588
- and len(current_dataset_elements) == self.section_size
589
- ):
590
- for batch in self.materialize_batches(current_dataset_elements):
591
- yield batch
592
- current_dataset_elements = []
593
-
594
- current_dataset_elements.append(dataset_elem)
595
-
596
- if i % 50_000 == 0:
597
- logger.info(f"Processed: {i} number of elements")
598
-
599
- if len(current_dataset_elements) != 0:
600
- for batch in self.materialize_batches(current_dataset_elements):
601
- yield batch
602
-
603
- if i is not None:
604
- logger.info(f"Dataset finished: {i} number of elements processed")
605
- else:
606
- logger.warning("Dataset empty")
607
-
608
- def dataset_iterator_func(self):
609
- skipped_instances = 0
610
- data_samples = (
611
- load_relik_reader_samples(self.dataset_path)
612
- if self.samples is None
613
- else self.samples
614
- )
615
- for sample in data_samples:
616
- preprocess_sample(
617
- sample, self.tokenizer, lowercase_policy=self.lowercase_policy
618
- )
619
- current_patch = 0
620
- sample_bag, used_candidates = None, None
621
- remaining_candidates = list(sample.window_candidates)
622
-
623
- if not self.for_inference:
624
- # randomly drop gold candidates at training time
625
- if (
626
- self.random_drop_gold_candidates > 0.0
627
- and np.random.uniform() < self.random_drop_gold_candidates
628
- and len(set(ct for _, _, ct in sample.window_labels)) > 1
629
- ):
630
- # selecting candidates to drop
631
- np.random.shuffle(sample.window_labels)
632
- n_dropped_candidates = np.random.randint(
633
- 0, len(sample.window_labels) - 1
634
- )
635
- dropped_candidates = [
636
- label_elem[-1]
637
- for label_elem in sample.window_labels[:n_dropped_candidates]
638
- ]
639
- dropped_candidates = set(dropped_candidates)
640
-
641
- # saving NMEs because they should not be dropped
642
- if NME_SYMBOL in dropped_candidates:
643
- dropped_candidates.remove(NME_SYMBOL)
644
-
645
- # sample update
646
- sample.window_labels = [
647
- (s, e, _l)
648
- if _l not in dropped_candidates
649
- else (s, e, NME_SYMBOL)
650
- for s, e, _l in sample.window_labels
651
- ]
652
- remaining_candidates = [
653
- wc
654
- for wc in remaining_candidates
655
- if wc not in dropped_candidates
656
- ]
657
-
658
- # shuffle candidates
659
- if (
660
- isinstance(self.shuffle_candidates, bool)
661
- and self.shuffle_candidates
662
- ) or (
663
- isinstance(self.shuffle_candidates, float)
664
- and np.random.uniform() < self.shuffle_candidates
665
- ):
666
- np.random.shuffle(remaining_candidates)
667
-
668
- while len(remaining_candidates) != 0:
669
- sample_bag = self.produce_sample_bag(
670
- sample,
671
- predictable_candidates=remaining_candidates,
672
- candidates_starting_offset=used_candidates
673
- if used_candidates is not None
674
- else 0,
675
- )
676
- if sample_bag is not None:
677
- sample_bag, remaining_candidates, used_candidates = sample_bag
678
- if (
679
- self.for_inference
680
- or not self.skip_empty_training_samples
681
- or (
682
- (
683
- sample_bag.get("start_labels") is not None
684
- and torch.any(sample_bag["start_labels"] > 1).item()
685
- )
686
- or (
687
- sample_bag.get("optimus_labels") is not None
688
- and len(sample_bag["optimus_labels"]) > 0
689
- )
690
- )
691
- ):
692
- sample_bag["patch_offset"] = current_patch
693
- current_patch += 1
694
- yield sample_bag
695
- else:
696
- skipped_instances += 1
697
- if skipped_instances % 1000 == 0 and skipped_instances != 0:
698
- logger.info(
699
- f"Skipped {skipped_instances} instances since they did not have any gold labels..."
700
- )
701
-
702
- # Just use the first fitting candidates if split on
703
- # cand is not True
704
- if not self.split_on_cand_overload:
705
- break
706
-
707
- def preshuffle_elements(self, dataset_elements: List):
708
- # This shuffling is done so that when using the sorting function,
709
- # if it is deterministic given a collection and its order, we will
710
- # make the whole operation not deterministic anymore.
711
- # Basically, the aim is not to build every time the same batches.
712
- if not self.for_inference:
713
- dataset_elements = np.random.permutation(dataset_elements)
714
-
715
- sorting_fn = (
716
- lambda elem: add_noise_to_value(
717
- sum(len(elem[k]) for k in self.sorting_fields),
718
- noise_param=self.noise_param,
719
- )
720
- if not self.for_inference
721
- else sum(len(elem[k]) for k in self.sorting_fields)
722
- )
723
-
724
- dataset_elements = sorted(dataset_elements, key=sorting_fn)
725
-
726
- if self.for_inference:
727
- return dataset_elements
728
-
729
- ds = list(chunks(dataset_elements, 64))
730
- np.random.shuffle(ds)
731
- return flatten(ds)
732
-
733
- def materialize_batches(
734
- self, dataset_elements: List[Dict[str, Any]]
735
- ) -> Generator[Dict[str, Any], None, None]:
736
- if self.prebatch:
737
- dataset_elements = self.preshuffle_elements(dataset_elements)
738
-
739
- current_batch = []
740
-
741
- # function that creates a batch from the 'current_batch' list
742
- def output_batch() -> Dict[str, Any]:
743
- assert (
744
- len(
745
- set([len(elem["predictable_candidates"]) for elem in current_batch])
746
- )
747
- == 1
748
- ), " ".join(
749
- map(
750
- str, [len(elem["predictable_candidates"]) for elem in current_batch]
751
- )
752
- )
753
-
754
- batch_dict = dict()
755
-
756
- de_values_by_field = {
757
- fn: [de[fn] for de in current_batch if fn in de]
758
- for fn in self.fields_batcher
759
- }
760
-
761
- # in case you provide fields batchers but in the batch
762
- # there are no elements for that field
763
- de_values_by_field = {
764
- fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
765
- }
766
-
767
- assert len(set([len(v) for v in de_values_by_field.values()]))
768
-
769
- # todo: maybe we should report the user about possible
770
- # fields filtering due to "None" instances
771
- de_values_by_field = {
772
- fn: fvs
773
- for fn, fvs in de_values_by_field.items()
774
- if all([fv is not None for fv in fvs])
775
- }
776
-
777
- for field_name, field_values in de_values_by_field.items():
778
- field_batch = (
779
- self.fields_batcher[field_name](field_values)
780
- if self.fields_batcher[field_name] is not None
781
- else field_values
782
- )
783
-
784
- batch_dict[field_name] = field_batch
785
-
786
- return batch_dict
787
-
788
- max_len_discards, min_len_discards = 0, 0
789
-
790
- should_token_batch = self.batch_size is None
791
-
792
- curr_pred_elements = -1
793
- for de in dataset_elements:
794
- if (
795
- should_token_batch
796
- and self.max_batch_size != -1
797
- and len(current_batch) == self.max_batch_size
798
- ) or (not should_token_batch and len(current_batch) == self.batch_size):
799
- yield output_batch()
800
- current_batch = []
801
- curr_pred_elements = -1
802
-
803
- too_long_fields = [
804
- k
805
- for k in de
806
- if self.max_length != -1
807
- and torch.is_tensor(de[k])
808
- and len(de[k]) > self.max_length
809
- ]
810
- if len(too_long_fields) > 0:
811
- max_len_discards += 1
812
- continue
813
-
814
- too_short_fields = [
815
- k
816
- for k in de
817
- if self.min_length != -1
818
- and torch.is_tensor(de[k])
819
- and len(de[k]) < self.min_length
820
- ]
821
- if len(too_short_fields) > 0:
822
- min_len_discards += 1
823
- continue
824
-
825
- if should_token_batch:
826
- de_len = sum(len(de[k]) for k in self.batching_fields)
827
-
828
- future_max_len = max(
829
- de_len,
830
- max(
831
- [
832
- sum(len(bde[k]) for k in self.batching_fields)
833
- for bde in current_batch
834
- ],
835
- default=0,
836
- ),
837
- )
838
-
839
- future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
840
-
841
- num_predictable_candidates = len(de["predictable_candidates"])
842
-
843
- if len(current_batch) > 0 and (
844
- future_tokens_per_batch >= self.tokens_per_batch
845
- or (
846
- num_predictable_candidates != curr_pred_elements
847
- and curr_pred_elements != -1
848
- )
849
- ):
850
- yield output_batch()
851
- current_batch = []
852
-
853
- current_batch.append(de)
854
- curr_pred_elements = len(de["predictable_candidates"])
855
-
856
- if len(current_batch) != 0 and not self.drop_last:
857
- yield output_batch()
858
-
859
- if max_len_discards > 0:
860
- if self.for_inference:
861
- logger.warning(
862
- f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were "
863
- f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
864
- f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the "
865
- f"sample length exceeds the maximum length supported by the current model."
866
- )
867
- else:
868
- logger.warning(
869
- f"During iteration, {max_len_discards} elements were "
870
- f"discarded since longer than max length {self.max_length}"
871
- )
872
-
873
- if min_len_discards > 0:
874
- if self.for_inference:
875
- logger.warning(
876
- f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were "
877
- f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
878
- f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the "
879
- f"sample length is shorter than the minimum length supported by the current model."
880
- )
881
- else:
882
- logger.warning(
883
- f"During iteration, {min_len_discards} elements were "
884
- f"discarded since shorter than min length {self.min_length}"
885
- )
886
-
887
- @staticmethod
888
- def convert_tokens_to_char_annotations(
889
- sample: RelikReaderSample,
890
- remove_nmes: bool = True,
891
- ) -> RelikReaderSample:
892
- """
893
- Converts the token annotations to char annotations.
894
-
895
- Args:
896
- sample (:obj:`RelikReaderSample`):
897
- The sample to convert.
898
- remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`):
899
- Whether to remove the NMEs from the annotations.
900
- Returns:
901
- :obj:`RelikReaderSample`: The converted sample.
902
- """
903
- char_annotations = set()
904
- for (
905
- predicted_entity,
906
- predicted_spans,
907
- ) in sample.predicted_window_labels.items():
908
- if predicted_entity == NME_SYMBOL and remove_nmes:
909
- continue
910
-
911
- for span_start, span_end in predicted_spans:
912
- span_start = sample.token2char_start[str(span_start)]
913
- span_end = sample.token2char_end[str(span_end)]
914
-
915
- char_annotations.add((span_start, span_end, predicted_entity))
916
-
917
- char_probs_annotations = dict()
918
- for (
919
- span_start,
920
- span_end,
921
- ), candidates_probs in sample.span_title_probabilities.items():
922
- span_start = sample.token2char_start[str(span_start)]
923
- span_end = sample.token2char_end[str(span_end)]
924
- char_probs_annotations[(span_start, span_end)] = {
925
- title for title, _ in candidates_probs
926
- }
927
-
928
- sample.predicted_window_labels_chars = char_annotations
929
- sample.probs_window_labels_chars = char_probs_annotations
930
-
931
- return sample
932
-
933
- @staticmethod
934
- def merge_patches_predictions(sample) -> None:
935
- sample._d["predicted_window_labels"] = dict()
936
- predicted_window_labels = sample._d["predicted_window_labels"]
937
-
938
- sample._d["span_title_probabilities"] = dict()
939
- span_title_probabilities = sample._d["span_title_probabilities"]
940
-
941
- span2title = dict()
942
- for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
943
- # selecting span predictions
944
- for predicted_title, predicted_spans in patch_info[
945
- "predicted_window_labels"
946
- ].items():
947
- for pred_span in predicted_spans:
948
- pred_span = tuple(pred_span)
949
- curr_title = span2title.get(pred_span)
950
- if curr_title is None or curr_title == NME_SYMBOL:
951
- span2title[pred_span] = predicted_title
952
- # else:
953
- # print("Merging at patch level")
954
-
955
- # selecting span predictions probability
956
- for predicted_span, titles_probabilities in patch_info[
957
- "span_title_probabilities"
958
- ].items():
959
- if predicted_span not in span_title_probabilities:
960
- span_title_probabilities[predicted_span] = titles_probabilities
961
-
962
- for span, title in span2title.items():
963
- if title not in predicted_window_labels:
964
- predicted_window_labels[title] = list()
965
- predicted_window_labels[title].append(span)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/data/relik_reader_data_utils.py DELETED
@@ -1,51 +0,0 @@
1
- from typing import List
2
-
3
- import numpy as np
4
- import torch
5
-
6
-
7
- def flatten(lsts: List[list]) -> list:
8
- acc_lst = list()
9
- for lst in lsts:
10
- acc_lst.extend(lst)
11
- return acc_lst
12
-
13
-
14
- def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor:
15
- return torch.nn.utils.rnn.pad_sequence(
16
- tensors, batch_first=True, padding_value=padding_value
17
- )
18
-
19
-
20
- def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
21
- x = max([t.shape[0] for t in tensors])
22
- y = max([t.shape[1] for t in tensors])
23
- out_matrix = torch.zeros((len(tensors), x, y))
24
- out_matrix += padding_value
25
- for i, tensor in enumerate(tensors):
26
- out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor
27
- return out_matrix
28
-
29
-
30
- def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
31
- x = max([t.shape[0] for t in tensors])
32
- y = max([t.shape[1] for t in tensors])
33
- rest = tensors[0].shape[2]
34
- out_matrix = torch.zeros((len(tensors), x, y, rest))
35
- out_matrix += padding_value
36
- for i, tensor in enumerate(tensors):
37
- out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor
38
- return out_matrix
39
-
40
-
41
- def chunks(lst: list, chunk_size: int) -> List[list]:
42
- chunks_acc = list()
43
- for i in range(0, len(lst), chunk_size):
44
- chunks_acc.append(lst[i : i + chunk_size])
45
- return chunks_acc
46
-
47
-
48
- def add_noise_to_value(value: int, noise_param: float):
49
- noise_value = value * noise_param
50
- noise = np.random.uniform(-noise_value, noise_value)
51
- return max(1, value + noise)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/data/relik_reader_sample.py DELETED
@@ -1,49 +0,0 @@
1
- import json
2
- from typing import Iterable
3
-
4
-
5
- class RelikReaderSample:
6
- def __init__(self, **kwargs):
7
- super().__setattr__("_d", {})
8
- self._d = kwargs
9
-
10
- def __getattribute__(self, item):
11
- return super(RelikReaderSample, self).__getattribute__(item)
12
-
13
- def __getattr__(self, item):
14
- if item.startswith("__") and item.endswith("__"):
15
- # this is likely some python library-specific variable (such as __deepcopy__ for copy)
16
- # better follow standard behavior here
17
- raise AttributeError(item)
18
- elif item in self._d:
19
- return self._d[item]
20
- else:
21
- return None
22
-
23
- def __setattr__(self, key, value):
24
- if key in self._d:
25
- self._d[key] = value
26
- else:
27
- super().__setattr__(key, value)
28
-
29
- def to_jsons(self) -> str:
30
- if "predicted_window_labels" in self._d:
31
- new_obj = {
32
- k: v
33
- for k, v in self._d.items()
34
- if k != "predicted_window_labels" and k != "span_title_probabilities"
35
- }
36
- new_obj["predicted_window_labels"] = [
37
- [ss, se, pred_title]
38
- for (ss, se), pred_title in self.predicted_window_labels_chars
39
- ]
40
- else:
41
- return json.dumps(self._d)
42
-
43
-
44
- def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]:
45
- with open(path) as f:
46
- for line in f:
47
- jsonl_line = json.loads(line.strip())
48
- relik_reader_sample = RelikReaderSample(**jsonl_line)
49
- yield relik_reader_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/lightning_modules/__init__.py DELETED
File without changes
relik/reader/lightning_modules/relik_reader_pl_module.py DELETED
@@ -1,50 +0,0 @@
1
- from typing import Any, Optional
2
-
3
- import lightning
4
- from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
5
-
6
- from relik.reader.relik_reader_core import RelikReaderCoreModel
7
-
8
-
9
- class RelikReaderPLModule(lightning.LightningModule):
10
- def __init__(
11
- self,
12
- cfg: dict,
13
- transformer_model: str,
14
- additional_special_symbols: int,
15
- num_layers: Optional[int] = None,
16
- activation: str = "gelu",
17
- linears_hidden_size: Optional[int] = 512,
18
- use_last_k_layers: int = 1,
19
- training: bool = False,
20
- *args: Any,
21
- **kwargs: Any
22
- ):
23
- super().__init__(*args, **kwargs)
24
- self.save_hyperparameters()
25
- self.relik_reader_core_model = RelikReaderCoreModel(
26
- transformer_model,
27
- additional_special_symbols,
28
- num_layers,
29
- activation,
30
- linears_hidden_size,
31
- use_last_k_layers,
32
- training=training,
33
- )
34
- self.optimizer_factory = None
35
-
36
- def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
37
- relik_output = self.relik_reader_core_model(**batch)
38
- self.log("train-loss", relik_output["loss"])
39
- return relik_output["loss"]
40
-
41
- def validation_step(
42
- self, batch: dict, *args: Any, **kwargs: Any
43
- ) -> Optional[STEP_OUTPUT]:
44
- return
45
-
46
- def set_optimizer_factory(self, optimizer_factory) -> None:
47
- self.optimizer_factory = optimizer_factory
48
-
49
- def configure_optimizers(self) -> OptimizerLRScheduler:
50
- return self.optimizer_factory(self.relik_reader_core_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/lightning_modules/relik_reader_re_pl_module.py DELETED
@@ -1,54 +0,0 @@
1
- from typing import Any, Optional
2
-
3
- import lightning
4
- from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
5
-
6
- from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
7
-
8
-
9
- class RelikReaderREPLModule(lightning.LightningModule):
10
- def __init__(
11
- self,
12
- cfg: dict,
13
- transformer_model: str,
14
- additional_special_symbols: int,
15
- num_layers: Optional[int] = None,
16
- activation: str = "gelu",
17
- linears_hidden_size: Optional[int] = 512,
18
- use_last_k_layers: int = 1,
19
- training: bool = False,
20
- *args: Any,
21
- **kwargs: Any
22
- ):
23
- super().__init__(*args, **kwargs)
24
- self.save_hyperparameters()
25
-
26
- self.relik_reader_re_model = RelikReaderForTripletExtraction(
27
- transformer_model,
28
- additional_special_symbols,
29
- num_layers,
30
- activation,
31
- linears_hidden_size,
32
- use_last_k_layers,
33
- training=training,
34
- )
35
- self.optimizer_factory = None
36
-
37
- def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
38
- relik_output = self.relik_reader_re_model(**batch)
39
- self.log("train-loss", relik_output["loss"])
40
- self.log("train-start_loss", relik_output["ned_start_loss"])
41
- self.log("train-end_loss", relik_output["ned_end_loss"])
42
- self.log("train-relation_loss", relik_output["re_loss"])
43
- return relik_output["loss"]
44
-
45
- def validation_step(
46
- self, batch: dict, *args: Any, **kwargs: Any
47
- ) -> Optional[STEP_OUTPUT]:
48
- return
49
-
50
- def set_optimizer_factory(self, optimizer_factory) -> None:
51
- self.optimizer_factory = optimizer_factory
52
-
53
- def configure_optimizers(self) -> OptimizerLRScheduler:
54
- return self.optimizer_factory(self.relik_reader_re_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/pytorch_modules/__init__.py DELETED
File without changes
relik/reader/pytorch_modules/base.py DELETED
@@ -1,248 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
- from typing import Any, Dict, List
5
-
6
- import torch
7
- import transformers as tr
8
- from torch.utils.data import IterableDataset
9
- from transformers import AutoConfig
10
-
11
- from relik.common.log import get_console_logger, get_logger
12
- from relik.common.utils import get_callable_from_string
13
- from relik.reader.pytorch_modules.hf.modeling_relik import (
14
- RelikReaderConfig,
15
- RelikReaderSample,
16
- )
17
-
18
- console_logger = get_console_logger()
19
- logger = get_logger(__name__, level=logging.INFO)
20
-
21
-
22
- class RelikReaderBase(torch.nn.Module):
23
- default_reader_class: str | None = None
24
- default_data_class: str | None = None
25
-
26
- def __init__(
27
- self,
28
- transformer_model: str | tr.PreTrainedModel | None = None,
29
- additional_special_symbols: int = 0,
30
- num_layers: int | None = None,
31
- activation: str = "gelu",
32
- linears_hidden_size: int | None = 512,
33
- use_last_k_layers: int = 1,
34
- training: bool = False,
35
- device: str | torch.device | None = None,
36
- precision: int = 32,
37
- tokenizer: str | tr.PreTrainedTokenizer | None = None,
38
- dataset: IterableDataset | str | None = None,
39
- default_reader_class: tr.PreTrainedModel | str | None = None,
40
- **kwargs,
41
- ) -> None:
42
- super().__init__()
43
-
44
- self.default_reader_class = default_reader_class or self.default_reader_class
45
-
46
- if self.default_reader_class is None:
47
- raise ValueError("You must specify a default reader class.")
48
-
49
- # get the callable for the default reader class
50
- self.default_reader_class: tr.PreTrainedModel = get_callable_from_string(
51
- self.default_reader_class
52
- )
53
-
54
- if isinstance(transformer_model, str):
55
- config = AutoConfig.from_pretrained(
56
- transformer_model, trust_remote_code=True
57
- )
58
- if "relik-reader" in config.model_type:
59
- transformer_model = self.default_reader_class.from_pretrained(
60
- transformer_model, **kwargs
61
- )
62
- else:
63
- reader_config = RelikReaderConfig(
64
- transformer_model=transformer_model,
65
- additional_special_symbols=additional_special_symbols,
66
- num_layers=num_layers,
67
- activation=activation,
68
- linears_hidden_size=linears_hidden_size,
69
- use_last_k_layers=use_last_k_layers,
70
- training=training,
71
- )
72
- transformer_model = self.default_reader_class(reader_config)
73
-
74
- self.relik_reader_model = transformer_model
75
- self.relik_reader_model_config = self.relik_reader_model.config
76
-
77
- # get the tokenizer
78
- self._tokenizer = tokenizer
79
-
80
- # and instantiate the dataset class
81
- self.dataset: IterableDataset | None = dataset
82
-
83
- # move the model to the device
84
- self.to(device or torch.device("cpu"))
85
-
86
- # set the precision
87
- self.precision = precision
88
-
89
- def forward(self, **kwargs) -> Dict[str, Any]:
90
- return self.relik_reader_model(**kwargs)
91
-
92
- def _read(self, *args, **kwargs) -> Any:
93
- raise NotImplementedError
94
-
95
- @torch.no_grad()
96
- @torch.inference_mode()
97
- def read(
98
- self,
99
- text: List[str] | List[List[str]] | None = None,
100
- samples: List[RelikReaderSample] | None = None,
101
- input_ids: torch.Tensor | None = None,
102
- attention_mask: torch.Tensor | None = None,
103
- token_type_ids: torch.Tensor | None = None,
104
- prediction_mask: torch.Tensor | None = None,
105
- special_symbols_mask: torch.Tensor | None = None,
106
- candidates: List[List[str]] | None = None,
107
- max_length: int = 1000,
108
- max_batch_size: int = 128,
109
- token_batch_size: int = 2048,
110
- precision: int | str | None = None,
111
- progress_bar: bool = False,
112
- *args,
113
- **kwargs,
114
- ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]:
115
- """
116
- Reads the given text.
117
-
118
- Args:
119
- text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`):
120
- The text to read in tokens. If a list of list of tokens is provided, each
121
- inner list is considered a sentence.
122
- samples (:obj:`List[RelikReaderSample]`, `optional`):
123
- The samples to read. If provided, `text` and `candidates` are ignored.
124
- input_ids (:obj:`torch.Tensor`, `optional`):
125
- The input ids of the text.
126
- attention_mask (:obj:`torch.Tensor`, `optional`):
127
- The attention mask of the text.
128
- token_type_ids (:obj:`torch.Tensor`, `optional`):
129
- The token type ids of the text.
130
- prediction_mask (:obj:`torch.Tensor`, `optional`):
131
- The prediction mask of the text.
132
- special_symbols_mask (:obj:`torch.Tensor`, `optional`):
133
- The special symbols mask of the text.
134
- candidates (:obj:`List[List[str]]`, `optional`):
135
- The candidates of the text.
136
- max_length (:obj:`int`, `optional`, defaults to 1024):
137
- The maximum length of the text.
138
- max_batch_size (:obj:`int`, `optional`, defaults to 128):
139
- The maximum batch size.
140
- token_batch_size (:obj:`int`, `optional`):
141
- The maximum number of tokens per batch.
142
- precision (:obj:`int` or :obj:`str`, `optional`):
143
- The precision to use. If not provided, the default is 32 bit.
144
- progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`):
145
- Whether to show a progress bar.
146
-
147
- Returns:
148
- The predicted labels for each sample.
149
- """
150
- if text is None and input_ids is None and samples is None:
151
- raise ValueError(
152
- "Either `text` or `input_ids` or `samples` must be provided."
153
- )
154
- if (input_ids is None and samples is None) and (
155
- text is None or candidates is None
156
- ):
157
- raise ValueError(
158
- "`text` and `candidates` must be provided to return the predictions when "
159
- "`input_ids` and `samples` is not provided."
160
- )
161
- if text is not None and samples is None:
162
- if len(text) != len(candidates):
163
- raise ValueError("`text` and `candidates` must have the same length.")
164
- if isinstance(text[0], str): # change to list of text
165
- text = [text]
166
- candidates = [candidates]
167
-
168
- samples = [
169
- RelikReaderSample(tokens=t, candidates=c)
170
- for t, c in zip(text, candidates)
171
- ]
172
-
173
- return self._read(
174
- samples,
175
- input_ids,
176
- attention_mask,
177
- token_type_ids,
178
- prediction_mask,
179
- special_symbols_mask,
180
- max_length,
181
- max_batch_size,
182
- token_batch_size,
183
- precision or self.precision,
184
- progress_bar,
185
- *args,
186
- **kwargs,
187
- )
188
-
189
- @property
190
- def device(self) -> torch.device:
191
- """
192
- The device of the model.
193
- """
194
- return next(self.parameters()).device
195
-
196
- @property
197
- def tokenizer(self) -> tr.PreTrainedTokenizer:
198
- """
199
- The tokenizer.
200
- """
201
- if self._tokenizer:
202
- return self._tokenizer
203
-
204
- self._tokenizer = tr.AutoTokenizer.from_pretrained(
205
- self.relik_reader_model.config.name_or_path
206
- )
207
- return self._tokenizer
208
-
209
- def save_pretrained(
210
- self,
211
- output_dir: str | os.PathLike,
212
- model_name: str | None = None,
213
- push_to_hub: bool = False,
214
- **kwargs,
215
- ) -> None:
216
- """
217
- Saves the model to the given path.
218
-
219
- Args:
220
- output_dir (`str` or :obj:`os.PathLike`):
221
- The path to save the model to.
222
- model_name (`str`, `optional`):
223
- The name of the model. If not provided, the model will be saved as
224
- `default_reader_class.__name__`.
225
- push_to_hub (`bool`, `optional`, defaults to `False`):
226
- Whether to push the model to the HuggingFace Hub.
227
- **kwargs:
228
- Additional keyword arguments to pass to the `save_pretrained` method
229
- """
230
- # create the output directory
231
- output_dir = Path(output_dir)
232
- output_dir.mkdir(parents=True, exist_ok=True)
233
-
234
- model_name = model_name or self.default_reader_class.__name__
235
-
236
- logger.info(f"Saving reader to {output_dir / model_name}")
237
-
238
- # save the model
239
- self.relik_reader_model.register_for_auto_class()
240
- self.relik_reader_model.save_pretrained(
241
- output_dir / model_name, push_to_hub=push_to_hub, **kwargs
242
- )
243
-
244
- if self.tokenizer:
245
- logger.info("Saving also the tokenizer")
246
- self.tokenizer.save_pretrained(
247
- output_dir / model_name, push_to_hub=push_to_hub, **kwargs
248
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/pytorch_modules/hf/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .configuration_relik import RelikReaderConfig
2
- from .modeling_relik import RelikReaderREModel
 
 
 
relik/reader/pytorch_modules/hf/configuration_relik.py DELETED
@@ -1,33 +0,0 @@
1
- from typing import Optional
2
-
3
- from transformers import AutoConfig
4
- from transformers.configuration_utils import PretrainedConfig
5
-
6
-
7
- class RelikReaderConfig(PretrainedConfig):
8
- model_type = "relik-reader"
9
-
10
- def __init__(
11
- self,
12
- transformer_model: str = "microsoft/deberta-v3-base",
13
- additional_special_symbols: int = 101,
14
- num_layers: Optional[int] = None,
15
- activation: str = "gelu",
16
- linears_hidden_size: Optional[int] = 512,
17
- use_last_k_layers: int = 1,
18
- training: bool = False,
19
- default_reader_class: Optional[str] = None,
20
- **kwargs
21
- ) -> None:
22
- self.transformer_model = transformer_model
23
- self.additional_special_symbols = additional_special_symbols
24
- self.num_layers = num_layers
25
- self.activation = activation
26
- self.linears_hidden_size = linears_hidden_size
27
- self.use_last_k_layers = use_last_k_layers
28
- self.training = training
29
- self.default_reader_class = default_reader_class
30
- super().__init__(**kwargs)
31
-
32
-
33
- AutoConfig.register("relik-reader", RelikReaderConfig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/pytorch_modules/hf/modeling_relik.py DELETED
@@ -1,981 +0,0 @@
1
- from typing import Any, Dict, Optional
2
-
3
- import torch
4
- from transformers import AutoModel, PreTrainedModel
5
- from transformers.activations import ClippedGELUActivation, GELUActivation
6
- from transformers.configuration_utils import PretrainedConfig
7
- from transformers.modeling_utils import PoolerEndLogits
8
-
9
- from .configuration_relik import RelikReaderConfig
10
-
11
-
12
- class RelikReaderSample:
13
- def __init__(self, **kwargs):
14
- super().__setattr__("_d", {})
15
- self._d = kwargs
16
-
17
- def __getattribute__(self, item):
18
- return super(RelikReaderSample, self).__getattribute__(item)
19
-
20
- def __getattr__(self, item):
21
- if item.startswith("__") and item.endswith("__"):
22
- # this is likely some python library-specific variable (such as __deepcopy__ for copy)
23
- # better follow standard behavior here
24
- raise AttributeError(item)
25
- elif item in self._d:
26
- return self._d[item]
27
- else:
28
- return None
29
-
30
- def __setattr__(self, key, value):
31
- if key in self._d:
32
- self._d[key] = value
33
- else:
34
- super().__setattr__(key, value)
35
-
36
-
37
- activation2functions = {
38
- "relu": torch.nn.ReLU(),
39
- "gelu": GELUActivation(),
40
- "gelu_10": ClippedGELUActivation(-10, 10),
41
- }
42
-
43
-
44
- class PoolerEndLogitsBi(PoolerEndLogits):
45
- def __init__(self, config: PretrainedConfig):
46
- super().__init__(config)
47
- self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
48
-
49
- def forward(
50
- self,
51
- hidden_states: torch.FloatTensor,
52
- start_states: Optional[torch.FloatTensor] = None,
53
- start_positions: Optional[torch.LongTensor] = None,
54
- p_mask: Optional[torch.FloatTensor] = None,
55
- ) -> torch.FloatTensor:
56
- if p_mask is not None:
57
- p_mask = p_mask.unsqueeze(-1)
58
- logits = super().forward(
59
- hidden_states,
60
- start_states,
61
- start_positions,
62
- p_mask,
63
- )
64
- return logits
65
-
66
-
67
- class RelikReaderSpanModel(PreTrainedModel):
68
- config_class = RelikReaderConfig
69
-
70
- def __init__(self, config: RelikReaderConfig, *args, **kwargs):
71
- super().__init__(config)
72
- # Transformer model declaration
73
- self.config = config
74
- self.transformer_model = (
75
- AutoModel.from_pretrained(self.config.transformer_model)
76
- if self.config.num_layers is None
77
- else AutoModel.from_pretrained(
78
- self.config.transformer_model, num_hidden_layers=self.config.num_layers
79
- )
80
- )
81
- self.transformer_model.resize_token_embeddings(
82
- self.transformer_model.config.vocab_size
83
- + self.config.additional_special_symbols
84
- )
85
-
86
- self.activation = self.config.activation
87
- self.linears_hidden_size = self.config.linears_hidden_size
88
- self.use_last_k_layers = self.config.use_last_k_layers
89
-
90
- # named entity detection layers
91
- self.ned_start_classifier = self._get_projection_layer(
92
- self.activation, last_hidden=2, layer_norm=False
93
- )
94
- self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
95
-
96
- # END entity disambiguation layer
97
- self.ed_start_projector = self._get_projection_layer(self.activation)
98
- self.ed_end_projector = self._get_projection_layer(self.activation)
99
-
100
- self.training = self.config.training
101
-
102
- # criterion
103
- self.criterion = torch.nn.CrossEntropyLoss()
104
-
105
- def _get_projection_layer(
106
- self,
107
- activation: str,
108
- last_hidden: Optional[int] = None,
109
- input_hidden=None,
110
- layer_norm: bool = True,
111
- ) -> torch.nn.Sequential:
112
- head_components = [
113
- torch.nn.Dropout(0.1),
114
- torch.nn.Linear(
115
- self.transformer_model.config.hidden_size * self.use_last_k_layers
116
- if input_hidden is None
117
- else input_hidden,
118
- self.linears_hidden_size,
119
- ),
120
- activation2functions[activation],
121
- torch.nn.Dropout(0.1),
122
- torch.nn.Linear(
123
- self.linears_hidden_size,
124
- self.linears_hidden_size if last_hidden is None else last_hidden,
125
- ),
126
- ]
127
-
128
- if layer_norm:
129
- head_components.append(
130
- torch.nn.LayerNorm(
131
- self.linears_hidden_size if last_hidden is None else last_hidden,
132
- self.transformer_model.config.layer_norm_eps,
133
- )
134
- )
135
-
136
- return torch.nn.Sequential(*head_components)
137
-
138
- def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
139
- mask = mask.unsqueeze(-1)
140
- if next(self.parameters()).dtype == torch.float16:
141
- logits = logits * (1 - mask) - 65500 * mask
142
- else:
143
- logits = logits * (1 - mask) - 1e30 * mask
144
- return logits
145
-
146
- def _get_model_features(
147
- self,
148
- input_ids: torch.Tensor,
149
- attention_mask: torch.Tensor,
150
- token_type_ids: Optional[torch.Tensor],
151
- ):
152
- model_input = {
153
- "input_ids": input_ids,
154
- "attention_mask": attention_mask,
155
- "output_hidden_states": self.use_last_k_layers > 1,
156
- }
157
-
158
- if token_type_ids is not None:
159
- model_input["token_type_ids"] = token_type_ids
160
-
161
- model_output = self.transformer_model(**model_input)
162
-
163
- if self.use_last_k_layers > 1:
164
- model_features = torch.cat(
165
- model_output[1][-self.use_last_k_layers :], dim=-1
166
- )
167
- else:
168
- model_features = model_output[0]
169
-
170
- return model_features
171
-
172
- def compute_ned_end_logits(
173
- self,
174
- start_predictions,
175
- start_labels,
176
- model_features,
177
- prediction_mask,
178
- batch_size,
179
- ) -> Optional[torch.Tensor]:
180
- # todo: maybe when constraining on the spans,
181
- # we should not use a prediction_mask for the end tokens.
182
- # at least we should not during training imo
183
- start_positions = start_labels if self.training else start_predictions
184
- start_positions_indices = (
185
- torch.arange(start_positions.size(1), device=start_positions.device)
186
- .unsqueeze(0)
187
- .expand(batch_size, -1)[start_positions > 0]
188
- ).to(start_positions.device)
189
-
190
- if len(start_positions_indices) > 0:
191
- expanded_features = torch.cat(
192
- [
193
- model_features[i].unsqueeze(0).expand(x, -1, -1)
194
- for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
195
- if x > 0
196
- ],
197
- dim=0,
198
- ).to(start_positions_indices.device)
199
-
200
- expanded_prediction_mask = torch.cat(
201
- [
202
- prediction_mask[i].unsqueeze(0).expand(x, -1)
203
- for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
204
- if x > 0
205
- ],
206
- dim=0,
207
- ).to(expanded_features.device)
208
-
209
- end_logits = self.ned_end_classifier(
210
- hidden_states=expanded_features,
211
- start_positions=start_positions_indices,
212
- p_mask=expanded_prediction_mask,
213
- )
214
-
215
- return end_logits
216
-
217
- return None
218
-
219
- def compute_classification_logits(
220
- self,
221
- model_features,
222
- special_symbols_mask,
223
- prediction_mask,
224
- batch_size,
225
- start_positions=None,
226
- end_positions=None,
227
- ) -> torch.Tensor:
228
- if start_positions is None or end_positions is None:
229
- start_positions = torch.zeros_like(prediction_mask)
230
- end_positions = torch.zeros_like(prediction_mask)
231
-
232
- model_start_features = self.ed_start_projector(model_features)
233
- model_end_features = self.ed_end_projector(model_features)
234
- model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
235
-
236
- model_ed_features = torch.cat(
237
- [model_start_features, model_end_features], dim=-1
238
- )
239
-
240
- # computing ed features
241
- classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
242
- special_symbols_representation = model_ed_features[special_symbols_mask].view(
243
- batch_size, classes_representations, -1
244
- )
245
-
246
- logits = torch.bmm(
247
- model_ed_features,
248
- torch.permute(special_symbols_representation, (0, 2, 1)),
249
- )
250
-
251
- logits = self._mask_logits(logits, prediction_mask)
252
-
253
- return logits
254
-
255
- def forward(
256
- self,
257
- input_ids: torch.Tensor,
258
- attention_mask: torch.Tensor,
259
- token_type_ids: Optional[torch.Tensor] = None,
260
- prediction_mask: Optional[torch.Tensor] = None,
261
- special_symbols_mask: Optional[torch.Tensor] = None,
262
- start_labels: Optional[torch.Tensor] = None,
263
- end_labels: Optional[torch.Tensor] = None,
264
- use_predefined_spans: bool = False,
265
- *args,
266
- **kwargs,
267
- ) -> Dict[str, Any]:
268
- batch_size, seq_len = input_ids.shape
269
-
270
- model_features = self._get_model_features(
271
- input_ids, attention_mask, token_type_ids
272
- )
273
-
274
- ned_start_labels = None
275
-
276
- # named entity detection if required
277
- if use_predefined_spans: # no need to compute spans
278
- ned_start_logits, ned_start_probabilities, ned_start_predictions = (
279
- None,
280
- None,
281
- torch.clone(start_labels)
282
- if start_labels is not None
283
- else torch.zeros_like(input_ids),
284
- )
285
- ned_end_logits, ned_end_probabilities, ned_end_predictions = (
286
- None,
287
- None,
288
- torch.clone(end_labels)
289
- if end_labels is not None
290
- else torch.zeros_like(input_ids),
291
- )
292
-
293
- ned_start_predictions[ned_start_predictions > 0] = 1
294
- ned_end_predictions[ned_end_predictions > 0] = 1
295
-
296
- else: # compute spans
297
- # start boundary prediction
298
- ned_start_logits = self.ned_start_classifier(model_features)
299
- ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
300
- ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
301
- ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
302
-
303
- # end boundary prediction
304
- ned_start_labels = (
305
- torch.zeros_like(start_labels) if start_labels is not None else None
306
- )
307
-
308
- if ned_start_labels is not None:
309
- ned_start_labels[start_labels == -100] = -100
310
- ned_start_labels[start_labels > 0] = 1
311
-
312
- ned_end_logits = self.compute_ned_end_logits(
313
- ned_start_predictions,
314
- ned_start_labels,
315
- model_features,
316
- prediction_mask,
317
- batch_size,
318
- )
319
-
320
- if ned_end_logits is not None:
321
- ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
322
- ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
323
- else:
324
- ned_end_logits, ned_end_probabilities = None, None
325
- ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
326
-
327
- # flattening end predictions
328
- # (flattening can happen only if the
329
- # end boundaries were not predicted using the gold labels)
330
- if not self.training:
331
- flattened_end_predictions = torch.clone(ned_start_predictions)
332
- flattened_end_predictions[flattened_end_predictions > 0] = 0
333
-
334
- batch_start_predictions = list()
335
- for elem_idx in range(batch_size):
336
- batch_start_predictions.append(
337
- torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
338
- )
339
-
340
- # check that the total number of start predictions
341
- # is equal to the end predictions
342
- total_start_predictions = sum(map(len, batch_start_predictions))
343
- total_end_predictions = len(ned_end_predictions)
344
- assert (
345
- total_start_predictions == 0
346
- or total_start_predictions == total_end_predictions
347
- ), (
348
- f"Total number of start predictions = {total_start_predictions}. "
349
- f"Total number of end predictions = {total_end_predictions}"
350
- )
351
-
352
- curr_end_pred_num = 0
353
- for elem_idx, bsp in enumerate(batch_start_predictions):
354
- for sp in bsp:
355
- ep = ned_end_predictions[curr_end_pred_num].item()
356
- if ep < sp:
357
- ep = sp
358
-
359
- # if we already set this span throw it (no overlap)
360
- if flattened_end_predictions[elem_idx, ep] == 1:
361
- ned_start_predictions[elem_idx, sp] = 0
362
- else:
363
- flattened_end_predictions[elem_idx, ep] = 1
364
-
365
- curr_end_pred_num += 1
366
-
367
- ned_end_predictions = flattened_end_predictions
368
-
369
- start_position, end_position = (
370
- (start_labels, end_labels)
371
- if self.training
372
- else (ned_start_predictions, ned_end_predictions)
373
- )
374
-
375
- # Entity disambiguation
376
- ed_logits = self.compute_classification_logits(
377
- model_features,
378
- special_symbols_mask,
379
- prediction_mask,
380
- batch_size,
381
- start_position,
382
- end_position,
383
- )
384
- ed_probabilities = torch.softmax(ed_logits, dim=-1)
385
- ed_predictions = torch.argmax(ed_probabilities, dim=-1)
386
-
387
- # output build
388
- output_dict = dict(
389
- batch_size=batch_size,
390
- ned_start_logits=ned_start_logits,
391
- ned_start_probabilities=ned_start_probabilities,
392
- ned_start_predictions=ned_start_predictions,
393
- ned_end_logits=ned_end_logits,
394
- ned_end_probabilities=ned_end_probabilities,
395
- ned_end_predictions=ned_end_predictions,
396
- ed_logits=ed_logits,
397
- ed_probabilities=ed_probabilities,
398
- ed_predictions=ed_predictions,
399
- )
400
-
401
- # compute loss if labels
402
- if start_labels is not None and end_labels is not None and self.training:
403
- # named entity detection loss
404
-
405
- # start
406
- if ned_start_logits is not None:
407
- ned_start_loss = self.criterion(
408
- ned_start_logits.view(-1, ned_start_logits.shape[-1]),
409
- ned_start_labels.view(-1),
410
- )
411
- else:
412
- ned_start_loss = 0
413
-
414
- # end
415
- if ned_end_logits is not None:
416
- ned_end_labels = torch.zeros_like(end_labels)
417
- ned_end_labels[end_labels == -100] = -100
418
- ned_end_labels[end_labels > 0] = 1
419
-
420
- ned_end_loss = self.criterion(
421
- ned_end_logits,
422
- (
423
- torch.arange(
424
- ned_end_labels.size(1), device=ned_end_labels.device
425
- )
426
- .unsqueeze(0)
427
- .expand(batch_size, -1)[ned_end_labels > 0]
428
- ).to(ned_end_labels.device),
429
- )
430
-
431
- else:
432
- ned_end_loss = 0
433
-
434
- # entity disambiguation loss
435
- start_labels[ned_start_labels != 1] = -100
436
- ed_labels = torch.clone(start_labels)
437
- ed_labels[end_labels > 0] = end_labels[end_labels > 0]
438
- ed_loss = self.criterion(
439
- ed_logits.view(-1, ed_logits.shape[-1]),
440
- ed_labels.view(-1),
441
- )
442
-
443
- output_dict["ned_start_loss"] = ned_start_loss
444
- output_dict["ned_end_loss"] = ned_end_loss
445
- output_dict["ed_loss"] = ed_loss
446
-
447
- output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
448
-
449
- return output_dict
450
-
451
-
452
- class RelikReaderREModel(PreTrainedModel):
453
- config_class = RelikReaderConfig
454
-
455
- def __init__(self, config, *args, **kwargs):
456
- super().__init__(config)
457
- # Transformer model declaration
458
- # self.transformer_model_name = transformer_model
459
- self.config = config
460
- self.transformer_model = (
461
- AutoModel.from_pretrained(config.transformer_model)
462
- if config.num_layers is None
463
- else AutoModel.from_pretrained(
464
- config.transformer_model, num_hidden_layers=config.num_layers
465
- )
466
- )
467
- self.transformer_model.resize_token_embeddings(
468
- self.transformer_model.config.vocab_size + config.additional_special_symbols
469
- )
470
-
471
- # named entity detection layers
472
- self.ned_start_classifier = self._get_projection_layer(
473
- config.activation, last_hidden=2, layer_norm=False
474
- )
475
-
476
- self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
477
-
478
- self.entity_type_loss = (
479
- config.entity_type_loss if hasattr(config, "entity_type_loss") else False
480
- )
481
- self.relation_disambiguation_loss = (
482
- config.relation_disambiguation_loss
483
- if hasattr(config, "relation_disambiguation_loss")
484
- else False
485
- )
486
-
487
- input_hidden_ents = 2 * self.transformer_model.config.hidden_size
488
-
489
- self.re_subject_projector = self._get_projection_layer(
490
- config.activation, input_hidden=input_hidden_ents
491
- )
492
- self.re_object_projector = self._get_projection_layer(
493
- config.activation, input_hidden=input_hidden_ents
494
- )
495
- self.re_relation_projector = self._get_projection_layer(config.activation)
496
-
497
- if self.entity_type_loss or self.relation_disambiguation_loss:
498
- self.re_entities_projector = self._get_projection_layer(
499
- config.activation,
500
- input_hidden=2 * self.transformer_model.config.hidden_size,
501
- )
502
- self.re_definition_projector = self._get_projection_layer(
503
- config.activation,
504
- )
505
-
506
- self.re_classifier = self._get_projection_layer(
507
- config.activation,
508
- input_hidden=config.linears_hidden_size,
509
- last_hidden=2,
510
- layer_norm=False,
511
- )
512
-
513
- if self.entity_type_loss or self.relation_disambiguation_loss:
514
- self.re_ed_classifier = self._get_projection_layer(
515
- config.activation,
516
- input_hidden=config.linears_hidden_size,
517
- last_hidden=2,
518
- layer_norm=False,
519
- )
520
-
521
- self.training = config.training
522
-
523
- # criterion
524
- self.criterion = torch.nn.CrossEntropyLoss()
525
-
526
- def _get_projection_layer(
527
- self,
528
- activation: str,
529
- last_hidden: Optional[int] = None,
530
- input_hidden=None,
531
- layer_norm: bool = True,
532
- ) -> torch.nn.Sequential:
533
- head_components = [
534
- torch.nn.Dropout(0.1),
535
- torch.nn.Linear(
536
- self.transformer_model.config.hidden_size
537
- * self.config.use_last_k_layers
538
- if input_hidden is None
539
- else input_hidden,
540
- self.config.linears_hidden_size,
541
- ),
542
- activation2functions[activation],
543
- torch.nn.Dropout(0.1),
544
- torch.nn.Linear(
545
- self.config.linears_hidden_size,
546
- self.config.linears_hidden_size if last_hidden is None else last_hidden,
547
- ),
548
- ]
549
-
550
- if layer_norm:
551
- head_components.append(
552
- torch.nn.LayerNorm(
553
- self.config.linears_hidden_size
554
- if last_hidden is None
555
- else last_hidden,
556
- self.transformer_model.config.layer_norm_eps,
557
- )
558
- )
559
-
560
- return torch.nn.Sequential(*head_components)
561
-
562
- def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
563
- mask = mask.unsqueeze(-1)
564
- if next(self.parameters()).dtype == torch.float16:
565
- logits = logits * (1 - mask) - 65500 * mask
566
- else:
567
- logits = logits * (1 - mask) - 1e30 * mask
568
- return logits
569
-
570
- def _get_model_features(
571
- self,
572
- input_ids: torch.Tensor,
573
- attention_mask: torch.Tensor,
574
- token_type_ids: Optional[torch.Tensor],
575
- ):
576
- model_input = {
577
- "input_ids": input_ids,
578
- "attention_mask": attention_mask,
579
- "output_hidden_states": self.config.use_last_k_layers > 1,
580
- }
581
-
582
- if token_type_ids is not None:
583
- model_input["token_type_ids"] = token_type_ids
584
-
585
- model_output = self.transformer_model(**model_input)
586
-
587
- if self.config.use_last_k_layers > 1:
588
- model_features = torch.cat(
589
- model_output[1][-self.config.use_last_k_layers :], dim=-1
590
- )
591
- else:
592
- model_features = model_output[0]
593
-
594
- return model_features
595
-
596
- def compute_ned_end_logits(
597
- self,
598
- start_predictions,
599
- start_labels,
600
- model_features,
601
- prediction_mask,
602
- batch_size,
603
- ) -> Optional[torch.Tensor]:
604
- # todo: maybe when constraining on the spans,
605
- # we should not use a prediction_mask for the end tokens.
606
- # at least we should not during training imo
607
- start_positions = start_labels if self.training else start_predictions
608
- start_positions_indices = (
609
- torch.arange(start_positions.size(1), device=start_positions.device)
610
- .unsqueeze(0)
611
- .expand(batch_size, -1)[start_positions > 0]
612
- ).to(start_positions.device)
613
-
614
- if len(start_positions_indices) > 0:
615
- expanded_features = torch.cat(
616
- [
617
- model_features[i].unsqueeze(0).expand(x, -1, -1)
618
- for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
619
- if x > 0
620
- ],
621
- dim=0,
622
- ).to(start_positions_indices.device)
623
-
624
- expanded_prediction_mask = torch.cat(
625
- [
626
- prediction_mask[i].unsqueeze(0).expand(x, -1)
627
- for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
628
- if x > 0
629
- ],
630
- dim=0,
631
- ).to(expanded_features.device)
632
-
633
- # mask all tokens before start_positions_indices ie, mask all tokens with
634
- # indices < start_positions_indices with 1, ie. [range(x) for x in start_positions_indices]
635
- expanded_prediction_mask = torch.stack(
636
- [
637
- torch.cat(
638
- [
639
- torch.ones(x, device=expanded_features.device),
640
- expanded_prediction_mask[i, x:],
641
- ]
642
- )
643
- for i, x in enumerate(start_positions_indices)
644
- if x > 0
645
- ],
646
- dim=0,
647
- ).to(expanded_features.device)
648
-
649
- end_logits = self.ned_end_classifier(
650
- hidden_states=expanded_features,
651
- start_positions=start_positions_indices,
652
- p_mask=expanded_prediction_mask,
653
- )
654
-
655
- return end_logits
656
-
657
- return None
658
-
659
- def compute_relation_logits(
660
- self,
661
- model_entity_features,
662
- special_symbols_features,
663
- ) -> torch.Tensor:
664
- model_subject_features = self.re_subject_projector(model_entity_features)
665
- model_object_features = self.re_object_projector(model_entity_features)
666
- special_symbols_start_representation = self.re_relation_projector(
667
- special_symbols_features
668
- )
669
- re_logits = torch.einsum(
670
- "bse,bde,bfe->bsdfe",
671
- model_subject_features,
672
- model_object_features,
673
- special_symbols_start_representation,
674
- )
675
- re_logits = self.re_classifier(re_logits)
676
-
677
- return re_logits
678
-
679
- def compute_entity_logits(
680
- self,
681
- model_entity_features,
682
- special_symbols_features,
683
- ) -> torch.Tensor:
684
- model_ed_features = self.re_entities_projector(model_entity_features)
685
- special_symbols_ed_representation = self.re_definition_projector(
686
- special_symbols_features
687
- )
688
- logits = torch.einsum(
689
- "bce,bde->bcde",
690
- model_ed_features,
691
- special_symbols_ed_representation,
692
- )
693
- logits = self.re_ed_classifier(logits)
694
- start_logits = self._mask_logits(
695
- logits,
696
- (model_entity_features == -100)
697
- .all(2)
698
- .long()
699
- .unsqueeze(2)
700
- .repeat(1, 1, torch.sum(model_entity_features, dim=1)[0].item()),
701
- )
702
-
703
- return logits
704
-
705
- def compute_loss(self, logits, labels, mask=None):
706
- logits = logits.view(-1, logits.shape[-1])
707
- labels = labels.view(-1).long()
708
- if mask is not None:
709
- return self.criterion(logits[mask], labels[mask])
710
- return self.criterion(logits, labels)
711
-
712
- def compute_ned_end_loss(self, ned_end_logits, end_labels):
713
- if ned_end_logits is None:
714
- return 0
715
- ned_end_labels = torch.zeros_like(end_labels)
716
- ned_end_labels[end_labels == -100] = -100
717
- ned_end_labels[end_labels > 0] = 1
718
- return self.compute_loss(ned_end_logits, ned_end_labels)
719
-
720
- def compute_ned_type_loss(
721
- self,
722
- disambiguation_labels,
723
- re_ned_entities_logits,
724
- ned_type_logits,
725
- re_entities_logits,
726
- entity_types,
727
- ):
728
- if self.entity_type_loss and self.relation_disambiguation_loss:
729
- return self.compute_loss(disambiguation_labels, re_ned_entities_logits)
730
- if self.entity_type_loss:
731
- return self.compute_loss(
732
- disambiguation_labels[:, :, :entity_types], ned_type_logits
733
- )
734
- if self.relation_disambiguation_loss:
735
- return self.compute_loss(disambiguation_labels, re_entities_logits)
736
- return 0
737
-
738
- def compute_relation_loss(self, relation_labels, re_logits):
739
- return self.compute_loss(
740
- re_logits, relation_labels, relation_labels.view(-1) != -100
741
- )
742
-
743
- def forward(
744
- self,
745
- input_ids: torch.Tensor,
746
- attention_mask: torch.Tensor,
747
- token_type_ids: torch.Tensor,
748
- prediction_mask: Optional[torch.Tensor] = None,
749
- special_symbols_mask: Optional[torch.Tensor] = None,
750
- special_symbols_mask_entities: Optional[torch.Tensor] = None,
751
- start_labels: Optional[torch.Tensor] = None,
752
- end_labels: Optional[torch.Tensor] = None,
753
- disambiguation_labels: Optional[torch.Tensor] = None,
754
- relation_labels: Optional[torch.Tensor] = None,
755
- is_validation: bool = False,
756
- is_prediction: bool = False,
757
- *args,
758
- **kwargs,
759
- ) -> Dict[str, Any]:
760
- batch_size = input_ids.shape[0]
761
-
762
- model_features = self._get_model_features(
763
- input_ids, attention_mask, token_type_ids
764
- )
765
-
766
- # named entity detection
767
- if is_prediction and start_labels is not None:
768
- ned_start_logits, ned_start_probabilities, ned_start_predictions = (
769
- None,
770
- None,
771
- torch.zeros_like(start_labels),
772
- )
773
- ned_end_logits, ned_end_probabilities, ned_end_predictions = (
774
- None,
775
- None,
776
- torch.zeros_like(end_labels),
777
- )
778
-
779
- ned_start_predictions[start_labels > 0] = 1
780
- ned_end_predictions[end_labels > 0] = 1
781
- ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
782
- else:
783
- # start boundary prediction
784
- ned_start_logits = self.ned_start_classifier(model_features)
785
- ned_start_logits = self._mask_logits(
786
- ned_start_logits, prediction_mask
787
- ) # why?
788
- ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
789
- ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
790
-
791
- # end boundary prediction
792
- ned_start_labels = (
793
- torch.zeros_like(start_labels) if start_labels is not None else None
794
- )
795
-
796
- # start_labels contain entity id at their position, we just need 1 for start of entity
797
- if ned_start_labels is not None:
798
- ned_start_labels[start_labels > 0] = 1
799
-
800
- # compute end logits only if there are any start predictions.
801
- # For each start prediction, n end predictions are made
802
- ned_end_logits = self.compute_ned_end_logits(
803
- ned_start_predictions,
804
- ned_start_labels,
805
- model_features,
806
- prediction_mask,
807
- batch_size,
808
- )
809
- # For each start prediction, n end predictions are made based on
810
- # binary classification ie. argmax at each position.
811
- ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
812
- ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
813
- if is_prediction or is_validation:
814
- end_preds_count = ned_end_predictions.sum(1)
815
- # If there are no end predictions for a start prediction, remove the start prediction
816
- ned_start_predictions[ned_start_predictions == 1] = (
817
- end_preds_count != 0
818
- ).long()
819
- ned_end_predictions = ned_end_predictions[end_preds_count != 0]
820
-
821
- if end_labels is not None:
822
- end_labels = end_labels[~(end_labels == -100).all(2)]
823
-
824
- start_position, end_position = (
825
- (start_labels, end_labels)
826
- if (not is_prediction and not is_validation)
827
- else (ned_start_predictions, ned_end_predictions)
828
- )
829
-
830
- start_counts = (start_position > 0).sum(1)
831
- ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
832
-
833
- # We can only predict relations if we have start and end predictions
834
- if (end_position > 0).sum() > 0:
835
- ends_count = (end_position > 0).sum(1)
836
- model_subject_features = torch.cat(
837
- [
838
- torch.repeat_interleave(
839
- model_features[start_position > 0], ends_count, dim=0
840
- ), # start position features
841
- torch.repeat_interleave(model_features, start_counts, dim=0)[
842
- end_position > 0
843
- ], # end position features
844
- ],
845
- dim=-1,
846
- )
847
- ents_count = torch.nn.utils.rnn.pad_sequence(
848
- torch.split(ends_count, start_counts.tolist()),
849
- batch_first=True,
850
- padding_value=0,
851
- ).sum(1)
852
- model_subject_features = torch.nn.utils.rnn.pad_sequence(
853
- torch.split(model_subject_features, ents_count.tolist()),
854
- batch_first=True,
855
- padding_value=-100,
856
- )
857
-
858
- if is_validation or is_prediction:
859
- model_subject_features = model_subject_features[:, :30, :]
860
-
861
- # entity disambiguation. Here relation_disambiguation_loss would only be useful to
862
- # reduce the number of candidate relations for the next step, but currently unused.
863
- if self.entity_type_loss or self.relation_disambiguation_loss:
864
- (re_ned_entities_logits) = self.compute_entity_logits(
865
- model_subject_features,
866
- model_features[
867
- special_symbols_mask | special_symbols_mask_entities
868
- ].view(batch_size, -1, model_features.shape[-1]),
869
- )
870
- entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
871
- ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
872
- re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
873
-
874
- if self.entity_type_loss:
875
- ned_type_probabilities = torch.softmax(ned_type_logits, dim=-1)
876
- ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
877
- ned_type_predictions = ned_type_predictions.argmax(dim=-1)
878
-
879
- re_entities_probabilities = torch.softmax(re_entities_logits, dim=-1)
880
- re_entities_predictions = re_entities_probabilities.argmax(dim=-1)
881
- else:
882
- (
883
- ned_type_logits,
884
- ned_type_probabilities,
885
- re_entities_logits,
886
- re_entities_probabilities,
887
- ) = (None, None, None, None)
888
- ned_type_predictions, re_entities_predictions = (
889
- torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
890
- torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
891
- )
892
-
893
- # Compute relation logits
894
- re_logits = self.compute_relation_logits(
895
- model_subject_features,
896
- model_features[special_symbols_mask].view(
897
- batch_size, -1, model_features.shape[-1]
898
- ),
899
- )
900
-
901
- re_probabilities = torch.softmax(re_logits, dim=-1)
902
- # we set a thresshold instead of argmax in cause it needs to be tweaked
903
- re_predictions = re_probabilities[:, :, :, :, 1] > 0.5
904
- # re_predictions = re_probabilities.argmax(dim=-1)
905
- re_probabilities = re_probabilities[:, :, :, :, 1]
906
-
907
- else:
908
- (
909
- ned_type_logits,
910
- ned_type_probabilities,
911
- re_entities_logits,
912
- re_entities_probabilities,
913
- ) = (None, None, None, None)
914
- ned_type_predictions, re_entities_predictions = (
915
- torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
916
- torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
917
- )
918
- re_logits, re_probabilities, re_predictions = (
919
- torch.zeros(
920
- [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
921
- ).to(input_ids.device),
922
- torch.zeros(
923
- [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
924
- ).to(input_ids.device),
925
- torch.zeros(
926
- [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
927
- ).to(input_ids.device),
928
- )
929
-
930
- # output build
931
- output_dict = dict(
932
- batch_size=batch_size,
933
- ned_start_logits=ned_start_logits,
934
- ned_start_probabilities=ned_start_probabilities,
935
- ned_start_predictions=ned_start_predictions,
936
- ned_end_logits=ned_end_logits,
937
- ned_end_probabilities=ned_end_probabilities,
938
- ned_end_predictions=ned_end_predictions,
939
- ned_type_logits=ned_type_logits,
940
- ned_type_probabilities=ned_type_probabilities,
941
- ned_type_predictions=ned_type_predictions,
942
- re_entities_logits=re_entities_logits,
943
- re_entities_probabilities=re_entities_probabilities,
944
- re_entities_predictions=re_entities_predictions,
945
- re_logits=re_logits,
946
- re_probabilities=re_probabilities,
947
- re_predictions=re_predictions,
948
- )
949
-
950
- if (
951
- start_labels is not None
952
- and end_labels is not None
953
- and relation_labels is not None
954
- ):
955
- ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
956
- ned_end_loss = self.compute_ned_end_loss(ned_end_logits, end_labels)
957
- if self.entity_type_loss or self.relation_disambiguation_loss:
958
- ned_type_loss = self.compute_ned_type_loss(
959
- disambiguation_labels,
960
- re_ned_entities_logits,
961
- ned_type_logits,
962
- re_entities_logits,
963
- entity_types,
964
- )
965
- relation_loss = self.compute_relation_loss(relation_labels, re_logits)
966
- # compute loss. We can skip the relation loss if we are in the first epochs (optional)
967
- if self.entity_type_loss or self.relation_disambiguation_loss:
968
- output_dict["loss"] = (
969
- ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
970
- ) / 4
971
- output_dict["ned_type_loss"] = ned_type_loss
972
- else:
973
- output_dict["loss"] = (
974
- ned_start_loss + ned_end_loss + relation_loss
975
- ) / 3
976
-
977
- output_dict["ned_start_loss"] = ned_start_loss
978
- output_dict["ned_end_loss"] = ned_end_loss
979
- output_dict["re_loss"] = relation_loss
980
-
981
- return output_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/pytorch_modules/optim/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from relik.reader.pytorch_modules.optim.adamw_with_warmup import (
2
- AdamWWithWarmupOptimizer,
3
- )
4
- from relik.reader.pytorch_modules.optim.layer_wise_lr_decay import (
5
- LayerWiseLRDecayOptimizer,
6
- )
 
 
 
 
 
 
 
relik/reader/pytorch_modules/optim/adamw_with_warmup.py DELETED
@@ -1,66 +0,0 @@
1
- from typing import List
2
-
3
- import torch
4
- import transformers
5
- from torch.optim import AdamW
6
-
7
-
8
- class AdamWWithWarmupOptimizer:
9
- def __init__(
10
- self,
11
- lr: float,
12
- warmup_steps: int,
13
- total_steps: int,
14
- weight_decay: float,
15
- no_decay_params: List[str],
16
- ):
17
- self.lr = lr
18
- self.warmup_steps = warmup_steps
19
- self.total_steps = total_steps
20
- self.weight_decay = weight_decay
21
- self.no_decay_params = no_decay_params
22
-
23
- def group_params(self, module: torch.nn.Module) -> list:
24
- if self.no_decay_params is not None:
25
- optimizer_grouped_parameters = [
26
- {
27
- "params": [
28
- p
29
- for n, p in module.named_parameters()
30
- if not any(nd in n for nd in self.no_decay_params)
31
- ],
32
- "weight_decay": self.weight_decay,
33
- },
34
- {
35
- "params": [
36
- p
37
- for n, p in module.named_parameters()
38
- if any(nd in n for nd in self.no_decay_params)
39
- ],
40
- "weight_decay": 0.0,
41
- },
42
- ]
43
-
44
- else:
45
- optimizer_grouped_parameters = [
46
- {"params": module.parameters(), "weight_decay": self.weight_decay}
47
- ]
48
-
49
- return optimizer_grouped_parameters
50
-
51
- def __call__(self, module: torch.nn.Module):
52
- optimizer_grouped_parameters = self.group_params(module)
53
- optimizer = AdamW(
54
- optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay
55
- )
56
- scheduler = transformers.get_linear_schedule_with_warmup(
57
- optimizer, self.warmup_steps, self.total_steps
58
- )
59
- return {
60
- "optimizer": optimizer,
61
- "lr_scheduler": {
62
- "scheduler": scheduler,
63
- "interval": "step",
64
- "frequency": 1,
65
- },
66
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py DELETED
@@ -1,104 +0,0 @@
1
- import collections
2
- from typing import List
3
-
4
- import torch
5
- import transformers
6
- from torch.optim import AdamW
7
-
8
-
9
- class LayerWiseLRDecayOptimizer:
10
- def __init__(
11
- self,
12
- lr: float,
13
- warmup_steps: int,
14
- total_steps: int,
15
- weight_decay: float,
16
- lr_decay: float,
17
- no_decay_params: List[str],
18
- total_reset: int,
19
- ):
20
- self.lr = lr
21
- self.warmup_steps = warmup_steps
22
- self.total_steps = total_steps
23
- self.weight_decay = weight_decay
24
- self.lr_decay = lr_decay
25
- self.no_decay_params = no_decay_params
26
- self.total_reset = total_reset
27
-
28
- def group_layers(self, module) -> dict:
29
- grouped_layers = collections.defaultdict(list)
30
- module_named_parameters = list(module.named_parameters())
31
- for ln, lp in module_named_parameters:
32
- if "embeddings" in ln:
33
- grouped_layers["embeddings"].append((ln, lp))
34
- elif "encoder.layer" in ln:
35
- layer_num = ln.split("transformer_model.encoder.layer.")[-1]
36
- layer_num = layer_num[0 : layer_num.index(".")]
37
- grouped_layers[layer_num].append((ln, lp))
38
- else:
39
- grouped_layers["head"].append((ln, lp))
40
-
41
- depth = len(grouped_layers) - 1
42
- final_dict = dict()
43
- for key, value in grouped_layers.items():
44
- if key == "head":
45
- final_dict[0] = value
46
- elif key == "embeddings":
47
- final_dict[depth] = value
48
- else:
49
- # -1 because layer number starts from zero
50
- final_dict[depth - int(key) - 1] = value
51
-
52
- assert len(module_named_parameters) == sum(
53
- len(v) for _, v in final_dict.items()
54
- )
55
-
56
- return final_dict
57
-
58
- def group_params(self, module) -> list:
59
- optimizer_grouped_params = []
60
- for inverse_depth, layer in self.group_layers(module).items():
61
- layer_lr = self.lr * (self.lr_decay**inverse_depth)
62
- layer_wd_params = {
63
- "params": [
64
- lp
65
- for ln, lp in layer
66
- if not any(nd in ln for nd in self.no_decay_params)
67
- ],
68
- "weight_decay": self.weight_decay,
69
- "lr": layer_lr,
70
- }
71
- layer_no_wd_params = {
72
- "params": [
73
- lp
74
- for ln, lp in layer
75
- if any(nd in ln for nd in self.no_decay_params)
76
- ],
77
- "weight_decay": 0,
78
- "lr": layer_lr,
79
- }
80
-
81
- if len(layer_wd_params) != 0:
82
- optimizer_grouped_params.append(layer_wd_params)
83
- if len(layer_no_wd_params) != 0:
84
- optimizer_grouped_params.append(layer_no_wd_params)
85
-
86
- return optimizer_grouped_params
87
-
88
- def __call__(self, module: torch.nn.Module):
89
- optimizer_grouped_parameters = self.group_params(module)
90
- optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
91
- scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
92
- optimizer,
93
- self.warmup_steps,
94
- self.total_steps,
95
- num_cycles=self.total_reset,
96
- )
97
- return {
98
- "optimizer": optimizer,
99
- "lr_scheduler": {
100
- "scheduler": scheduler,
101
- "interval": "step",
102
- "frequency": 1,
103
- },
104
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/pytorch_modules/span.py DELETED
@@ -1,367 +0,0 @@
1
- import collections
2
- import contextlib
3
- import logging
4
- from typing import Any, Dict, Iterator, List
5
-
6
- import torch
7
- import transformers as tr
8
- from lightning_fabric.utilities import move_data_to_device
9
- from torch.utils.data import DataLoader, IterableDataset
10
- from tqdm import tqdm
11
-
12
- from relik.common.log import get_console_logger, get_logger
13
- from relik.common.utils import get_callable_from_string
14
- from relik.reader.data.relik_reader_sample import RelikReaderSample
15
- from relik.reader.pytorch_modules.base import RelikReaderBase
16
- from relik.reader.utils.special_symbols import get_special_symbols
17
- from relik.retriever.pytorch_modules import PRECISION_MAP
18
-
19
- console_logger = get_console_logger()
20
- logger = get_logger(__name__, level=logging.INFO)
21
-
22
-
23
- class RelikReaderForSpanExtraction(RelikReaderBase):
24
- """
25
- A class for the RelikReader model for span extraction.
26
-
27
- Args:
28
- transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
29
- The transformer model to use. If `None`, the default model is used.
30
- additional_special_symbols (:obj:`int`, `optional`, defaults to 0):
31
- The number of additional special symbols to add to the tokenizer.
32
- num_layers (:obj:`int`, `optional`):
33
- The number of layers to use. If `None`, all layers are used.
34
- activation (:obj:`str`, `optional`, defaults to "gelu"):
35
- The activation function to use.
36
- linears_hidden_size (:obj:`int`, `optional`, defaults to 512):
37
- The hidden size of the linears.
38
- use_last_k_layers (:obj:`int`, `optional`, defaults to 1):
39
- The number of last layers to use.
40
- training (:obj:`bool`, `optional`, defaults to False):
41
- Whether the model is in training mode.
42
- device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`):
43
- The device to use. If `None`, the default device is used.
44
- tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`):
45
- The tokenizer to use. If `None`, the default tokenizer is used.
46
- dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`):
47
- The dataset to use. If `None`, the default dataset is used.
48
- dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`):
49
- The keyword arguments to pass to the dataset class.
50
- default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
51
- The default reader class to use. If `None`, the default reader class is used.
52
- **kwargs:
53
- Keyword arguments.
54
- """
55
-
56
- default_reader_class: str = (
57
- "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel"
58
- )
59
- default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset"
60
-
61
- def __init__(
62
- self,
63
- transformer_model: str | tr.PreTrainedModel | None = None,
64
- additional_special_symbols: int = 0,
65
- num_layers: int | None = None,
66
- activation: str = "gelu",
67
- linears_hidden_size: int | None = 512,
68
- use_last_k_layers: int = 1,
69
- training: bool = False,
70
- device: str | torch.device | None = None,
71
- tokenizer: str | tr.PreTrainedTokenizer | None = None,
72
- dataset: IterableDataset | str | None = None,
73
- dataset_kwargs: Dict[str, Any] | None = None,
74
- default_reader_class: tr.PreTrainedModel | str | None = None,
75
- **kwargs,
76
- ):
77
- super().__init__(
78
- transformer_model=transformer_model,
79
- additional_special_symbols=additional_special_symbols,
80
- num_layers=num_layers,
81
- activation=activation,
82
- linears_hidden_size=linears_hidden_size,
83
- use_last_k_layers=use_last_k_layers,
84
- training=training,
85
- device=device,
86
- tokenizer=tokenizer,
87
- dataset=dataset,
88
- default_reader_class=default_reader_class,
89
- **kwargs,
90
- )
91
- # and instantiate the dataset class
92
- self.dataset = dataset
93
- if self.dataset is None:
94
- default_data_kwargs = dict(
95
- dataset_path=None,
96
- materialize_samples=False,
97
- transformer_model=self.tokenizer,
98
- special_symbols=get_special_symbols(
99
- self.relik_reader_model.config.additional_special_symbols
100
- ),
101
- for_inference=True,
102
- )
103
- # merge the default data kwargs with the ones passed to the model
104
- default_data_kwargs.update(dataset_kwargs or {})
105
- self.dataset = get_callable_from_string(self.default_data_class)(
106
- **default_data_kwargs
107
- )
108
-
109
- @torch.no_grad()
110
- @torch.inference_mode()
111
- def _read(
112
- self,
113
- samples: List[RelikReaderSample] | None = None,
114
- input_ids: torch.Tensor | None = None,
115
- attention_mask: torch.Tensor | None = None,
116
- token_type_ids: torch.Tensor | None = None,
117
- prediction_mask: torch.Tensor | None = None,
118
- special_symbols_mask: torch.Tensor | None = None,
119
- max_length: int = 1000,
120
- max_batch_size: int = 128,
121
- token_batch_size: int = 2048,
122
- precision: str = 32,
123
- annotation_type: str = "char",
124
- progress_bar: bool = False,
125
- *args: object,
126
- **kwargs: object,
127
- ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]:
128
- """
129
- A wrapper around the forward method that returns the predicted labels for each sample.
130
-
131
- Args:
132
- samples (:obj:`List[RelikReaderSample]`, `optional`):
133
- The samples to read. If provided, `text` and `candidates` are ignored.
134
- input_ids (:obj:`torch.Tensor`, `optional`):
135
- The input ids of the text. If `samples` is provided, this is ignored.
136
- attention_mask (:obj:`torch.Tensor`, `optional`):
137
- The attention mask of the text. If `samples` is provided, this is ignored.
138
- token_type_ids (:obj:`torch.Tensor`, `optional`):
139
- The token type ids of the text. If `samples` is provided, this is ignored.
140
- prediction_mask (:obj:`torch.Tensor`, `optional`):
141
- The prediction mask of the text. If `samples` is provided, this is ignored.
142
- special_symbols_mask (:obj:`torch.Tensor`, `optional`):
143
- The special symbols mask of the text. If `samples` is provided, this is ignored.
144
- max_length (:obj:`int`, `optional`, defaults to 1000):
145
- The maximum length of the text.
146
- max_batch_size (:obj:`int`, `optional`, defaults to 128):
147
- The maximum batch size.
148
- token_batch_size (:obj:`int`, `optional`):
149
- The token batch size.
150
- progress_bar (:obj:`bool`, `optional`, defaults to False):
151
- Whether to show a progress bar.
152
- precision (:obj:`str`, `optional`, defaults to 32):
153
- The precision to use for the model.
154
- annotation_type (:obj:`str`, `optional`, defaults to "char"):
155
- The annotation type to use. It can be either "char", "token" or "word".
156
- *args:
157
- Positional arguments.
158
- **kwargs:
159
- Keyword arguments.
160
-
161
- Returns:
162
- :obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`:
163
- The predicted labels for each sample.
164
- """
165
-
166
- precision = precision or self.precision
167
- if samples is not None:
168
-
169
- def _read_iterator():
170
- def samples_it():
171
- for i, sample in enumerate(samples):
172
- assert sample._mixin_prediction_position is None
173
- sample._mixin_prediction_position = i
174
- yield sample
175
-
176
- next_prediction_position = 0
177
- position2predicted_sample = {}
178
-
179
- # instantiate dataset
180
- if self.dataset is None:
181
- raise ValueError(
182
- "You need to pass a dataset to the model in order to predict"
183
- )
184
- self.dataset.samples = samples_it()
185
- self.dataset.model_max_length = max_length
186
- self.dataset.tokens_per_batch = token_batch_size
187
- self.dataset.max_batch_size = max_batch_size
188
-
189
- # instantiate dataloader
190
- iterator = DataLoader(
191
- self.dataset, batch_size=None, num_workers=0, shuffle=False
192
- )
193
- if progress_bar:
194
- iterator = tqdm(iterator, desc="Predicting with RelikReader")
195
-
196
- # fucking autocast only wants pure strings like 'cpu' or 'cuda'
197
- # we need to convert the model device to that
198
- device_type_for_autocast = str(self.device).split(":")[0]
199
- # autocast doesn't work with CPU and stuff different from bfloat16
200
- autocast_mngr = (
201
- contextlib.nullcontext()
202
- if device_type_for_autocast == "cpu"
203
- else (
204
- torch.autocast(
205
- device_type=device_type_for_autocast,
206
- dtype=PRECISION_MAP[precision],
207
- )
208
- )
209
- )
210
-
211
- with autocast_mngr:
212
- for batch in iterator:
213
- batch = move_data_to_device(batch, self.device)
214
- batch_out = self._batch_predict(**batch)
215
-
216
- for sample in batch_out:
217
- if (
218
- sample._mixin_prediction_position
219
- >= next_prediction_position
220
- ):
221
- position2predicted_sample[
222
- sample._mixin_prediction_position
223
- ] = sample
224
-
225
- # yield
226
- while next_prediction_position in position2predicted_sample:
227
- yield position2predicted_sample[next_prediction_position]
228
- del position2predicted_sample[next_prediction_position]
229
- next_prediction_position += 1
230
-
231
- outputs = list(_read_iterator())
232
- for sample in outputs:
233
- self.dataset.merge_patches_predictions(sample)
234
- self.dataset.convert_tokens_to_char_annotations(sample)
235
-
236
- else:
237
- outputs = list(
238
- self._batch_predict(
239
- input_ids,
240
- attention_mask,
241
- token_type_ids,
242
- prediction_mask,
243
- special_symbols_mask,
244
- *args,
245
- **kwargs,
246
- )
247
- )
248
- return outputs
249
-
250
- def _batch_predict(
251
- self,
252
- input_ids: torch.Tensor,
253
- attention_mask: torch.Tensor,
254
- token_type_ids: torch.Tensor | None = None,
255
- prediction_mask: torch.Tensor | None = None,
256
- special_symbols_mask: torch.Tensor | None = None,
257
- sample: List[RelikReaderSample] | None = None,
258
- top_k: int = 5, # the amount of top-k most probable entities to predict
259
- *args,
260
- **kwargs,
261
- ) -> Iterator[RelikReaderSample]:
262
- """
263
- A wrapper around the forward method that returns the predicted labels for each sample.
264
- It also adds the predicted labels to the samples.
265
-
266
- Args:
267
- input_ids (:obj:`torch.Tensor`):
268
- The input ids of the text.
269
- attention_mask (:obj:`torch.Tensor`):
270
- The attention mask of the text.
271
- token_type_ids (:obj:`torch.Tensor`, `optional`):
272
- The token type ids of the text.
273
- prediction_mask (:obj:`torch.Tensor`, `optional`):
274
- The prediction mask of the text.
275
- special_symbols_mask (:obj:`torch.Tensor`, `optional`):
276
- The special symbols mask of the text.
277
- sample (:obj:`List[RelikReaderSample]`, `optional`):
278
- The samples to read. If provided, `text` and `candidates` are ignored.
279
- top_k (:obj:`int`, `optional`, defaults to 5):
280
- The amount of top-k most probable entities to predict.
281
- *args:
282
- Positional arguments.
283
- **kwargs:
284
- Keyword arguments.
285
-
286
- Returns:
287
- The predicted labels for each sample.
288
- """
289
- forward_output = self.forward(
290
- input_ids=input_ids,
291
- attention_mask=attention_mask,
292
- token_type_ids=token_type_ids,
293
- prediction_mask=prediction_mask,
294
- special_symbols_mask=special_symbols_mask,
295
- )
296
-
297
- ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
298
- ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
299
- ed_predictions = forward_output["ed_predictions"].cpu().numpy()
300
- ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
301
-
302
- batch_predictable_candidates = kwargs["predictable_candidates"]
303
- patch_offset = kwargs["patch_offset"]
304
- for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
305
- sample,
306
- ned_start_predictions,
307
- ned_end_predictions,
308
- ed_predictions,
309
- ed_probabilities,
310
- batch_predictable_candidates,
311
- patch_offset,
312
- ):
313
- ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
314
- ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
315
-
316
- final_class2predicted_spans = collections.defaultdict(list)
317
- spans2predicted_probabilities = dict()
318
- for start_token_index, end_token_index in zip(
319
- ne_start_indices, ne_end_indices
320
- ):
321
- # predicted candidate
322
- token_class = edp[start_token_index + 1] - 1
323
- predicted_candidate_title = pred_cands[token_class]
324
- final_class2predicted_spans[predicted_candidate_title].append(
325
- [start_token_index, end_token_index]
326
- )
327
-
328
- # candidates probabilities
329
- classes_probabilities = edpr[start_token_index + 1]
330
- classes_probabilities_best_indices = classes_probabilities.argsort()[
331
- ::-1
332
- ]
333
- titles_2_probs = []
334
- top_k = (
335
- min(
336
- top_k,
337
- len(classes_probabilities_best_indices),
338
- )
339
- if top_k != -1
340
- else len(classes_probabilities_best_indices)
341
- )
342
- for i in range(top_k):
343
- titles_2_probs.append(
344
- (
345
- pred_cands[classes_probabilities_best_indices[i] - 1],
346
- classes_probabilities[
347
- classes_probabilities_best_indices[i]
348
- ].item(),
349
- )
350
- )
351
- spans2predicted_probabilities[
352
- (start_token_index, end_token_index)
353
- ] = titles_2_probs
354
-
355
- if "patches" not in ts._d:
356
- ts._d["patches"] = dict()
357
-
358
- ts._d["patches"][po] = dict()
359
- sample_patch = ts._d["patches"][po]
360
-
361
- sample_patch["predicted_window_labels"] = final_class2predicted_spans
362
- sample_patch["span_title_probabilities"] = spans2predicted_probabilities
363
-
364
- # additional info
365
- sample_patch["predictable_candidates"] = pred_cands
366
-
367
- yield ts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
relik/reader/relik_reader.py DELETED
@@ -1,629 +0,0 @@
1
- import collections
2
- import logging
3
- from pathlib import Path
4
- from typing import Any, Callable, Dict, Iterator, List, Union
5
-
6
- import torch
7
- import transformers as tr
8
- from tqdm import tqdm
9
- from transformers import AutoConfig
10
-
11
- from relik.common.log import get_console_logger, get_logger
12
- from relik.reader.data.relik_reader_data_utils import batchify, flatten
13
- from relik.reader.data.relik_reader_sample import RelikReaderSample
14
- from relik.reader.pytorch_modules.hf.modeling_relik import (
15
- RelikReaderConfig,
16
- RelikReaderSpanModel,
17
- )
18
- from relik.reader.relik_reader_predictor import RelikReaderPredictor
19
- from relik.reader.utils.save_load_utilities import load_model_and_conf
20
- from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols
21
-
22
- console_logger = get_console_logger()
23
- logger = get_logger(__name__, level=logging.INFO)
24
-
25
-
26
- class RelikReaderForSpanExtraction(torch.nn.Module):
27
- def __init__(
28
- self,
29
- transformer_model: str | tr.PreTrainedModel | None = None,
30
- additional_special_symbols: int = 0,
31
- num_layers: int | None = None,
32
- activation: str = "gelu",
33
- linears_hidden_size: int | None = 512,
34
- use_last_k_layers: int = 1,
35
- training: bool = False,
36
- device: str | torch.device | None = None,
37
- tokenizer: str | tr.PreTrainedTokenizer | None = None,
38
- **kwargs,
39
- ) -> None:
40
- super().__init__()
41
-
42
- if isinstance(transformer_model, str):
43
- config = AutoConfig.from_pretrained(
44
- transformer_model, trust_remote_code=True
45
- )
46
- if "relik-reader" in config.model_type:
47
- transformer_model = RelikReaderSpanModel.from_pretrained(
48
- transformer_model, **kwargs
49
- )
50
- else:
51
- reader_config = RelikReaderConfig(
52
- transformer_model=transformer_model,
53
- additional_special_symbols=additional_special_symbols,
54
- num_layers=num_layers,
55
- activation=activation,
56
- linears_hidden_size=linears_hidden_size,
57
- use_last_k_layers=use_last_k_layers,
58
- training=training,
59
- )
60
- transformer_model = RelikReaderSpanModel(reader_config)
61
-
62
- self.relik_reader_model = transformer_model
63
-
64
- self._tokenizer = tokenizer
65
-
66
- # move the model to the device
67
- self.to(device or torch.device("cpu"))
68
-
69
- def forward(
70
- self,
71
- input_ids: torch.Tensor,
72
- attention_mask: torch.Tensor,
73
- token_type_ids: torch.Tensor,
74
- prediction_mask: torch.Tensor | None = None,
75
- special_symbols_mask: torch.Tensor | None = None,
76
- special_symbols_mask_entities: torch.Tensor | None = None,
77
- start_labels: torch.Tensor | None = None,
78
- end_labels: torch.Tensor | None = None,
79
- disambiguation_labels: torch.Tensor | None = None,
80
- relation_labels: torch.Tensor | None = None,
81
- is_validation: bool = False,
82
- is_prediction: bool = False,
83
- *args,
84
- **kwargs,
85
- ) -> Dict[str, Any]:
86
- return self.relik_reader_model(
87
- input_ids,
88
- attention_mask,
89
- token_type_ids,
90
- prediction_mask,
91
- special_symbols_mask,
92
- special_symbols_mask_entities,
93
- start_labels,
94
- end_labels,
95
- disambiguation_labels,
96
- relation_labels,
97
- is_validation,
98
- is_prediction,
99
- *args,
100
- **kwargs,
101
- )
102
-
103
- def batch_predict(
104
- self,
105
- input_ids: torch.Tensor,
106
- attention_mask: torch.Tensor,
107
- token_type_ids: torch.Tensor | None = None,
108
- prediction_mask: torch.Tensor | None = None,
109
- special_symbols_mask: torch.Tensor | None = None,
110
- sample: List[RelikReaderSample] | None = None,
111
- top_k: int = 5, # the amount of top-k most probable entities to predict
112
- *args,
113
- **kwargs,
114
- ) -> Iterator[RelikReaderSample]:
115
- """
116
-
117
-
118
- Args:
119
- input_ids:
120
- attention_mask:
121
- token_type_ids:
122
- prediction_mask:
123
- special_symbols_mask:
124
- sample:
125
- top_k:
126
- *args:
127
- **kwargs:
128
-
129
- Returns:
130
-
131
- """
132
- forward_output = self.forward(
133
- input_ids,
134
- attention_mask,
135
- token_type_ids,
136
- prediction_mask,
137
- special_symbols_mask,
138
- )
139
-
140
- ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
141
- ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
142
- ed_predictions = forward_output["ed_predictions"].cpu().numpy()
143
- ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
144
-
145
- batch_predictable_candidates = kwargs["predictable_candidates"]
146
- patch_offset = kwargs["patch_offset"]
147
- for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
148
- sample,
149
- ned_start_predictions,
150
- ned_end_predictions,
151
- ed_predictions,
152
- ed_probabilities,
153
- batch_predictable_candidates,
154
- patch_offset,
155
- ):
156
- ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
157
- ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
158
-
159
- final_class2predicted_spans = collections.defaultdict(list)
160
- spans2predicted_probabilities = dict()
161
- for start_token_index, end_token_index in zip(
162
- ne_start_indices, ne_end_indices
163
- ):
164
- # predicted candidate
165
- token_class = edp[start_token_index + 1] - 1
166
- predicted_candidate_title = pred_cands[token_class]
167
- final_class2predicted_spans[predicted_candidate_title].append(
168
- [start_token_index, end_token_index]
169
- )
170
-
171
- # candidates probabilities
172
- classes_probabilities = edpr[start_token_index + 1]
173
- classes_probabilities_best_indices = classes_probabilities.argsort()[
174
- ::-1
175
- ]
176
- titles_2_probs = []
177
- top_k = (
178
- min(
179
- top_k,
180
- len(classes_probabilities_best_indices),
181
- )
182
- if top_k != -1
183
- else len(classes_probabilities_best_indices)
184
- )
185
- for i in range(top_k):
186
- titles_2_probs.append(
187
- (
188
- pred_cands[classes_probabilities_best_indices[i] - 1],
189
- classes_probabilities[
190
- classes_probabilities_best_indices[i]
191
- ].item(),
192
- )
193
- )
194
- spans2predicted_probabilities[
195
- (start_token_index, end_token_index)
196
- ] = titles_2_probs
197
-
198
- if "patches" not in ts._d:
199
- ts._d["patches"] = dict()
200
-
201
- ts._d["patches"][po] = dict()
202
- sample_patch = ts._d["patches"][po]
203
-
204
- sample_patch["predicted_window_labels"] = final_class2predicted_spans
205
- sample_patch["span_title_probabilities"] = spans2predicted_probabilities
206
-
207
- # additional info
208
- sample_patch["predictable_candidates"] = pred_cands
209
-
210
- yield ts
211
-
212
- def _build_input(self, text: List[str], candidates: List[List[str]]) -> list[str]:
213
- candidates_symbols = get_special_symbols(len(candidates))
214
- candidates = [
215
- [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL]
216
- for cs, ct in zip(candidates_symbols, candidates)
217
- ]
218
- return (
219
- [self.tokenizer.cls_token]
220
- + text
221
- + [self.tokenizer.sep_token]
222
- + flatten(candidates)
223
- + [self.tokenizer.sep_token]
224
- )
225
-
226
- @staticmethod
227
- def _compute_offsets(offsets_mapping):
228
- offsets_mapping = offsets_mapping.numpy()
229
- token2word = []
230
- word2token = {}
231
- count = 0
232
- for i, offset in enumerate(offsets_mapping):
233
- if offset[0] == 0:
234
- token2word.append(i - count)
235
- word2token[i - count] = [i]
236
- else:
237
- token2word.append(token2word[-1])
238
- word2token[token2word[-1]].append(i)
239
- count += 1
240
- return token2word, word2token
241
-
242
- @staticmethod
243
- def _convert_tokens_to_word_annotations(sample: RelikReaderSample):
244
- triplets = []
245
- entities = []
246
- for entity in sample.predicted_entities:
247
- if sample.entity_candidates:
248
- entities.append(
249
- (
250
- sample.token2word[entity[0] - 1],
251
- sample.token2word[entity[1] - 1] + 1,
252
- sample.entity_candidates[entity[2]],
253
- )
254
- )
255
- else:
256
- entities.append(
257
- (
258
- sample.token2word[entity[0] - 1],
259
- sample.token2word[entity[1] - 1] + 1,
260
- -1,
261
- )
262
- )
263
- for predicted_triplet, predicted_triplet_probabilities in zip(
264
- sample.predicted_relations, sample.predicted_relations_probabilities
265
- ):
266
- subject, object_, relation = predicted_triplet
267
- subject = entities[subject]
268
- object_ = entities[object_]
269
- relation = sample.candidates[relation]
270
- triplets.append(
271
- {
272
- "subject": {
273
- "start": subject[0],
274
- "end": subject[1],
275
- "type": subject[2],
276
- "name": " ".join(sample.tokens[subject[0] : subject[1]]),
277
- },
278
- "relation": {
279
- "name": relation,
280
- "probability": float(predicted_triplet_probabilities.round(2)),
281
- },
282
- "object": {
283
- "start": object_[0],
284
- "end": object_[1],
285
- "type": object_[2],
286
- "name": " ".join(sample.tokens[object_[0] : object_[1]]),
287
- },
288
- }
289
- )
290
- sample.predicted_entities = entities
291
- sample.predicted_relations = triplets
292
- sample.predicted_relations_probabilities = None
293
-
294
- @torch.no_grad()
295
- @torch.inference_mode()
296
- def read(
297
- self,
298
- text: List[str] | List[List[str]] | None = None,
299
- samples: List[RelikReaderSample] | None = None,
300
- input_ids: torch.Tensor | None = None,
301
- attention_mask: torch.Tensor | None = None,
302
- token_type_ids: torch.Tensor | None = None,
303
- prediction_mask: torch.Tensor | None = None,
304
- special_symbols_mask: torch.Tensor | None = None,
305
- special_symbols_mask_entities: torch.Tensor | None = None,
306
- candidates: List[List[str]] | None = None,
307
- max_length: int | None = 1024,
308
- max_batch_size: int | None = 64,
309
- token_batch_size: int | None = None,
310
- progress_bar: bool = False,
311
- *args,
312
- **kwargs,
313
- ) -> List[List[RelikReaderSample]]:
314
- """
315
- Reads the given text.
316
- Args:
317
- text: The text to read in tokens.
318
- samples:
319
- input_ids: The input ids of the text.
320
- attention_mask: The attention mask of the text.
321
- token_type_ids: The token type ids of the text.
322
- prediction_mask: The prediction mask of the text.
323
- special_symbols_mask: The special symbols mask of the text.
324
- special_symbols_mask_entities: The special symbols mask entities of the text.
325
- candidates: The candidates of the text.
326
- max_length: The maximum length of the text.
327
- max_batch_size: The maximum batch size.
328
- token_batch_size: The maximum number of tokens per batch.
329
- progress_bar:
330
- Returns:
331
- The predicted labels for each sample.
332
- """
333
- if text is None and input_ids is None and samples is None:
334
- raise ValueError(
335
- "Either `text` or `input_ids` or `samples` must be provided."
336
- )
337
- if (input_ids is None and samples is None) and (
338
- text is None or candidates is None
339
- ):
340
- raise ValueError(
341
- "`text` and `candidates` must be provided to return the predictions when "
342
- "`input_ids` and `samples` is not provided."
343
- )
344
- if text is not None and samples is None:
345
- if len(text) != len(candidates):
346
- raise ValueError("`text` and `candidates` must have the same length.")
347
- if isinstance(text[0], str): # change to list of text
348
- text = [text]
349
- candidates = [candidates]
350
-
351
- samples = [
352
- RelikReaderSample(tokens=t, candidates=c)
353
- for t, c in zip(text, candidates)
354
- ]
355
-
356
- if samples is not None:
357
- # function that creates a batch from the 'current_batch' list
358
- def output_batch() -> Dict[str, Any]:
359
- assert (
360
- len(
361
- set(
362
- [
363
- len(elem["predictable_candidates"])
364
- for elem in current_batch
365
- ]
366
- )
367
- )
368
- == 1
369
- ), " ".join(
370
- map(
371
- str,
372
- [len(elem["predictable_candidates"]) for elem in current_batch],
373
- )
374
- )
375
-
376
- batch_dict = dict()
377
-
378
- de_values_by_field = {
379
- fn: [de[fn] for de in current_batch if fn in de]
380
- for fn in self.fields_batcher
381
- }
382
-
383
- # in case you provide fields batchers but in the batch
384
- # there are no elements for that field
385
- de_values_by_field = {
386
- fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
387
- }
388
-
389
- assert len(set([len(v) for v in de_values_by_field.values()]))
390
-
391
- # todo: maybe we should report the user about possible
392
- # fields filtering due to "None" instances
393
- de_values_by_field = {
394
- fn: fvs
395
- for fn, fvs in de_values_by_field.items()
396
- if all([fv is not None for fv in fvs])
397
- }
398
-
399
- for field_name, field_values in de_values_by_field.items():
400
- field_batch = (
401
- self.fields_batcher[field_name]([fv[0] for fv in field_values])
402
- if self.fields_batcher[field_name] is not None
403
- else field_values
404
- )
405
-
406
- batch_dict[field_name] = field_batch
407
-
408
- batch_dict = {
409
- k: v.to(self.device) if isinstance(v, torch.Tensor) else v
410
- for k, v in batch_dict.items()
411
- }
412
- return batch_dict
413
-
414
- current_batch = []
415
- predictions = []
416
- current_cand_len = -1
417
-
418
- for sample in tqdm(samples, disable=not progress_bar):
419
- sample.candidates = [NME_SYMBOL] + sample.candidates
420
- inputs_text = self._build_input(sample.tokens, sample.candidates)
421
- model_inputs = self.tokenizer(
422
- inputs_text,
423
- is_split_into_words=True,
424
- add_special_tokens=False,
425
- padding=False,
426
- truncation=True,
427
- max_length=max_length or self.tokenizer.model_max_length,
428
- return_offsets_mapping=True,
429
- return_tensors="pt",
430
- )
431
- model_inputs["special_symbols_mask"] = (
432
- model_inputs["input_ids"] > self.tokenizer.vocab_size
433
- )
434
- # prediction mask is 0 until the first special symbol
435
- model_inputs["token_type_ids"] = (
436
- torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0
437
- ).long()
438
- # shift prediction_mask to the left
439
- model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll(
440
- shifts=-1, dims=1
441
- )
442
- model_inputs["prediction_mask"][:, -1] = 1
443
- model_inputs["prediction_mask"][:, 0] = 1
444
-
445
- assert (
446
- len(model_inputs["special_symbols_mask"])
447
- == len(model_inputs["prediction_mask"])
448
- == len(model_inputs["input_ids"])
449
- )
450
-
451
- model_inputs["sample"] = sample
452
-
453
- # compute cand_len using special_symbols_mask
454
- model_inputs["predictable_candidates"] = sample.candidates[
455
- : model_inputs["special_symbols_mask"].sum().item()
456
- ]
457
- # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]])
458
- offsets = model_inputs.pop("offset_mapping")
459
- offsets = offsets[model_inputs["prediction_mask"] == 0]
460
- sample.token2word, sample.word2token = self._compute_offsets(offsets)
461
- future_max_len = max(
462
- len(model_inputs["input_ids"]),
463
- max([len(b["input_ids"]) for b in current_batch], default=0),
464
- )
465
- future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
466
-
467
- if len(current_batch) > 0 and (
468
- (
469
- len(model_inputs["predictable_candidates"]) != current_cand_len
470
- and current_cand_len != -1
471
- )
472
- or (
473
- isinstance(token_batch_size, int)
474
- and future_tokens_per_batch >= token_batch_size
475
- )
476
- or len(current_batch) == max_batch_size
477
- ):
478
- batch_inputs = output_batch()
479
- current_batch = []
480
- predictions.extend(list(self.batch_predict(**batch_inputs)))
481
- current_cand_len = len(model_inputs["predictable_candidates"])
482
- current_batch.append(model_inputs)
483
-
484
- if current_batch:
485
- batch_inputs = output_batch()
486
- predictions.extend(list(self.batch_predict(**batch_inputs)))
487
- else:
488
- predictions = list(
489
- self.batch_predict(
490
- input_ids,
491
- attention_mask,
492
- token_type_ids,
493
- prediction_mask,
494
- special_symbols_mask,
495
- special_symbols_mask_entities,
496
- *args,
497
- **kwargs,
498
- )
499
- )
500
- return predictions
501
-
502
- @property
503
- def device(self) -> torch.device:
504
- """
505
- The device of the model.
506
- """
507
- return next(self.parameters()).device
508
-
509
- @property
510
- def tokenizer(self) -> tr.PreTrainedTokenizer:
511
- """
512
- The tokenizer.
513
- """
514
- if self._tokenizer:
515
- return self._tokenizer
516
-
517
- self._tokenizer = tr.AutoTokenizer.from_pretrained(
518
- self.relik_reader_model.config.name_or_path
519
- )
520
- return self._tokenizer
521
-
522
- @property
523
- def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
524
- fields_batchers = {
525
- "input_ids": lambda x: batchify(
526
- x, padding_value=self.tokenizer.pad_token_id
527
- ),
528
- "attention_mask": lambda x: batchify(x, padding_value=0),
529
- "token_type_ids": lambda x: batchify(x, padding_value=0),
530
- "prediction_mask": lambda x: batchify(x, padding_value=1),
531
- "global_attention": lambda x: batchify(x, padding_value=0),
532
- "token2word": None,
533
- "sample": None,
534
- "special_symbols_mask": lambda x: batchify(x, padding_value=False),
535
- "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False),
536
- }
537
- if "roberta" in self.relik_reader_model.config.model_type:
538
- del fields_batchers["token_type_ids"]
539
-
540
- return fields_batchers
541
-
542
- def save_pretrained(
543
- self,
544
- output_dir: str,
545
- model_name: str | None = None,
546
- push_to_hub: bool = False,
547
- **kwargs,
548
- ) -> None:
549
- """
550
- Saves the model to the given path.
551
- Args:
552
- output_dir: The path to save the model to.
553
- model_name: The name of the model.
554
- push_to_hub: Whether to push the model to the hub.
555
- """
556
- # create the output directory
557
- output_dir = Path(output_dir)
558
- output_dir.mkdir(parents=True, exist_ok=True)
559
-
560
- model_name = model_name or "relik-reader-for-span-extraction"
561
-
562
- logger.info(f"Saving reader to {output_dir / model_name}")
563
-
564
- # save the model
565
- self.relik_reader_model.register_for_auto_class()
566
- self.relik_reader_model.save_pretrained(
567
- output_dir / model_name, push_to_hub=push_to_hub, **kwargs
568
- )
569
-
570
- logger.info("Saving reader to disk done.")
571
-
572
- if self.tokenizer:
573
- self.tokenizer.save_pretrained(
574
- output_dir / model_name, push_to_hub=push_to_hub, **kwargs
575
- )
576
- logger.info("Saving tokenizer to disk done.")
577
-
578
-
579
- class RelikReader:
580
- def __init__(self, model_path: str, predict_nmes: bool = False):
581
- model, model_conf = load_model_and_conf(model_path)
582
- model.training = False
583
- model.eval()
584
-
585
- val_dataset_conf = model_conf.data.val_dataset
586
- val_dataset_conf.special_symbols = get_special_symbols(
587
- model_conf.model.entities_per_forward
588
- )
589
- val_dataset_conf.transformer_model = model_conf.model.model.transformer_model
590
-
591
- self.predictor = RelikReaderPredictor(
592
- model,
593
- dataset_conf=model_conf.data.val_dataset,
594
- predict_nmes=predict_nmes,
595
- )
596
- self.model_path = model_path
597
-
598
- def link_entities(
599
- self,
600
- dataset_path_or_samples: str | Iterator[RelikReaderSample],
601
- token_batch_size: int = 2048,
602
- progress_bar: bool = False,
603
- ) -> List[RelikReaderSample]:
604
- data_input = (
605
- (dataset_path_or_samples, None)
606
- if isinstance(dataset_path_or_samples, str)
607
- else (None, dataset_path_or_samples)
608
- )
609
- return self.predictor.predict(
610
- *data_input,
611
- dataset_conf=None,
612
- token_batch_size=token_batch_size,
613
- progress_bar=progress_bar,
614
- )
615
-
616
- # def save_pretrained(self, path: Union[str, Path]):
617
- # self.predictor.save(path)
618
-
619
-
620
- def main():
621
- rr = RelikReader("riccorl/relik-reader-aida-deberta-small-old", predict_nmes=True)
622
- predictions = rr.link_entities(
623
- "/Users/ric/Documents/PhD/Projects/relik/data/reader/aida/testa.jsonl"
624
- )
625
- print(predictions)
626
-
627
-
628
- if __name__ == "__main__":
629
- main()