CarlosMalaga
commited on
Commit
•
2f044c1
1
Parent(s):
3376207
Upload 201 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- relik/__init__.py +8 -0
- relik/common/__init__.py +0 -0
- relik/common/__pycache__/__init__.cpython-310.pyc +0 -0
- relik/common/__pycache__/log.cpython-310.pyc +0 -0
- relik/common/__pycache__/torch_utils.cpython-310.pyc +0 -0
- relik/common/__pycache__/upload.cpython-310.pyc +0 -0
- relik/common/__pycache__/utils.cpython-310.pyc +0 -0
- relik/common/log.py +174 -0
- relik/common/torch_utils.py +82 -0
- relik/common/upload.py +144 -0
- relik/common/utils.py +610 -0
- relik/inference/__init__.py +0 -0
- relik/inference/__pycache__/__init__.cpython-310.pyc +0 -0
- relik/inference/__pycache__/annotator.cpython-310.pyc +0 -0
- relik/inference/annotator.py +840 -0
- relik/inference/data/__init__.py +0 -0
- relik/inference/data/__pycache__/__init__.cpython-310.pyc +0 -0
- relik/inference/data/__pycache__/objects.cpython-310.pyc +0 -0
- relik/inference/data/objects.py +88 -0
- relik/inference/data/splitters/__init__.py +0 -0
- relik/inference/data/splitters/__pycache__/__init__.cpython-310.pyc +0 -0
- relik/inference/data/splitters/__pycache__/base_sentence_splitter.cpython-310.pyc +0 -0
- relik/inference/data/splitters/__pycache__/blank_sentence_splitter.cpython-310.pyc +0 -0
- relik/inference/data/splitters/__pycache__/spacy_sentence_splitter.cpython-310.pyc +0 -0
- relik/inference/data/splitters/__pycache__/window_based_splitter.cpython-310.pyc +0 -0
- relik/inference/data/splitters/base_sentence_splitter.py +55 -0
- relik/inference/data/splitters/blank_sentence_splitter.py +29 -0
- relik/inference/data/splitters/spacy_sentence_splitter.py +153 -0
- relik/inference/data/splitters/window_based_splitter.py +62 -0
- relik/inference/data/tokenizers/__init__.py +87 -0
- relik/inference/data/tokenizers/__pycache__/__init__.cpython-310.pyc +0 -0
- relik/inference/data/tokenizers/__pycache__/base_tokenizer.cpython-310.pyc +0 -0
- relik/inference/data/tokenizers/__pycache__/spacy_tokenizer.cpython-310.pyc +0 -0
- relik/inference/data/tokenizers/base_tokenizer.py +84 -0
- relik/inference/data/tokenizers/spacy_tokenizer.py +194 -0
- relik/inference/data/window/__init__.py +0 -0
- relik/inference/data/window/__pycache__/__init__.cpython-310.pyc +0 -0
- relik/inference/data/window/__pycache__/manager.cpython-310.pyc +0 -0
- relik/inference/data/window/manager.py +431 -0
- relik/inference/gerbil.py +269 -0
- relik/inference/serve/__init__.py +0 -0
- relik/inference/serve/backend/__init__.py +0 -0
- relik/inference/serve/backend/fastapi.py +122 -0
- relik/inference/serve/backend/ray.py +165 -0
- relik/inference/serve/backend/utils.py +38 -0
- relik/inference/serve/frontend/__init__.py +0 -0
- relik/inference/serve/frontend/relik_front.py +229 -0
- relik/inference/serve/frontend/relik_re_front.py +251 -0
- relik/inference/serve/frontend/style.css +33 -0
- relik/inference/serve/frontend/utils.py +132 -0
relik/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from relik.inference.annotator import Relik
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
VERSION = {} # type: ignore
|
5 |
+
with open(Path(__file__).parent / "version.py", "r") as version_file:
|
6 |
+
exec(version_file.read(), VERSION)
|
7 |
+
|
8 |
+
__version__ = VERSION["VERSION"]
|
relik/common/__init__.py
ADDED
File without changes
|
relik/common/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (180 Bytes). View file
|
|
relik/common/__pycache__/log.cpython-310.pyc
ADDED
Binary file (4.39 kB). View file
|
|
relik/common/__pycache__/torch_utils.cpython-310.pyc
ADDED
Binary file (1.11 kB). View file
|
|
relik/common/__pycache__/upload.cpython-310.pyc
ADDED
Binary file (4.04 kB). View file
|
|
relik/common/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (14.8 kB). View file
|
|
relik/common/log.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import threading
|
5 |
+
from logging.config import dictConfig
|
6 |
+
from typing import Any, Dict, Optional
|
7 |
+
|
8 |
+
from art import text2art, tprint
|
9 |
+
from colorama import Fore, Style, init
|
10 |
+
from rich import get_console
|
11 |
+
|
12 |
+
_lock = threading.Lock()
|
13 |
+
_default_handler: Optional[logging.Handler] = None
|
14 |
+
|
15 |
+
_default_log_level = logging.WARNING
|
16 |
+
|
17 |
+
# fancy logger
|
18 |
+
_console = get_console()
|
19 |
+
|
20 |
+
|
21 |
+
class ColorfulFormatter(logging.Formatter):
|
22 |
+
"""
|
23 |
+
Formatter to add coloring to log messages by log type
|
24 |
+
"""
|
25 |
+
|
26 |
+
COLORS = {
|
27 |
+
"WARNING": Fore.YELLOW,
|
28 |
+
"ERROR": Fore.RED,
|
29 |
+
"CRITICAL": Fore.RED + Style.BRIGHT,
|
30 |
+
"DEBUG": Fore.CYAN,
|
31 |
+
# "INFO": Fore.GREEN,
|
32 |
+
}
|
33 |
+
|
34 |
+
def format(self, record):
|
35 |
+
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
36 |
+
log_message = super().format(record)
|
37 |
+
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
38 |
+
|
39 |
+
|
40 |
+
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
41 |
+
"version": 1,
|
42 |
+
"formatters": {
|
43 |
+
"simple": {
|
44 |
+
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
45 |
+
},
|
46 |
+
"colorful": {
|
47 |
+
"()": ColorfulFormatter,
|
48 |
+
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
|
49 |
+
},
|
50 |
+
},
|
51 |
+
"filters": {},
|
52 |
+
"handlers": {
|
53 |
+
"console": {
|
54 |
+
"class": "logging.StreamHandler",
|
55 |
+
"formatter": "simple",
|
56 |
+
"filters": [],
|
57 |
+
"stream": sys.stdout,
|
58 |
+
},
|
59 |
+
"color_console": {
|
60 |
+
"class": "logging.StreamHandler",
|
61 |
+
"formatter": "colorful",
|
62 |
+
"filters": [],
|
63 |
+
"stream": sys.stdout,
|
64 |
+
},
|
65 |
+
},
|
66 |
+
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
|
67 |
+
"loggers": {
|
68 |
+
"relik": {
|
69 |
+
"handlers": ["color_console"],
|
70 |
+
"level": "DEBUG",
|
71 |
+
"propagate": False,
|
72 |
+
},
|
73 |
+
},
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
def configure_logging(**kwargs):
|
78 |
+
"""Configure with default logging"""
|
79 |
+
init() # Initialize colorama
|
80 |
+
# merge DEFAULT_LOGGING_CONFIG with kwargs
|
81 |
+
logger_config = DEFAULT_LOGGING_CONFIG
|
82 |
+
if kwargs:
|
83 |
+
logger_config.update(kwargs)
|
84 |
+
dictConfig(logger_config)
|
85 |
+
|
86 |
+
|
87 |
+
def _get_library_name() -> str:
|
88 |
+
return __name__.split(".")[0]
|
89 |
+
|
90 |
+
|
91 |
+
def _get_library_root_logger() -> logging.Logger:
|
92 |
+
return logging.getLogger(_get_library_name())
|
93 |
+
|
94 |
+
|
95 |
+
def _configure_library_root_logger() -> None:
|
96 |
+
global _default_handler
|
97 |
+
|
98 |
+
with _lock:
|
99 |
+
if _default_handler:
|
100 |
+
# This library has already configured the library root logger.
|
101 |
+
return
|
102 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
103 |
+
_default_handler.flush = sys.stderr.flush
|
104 |
+
|
105 |
+
# Apply our default configuration to the library root logger.
|
106 |
+
library_root_logger = _get_library_root_logger()
|
107 |
+
library_root_logger.addHandler(_default_handler)
|
108 |
+
library_root_logger.setLevel(_default_log_level)
|
109 |
+
library_root_logger.propagate = False
|
110 |
+
|
111 |
+
|
112 |
+
def _reset_library_root_logger() -> None:
|
113 |
+
global _default_handler
|
114 |
+
|
115 |
+
with _lock:
|
116 |
+
if not _default_handler:
|
117 |
+
return
|
118 |
+
|
119 |
+
library_root_logger = _get_library_root_logger()
|
120 |
+
library_root_logger.removeHandler(_default_handler)
|
121 |
+
library_root_logger.setLevel(logging.NOTSET)
|
122 |
+
_default_handler = None
|
123 |
+
|
124 |
+
|
125 |
+
def set_log_level(level: int, logger: logging.Logger = None) -> None:
|
126 |
+
"""
|
127 |
+
Set the log level.
|
128 |
+
Args:
|
129 |
+
level (:obj:`int`):
|
130 |
+
Logging level.
|
131 |
+
logger (:obj:`logging.Logger`):
|
132 |
+
Logger to set the log level.
|
133 |
+
"""
|
134 |
+
if not logger:
|
135 |
+
_configure_library_root_logger()
|
136 |
+
logger = _get_library_root_logger()
|
137 |
+
logger.setLevel(level)
|
138 |
+
|
139 |
+
|
140 |
+
def get_logger(
|
141 |
+
name: Optional[str] = None,
|
142 |
+
level: Optional[int] = None,
|
143 |
+
formatter: Optional[str] = None,
|
144 |
+
**kwargs,
|
145 |
+
) -> logging.Logger:
|
146 |
+
"""
|
147 |
+
Return a logger with the specified name.
|
148 |
+
"""
|
149 |
+
|
150 |
+
configure_logging(**kwargs)
|
151 |
+
|
152 |
+
if name is None:
|
153 |
+
name = _get_library_name()
|
154 |
+
|
155 |
+
_configure_library_root_logger()
|
156 |
+
|
157 |
+
if level is not None:
|
158 |
+
set_log_level(level)
|
159 |
+
|
160 |
+
if formatter is None:
|
161 |
+
formatter = logging.Formatter(
|
162 |
+
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
163 |
+
)
|
164 |
+
_default_handler.setFormatter(formatter)
|
165 |
+
|
166 |
+
return logging.getLogger(name)
|
167 |
+
|
168 |
+
|
169 |
+
def get_console_logger():
|
170 |
+
return _console
|
171 |
+
|
172 |
+
|
173 |
+
def print_relik_text_art(text: str = "relik", font: str = "larry3d", **kwargs):
|
174 |
+
tprint(text, font=font, **kwargs)
|
relik/common/torch_utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import transformers as tr
|
6 |
+
|
7 |
+
from relik.common.utils import is_package_available
|
8 |
+
|
9 |
+
# check if ORT is available
|
10 |
+
if is_package_available("onnxruntime"):
|
11 |
+
from optimum.onnxruntime import (
|
12 |
+
ORTModel,
|
13 |
+
ORTModelForCustomTasks,
|
14 |
+
ORTModelForSequenceClassification,
|
15 |
+
ORTOptimizer,
|
16 |
+
)
|
17 |
+
from optimum.onnxruntime.configuration import AutoOptimizationConfig
|
18 |
+
|
19 |
+
# from relik.retriever.pytorch_modules import PRECISION_MAP
|
20 |
+
|
21 |
+
|
22 |
+
def get_autocast_context(
|
23 |
+
device: str | torch.device, precision: str
|
24 |
+
) -> contextlib.AbstractContextManager:
|
25 |
+
# fucking autocast only wants pure strings like 'cpu' or 'cuda'
|
26 |
+
# we need to convert the model device to that
|
27 |
+
device_type_for_autocast = str(device).split(":")[0]
|
28 |
+
|
29 |
+
from relik.retriever.pytorch_modules import PRECISION_MAP
|
30 |
+
|
31 |
+
# autocast doesn't work with CPU and stuff different from bfloat16
|
32 |
+
autocast_manager = (
|
33 |
+
contextlib.nullcontext()
|
34 |
+
if device_type_for_autocast in ["cpu", "mps"]
|
35 |
+
and PRECISION_MAP[precision] != torch.bfloat16
|
36 |
+
else (
|
37 |
+
torch.autocast(
|
38 |
+
device_type=device_type_for_autocast,
|
39 |
+
dtype=PRECISION_MAP[precision],
|
40 |
+
)
|
41 |
+
)
|
42 |
+
)
|
43 |
+
return autocast_manager
|
44 |
+
|
45 |
+
|
46 |
+
# def load_ort_optimized_hf_model(
|
47 |
+
# hf_model: tr.PreTrainedModel,
|
48 |
+
# provider: str = "CPUExecutionProvider",
|
49 |
+
# ort_model_type: callable = "ORTModelForCustomTasks",
|
50 |
+
# ) -> ORTModel:
|
51 |
+
# """
|
52 |
+
# Load an optimized ONNX Runtime HF model.
|
53 |
+
#
|
54 |
+
# Args:
|
55 |
+
# hf_model (`tr.PreTrainedModel`):
|
56 |
+
# The HF model to optimize.
|
57 |
+
# provider (`str`, optional):
|
58 |
+
# The ONNX Runtime provider to use. Defaults to "CPUExecutionProvider".
|
59 |
+
#
|
60 |
+
# Returns:
|
61 |
+
# `ORTModel`: The optimized HF model.
|
62 |
+
# """
|
63 |
+
# if isinstance(hf_model, ORTModel):
|
64 |
+
# return hf_model
|
65 |
+
# temp_dir = tempfile.mkdtemp()
|
66 |
+
# hf_model.save_pretrained(temp_dir)
|
67 |
+
# ort_model = ort_model_type.from_pretrained(
|
68 |
+
# temp_dir, export=True, provider=provider, use_io_binding=True
|
69 |
+
# )
|
70 |
+
# if is_package_available("onnxruntime"):
|
71 |
+
# optimizer = ORTOptimizer.from_pretrained(ort_model)
|
72 |
+
# optimization_config = AutoOptimizationConfig.O4()
|
73 |
+
# optimizer.optimize(save_dir=temp_dir, optimization_config=optimization_config)
|
74 |
+
# ort_model = ort_model_type.from_pretrained(
|
75 |
+
# temp_dir,
|
76 |
+
# export=True,
|
77 |
+
# provider=provider,
|
78 |
+
# use_io_binding=bool(provider == "CUDAExecutionProvider"),
|
79 |
+
# )
|
80 |
+
# return ort_model
|
81 |
+
# else:
|
82 |
+
# raise ValueError("onnxruntime is not installed. Please install Ray with `pip install relik[serve]`.")
|
relik/common/upload.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import tempfile
|
6 |
+
import zipfile
|
7 |
+
from datetime import datetime
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
import huggingface_hub
|
12 |
+
|
13 |
+
from relik.common.log import get_logger
|
14 |
+
from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
|
15 |
+
|
16 |
+
logger = get_logger(__name__, level=logging.DEBUG)
|
17 |
+
|
18 |
+
|
19 |
+
def create_info_file(tmpdir: Path):
|
20 |
+
logger.debug("Computing md5 of model.zip")
|
21 |
+
md5 = get_md5(tmpdir / "model.zip")
|
22 |
+
date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
|
23 |
+
|
24 |
+
logger.debug("Dumping info.json file")
|
25 |
+
with (tmpdir / "info.json").open("w") as f:
|
26 |
+
json.dump(dict(md5=md5, upload_date=date), f, indent=2)
|
27 |
+
|
28 |
+
|
29 |
+
def zip_run(
|
30 |
+
dir_path: Union[str, os.PathLike],
|
31 |
+
tmpdir: Union[str, os.PathLike],
|
32 |
+
zip_name: str = "model.zip",
|
33 |
+
) -> Path:
|
34 |
+
logger.debug(f"zipping {dir_path} to {tmpdir}")
|
35 |
+
# creates a zip version of the provided dir_path
|
36 |
+
run_dir = Path(dir_path)
|
37 |
+
zip_path = tmpdir / zip_name
|
38 |
+
|
39 |
+
with zipfile.ZipFile(zip_path, "w") as zip_file:
|
40 |
+
# fully zip the run directory maintaining its structure
|
41 |
+
for file in run_dir.rglob("*.*"):
|
42 |
+
if file.is_dir():
|
43 |
+
continue
|
44 |
+
|
45 |
+
zip_file.write(file, arcname=file.relative_to(run_dir))
|
46 |
+
|
47 |
+
return zip_path
|
48 |
+
|
49 |
+
|
50 |
+
def get_logged_in_username():
|
51 |
+
token = huggingface_hub.HfFolder.get_token()
|
52 |
+
if token is None:
|
53 |
+
raise ValueError(
|
54 |
+
"No HuggingFace token found. You need to execute `huggingface-cli login` first!"
|
55 |
+
)
|
56 |
+
api = huggingface_hub.HfApi()
|
57 |
+
user = api.whoami(token=token)
|
58 |
+
return user["name"]
|
59 |
+
|
60 |
+
|
61 |
+
def upload(
|
62 |
+
model_dir: Union[str, os.PathLike],
|
63 |
+
model_name: str,
|
64 |
+
filenames: Optional[list[str]] = None,
|
65 |
+
organization: Optional[str] = None,
|
66 |
+
repo_name: Optional[str] = None,
|
67 |
+
commit: Optional[str] = None,
|
68 |
+
archive: bool = False,
|
69 |
+
):
|
70 |
+
token = huggingface_hub.HfFolder.get_token()
|
71 |
+
if token is None:
|
72 |
+
raise ValueError(
|
73 |
+
"No HuggingFace token found. You need to execute `huggingface-cli login` first!"
|
74 |
+
)
|
75 |
+
|
76 |
+
repo_id = repo_name or model_name
|
77 |
+
if organization is not None:
|
78 |
+
repo_id = f"{organization}/{repo_id}"
|
79 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
80 |
+
api = huggingface_hub.HfApi()
|
81 |
+
repo_url = api.create_repo(
|
82 |
+
token=token,
|
83 |
+
repo_id=repo_id,
|
84 |
+
exist_ok=True,
|
85 |
+
)
|
86 |
+
repo = huggingface_hub.Repository(
|
87 |
+
str(tmpdir), clone_from=repo_url, use_auth_token=token
|
88 |
+
)
|
89 |
+
|
90 |
+
tmp_path = Path(tmpdir)
|
91 |
+
if archive:
|
92 |
+
# otherwise we zip the model_dir
|
93 |
+
logger.debug(f"Zipping {model_dir} to {tmp_path}")
|
94 |
+
zip_run(model_dir, tmp_path)
|
95 |
+
create_info_file(tmp_path)
|
96 |
+
else:
|
97 |
+
# if the user wants to upload a transformers model, we don't need to zip it
|
98 |
+
# we just need to copy the files to the tmpdir
|
99 |
+
logger.debug(f"Copying {model_dir} to {tmpdir}")
|
100 |
+
# copy only the files that are needed
|
101 |
+
if filenames is not None:
|
102 |
+
for filename in filenames:
|
103 |
+
os.system(f"cp {model_dir}/{filename} {tmpdir}")
|
104 |
+
else:
|
105 |
+
os.system(f"cp -r {model_dir}/* {tmpdir}")
|
106 |
+
|
107 |
+
# this method automatically puts large files (>10MB) into git lfs
|
108 |
+
repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
|
109 |
+
|
110 |
+
|
111 |
+
def parse_args() -> argparse.Namespace:
|
112 |
+
parser = argparse.ArgumentParser()
|
113 |
+
parser.add_argument(
|
114 |
+
"model_dir", help="The directory of the model you want to upload"
|
115 |
+
)
|
116 |
+
parser.add_argument("model_name", help="The model you want to upload")
|
117 |
+
parser.add_argument(
|
118 |
+
"--organization",
|
119 |
+
help="the name of the organization where you want to upload the model",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--repo_name",
|
123 |
+
help="Optional name to use when uploading to the HuggingFace repository",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--commit", help="Commit message to use when pushing to the HuggingFace Hub"
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--archive",
|
130 |
+
action="store_true",
|
131 |
+
help="""
|
132 |
+
Whether to compress the model directory before uploading it.
|
133 |
+
If True, the model directory will be zipped and the zip file will be uploaded.
|
134 |
+
If False, the model directory will be uploaded as is.""",
|
135 |
+
)
|
136 |
+
return parser.parse_args()
|
137 |
+
|
138 |
+
|
139 |
+
def main():
|
140 |
+
upload(**vars(parse_args()))
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
main()
|
relik/common/utils.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import tarfile
|
7 |
+
import tempfile
|
8 |
+
from functools import partial
|
9 |
+
from hashlib import sha256
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Any, BinaryIO, Dict, List, Optional, Union
|
12 |
+
from urllib.parse import urlparse
|
13 |
+
from zipfile import ZipFile, is_zipfile
|
14 |
+
|
15 |
+
import huggingface_hub
|
16 |
+
import requests
|
17 |
+
import tqdm
|
18 |
+
from filelock import FileLock
|
19 |
+
from transformers.utils.hub import cached_file as hf_cached_file
|
20 |
+
|
21 |
+
from relik.common.log import get_logger
|
22 |
+
|
23 |
+
# name constants
|
24 |
+
WEIGHTS_NAME = "weights.pt"
|
25 |
+
ONNX_WEIGHTS_NAME = "weights.onnx"
|
26 |
+
CONFIG_NAME = "config.yaml"
|
27 |
+
LABELS_NAME = "labels.json"
|
28 |
+
|
29 |
+
# SAPIENZANLP_USER_NAME = "sapienzanlp"
|
30 |
+
SAPIENZANLP_USER_NAME = "riccorl"
|
31 |
+
SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}"
|
32 |
+
SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = (
|
33 |
+
f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip"
|
34 |
+
)
|
35 |
+
# path constants
|
36 |
+
HF_CACHE_DIR = Path(os.getenv("HF_HOME", Path.home() / ".cache/huggingface/hub"))
|
37 |
+
SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", HF_CACHE_DIR)
|
38 |
+
SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S"
|
39 |
+
|
40 |
+
logger = get_logger(__name__)
|
41 |
+
|
42 |
+
|
43 |
+
def sapienzanlp_model_urls(model_id: str) -> str:
|
44 |
+
"""
|
45 |
+
Returns the URL for a possible SapienzaNLP valid model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
model_id (:obj:`str`):
|
49 |
+
A SapienzaNLP model id.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
:obj:`str`: The url for the model id.
|
53 |
+
"""
|
54 |
+
# check if there is already the namespace of the user
|
55 |
+
if "/" in model_id:
|
56 |
+
return model_id
|
57 |
+
return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id)
|
58 |
+
|
59 |
+
|
60 |
+
def is_package_available(package_name: str) -> bool:
|
61 |
+
"""
|
62 |
+
Check if a package is available.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
package_name (`str`): The name of the package to check.
|
66 |
+
"""
|
67 |
+
return importlib.util.find_spec(package_name) is not None
|
68 |
+
|
69 |
+
|
70 |
+
def load_json(path: Union[str, Path]) -> Any:
|
71 |
+
"""
|
72 |
+
Load a json file provided in input.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
path (`Union[str, Path]`): The path to the json file to load.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
`Any`: The loaded json file.
|
79 |
+
"""
|
80 |
+
with open(path, encoding="utf8") as f:
|
81 |
+
return json.load(f)
|
82 |
+
|
83 |
+
|
84 |
+
def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None):
|
85 |
+
"""
|
86 |
+
Dump input to json file.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
document (`Any`): The document to dump.
|
90 |
+
path (`Union[str, Path]`): The path to dump the document to.
|
91 |
+
indent (`Optional[int]`): The indent to use for the json file.
|
92 |
+
|
93 |
+
"""
|
94 |
+
with open(path, "w", encoding="utf8") as outfile:
|
95 |
+
json.dump(document, outfile, indent=indent)
|
96 |
+
|
97 |
+
|
98 |
+
def get_md5(path: Path):
|
99 |
+
"""
|
100 |
+
Get the MD5 value of a path.
|
101 |
+
"""
|
102 |
+
import hashlib
|
103 |
+
|
104 |
+
with path.open("rb") as fin:
|
105 |
+
data = fin.read()
|
106 |
+
return hashlib.md5(data).hexdigest()
|
107 |
+
|
108 |
+
|
109 |
+
def file_exists(path: Union[str, os.PathLike]) -> bool:
|
110 |
+
"""
|
111 |
+
Check if the file at :obj:`path` exists.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
path (:obj:`str`, :obj:`os.PathLike`):
|
115 |
+
Path to check.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
:obj:`bool`: :obj:`True` if the file exists.
|
119 |
+
"""
|
120 |
+
return Path(path).exists()
|
121 |
+
|
122 |
+
|
123 |
+
def dir_exists(path: Union[str, os.PathLike]) -> bool:
|
124 |
+
"""
|
125 |
+
Check if the directory at :obj:`path` exists.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
path (:obj:`str`, :obj:`os.PathLike`):
|
129 |
+
Path to check.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
:obj:`bool`: :obj:`True` if the directory exists.
|
133 |
+
"""
|
134 |
+
return Path(path).is_dir()
|
135 |
+
|
136 |
+
|
137 |
+
def is_remote_url(url_or_filename: Union[str, Path]):
|
138 |
+
"""
|
139 |
+
Returns :obj:`True` if the input path is an url.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
url_or_filename (:obj:`str`, :obj:`Path`):
|
143 |
+
path to check.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
:obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise.
|
147 |
+
|
148 |
+
"""
|
149 |
+
if isinstance(url_or_filename, Path):
|
150 |
+
url_or_filename = str(url_or_filename)
|
151 |
+
parsed = urlparse(url_or_filename)
|
152 |
+
return parsed.scheme in ("http", "https")
|
153 |
+
|
154 |
+
|
155 |
+
def url_to_filename(resource: str, etag: str = None) -> str:
|
156 |
+
"""
|
157 |
+
Convert a `resource` into a hashed filename in a repeatable way.
|
158 |
+
If `etag` is specified, append its hash to the resources's, delimited
|
159 |
+
by a period.
|
160 |
+
"""
|
161 |
+
resource_bytes = resource.encode("utf-8")
|
162 |
+
resource_hash = sha256(resource_bytes)
|
163 |
+
filename = resource_hash.hexdigest()
|
164 |
+
|
165 |
+
if etag:
|
166 |
+
etag_bytes = etag.encode("utf-8")
|
167 |
+
etag_hash = sha256(etag_bytes)
|
168 |
+
filename += "." + etag_hash.hexdigest()
|
169 |
+
|
170 |
+
return filename
|
171 |
+
|
172 |
+
|
173 |
+
def download_resource(
|
174 |
+
url: str,
|
175 |
+
temp_file: BinaryIO,
|
176 |
+
headers=None,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Download remote file.
|
180 |
+
"""
|
181 |
+
|
182 |
+
if headers is None:
|
183 |
+
headers = {}
|
184 |
+
|
185 |
+
r = requests.get(url, stream=True, headers=headers)
|
186 |
+
r.raise_for_status()
|
187 |
+
content_length = r.headers.get("Content-Length")
|
188 |
+
total = int(content_length) if content_length is not None else None
|
189 |
+
progress = tqdm(
|
190 |
+
unit="B",
|
191 |
+
unit_scale=True,
|
192 |
+
total=total,
|
193 |
+
desc="Downloading",
|
194 |
+
disable=logger.level in [logging.NOTSET],
|
195 |
+
)
|
196 |
+
for chunk in r.iter_content(chunk_size=1024):
|
197 |
+
if chunk: # filter out keep-alive new chunks
|
198 |
+
progress.update(len(chunk))
|
199 |
+
temp_file.write(chunk)
|
200 |
+
progress.close()
|
201 |
+
|
202 |
+
|
203 |
+
def download_and_cache(
|
204 |
+
url: Union[str, Path],
|
205 |
+
cache_dir: Union[str, Path] = None,
|
206 |
+
force_download: bool = False,
|
207 |
+
):
|
208 |
+
if cache_dir is None:
|
209 |
+
cache_dir = SAPIENZANLP_CACHE_DIR
|
210 |
+
if isinstance(url, Path):
|
211 |
+
url = str(url)
|
212 |
+
|
213 |
+
# check if cache dir exists
|
214 |
+
Path(cache_dir).mkdir(parents=True, exist_ok=True)
|
215 |
+
|
216 |
+
# check if file is private
|
217 |
+
headers = {}
|
218 |
+
try:
|
219 |
+
r = requests.head(url, allow_redirects=False, timeout=10)
|
220 |
+
r.raise_for_status()
|
221 |
+
except requests.exceptions.HTTPError:
|
222 |
+
if r.status_code == 401:
|
223 |
+
hf_token = huggingface_hub.HfFolder.get_token()
|
224 |
+
if hf_token is None:
|
225 |
+
raise ValueError(
|
226 |
+
"You need to login to HuggingFace to download this model "
|
227 |
+
"(use the `huggingface-cli login` command)"
|
228 |
+
)
|
229 |
+
headers["Authorization"] = f"Bearer {hf_token}"
|
230 |
+
|
231 |
+
etag = None
|
232 |
+
try:
|
233 |
+
r = requests.head(url, allow_redirects=True, timeout=10, headers=headers)
|
234 |
+
r.raise_for_status()
|
235 |
+
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
236 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
237 |
+
# we fallback to the regular etag header.
|
238 |
+
# If we don't have any of those, raise an error.
|
239 |
+
if etag is None:
|
240 |
+
raise OSError(
|
241 |
+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
242 |
+
)
|
243 |
+
# In case of a redirect,
|
244 |
+
# save an extra redirect on the request.get call,
|
245 |
+
# and ensure we download the exact atomic version even if it changed
|
246 |
+
# between the HEAD and the GET (unlikely, but hey).
|
247 |
+
if 300 <= r.status_code <= 399:
|
248 |
+
url = r.headers["Location"]
|
249 |
+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
250 |
+
# Actually raise for those subclasses of ConnectionError
|
251 |
+
raise
|
252 |
+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
253 |
+
# Otherwise, our Internet connection is down.
|
254 |
+
# etag is None
|
255 |
+
pass
|
256 |
+
|
257 |
+
# get filename from the url
|
258 |
+
filename = url_to_filename(url, etag)
|
259 |
+
# get cache path to put the file
|
260 |
+
cache_path = cache_dir / filename
|
261 |
+
|
262 |
+
# the file is already here, return it
|
263 |
+
if file_exists(cache_path) and not force_download:
|
264 |
+
logger.info(
|
265 |
+
f"{url} found in cache, set `force_download=True` to force the download"
|
266 |
+
)
|
267 |
+
return cache_path
|
268 |
+
|
269 |
+
cache_path = str(cache_path)
|
270 |
+
# Prevent parallel downloads of the same file with a lock.
|
271 |
+
lock_path = cache_path + ".lock"
|
272 |
+
with FileLock(lock_path):
|
273 |
+
# If the download just completed while the lock was activated.
|
274 |
+
if file_exists(cache_path) and not force_download:
|
275 |
+
# Even if returning early like here, the lock will be released.
|
276 |
+
return cache_path
|
277 |
+
|
278 |
+
temp_file_manager = partial(
|
279 |
+
tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
|
280 |
+
)
|
281 |
+
|
282 |
+
# Download to temporary file, then copy to cache dir once finished.
|
283 |
+
# Otherwise, you get corrupt cache entries if the download gets interrupted.
|
284 |
+
with temp_file_manager() as temp_file:
|
285 |
+
logger.info(
|
286 |
+
f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}"
|
287 |
+
)
|
288 |
+
download_resource(url, temp_file, headers)
|
289 |
+
|
290 |
+
logger.info(f"storing {url} in cache at {cache_path}")
|
291 |
+
os.replace(temp_file.name, cache_path)
|
292 |
+
|
293 |
+
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
|
294 |
+
umask = os.umask(0o666)
|
295 |
+
os.umask(umask)
|
296 |
+
os.chmod(cache_path, 0o666 & ~umask)
|
297 |
+
|
298 |
+
logger.info(f"creating metadata file for {cache_path}")
|
299 |
+
meta = {"url": url} # , "etag": etag}
|
300 |
+
meta_path = cache_path + ".json"
|
301 |
+
with open(meta_path, "w") as meta_file:
|
302 |
+
json.dump(meta, meta_file)
|
303 |
+
|
304 |
+
return cache_path
|
305 |
+
|
306 |
+
|
307 |
+
def download_from_hf(
|
308 |
+
path_or_repo_id: Union[str, Path],
|
309 |
+
filenames: List[str],
|
310 |
+
cache_dir: Union[str, Path] = None,
|
311 |
+
force_download: bool = False,
|
312 |
+
resume_download: bool = False,
|
313 |
+
proxies: Optional[Dict[str, str]] = None,
|
314 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
315 |
+
revision: Optional[str] = None,
|
316 |
+
local_files_only: bool = False,
|
317 |
+
subfolder: str = "",
|
318 |
+
repo_type: str = "model",
|
319 |
+
):
|
320 |
+
if isinstance(path_or_repo_id, Path):
|
321 |
+
path_or_repo_id = str(path_or_repo_id)
|
322 |
+
|
323 |
+
downloaded_paths = []
|
324 |
+
for filename in filenames:
|
325 |
+
downloaded_path = hf_cached_file(
|
326 |
+
path_or_repo_id,
|
327 |
+
filename,
|
328 |
+
cache_dir=cache_dir,
|
329 |
+
force_download=force_download,
|
330 |
+
proxies=proxies,
|
331 |
+
resume_download=resume_download,
|
332 |
+
use_auth_token=use_auth_token,
|
333 |
+
revision=revision,
|
334 |
+
local_files_only=local_files_only,
|
335 |
+
subfolder=subfolder,
|
336 |
+
)
|
337 |
+
downloaded_paths.append(downloaded_path)
|
338 |
+
|
339 |
+
# we want the folder where the files are downloaded
|
340 |
+
# the best guess is the parent folder of the first file
|
341 |
+
probably_the_folder = Path(downloaded_paths[0]).parent
|
342 |
+
return probably_the_folder
|
343 |
+
|
344 |
+
|
345 |
+
def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str:
|
346 |
+
"""
|
347 |
+
Resolve a model name or directory to a model archive name or directory.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
model_name_or_dir (:obj:`str` or :obj:`os.PathLike`):
|
351 |
+
A model name or directory.
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
:obj:`str`: The model archive name or directory.
|
355 |
+
"""
|
356 |
+
if is_remote_url(model_name_or_dir):
|
357 |
+
# if model_name_or_dir is a URL
|
358 |
+
# download it and try to load
|
359 |
+
model_archive = model_name_or_dir
|
360 |
+
elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file():
|
361 |
+
# if model_name_or_dir is a local directory or
|
362 |
+
# an archive file try to load it
|
363 |
+
model_archive = model_name_or_dir
|
364 |
+
else:
|
365 |
+
# probably model_name_or_dir is a sapienzanlp model id
|
366 |
+
# guess the url and try to download
|
367 |
+
model_name_or_dir_ = model_name_or_dir
|
368 |
+
# raise ValueError(f"Providing a model id is not supported yet.")
|
369 |
+
model_archive = sapienzanlp_model_urls(model_name_or_dir_)
|
370 |
+
|
371 |
+
return model_archive
|
372 |
+
|
373 |
+
|
374 |
+
def from_cache(
|
375 |
+
url_or_filename: Union[str, Path],
|
376 |
+
cache_dir: Union[str, Path] = None,
|
377 |
+
force_download: bool = False,
|
378 |
+
resume_download: bool = False,
|
379 |
+
proxies: Optional[Dict[str, str]] = None,
|
380 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
381 |
+
revision: Optional[str] = None,
|
382 |
+
local_files_only: bool = False,
|
383 |
+
subfolder: str = "",
|
384 |
+
filenames: Optional[List[str]] = None,
|
385 |
+
) -> Path:
|
386 |
+
"""
|
387 |
+
Given something that could be either a local path or a URL (or a SapienzaNLP model id),
|
388 |
+
determine which one and return a path to the corresponding file.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
url_or_filename (:obj:`str` or :obj:`Path`):
|
392 |
+
A path to a local file or a URL (or a SapienzaNLP model id).
|
393 |
+
cache_dir (:obj:`str` or :obj:`Path`, `optional`):
|
394 |
+
Path to a directory in which a downloaded file will be cached.
|
395 |
+
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
396 |
+
Whether or not to re-download the file even if it already exists.
|
397 |
+
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
398 |
+
Whether or not to delete incompletely received files. Attempts to resume the download if such a file
|
399 |
+
exists.
|
400 |
+
proxies (:obj:`Dict[str, str]`, `optional`):
|
401 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
402 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
403 |
+
use_auth_token (:obj:`Union[bool, str]`, `optional`):
|
404 |
+
Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from
|
405 |
+
:obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token.
|
406 |
+
revision (:obj:`str`, `optional`):
|
407 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
408 |
+
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
409 |
+
identifier allowed by git.
|
410 |
+
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
411 |
+
Whether or not to raise an error if the file to be downloaded is local.
|
412 |
+
subfolder (:obj:`str`, `optional`):
|
413 |
+
In case the relevant file is in a subfolder of the URL, specify it here.
|
414 |
+
filenames (:obj:`List[str]`, `optional`):
|
415 |
+
List of filenames to look for in the directory structure.
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
:obj:`Path`: Path to the cached file.
|
419 |
+
"""
|
420 |
+
|
421 |
+
url_or_filename = model_name_or_path_resolver(url_or_filename)
|
422 |
+
|
423 |
+
if cache_dir is None:
|
424 |
+
cache_dir = SAPIENZANLP_CACHE_DIR
|
425 |
+
|
426 |
+
if file_exists(url_or_filename):
|
427 |
+
logger.info(f"{url_or_filename} is a local path or file")
|
428 |
+
output_path = url_or_filename
|
429 |
+
elif is_remote_url(url_or_filename):
|
430 |
+
# URL, so get it from the cache (downloading if necessary)
|
431 |
+
output_path = download_and_cache(
|
432 |
+
url_or_filename,
|
433 |
+
cache_dir=cache_dir,
|
434 |
+
force_download=force_download,
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
if filenames is None:
|
438 |
+
filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME]
|
439 |
+
output_path = download_from_hf(
|
440 |
+
url_or_filename,
|
441 |
+
filenames,
|
442 |
+
cache_dir,
|
443 |
+
force_download,
|
444 |
+
resume_download,
|
445 |
+
proxies,
|
446 |
+
use_auth_token,
|
447 |
+
revision,
|
448 |
+
local_files_only,
|
449 |
+
subfolder,
|
450 |
+
)
|
451 |
+
|
452 |
+
# if is_hf_hub_url(url_or_filename):
|
453 |
+
# HuggingFace Hub
|
454 |
+
# output_path = hf_hub_download_url(url_or_filename)
|
455 |
+
# elif is_remote_url(url_or_filename):
|
456 |
+
# # URL, so get it from the cache (downloading if necessary)
|
457 |
+
# output_path = download_and_cache(
|
458 |
+
# url_or_filename,
|
459 |
+
# cache_dir=cache_dir,
|
460 |
+
# force_download=force_download,
|
461 |
+
# )
|
462 |
+
# elif file_exists(url_or_filename):
|
463 |
+
# logger.info(f"{url_or_filename} is a local path or file")
|
464 |
+
# # File, and it exists.
|
465 |
+
# output_path = url_or_filename
|
466 |
+
# elif urlparse(url_or_filename).scheme == "":
|
467 |
+
# # File, but it doesn't exist.
|
468 |
+
# raise EnvironmentError(f"file {url_or_filename} not found")
|
469 |
+
# else:
|
470 |
+
# # Something unknown
|
471 |
+
# raise ValueError(
|
472 |
+
# f"unable to parse {url_or_filename} as a URL or as a local path"
|
473 |
+
# )
|
474 |
+
|
475 |
+
if dir_exists(output_path) or (
|
476 |
+
not is_zipfile(output_path) and not tarfile.is_tarfile(output_path)
|
477 |
+
):
|
478 |
+
return Path(output_path)
|
479 |
+
|
480 |
+
# Path where we extract compressed archives
|
481 |
+
# for now it will extract it in the same folder
|
482 |
+
# maybe implement extraction in the sapienzanlp folder
|
483 |
+
# when using local archive path?
|
484 |
+
logger.info("Extracting compressed archive")
|
485 |
+
output_dir, output_file = os.path.split(output_path)
|
486 |
+
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
487 |
+
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
488 |
+
|
489 |
+
# already extracted, do not extract
|
490 |
+
if (
|
491 |
+
os.path.isdir(output_path_extracted)
|
492 |
+
and os.listdir(output_path_extracted)
|
493 |
+
and not force_download
|
494 |
+
):
|
495 |
+
return Path(output_path_extracted)
|
496 |
+
|
497 |
+
# Prevent parallel extractions
|
498 |
+
lock_path = output_path + ".lock"
|
499 |
+
with FileLock(lock_path):
|
500 |
+
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
501 |
+
os.makedirs(output_path_extracted)
|
502 |
+
if is_zipfile(output_path):
|
503 |
+
with ZipFile(output_path, "r") as zip_file:
|
504 |
+
zip_file.extractall(output_path_extracted)
|
505 |
+
zip_file.close()
|
506 |
+
elif tarfile.is_tarfile(output_path):
|
507 |
+
tar_file = tarfile.open(output_path)
|
508 |
+
tar_file.extractall(output_path_extracted)
|
509 |
+
tar_file.close()
|
510 |
+
else:
|
511 |
+
raise EnvironmentError(
|
512 |
+
f"Archive format of {output_path} could not be identified"
|
513 |
+
)
|
514 |
+
|
515 |
+
# remove lock file, is it safe?
|
516 |
+
os.remove(lock_path)
|
517 |
+
|
518 |
+
return Path(output_path_extracted)
|
519 |
+
|
520 |
+
|
521 |
+
def is_str_a_path(maybe_path: str) -> bool:
|
522 |
+
"""
|
523 |
+
Check if a string is a path.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
maybe_path (`str`): The string to check.
|
527 |
+
|
528 |
+
Returns:
|
529 |
+
`bool`: `True` if the string is a path, `False` otherwise.
|
530 |
+
"""
|
531 |
+
# first check if it is a path
|
532 |
+
if Path(maybe_path).exists():
|
533 |
+
return True
|
534 |
+
# check if it is a relative path
|
535 |
+
if Path(os.path.join(os.getcwd(), maybe_path)).exists():
|
536 |
+
return True
|
537 |
+
# otherwise it is not a path
|
538 |
+
return False
|
539 |
+
|
540 |
+
|
541 |
+
def relative_to_absolute_path(path: str) -> os.PathLike:
|
542 |
+
"""
|
543 |
+
Convert a relative path to an absolute path.
|
544 |
+
|
545 |
+
Args:
|
546 |
+
path (`str`): The relative path to convert.
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
`os.PathLike`: The absolute path.
|
550 |
+
"""
|
551 |
+
if not is_str_a_path(path):
|
552 |
+
raise ValueError(f"{path} is not a path")
|
553 |
+
if Path(path).exists():
|
554 |
+
return Path(path).absolute()
|
555 |
+
if Path(os.path.join(os.getcwd(), path)).exists():
|
556 |
+
return Path(os.path.join(os.getcwd(), path)).absolute()
|
557 |
+
raise ValueError(f"{path} is not a path")
|
558 |
+
|
559 |
+
|
560 |
+
def to_config(object_to_save: Any) -> Dict[str, Any]:
|
561 |
+
"""
|
562 |
+
Convert an object to a dictionary.
|
563 |
+
|
564 |
+
Returns:
|
565 |
+
`Dict[str, Any]`: The dictionary representation of the object.
|
566 |
+
"""
|
567 |
+
|
568 |
+
def obj_to_dict(obj):
|
569 |
+
match obj:
|
570 |
+
case dict():
|
571 |
+
data = {}
|
572 |
+
for k, v in obj.items():
|
573 |
+
data[k] = obj_to_dict(v)
|
574 |
+
return data
|
575 |
+
|
576 |
+
case list() | tuple():
|
577 |
+
return [obj_to_dict(x) for x in obj]
|
578 |
+
|
579 |
+
case object(__dict__=_):
|
580 |
+
data = {
|
581 |
+
"_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
|
582 |
+
}
|
583 |
+
for k, v in obj.__dict__.items():
|
584 |
+
if not k.startswith("_"):
|
585 |
+
data[k] = obj_to_dict(v)
|
586 |
+
return data
|
587 |
+
|
588 |
+
case _:
|
589 |
+
return obj
|
590 |
+
|
591 |
+
return obj_to_dict(object_to_save)
|
592 |
+
|
593 |
+
|
594 |
+
def get_callable_from_string(callable_fn: str) -> Any:
|
595 |
+
"""
|
596 |
+
Get a callable from a string.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
callable_fn (`str`):
|
600 |
+
The string representation of the callable.
|
601 |
+
|
602 |
+
Returns:
|
603 |
+
`Any`: The callable.
|
604 |
+
"""
|
605 |
+
# separate the function name from the module name
|
606 |
+
module_name, function_name = callable_fn.rsplit(".", 1)
|
607 |
+
# import the module
|
608 |
+
module = importlib.import_module(module_name)
|
609 |
+
# get the function
|
610 |
+
return getattr(module, function_name)
|
relik/inference/__init__.py
ADDED
File without changes
|
relik/inference/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (183 Bytes). View file
|
|
relik/inference/__pycache__/annotator.cpython-310.pyc
ADDED
Binary file (22.7 kB). View file
|
|
relik/inference/annotator.py
ADDED
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
import hydra
|
8 |
+
import torch
|
9 |
+
from omegaconf import DictConfig, OmegaConf
|
10 |
+
from pprintpp import pformat
|
11 |
+
|
12 |
+
from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter
|
13 |
+
from relik.common.log import get_logger
|
14 |
+
from relik.common.upload import get_logged_in_username, upload
|
15 |
+
from relik.common.utils import CONFIG_NAME, from_cache
|
16 |
+
from relik.inference.data.objects import (
|
17 |
+
AnnotationType,
|
18 |
+
RelikOutput,
|
19 |
+
Span,
|
20 |
+
TaskType,
|
21 |
+
Triples,
|
22 |
+
)
|
23 |
+
from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
|
24 |
+
from relik.inference.data.splitters.spacy_sentence_splitter import SpacySentenceSplitter
|
25 |
+
from relik.inference.data.splitters.window_based_splitter import WindowSentenceSplitter
|
26 |
+
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
27 |
+
from relik.inference.data.window.manager import WindowManager
|
28 |
+
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
29 |
+
from relik.reader.pytorch_modules.base import RelikReaderBase
|
30 |
+
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
|
31 |
+
from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
|
32 |
+
from relik.retriever.indexers.base import BaseDocumentIndex
|
33 |
+
from relik.retriever.indexers.document import Document
|
34 |
+
from relik.retriever.pytorch_modules import PRECISION_MAP
|
35 |
+
from relik.retriever.pytorch_modules.model import GoldenRetriever
|
36 |
+
|
37 |
+
# set tokenizers parallelism to False
|
38 |
+
|
39 |
+
os.environ["TOKENIZERS_PARALLELISM"] = os.getenv("TOKENIZERS_PARALLELISM", "false")
|
40 |
+
|
41 |
+
LOG_QUERY = os.getenv("RELIK_LOG_QUERY_ON_FILE", "false").lower() == "true"
|
42 |
+
|
43 |
+
logger = get_logger(__name__, level=logging.INFO)
|
44 |
+
file_logger = None
|
45 |
+
if LOG_QUERY:
|
46 |
+
RELIK_LOG_PATH = Path(__file__).parent.parent.parent / "relik.log"
|
47 |
+
# create file handler which logs even debug messages
|
48 |
+
fh = logging.FileHandler(RELIK_LOG_PATH)
|
49 |
+
fh.setLevel(logging.INFO)
|
50 |
+
file_logger = get_logger("relik", level=logging.INFO)
|
51 |
+
file_logger.addHandler(fh)
|
52 |
+
|
53 |
+
|
54 |
+
class Relik:
|
55 |
+
"""
|
56 |
+
Relik main class. It is a wrapper around a retriever and a reader.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
retriever (:obj:`GoldenRetriever`):
|
60 |
+
The retriever to use.
|
61 |
+
reader (:obj:`RelikReaderBase`):
|
62 |
+
The reader to use.
|
63 |
+
document_index (:obj:`BaseDocumentIndex`, `optional`):
|
64 |
+
The document index to use. If `None`, the retriever's document index will be used.
|
65 |
+
device (`str`, `optional`, defaults to `cpu`):
|
66 |
+
The device to use for both the retriever and the reader.
|
67 |
+
retriever_device (`str`, `optional`, defaults to `None`):
|
68 |
+
The device to use for the retriever. If `None`, the `device` argument will be used.
|
69 |
+
document_index_device (`str`, `optional`, defaults to `None`):
|
70 |
+
The device to use for the document index. If `None`, the `device` argument will be used.
|
71 |
+
reader_device (`str`, `optional`, defaults to `None`):
|
72 |
+
The device to use for the reader. If `None`, the `device` argument will be used.
|
73 |
+
precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `32`):
|
74 |
+
The precision to use for both the retriever and the reader.
|
75 |
+
retriever_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`):
|
76 |
+
The precision to use for the retriever. If `None`, the `precision` argument will be used.
|
77 |
+
document_index_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`):
|
78 |
+
The precision to use for the document index. If `None`, the `precision` argument will be used.
|
79 |
+
reader_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`):
|
80 |
+
The precision to use for the reader. If `None`, the `precision` argument will be used.
|
81 |
+
metadata_fields (`list[str]`, `optional`, defaults to `None`):
|
82 |
+
The fields to add to the candidates for the reader.
|
83 |
+
top_k (`int`, `optional`, defaults to `None`):
|
84 |
+
The number of candidates to retrieve for each window.
|
85 |
+
window_size (`int`, `optional`, defaults to `None`):
|
86 |
+
The size of the window. If `None`, the whole text will be annotated.
|
87 |
+
window_stride (`int`, `optional`, defaults to `None`):
|
88 |
+
The stride of the window. If `None`, there will be no overlap between windows.
|
89 |
+
**kwargs:
|
90 |
+
Additional keyword arguments to pass to the retriever and the reader.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
retriever: GoldenRetriever | DictConfig | Dict | None = None,
|
96 |
+
reader: RelikReaderBase | DictConfig | None = None,
|
97 |
+
device: str | None = None,
|
98 |
+
retriever_device: str | None = None,
|
99 |
+
document_index_device: str | None = None,
|
100 |
+
reader_device: str | None = None,
|
101 |
+
precision: int | str | torch.dtype | None = None,
|
102 |
+
retriever_precision: int | str | torch.dtype | None = None,
|
103 |
+
document_index_precision: int | str | torch.dtype | None = None,
|
104 |
+
reader_precision: int | str | torch.dtype | None = None,
|
105 |
+
task: TaskType | str = TaskType.SPAN,
|
106 |
+
metadata_fields: list[str] | None = None,
|
107 |
+
top_k: int | None = None,
|
108 |
+
window_size: int | str | None = None,
|
109 |
+
window_stride: int | None = None,
|
110 |
+
retriever_kwargs: Dict[str, Any] | None = None,
|
111 |
+
reader_kwargs: Dict[str, Any] | None = None,
|
112 |
+
**kwargs,
|
113 |
+
) -> None:
|
114 |
+
# parse task into a TaskType
|
115 |
+
if isinstance(task, str):
|
116 |
+
try:
|
117 |
+
task = TaskType(task.lower())
|
118 |
+
except ValueError:
|
119 |
+
raise ValueError(
|
120 |
+
f"Task `{task}` not recognized. "
|
121 |
+
f"Please choose one of {list(TaskType)}."
|
122 |
+
)
|
123 |
+
self.task = task
|
124 |
+
|
125 |
+
# organize devices
|
126 |
+
if device is not None:
|
127 |
+
if retriever_device is None:
|
128 |
+
retriever_device = device
|
129 |
+
if document_index_device is None:
|
130 |
+
document_index_device = device
|
131 |
+
if reader_device is None:
|
132 |
+
reader_device = device
|
133 |
+
|
134 |
+
# organize precision
|
135 |
+
if precision is not None:
|
136 |
+
if retriever_precision is None:
|
137 |
+
retriever_precision = precision
|
138 |
+
if document_index_precision is None:
|
139 |
+
document_index_precision = precision
|
140 |
+
if reader_precision is None:
|
141 |
+
reader_precision = precision
|
142 |
+
|
143 |
+
# retriever
|
144 |
+
self.retriever: Dict[TaskType, GoldenRetriever] = {
|
145 |
+
TaskType.SPAN: None,
|
146 |
+
TaskType.TRIPLET: None,
|
147 |
+
}
|
148 |
+
|
149 |
+
if retriever:
|
150 |
+
# check retriever type, it can be a GoldenRetriever, a DictConfig or a Dict
|
151 |
+
if not isinstance(retriever, (GoldenRetriever, DictConfig, Dict)):
|
152 |
+
raise ValueError(
|
153 |
+
f"`retriever` must be a `GoldenRetriever`, a `DictConfig` or "
|
154 |
+
f"a `Dict`, got `{type(retriever)}`."
|
155 |
+
)
|
156 |
+
|
157 |
+
# we need to check weather the DictConfig is a DictConfig for an instance of GoldenRetriever
|
158 |
+
# or a primitive Dict
|
159 |
+
if isinstance(retriever, DictConfig):
|
160 |
+
# then it is probably a primitive Dict
|
161 |
+
if "_target_" not in retriever:
|
162 |
+
retriever = OmegaConf.to_container(retriever, resolve=True)
|
163 |
+
# convert the key to TaskType
|
164 |
+
try:
|
165 |
+
retriever = {
|
166 |
+
TaskType(k.lower()): v for k, v in retriever.items()
|
167 |
+
}
|
168 |
+
except ValueError as e:
|
169 |
+
raise ValueError(
|
170 |
+
f"Please choose a valid task type (one of {list(TaskType)}) for each retriever."
|
171 |
+
) from e
|
172 |
+
|
173 |
+
if isinstance(retriever, Dict):
|
174 |
+
# convert the key to TaskType
|
175 |
+
retriever = {TaskType(k): v for k, v in retriever.items()}
|
176 |
+
else:
|
177 |
+
retriever = {task: retriever}
|
178 |
+
|
179 |
+
# instantiate each retriever
|
180 |
+
if self.task in [TaskType.SPAN, TaskType.BOTH]:
|
181 |
+
self.retriever[TaskType.SPAN] = self._instantiate_retriever(
|
182 |
+
retriever[TaskType.SPAN],
|
183 |
+
retriever_device,
|
184 |
+
retriever_precision,
|
185 |
+
None,
|
186 |
+
document_index_device,
|
187 |
+
document_index_precision,
|
188 |
+
)
|
189 |
+
if self.task in [TaskType.TRIPLET, TaskType.BOTH]:
|
190 |
+
self.retriever[TaskType.TRIPLET] = self._instantiate_retriever(
|
191 |
+
retriever[TaskType.TRIPLET],
|
192 |
+
retriever_device,
|
193 |
+
retriever_precision,
|
194 |
+
None,
|
195 |
+
document_index_device,
|
196 |
+
document_index_precision,
|
197 |
+
)
|
198 |
+
|
199 |
+
# clean up None retrievers from the dictionary
|
200 |
+
self.retriever = {
|
201 |
+
task_type: r for task_type, r in self.retriever.items() if r is not None
|
202 |
+
}
|
203 |
+
# torch compile
|
204 |
+
# self.retriever = {task_type: torch.compile(r, backend="onnxrt") for task_type, r in self.retriever.items()}
|
205 |
+
|
206 |
+
# reader
|
207 |
+
self.reader: RelikReaderBase | None = None
|
208 |
+
if reader:
|
209 |
+
reader = (
|
210 |
+
hydra.utils.instantiate(
|
211 |
+
reader,
|
212 |
+
device=reader_device,
|
213 |
+
precision=reader_precision,
|
214 |
+
)
|
215 |
+
if isinstance(reader, DictConfig)
|
216 |
+
else reader
|
217 |
+
)
|
218 |
+
reader.training = False
|
219 |
+
reader.eval()
|
220 |
+
if reader_device is not None:
|
221 |
+
logger.info(f"Moving reader to `{reader_device}`.")
|
222 |
+
reader.to(reader_device)
|
223 |
+
if reader_precision is not None and reader.precision != PRECISION_MAP[reader_precision]:
|
224 |
+
logger.info(
|
225 |
+
f"Setting precision of reader to `{PRECISION_MAP[reader_precision]}`."
|
226 |
+
)
|
227 |
+
reader.to(PRECISION_MAP[reader_precision])
|
228 |
+
self.reader = reader
|
229 |
+
# self.reader = torch.compile(self.reader, backend="tvm")
|
230 |
+
|
231 |
+
# windowization stuff
|
232 |
+
self.tokenizer = SpacyTokenizer(language="en") # TODO: parametrize?
|
233 |
+
self.sentence_splitter: BaseSentenceSplitter | None = None
|
234 |
+
self.window_manager: WindowManager | None = None
|
235 |
+
|
236 |
+
if metadata_fields is None:
|
237 |
+
metadata_fields = []
|
238 |
+
self.metadata_fields = metadata_fields
|
239 |
+
|
240 |
+
# inference params
|
241 |
+
self.top_k = top_k
|
242 |
+
self.window_size = window_size
|
243 |
+
self.window_stride = window_stride
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
def _instantiate_retriever(
|
247 |
+
retriever,
|
248 |
+
retriever_device,
|
249 |
+
retriever_precision,
|
250 |
+
document_index,
|
251 |
+
document_index_device,
|
252 |
+
document_index_precision,
|
253 |
+
):
|
254 |
+
if not isinstance(retriever, GoldenRetriever):
|
255 |
+
# convert to DictConfig
|
256 |
+
retriever = hydra.utils.instantiate(
|
257 |
+
OmegaConf.create(retriever),
|
258 |
+
device=retriever_device,
|
259 |
+
precision=retriever_precision,
|
260 |
+
index_device=document_index_device,
|
261 |
+
index_precision=document_index_precision,
|
262 |
+
)
|
263 |
+
retriever.training = False
|
264 |
+
retriever.eval()
|
265 |
+
if document_index is not None:
|
266 |
+
if retriever.document_index is not None:
|
267 |
+
logger.info(
|
268 |
+
"The Retriever already has a document index, replacing it with the provided one."
|
269 |
+
"If you want to keep using the old one, please do not provide a document index."
|
270 |
+
)
|
271 |
+
retriever.document_index = document_index
|
272 |
+
# we override the device and the precision of the document index if provided
|
273 |
+
if document_index_device is not None:
|
274 |
+
logger.info(f"Moving document index to `{document_index_device}`.")
|
275 |
+
retriever.document_index.to(document_index_device)
|
276 |
+
if document_index_precision is not None:
|
277 |
+
logger.info(
|
278 |
+
f"Setting precision of document index to `{PRECISION_MAP[document_index_precision]}`."
|
279 |
+
)
|
280 |
+
retriever.document_index.to(PRECISION_MAP[document_index_precision])
|
281 |
+
# retriever.document_index = document_index
|
282 |
+
# now we can move the retriever to the right device and set the precision
|
283 |
+
if retriever_device is not None:
|
284 |
+
logger.info(f"Moving retriever to `{retriever_device}`.")
|
285 |
+
retriever.to(retriever_device)
|
286 |
+
if retriever_precision is not None:
|
287 |
+
logger.info(
|
288 |
+
f"Setting precision of retriever to `{PRECISION_MAP[retriever_precision]}`."
|
289 |
+
)
|
290 |
+
retriever.to(PRECISION_MAP[retriever_precision])
|
291 |
+
return retriever
|
292 |
+
|
293 |
+
def __call__(
|
294 |
+
self,
|
295 |
+
text: str | List[str] | None = None,
|
296 |
+
windows: List[RelikReaderSample] | None = None,
|
297 |
+
candidates: List[str]
|
298 |
+
| List[Document]
|
299 |
+
| Dict[TaskType, List[Document]]
|
300 |
+
| None = None,
|
301 |
+
mentions: List[List[int]] | List[List[List[int]]] | None = None,
|
302 |
+
top_k: int | None = None,
|
303 |
+
window_size: int | None = None,
|
304 |
+
window_stride: int | None = None,
|
305 |
+
is_split_into_words: bool = False,
|
306 |
+
retriever_batch_size: int | None = 32,
|
307 |
+
reader_batch_size: int | None = 32,
|
308 |
+
return_also_windows: bool = False,
|
309 |
+
annotation_type: str | AnnotationType = AnnotationType.CHAR,
|
310 |
+
progress_bar: bool = False,
|
311 |
+
**kwargs,
|
312 |
+
) -> Union[RelikOutput, list[RelikOutput]]:
|
313 |
+
"""
|
314 |
+
Annotate a text with entities.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
text (`str` or `list`):
|
318 |
+
The text to annotate. If a list is provided, each element of the list
|
319 |
+
will be annotated separately.
|
320 |
+
candidates (`list[str]`, `list[Document]`, `optional`, defaults to `None`):
|
321 |
+
The candidates to use for the reader. If `None`, the candidates will be
|
322 |
+
retrieved from the retriever.
|
323 |
+
mentions (`list[list[int]]` or `list[list[list[int]]]`, `optional`, defaults to `None`):
|
324 |
+
The mentions to use for the reader. If `None`, the mentions will be
|
325 |
+
predicted by the reader.
|
326 |
+
top_k (`int`, `optional`, defaults to `None`):
|
327 |
+
The number of candidates to retrieve for each window.
|
328 |
+
window_size (`int`, `optional`, defaults to `None`):
|
329 |
+
The size of the window. If `None`, the whole text will be annotated.
|
330 |
+
window_stride (`int`, `optional`, defaults to `None`):
|
331 |
+
The stride of the window. If `None`, there will be no overlap between windows.
|
332 |
+
retriever_batch_size (`int`, `optional`, defaults to `None`):
|
333 |
+
The batch size to use for the retriever. The whole input is the batch for the retriever.
|
334 |
+
reader_batch_size (`int`, `optional`, defaults to `None`):
|
335 |
+
The batch size to use for the reader. The whole input is the batch for the reader.
|
336 |
+
return_also_windows (`bool`, `optional`, defaults to `False`):
|
337 |
+
Whether to return the windows in the output.
|
338 |
+
annotation_type (`str` or `AnnotationType`, `optional`, defaults to `char`):
|
339 |
+
The type of annotation to return. If `char`, the spans will be in terms of
|
340 |
+
character offsets. If `word`, the spans will be in terms of word offsets.
|
341 |
+
**kwargs:
|
342 |
+
Additional keyword arguments to pass to the retriever and the reader.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
`RelikOutput` or `list[RelikOutput]`:
|
346 |
+
The annotated text. If a list was provided as input, a list of
|
347 |
+
`RelikOutput` objects will be returned.
|
348 |
+
"""
|
349 |
+
|
350 |
+
if text is None and windows is None:
|
351 |
+
raise ValueError(
|
352 |
+
"Either `text` or `windows` must be provided. Both are `None`."
|
353 |
+
)
|
354 |
+
|
355 |
+
if isinstance(annotation_type, str):
|
356 |
+
try:
|
357 |
+
annotation_type = AnnotationType(annotation_type)
|
358 |
+
except ValueError:
|
359 |
+
raise ValueError(
|
360 |
+
f"Annotation type {annotation_type} not recognized. "
|
361 |
+
f"Please choose one of {list(AnnotationType)}."
|
362 |
+
)
|
363 |
+
|
364 |
+
if top_k is None:
|
365 |
+
top_k = self.top_k or 100
|
366 |
+
if window_size is None:
|
367 |
+
window_size = self.window_size
|
368 |
+
if window_stride is None:
|
369 |
+
window_stride = self.window_stride
|
370 |
+
|
371 |
+
if text:
|
372 |
+
if isinstance(text, str):
|
373 |
+
text = [text]
|
374 |
+
if mentions is not None:
|
375 |
+
mentions = [mentions]
|
376 |
+
if file_logger is not None:
|
377 |
+
file_logger.info("Annotating the following text:")
|
378 |
+
for t in text:
|
379 |
+
file_logger.info(f" {t}")
|
380 |
+
|
381 |
+
if self.window_manager is None:
|
382 |
+
if window_size == "none":
|
383 |
+
self.sentence_splitter = BlankSentenceSplitter()
|
384 |
+
elif window_size == "sentence":
|
385 |
+
self.sentence_splitter = SpacySentenceSplitter()
|
386 |
+
else:
|
387 |
+
self.sentence_splitter = WindowSentenceSplitter(
|
388 |
+
window_size=window_size, window_stride=window_stride
|
389 |
+
)
|
390 |
+
self.window_manager = WindowManager(
|
391 |
+
self.tokenizer, self.sentence_splitter
|
392 |
+
)
|
393 |
+
|
394 |
+
if (
|
395 |
+
window_size not in ["sentence", "none"]
|
396 |
+
and window_stride is not None
|
397 |
+
and window_size < window_stride
|
398 |
+
):
|
399 |
+
raise ValueError(
|
400 |
+
f"Window size ({window_size}) must be greater than window stride ({window_stride})"
|
401 |
+
)
|
402 |
+
|
403 |
+
if windows is None:
|
404 |
+
# windows were provided, use them
|
405 |
+
windows, blank_windows = self.window_manager.create_windows(
|
406 |
+
text,
|
407 |
+
window_size,
|
408 |
+
window_stride,
|
409 |
+
is_split_into_words=is_split_into_words,
|
410 |
+
mentions=mentions
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
blank_windows = []
|
414 |
+
text = {w.doc_id: w.text for w in windows}
|
415 |
+
|
416 |
+
if candidates is not None and any(
|
417 |
+
r is not None for r in self.retriever.values()
|
418 |
+
):
|
419 |
+
logger.info(
|
420 |
+
"Both candidates and a retriever were provided. "
|
421 |
+
"Retriever will be ignored."
|
422 |
+
)
|
423 |
+
|
424 |
+
windows_candidates = {TaskType.SPAN: None, TaskType.TRIPLET: None}
|
425 |
+
if candidates is not None:
|
426 |
+
# again, check if candidates is a dict
|
427 |
+
if isinstance(candidates, Dict):
|
428 |
+
if self.task not in candidates:
|
429 |
+
raise ValueError(
|
430 |
+
f"Task `{self.task}` not found in `candidates`."
|
431 |
+
f"Please choose one of {list(TaskType)}."
|
432 |
+
)
|
433 |
+
else:
|
434 |
+
candidates = {self.task: candidates}
|
435 |
+
|
436 |
+
for task_type, _candidates in candidates.items():
|
437 |
+
if isinstance(_candidates, list):
|
438 |
+
_candidates = [
|
439 |
+
[
|
440 |
+
c if isinstance(c, Document) else Document(c)
|
441 |
+
for c in _candidates[w.doc_id]
|
442 |
+
]
|
443 |
+
for w in windows
|
444 |
+
]
|
445 |
+
windows_candidates[task_type] = _candidates
|
446 |
+
|
447 |
+
else:
|
448 |
+
# retrieve candidates first
|
449 |
+
if self.retriever is None:
|
450 |
+
raise ValueError(
|
451 |
+
"No retriever was provided, please provide a retriever or candidates."
|
452 |
+
)
|
453 |
+
start_retr = time.time()
|
454 |
+
for task_type, retriever in self.retriever.items():
|
455 |
+
retriever_out = retriever.retrieve(
|
456 |
+
[w.text for w in windows],
|
457 |
+
text_pair=[w.doc_topic.text if w.doc_topic is not None else None for w in windows],
|
458 |
+
k=top_k,
|
459 |
+
batch_size=retriever_batch_size,
|
460 |
+
progress_bar=progress_bar,
|
461 |
+
**kwargs,
|
462 |
+
)
|
463 |
+
windows_candidates[task_type] = [
|
464 |
+
[p.document for p in predictions] for predictions in retriever_out
|
465 |
+
]
|
466 |
+
end_retr = time.time()
|
467 |
+
logger.info(f"Retrieval took {end_retr - start_retr} seconds.")
|
468 |
+
|
469 |
+
# clean up None's
|
470 |
+
windows_candidates = {
|
471 |
+
t: c for t, c in windows_candidates.items() if c is not None
|
472 |
+
}
|
473 |
+
|
474 |
+
# add passage to the windows
|
475 |
+
for task_type, task_candidates in windows_candidates.items():
|
476 |
+
for window, candidates in zip(windows, task_candidates):
|
477 |
+
# construct the candidates for the reader
|
478 |
+
formatted_candidates = []
|
479 |
+
for candidate in candidates:
|
480 |
+
window_candidate_text = candidate.text
|
481 |
+
for field in self.metadata_fields:
|
482 |
+
window_candidate_text += f"{candidate.metadata.get(field, '')}"
|
483 |
+
formatted_candidates.append(window_candidate_text)
|
484 |
+
# create a member for the windows that is named like the task
|
485 |
+
setattr(window, f"{task_type.value}_candidates", formatted_candidates)
|
486 |
+
|
487 |
+
for task_type, task_candidates in windows_candidates.items():
|
488 |
+
for window in blank_windows:
|
489 |
+
setattr(window, f"{task_type.value}_candidates", [])
|
490 |
+
setattr(window, "predicted_spans", [])
|
491 |
+
setattr(window, "predicted_triples", [])
|
492 |
+
if self.reader is not None:
|
493 |
+
start_read = time.time()
|
494 |
+
windows = self.reader.read(
|
495 |
+
samples=windows,
|
496 |
+
max_batch_size=reader_batch_size,
|
497 |
+
annotation_type=annotation_type,
|
498 |
+
progress_bar=progress_bar,
|
499 |
+
**kwargs,
|
500 |
+
)
|
501 |
+
end_read = time.time()
|
502 |
+
logger.info(f"Reading took {end_read - start_read} seconds.")
|
503 |
+
# TODO: check merging behavior without a reader
|
504 |
+
# do we want to merge windows if there is no reader?
|
505 |
+
|
506 |
+
if self.window_size is not None and self.window_size not in ["sentence", "none"]:
|
507 |
+
start_w = time.time()
|
508 |
+
windows = windows + blank_windows
|
509 |
+
windows.sort(key=lambda x: (x.doc_id, x.offset))
|
510 |
+
merged_windows = self.window_manager.merge_windows(windows)
|
511 |
+
end_w = time.time()
|
512 |
+
logger.info(f"Merging took {end_w - start_w} seconds.")
|
513 |
+
else:
|
514 |
+
merged_windows = windows
|
515 |
+
else:
|
516 |
+
windows = windows + blank_windows
|
517 |
+
windows.sort(key=lambda x: (x.doc_id, x.offset))
|
518 |
+
merged_windows = windows
|
519 |
+
|
520 |
+
# transform predictions into RelikOutput objects
|
521 |
+
output = []
|
522 |
+
for w in merged_windows:
|
523 |
+
span_labels = []
|
524 |
+
triples_labels = []
|
525 |
+
# span extraction should always be present
|
526 |
+
if getattr(w, "predicted_spans", None) is not None:
|
527 |
+
span_labels = sorted(
|
528 |
+
[
|
529 |
+
Span(start=ss, end=se, label=sl, text=text[w.doc_id][ss:se])
|
530 |
+
if annotation_type == AnnotationType.CHAR
|
531 |
+
else Span(start=ss, end=se, label=sl, text=w.words[ss:se])
|
532 |
+
for ss, se, sl in w.predicted_spans
|
533 |
+
],
|
534 |
+
key=lambda x: x.start,
|
535 |
+
)
|
536 |
+
# triple extraction is optional, if here add it
|
537 |
+
if getattr(w, "predicted_triples", None) is not None:
|
538 |
+
triples_labels = [
|
539 |
+
Triples(
|
540 |
+
subject=span_labels[subj],
|
541 |
+
label=label,
|
542 |
+
object=span_labels[obj],
|
543 |
+
confidence=conf,
|
544 |
+
)
|
545 |
+
for subj, label, obj, conf in w.predicted_triples
|
546 |
+
]
|
547 |
+
# create the output
|
548 |
+
sample_output = RelikOutput(
|
549 |
+
text=text[w.doc_id],
|
550 |
+
tokens=w.words,
|
551 |
+
spans=span_labels,
|
552 |
+
triples=triples_labels,
|
553 |
+
candidates={
|
554 |
+
task_type: [
|
555 |
+
r.document_index.documents.get_document_from_text(c)
|
556 |
+
for c in getattr(w, f"{task_type.value}_candidates", [])
|
557 |
+
if r.document_index.documents.get_document_from_text(c) is not None
|
558 |
+
]
|
559 |
+
for task_type, r in self.retriever.items()
|
560 |
+
},
|
561 |
+
)
|
562 |
+
output.append(sample_output)
|
563 |
+
|
564 |
+
# add windows to the output if requested
|
565 |
+
# do we want to force windows to be returned if there is no reader?
|
566 |
+
if return_also_windows:
|
567 |
+
for i, sample_output in enumerate(output):
|
568 |
+
sample_output.windows = [w for w in windows if w.doc_id == i]
|
569 |
+
|
570 |
+
# if only one text was provided, return a single RelikOutput object
|
571 |
+
if len(output) == 1:
|
572 |
+
return output[0]
|
573 |
+
|
574 |
+
return output
|
575 |
+
|
576 |
+
@classmethod
|
577 |
+
def from_pretrained(
|
578 |
+
cls,
|
579 |
+
model_name_or_dir: Union[str, os.PathLike],
|
580 |
+
config_file_name: str = CONFIG_NAME,
|
581 |
+
*args,
|
582 |
+
**kwargs,
|
583 |
+
) -> "Relik":
|
584 |
+
"""
|
585 |
+
Instantiate a `Relik` from a pretrained model.
|
586 |
+
|
587 |
+
Args:
|
588 |
+
model_name_or_dir (`str` or `os.PathLike`):
|
589 |
+
The name or path of the model to load.
|
590 |
+
config_file_name (`str`, `optional`, defaults to `config.yaml`):
|
591 |
+
The name of the configuration file to load.
|
592 |
+
*args:
|
593 |
+
Additional positional arguments to pass to `OmegaConf.merge`.
|
594 |
+
**kwargs:
|
595 |
+
Additional keyword arguments to pass to `OmegaConf.merge`.
|
596 |
+
|
597 |
+
Returns:
|
598 |
+
`Relik`:
|
599 |
+
The instantiated `Relik`.
|
600 |
+
|
601 |
+
"""
|
602 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
603 |
+
force_download = kwargs.pop("force_download", False)
|
604 |
+
|
605 |
+
model_dir = from_cache(
|
606 |
+
model_name_or_dir,
|
607 |
+
filenames=[config_file_name],
|
608 |
+
cache_dir=cache_dir,
|
609 |
+
force_download=force_download,
|
610 |
+
)
|
611 |
+
|
612 |
+
config_path = model_dir / config_file_name
|
613 |
+
if not config_path.exists():
|
614 |
+
raise FileNotFoundError(
|
615 |
+
f"Model configuration file not found at {config_path}."
|
616 |
+
)
|
617 |
+
|
618 |
+
# overwrite config with config_kwargs
|
619 |
+
config = OmegaConf.load(config_path)
|
620 |
+
# if kwargs is not None:
|
621 |
+
config = OmegaConf.merge(config, OmegaConf.create(kwargs))
|
622 |
+
# do we want to print the config? I like it
|
623 |
+
logger.info(f"Loading Relik from {model_name_or_dir}")
|
624 |
+
logger.info(pformat(OmegaConf.to_container(config)))
|
625 |
+
|
626 |
+
# load relik from config
|
627 |
+
relik = hydra.utils.instantiate(config, _recursive_=False, *args)
|
628 |
+
|
629 |
+
return relik
|
630 |
+
|
631 |
+
def save_pretrained(
|
632 |
+
self,
|
633 |
+
output_dir: Union[str, os.PathLike],
|
634 |
+
config: Optional[Dict[str, Any]] = None,
|
635 |
+
config_file_name: Optional[str] = None,
|
636 |
+
save_weights: bool = False,
|
637 |
+
push_to_hub: bool = False,
|
638 |
+
model_id: Optional[str] = None,
|
639 |
+
organization: Optional[str] = None,
|
640 |
+
repo_name: Optional[str] = None,
|
641 |
+
retriever_model_id: Optional[str] = None,
|
642 |
+
reader_model_id: Optional[str] = None,
|
643 |
+
**kwargs,
|
644 |
+
):
|
645 |
+
"""
|
646 |
+
Save the configuration of Relik to the specified directory as a YAML file.
|
647 |
+
|
648 |
+
Args:
|
649 |
+
output_dir (`str`):
|
650 |
+
The directory to save the configuration file to.
|
651 |
+
config (`Optional[Dict[str, Any]]`, `optional`):
|
652 |
+
The configuration to save. If `None`, the current configuration will be
|
653 |
+
saved. Defaults to `None`.
|
654 |
+
config_file_name (`Optional[str]`, `optional`):
|
655 |
+
The name of the configuration file. Defaults to `config.yaml`.
|
656 |
+
save_weights (`bool`, `optional`):
|
657 |
+
Whether to save the weights of the model. Defaults to `False`.
|
658 |
+
push_to_hub (`bool`, `optional`):
|
659 |
+
Whether to push the saved model to the hub. Defaults to `False`.
|
660 |
+
model_id (`Optional[str]`, `optional`):
|
661 |
+
The id of the model to push to the hub. If `None`, the name of the
|
662 |
+
directory will be used. Defaults to `None`.
|
663 |
+
organization (`Optional[str]`, `optional`):
|
664 |
+
The organization to push the model to. Defaults to `None`.
|
665 |
+
repo_name (`Optional[str]`, `optional`):
|
666 |
+
The name of the repository to push the model to. Defaults to `None`.
|
667 |
+
retriever_model_id (`Optional[str]`, `optional`):
|
668 |
+
The id of the retriever model to push to the hub. If `None`, the name of the
|
669 |
+
directory will be used. Defaults to `None`.
|
670 |
+
reader_model_id (`Optional[str]`, `optional`):
|
671 |
+
The id of the reader model to push to the hub. If `None`, the name of the
|
672 |
+
directory will be used. Defaults to `None`.
|
673 |
+
**kwargs:
|
674 |
+
Additional keyword arguments to pass to `OmegaConf.save`.
|
675 |
+
"""
|
676 |
+
# create the output directory
|
677 |
+
output_dir = Path(output_dir)
|
678 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
679 |
+
|
680 |
+
retrievers_names: Dict[TaskType, Dict | None] = {
|
681 |
+
TaskType.SPAN: {
|
682 |
+
"question_encoder_name": None,
|
683 |
+
"passage_encoder_name": None,
|
684 |
+
"document_index_name": None,
|
685 |
+
},
|
686 |
+
TaskType.TRIPLET: {
|
687 |
+
"question_encoder_name": None,
|
688 |
+
"passage_encoder_name": None,
|
689 |
+
"document_index_name": None,
|
690 |
+
},
|
691 |
+
}
|
692 |
+
|
693 |
+
if save_weights:
|
694 |
+
# save weights
|
695 |
+
# retriever
|
696 |
+
model_id = model_id or output_dir.name
|
697 |
+
retriever_model_id = retriever_model_id or f"retriever-{model_id}"
|
698 |
+
for task_type, retriever in self.retriever.items():
|
699 |
+
if retriever is None:
|
700 |
+
continue
|
701 |
+
task_retriever_model_id = f"{retriever_model_id}-{task_type.value}"
|
702 |
+
question_encoder_name = f"{task_retriever_model_id}-question-encoder"
|
703 |
+
passage_encoder_name = f"{task_retriever_model_id}-passage-encoder"
|
704 |
+
document_index_name = f"{task_retriever_model_id}-index"
|
705 |
+
logger.info(
|
706 |
+
f"Saving retriever to {output_dir / task_retriever_model_id}"
|
707 |
+
)
|
708 |
+
retriever.save_pretrained(
|
709 |
+
output_dir / task_retriever_model_id,
|
710 |
+
question_encoder_name=question_encoder_name,
|
711 |
+
passage_encoder_name=passage_encoder_name,
|
712 |
+
document_index_name=document_index_name,
|
713 |
+
push_to_hub=push_to_hub,
|
714 |
+
organization=organization,
|
715 |
+
**kwargs,
|
716 |
+
)
|
717 |
+
retrievers_names[task_type] = {
|
718 |
+
"reader_model_id": task_retriever_model_id,
|
719 |
+
"question_encoder_name": question_encoder_name,
|
720 |
+
"passage_encoder_name": passage_encoder_name,
|
721 |
+
"document_index_name": document_index_name,
|
722 |
+
}
|
723 |
+
|
724 |
+
# reader
|
725 |
+
reader_model_id = reader_model_id or f"reader-{model_id}"
|
726 |
+
logger.info(f"Saving reader to {output_dir / reader_model_id}")
|
727 |
+
self.reader.save_pretrained(
|
728 |
+
output_dir / reader_model_id,
|
729 |
+
push_to_hub=push_to_hub,
|
730 |
+
organization=organization,
|
731 |
+
**kwargs,
|
732 |
+
)
|
733 |
+
|
734 |
+
if push_to_hub:
|
735 |
+
user = organization or get_logged_in_username()
|
736 |
+
# we need to update the config with the model ids that will
|
737 |
+
# result from the push to hub
|
738 |
+
for task_type, retriever_names in retrievers_names.items():
|
739 |
+
retriever_names[
|
740 |
+
"question_encoder_name"
|
741 |
+
] = f"{user}/{retriever_names['question_encoder_name']}"
|
742 |
+
retriever_names[
|
743 |
+
"passage_encoder_name"
|
744 |
+
] = f"{user}/{retriever_names['passage_encoder_name']}"
|
745 |
+
retriever_names[
|
746 |
+
"document_index_name"
|
747 |
+
] = f"{user}/{retriever_names['document_index_name']}"
|
748 |
+
# question_encoder_name = f"{user}/{question_encoder_name}"
|
749 |
+
# passage_encoder_name = f"{user}/{passage_encoder_name}"
|
750 |
+
# document_index_name = f"{user}/{document_index_name}"
|
751 |
+
reader_model_id = f"{user}/{reader_model_id}"
|
752 |
+
else:
|
753 |
+
for task_type, retriever_names in retrievers_names.items():
|
754 |
+
retriever_names["question_encoder_name"] = (
|
755 |
+
output_dir / retriever_names["question_encoder_name"]
|
756 |
+
)
|
757 |
+
retriever_names["passage_encoder_name"] = (
|
758 |
+
output_dir / retriever_names["passage_encoder_name"]
|
759 |
+
)
|
760 |
+
retriever_names["document_index_name"] = (
|
761 |
+
output_dir / retriever_names["document_index_name"]
|
762 |
+
)
|
763 |
+
reader_model_id = output_dir / reader_model_id
|
764 |
+
else:
|
765 |
+
# save config only
|
766 |
+
for task_type, retriever_names in retrievers_names.items():
|
767 |
+
retriever = self.retriever.get(task_type, None)
|
768 |
+
if retriever is None:
|
769 |
+
continue
|
770 |
+
retriever_names[
|
771 |
+
"question_encoder_name"
|
772 |
+
] = retriever.question_encoder.name_or_path
|
773 |
+
retriever_names[
|
774 |
+
"passage_encoder_name"
|
775 |
+
] = retriever.passage_encoder.name_or_path
|
776 |
+
retriever_names[
|
777 |
+
"document_index_name"
|
778 |
+
] = retriever.document_index.name_or_path
|
779 |
+
|
780 |
+
reader_model_id = self.reader.name_or_path
|
781 |
+
|
782 |
+
if config is None:
|
783 |
+
# create a default config
|
784 |
+
config = {
|
785 |
+
"_target_": f"{self.__class__.__module__}.{self.__class__.__name__}"
|
786 |
+
}
|
787 |
+
if self.retriever is not None:
|
788 |
+
config["retriever"] = {}
|
789 |
+
for task_type, retriever in self.retriever.items():
|
790 |
+
if retriever is None:
|
791 |
+
continue
|
792 |
+
config["retriever"][task_type.value] = {
|
793 |
+
"_target_": f"{retriever.__class__.__module__}.{retriever.__class__.__name__}",
|
794 |
+
}
|
795 |
+
if retriever.question_encoder is not None:
|
796 |
+
config["retriever"][task_type.value][
|
797 |
+
"question_encoder"
|
798 |
+
] = retrievers_names[task_type]["question_encoder_name"]
|
799 |
+
if (
|
800 |
+
retriever.passage_encoder is not None
|
801 |
+
and not retriever.passage_encoder_is_question_encoder
|
802 |
+
):
|
803 |
+
config["retriever"][task_type.value][
|
804 |
+
"passage_encoder"
|
805 |
+
] = retrievers_names[task_type]["passage_encoder_name"]
|
806 |
+
if retriever.document_index is not None:
|
807 |
+
config["retriever"][task_type.value][
|
808 |
+
"document_index"
|
809 |
+
] = retrievers_names[task_type]["document_index_name"]
|
810 |
+
if self.reader is not None:
|
811 |
+
config["reader"] = {
|
812 |
+
"_target_": f"{self.reader.__class__.__module__}.{self.reader.__class__.__name__}",
|
813 |
+
"transformer_model": reader_model_id,
|
814 |
+
}
|
815 |
+
|
816 |
+
# these are model-specific and should be saved
|
817 |
+
config["task"] = self.task
|
818 |
+
config["metadata_fields"] = self.metadata_fields
|
819 |
+
config["top_k"] = self.top_k
|
820 |
+
config["window_size"] = self.window_size
|
821 |
+
config["window_stride"] = self.window_stride
|
822 |
+
|
823 |
+
config_file_name = config_file_name or CONFIG_NAME
|
824 |
+
|
825 |
+
logger.info(f"Saving relik config to {output_dir / config_file_name}")
|
826 |
+
# pretty print the config
|
827 |
+
logger.info(pformat(config))
|
828 |
+
OmegaConf.save(config, output_dir / config_file_name)
|
829 |
+
|
830 |
+
if push_to_hub:
|
831 |
+
# push to hub
|
832 |
+
logger.info("Pushing to hub")
|
833 |
+
model_id = model_id or output_dir.name
|
834 |
+
upload(
|
835 |
+
output_dir,
|
836 |
+
model_id,
|
837 |
+
filenames=[config_file_name],
|
838 |
+
organization=organization,
|
839 |
+
repo_name=repo_name,
|
840 |
+
)
|
relik/inference/data/__init__.py
ADDED
File without changes
|
relik/inference/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (188 Bytes). View file
|
|
relik/inference/data/__pycache__/objects.cpython-310.pyc
ADDED
Binary file (3.24 kB). View file
|
|
relik/inference/data/objects.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Dict, List, NamedTuple, Optional
|
5 |
+
|
6 |
+
from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
|
7 |
+
from relik.retriever.indexers.document import Document
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class Word:
|
12 |
+
"""
|
13 |
+
A word representation that includes text, index in the sentence, POS tag, lemma,
|
14 |
+
dependency relation, and similar information.
|
15 |
+
|
16 |
+
# Parameters
|
17 |
+
text : `str`, optional
|
18 |
+
The text representation.
|
19 |
+
index : `int`, optional
|
20 |
+
The word offset in the sentence.
|
21 |
+
lemma : `str`, optional
|
22 |
+
The lemma of this word.
|
23 |
+
pos : `str`, optional
|
24 |
+
The coarse-grained part of speech of this word.
|
25 |
+
dep : `str`, optional
|
26 |
+
The dependency relation for this word.
|
27 |
+
|
28 |
+
input_id : `int`, optional
|
29 |
+
Integer representation of the word, used to pass it to a model.
|
30 |
+
token_type_id : `int`, optional
|
31 |
+
Token type id used by some transformers.
|
32 |
+
attention_mask: `int`, optional
|
33 |
+
Attention mask used by transformers, indicates to the model which tokens should
|
34 |
+
be attended to, and which should not.
|
35 |
+
"""
|
36 |
+
|
37 |
+
text: str
|
38 |
+
i: int
|
39 |
+
idx: Optional[int] = None
|
40 |
+
idx_end: Optional[int] = None
|
41 |
+
# preprocessing fields
|
42 |
+
lemma: Optional[str] = None
|
43 |
+
pos: Optional[str] = None
|
44 |
+
dep: Optional[str] = None
|
45 |
+
head: Optional[int] = None
|
46 |
+
|
47 |
+
def __str__(self):
|
48 |
+
return self.text
|
49 |
+
|
50 |
+
def __repr__(self):
|
51 |
+
return self.__str__()
|
52 |
+
|
53 |
+
|
54 |
+
class Span(NamedTuple):
|
55 |
+
start: int
|
56 |
+
end: int
|
57 |
+
label: str
|
58 |
+
text: str
|
59 |
+
|
60 |
+
|
61 |
+
class Triples(NamedTuple):
|
62 |
+
subject: Span
|
63 |
+
label: str
|
64 |
+
object: Span
|
65 |
+
confidence: float
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class RelikOutput:
|
69 |
+
text: str
|
70 |
+
tokens: List[str]
|
71 |
+
spans: List[Span]
|
72 |
+
triples: List[Triples]
|
73 |
+
candidates: Dict[TaskType, List[Document]]
|
74 |
+
windows: Optional[List[RelikReaderSample]] = None
|
75 |
+
|
76 |
+
|
77 |
+
from enum import Enum
|
78 |
+
|
79 |
+
|
80 |
+
class AnnotationType(Enum):
|
81 |
+
CHAR = "char"
|
82 |
+
WORD = "word"
|
83 |
+
|
84 |
+
|
85 |
+
class TaskType(Enum):
|
86 |
+
SPAN = "span"
|
87 |
+
TRIPLET = "triplet"
|
88 |
+
BOTH = "both"
|
relik/inference/data/splitters/__init__.py
ADDED
File without changes
|
relik/inference/data/splitters/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (198 Bytes). View file
|
|
relik/inference/data/splitters/__pycache__/base_sentence_splitter.cpython-310.pyc
ADDED
Binary file (2.38 kB). View file
|
|
relik/inference/data/splitters/__pycache__/blank_sentence_splitter.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
relik/inference/data/splitters/__pycache__/spacy_sentence_splitter.cpython-310.pyc
ADDED
Binary file (5.31 kB). View file
|
|
relik/inference/data/splitters/__pycache__/window_based_splitter.cpython-310.pyc
ADDED
Binary file (2.49 kB). View file
|
|
relik/inference/data/splitters/base_sentence_splitter.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
|
4 |
+
class BaseSentenceSplitter:
|
5 |
+
"""
|
6 |
+
A `BaseSentenceSplitter` splits strings into sentences.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __call__(self, *args, **kwargs):
|
10 |
+
"""
|
11 |
+
Calls :meth:`split_sentences`.
|
12 |
+
"""
|
13 |
+
return self.split_sentences(*args, **kwargs)
|
14 |
+
|
15 |
+
def split_sentences(
|
16 |
+
self, text: str, max_len: int = 0, *args, **kwargs
|
17 |
+
) -> List[str]:
|
18 |
+
"""
|
19 |
+
Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence.
|
20 |
+
"""
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
def split_sentences_batch(
|
24 |
+
self, texts: List[str], *args, **kwargs
|
25 |
+
) -> List[List[str]]:
|
26 |
+
"""
|
27 |
+
Default implementation is to just iterate over the texts and call `split_sentences`.
|
28 |
+
"""
|
29 |
+
return [self.split_sentences(text) for text in texts]
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def check_is_batched(
|
33 |
+
texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Check if input is batched or a single sample.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
40 |
+
Text to check.
|
41 |
+
is_split_into_words (:obj:`bool`):
|
42 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
:obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
|
46 |
+
"""
|
47 |
+
return bool(
|
48 |
+
(not is_split_into_words and isinstance(texts, (list, tuple)))
|
49 |
+
or (
|
50 |
+
is_split_into_words
|
51 |
+
and isinstance(texts, (list, tuple))
|
52 |
+
and texts
|
53 |
+
and isinstance(texts[0], (list, tuple))
|
54 |
+
)
|
55 |
+
)
|
relik/inference/data/splitters/blank_sentence_splitter.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
|
4 |
+
class BlankSentenceSplitter:
|
5 |
+
"""
|
6 |
+
A `BlankSentenceSplitter` splits strings into sentences.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __call__(self, *args, **kwargs):
|
10 |
+
"""
|
11 |
+
Calls :meth:`split_sentences`.
|
12 |
+
"""
|
13 |
+
return self.split_sentences(*args, **kwargs)
|
14 |
+
|
15 |
+
def split_sentences(
|
16 |
+
self, text: str, max_len: int = 0, *args, **kwargs
|
17 |
+
) -> List[str]:
|
18 |
+
"""
|
19 |
+
Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence.
|
20 |
+
"""
|
21 |
+
return [text]
|
22 |
+
|
23 |
+
def split_sentences_batch(
|
24 |
+
self, texts: List[str], *args, **kwargs
|
25 |
+
) -> List[List[str]]:
|
26 |
+
"""
|
27 |
+
Default implementation is to just iterate over the texts and call `split_sentences`.
|
28 |
+
"""
|
29 |
+
return [self.split_sentences(text) for text in texts]
|
relik/inference/data/splitters/spacy_sentence_splitter.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Iterable, List, Optional, Union
|
2 |
+
|
3 |
+
import spacy
|
4 |
+
|
5 |
+
from relik.inference.data.objects import Word
|
6 |
+
from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
|
7 |
+
from relik.inference.data.tokenizers.spacy_tokenizer import load_spacy
|
8 |
+
|
9 |
+
SPACY_LANGUAGE_MAPPER = {
|
10 |
+
"cs": "xx_sent_ud_sm",
|
11 |
+
"da": "xx_sent_ud_sm",
|
12 |
+
"de": "xx_sent_ud_sm",
|
13 |
+
"fa": "xx_sent_ud_sm",
|
14 |
+
"fi": "xx_sent_ud_sm",
|
15 |
+
"fr": "xx_sent_ud_sm",
|
16 |
+
"el": "el_core_news_sm",
|
17 |
+
"en": "xx_sent_ud_sm",
|
18 |
+
"es": "xx_sent_ud_sm",
|
19 |
+
"ga": "xx_sent_ud_sm",
|
20 |
+
"hr": "xx_sent_ud_sm",
|
21 |
+
"id": "xx_sent_ud_sm",
|
22 |
+
"it": "xx_sent_ud_sm",
|
23 |
+
"ja": "ja_core_news_sm",
|
24 |
+
"lv": "xx_sent_ud_sm",
|
25 |
+
"lt": "xx_sent_ud_sm",
|
26 |
+
"mr": "xx_sent_ud_sm",
|
27 |
+
"nb": "xx_sent_ud_sm",
|
28 |
+
"nl": "xx_sent_ud_sm",
|
29 |
+
"no": "xx_sent_ud_sm",
|
30 |
+
"pl": "pl_core_news_sm",
|
31 |
+
"pt": "xx_sent_ud_sm",
|
32 |
+
"ro": "xx_sent_ud_sm",
|
33 |
+
"ru": "xx_sent_ud_sm",
|
34 |
+
"sk": "xx_sent_ud_sm",
|
35 |
+
"sr": "xx_sent_ud_sm",
|
36 |
+
"sv": "xx_sent_ud_sm",
|
37 |
+
"te": "xx_sent_ud_sm",
|
38 |
+
"vi": "xx_sent_ud_sm",
|
39 |
+
"zh": "zh_core_web_sm",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
class SpacySentenceSplitter(BaseSentenceSplitter):
|
44 |
+
"""
|
45 |
+
A :obj:`SentenceSplitter` that uses spaCy's built-in sentence boundary detection.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
language (:obj:`str`, optional, defaults to :obj:`en`):
|
49 |
+
Language of the text to tokenize.
|
50 |
+
model_type (:obj:`str`, optional, defaults to :obj:`statistical`):
|
51 |
+
Three different type of sentence splitter:
|
52 |
+
- ``dependency``: sentence splitter uses a dependency parse to detect sentence boundaries,
|
53 |
+
slow, but accurate.
|
54 |
+
- ``statistical``:
|
55 |
+
- ``rule_based``: It's fast and has a small memory footprint, since it uses punctuation to detect
|
56 |
+
sentence boundaries.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, language: str = "en", model_type: str = "statistical") -> None:
|
60 |
+
# we need spacy's dependency parser if we're not using rule-based sentence boundary detection.
|
61 |
+
# self.spacy = get_spacy_model(language, parse=not rule_based, ner=False)
|
62 |
+
dep = bool(model_type == "dependency")
|
63 |
+
if language in SPACY_LANGUAGE_MAPPER:
|
64 |
+
self.spacy = load_spacy(SPACY_LANGUAGE_MAPPER[language], parse=dep)
|
65 |
+
else:
|
66 |
+
self.spacy = spacy.blank(language)
|
67 |
+
# force type to rule_based since there is no pre-trained model
|
68 |
+
model_type = "rule_based"
|
69 |
+
if model_type == "dependency":
|
70 |
+
# dependency type must declared at model init
|
71 |
+
pass
|
72 |
+
elif model_type == "statistical":
|
73 |
+
if not self.spacy.has_pipe("senter"):
|
74 |
+
self.spacy.enable_pipe("senter")
|
75 |
+
elif model_type == "rule_based":
|
76 |
+
# we use `sentencizer`, a built-in spacy module for rule-based sentence boundary detection.
|
77 |
+
# depending on the spacy version, it could be called 'sentencizer' or 'sbd'
|
78 |
+
if not self.spacy.has_pipe("sentencizer"):
|
79 |
+
self.spacy.add_pipe("sentencizer")
|
80 |
+
else:
|
81 |
+
raise ValueError(
|
82 |
+
f"type {model_type} not supported. Choose between `dependency`, `statistical` or `rule_based`"
|
83 |
+
)
|
84 |
+
|
85 |
+
def __call__(
|
86 |
+
self,
|
87 |
+
texts: Union[str, List[str], List[List[str]]],
|
88 |
+
max_length: Optional[int] = None,
|
89 |
+
is_split_into_words: bool = False,
|
90 |
+
**kwargs,
|
91 |
+
) -> Union[List[str], List[List[str]]]:
|
92 |
+
"""
|
93 |
+
Tokenize the input into single words using SpaCy models.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
97 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
98 |
+
max_len (:obj:`int`, optional, defaults to :obj:`0`):
|
99 |
+
Maximum length of a single text. If the text is longer than `max_len`, it will be split
|
100 |
+
into multiple sentences.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
:obj:`List[List[str]]`: The input doc split into sentences.
|
104 |
+
"""
|
105 |
+
# check if input is batched or a single sample
|
106 |
+
is_batched = self.check_is_batched(texts, is_split_into_words)
|
107 |
+
|
108 |
+
if is_batched:
|
109 |
+
sents = self.split_sentences_batch(texts)
|
110 |
+
else:
|
111 |
+
sents = self.split_sentences(texts, max_length)
|
112 |
+
return sents
|
113 |
+
|
114 |
+
@staticmethod
|
115 |
+
def chunked(iterable, n: int) -> Iterable[List[Any]]:
|
116 |
+
"""
|
117 |
+
Chunks a list into n sized chunks.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
iterable (:obj:`List[Any]`):
|
121 |
+
List to chunk.
|
122 |
+
n (:obj:`int`):
|
123 |
+
Size of the chunks.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
:obj:`Iterable[List[Any]]`: The input list chunked into n sized chunks.
|
127 |
+
"""
|
128 |
+
return [iterable[i : i + n] for i in range(0, len(iterable), n)]
|
129 |
+
|
130 |
+
def split_sentences(
|
131 |
+
self, text: str | List[Word], max_length: Optional[int] = None, *args, **kwargs
|
132 |
+
) -> List[str]:
|
133 |
+
"""
|
134 |
+
Splits a `text` into smaller sentences.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
text (:obj:`str`):
|
138 |
+
Text to split.
|
139 |
+
max_length (:obj:`int`, optional, defaults to :obj:`0`):
|
140 |
+
Maximum length of a single sentence. If the text is longer than `max_len`, it will be split
|
141 |
+
into multiple sentences.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
:obj:`List[str]`: The input text split into sentences.
|
145 |
+
"""
|
146 |
+
sentences = [sent for sent in self.spacy(text).sents]
|
147 |
+
if max_length is not None and max_length > 0:
|
148 |
+
sentences = [
|
149 |
+
chunk
|
150 |
+
for sentence in sentences
|
151 |
+
for chunk in self.chunked(sentence, max_length)
|
152 |
+
]
|
153 |
+
return sentences
|
relik/inference/data/splitters/window_based_splitter.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
|
4 |
+
|
5 |
+
|
6 |
+
class WindowSentenceSplitter(BaseSentenceSplitter):
|
7 |
+
"""
|
8 |
+
A :obj:`WindowSentenceSplitter` that splits a text into windows of a given size.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, window_size: int, window_stride: int, *args, **kwargs) -> None:
|
12 |
+
super(WindowSentenceSplitter, self).__init__()
|
13 |
+
self.window_size = window_size
|
14 |
+
self.window_stride = window_stride
|
15 |
+
|
16 |
+
def __call__(
|
17 |
+
self,
|
18 |
+
texts: Union[str, List[str], List[List[str]]],
|
19 |
+
is_split_into_words: bool = False,
|
20 |
+
**kwargs,
|
21 |
+
) -> Union[List[str], List[List[str]]]:
|
22 |
+
"""
|
23 |
+
Tokenize the input into single words using SpaCy models.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
27 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
:obj:`List[List[str]]`: The input doc split into sentences.
|
31 |
+
"""
|
32 |
+
return self.split_sentences(texts)
|
33 |
+
|
34 |
+
def split_sentences(self, text: str | List, *args, **kwargs) -> List[List]:
|
35 |
+
"""
|
36 |
+
Splits a `text` into sentences.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
text (:obj:`str`):
|
40 |
+
Text to split.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
:obj:`List[str]`: The input text split into sentences.
|
44 |
+
"""
|
45 |
+
|
46 |
+
if isinstance(text, str):
|
47 |
+
text = text.split()
|
48 |
+
sentences = []
|
49 |
+
for i in range(0, len(text), self.window_stride):
|
50 |
+
# if the last stride is smaller than the window size, then we can
|
51 |
+
# include more tokens form the previous window.
|
52 |
+
if i != 0 and i + self.window_size > len(text):
|
53 |
+
overflowing_tokens = i + self.window_size - len(text)
|
54 |
+
if overflowing_tokens >= self.window_stride:
|
55 |
+
break
|
56 |
+
i -= overflowing_tokens
|
57 |
+
involved_token_indices = list(
|
58 |
+
range(i, min(i + self.window_size, len(text)))
|
59 |
+
)
|
60 |
+
window_tokens = [text[j] for j in involved_token_indices]
|
61 |
+
sentences.append(window_tokens)
|
62 |
+
return sentences
|
relik/inference/data/tokenizers/__init__.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SPACY_LANGUAGE_MAPPER = {
|
2 |
+
"ca": "ca_core_news_sm",
|
3 |
+
"da": "da_core_news_sm",
|
4 |
+
"de": "de_core_news_sm",
|
5 |
+
"el": "el_core_news_sm",
|
6 |
+
"en": "en_core_web_sm",
|
7 |
+
"es": "es_core_news_sm",
|
8 |
+
"fr": "fr_core_news_sm",
|
9 |
+
"it": "it_core_news_sm",
|
10 |
+
"ja": "ja_core_news_sm",
|
11 |
+
"lt": "lt_core_news_sm",
|
12 |
+
"mk": "mk_core_news_sm",
|
13 |
+
"nb": "nb_core_news_sm",
|
14 |
+
"nl": "nl_core_news_sm",
|
15 |
+
"pl": "pl_core_news_sm",
|
16 |
+
"pt": "pt_core_news_sm",
|
17 |
+
"ro": "ro_core_news_sm",
|
18 |
+
"ru": "ru_core_news_sm",
|
19 |
+
"xx": "xx_sent_ud_sm",
|
20 |
+
"zh": "zh_core_web_sm",
|
21 |
+
"ca_core_news_sm": "ca_core_news_sm",
|
22 |
+
"ca_core_news_md": "ca_core_news_md",
|
23 |
+
"ca_core_news_lg": "ca_core_news_lg",
|
24 |
+
"ca_core_news_trf": "ca_core_news_trf",
|
25 |
+
"da_core_news_sm": "da_core_news_sm",
|
26 |
+
"da_core_news_md": "da_core_news_md",
|
27 |
+
"da_core_news_lg": "da_core_news_lg",
|
28 |
+
"da_core_news_trf": "da_core_news_trf",
|
29 |
+
"de_core_news_sm": "de_core_news_sm",
|
30 |
+
"de_core_news_md": "de_core_news_md",
|
31 |
+
"de_core_news_lg": "de_core_news_lg",
|
32 |
+
"de_dep_news_trf": "de_dep_news_trf",
|
33 |
+
"el_core_news_sm": "el_core_news_sm",
|
34 |
+
"el_core_news_md": "el_core_news_md",
|
35 |
+
"el_core_news_lg": "el_core_news_lg",
|
36 |
+
"en_core_web_sm": "en_core_web_sm",
|
37 |
+
"en_core_web_md": "en_core_web_md",
|
38 |
+
"en_core_web_lg": "en_core_web_lg",
|
39 |
+
"en_core_web_trf": "en_core_web_trf",
|
40 |
+
"es_core_news_sm": "es_core_news_sm",
|
41 |
+
"es_core_news_md": "es_core_news_md",
|
42 |
+
"es_core_news_lg": "es_core_news_lg",
|
43 |
+
"es_dep_news_trf": "es_dep_news_trf",
|
44 |
+
"fr_core_news_sm": "fr_core_news_sm",
|
45 |
+
"fr_core_news_md": "fr_core_news_md",
|
46 |
+
"fr_core_news_lg": "fr_core_news_lg",
|
47 |
+
"fr_dep_news_trf": "fr_dep_news_trf",
|
48 |
+
"it_core_news_sm": "it_core_news_sm",
|
49 |
+
"it_core_news_md": "it_core_news_md",
|
50 |
+
"it_core_news_lg": "it_core_news_lg",
|
51 |
+
"ja_core_news_sm": "ja_core_news_sm",
|
52 |
+
"ja_core_news_md": "ja_core_news_md",
|
53 |
+
"ja_core_news_lg": "ja_core_news_lg",
|
54 |
+
"ja_dep_news_trf": "ja_dep_news_trf",
|
55 |
+
"lt_core_news_sm": "lt_core_news_sm",
|
56 |
+
"lt_core_news_md": "lt_core_news_md",
|
57 |
+
"lt_core_news_lg": "lt_core_news_lg",
|
58 |
+
"mk_core_news_sm": "mk_core_news_sm",
|
59 |
+
"mk_core_news_md": "mk_core_news_md",
|
60 |
+
"mk_core_news_lg": "mk_core_news_lg",
|
61 |
+
"nb_core_news_sm": "nb_core_news_sm",
|
62 |
+
"nb_core_news_md": "nb_core_news_md",
|
63 |
+
"nb_core_news_lg": "nb_core_news_lg",
|
64 |
+
"nl_core_news_sm": "nl_core_news_sm",
|
65 |
+
"nl_core_news_md": "nl_core_news_md",
|
66 |
+
"nl_core_news_lg": "nl_core_news_lg",
|
67 |
+
"pl_core_news_sm": "pl_core_news_sm",
|
68 |
+
"pl_core_news_md": "pl_core_news_md",
|
69 |
+
"pl_core_news_lg": "pl_core_news_lg",
|
70 |
+
"pt_core_news_sm": "pt_core_news_sm",
|
71 |
+
"pt_core_news_md": "pt_core_news_md",
|
72 |
+
"pt_core_news_lg": "pt_core_news_lg",
|
73 |
+
"ro_core_news_sm": "ro_core_news_sm",
|
74 |
+
"ro_core_news_md": "ro_core_news_md",
|
75 |
+
"ro_core_news_lg": "ro_core_news_lg",
|
76 |
+
"ru_core_news_sm": "ru_core_news_sm",
|
77 |
+
"ru_core_news_md": "ru_core_news_md",
|
78 |
+
"ru_core_news_lg": "ru_core_news_lg",
|
79 |
+
"xx_ent_wiki_sm": "xx_ent_wiki_sm",
|
80 |
+
"xx_sent_ud_sm": "xx_sent_ud_sm",
|
81 |
+
"zh_core_web_sm": "zh_core_web_sm",
|
82 |
+
"zh_core_web_md": "zh_core_web_md",
|
83 |
+
"zh_core_web_lg": "zh_core_web_lg",
|
84 |
+
"zh_core_web_trf": "zh_core_web_trf",
|
85 |
+
}
|
86 |
+
|
87 |
+
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
relik/inference/data/tokenizers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.31 kB). View file
|
|
relik/inference/data/tokenizers/__pycache__/base_tokenizer.cpython-310.pyc
ADDED
Binary file (3.13 kB). View file
|
|
relik/inference/data/tokenizers/__pycache__/spacy_tokenizer.cpython-310.pyc
ADDED
Binary file (6.55 kB). View file
|
|
relik/inference/data/tokenizers/base_tokenizer.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
from relik.inference.data.objects import Word
|
4 |
+
|
5 |
+
|
6 |
+
class BaseTokenizer:
|
7 |
+
"""
|
8 |
+
A :obj:`Tokenizer` splits strings of text into single words, optionally adds
|
9 |
+
pos tags and perform lemmatization.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __call__(
|
13 |
+
self,
|
14 |
+
texts: Union[str, List[str], List[List[str]]],
|
15 |
+
is_split_into_words: bool = False,
|
16 |
+
**kwargs
|
17 |
+
) -> List[List[Word]]:
|
18 |
+
"""
|
19 |
+
Tokenize the input into single words.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
23 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
24 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
25 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
29 |
+
"""
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
def tokenize(self, text: str) -> List[Word]:
|
33 |
+
"""
|
34 |
+
Implements splitting words into tokens.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
text (:obj:`str`):
|
38 |
+
Text to tokenize.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
:obj:`List[Word]`: The input text tokenized in single words.
|
42 |
+
|
43 |
+
"""
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
|
47 |
+
"""
|
48 |
+
Implements batch splitting words into tokens.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
texts (:obj:`List[str]`):
|
52 |
+
Batch of text to tokenize.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
:obj:`List[List[Word]]`: The input batch tokenized in single words.
|
56 |
+
|
57 |
+
"""
|
58 |
+
return [self.tokenize(text) for text in texts]
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def check_is_batched(
|
62 |
+
texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
|
63 |
+
):
|
64 |
+
"""
|
65 |
+
Check if input is batched or a single sample.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
69 |
+
Text to check.
|
70 |
+
is_split_into_words (:obj:`bool`):
|
71 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
:obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
|
75 |
+
"""
|
76 |
+
return bool(
|
77 |
+
(not is_split_into_words and isinstance(texts, (list, tuple)))
|
78 |
+
or (
|
79 |
+
is_split_into_words
|
80 |
+
and isinstance(texts, (list, tuple))
|
81 |
+
and texts
|
82 |
+
and isinstance(texts[0], (list, tuple))
|
83 |
+
)
|
84 |
+
)
|
relik/inference/data/tokenizers/spacy_tokenizer.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from copy import deepcopy
|
3 |
+
from typing import Dict, List, Tuple, Union, Any
|
4 |
+
|
5 |
+
import spacy
|
6 |
+
|
7 |
+
# from ipa.common.utils import load_spacy
|
8 |
+
from spacy.cli.download import download as spacy_download
|
9 |
+
from spacy.tokens import Doc
|
10 |
+
|
11 |
+
from relik.common.log import get_logger
|
12 |
+
from relik.inference.data.objects import Word
|
13 |
+
from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
|
14 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
15 |
+
|
16 |
+
logger = get_logger(level=logging.DEBUG)
|
17 |
+
|
18 |
+
# Spacy and Stanza stuff
|
19 |
+
|
20 |
+
LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
|
21 |
+
|
22 |
+
|
23 |
+
def load_spacy(
|
24 |
+
language: str,
|
25 |
+
pos_tags: bool = False,
|
26 |
+
lemma: bool = False,
|
27 |
+
parse: bool = False,
|
28 |
+
split_on_spaces: bool = False,
|
29 |
+
) -> spacy.Language:
|
30 |
+
"""
|
31 |
+
Download and load spacy model.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
language (:obj:`str`, defaults to :obj:`en`):
|
35 |
+
Language of the text to tokenize.
|
36 |
+
pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
|
37 |
+
If :obj:`True`, performs POS tagging with spacy model.
|
38 |
+
lemma (:obj:`bool`, optional, defaults to :obj:`False`):
|
39 |
+
If :obj:`True`, performs lemmatization with spacy model.
|
40 |
+
parse (:obj:`bool`, optional, defaults to :obj:`False`):
|
41 |
+
If :obj:`True`, performs dependency parsing with spacy model.
|
42 |
+
split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
|
43 |
+
If :obj:`True`, will split by spaces without performing tokenization.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
:obj:`spacy.Language`: The spacy model loaded.
|
47 |
+
"""
|
48 |
+
exclude = ["vectors", "textcat", "ner"]
|
49 |
+
if not pos_tags:
|
50 |
+
exclude.append("tagger")
|
51 |
+
if not lemma:
|
52 |
+
exclude.append("lemmatizer")
|
53 |
+
if not parse:
|
54 |
+
exclude.append("parser")
|
55 |
+
|
56 |
+
# check if the model is already loaded
|
57 |
+
# if so, there is no need to reload it
|
58 |
+
spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
|
59 |
+
if spacy_params not in LOADED_SPACY_MODELS:
|
60 |
+
try:
|
61 |
+
spacy_tagger = spacy.load(language, exclude=exclude)
|
62 |
+
except OSError:
|
63 |
+
logger.warning(
|
64 |
+
"Spacy model '%s' not found. Downloading and installing.", language
|
65 |
+
)
|
66 |
+
spacy_download(language)
|
67 |
+
spacy_tagger = spacy.load(language, exclude=exclude)
|
68 |
+
|
69 |
+
# if everything is disabled, return only the tokenizer
|
70 |
+
# for faster tokenization
|
71 |
+
# TODO: is it really faster?
|
72 |
+
# TODO: check split_on_spaces behaviour if we don't do this if
|
73 |
+
if len(exclude) >= 6 and split_on_spaces:
|
74 |
+
spacy_tagger = spacy_tagger.tokenizer
|
75 |
+
LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
|
76 |
+
|
77 |
+
return LOADED_SPACY_MODELS[spacy_params]
|
78 |
+
|
79 |
+
|
80 |
+
class SpacyTokenizer(BaseTokenizer):
|
81 |
+
"""
|
82 |
+
A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
language (:obj:`str`, optional, defaults to :obj:`en`):
|
86 |
+
Language of the text to tokenize.
|
87 |
+
return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
|
88 |
+
If :obj:`True`, performs POS tagging with spacy model.
|
89 |
+
return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
|
90 |
+
If :obj:`True`, performs lemmatization with spacy model.
|
91 |
+
return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
|
92 |
+
If :obj:`True`, performs dependency parsing with spacy model.
|
93 |
+
use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
|
94 |
+
If :obj:`True`, will load the Stanza model on GPU.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
language: str = "en",
|
100 |
+
return_pos_tags: bool = False,
|
101 |
+
return_lemmas: bool = False,
|
102 |
+
return_deps: bool = False,
|
103 |
+
use_gpu: bool = False,
|
104 |
+
):
|
105 |
+
super().__init__()
|
106 |
+
if language not in SPACY_LANGUAGE_MAPPER:
|
107 |
+
raise ValueError(
|
108 |
+
f"`{language}` language not supported. The supported "
|
109 |
+
f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
|
110 |
+
)
|
111 |
+
if use_gpu:
|
112 |
+
# load the model on GPU
|
113 |
+
# if the GPU is not available or not correctly configured,
|
114 |
+
# it will rise an error
|
115 |
+
spacy.require_gpu()
|
116 |
+
self.spacy = load_spacy(
|
117 |
+
SPACY_LANGUAGE_MAPPER[language],
|
118 |
+
return_pos_tags,
|
119 |
+
return_lemmas,
|
120 |
+
return_deps,
|
121 |
+
)
|
122 |
+
|
123 |
+
def __call__(
|
124 |
+
self,
|
125 |
+
texts: Union[str, List[str], List[List[str]]],
|
126 |
+
is_split_into_words: bool = False,
|
127 |
+
**kwargs,
|
128 |
+
) -> Union[List[Word], List[List[Word]]]:
|
129 |
+
"""
|
130 |
+
Tokenize the input into single words using SpaCy models.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
134 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
135 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
136 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
140 |
+
|
141 |
+
Example::
|
142 |
+
|
143 |
+
>>> from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
144 |
+
|
145 |
+
>>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
|
146 |
+
>>> spacy_tokenizer("Mary sold the car to John.")
|
147 |
+
|
148 |
+
"""
|
149 |
+
# check if input is batched or a single sample
|
150 |
+
is_batched = self.check_is_batched(texts, is_split_into_words)
|
151 |
+
|
152 |
+
if is_batched:
|
153 |
+
tokenized = self.tokenize_batch(texts, is_split_into_words)
|
154 |
+
else:
|
155 |
+
tokenized = self.tokenize(texts, is_split_into_words)
|
156 |
+
|
157 |
+
return tokenized
|
158 |
+
|
159 |
+
def tokenize(self, text: Union[str, List[str]], is_split_into_words: bool) -> Doc:
|
160 |
+
if is_split_into_words:
|
161 |
+
if isinstance(text, str):
|
162 |
+
text = text.split(" ")
|
163 |
+
elif isinstance(text, list):
|
164 |
+
text = text
|
165 |
+
else:
|
166 |
+
raise ValueError(
|
167 |
+
f"text must be either `str` or `list`, found: `{type(text)}`"
|
168 |
+
)
|
169 |
+
spaces = [True] * len(text)
|
170 |
+
return self.spacy(Doc(self.spacy.vocab, words=text, spaces=spaces))
|
171 |
+
return self.spacy(text)
|
172 |
+
|
173 |
+
def tokenize_batch(
|
174 |
+
self, texts: Union[List[str], List[List[str]]], is_split_into_words: bool
|
175 |
+
) -> list[Any] | list[Doc]:
|
176 |
+
try:
|
177 |
+
if is_split_into_words:
|
178 |
+
if isinstance(texts[0], str):
|
179 |
+
texts = [text.split(" ") for text in texts]
|
180 |
+
elif isinstance(texts[0], list):
|
181 |
+
texts = texts
|
182 |
+
else:
|
183 |
+
raise ValueError(
|
184 |
+
f"text must be either `str` or `list`, found: `{type(texts[0])}`"
|
185 |
+
)
|
186 |
+
spaces = [[True] * len(text) for text in texts]
|
187 |
+
texts = [
|
188 |
+
Doc(self.spacy.vocab, words=text, spaces=space)
|
189 |
+
for text, space in zip(texts, spaces)
|
190 |
+
]
|
191 |
+
return list(self.spacy.pipe(texts))
|
192 |
+
except AttributeError:
|
193 |
+
# a WhitespaceSpacyTokenizer has no `pipe()` method, we use simple for loop
|
194 |
+
return [self.spacy(tokens) for tokens in texts]
|
relik/inference/data/window/__init__.py
ADDED
File without changes
|
relik/inference/data/window/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (195 Bytes). View file
|
|
relik/inference/data/window/__pycache__/manager.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
relik/inference/data/window/manager.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import itertools
|
3 |
+
from typing import Dict, List, Optional, Set, Tuple
|
4 |
+
|
5 |
+
from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter
|
6 |
+
from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
|
7 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
8 |
+
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
9 |
+
|
10 |
+
|
11 |
+
class WindowManager:
|
12 |
+
def __init__(
|
13 |
+
self, tokenizer: BaseTokenizer, splitter: BaseSentenceSplitter | None = None
|
14 |
+
) -> None:
|
15 |
+
self.tokenizer = tokenizer
|
16 |
+
self.splitter = splitter or BlankSentenceSplitter()
|
17 |
+
|
18 |
+
def create_windows(
|
19 |
+
self,
|
20 |
+
documents: str | List[str],
|
21 |
+
window_size: int | None = None,
|
22 |
+
stride: int | None = None,
|
23 |
+
max_length: int | None = None,
|
24 |
+
doc_id: str | int | None = None,
|
25 |
+
doc_topic: str | None = None,
|
26 |
+
is_split_into_words: bool = False,
|
27 |
+
mentions: List[List[List[int]]] = None,
|
28 |
+
) -> Tuple[List[RelikReaderSample], List[RelikReaderSample]]:
|
29 |
+
"""
|
30 |
+
Create windows from a list of documents.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
documents (:obj:`str` or :obj:`List[str]`):
|
34 |
+
The document(s) to split in windows.
|
35 |
+
window_size (:obj:`int`):
|
36 |
+
The size of the window.
|
37 |
+
stride (:obj:`int`):
|
38 |
+
The stride between two windows.
|
39 |
+
max_length (:obj:`int`, `optional`):
|
40 |
+
The maximum length of a window.
|
41 |
+
doc_id (:obj:`str` or :obj:`int`, `optional`):
|
42 |
+
The id of the document(s).
|
43 |
+
doc_topic (:obj:`str`, `optional`):
|
44 |
+
The topic of the document(s).
|
45 |
+
is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
46 |
+
Whether the input is already pre-tokenized (e.g., split into words). If :obj:`False`, the
|
47 |
+
input will first be tokenized using the tokenizer, then the tokens will be split into words.
|
48 |
+
mentions (:obj:`List[List[List[int]]]`, `optional`):
|
49 |
+
The mentions of the document(s).
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
:obj:`List[RelikReaderSample]`: The windows created from the documents.
|
53 |
+
"""
|
54 |
+
# normalize input
|
55 |
+
if isinstance(documents, str) or is_split_into_words:
|
56 |
+
documents = [documents]
|
57 |
+
|
58 |
+
# batch tokenize
|
59 |
+
documents_tokens = self.tokenizer(
|
60 |
+
documents, is_split_into_words=is_split_into_words
|
61 |
+
)
|
62 |
+
|
63 |
+
# set splitter params
|
64 |
+
if hasattr(self.splitter, "window_size"):
|
65 |
+
self.splitter.window_size = window_size or self.splitter.window_size
|
66 |
+
if hasattr(self.splitter, "window_stride"):
|
67 |
+
self.splitter.window_stride = stride or self.splitter.window_stride
|
68 |
+
|
69 |
+
windowed_documents, windowed_blank_documents = [], []
|
70 |
+
|
71 |
+
if mentions is not None:
|
72 |
+
assert len(documents) == len(
|
73 |
+
mentions
|
74 |
+
), f"documents and mentions should have the same length, got {len(documents)} and {len(mentions)}"
|
75 |
+
doc_iter = zip(documents, documents_tokens, mentions)
|
76 |
+
else:
|
77 |
+
doc_iter = zip(documents, documents_tokens, itertools.repeat([]))
|
78 |
+
|
79 |
+
for infered_doc_id, (document, document_tokens, document_mentions) in enumerate(
|
80 |
+
doc_iter
|
81 |
+
):
|
82 |
+
if doc_topic is None:
|
83 |
+
doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
|
84 |
+
|
85 |
+
if doc_id is None:
|
86 |
+
doc_id = infered_doc_id
|
87 |
+
|
88 |
+
splitted_document = self.splitter(document_tokens, max_length=max_length)
|
89 |
+
|
90 |
+
document_windows = []
|
91 |
+
for window_id, window in enumerate(splitted_document):
|
92 |
+
window_text_start = window[0].idx
|
93 |
+
window_text_end = window[-1].idx + len(window[-1].text)
|
94 |
+
if isinstance(document, str):
|
95 |
+
text = document[window_text_start:window_text_end]
|
96 |
+
else:
|
97 |
+
# window_text_start = window[0].idx
|
98 |
+
# window_text_end = window[-1].i
|
99 |
+
text = " ".join([w.text for w in window])
|
100 |
+
sample = RelikReaderSample(
|
101 |
+
doc_id=doc_id,
|
102 |
+
window_id=window_id,
|
103 |
+
text=text,
|
104 |
+
tokens=[w.text for w in window],
|
105 |
+
words=[w.text for w in window],
|
106 |
+
doc_topic=doc_topic,
|
107 |
+
offset=window_text_start,
|
108 |
+
spans=[
|
109 |
+
[m[0], m[1]] for m in document_mentions
|
110 |
+
if window_text_end > m[0] >= window_text_start and window_text_end >= m[1] >= window_text_start
|
111 |
+
],
|
112 |
+
token2char_start={str(i): w.idx for i, w in enumerate(window)},
|
113 |
+
token2char_end={
|
114 |
+
str(i): w.idx + len(w.text) for i, w in enumerate(window)
|
115 |
+
},
|
116 |
+
char2token_start={
|
117 |
+
str(w.idx): w.i for i, w in enumerate(window)
|
118 |
+
},
|
119 |
+
char2token_end={
|
120 |
+
str(w.idx + len(w.text)): w.i for i, w in enumerate(window)
|
121 |
+
},
|
122 |
+
)
|
123 |
+
if mentions is not None and len(sample.spans) == 0:
|
124 |
+
windowed_blank_documents.append(sample)
|
125 |
+
else:
|
126 |
+
document_windows.append(sample)
|
127 |
+
|
128 |
+
windowed_documents.extend(document_windows)
|
129 |
+
if mentions is not None:
|
130 |
+
return windowed_documents, windowed_blank_documents
|
131 |
+
else:
|
132 |
+
return windowed_documents, windowed_blank_documents
|
133 |
+
|
134 |
+
def merge_windows(
|
135 |
+
self, windows: List[RelikReaderSample]
|
136 |
+
) -> List[RelikReaderSample]:
|
137 |
+
windows_by_doc_id = collections.defaultdict(list)
|
138 |
+
for window in windows:
|
139 |
+
windows_by_doc_id[window.doc_id].append(window)
|
140 |
+
|
141 |
+
merged_window_by_doc = {
|
142 |
+
doc_id: self._merge_doc_windows(doc_windows)
|
143 |
+
for doc_id, doc_windows in windows_by_doc_id.items()
|
144 |
+
}
|
145 |
+
|
146 |
+
return list(merged_window_by_doc.values())
|
147 |
+
|
148 |
+
def _merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
|
149 |
+
if len(windows) == 1:
|
150 |
+
return windows[0]
|
151 |
+
|
152 |
+
if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
|
153 |
+
windows = sorted(windows, key=(lambda x: x.offset))
|
154 |
+
|
155 |
+
window_accumulator = windows[0]
|
156 |
+
|
157 |
+
for next_window in windows[1:]:
|
158 |
+
window_accumulator = self._merge_window_pair(
|
159 |
+
window_accumulator, next_window
|
160 |
+
)
|
161 |
+
|
162 |
+
return window_accumulator
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def _merge_tokens(
|
166 |
+
window1: RelikReaderSample, window2: RelikReaderSample
|
167 |
+
) -> Tuple[list, dict, dict]:
|
168 |
+
w1_tokens = window1.tokens[1:-1]
|
169 |
+
w2_tokens = window2.tokens[1:-1]
|
170 |
+
|
171 |
+
# find intersection if any
|
172 |
+
tokens_intersection = 0
|
173 |
+
for k in reversed(range(1, len(w1_tokens))):
|
174 |
+
if w1_tokens[-k:] == w2_tokens[:k]:
|
175 |
+
tokens_intersection = k
|
176 |
+
break
|
177 |
+
|
178 |
+
final_tokens = (
|
179 |
+
[window1.tokens[0]] # CLS
|
180 |
+
+ w1_tokens
|
181 |
+
+ w2_tokens[tokens_intersection:]
|
182 |
+
+ [window1.tokens[-1]] # SEP
|
183 |
+
)
|
184 |
+
|
185 |
+
w2_starting_offset = len(w1_tokens) - tokens_intersection
|
186 |
+
|
187 |
+
def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
|
188 |
+
final_t2c = dict()
|
189 |
+
final_t2c.update(t2c1)
|
190 |
+
for t, c in t2c2.items():
|
191 |
+
t = int(t)
|
192 |
+
if t < tokens_intersection:
|
193 |
+
continue
|
194 |
+
final_t2c[str(t + w2_starting_offset)] = c
|
195 |
+
return final_t2c
|
196 |
+
|
197 |
+
return (
|
198 |
+
final_tokens,
|
199 |
+
merge_char_mapping(window1.token2char_start, window2.token2char_start),
|
200 |
+
merge_char_mapping(window1.token2char_end, window2.token2char_end),
|
201 |
+
)
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def _merge_words(
|
205 |
+
window1: RelikReaderSample, window2: RelikReaderSample
|
206 |
+
) -> Tuple[list, dict, dict]:
|
207 |
+
w1_words = window1.words
|
208 |
+
w2_words = window2.words
|
209 |
+
|
210 |
+
# find intersection if any
|
211 |
+
words_intersection = 0
|
212 |
+
for k in reversed(range(1, len(w1_words))):
|
213 |
+
if w1_words[-k:] == w2_words[:k]:
|
214 |
+
words_intersection = k
|
215 |
+
break
|
216 |
+
|
217 |
+
final_words = w1_words + w2_words[words_intersection:]
|
218 |
+
|
219 |
+
w2_starting_offset = len(w1_words) - words_intersection
|
220 |
+
|
221 |
+
def merge_word_mapping(t2c1: dict, t2c2: dict) -> dict:
|
222 |
+
final_t2c = dict()
|
223 |
+
if t2c1 is None:
|
224 |
+
t2c1 = dict()
|
225 |
+
if t2c2 is None:
|
226 |
+
t2c2 = dict()
|
227 |
+
final_t2c.update(t2c1)
|
228 |
+
for t, c in t2c2.items():
|
229 |
+
t = int(t)
|
230 |
+
if t < words_intersection:
|
231 |
+
continue
|
232 |
+
final_t2c[str(t + w2_starting_offset)] = c
|
233 |
+
return final_t2c
|
234 |
+
|
235 |
+
return (
|
236 |
+
final_words,
|
237 |
+
merge_word_mapping(window1.token2word_start, window2.token2word_start),
|
238 |
+
merge_word_mapping(window1.token2word_end, window2.token2word_end),
|
239 |
+
)
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
def _merge_span_annotation(
|
243 |
+
span_annotation1: List[list], span_annotation2: List[list]
|
244 |
+
) -> List[list]:
|
245 |
+
uniq_store = set()
|
246 |
+
final_span_annotation_store = []
|
247 |
+
for span_annotation in itertools.chain(span_annotation1, span_annotation2):
|
248 |
+
span_annotation_id = tuple(span_annotation)
|
249 |
+
if span_annotation_id not in uniq_store:
|
250 |
+
uniq_store.add(span_annotation_id)
|
251 |
+
final_span_annotation_store.append(span_annotation)
|
252 |
+
return sorted(final_span_annotation_store, key=lambda x: x[0])
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def _merge_predictions(
|
256 |
+
window1: RelikReaderSample, window2: RelikReaderSample
|
257 |
+
) -> Tuple[Set[Tuple[int, int, str]], dict]:
|
258 |
+
# a RelikReaderSample should have a filed called `predicted_spans`
|
259 |
+
# that stores the span-level predictions, or a filed called
|
260 |
+
# `predicted_triples` that stores the triple-level predictions
|
261 |
+
|
262 |
+
# span predictions
|
263 |
+
merged_span_predictions: Set = set()
|
264 |
+
merged_span_probabilities = dict()
|
265 |
+
# triple predictions
|
266 |
+
merged_triplet_predictions: Set = set()
|
267 |
+
merged_triplet_probs: Dict = dict()
|
268 |
+
|
269 |
+
if (
|
270 |
+
getattr(window1, "predicted_spans", None) is not None
|
271 |
+
and getattr(window2, "predicted_spans", None) is not None
|
272 |
+
):
|
273 |
+
merged_span_predictions = set(window1.predicted_spans).union(
|
274 |
+
set(window2.predicted_spans)
|
275 |
+
)
|
276 |
+
merged_span_predictions = sorted(merged_span_predictions)
|
277 |
+
# probabilities
|
278 |
+
for span_prediction, predicted_probs in itertools.chain(
|
279 |
+
window1.probs_window_labels_chars.items()
|
280 |
+
if window1.probs_window_labels_chars is not None
|
281 |
+
else [],
|
282 |
+
window2.probs_window_labels_chars.items()
|
283 |
+
if window2.probs_window_labels_chars is not None
|
284 |
+
else [],
|
285 |
+
):
|
286 |
+
if span_prediction not in merged_span_probabilities:
|
287 |
+
merged_span_probabilities[span_prediction] = predicted_probs
|
288 |
+
|
289 |
+
if (
|
290 |
+
getattr(window1, "predicted_triples", None) is not None
|
291 |
+
and getattr(window2, "predicted_triples", None) is not None
|
292 |
+
):
|
293 |
+
# try to merge the triples predictions
|
294 |
+
# add offset to the second window
|
295 |
+
window1_triplets = [
|
296 |
+
(
|
297 |
+
merged_span_predictions.index(window1.predicted_spans[t[0]]),
|
298 |
+
t[1],
|
299 |
+
merged_span_predictions.index(window1.predicted_spans[t[2]]),
|
300 |
+
t[3]
|
301 |
+
)
|
302 |
+
for t in window1.predicted_triples
|
303 |
+
]
|
304 |
+
window2_triplets = [
|
305 |
+
(
|
306 |
+
merged_span_predictions.index(window2.predicted_spans[t[0]]),
|
307 |
+
t[1],
|
308 |
+
merged_span_predictions.index(window2.predicted_spans[t[2]]),
|
309 |
+
t[3]
|
310 |
+
)
|
311 |
+
for t in window2.predicted_triples
|
312 |
+
]
|
313 |
+
merged_triplet_predictions = set(window1_triplets).union(
|
314 |
+
set(window2_triplets)
|
315 |
+
)
|
316 |
+
merged_triplet_predictions = sorted(merged_triplet_predictions)
|
317 |
+
# for now no triplet probs, we don't need them for the moment
|
318 |
+
|
319 |
+
return (
|
320 |
+
merged_span_predictions,
|
321 |
+
merged_span_probabilities,
|
322 |
+
merged_triplet_predictions,
|
323 |
+
merged_triplet_probs,
|
324 |
+
)
|
325 |
+
|
326 |
+
@staticmethod
|
327 |
+
def _merge_candidates(window1: RelikReaderSample, window2: RelikReaderSample):
|
328 |
+
candidates = []
|
329 |
+
windows_candidates = []
|
330 |
+
|
331 |
+
# TODO: retro-compatibility
|
332 |
+
if getattr(window1, "candidates", None) is not None:
|
333 |
+
candidates = window1.candidates
|
334 |
+
if getattr(window2, "candidates", None) is not None:
|
335 |
+
candidates += window2.candidates
|
336 |
+
|
337 |
+
# TODO: retro-compatibility
|
338 |
+
if getattr(window1, "windows_candidates", None) is not None:
|
339 |
+
windows_candidates = window1.windows_candidates
|
340 |
+
if getattr(window2, "windows_candidates", None) is not None:
|
341 |
+
windows_candidates += window2.windows_candidates
|
342 |
+
|
343 |
+
# TODO: add programmatically
|
344 |
+
span_candidates = []
|
345 |
+
if getattr(window1, "span_candidates", None) is not None:
|
346 |
+
span_candidates = window1.span_candidates
|
347 |
+
if getattr(window2, "span_candidates", None) is not None:
|
348 |
+
span_candidates += window2.span_candidates
|
349 |
+
|
350 |
+
triplet_candidates = []
|
351 |
+
if getattr(window1, "triplet_candidates", None) is not None:
|
352 |
+
triplet_candidates = window1.triplet_candidates
|
353 |
+
if getattr(window2, "triplet_candidates", None) is not None:
|
354 |
+
triplet_candidates += window2.triplet_candidates
|
355 |
+
|
356 |
+
# make them unique
|
357 |
+
candidates = list(set(candidates))
|
358 |
+
windows_candidates = list(set(windows_candidates))
|
359 |
+
|
360 |
+
span_candidates = list(set(span_candidates))
|
361 |
+
triplet_candidates = list(set(triplet_candidates))
|
362 |
+
|
363 |
+
return candidates, windows_candidates, span_candidates, triplet_candidates
|
364 |
+
|
365 |
+
def _merge_window_pair(
|
366 |
+
self,
|
367 |
+
window1: RelikReaderSample,
|
368 |
+
window2: RelikReaderSample,
|
369 |
+
) -> RelikReaderSample:
|
370 |
+
merging_output = dict()
|
371 |
+
|
372 |
+
if getattr(window1, "doc_id", None) is not None:
|
373 |
+
assert window1.doc_id == window2.doc_id
|
374 |
+
|
375 |
+
if getattr(window1, "offset", None) is not None:
|
376 |
+
assert (
|
377 |
+
window1.offset < window2.offset
|
378 |
+
), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
|
379 |
+
|
380 |
+
merging_output["doc_id"] = window1.doc_id
|
381 |
+
merging_output["offset"] = window2.offset
|
382 |
+
|
383 |
+
m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
|
384 |
+
window1, window2
|
385 |
+
)
|
386 |
+
|
387 |
+
m_words, m_token2word_start, m_token2word_end = self._merge_words(
|
388 |
+
window1, window2
|
389 |
+
)
|
390 |
+
|
391 |
+
(
|
392 |
+
m_candidates,
|
393 |
+
m_windows_candidates,
|
394 |
+
m_span_candidates,
|
395 |
+
m_triplet_candidates,
|
396 |
+
) = self._merge_candidates(window1, window2)
|
397 |
+
|
398 |
+
window_labels = None
|
399 |
+
if getattr(window1, "window_labels", None) is not None:
|
400 |
+
window_labels = self._merge_span_annotation(
|
401 |
+
window1.window_labels, window2.window_labels
|
402 |
+
)
|
403 |
+
|
404 |
+
(
|
405 |
+
predicted_spans,
|
406 |
+
predicted_spans_probs,
|
407 |
+
predicted_triples,
|
408 |
+
predicted_triples_probs,
|
409 |
+
) = self._merge_predictions(window1, window2)
|
410 |
+
|
411 |
+
merging_output.update(
|
412 |
+
dict(
|
413 |
+
tokens=m_tokens,
|
414 |
+
words=m_words,
|
415 |
+
token2char_start=m_token2char_start,
|
416 |
+
token2char_end=m_token2char_end,
|
417 |
+
token2word_start=m_token2word_start,
|
418 |
+
token2word_end=m_token2word_end,
|
419 |
+
window_labels=window_labels,
|
420 |
+
candidates=m_candidates,
|
421 |
+
span_candidates=m_span_candidates,
|
422 |
+
triplet_candidates=m_triplet_candidates,
|
423 |
+
windows_candidates=m_windows_candidates,
|
424 |
+
predicted_spans=predicted_spans,
|
425 |
+
predicted_spans_probs=predicted_spans_probs,
|
426 |
+
predicted_triples=predicted_triples,
|
427 |
+
predicted_triples_probs=predicted_triples_probs,
|
428 |
+
)
|
429 |
+
)
|
430 |
+
|
431 |
+
return RelikReaderSample(**merging_output)
|
relik/inference/gerbil.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
9 |
+
from typing import Iterator, List, Optional, Tuple
|
10 |
+
from urllib import parse
|
11 |
+
|
12 |
+
from relik.inference.annotator import Relik
|
13 |
+
from relik.inference.data.objects import RelikOutput
|
14 |
+
|
15 |
+
# sys.path += ['../']
|
16 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class GerbilAlbyManager:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
annotator: Optional[Relik] = None,
|
26 |
+
response_logger_dir: Optional[str] = None,
|
27 |
+
) -> None:
|
28 |
+
self.annotator = annotator
|
29 |
+
self.response_logger_dir = response_logger_dir
|
30 |
+
self.predictions_counter = 0
|
31 |
+
self.labels_mapping = None
|
32 |
+
|
33 |
+
def annotate(self, document: str):
|
34 |
+
relik_output: RelikOutput = self.annotator(
|
35 |
+
document, retriever_batch_size=2, reader_batch_size=1
|
36 |
+
)
|
37 |
+
annotations = [(ss, se, l) for ss, se, l, _ in relik_output.spans]
|
38 |
+
if self.labels_mapping is not None:
|
39 |
+
return [
|
40 |
+
(ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations
|
41 |
+
]
|
42 |
+
return annotations
|
43 |
+
|
44 |
+
def set_mapping_file(self, mapping_file_path: str):
|
45 |
+
with open(mapping_file_path) as f:
|
46 |
+
labels_mapping = json.load(f)
|
47 |
+
self.labels_mapping = {v: k for k, v in labels_mapping.items()}
|
48 |
+
|
49 |
+
def write_response_bundle(
|
50 |
+
self,
|
51 |
+
document: str,
|
52 |
+
new_document: str,
|
53 |
+
annotations: list,
|
54 |
+
mapped_annotations: list,
|
55 |
+
) -> None:
|
56 |
+
if self.response_logger_dir is None:
|
57 |
+
return
|
58 |
+
|
59 |
+
if not os.path.isdir(self.response_logger_dir):
|
60 |
+
os.mkdir(self.response_logger_dir)
|
61 |
+
|
62 |
+
with open(
|
63 |
+
f"{self.response_logger_dir}/{self.predictions_counter}.json", "w"
|
64 |
+
) as f:
|
65 |
+
out_json_obj = dict(
|
66 |
+
document=document,
|
67 |
+
new_document=new_document,
|
68 |
+
annotations=annotations,
|
69 |
+
mapped_annotations=mapped_annotations,
|
70 |
+
)
|
71 |
+
|
72 |
+
out_json_obj["span_annotations"] = [
|
73 |
+
(ss, se, document[ss:se], label) for (ss, se, label) in annotations
|
74 |
+
]
|
75 |
+
|
76 |
+
out_json_obj["span_mapped_annotations"] = [
|
77 |
+
(ss, se, new_document[ss:se], label)
|
78 |
+
for (ss, se, label) in mapped_annotations
|
79 |
+
]
|
80 |
+
|
81 |
+
json.dump(out_json_obj, f, indent=2)
|
82 |
+
|
83 |
+
self.predictions_counter += 1
|
84 |
+
|
85 |
+
|
86 |
+
manager = GerbilAlbyManager()
|
87 |
+
|
88 |
+
|
89 |
+
def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]:
|
90 |
+
pattern_subs = {
|
91 |
+
"-LPR- ": " (",
|
92 |
+
"-RPR-": ")",
|
93 |
+
"\n\n": "\n",
|
94 |
+
"-LRB-": "(",
|
95 |
+
"-RRB-": ")",
|
96 |
+
'","': ",",
|
97 |
+
}
|
98 |
+
|
99 |
+
document_acc = document
|
100 |
+
curr_offset = 0
|
101 |
+
char2offset = []
|
102 |
+
|
103 |
+
matchings = re.finditer("({})".format("|".join(pattern_subs)), document)
|
104 |
+
for span_matching in sorted(matchings, key=lambda x: x.span()[0]):
|
105 |
+
span_start, span_end = span_matching.span()
|
106 |
+
span_start -= curr_offset
|
107 |
+
span_end -= curr_offset
|
108 |
+
|
109 |
+
span_text = document_acc[span_start:span_end]
|
110 |
+
span_sub = pattern_subs[span_text]
|
111 |
+
document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:]
|
112 |
+
|
113 |
+
offset = len(span_text) - len(span_sub)
|
114 |
+
curr_offset += offset
|
115 |
+
|
116 |
+
char2offset.append((span_start + len(span_sub), curr_offset))
|
117 |
+
|
118 |
+
return document_acc, char2offset
|
119 |
+
|
120 |
+
|
121 |
+
def map_back_annotations(
|
122 |
+
annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]]
|
123 |
+
) -> Iterator[Tuple[int, int, str]]:
|
124 |
+
def map_char(char_idx: int) -> int:
|
125 |
+
current_offset = 0
|
126 |
+
for offset_idx, offset_value in char_mapping:
|
127 |
+
if char_idx >= offset_idx:
|
128 |
+
current_offset = offset_value
|
129 |
+
else:
|
130 |
+
break
|
131 |
+
return char_idx + current_offset
|
132 |
+
|
133 |
+
for ss, se, label in annotations:
|
134 |
+
yield map_char(ss), map_char(se), label
|
135 |
+
|
136 |
+
|
137 |
+
def annotate(document: str) -> List[Tuple[int, int, str]]:
|
138 |
+
new_document, mapping = preprocess_document(document)
|
139 |
+
logger.info("Mapping: " + str(mapping))
|
140 |
+
logger.info("Document: " + str(document))
|
141 |
+
annotations = [
|
142 |
+
(cs, ce, label.replace(" ", "_"))
|
143 |
+
for cs, ce, label in manager.annotate(new_document)
|
144 |
+
]
|
145 |
+
logger.info("New document: " + str(new_document))
|
146 |
+
mapped_annotations = (
|
147 |
+
list(map_back_annotations(annotations, mapping))
|
148 |
+
if len(mapping) > 0
|
149 |
+
else annotations
|
150 |
+
)
|
151 |
+
|
152 |
+
logger.info(
|
153 |
+
"Annotations: "
|
154 |
+
+ str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations])
|
155 |
+
)
|
156 |
+
|
157 |
+
manager.write_response_bundle(
|
158 |
+
document, new_document, mapped_annotations, annotations
|
159 |
+
)
|
160 |
+
|
161 |
+
if not all(
|
162 |
+
[
|
163 |
+
new_document[ss:se] == document[mss:mse]
|
164 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
165 |
+
]
|
166 |
+
):
|
167 |
+
diff_mappings = [
|
168 |
+
(new_document[ss:se], document[mss:mse])
|
169 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
170 |
+
]
|
171 |
+
return None
|
172 |
+
assert all(
|
173 |
+
[
|
174 |
+
document[mss:mse] == new_document[ss:se]
|
175 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
176 |
+
]
|
177 |
+
), (mapped_annotations, annotations)
|
178 |
+
|
179 |
+
return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations]
|
180 |
+
|
181 |
+
|
182 |
+
class GetHandler(BaseHTTPRequestHandler):
|
183 |
+
def do_POST(self):
|
184 |
+
content_length = int(self.headers["Content-Length"])
|
185 |
+
post_data = self.rfile.read(content_length)
|
186 |
+
self.send_response(200)
|
187 |
+
self.end_headers()
|
188 |
+
doc_text = read_json(post_data)
|
189 |
+
# try:
|
190 |
+
response = annotate(doc_text)
|
191 |
+
|
192 |
+
self.wfile.write(bytes(json.dumps(response), "utf-8"))
|
193 |
+
return
|
194 |
+
|
195 |
+
|
196 |
+
def read_json(post_data):
|
197 |
+
data = json.loads(post_data.decode("utf-8"))
|
198 |
+
# logger.info("received data:", data)
|
199 |
+
text = data["text"]
|
200 |
+
# spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
|
201 |
+
return text
|
202 |
+
|
203 |
+
|
204 |
+
def parse_args() -> argparse.Namespace:
|
205 |
+
parser = argparse.ArgumentParser()
|
206 |
+
parser.add_argument("--relik-model-name", required=True)
|
207 |
+
parser.add_argument("--responses-log-dir")
|
208 |
+
parser.add_argument("--log-file", default="experiments/logging.txt")
|
209 |
+
parser.add_argument("--mapping-file")
|
210 |
+
return parser.parse_args()
|
211 |
+
|
212 |
+
|
213 |
+
def main():
|
214 |
+
args = parse_args()
|
215 |
+
|
216 |
+
responses_log_dir = Path(args.responses_log_dir)
|
217 |
+
responses_log_dir.mkdir(parents=True, exist_ok=True)
|
218 |
+
|
219 |
+
# init manager
|
220 |
+
manager.response_logger_dir = args.responses_log_dir
|
221 |
+
manager.annotator = Relik.from_pretrained(
|
222 |
+
args.relik_model_name,
|
223 |
+
device="cuda",
|
224 |
+
# document_index_device="cpu",
|
225 |
+
# document_index_precision="fp32",
|
226 |
+
# reader_device="cpu",
|
227 |
+
precision="fp16", # , reader_device="cpu", reader_precision="fp32"
|
228 |
+
dataset_kwargs={"use_nme": True}
|
229 |
+
)
|
230 |
+
|
231 |
+
# print("Debugging, not using you relik model but an hardcoded one.")
|
232 |
+
# manager.annotator = Relik(
|
233 |
+
# question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
|
234 |
+
# document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
|
235 |
+
# reader="relik/reader/models/relik-reader-deberta-base-new-data",
|
236 |
+
# window_size=32,
|
237 |
+
# window_stride=16,
|
238 |
+
# candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()),
|
239 |
+
# )
|
240 |
+
|
241 |
+
if args.mapping_file is not None:
|
242 |
+
manager.set_mapping_file(args.mapping_file)
|
243 |
+
|
244 |
+
# port = 6654
|
245 |
+
port = 5555
|
246 |
+
server = HTTPServer(("localhost", port), GetHandler)
|
247 |
+
logger.info(f"Starting server at http://localhost:{port}")
|
248 |
+
|
249 |
+
# Create a file handler and set its level
|
250 |
+
file_handler = logging.FileHandler(args.log_file)
|
251 |
+
file_handler.setLevel(logging.DEBUG)
|
252 |
+
|
253 |
+
# Create a log formatter and set it on the handler
|
254 |
+
formatter = logging.Formatter(
|
255 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
256 |
+
)
|
257 |
+
file_handler.setFormatter(formatter)
|
258 |
+
|
259 |
+
# Add the file handler to the logger
|
260 |
+
logger.addHandler(file_handler)
|
261 |
+
|
262 |
+
try:
|
263 |
+
server.serve_forever()
|
264 |
+
except KeyboardInterrupt:
|
265 |
+
exit(0)
|
266 |
+
|
267 |
+
|
268 |
+
if __name__ == "__main__":
|
269 |
+
main()
|
relik/inference/serve/__init__.py
ADDED
File without changes
|
relik/inference/serve/backend/__init__.py
ADDED
File without changes
|
relik/inference/serve/backend/fastapi.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List, Union
|
5 |
+
import psutil
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from relik.common.utils import is_package_available
|
10 |
+
from relik.inference.annotator import Relik
|
11 |
+
|
12 |
+
if not is_package_available("fastapi"):
|
13 |
+
raise ImportError(
|
14 |
+
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
|
15 |
+
)
|
16 |
+
from fastapi import FastAPI, HTTPException, APIRouter
|
17 |
+
|
18 |
+
|
19 |
+
from relik.common.log import get_logger
|
20 |
+
from relik.inference.serve.backend.utils import (
|
21 |
+
RayParameterManager,
|
22 |
+
ServerParameterManager,
|
23 |
+
)
|
24 |
+
|
25 |
+
logger = get_logger(__name__, level=logging.INFO)
|
26 |
+
|
27 |
+
VERSION = {} # type: ignore
|
28 |
+
with open(
|
29 |
+
Path(__file__).parent.parent.parent.parent / "version.py", "r"
|
30 |
+
) as version_file:
|
31 |
+
exec(version_file.read(), VERSION)
|
32 |
+
|
33 |
+
# Env variables for server
|
34 |
+
SERVER_MANAGER = ServerParameterManager()
|
35 |
+
RAY_MANAGER = RayParameterManager()
|
36 |
+
|
37 |
+
|
38 |
+
class RelikServer:
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
relik_pretrained: str | None = None,
|
42 |
+
device: str = "cpu",
|
43 |
+
retriever_device: str | None = None,
|
44 |
+
document_index_device: str | None = None,
|
45 |
+
reader_device: str | None = None,
|
46 |
+
precision: str | int | torch.dtype = 32,
|
47 |
+
retriever_precision: str | int | torch.dtype | None = None,
|
48 |
+
document_index_precision: str | int | torch.dtype | None = None,
|
49 |
+
reader_precision: str | int | torch.dtype | None = None,
|
50 |
+
annotation_type: str = "char",
|
51 |
+
**kwargs,
|
52 |
+
):
|
53 |
+
num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False))
|
54 |
+
torch.set_num_threads(num_threads)
|
55 |
+
logger.info(f"Torch is running on {num_threads} threads.")
|
56 |
+
# parameters
|
57 |
+
logger.info(f"RELIK_PRETRAINED: {relik_pretrained}")
|
58 |
+
self.relik_pretrained = relik_pretrained
|
59 |
+
logger.info(f"DEVICE: {device}")
|
60 |
+
self.device = device
|
61 |
+
if retriever_device is not None:
|
62 |
+
logger.info(f"RETRIEVER_DEVICE: {retriever_device}")
|
63 |
+
self.retriever_device = retriever_device or device
|
64 |
+
if document_index_device is not None:
|
65 |
+
logger.info(f"INDEX_DEVICE: {document_index_device}")
|
66 |
+
self.document_index_device = document_index_device or retriever_device
|
67 |
+
if reader_device is not None:
|
68 |
+
logger.info(f"READER_DEVICE: {reader_device}")
|
69 |
+
self.reader_device = reader_device
|
70 |
+
logger.info(f"PRECISION: {precision}")
|
71 |
+
self.precision = precision
|
72 |
+
if retriever_precision is not None:
|
73 |
+
logger.info(f"RETRIEVER_PRECISION: {retriever_precision}")
|
74 |
+
self.retriever_precision = retriever_precision or precision
|
75 |
+
if document_index_precision is not None:
|
76 |
+
logger.info(f"INDEX_PRECISION: {document_index_precision}")
|
77 |
+
self.document_index_precision = document_index_precision or precision
|
78 |
+
if reader_precision is not None:
|
79 |
+
logger.info(f"READER_PRECISION: {reader_precision}")
|
80 |
+
self.reader_precision = reader_precision or precision
|
81 |
+
logger.info(f"ANNOTATION_TYPE: {annotation_type}")
|
82 |
+
self.annotation_type = annotation_type
|
83 |
+
|
84 |
+
self.relik = Relik.from_pretrained(
|
85 |
+
self.relik_pretrained,
|
86 |
+
device=self.device,
|
87 |
+
retriever_device=self.retriever_device,
|
88 |
+
document_index_device=self.document_index_device,
|
89 |
+
reader_device=self.reader_device,
|
90 |
+
precision=self.precision,
|
91 |
+
retriever_precision=self.retriever_precision,
|
92 |
+
document_index_precision=self.document_index_precision,
|
93 |
+
reader_precision=self.reader_precision,
|
94 |
+
)
|
95 |
+
|
96 |
+
self.router = APIRouter()
|
97 |
+
self.router.add_api_route("/api/relik", self.relik_endpoint, methods=["POST"])
|
98 |
+
|
99 |
+
logger.info("RelikServer initialized.")
|
100 |
+
|
101 |
+
# @serve.batch()
|
102 |
+
async def __call__(self, text: List[str]) -> List:
|
103 |
+
return self.relik(text, annotation_type=self.annotation_type)
|
104 |
+
|
105 |
+
# @app.post("/api/relik")
|
106 |
+
async def relik_endpoint(self, text: Union[str, List[str]]):
|
107 |
+
try:
|
108 |
+
# get predictions for the retriever
|
109 |
+
return await self(text)
|
110 |
+
except Exception as e:
|
111 |
+
# log the entire stack trace
|
112 |
+
logger.exception(e)
|
113 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
114 |
+
|
115 |
+
|
116 |
+
app = FastAPI(
|
117 |
+
title="ReLiK",
|
118 |
+
version=VERSION["VERSION"],
|
119 |
+
description="ReLiK REST API",
|
120 |
+
)
|
121 |
+
server = RelikServer(**vars(SERVER_MANAGER))
|
122 |
+
app.include_router(server.router)
|
relik/inference/serve/backend/ray.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List, Union
|
5 |
+
import psutil
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from relik.common.utils import is_package_available
|
10 |
+
from relik.inference.annotator import Relik
|
11 |
+
|
12 |
+
if not is_package_available("fastapi"):
|
13 |
+
raise ImportError(
|
14 |
+
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
|
15 |
+
)
|
16 |
+
from fastapi import FastAPI, HTTPException
|
17 |
+
|
18 |
+
if not is_package_available("ray"):
|
19 |
+
raise ImportError(
|
20 |
+
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
|
21 |
+
)
|
22 |
+
from ray import serve
|
23 |
+
|
24 |
+
from relik.common.log import get_logger
|
25 |
+
from relik.inference.serve.backend.utils import (
|
26 |
+
RayParameterManager,
|
27 |
+
ServerParameterManager,
|
28 |
+
)
|
29 |
+
|
30 |
+
logger = get_logger(__name__, level=logging.INFO)
|
31 |
+
|
32 |
+
VERSION = {} # type: ignore
|
33 |
+
with open(
|
34 |
+
Path(__file__).parent.parent.parent.parent / "version.py", "r"
|
35 |
+
) as version_file:
|
36 |
+
exec(version_file.read(), VERSION)
|
37 |
+
|
38 |
+
# Env variables for server
|
39 |
+
SERVER_MANAGER = ServerParameterManager()
|
40 |
+
RAY_MANAGER = RayParameterManager()
|
41 |
+
|
42 |
+
app = FastAPI(
|
43 |
+
title="ReLiK",
|
44 |
+
version=VERSION["VERSION"],
|
45 |
+
description="ReLiK REST API",
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
@serve.deployment(
|
50 |
+
ray_actor_options={
|
51 |
+
"num_gpus": RAY_MANAGER.num_gpus
|
52 |
+
if (
|
53 |
+
SERVER_MANAGER.device == "cuda"
|
54 |
+
or SERVER_MANAGER.retriever_device == "cuda"
|
55 |
+
or SERVER_MANAGER.reader_device == "cuda"
|
56 |
+
)
|
57 |
+
else 0
|
58 |
+
},
|
59 |
+
autoscaling_config={
|
60 |
+
"min_replicas": RAY_MANAGER.min_replicas,
|
61 |
+
"max_replicas": RAY_MANAGER.max_replicas,
|
62 |
+
},
|
63 |
+
)
|
64 |
+
@serve.ingress(app)
|
65 |
+
class RelikServer:
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
relik_pretrained: str | None = None,
|
69 |
+
device: str = "cpu",
|
70 |
+
retriever_device: str | None = None,
|
71 |
+
document_index_device: str | None = None,
|
72 |
+
reader_device: str | None = None,
|
73 |
+
precision: str | int | torch.dtype = 32,
|
74 |
+
retriever_precision: str | int | torch.dtype | None = None,
|
75 |
+
document_index_precision: str | int | torch.dtype | None = None,
|
76 |
+
reader_precision: str | int | torch.dtype | None = None,
|
77 |
+
annotation_type: str = "char",
|
78 |
+
retriever_batch_size: int = 32,
|
79 |
+
reader_batch_size: int = 32,
|
80 |
+
relik_config_override: dict | None = None,
|
81 |
+
**kwargs,
|
82 |
+
):
|
83 |
+
num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False))
|
84 |
+
torch.set_num_threads(num_threads)
|
85 |
+
logger.info(f"Torch is running on {num_threads} threads.")
|
86 |
+
|
87 |
+
# parameters
|
88 |
+
logger.info(f"RELIK_PRETRAINED: {relik_pretrained}")
|
89 |
+
self.relik_pretrained = relik_pretrained
|
90 |
+
|
91 |
+
if relik_config_override is None:
|
92 |
+
relik_config_override = {}
|
93 |
+
logger.info(f"RELIK_CONFIG_OVERRIDE: {relik_config_override}")
|
94 |
+
self.relik_config_override = relik_config_override
|
95 |
+
|
96 |
+
logger.info(f"DEVICE: {device}")
|
97 |
+
self.device = device
|
98 |
+
|
99 |
+
if retriever_device is not None:
|
100 |
+
logger.info(f"RETRIEVER_DEVICE: {retriever_device}")
|
101 |
+
self.retriever_device = retriever_device or device
|
102 |
+
|
103 |
+
if document_index_device is not None:
|
104 |
+
logger.info(f"INDEX_DEVICE: {document_index_device}")
|
105 |
+
self.document_index_device = document_index_device or retriever_device
|
106 |
+
|
107 |
+
if reader_device is not None:
|
108 |
+
logger.info(f"READER_DEVICE: {reader_device}")
|
109 |
+
self.reader_device = reader_device
|
110 |
+
|
111 |
+
logger.info(f"PRECISION: {precision}")
|
112 |
+
self.precision = precision
|
113 |
+
|
114 |
+
if retriever_precision is not None:
|
115 |
+
logger.info(f"RETRIEVER_PRECISION: {retriever_precision}")
|
116 |
+
self.retriever_precision = retriever_precision or precision
|
117 |
+
|
118 |
+
if document_index_precision is not None:
|
119 |
+
logger.info(f"INDEX_PRECISION: {document_index_precision}")
|
120 |
+
self.document_index_precision = document_index_precision or precision
|
121 |
+
|
122 |
+
if reader_precision is not None:
|
123 |
+
logger.info(f"READER_PRECISION: {reader_precision}")
|
124 |
+
self.reader_precision = reader_precision or precision
|
125 |
+
|
126 |
+
logger.info(f"ANNOTATION_TYPE: {annotation_type}")
|
127 |
+
self.annotation_type = annotation_type
|
128 |
+
|
129 |
+
self.relik = Relik.from_pretrained(
|
130 |
+
self.relik_pretrained,
|
131 |
+
device=self.device,
|
132 |
+
retriever_device=self.retriever_device,
|
133 |
+
document_index_device=self.document_index_device,
|
134 |
+
reader_device=self.reader_device,
|
135 |
+
precision=self.precision,
|
136 |
+
retriever_precision=self.retriever_precision,
|
137 |
+
document_index_precision=self.document_index_precision,
|
138 |
+
reader_precision=self.reader_precision,
|
139 |
+
**self.relik_config_override,
|
140 |
+
)
|
141 |
+
|
142 |
+
self.retriever_batch_size = retriever_batch_size
|
143 |
+
self.reader_batch_size = reader_batch_size
|
144 |
+
|
145 |
+
# @serve.batch()
|
146 |
+
async def handle_batch(self, text: List[str]) -> List:
|
147 |
+
return self.relik(
|
148 |
+
text,
|
149 |
+
annotation_type=self.annotation_type,
|
150 |
+
retriever_batch_size=self.retriever_batch_size,
|
151 |
+
reader_batch_size=self.reader_batch_size,
|
152 |
+
)
|
153 |
+
|
154 |
+
@app.post("/api/relik")
|
155 |
+
async def relik_endpoint(self, text: Union[str, List[str]]):
|
156 |
+
try:
|
157 |
+
# get predictions for the retriever
|
158 |
+
return await self.handle_batch(text)
|
159 |
+
except Exception as e:
|
160 |
+
# log the entire stack trace
|
161 |
+
logger.exception(e)
|
162 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
163 |
+
|
164 |
+
|
165 |
+
server = RelikServer.bind(**vars(SERVER_MANAGER))
|
relik/inference/serve/backend/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class ServerParameterManager:
|
8 |
+
relik_pretrained: str = os.environ.get("RELIK_PRETRAINED", None)
|
9 |
+
device: str = os.environ.get("DEVICE", "cpu")
|
10 |
+
retriever_device: str | None = os.environ.get("RETRIEVER_DEVICE", None)
|
11 |
+
document_index_device: str | None = os.environ.get("INDEX_DEVICE", None)
|
12 |
+
reader_device: str | None = os.environ.get("READER_DEVICE", None)
|
13 |
+
precision: int | str | None = os.environ.get("PRECISION", "fp32")
|
14 |
+
retriever_precision: int | str | None = os.environ.get("RETRIEVER_PRECISION", None)
|
15 |
+
document_index_precision: int | str | None = os.environ.get("INDEX_PRECISION", None)
|
16 |
+
reader_precision: int | str | None = os.environ.get("READER_PRECISION", None)
|
17 |
+
annotation_type: str = os.environ.get("ANNOTATION_TYPE", "char")
|
18 |
+
question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
|
19 |
+
passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
|
20 |
+
document_index: str = os.environ.get("DOCUMENT_INDEX", None)
|
21 |
+
reader_encoder: str = os.environ.get("READER_ENCODER", None)
|
22 |
+
top_k: int = int(os.environ.get("TOP_K", 100))
|
23 |
+
use_faiss: bool = os.environ.get("USE_FAISS", False)
|
24 |
+
retriever_batch_size: int = int(os.environ.get("RETRIEVER_BATCH_SIZE", 32))
|
25 |
+
reader_batch_size: int = int(os.environ.get("READER_BATCH_SIZE", 32))
|
26 |
+
window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
|
27 |
+
window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
|
28 |
+
split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
|
29 |
+
# relik_config_override: dict = ast.literal_eval(
|
30 |
+
# os.environ.get("RELIK_CONFIG_OVERRIDE", None)
|
31 |
+
# )
|
32 |
+
|
33 |
+
|
34 |
+
class RayParameterManager:
|
35 |
+
def __init__(self) -> None:
|
36 |
+
self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
|
37 |
+
self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
|
38 |
+
self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
|
relik/inference/serve/frontend/__init__.py
ADDED
File without changes
|
relik/inference/serve/frontend/relik_front.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import requests
|
5 |
+
import streamlit as st
|
6 |
+
from spacy import displacy
|
7 |
+
from streamlit_extras.badges import badge
|
8 |
+
from streamlit_extras.stylable_container import stylable_container
|
9 |
+
|
10 |
+
RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
|
11 |
+
|
12 |
+
import random
|
13 |
+
|
14 |
+
|
15 |
+
def get_random_color(ents):
|
16 |
+
colors = {}
|
17 |
+
random_colors = generate_pastel_colors(len(ents))
|
18 |
+
for ent in ents:
|
19 |
+
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
|
20 |
+
return colors
|
21 |
+
|
22 |
+
|
23 |
+
def floatrange(start, stop, steps):
|
24 |
+
if int(steps) == 1:
|
25 |
+
return [stop]
|
26 |
+
return [
|
27 |
+
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
def hsl_to_rgb(h, s, l):
|
32 |
+
def hue_2_rgb(v1, v2, v_h):
|
33 |
+
while v_h < 0.0:
|
34 |
+
v_h += 1.0
|
35 |
+
while v_h > 1.0:
|
36 |
+
v_h -= 1.0
|
37 |
+
if 6 * v_h < 1.0:
|
38 |
+
return v1 + (v2 - v1) * 6.0 * v_h
|
39 |
+
if 2 * v_h < 1.0:
|
40 |
+
return v2
|
41 |
+
if 3 * v_h < 2.0:
|
42 |
+
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
|
43 |
+
return v1
|
44 |
+
|
45 |
+
# if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
|
46 |
+
# if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
|
47 |
+
|
48 |
+
r, b, g = (l * 255,) * 3
|
49 |
+
if s != 0.0:
|
50 |
+
if l < 0.5:
|
51 |
+
var_2 = l * (1.0 + s)
|
52 |
+
else:
|
53 |
+
var_2 = (l + s) - (s * l)
|
54 |
+
var_1 = 2.0 * l - var_2
|
55 |
+
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
|
56 |
+
g = 255 * hue_2_rgb(var_1, var_2, h)
|
57 |
+
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
|
58 |
+
|
59 |
+
return int(round(r)), int(round(g)), int(round(b))
|
60 |
+
|
61 |
+
|
62 |
+
def generate_pastel_colors(n):
|
63 |
+
"""Return different pastel colours.
|
64 |
+
|
65 |
+
Input:
|
66 |
+
n (integer) : The number of colors to return
|
67 |
+
|
68 |
+
Output:
|
69 |
+
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
|
70 |
+
|
71 |
+
Example:
|
72 |
+
>>> print generate_pastel_colors(5)
|
73 |
+
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
|
74 |
+
"""
|
75 |
+
if n == 0:
|
76 |
+
return []
|
77 |
+
|
78 |
+
# To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
|
79 |
+
start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
|
80 |
+
saturation = 1.0
|
81 |
+
lightness = 0.8
|
82 |
+
# We take points around the chromatic circle (hue):
|
83 |
+
# (Note: we generate n+1 colors, then drop the last one ([:-1]) because
|
84 |
+
# it equals the first one (hue 0 = hue 1))
|
85 |
+
return [
|
86 |
+
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
|
87 |
+
for hue in floatrange(start_hue, start_hue + 1, n + 1)
|
88 |
+
][:-1]
|
89 |
+
|
90 |
+
|
91 |
+
def set_sidebar(css):
|
92 |
+
white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
|
93 |
+
with st.sidebar:
|
94 |
+
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
95 |
+
st.image(
|
96 |
+
"http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
|
97 |
+
use_column_width=True,
|
98 |
+
)
|
99 |
+
st.markdown("## ReLiK")
|
100 |
+
st.write(
|
101 |
+
f"""
|
102 |
+
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")}
|
103 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")}
|
104 |
+
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")}
|
105 |
+
""",
|
106 |
+
unsafe_allow_html=True,
|
107 |
+
)
|
108 |
+
st.markdown("## Sapienza NLP")
|
109 |
+
st.write(
|
110 |
+
f"""
|
111 |
+
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")}
|
112 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")}
|
113 |
+
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")}
|
114 |
+
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")}
|
115 |
+
""",
|
116 |
+
unsafe_allow_html=True,
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
def get_el_annotations(response):
|
121 |
+
# swap labels key with ents
|
122 |
+
response["ents"] = response.pop("labels")
|
123 |
+
label_in_text = set(l["label"] for l in response["ents"])
|
124 |
+
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
|
125 |
+
return response, options
|
126 |
+
|
127 |
+
|
128 |
+
def set_intro(css):
|
129 |
+
# intro
|
130 |
+
st.markdown("# ReLik")
|
131 |
+
st.markdown(
|
132 |
+
"### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
|
133 |
+
)
|
134 |
+
# st.markdown(
|
135 |
+
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
136 |
+
# "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
|
137 |
+
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
138 |
+
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
139 |
+
# )
|
140 |
+
badge(type="github", name="sapienzanlp/relik")
|
141 |
+
badge(type="pypi", name="relik")
|
142 |
+
|
143 |
+
|
144 |
+
def run_client():
|
145 |
+
with open(Path(__file__).parent / "style.css") as f:
|
146 |
+
css = f.read()
|
147 |
+
|
148 |
+
st.set_page_config(
|
149 |
+
page_title="ReLik",
|
150 |
+
page_icon="🦮",
|
151 |
+
layout="wide",
|
152 |
+
)
|
153 |
+
set_sidebar(css)
|
154 |
+
set_intro(css)
|
155 |
+
|
156 |
+
# text input
|
157 |
+
text = st.text_area(
|
158 |
+
"Enter Text Below:",
|
159 |
+
value="Obama went to Rome for a quick vacation.",
|
160 |
+
height=200,
|
161 |
+
max_chars=500,
|
162 |
+
)
|
163 |
+
|
164 |
+
with stylable_container(
|
165 |
+
key="annotate_button",
|
166 |
+
css_styles="""
|
167 |
+
button {
|
168 |
+
background-color: #802433;
|
169 |
+
color: white;
|
170 |
+
border-radius: 25px;
|
171 |
+
}
|
172 |
+
""",
|
173 |
+
):
|
174 |
+
submit = st.button("Annotate")
|
175 |
+
# submit = st.button("Run")
|
176 |
+
|
177 |
+
# ReLik API call
|
178 |
+
if submit:
|
179 |
+
text = text.strip()
|
180 |
+
if text:
|
181 |
+
st.markdown("####")
|
182 |
+
st.markdown("#### Entity Linking")
|
183 |
+
with st.spinner(text="In progress"):
|
184 |
+
response = requests.post(RELIK, json=text)
|
185 |
+
if response.status_code != 200:
|
186 |
+
st.error("Error: {}".format(response.status_code))
|
187 |
+
else:
|
188 |
+
response = response.json()
|
189 |
+
|
190 |
+
# Entity Linking
|
191 |
+
# with stylable_container(
|
192 |
+
# key="container_with_border",
|
193 |
+
# css_styles="""
|
194 |
+
# {
|
195 |
+
# border: 1px solid rgba(49, 51, 63, 0.2);
|
196 |
+
# border-radius: 0.5rem;
|
197 |
+
# padding: 0.5rem;
|
198 |
+
# padding-bottom: 2rem;
|
199 |
+
# }
|
200 |
+
# """,
|
201 |
+
# ):
|
202 |
+
# st.markdown("##")
|
203 |
+
dict_of_ents, options = get_el_annotations(response=response)
|
204 |
+
display = displacy.render(
|
205 |
+
dict_of_ents, manual=True, style="ent", options=options
|
206 |
+
)
|
207 |
+
display = display.replace("\n", " ")
|
208 |
+
# wsd_display = re.sub(
|
209 |
+
# r"(wiki::\d+\w)",
|
210 |
+
# r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
|
211 |
+
# language.upper()
|
212 |
+
# ),
|
213 |
+
# wsd_display,
|
214 |
+
# )
|
215 |
+
with st.container():
|
216 |
+
st.write(display, unsafe_allow_html=True)
|
217 |
+
|
218 |
+
st.markdown("####")
|
219 |
+
st.markdown("#### Relation Extraction")
|
220 |
+
|
221 |
+
with st.container():
|
222 |
+
st.write("Coming :)", unsafe_allow_html=True)
|
223 |
+
|
224 |
+
else:
|
225 |
+
st.error("Please enter some text.")
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
run_client()
|
relik/inference/serve/frontend/relik_re_front.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime as dt
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import requests
|
6 |
+
import spacy
|
7 |
+
import streamlit as st
|
8 |
+
import streamlit.components.v1 as components
|
9 |
+
from pyvis.network import Network
|
10 |
+
from spacy import displacy
|
11 |
+
from spacy.tokens import Doc
|
12 |
+
from streamlit_extras.badges import badge
|
13 |
+
from streamlit_extras.stylable_container import stylable_container
|
14 |
+
from utils import get_random_color, visualize_parser
|
15 |
+
|
16 |
+
from relik import Relik
|
17 |
+
|
18 |
+
# RELIK = os.getenv("RELIK", "localhost:8000/api/relik")
|
19 |
+
|
20 |
+
state_variables = {"has_run_free": False, "html_free": ""}
|
21 |
+
|
22 |
+
|
23 |
+
def init_state_variables():
|
24 |
+
for k, v in state_variables.items():
|
25 |
+
if k not in st.session_state:
|
26 |
+
st.session_state[k] = v
|
27 |
+
|
28 |
+
|
29 |
+
def free_reset_session():
|
30 |
+
for k in state_variables:
|
31 |
+
del st.session_state[k]
|
32 |
+
|
33 |
+
|
34 |
+
def generate_graph(dict_ents, response, filename, options):
|
35 |
+
g = Network(
|
36 |
+
width="720px",
|
37 |
+
height="600px",
|
38 |
+
directed=True,
|
39 |
+
notebook=False,
|
40 |
+
bgcolor="#222222",
|
41 |
+
font_color="white",
|
42 |
+
)
|
43 |
+
g.barnes_hut(
|
44 |
+
gravity=-3000,
|
45 |
+
central_gravity=0.3,
|
46 |
+
spring_length=50,
|
47 |
+
spring_strength=0.001,
|
48 |
+
damping=0.09,
|
49 |
+
overlap=0,
|
50 |
+
)
|
51 |
+
for ent in dict_ents:
|
52 |
+
g.add_node(
|
53 |
+
dict_ents[ent][0],
|
54 |
+
label=dict_ents[ent][1],
|
55 |
+
color=options["colors"][dict_ents[ent][0]],
|
56 |
+
title=dict_ents[ent][0],
|
57 |
+
size=15,
|
58 |
+
labelHighlightBold=True,
|
59 |
+
)
|
60 |
+
|
61 |
+
for rel in response.triples:
|
62 |
+
g.add_edge(
|
63 |
+
dict_ents[(rel.subject.start, rel.subject.end)][0],
|
64 |
+
dict_ents[(rel.object.start, rel.object.end)][0],
|
65 |
+
label=rel.label,
|
66 |
+
title=rel.label,
|
67 |
+
)
|
68 |
+
g.show(filename, notebook=False)
|
69 |
+
|
70 |
+
|
71 |
+
def set_sidebar(css):
|
72 |
+
white_link_wrapper = (
|
73 |
+
"<link rel='stylesheet' "
|
74 |
+
"href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
|
75 |
+
)
|
76 |
+
with st.sidebar:
|
77 |
+
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
78 |
+
st.image(
|
79 |
+
"http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
|
80 |
+
use_column_width=True,
|
81 |
+
)
|
82 |
+
st.markdown("## ReLiK")
|
83 |
+
st.write(
|
84 |
+
f"""
|
85 |
+
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")}
|
86 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")}
|
87 |
+
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")}
|
88 |
+
""",
|
89 |
+
unsafe_allow_html=True,
|
90 |
+
)
|
91 |
+
st.markdown("## Sapienza NLP")
|
92 |
+
st.write(
|
93 |
+
f"""
|
94 |
+
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")}
|
95 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")}
|
96 |
+
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")}
|
97 |
+
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")}
|
98 |
+
""",
|
99 |
+
unsafe_allow_html=True,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def get_span_annotations(response):
|
104 |
+
el_link_wrapper = (
|
105 |
+
"<link rel='stylesheet' "
|
106 |
+
"href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'>"
|
107 |
+
"<a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands"
|
108 |
+
" fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> "
|
109 |
+
"{}</span></a>"
|
110 |
+
)
|
111 |
+
tokens = response.tokens
|
112 |
+
labels = ["O"] * len(tokens)
|
113 |
+
dict_ents = {}
|
114 |
+
# make BIO labels
|
115 |
+
for idx, span in enumerate(response.spans):
|
116 |
+
labels[span.start] = (
|
117 |
+
"B-" + span.label + str(idx)
|
118 |
+
if span.label == "NME"
|
119 |
+
else "B-" + el_link_wrapper.format(span.label.replace(" ", "_"), span.label)
|
120 |
+
)
|
121 |
+
for i in range(span.start + 1, span.end):
|
122 |
+
labels[i] = (
|
123 |
+
"I-" + span.label + str(idx)
|
124 |
+
if span.label == "NME"
|
125 |
+
else "I-"
|
126 |
+
+ el_link_wrapper.format(span.label.replace(" ", "_"), span.label)
|
127 |
+
)
|
128 |
+
dict_ents[(span.start, span.end)] = (
|
129 |
+
span.label + str(idx),
|
130 |
+
" ".join(tokens[span.start : span.end]),
|
131 |
+
)
|
132 |
+
unique_labels = set(w[2:] for w in labels if w != "O")
|
133 |
+
options = {"ents": unique_labels, "colors": get_random_color(unique_labels)}
|
134 |
+
return tokens, labels, options, dict_ents
|
135 |
+
|
136 |
+
|
137 |
+
@st.cache_resource()
|
138 |
+
def load_model():
|
139 |
+
return Relik.from_pretrained("riccorl/relik-relation-extraction-nyt-small")
|
140 |
+
|
141 |
+
|
142 |
+
def set_intro(css):
|
143 |
+
# intro
|
144 |
+
st.markdown("# ReLik")
|
145 |
+
st.markdown(
|
146 |
+
"### Retrieve, Read and LinK: Fast and Accurate Entity Linking "
|
147 |
+
"and Relation Extraction on an Academic Budget"
|
148 |
+
)
|
149 |
+
# st.markdown(
|
150 |
+
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
151 |
+
# "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal
|
152 |
+
# _Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing),
|
153 |
+
# which will be presented at LREC 2022 by "
|
154 |
+
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
155 |
+
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it),
|
156 |
+
# and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
157 |
+
# )
|
158 |
+
badge(type="github", name="sapienzanlp/relik")
|
159 |
+
badge(type="pypi", name="relik")
|
160 |
+
|
161 |
+
|
162 |
+
def run_client():
|
163 |
+
with open(Path(__file__).parent / "style.css") as f:
|
164 |
+
css = f.read()
|
165 |
+
|
166 |
+
st.set_page_config(
|
167 |
+
page_title="ReLik",
|
168 |
+
page_icon="🦮",
|
169 |
+
layout="wide",
|
170 |
+
)
|
171 |
+
set_sidebar(css)
|
172 |
+
set_intro(css)
|
173 |
+
|
174 |
+
# text input
|
175 |
+
text = st.text_area(
|
176 |
+
"Enter Text Below:",
|
177 |
+
value="Michael Jordan was one of the best players in the NBA.",
|
178 |
+
height=200,
|
179 |
+
max_chars=1500,
|
180 |
+
)
|
181 |
+
|
182 |
+
with stylable_container(
|
183 |
+
key="annotate_button",
|
184 |
+
css_styles="""
|
185 |
+
button {
|
186 |
+
background-color: #802433;
|
187 |
+
color: white;
|
188 |
+
border-radius: 25px;
|
189 |
+
}
|
190 |
+
""",
|
191 |
+
):
|
192 |
+
submit = st.button("Annotate")
|
193 |
+
|
194 |
+
if "relik_model" not in st.session_state.keys():
|
195 |
+
st.session_state["relik_model"] = load_model()
|
196 |
+
relik_model = st.session_state["relik_model"]
|
197 |
+
init_state_variables()
|
198 |
+
# ReLik API call
|
199 |
+
|
200 |
+
# spacy for span visualization
|
201 |
+
nlp = spacy.blank("xx")
|
202 |
+
|
203 |
+
if submit:
|
204 |
+
text = text.strip()
|
205 |
+
if text:
|
206 |
+
st.session_state["filename"] = str(dt.now().timestamp() * 1000) + ".html"
|
207 |
+
|
208 |
+
with st.spinner(text="In progress"):
|
209 |
+
response = relik_model(text, annotation_type="word", num_workers=0)
|
210 |
+
# response = requests.post(RELIK, json=text)
|
211 |
+
# if response.status_code != 200:
|
212 |
+
# st.error("Error: {}".format(response.status_code))
|
213 |
+
# else:
|
214 |
+
# response = response.json()
|
215 |
+
|
216 |
+
# EL
|
217 |
+
st.markdown("####")
|
218 |
+
st.markdown("#### Entities")
|
219 |
+
tokens, labels, options, dict_ents = get_span_annotations(
|
220 |
+
response=response
|
221 |
+
)
|
222 |
+
doc = Doc(nlp.vocab, words=tokens, ents=labels)
|
223 |
+
display_el = displacy.render(doc, style="ent", options=options)
|
224 |
+
display_el = display_el.replace("\n", " ")
|
225 |
+
# heuristic, prevents split of annotation decorations
|
226 |
+
display_el = display_el.replace(
|
227 |
+
"border-radius: 0.35em;",
|
228 |
+
"border-radius: 0.35em; white-space: nowrap;",
|
229 |
+
)
|
230 |
+
with st.container():
|
231 |
+
st.write(display_el, unsafe_allow_html=True)
|
232 |
+
|
233 |
+
# RE
|
234 |
+
generate_graph(
|
235 |
+
dict_ents, response, st.session_state["filename"], options
|
236 |
+
)
|
237 |
+
HtmlFile = open(st.session_state["filename"], "r", encoding="utf-8")
|
238 |
+
source_code = HtmlFile.read()
|
239 |
+
st.session_state["html_free"] = source_code
|
240 |
+
os.remove(st.session_state["filename"])
|
241 |
+
st.session_state["has_run_free"] = True
|
242 |
+
else:
|
243 |
+
st.error("Please enter some text.")
|
244 |
+
|
245 |
+
if st.session_state["has_run_free"]:
|
246 |
+
st.markdown("#### Relations")
|
247 |
+
components.html(st.session_state["html_free"], width=720, height=600)
|
248 |
+
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
run_client()
|
relik/inference/serve/frontend/style.css
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Sidebar */
|
2 |
+
.eczjsme11 {
|
3 |
+
background-color: #802433;
|
4 |
+
}
|
5 |
+
|
6 |
+
.st-emotion-cache-10oheav h2 {
|
7 |
+
color: white;
|
8 |
+
}
|
9 |
+
|
10 |
+
.st-emotion-cache-10oheav li {
|
11 |
+
color: white;
|
12 |
+
}
|
13 |
+
|
14 |
+
/* Main */
|
15 |
+
a:link {
|
16 |
+
text-decoration: none;
|
17 |
+
color: white;
|
18 |
+
}
|
19 |
+
|
20 |
+
a:visited {
|
21 |
+
text-decoration: none;
|
22 |
+
color: white;
|
23 |
+
}
|
24 |
+
|
25 |
+
a:hover {
|
26 |
+
text-decoration: none;
|
27 |
+
color: rgba(255, 255, 255, 0.871);
|
28 |
+
}
|
29 |
+
|
30 |
+
a:active {
|
31 |
+
text-decoration: none;
|
32 |
+
color: white;
|
33 |
+
}
|
relik/inference/serve/frontend/utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import random
|
3 |
+
from typing import Dict, List, Optional, Union
|
4 |
+
|
5 |
+
import spacy
|
6 |
+
import streamlit as st
|
7 |
+
from spacy import displacy
|
8 |
+
|
9 |
+
|
10 |
+
def get_html(html: str):
|
11 |
+
"""Convert HTML so it can be rendered."""
|
12 |
+
WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
|
13 |
+
# Newlines seem to mess with the rendering
|
14 |
+
html = html.replace("\n", " ")
|
15 |
+
return WRAPPER.format(html)
|
16 |
+
|
17 |
+
|
18 |
+
def get_svg(svg: str, style: str = "", wrap: bool = True):
|
19 |
+
"""Convert an SVG to a base64-encoded image."""
|
20 |
+
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
|
21 |
+
html = f'<img src="data:image/svg+xml;base64,{b64}" style="{style}"/>'
|
22 |
+
return get_html(html) if wrap else html
|
23 |
+
|
24 |
+
|
25 |
+
def visualize_parser(
|
26 |
+
doc: Union[spacy.tokens.Doc, List[Dict[str, str]]],
|
27 |
+
*,
|
28 |
+
title: Optional[str] = None,
|
29 |
+
key: Optional[str] = None,
|
30 |
+
manual: bool = False,
|
31 |
+
displacy_options: Optional[Dict] = None,
|
32 |
+
) -> None:
|
33 |
+
"""Visualizer for dependency parses.
|
34 |
+
|
35 |
+
doc (Doc, List): The document to visualize.
|
36 |
+
key (str): Key used for the streamlit component for selecting labels.
|
37 |
+
title (str): The title displayed at the top of the parser visualization.
|
38 |
+
manual (bool): Flag signifying whether the doc argument is a Doc object or a List of Dicts containing parse information.
|
39 |
+
displacy_options (Dict): Dictionary of options to be passed to the displacy render method for generating the HTML to be rendered.
|
40 |
+
See: https://spacy.io/api/top-level#options-dep
|
41 |
+
"""
|
42 |
+
if displacy_options is None:
|
43 |
+
displacy_options = dict()
|
44 |
+
if title:
|
45 |
+
st.header(title)
|
46 |
+
docs = [doc]
|
47 |
+
# add selected options to options provided by user
|
48 |
+
# `options` from `displacy_options` are overwritten by user provided
|
49 |
+
# options from the checkboxes
|
50 |
+
for sent in docs:
|
51 |
+
html = displacy.render(
|
52 |
+
sent, options=displacy_options, style="dep", manual=manual
|
53 |
+
)
|
54 |
+
# Double newlines seem to mess with the rendering
|
55 |
+
html = html.replace("\n\n", "\n")
|
56 |
+
st.write(get_svg(html), unsafe_allow_html=True)
|
57 |
+
|
58 |
+
|
59 |
+
def get_random_color(ents):
|
60 |
+
colors = {}
|
61 |
+
random_colors = generate_pastel_colors(len(ents))
|
62 |
+
for ent in ents:
|
63 |
+
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
|
64 |
+
return colors
|
65 |
+
|
66 |
+
|
67 |
+
def floatrange(start, stop, steps):
|
68 |
+
if int(steps) == 1:
|
69 |
+
return [stop]
|
70 |
+
return [
|
71 |
+
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
|
72 |
+
]
|
73 |
+
|
74 |
+
|
75 |
+
def hsl_to_rgb(h, s, l):
|
76 |
+
def hue_2_rgb(v1, v2, v_h):
|
77 |
+
while v_h < 0.0:
|
78 |
+
v_h += 1.0
|
79 |
+
while v_h > 1.0:
|
80 |
+
v_h -= 1.0
|
81 |
+
if 6 * v_h < 1.0:
|
82 |
+
return v1 + (v2 - v1) * 6.0 * v_h
|
83 |
+
if 2 * v_h < 1.0:
|
84 |
+
return v2
|
85 |
+
if 3 * v_h < 2.0:
|
86 |
+
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
|
87 |
+
return v1
|
88 |
+
|
89 |
+
# if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
|
90 |
+
# if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
|
91 |
+
|
92 |
+
r, b, g = (l * 255,) * 3
|
93 |
+
if s != 0.0:
|
94 |
+
if l < 0.5:
|
95 |
+
var_2 = l * (1.0 + s)
|
96 |
+
else:
|
97 |
+
var_2 = (l + s) - (s * l)
|
98 |
+
var_1 = 2.0 * l - var_2
|
99 |
+
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
|
100 |
+
g = 255 * hue_2_rgb(var_1, var_2, h)
|
101 |
+
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
|
102 |
+
|
103 |
+
return int(round(r)), int(round(g)), int(round(b))
|
104 |
+
|
105 |
+
|
106 |
+
def generate_pastel_colors(n):
|
107 |
+
"""Return different pastel colours.
|
108 |
+
|
109 |
+
Input:
|
110 |
+
n (integer) : The number of colors to return
|
111 |
+
|
112 |
+
Output:
|
113 |
+
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
|
114 |
+
|
115 |
+
Example:
|
116 |
+
>>> print generate_pastel_colors(5)
|
117 |
+
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
|
118 |
+
"""
|
119 |
+
if n == 0:
|
120 |
+
return []
|
121 |
+
|
122 |
+
# To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
|
123 |
+
start_hue = 0.0 # 0=red 1/3=0.333=green 2/3=0.666=blue
|
124 |
+
saturation = 1.0
|
125 |
+
lightness = 0.9
|
126 |
+
# We take points around the chromatic circle (hue):
|
127 |
+
# (Note: we generate n+1 colors, then drop the last one ([:-1]) because
|
128 |
+
# it equals the first one (hue 0 = hue 1))
|
129 |
+
return [
|
130 |
+
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
|
131 |
+
for hue in floatrange(start_hue, start_hue + 1, n + 1)
|
132 |
+
][:-1]
|