diff --git a/relik/__init__.py b/relik/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9da9aebaf01ea90203d6d02bc2fb82e6fa55dfd0 --- /dev/null +++ b/relik/__init__.py @@ -0,0 +1,8 @@ +from relik.inference.annotator import Relik +from pathlib import Path + +VERSION = {} # type: ignore +with open(Path(__file__).parent / "version.py", "r") as version_file: + exec(version_file.read(), VERSION) + +__version__ = VERSION["VERSION"] diff --git a/relik/common/__init__.py b/relik/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/common/__pycache__/__init__.cpython-310.pyc b/relik/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9c22f01aee443fdf371da8b384e2b573840baa5 Binary files /dev/null and b/relik/common/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/common/__pycache__/log.cpython-310.pyc b/relik/common/__pycache__/log.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb8a663dec46edda8ba3d810a40ec530a288fe6e Binary files /dev/null and b/relik/common/__pycache__/log.cpython-310.pyc differ diff --git a/relik/common/__pycache__/torch_utils.cpython-310.pyc b/relik/common/__pycache__/torch_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f268c2c907a9f7e0bb777823cefa6d4c34c31c0b Binary files /dev/null and b/relik/common/__pycache__/torch_utils.cpython-310.pyc differ diff --git a/relik/common/__pycache__/upload.cpython-310.pyc b/relik/common/__pycache__/upload.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f6c0a3477b08a74c9db452159f5f749ac9299c Binary files /dev/null and b/relik/common/__pycache__/upload.cpython-310.pyc differ diff --git a/relik/common/__pycache__/utils.cpython-310.pyc b/relik/common/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4979603754ccdab9a366000142b80e4833adcd5 Binary files /dev/null and b/relik/common/__pycache__/utils.cpython-310.pyc differ diff --git a/relik/common/log.py b/relik/common/log.py new file mode 100644 index 0000000000000000000000000000000000000000..c3195822c85cc389db74dd454a8b8d759cfafb5c --- /dev/null +++ b/relik/common/log.py @@ -0,0 +1,174 @@ +import logging +import os +import sys +import threading +from logging.config import dictConfig +from typing import Any, Dict, Optional + +from art import text2art, tprint +from colorama import Fore, Style, init +from rich import get_console + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +_default_log_level = logging.WARNING + +# fancy logger +_console = get_console() + + +class ColorfulFormatter(logging.Formatter): + """ + Formatter to add coloring to log messages by log type + """ + + COLORS = { + "WARNING": Fore.YELLOW, + "ERROR": Fore.RED, + "CRITICAL": Fore.RED + Style.BRIGHT, + "DEBUG": Fore.CYAN, + # "INFO": Fore.GREEN, + } + + def format(self, record): + record.rank = int(os.getenv("LOCAL_RANK", "0")) + log_message = super().format(record) + return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET + + +DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { + "version": 1, + "formatters": { + "simple": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", + }, + "colorful": { + "()": ColorfulFormatter, + "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s", + }, + }, + "filters": {}, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "simple", + "filters": [], + "stream": sys.stdout, + }, + "color_console": { + "class": "logging.StreamHandler", + "formatter": "colorful", + "filters": [], + "stream": sys.stdout, + }, + }, + "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, + "loggers": { + "relik": { + "handlers": ["color_console"], + "level": "DEBUG", + "propagate": False, + }, + }, +} + + +def configure_logging(**kwargs): + """Configure with default logging""" + init() # Initialize colorama + # merge DEFAULT_LOGGING_CONFIG with kwargs + logger_config = DEFAULT_LOGGING_CONFIG + if kwargs: + logger_config.update(kwargs) + dictConfig(logger_config) + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_default_log_level) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def set_log_level(level: int, logger: logging.Logger = None) -> None: + """ + Set the log level. + Args: + level (:obj:`int`): + Logging level. + logger (:obj:`logging.Logger`): + Logger to set the log level. + """ + if not logger: + _configure_library_root_logger() + logger = _get_library_root_logger() + logger.setLevel(level) + + +def get_logger( + name: Optional[str] = None, + level: Optional[int] = None, + formatter: Optional[str] = None, + **kwargs, +) -> logging.Logger: + """ + Return a logger with the specified name. + """ + + configure_logging(**kwargs) + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + + if level is not None: + set_log_level(level) + + if formatter is None: + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(name)s - %(message)s" + ) + _default_handler.setFormatter(formatter) + + return logging.getLogger(name) + + +def get_console_logger(): + return _console + + +def print_relik_text_art(text: str = "relik", font: str = "larry3d", **kwargs): + tprint(text, font=font, **kwargs) diff --git a/relik/common/torch_utils.py b/relik/common/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73946c6c60e720461f9df93fd46085b5f40a01ca --- /dev/null +++ b/relik/common/torch_utils.py @@ -0,0 +1,82 @@ +import contextlib +import tempfile + +import torch +import transformers as tr + +from relik.common.utils import is_package_available + +# check if ORT is available +if is_package_available("onnxruntime"): + from optimum.onnxruntime import ( + ORTModel, + ORTModelForCustomTasks, + ORTModelForSequenceClassification, + ORTOptimizer, + ) + from optimum.onnxruntime.configuration import AutoOptimizationConfig + +# from relik.retriever.pytorch_modules import PRECISION_MAP + + +def get_autocast_context( + device: str | torch.device, precision: str +) -> contextlib.AbstractContextManager: + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(device).split(":")[0] + + from relik.retriever.pytorch_modules import PRECISION_MAP + + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_manager = ( + contextlib.nullcontext() + if device_type_for_autocast in ["cpu", "mps"] + and PRECISION_MAP[precision] != torch.bfloat16 + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=PRECISION_MAP[precision], + ) + ) + ) + return autocast_manager + + +# def load_ort_optimized_hf_model( +# hf_model: tr.PreTrainedModel, +# provider: str = "CPUExecutionProvider", +# ort_model_type: callable = "ORTModelForCustomTasks", +# ) -> ORTModel: +# """ +# Load an optimized ONNX Runtime HF model. +# +# Args: +# hf_model (`tr.PreTrainedModel`): +# The HF model to optimize. +# provider (`str`, optional): +# The ONNX Runtime provider to use. Defaults to "CPUExecutionProvider". +# +# Returns: +# `ORTModel`: The optimized HF model. +# """ +# if isinstance(hf_model, ORTModel): +# return hf_model +# temp_dir = tempfile.mkdtemp() +# hf_model.save_pretrained(temp_dir) +# ort_model = ort_model_type.from_pretrained( +# temp_dir, export=True, provider=provider, use_io_binding=True +# ) +# if is_package_available("onnxruntime"): +# optimizer = ORTOptimizer.from_pretrained(ort_model) +# optimization_config = AutoOptimizationConfig.O4() +# optimizer.optimize(save_dir=temp_dir, optimization_config=optimization_config) +# ort_model = ort_model_type.from_pretrained( +# temp_dir, +# export=True, +# provider=provider, +# use_io_binding=bool(provider == "CUDAExecutionProvider"), +# ) +# return ort_model +# else: +# raise ValueError("onnxruntime is not installed. Please install Ray with `pip install relik[serve]`.") diff --git a/relik/common/upload.py b/relik/common/upload.py new file mode 100644 index 0000000000000000000000000000000000000000..357b6595c11610ef416078e9a186da98204930f1 --- /dev/null +++ b/relik/common/upload.py @@ -0,0 +1,144 @@ +import argparse +import json +import logging +import os +import tempfile +import zipfile +from datetime import datetime +from pathlib import Path +from typing import Optional, Union + +import huggingface_hub + +from relik.common.log import get_logger +from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5 + +logger = get_logger(__name__, level=logging.DEBUG) + + +def create_info_file(tmpdir: Path): + logger.debug("Computing md5 of model.zip") + md5 = get_md5(tmpdir / "model.zip") + date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT) + + logger.debug("Dumping info.json file") + with (tmpdir / "info.json").open("w") as f: + json.dump(dict(md5=md5, upload_date=date), f, indent=2) + + +def zip_run( + dir_path: Union[str, os.PathLike], + tmpdir: Union[str, os.PathLike], + zip_name: str = "model.zip", +) -> Path: + logger.debug(f"zipping {dir_path} to {tmpdir}") + # creates a zip version of the provided dir_path + run_dir = Path(dir_path) + zip_path = tmpdir / zip_name + + with zipfile.ZipFile(zip_path, "w") as zip_file: + # fully zip the run directory maintaining its structure + for file in run_dir.rglob("*.*"): + if file.is_dir(): + continue + + zip_file.write(file, arcname=file.relative_to(run_dir)) + + return zip_path + + +def get_logged_in_username(): + token = huggingface_hub.HfFolder.get_token() + if token is None: + raise ValueError( + "No HuggingFace token found. You need to execute `huggingface-cli login` first!" + ) + api = huggingface_hub.HfApi() + user = api.whoami(token=token) + return user["name"] + + +def upload( + model_dir: Union[str, os.PathLike], + model_name: str, + filenames: Optional[list[str]] = None, + organization: Optional[str] = None, + repo_name: Optional[str] = None, + commit: Optional[str] = None, + archive: bool = False, +): + token = huggingface_hub.HfFolder.get_token() + if token is None: + raise ValueError( + "No HuggingFace token found. You need to execute `huggingface-cli login` first!" + ) + + repo_id = repo_name or model_name + if organization is not None: + repo_id = f"{organization}/{repo_id}" + with tempfile.TemporaryDirectory() as tmpdir: + api = huggingface_hub.HfApi() + repo_url = api.create_repo( + token=token, + repo_id=repo_id, + exist_ok=True, + ) + repo = huggingface_hub.Repository( + str(tmpdir), clone_from=repo_url, use_auth_token=token + ) + + tmp_path = Path(tmpdir) + if archive: + # otherwise we zip the model_dir + logger.debug(f"Zipping {model_dir} to {tmp_path}") + zip_run(model_dir, tmp_path) + create_info_file(tmp_path) + else: + # if the user wants to upload a transformers model, we don't need to zip it + # we just need to copy the files to the tmpdir + logger.debug(f"Copying {model_dir} to {tmpdir}") + # copy only the files that are needed + if filenames is not None: + for filename in filenames: + os.system(f"cp {model_dir}/{filename} {tmpdir}") + else: + os.system(f"cp -r {model_dir}/* {tmpdir}") + + # this method automatically puts large files (>10MB) into git lfs + repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "model_dir", help="The directory of the model you want to upload" + ) + parser.add_argument("model_name", help="The model you want to upload") + parser.add_argument( + "--organization", + help="the name of the organization where you want to upload the model", + ) + parser.add_argument( + "--repo_name", + help="Optional name to use when uploading to the HuggingFace repository", + ) + parser.add_argument( + "--commit", help="Commit message to use when pushing to the HuggingFace Hub" + ) + parser.add_argument( + "--archive", + action="store_true", + help=""" + Whether to compress the model directory before uploading it. + If True, the model directory will be zipped and the zip file will be uploaded. + If False, the model directory will be uploaded as is.""", + ) + return parser.parse_args() + + +def main(): + upload(**vars(parse_args())) + + +if __name__ == "__main__": + main() diff --git a/relik/common/utils.py b/relik/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d014c08b79c940d0701292ea3adb1886462cf25 --- /dev/null +++ b/relik/common/utils.py @@ -0,0 +1,610 @@ +import importlib.util +import json +import logging +import os +import shutil +import tarfile +import tempfile +from functools import partial +from hashlib import sha256 +from pathlib import Path +from typing import Any, BinaryIO, Dict, List, Optional, Union +from urllib.parse import urlparse +from zipfile import ZipFile, is_zipfile + +import huggingface_hub +import requests +import tqdm +from filelock import FileLock +from transformers.utils.hub import cached_file as hf_cached_file + +from relik.common.log import get_logger + +# name constants +WEIGHTS_NAME = "weights.pt" +ONNX_WEIGHTS_NAME = "weights.onnx" +CONFIG_NAME = "config.yaml" +LABELS_NAME = "labels.json" + +# SAPIENZANLP_USER_NAME = "sapienzanlp" +SAPIENZANLP_USER_NAME = "riccorl" +SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}" +SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = ( + f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip" +) +# path constants +HF_CACHE_DIR = Path(os.getenv("HF_HOME", Path.home() / ".cache/huggingface/hub")) +SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", HF_CACHE_DIR) +SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S" + +logger = get_logger(__name__) + + +def sapienzanlp_model_urls(model_id: str) -> str: + """ + Returns the URL for a possible SapienzaNLP valid model. + + Args: + model_id (:obj:`str`): + A SapienzaNLP model id. + + Returns: + :obj:`str`: The url for the model id. + """ + # check if there is already the namespace of the user + if "/" in model_id: + return model_id + return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id) + + +def is_package_available(package_name: str) -> bool: + """ + Check if a package is available. + + Args: + package_name (`str`): The name of the package to check. + """ + return importlib.util.find_spec(package_name) is not None + + +def load_json(path: Union[str, Path]) -> Any: + """ + Load a json file provided in input. + + Args: + path (`Union[str, Path]`): The path to the json file to load. + + Returns: + `Any`: The loaded json file. + """ + with open(path, encoding="utf8") as f: + return json.load(f) + + +def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None): + """ + Dump input to json file. + + Args: + document (`Any`): The document to dump. + path (`Union[str, Path]`): The path to dump the document to. + indent (`Optional[int]`): The indent to use for the json file. + + """ + with open(path, "w", encoding="utf8") as outfile: + json.dump(document, outfile, indent=indent) + + +def get_md5(path: Path): + """ + Get the MD5 value of a path. + """ + import hashlib + + with path.open("rb") as fin: + data = fin.read() + return hashlib.md5(data).hexdigest() + + +def file_exists(path: Union[str, os.PathLike]) -> bool: + """ + Check if the file at :obj:`path` exists. + + Args: + path (:obj:`str`, :obj:`os.PathLike`): + Path to check. + + Returns: + :obj:`bool`: :obj:`True` if the file exists. + """ + return Path(path).exists() + + +def dir_exists(path: Union[str, os.PathLike]) -> bool: + """ + Check if the directory at :obj:`path` exists. + + Args: + path (:obj:`str`, :obj:`os.PathLike`): + Path to check. + + Returns: + :obj:`bool`: :obj:`True` if the directory exists. + """ + return Path(path).is_dir() + + +def is_remote_url(url_or_filename: Union[str, Path]): + """ + Returns :obj:`True` if the input path is an url. + + Args: + url_or_filename (:obj:`str`, :obj:`Path`): + path to check. + + Returns: + :obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise. + + """ + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def url_to_filename(resource: str, etag: str = None) -> str: + """ + Convert a `resource` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the resources's, delimited + by a period. + """ + resource_bytes = resource.encode("utf-8") + resource_hash = sha256(resource_bytes) + filename = resource_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + return filename + + +def download_resource( + url: str, + temp_file: BinaryIO, + headers=None, +): + """ + Download remote file. + """ + + if headers is None: + headers = {} + + r = requests.get(url, stream=True, headers=headers) + r.raise_for_status() + content_length = r.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm( + unit="B", + unit_scale=True, + total=total, + desc="Downloading", + disable=logger.level in [logging.NOTSET], + ) + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def download_and_cache( + url: Union[str, Path], + cache_dir: Union[str, Path] = None, + force_download: bool = False, +): + if cache_dir is None: + cache_dir = SAPIENZANLP_CACHE_DIR + if isinstance(url, Path): + url = str(url) + + # check if cache dir exists + Path(cache_dir).mkdir(parents=True, exist_ok=True) + + # check if file is private + headers = {} + try: + r = requests.head(url, allow_redirects=False, timeout=10) + r.raise_for_status() + except requests.exceptions.HTTPError: + if r.status_code == 401: + hf_token = huggingface_hub.HfFolder.get_token() + if hf_token is None: + raise ValueError( + "You need to login to HuggingFace to download this model " + "(use the `huggingface-cli login` command)" + ) + headers["Authorization"] = f"Bearer {hf_token}" + + etag = None + try: + r = requests.head(url, allow_redirects=True, timeout=10, headers=headers) + r.raise_for_status() + etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") + # We favor a custom header indicating the etag of the linked resource, and + # we fallback to the regular etag header. + # If we don't have any of those, raise an error. + if etag is None: + raise OSError( + "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." + ) + # In case of a redirect, + # save an extra redirect on the request.get call, + # and ensure we download the exact atomic version even if it changed + # between the HEAD and the GET (unlikely, but hey). + if 300 <= r.status_code <= 399: + url = r.headers["Location"] + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + # Otherwise, our Internet connection is down. + # etag is None + pass + + # get filename from the url + filename = url_to_filename(url, etag) + # get cache path to put the file + cache_path = cache_dir / filename + + # the file is already here, return it + if file_exists(cache_path) and not force_download: + logger.info( + f"{url} found in cache, set `force_download=True` to force the download" + ) + return cache_path + + cache_path = str(cache_path) + # Prevent parallel downloads of the same file with a lock. + lock_path = cache_path + ".lock" + with FileLock(lock_path): + # If the download just completed while the lock was activated. + if file_exists(cache_path) and not force_download: + # Even if returning early like here, the lock will be released. + return cache_path + + temp_file_manager = partial( + tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False + ) + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise, you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info( + f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}" + ) + download_resource(url, temp_file, headers) + + logger.info(f"storing {url} in cache at {cache_path}") + os.replace(temp_file.name, cache_path) + + # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it. + umask = os.umask(0o666) + os.umask(umask) + os.chmod(cache_path, 0o666 & ~umask) + + logger.info(f"creating metadata file for {cache_path}") + meta = {"url": url} # , "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) + + return cache_path + + +def download_from_hf( + path_or_repo_id: Union[str, Path], + filenames: List[str], + cache_dir: Union[str, Path] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + subfolder: str = "", + repo_type: str = "model", +): + if isinstance(path_or_repo_id, Path): + path_or_repo_id = str(path_or_repo_id) + + downloaded_paths = [] + for filename in filenames: + downloaded_path = hf_cached_file( + path_or_repo_id, + filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + ) + downloaded_paths.append(downloaded_path) + + # we want the folder where the files are downloaded + # the best guess is the parent folder of the first file + probably_the_folder = Path(downloaded_paths[0]).parent + return probably_the_folder + + +def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str: + """ + Resolve a model name or directory to a model archive name or directory. + + Args: + model_name_or_dir (:obj:`str` or :obj:`os.PathLike`): + A model name or directory. + + Returns: + :obj:`str`: The model archive name or directory. + """ + if is_remote_url(model_name_or_dir): + # if model_name_or_dir is a URL + # download it and try to load + model_archive = model_name_or_dir + elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file(): + # if model_name_or_dir is a local directory or + # an archive file try to load it + model_archive = model_name_or_dir + else: + # probably model_name_or_dir is a sapienzanlp model id + # guess the url and try to download + model_name_or_dir_ = model_name_or_dir + # raise ValueError(f"Providing a model id is not supported yet.") + model_archive = sapienzanlp_model_urls(model_name_or_dir_) + + return model_archive + + +def from_cache( + url_or_filename: Union[str, Path], + cache_dir: Union[str, Path] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + subfolder: str = "", + filenames: Optional[List[str]] = None, +) -> Path: + """ + Given something that could be either a local path or a URL (or a SapienzaNLP model id), + determine which one and return a path to the corresponding file. + + Args: + url_or_filename (:obj:`str` or :obj:`Path`): + A path to a local file or a URL (or a SapienzaNLP model id). + cache_dir (:obj:`str` or :obj:`Path`, `optional`): + Path to a directory in which a downloaded file will be cached. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to re-download the file even if it already exists. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Attempts to resume the download if such a file + exists. + proxies (:obj:`Dict[str, str]`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (:obj:`Union[bool, str]`, `optional`): + Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from + :obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token. + revision (:obj:`str`, `optional`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to raise an error if the file to be downloaded is local. + subfolder (:obj:`str`, `optional`): + In case the relevant file is in a subfolder of the URL, specify it here. + filenames (:obj:`List[str]`, `optional`): + List of filenames to look for in the directory structure. + + Returns: + :obj:`Path`: Path to the cached file. + """ + + url_or_filename = model_name_or_path_resolver(url_or_filename) + + if cache_dir is None: + cache_dir = SAPIENZANLP_CACHE_DIR + + if file_exists(url_or_filename): + logger.info(f"{url_or_filename} is a local path or file") + output_path = url_or_filename + elif is_remote_url(url_or_filename): + # URL, so get it from the cache (downloading if necessary) + output_path = download_and_cache( + url_or_filename, + cache_dir=cache_dir, + force_download=force_download, + ) + else: + if filenames is None: + filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME] + output_path = download_from_hf( + url_or_filename, + filenames, + cache_dir, + force_download, + resume_download, + proxies, + use_auth_token, + revision, + local_files_only, + subfolder, + ) + + # if is_hf_hub_url(url_or_filename): + # HuggingFace Hub + # output_path = hf_hub_download_url(url_or_filename) + # elif is_remote_url(url_or_filename): + # # URL, so get it from the cache (downloading if necessary) + # output_path = download_and_cache( + # url_or_filename, + # cache_dir=cache_dir, + # force_download=force_download, + # ) + # elif file_exists(url_or_filename): + # logger.info(f"{url_or_filename} is a local path or file") + # # File, and it exists. + # output_path = url_or_filename + # elif urlparse(url_or_filename).scheme == "": + # # File, but it doesn't exist. + # raise EnvironmentError(f"file {url_or_filename} not found") + # else: + # # Something unknown + # raise ValueError( + # f"unable to parse {url_or_filename} as a URL or as a local path" + # ) + + if dir_exists(output_path) or ( + not is_zipfile(output_path) and not tarfile.is_tarfile(output_path) + ): + return Path(output_path) + + # Path where we extract compressed archives + # for now it will extract it in the same folder + # maybe implement extraction in the sapienzanlp folder + # when using local archive path? + logger.info("Extracting compressed archive") + output_dir, output_file = os.path.split(output_path) + output_extract_dir_name = output_file.replace(".", "-") + "-extracted" + output_path_extracted = os.path.join(output_dir, output_extract_dir_name) + + # already extracted, do not extract + if ( + os.path.isdir(output_path_extracted) + and os.listdir(output_path_extracted) + and not force_download + ): + return Path(output_path_extracted) + + # Prevent parallel extractions + lock_path = output_path + ".lock" + with FileLock(lock_path): + shutil.rmtree(output_path_extracted, ignore_errors=True) + os.makedirs(output_path_extracted) + if is_zipfile(output_path): + with ZipFile(output_path, "r") as zip_file: + zip_file.extractall(output_path_extracted) + zip_file.close() + elif tarfile.is_tarfile(output_path): + tar_file = tarfile.open(output_path) + tar_file.extractall(output_path_extracted) + tar_file.close() + else: + raise EnvironmentError( + f"Archive format of {output_path} could not be identified" + ) + + # remove lock file, is it safe? + os.remove(lock_path) + + return Path(output_path_extracted) + + +def is_str_a_path(maybe_path: str) -> bool: + """ + Check if a string is a path. + + Args: + maybe_path (`str`): The string to check. + + Returns: + `bool`: `True` if the string is a path, `False` otherwise. + """ + # first check if it is a path + if Path(maybe_path).exists(): + return True + # check if it is a relative path + if Path(os.path.join(os.getcwd(), maybe_path)).exists(): + return True + # otherwise it is not a path + return False + + +def relative_to_absolute_path(path: str) -> os.PathLike: + """ + Convert a relative path to an absolute path. + + Args: + path (`str`): The relative path to convert. + + Returns: + `os.PathLike`: The absolute path. + """ + if not is_str_a_path(path): + raise ValueError(f"{path} is not a path") + if Path(path).exists(): + return Path(path).absolute() + if Path(os.path.join(os.getcwd(), path)).exists(): + return Path(os.path.join(os.getcwd(), path)).absolute() + raise ValueError(f"{path} is not a path") + + +def to_config(object_to_save: Any) -> Dict[str, Any]: + """ + Convert an object to a dictionary. + + Returns: + `Dict[str, Any]`: The dictionary representation of the object. + """ + + def obj_to_dict(obj): + match obj: + case dict(): + data = {} + for k, v in obj.items(): + data[k] = obj_to_dict(v) + return data + + case list() | tuple(): + return [obj_to_dict(x) for x in obj] + + case object(__dict__=_): + data = { + "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}", + } + for k, v in obj.__dict__.items(): + if not k.startswith("_"): + data[k] = obj_to_dict(v) + return data + + case _: + return obj + + return obj_to_dict(object_to_save) + + +def get_callable_from_string(callable_fn: str) -> Any: + """ + Get a callable from a string. + + Args: + callable_fn (`str`): + The string representation of the callable. + + Returns: + `Any`: The callable. + """ + # separate the function name from the module name + module_name, function_name = callable_fn.rsplit(".", 1) + # import the module + module = importlib.import_module(module_name) + # get the function + return getattr(module, function_name) diff --git a/relik/inference/__init__.py b/relik/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/__pycache__/__init__.cpython-310.pyc b/relik/inference/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6014514f4cc9ed561c684021c3c09b8ca42f886a Binary files /dev/null and b/relik/inference/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/inference/__pycache__/annotator.cpython-310.pyc b/relik/inference/__pycache__/annotator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07dc2f6732bba63967ae2d34ff208e9ea270560b Binary files /dev/null and b/relik/inference/__pycache__/annotator.cpython-310.pyc differ diff --git a/relik/inference/annotator.py b/relik/inference/annotator.py new file mode 100644 index 0000000000000000000000000000000000000000..436a9693bf57ad6114dedab4f3e2a2ddb44e78aa --- /dev/null +++ b/relik/inference/annotator.py @@ -0,0 +1,840 @@ +import logging +import os +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import hydra +import torch +from omegaconf import DictConfig, OmegaConf +from pprintpp import pformat + +from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter +from relik.common.log import get_logger +from relik.common.upload import get_logged_in_username, upload +from relik.common.utils import CONFIG_NAME, from_cache +from relik.inference.data.objects import ( + AnnotationType, + RelikOutput, + Span, + TaskType, + Triples, +) +from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter +from relik.inference.data.splitters.spacy_sentence_splitter import SpacySentenceSplitter +from relik.inference.data.splitters.window_based_splitter import WindowSentenceSplitter +from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer +from relik.inference.data.window.manager import WindowManager +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.pytorch_modules.base import RelikReaderBase +from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction +from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.indexers.document import Document +from relik.retriever.pytorch_modules import PRECISION_MAP +from relik.retriever.pytorch_modules.model import GoldenRetriever + +# set tokenizers parallelism to False + +os.environ["TOKENIZERS_PARALLELISM"] = os.getenv("TOKENIZERS_PARALLELISM", "false") + +LOG_QUERY = os.getenv("RELIK_LOG_QUERY_ON_FILE", "false").lower() == "true" + +logger = get_logger(__name__, level=logging.INFO) +file_logger = None +if LOG_QUERY: + RELIK_LOG_PATH = Path(__file__).parent.parent.parent / "relik.log" + # create file handler which logs even debug messages + fh = logging.FileHandler(RELIK_LOG_PATH) + fh.setLevel(logging.INFO) + file_logger = get_logger("relik", level=logging.INFO) + file_logger.addHandler(fh) + + +class Relik: + """ + Relik main class. It is a wrapper around a retriever and a reader. + + Args: + retriever (:obj:`GoldenRetriever`): + The retriever to use. + reader (:obj:`RelikReaderBase`): + The reader to use. + document_index (:obj:`BaseDocumentIndex`, `optional`): + The document index to use. If `None`, the retriever's document index will be used. + device (`str`, `optional`, defaults to `cpu`): + The device to use for both the retriever and the reader. + retriever_device (`str`, `optional`, defaults to `None`): + The device to use for the retriever. If `None`, the `device` argument will be used. + document_index_device (`str`, `optional`, defaults to `None`): + The device to use for the document index. If `None`, the `device` argument will be used. + reader_device (`str`, `optional`, defaults to `None`): + The device to use for the reader. If `None`, the `device` argument will be used. + precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `32`): + The precision to use for both the retriever and the reader. + retriever_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`): + The precision to use for the retriever. If `None`, the `precision` argument will be used. + document_index_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`): + The precision to use for the document index. If `None`, the `precision` argument will be used. + reader_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`): + The precision to use for the reader. If `None`, the `precision` argument will be used. + metadata_fields (`list[str]`, `optional`, defaults to `None`): + The fields to add to the candidates for the reader. + top_k (`int`, `optional`, defaults to `None`): + The number of candidates to retrieve for each window. + window_size (`int`, `optional`, defaults to `None`): + The size of the window. If `None`, the whole text will be annotated. + window_stride (`int`, `optional`, defaults to `None`): + The stride of the window. If `None`, there will be no overlap between windows. + **kwargs: + Additional keyword arguments to pass to the retriever and the reader. + """ + + def __init__( + self, + retriever: GoldenRetriever | DictConfig | Dict | None = None, + reader: RelikReaderBase | DictConfig | None = None, + device: str | None = None, + retriever_device: str | None = None, + document_index_device: str | None = None, + reader_device: str | None = None, + precision: int | str | torch.dtype | None = None, + retriever_precision: int | str | torch.dtype | None = None, + document_index_precision: int | str | torch.dtype | None = None, + reader_precision: int | str | torch.dtype | None = None, + task: TaskType | str = TaskType.SPAN, + metadata_fields: list[str] | None = None, + top_k: int | None = None, + window_size: int | str | None = None, + window_stride: int | None = None, + retriever_kwargs: Dict[str, Any] | None = None, + reader_kwargs: Dict[str, Any] | None = None, + **kwargs, + ) -> None: + # parse task into a TaskType + if isinstance(task, str): + try: + task = TaskType(task.lower()) + except ValueError: + raise ValueError( + f"Task `{task}` not recognized. " + f"Please choose one of {list(TaskType)}." + ) + self.task = task + + # organize devices + if device is not None: + if retriever_device is None: + retriever_device = device + if document_index_device is None: + document_index_device = device + if reader_device is None: + reader_device = device + + # organize precision + if precision is not None: + if retriever_precision is None: + retriever_precision = precision + if document_index_precision is None: + document_index_precision = precision + if reader_precision is None: + reader_precision = precision + + # retriever + self.retriever: Dict[TaskType, GoldenRetriever] = { + TaskType.SPAN: None, + TaskType.TRIPLET: None, + } + + if retriever: + # check retriever type, it can be a GoldenRetriever, a DictConfig or a Dict + if not isinstance(retriever, (GoldenRetriever, DictConfig, Dict)): + raise ValueError( + f"`retriever` must be a `GoldenRetriever`, a `DictConfig` or " + f"a `Dict`, got `{type(retriever)}`." + ) + + # we need to check weather the DictConfig is a DictConfig for an instance of GoldenRetriever + # or a primitive Dict + if isinstance(retriever, DictConfig): + # then it is probably a primitive Dict + if "_target_" not in retriever: + retriever = OmegaConf.to_container(retriever, resolve=True) + # convert the key to TaskType + try: + retriever = { + TaskType(k.lower()): v for k, v in retriever.items() + } + except ValueError as e: + raise ValueError( + f"Please choose a valid task type (one of {list(TaskType)}) for each retriever." + ) from e + + if isinstance(retriever, Dict): + # convert the key to TaskType + retriever = {TaskType(k): v for k, v in retriever.items()} + else: + retriever = {task: retriever} + + # instantiate each retriever + if self.task in [TaskType.SPAN, TaskType.BOTH]: + self.retriever[TaskType.SPAN] = self._instantiate_retriever( + retriever[TaskType.SPAN], + retriever_device, + retriever_precision, + None, + document_index_device, + document_index_precision, + ) + if self.task in [TaskType.TRIPLET, TaskType.BOTH]: + self.retriever[TaskType.TRIPLET] = self._instantiate_retriever( + retriever[TaskType.TRIPLET], + retriever_device, + retriever_precision, + None, + document_index_device, + document_index_precision, + ) + + # clean up None retrievers from the dictionary + self.retriever = { + task_type: r for task_type, r in self.retriever.items() if r is not None + } + # torch compile + # self.retriever = {task_type: torch.compile(r, backend="onnxrt") for task_type, r in self.retriever.items()} + + # reader + self.reader: RelikReaderBase | None = None + if reader: + reader = ( + hydra.utils.instantiate( + reader, + device=reader_device, + precision=reader_precision, + ) + if isinstance(reader, DictConfig) + else reader + ) + reader.training = False + reader.eval() + if reader_device is not None: + logger.info(f"Moving reader to `{reader_device}`.") + reader.to(reader_device) + if reader_precision is not None and reader.precision != PRECISION_MAP[reader_precision]: + logger.info( + f"Setting precision of reader to `{PRECISION_MAP[reader_precision]}`." + ) + reader.to(PRECISION_MAP[reader_precision]) + self.reader = reader + # self.reader = torch.compile(self.reader, backend="tvm") + + # windowization stuff + self.tokenizer = SpacyTokenizer(language="en") # TODO: parametrize? + self.sentence_splitter: BaseSentenceSplitter | None = None + self.window_manager: WindowManager | None = None + + if metadata_fields is None: + metadata_fields = [] + self.metadata_fields = metadata_fields + + # inference params + self.top_k = top_k + self.window_size = window_size + self.window_stride = window_stride + + @staticmethod + def _instantiate_retriever( + retriever, + retriever_device, + retriever_precision, + document_index, + document_index_device, + document_index_precision, + ): + if not isinstance(retriever, GoldenRetriever): + # convert to DictConfig + retriever = hydra.utils.instantiate( + OmegaConf.create(retriever), + device=retriever_device, + precision=retriever_precision, + index_device=document_index_device, + index_precision=document_index_precision, + ) + retriever.training = False + retriever.eval() + if document_index is not None: + if retriever.document_index is not None: + logger.info( + "The Retriever already has a document index, replacing it with the provided one." + "If you want to keep using the old one, please do not provide a document index." + ) + retriever.document_index = document_index + # we override the device and the precision of the document index if provided + if document_index_device is not None: + logger.info(f"Moving document index to `{document_index_device}`.") + retriever.document_index.to(document_index_device) + if document_index_precision is not None: + logger.info( + f"Setting precision of document index to `{PRECISION_MAP[document_index_precision]}`." + ) + retriever.document_index.to(PRECISION_MAP[document_index_precision]) + # retriever.document_index = document_index + # now we can move the retriever to the right device and set the precision + if retriever_device is not None: + logger.info(f"Moving retriever to `{retriever_device}`.") + retriever.to(retriever_device) + if retriever_precision is not None: + logger.info( + f"Setting precision of retriever to `{PRECISION_MAP[retriever_precision]}`." + ) + retriever.to(PRECISION_MAP[retriever_precision]) + return retriever + + def __call__( + self, + text: str | List[str] | None = None, + windows: List[RelikReaderSample] | None = None, + candidates: List[str] + | List[Document] + | Dict[TaskType, List[Document]] + | None = None, + mentions: List[List[int]] | List[List[List[int]]] | None = None, + top_k: int | None = None, + window_size: int | None = None, + window_stride: int | None = None, + is_split_into_words: bool = False, + retriever_batch_size: int | None = 32, + reader_batch_size: int | None = 32, + return_also_windows: bool = False, + annotation_type: str | AnnotationType = AnnotationType.CHAR, + progress_bar: bool = False, + **kwargs, + ) -> Union[RelikOutput, list[RelikOutput]]: + """ + Annotate a text with entities. + + Args: + text (`str` or `list`): + The text to annotate. If a list is provided, each element of the list + will be annotated separately. + candidates (`list[str]`, `list[Document]`, `optional`, defaults to `None`): + The candidates to use for the reader. If `None`, the candidates will be + retrieved from the retriever. + mentions (`list[list[int]]` or `list[list[list[int]]]`, `optional`, defaults to `None`): + The mentions to use for the reader. If `None`, the mentions will be + predicted by the reader. + top_k (`int`, `optional`, defaults to `None`): + The number of candidates to retrieve for each window. + window_size (`int`, `optional`, defaults to `None`): + The size of the window. If `None`, the whole text will be annotated. + window_stride (`int`, `optional`, defaults to `None`): + The stride of the window. If `None`, there will be no overlap between windows. + retriever_batch_size (`int`, `optional`, defaults to `None`): + The batch size to use for the retriever. The whole input is the batch for the retriever. + reader_batch_size (`int`, `optional`, defaults to `None`): + The batch size to use for the reader. The whole input is the batch for the reader. + return_also_windows (`bool`, `optional`, defaults to `False`): + Whether to return the windows in the output. + annotation_type (`str` or `AnnotationType`, `optional`, defaults to `char`): + The type of annotation to return. If `char`, the spans will be in terms of + character offsets. If `word`, the spans will be in terms of word offsets. + **kwargs: + Additional keyword arguments to pass to the retriever and the reader. + + Returns: + `RelikOutput` or `list[RelikOutput]`: + The annotated text. If a list was provided as input, a list of + `RelikOutput` objects will be returned. + """ + + if text is None and windows is None: + raise ValueError( + "Either `text` or `windows` must be provided. Both are `None`." + ) + + if isinstance(annotation_type, str): + try: + annotation_type = AnnotationType(annotation_type) + except ValueError: + raise ValueError( + f"Annotation type {annotation_type} not recognized. " + f"Please choose one of {list(AnnotationType)}." + ) + + if top_k is None: + top_k = self.top_k or 100 + if window_size is None: + window_size = self.window_size + if window_stride is None: + window_stride = self.window_stride + + if text: + if isinstance(text, str): + text = [text] + if mentions is not None: + mentions = [mentions] + if file_logger is not None: + file_logger.info("Annotating the following text:") + for t in text: + file_logger.info(f" {t}") + + if self.window_manager is None: + if window_size == "none": + self.sentence_splitter = BlankSentenceSplitter() + elif window_size == "sentence": + self.sentence_splitter = SpacySentenceSplitter() + else: + self.sentence_splitter = WindowSentenceSplitter( + window_size=window_size, window_stride=window_stride + ) + self.window_manager = WindowManager( + self.tokenizer, self.sentence_splitter + ) + + if ( + window_size not in ["sentence", "none"] + and window_stride is not None + and window_size < window_stride + ): + raise ValueError( + f"Window size ({window_size}) must be greater than window stride ({window_stride})" + ) + + if windows is None: + # windows were provided, use them + windows, blank_windows = self.window_manager.create_windows( + text, + window_size, + window_stride, + is_split_into_words=is_split_into_words, + mentions=mentions + ) + else: + blank_windows = [] + text = {w.doc_id: w.text for w in windows} + + if candidates is not None and any( + r is not None for r in self.retriever.values() + ): + logger.info( + "Both candidates and a retriever were provided. " + "Retriever will be ignored." + ) + + windows_candidates = {TaskType.SPAN: None, TaskType.TRIPLET: None} + if candidates is not None: + # again, check if candidates is a dict + if isinstance(candidates, Dict): + if self.task not in candidates: + raise ValueError( + f"Task `{self.task}` not found in `candidates`." + f"Please choose one of {list(TaskType)}." + ) + else: + candidates = {self.task: candidates} + + for task_type, _candidates in candidates.items(): + if isinstance(_candidates, list): + _candidates = [ + [ + c if isinstance(c, Document) else Document(c) + for c in _candidates[w.doc_id] + ] + for w in windows + ] + windows_candidates[task_type] = _candidates + + else: + # retrieve candidates first + if self.retriever is None: + raise ValueError( + "No retriever was provided, please provide a retriever or candidates." + ) + start_retr = time.time() + for task_type, retriever in self.retriever.items(): + retriever_out = retriever.retrieve( + [w.text for w in windows], + text_pair=[w.doc_topic.text if w.doc_topic is not None else None for w in windows], + k=top_k, + batch_size=retriever_batch_size, + progress_bar=progress_bar, + **kwargs, + ) + windows_candidates[task_type] = [ + [p.document for p in predictions] for predictions in retriever_out + ] + end_retr = time.time() + logger.info(f"Retrieval took {end_retr - start_retr} seconds.") + + # clean up None's + windows_candidates = { + t: c for t, c in windows_candidates.items() if c is not None + } + + # add passage to the windows + for task_type, task_candidates in windows_candidates.items(): + for window, candidates in zip(windows, task_candidates): + # construct the candidates for the reader + formatted_candidates = [] + for candidate in candidates: + window_candidate_text = candidate.text + for field in self.metadata_fields: + window_candidate_text += f"{candidate.metadata.get(field, '')}" + formatted_candidates.append(window_candidate_text) + # create a member for the windows that is named like the task + setattr(window, f"{task_type.value}_candidates", formatted_candidates) + + for task_type, task_candidates in windows_candidates.items(): + for window in blank_windows: + setattr(window, f"{task_type.value}_candidates", []) + setattr(window, "predicted_spans", []) + setattr(window, "predicted_triples", []) + if self.reader is not None: + start_read = time.time() + windows = self.reader.read( + samples=windows, + max_batch_size=reader_batch_size, + annotation_type=annotation_type, + progress_bar=progress_bar, + **kwargs, + ) + end_read = time.time() + logger.info(f"Reading took {end_read - start_read} seconds.") + # TODO: check merging behavior without a reader + # do we want to merge windows if there is no reader? + + if self.window_size is not None and self.window_size not in ["sentence", "none"]: + start_w = time.time() + windows = windows + blank_windows + windows.sort(key=lambda x: (x.doc_id, x.offset)) + merged_windows = self.window_manager.merge_windows(windows) + end_w = time.time() + logger.info(f"Merging took {end_w - start_w} seconds.") + else: + merged_windows = windows + else: + windows = windows + blank_windows + windows.sort(key=lambda x: (x.doc_id, x.offset)) + merged_windows = windows + + # transform predictions into RelikOutput objects + output = [] + for w in merged_windows: + span_labels = [] + triples_labels = [] + # span extraction should always be present + if getattr(w, "predicted_spans", None) is not None: + span_labels = sorted( + [ + Span(start=ss, end=se, label=sl, text=text[w.doc_id][ss:se]) + if annotation_type == AnnotationType.CHAR + else Span(start=ss, end=se, label=sl, text=w.words[ss:se]) + for ss, se, sl in w.predicted_spans + ], + key=lambda x: x.start, + ) + # triple extraction is optional, if here add it + if getattr(w, "predicted_triples", None) is not None: + triples_labels = [ + Triples( + subject=span_labels[subj], + label=label, + object=span_labels[obj], + confidence=conf, + ) + for subj, label, obj, conf in w.predicted_triples + ] + # create the output + sample_output = RelikOutput( + text=text[w.doc_id], + tokens=w.words, + spans=span_labels, + triples=triples_labels, + candidates={ + task_type: [ + r.document_index.documents.get_document_from_text(c) + for c in getattr(w, f"{task_type.value}_candidates", []) + if r.document_index.documents.get_document_from_text(c) is not None + ] + for task_type, r in self.retriever.items() + }, + ) + output.append(sample_output) + + # add windows to the output if requested + # do we want to force windows to be returned if there is no reader? + if return_also_windows: + for i, sample_output in enumerate(output): + sample_output.windows = [w for w in windows if w.doc_id == i] + + # if only one text was provided, return a single RelikOutput object + if len(output) == 1: + return output[0] + + return output + + @classmethod + def from_pretrained( + cls, + model_name_or_dir: Union[str, os.PathLike], + config_file_name: str = CONFIG_NAME, + *args, + **kwargs, + ) -> "Relik": + """ + Instantiate a `Relik` from a pretrained model. + + Args: + model_name_or_dir (`str` or `os.PathLike`): + The name or path of the model to load. + config_file_name (`str`, `optional`, defaults to `config.yaml`): + The name of the configuration file to load. + *args: + Additional positional arguments to pass to `OmegaConf.merge`. + **kwargs: + Additional keyword arguments to pass to `OmegaConf.merge`. + + Returns: + `Relik`: + The instantiated `Relik`. + + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + + model_dir = from_cache( + model_name_or_dir, + filenames=[config_file_name], + cache_dir=cache_dir, + force_download=force_download, + ) + + config_path = model_dir / config_file_name + if not config_path.exists(): + raise FileNotFoundError( + f"Model configuration file not found at {config_path}." + ) + + # overwrite config with config_kwargs + config = OmegaConf.load(config_path) + # if kwargs is not None: + config = OmegaConf.merge(config, OmegaConf.create(kwargs)) + # do we want to print the config? I like it + logger.info(f"Loading Relik from {model_name_or_dir}") + logger.info(pformat(OmegaConf.to_container(config))) + + # load relik from config + relik = hydra.utils.instantiate(config, _recursive_=False, *args) + + return relik + + def save_pretrained( + self, + output_dir: Union[str, os.PathLike], + config: Optional[Dict[str, Any]] = None, + config_file_name: Optional[str] = None, + save_weights: bool = False, + push_to_hub: bool = False, + model_id: Optional[str] = None, + organization: Optional[str] = None, + repo_name: Optional[str] = None, + retriever_model_id: Optional[str] = None, + reader_model_id: Optional[str] = None, + **kwargs, + ): + """ + Save the configuration of Relik to the specified directory as a YAML file. + + Args: + output_dir (`str`): + The directory to save the configuration file to. + config (`Optional[Dict[str, Any]]`, `optional`): + The configuration to save. If `None`, the current configuration will be + saved. Defaults to `None`. + config_file_name (`Optional[str]`, `optional`): + The name of the configuration file. Defaults to `config.yaml`. + save_weights (`bool`, `optional`): + Whether to save the weights of the model. Defaults to `False`. + push_to_hub (`bool`, `optional`): + Whether to push the saved model to the hub. Defaults to `False`. + model_id (`Optional[str]`, `optional`): + The id of the model to push to the hub. If `None`, the name of the + directory will be used. Defaults to `None`. + organization (`Optional[str]`, `optional`): + The organization to push the model to. Defaults to `None`. + repo_name (`Optional[str]`, `optional`): + The name of the repository to push the model to. Defaults to `None`. + retriever_model_id (`Optional[str]`, `optional`): + The id of the retriever model to push to the hub. If `None`, the name of the + directory will be used. Defaults to `None`. + reader_model_id (`Optional[str]`, `optional`): + The id of the reader model to push to the hub. If `None`, the name of the + directory will be used. Defaults to `None`. + **kwargs: + Additional keyword arguments to pass to `OmegaConf.save`. + """ + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + retrievers_names: Dict[TaskType, Dict | None] = { + TaskType.SPAN: { + "question_encoder_name": None, + "passage_encoder_name": None, + "document_index_name": None, + }, + TaskType.TRIPLET: { + "question_encoder_name": None, + "passage_encoder_name": None, + "document_index_name": None, + }, + } + + if save_weights: + # save weights + # retriever + model_id = model_id or output_dir.name + retriever_model_id = retriever_model_id or f"retriever-{model_id}" + for task_type, retriever in self.retriever.items(): + if retriever is None: + continue + task_retriever_model_id = f"{retriever_model_id}-{task_type.value}" + question_encoder_name = f"{task_retriever_model_id}-question-encoder" + passage_encoder_name = f"{task_retriever_model_id}-passage-encoder" + document_index_name = f"{task_retriever_model_id}-index" + logger.info( + f"Saving retriever to {output_dir / task_retriever_model_id}" + ) + retriever.save_pretrained( + output_dir / task_retriever_model_id, + question_encoder_name=question_encoder_name, + passage_encoder_name=passage_encoder_name, + document_index_name=document_index_name, + push_to_hub=push_to_hub, + organization=organization, + **kwargs, + ) + retrievers_names[task_type] = { + "reader_model_id": task_retriever_model_id, + "question_encoder_name": question_encoder_name, + "passage_encoder_name": passage_encoder_name, + "document_index_name": document_index_name, + } + + # reader + reader_model_id = reader_model_id or f"reader-{model_id}" + logger.info(f"Saving reader to {output_dir / reader_model_id}") + self.reader.save_pretrained( + output_dir / reader_model_id, + push_to_hub=push_to_hub, + organization=organization, + **kwargs, + ) + + if push_to_hub: + user = organization or get_logged_in_username() + # we need to update the config with the model ids that will + # result from the push to hub + for task_type, retriever_names in retrievers_names.items(): + retriever_names[ + "question_encoder_name" + ] = f"{user}/{retriever_names['question_encoder_name']}" + retriever_names[ + "passage_encoder_name" + ] = f"{user}/{retriever_names['passage_encoder_name']}" + retriever_names[ + "document_index_name" + ] = f"{user}/{retriever_names['document_index_name']}" + # question_encoder_name = f"{user}/{question_encoder_name}" + # passage_encoder_name = f"{user}/{passage_encoder_name}" + # document_index_name = f"{user}/{document_index_name}" + reader_model_id = f"{user}/{reader_model_id}" + else: + for task_type, retriever_names in retrievers_names.items(): + retriever_names["question_encoder_name"] = ( + output_dir / retriever_names["question_encoder_name"] + ) + retriever_names["passage_encoder_name"] = ( + output_dir / retriever_names["passage_encoder_name"] + ) + retriever_names["document_index_name"] = ( + output_dir / retriever_names["document_index_name"] + ) + reader_model_id = output_dir / reader_model_id + else: + # save config only + for task_type, retriever_names in retrievers_names.items(): + retriever = self.retriever.get(task_type, None) + if retriever is None: + continue + retriever_names[ + "question_encoder_name" + ] = retriever.question_encoder.name_or_path + retriever_names[ + "passage_encoder_name" + ] = retriever.passage_encoder.name_or_path + retriever_names[ + "document_index_name" + ] = retriever.document_index.name_or_path + + reader_model_id = self.reader.name_or_path + + if config is None: + # create a default config + config = { + "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}" + } + if self.retriever is not None: + config["retriever"] = {} + for task_type, retriever in self.retriever.items(): + if retriever is None: + continue + config["retriever"][task_type.value] = { + "_target_": f"{retriever.__class__.__module__}.{retriever.__class__.__name__}", + } + if retriever.question_encoder is not None: + config["retriever"][task_type.value][ + "question_encoder" + ] = retrievers_names[task_type]["question_encoder_name"] + if ( + retriever.passage_encoder is not None + and not retriever.passage_encoder_is_question_encoder + ): + config["retriever"][task_type.value][ + "passage_encoder" + ] = retrievers_names[task_type]["passage_encoder_name"] + if retriever.document_index is not None: + config["retriever"][task_type.value][ + "document_index" + ] = retrievers_names[task_type]["document_index_name"] + if self.reader is not None: + config["reader"] = { + "_target_": f"{self.reader.__class__.__module__}.{self.reader.__class__.__name__}", + "transformer_model": reader_model_id, + } + + # these are model-specific and should be saved + config["task"] = self.task + config["metadata_fields"] = self.metadata_fields + config["top_k"] = self.top_k + config["window_size"] = self.window_size + config["window_stride"] = self.window_stride + + config_file_name = config_file_name or CONFIG_NAME + + logger.info(f"Saving relik config to {output_dir / config_file_name}") + # pretty print the config + logger.info(pformat(config)) + OmegaConf.save(config, output_dir / config_file_name) + + if push_to_hub: + # push to hub + logger.info("Pushing to hub") + model_id = model_id or output_dir.name + upload( + output_dir, + model_id, + filenames=[config_file_name], + organization=organization, + repo_name=repo_name, + ) diff --git a/relik/inference/data/__init__.py b/relik/inference/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/data/__pycache__/__init__.cpython-310.pyc b/relik/inference/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06df015c132fc0b51ab0c4f3c6a2a00bbf71b829 Binary files /dev/null and b/relik/inference/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/inference/data/__pycache__/objects.cpython-310.pyc b/relik/inference/data/__pycache__/objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f82fc9965835219c466fc11161a9c15402805b1c Binary files /dev/null and b/relik/inference/data/__pycache__/objects.cpython-310.pyc differ diff --git a/relik/inference/data/objects.py b/relik/inference/data/objects.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a945b8e38e4b05f290bb6e90db0d9b9be8cabd --- /dev/null +++ b/relik/inference/data/objects.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, NamedTuple, Optional + +from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample +from relik.retriever.indexers.document import Document + + +@dataclass +class Word: + """ + A word representation that includes text, index in the sentence, POS tag, lemma, + dependency relation, and similar information. + + # Parameters + text : `str`, optional + The text representation. + index : `int`, optional + The word offset in the sentence. + lemma : `str`, optional + The lemma of this word. + pos : `str`, optional + The coarse-grained part of speech of this word. + dep : `str`, optional + The dependency relation for this word. + + input_id : `int`, optional + Integer representation of the word, used to pass it to a model. + token_type_id : `int`, optional + Token type id used by some transformers. + attention_mask: `int`, optional + Attention mask used by transformers, indicates to the model which tokens should + be attended to, and which should not. + """ + + text: str + i: int + idx: Optional[int] = None + idx_end: Optional[int] = None + # preprocessing fields + lemma: Optional[str] = None + pos: Optional[str] = None + dep: Optional[str] = None + head: Optional[int] = None + + def __str__(self): + return self.text + + def __repr__(self): + return self.__str__() + + +class Span(NamedTuple): + start: int + end: int + label: str + text: str + + +class Triples(NamedTuple): + subject: Span + label: str + object: Span + confidence: float + +@dataclass +class RelikOutput: + text: str + tokens: List[str] + spans: List[Span] + triples: List[Triples] + candidates: Dict[TaskType, List[Document]] + windows: Optional[List[RelikReaderSample]] = None + + +from enum import Enum + + +class AnnotationType(Enum): + CHAR = "char" + WORD = "word" + + +class TaskType(Enum): + SPAN = "span" + TRIPLET = "triplet" + BOTH = "both" diff --git a/relik/inference/data/splitters/__init__.py b/relik/inference/data/splitters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/data/splitters/__pycache__/__init__.cpython-310.pyc b/relik/inference/data/splitters/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee9ce7a008635178d2c91efe56d805e513ded1a8 Binary files /dev/null and b/relik/inference/data/splitters/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/inference/data/splitters/__pycache__/base_sentence_splitter.cpython-310.pyc b/relik/inference/data/splitters/__pycache__/base_sentence_splitter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1dd4cb2ea1a9e84936437dc5915d4bddd2ac0c2 Binary files /dev/null and b/relik/inference/data/splitters/__pycache__/base_sentence_splitter.cpython-310.pyc differ diff --git a/relik/inference/data/splitters/__pycache__/blank_sentence_splitter.cpython-310.pyc b/relik/inference/data/splitters/__pycache__/blank_sentence_splitter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..358d2d27ea8fe229150b868ec81e706132924fc4 Binary files /dev/null and b/relik/inference/data/splitters/__pycache__/blank_sentence_splitter.cpython-310.pyc differ diff --git a/relik/inference/data/splitters/__pycache__/spacy_sentence_splitter.cpython-310.pyc b/relik/inference/data/splitters/__pycache__/spacy_sentence_splitter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9f5d397af66de905b46d00993d5a267e5bbd2cf Binary files /dev/null and b/relik/inference/data/splitters/__pycache__/spacy_sentence_splitter.cpython-310.pyc differ diff --git a/relik/inference/data/splitters/__pycache__/window_based_splitter.cpython-310.pyc b/relik/inference/data/splitters/__pycache__/window_based_splitter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be8bfa88a58145fe4268b19e7b753a610eb75564 Binary files /dev/null and b/relik/inference/data/splitters/__pycache__/window_based_splitter.cpython-310.pyc differ diff --git a/relik/inference/data/splitters/base_sentence_splitter.py b/relik/inference/data/splitters/base_sentence_splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..5e69ee453b4f507aa37b8bd2ecfdc8aa4f1428d0 --- /dev/null +++ b/relik/inference/data/splitters/base_sentence_splitter.py @@ -0,0 +1,55 @@ +from typing import List, Union + + +class BaseSentenceSplitter: + """ + A `BaseSentenceSplitter` splits strings into sentences. + """ + + def __call__(self, *args, **kwargs): + """ + Calls :meth:`split_sentences`. + """ + return self.split_sentences(*args, **kwargs) + + def split_sentences( + self, text: str, max_len: int = 0, *args, **kwargs + ) -> List[str]: + """ + Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence. + """ + raise NotImplementedError + + def split_sentences_batch( + self, texts: List[str], *args, **kwargs + ) -> List[List[str]]: + """ + Default implementation is to just iterate over the texts and call `split_sentences`. + """ + return [self.split_sentences(text) for text in texts] + + @staticmethod + def check_is_batched( + texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool + ): + """ + Check if input is batched or a single sample. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to check. + is_split_into_words (:obj:`bool`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise. + """ + return bool( + (not is_split_into_words and isinstance(texts, (list, tuple))) + or ( + is_split_into_words + and isinstance(texts, (list, tuple)) + and texts + and isinstance(texts[0], (list, tuple)) + ) + ) diff --git a/relik/inference/data/splitters/blank_sentence_splitter.py b/relik/inference/data/splitters/blank_sentence_splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..452ae8f952b32e49b2f89fe79ecb51613a0981b0 --- /dev/null +++ b/relik/inference/data/splitters/blank_sentence_splitter.py @@ -0,0 +1,29 @@ +from typing import List, Union + + +class BlankSentenceSplitter: + """ + A `BlankSentenceSplitter` splits strings into sentences. + """ + + def __call__(self, *args, **kwargs): + """ + Calls :meth:`split_sentences`. + """ + return self.split_sentences(*args, **kwargs) + + def split_sentences( + self, text: str, max_len: int = 0, *args, **kwargs + ) -> List[str]: + """ + Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence. + """ + return [text] + + def split_sentences_batch( + self, texts: List[str], *args, **kwargs + ) -> List[List[str]]: + """ + Default implementation is to just iterate over the texts and call `split_sentences`. + """ + return [self.split_sentences(text) for text in texts] diff --git a/relik/inference/data/splitters/spacy_sentence_splitter.py b/relik/inference/data/splitters/spacy_sentence_splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..01aeefe009784574f221920e9ca792954fce108b --- /dev/null +++ b/relik/inference/data/splitters/spacy_sentence_splitter.py @@ -0,0 +1,153 @@ +from typing import Any, Iterable, List, Optional, Union + +import spacy + +from relik.inference.data.objects import Word +from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter +from relik.inference.data.tokenizers.spacy_tokenizer import load_spacy + +SPACY_LANGUAGE_MAPPER = { + "cs": "xx_sent_ud_sm", + "da": "xx_sent_ud_sm", + "de": "xx_sent_ud_sm", + "fa": "xx_sent_ud_sm", + "fi": "xx_sent_ud_sm", + "fr": "xx_sent_ud_sm", + "el": "el_core_news_sm", + "en": "xx_sent_ud_sm", + "es": "xx_sent_ud_sm", + "ga": "xx_sent_ud_sm", + "hr": "xx_sent_ud_sm", + "id": "xx_sent_ud_sm", + "it": "xx_sent_ud_sm", + "ja": "ja_core_news_sm", + "lv": "xx_sent_ud_sm", + "lt": "xx_sent_ud_sm", + "mr": "xx_sent_ud_sm", + "nb": "xx_sent_ud_sm", + "nl": "xx_sent_ud_sm", + "no": "xx_sent_ud_sm", + "pl": "pl_core_news_sm", + "pt": "xx_sent_ud_sm", + "ro": "xx_sent_ud_sm", + "ru": "xx_sent_ud_sm", + "sk": "xx_sent_ud_sm", + "sr": "xx_sent_ud_sm", + "sv": "xx_sent_ud_sm", + "te": "xx_sent_ud_sm", + "vi": "xx_sent_ud_sm", + "zh": "zh_core_web_sm", +} + + +class SpacySentenceSplitter(BaseSentenceSplitter): + """ + A :obj:`SentenceSplitter` that uses spaCy's built-in sentence boundary detection. + + Args: + language (:obj:`str`, optional, defaults to :obj:`en`): + Language of the text to tokenize. + model_type (:obj:`str`, optional, defaults to :obj:`statistical`): + Three different type of sentence splitter: + - ``dependency``: sentence splitter uses a dependency parse to detect sentence boundaries, + slow, but accurate. + - ``statistical``: + - ``rule_based``: It's fast and has a small memory footprint, since it uses punctuation to detect + sentence boundaries. + """ + + def __init__(self, language: str = "en", model_type: str = "statistical") -> None: + # we need spacy's dependency parser if we're not using rule-based sentence boundary detection. + # self.spacy = get_spacy_model(language, parse=not rule_based, ner=False) + dep = bool(model_type == "dependency") + if language in SPACY_LANGUAGE_MAPPER: + self.spacy = load_spacy(SPACY_LANGUAGE_MAPPER[language], parse=dep) + else: + self.spacy = spacy.blank(language) + # force type to rule_based since there is no pre-trained model + model_type = "rule_based" + if model_type == "dependency": + # dependency type must declared at model init + pass + elif model_type == "statistical": + if not self.spacy.has_pipe("senter"): + self.spacy.enable_pipe("senter") + elif model_type == "rule_based": + # we use `sentencizer`, a built-in spacy module for rule-based sentence boundary detection. + # depending on the spacy version, it could be called 'sentencizer' or 'sbd' + if not self.spacy.has_pipe("sentencizer"): + self.spacy.add_pipe("sentencizer") + else: + raise ValueError( + f"type {model_type} not supported. Choose between `dependency`, `statistical` or `rule_based`" + ) + + def __call__( + self, + texts: Union[str, List[str], List[List[str]]], + max_length: Optional[int] = None, + is_split_into_words: bool = False, + **kwargs, + ) -> Union[List[str], List[List[str]]]: + """ + Tokenize the input into single words using SpaCy models. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to tag. It can be a single string, a batch of string and pre-tokenized strings. + max_len (:obj:`int`, optional, defaults to :obj:`0`): + Maximum length of a single text. If the text is longer than `max_len`, it will be split + into multiple sentences. + + Returns: + :obj:`List[List[str]]`: The input doc split into sentences. + """ + # check if input is batched or a single sample + is_batched = self.check_is_batched(texts, is_split_into_words) + + if is_batched: + sents = self.split_sentences_batch(texts) + else: + sents = self.split_sentences(texts, max_length) + return sents + + @staticmethod + def chunked(iterable, n: int) -> Iterable[List[Any]]: + """ + Chunks a list into n sized chunks. + + Args: + iterable (:obj:`List[Any]`): + List to chunk. + n (:obj:`int`): + Size of the chunks. + + Returns: + :obj:`Iterable[List[Any]]`: The input list chunked into n sized chunks. + """ + return [iterable[i : i + n] for i in range(0, len(iterable), n)] + + def split_sentences( + self, text: str | List[Word], max_length: Optional[int] = None, *args, **kwargs + ) -> List[str]: + """ + Splits a `text` into smaller sentences. + + Args: + text (:obj:`str`): + Text to split. + max_length (:obj:`int`, optional, defaults to :obj:`0`): + Maximum length of a single sentence. If the text is longer than `max_len`, it will be split + into multiple sentences. + + Returns: + :obj:`List[str]`: The input text split into sentences. + """ + sentences = [sent for sent in self.spacy(text).sents] + if max_length is not None and max_length > 0: + sentences = [ + chunk + for sentence in sentences + for chunk in self.chunked(sentence, max_length) + ] + return sentences diff --git a/relik/inference/data/splitters/window_based_splitter.py b/relik/inference/data/splitters/window_based_splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..55ac48a6aec14614393f0d70b000d63ab1afb67b --- /dev/null +++ b/relik/inference/data/splitters/window_based_splitter.py @@ -0,0 +1,62 @@ +from typing import List, Union + +from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter + + +class WindowSentenceSplitter(BaseSentenceSplitter): + """ + A :obj:`WindowSentenceSplitter` that splits a text into windows of a given size. + """ + + def __init__(self, window_size: int, window_stride: int, *args, **kwargs) -> None: + super(WindowSentenceSplitter, self).__init__() + self.window_size = window_size + self.window_stride = window_stride + + def __call__( + self, + texts: Union[str, List[str], List[List[str]]], + is_split_into_words: bool = False, + **kwargs, + ) -> Union[List[str], List[List[str]]]: + """ + Tokenize the input into single words using SpaCy models. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to tag. It can be a single string, a batch of string and pre-tokenized strings. + + Returns: + :obj:`List[List[str]]`: The input doc split into sentences. + """ + return self.split_sentences(texts) + + def split_sentences(self, text: str | List, *args, **kwargs) -> List[List]: + """ + Splits a `text` into sentences. + + Args: + text (:obj:`str`): + Text to split. + + Returns: + :obj:`List[str]`: The input text split into sentences. + """ + + if isinstance(text, str): + text = text.split() + sentences = [] + for i in range(0, len(text), self.window_stride): + # if the last stride is smaller than the window size, then we can + # include more tokens form the previous window. + if i != 0 and i + self.window_size > len(text): + overflowing_tokens = i + self.window_size - len(text) + if overflowing_tokens >= self.window_stride: + break + i -= overflowing_tokens + involved_token_indices = list( + range(i, min(i + self.window_size, len(text))) + ) + window_tokens = [text[j] for j in involved_token_indices] + sentences.append(window_tokens) + return sentences diff --git a/relik/inference/data/tokenizers/__init__.py b/relik/inference/data/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffab94f389dc1a7feb3b90b69fddf167e8c1df9 --- /dev/null +++ b/relik/inference/data/tokenizers/__init__.py @@ -0,0 +1,87 @@ +SPACY_LANGUAGE_MAPPER = { + "ca": "ca_core_news_sm", + "da": "da_core_news_sm", + "de": "de_core_news_sm", + "el": "el_core_news_sm", + "en": "en_core_web_sm", + "es": "es_core_news_sm", + "fr": "fr_core_news_sm", + "it": "it_core_news_sm", + "ja": "ja_core_news_sm", + "lt": "lt_core_news_sm", + "mk": "mk_core_news_sm", + "nb": "nb_core_news_sm", + "nl": "nl_core_news_sm", + "pl": "pl_core_news_sm", + "pt": "pt_core_news_sm", + "ro": "ro_core_news_sm", + "ru": "ru_core_news_sm", + "xx": "xx_sent_ud_sm", + "zh": "zh_core_web_sm", + "ca_core_news_sm": "ca_core_news_sm", + "ca_core_news_md": "ca_core_news_md", + "ca_core_news_lg": "ca_core_news_lg", + "ca_core_news_trf": "ca_core_news_trf", + "da_core_news_sm": "da_core_news_sm", + "da_core_news_md": "da_core_news_md", + "da_core_news_lg": "da_core_news_lg", + "da_core_news_trf": "da_core_news_trf", + "de_core_news_sm": "de_core_news_sm", + "de_core_news_md": "de_core_news_md", + "de_core_news_lg": "de_core_news_lg", + "de_dep_news_trf": "de_dep_news_trf", + "el_core_news_sm": "el_core_news_sm", + "el_core_news_md": "el_core_news_md", + "el_core_news_lg": "el_core_news_lg", + "en_core_web_sm": "en_core_web_sm", + "en_core_web_md": "en_core_web_md", + "en_core_web_lg": "en_core_web_lg", + "en_core_web_trf": "en_core_web_trf", + "es_core_news_sm": "es_core_news_sm", + "es_core_news_md": "es_core_news_md", + "es_core_news_lg": "es_core_news_lg", + "es_dep_news_trf": "es_dep_news_trf", + "fr_core_news_sm": "fr_core_news_sm", + "fr_core_news_md": "fr_core_news_md", + "fr_core_news_lg": "fr_core_news_lg", + "fr_dep_news_trf": "fr_dep_news_trf", + "it_core_news_sm": "it_core_news_sm", + "it_core_news_md": "it_core_news_md", + "it_core_news_lg": "it_core_news_lg", + "ja_core_news_sm": "ja_core_news_sm", + "ja_core_news_md": "ja_core_news_md", + "ja_core_news_lg": "ja_core_news_lg", + "ja_dep_news_trf": "ja_dep_news_trf", + "lt_core_news_sm": "lt_core_news_sm", + "lt_core_news_md": "lt_core_news_md", + "lt_core_news_lg": "lt_core_news_lg", + "mk_core_news_sm": "mk_core_news_sm", + "mk_core_news_md": "mk_core_news_md", + "mk_core_news_lg": "mk_core_news_lg", + "nb_core_news_sm": "nb_core_news_sm", + "nb_core_news_md": "nb_core_news_md", + "nb_core_news_lg": "nb_core_news_lg", + "nl_core_news_sm": "nl_core_news_sm", + "nl_core_news_md": "nl_core_news_md", + "nl_core_news_lg": "nl_core_news_lg", + "pl_core_news_sm": "pl_core_news_sm", + "pl_core_news_md": "pl_core_news_md", + "pl_core_news_lg": "pl_core_news_lg", + "pt_core_news_sm": "pt_core_news_sm", + "pt_core_news_md": "pt_core_news_md", + "pt_core_news_lg": "pt_core_news_lg", + "ro_core_news_sm": "ro_core_news_sm", + "ro_core_news_md": "ro_core_news_md", + "ro_core_news_lg": "ro_core_news_lg", + "ru_core_news_sm": "ru_core_news_sm", + "ru_core_news_md": "ru_core_news_md", + "ru_core_news_lg": "ru_core_news_lg", + "xx_ent_wiki_sm": "xx_ent_wiki_sm", + "xx_sent_ud_sm": "xx_sent_ud_sm", + "zh_core_web_sm": "zh_core_web_sm", + "zh_core_web_md": "zh_core_web_md", + "zh_core_web_lg": "zh_core_web_lg", + "zh_core_web_trf": "zh_core_web_trf", +} + +from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer diff --git a/relik/inference/data/tokenizers/__pycache__/__init__.cpython-310.pyc b/relik/inference/data/tokenizers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..057c3f02df3595686935765206ddabd34779357f Binary files /dev/null and b/relik/inference/data/tokenizers/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/inference/data/tokenizers/__pycache__/base_tokenizer.cpython-310.pyc b/relik/inference/data/tokenizers/__pycache__/base_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9ca692ca496fff2012bf79b257a7d8879ad8f2f Binary files /dev/null and b/relik/inference/data/tokenizers/__pycache__/base_tokenizer.cpython-310.pyc differ diff --git a/relik/inference/data/tokenizers/__pycache__/spacy_tokenizer.cpython-310.pyc b/relik/inference/data/tokenizers/__pycache__/spacy_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bac9569f361e2a2e73b656ba34ad6aa3da9c8a92 Binary files /dev/null and b/relik/inference/data/tokenizers/__pycache__/spacy_tokenizer.cpython-310.pyc differ diff --git a/relik/inference/data/tokenizers/base_tokenizer.py b/relik/inference/data/tokenizers/base_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1fed161b3eca085656e85d44cb9a64739f3d1e4c --- /dev/null +++ b/relik/inference/data/tokenizers/base_tokenizer.py @@ -0,0 +1,84 @@ +from typing import List, Union + +from relik.inference.data.objects import Word + + +class BaseTokenizer: + """ + A :obj:`Tokenizer` splits strings of text into single words, optionally adds + pos tags and perform lemmatization. + """ + + def __call__( + self, + texts: Union[str, List[str], List[List[str]]], + is_split_into_words: bool = False, + **kwargs + ) -> List[List[Word]]: + """ + Tokenize the input into single words. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to tag. It can be a single string, a batch of string and pre-tokenized strings. + is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`List[List[Word]]`: The input text tokenized in single words. + """ + raise NotImplementedError + + def tokenize(self, text: str) -> List[Word]: + """ + Implements splitting words into tokens. + + Args: + text (:obj:`str`): + Text to tokenize. + + Returns: + :obj:`List[Word]`: The input text tokenized in single words. + + """ + raise NotImplementedError + + def tokenize_batch(self, texts: List[str]) -> List[List[Word]]: + """ + Implements batch splitting words into tokens. + + Args: + texts (:obj:`List[str]`): + Batch of text to tokenize. + + Returns: + :obj:`List[List[Word]]`: The input batch tokenized in single words. + + """ + return [self.tokenize(text) for text in texts] + + @staticmethod + def check_is_batched( + texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool + ): + """ + Check if input is batched or a single sample. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to check. + is_split_into_words (:obj:`bool`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise. + """ + return bool( + (not is_split_into_words and isinstance(texts, (list, tuple))) + or ( + is_split_into_words + and isinstance(texts, (list, tuple)) + and texts + and isinstance(texts[0], (list, tuple)) + ) + ) diff --git a/relik/inference/data/tokenizers/spacy_tokenizer.py b/relik/inference/data/tokenizers/spacy_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed49032febe4ef8f8953f7016430d36488ad0d1 --- /dev/null +++ b/relik/inference/data/tokenizers/spacy_tokenizer.py @@ -0,0 +1,194 @@ +import logging +from copy import deepcopy +from typing import Dict, List, Tuple, Union, Any + +import spacy + +# from ipa.common.utils import load_spacy +from spacy.cli.download import download as spacy_download +from spacy.tokens import Doc + +from relik.common.log import get_logger +from relik.inference.data.objects import Word +from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER +from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer + +logger = get_logger(level=logging.DEBUG) + +# Spacy and Stanza stuff + +LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {} + + +def load_spacy( + language: str, + pos_tags: bool = False, + lemma: bool = False, + parse: bool = False, + split_on_spaces: bool = False, +) -> spacy.Language: + """ + Download and load spacy model. + + Args: + language (:obj:`str`, defaults to :obj:`en`): + Language of the text to tokenize. + pos_tags (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs POS tagging with spacy model. + lemma (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs lemmatization with spacy model. + parse (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs dependency parsing with spacy model. + split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, will split by spaces without performing tokenization. + + Returns: + :obj:`spacy.Language`: The spacy model loaded. + """ + exclude = ["vectors", "textcat", "ner"] + if not pos_tags: + exclude.append("tagger") + if not lemma: + exclude.append("lemmatizer") + if not parse: + exclude.append("parser") + + # check if the model is already loaded + # if so, there is no need to reload it + spacy_params = (language, pos_tags, lemma, parse, split_on_spaces) + if spacy_params not in LOADED_SPACY_MODELS: + try: + spacy_tagger = spacy.load(language, exclude=exclude) + except OSError: + logger.warning( + "Spacy model '%s' not found. Downloading and installing.", language + ) + spacy_download(language) + spacy_tagger = spacy.load(language, exclude=exclude) + + # if everything is disabled, return only the tokenizer + # for faster tokenization + # TODO: is it really faster? + # TODO: check split_on_spaces behaviour if we don't do this if + if len(exclude) >= 6 and split_on_spaces: + spacy_tagger = spacy_tagger.tokenizer + LOADED_SPACY_MODELS[spacy_params] = spacy_tagger + + return LOADED_SPACY_MODELS[spacy_params] + + +class SpacyTokenizer(BaseTokenizer): + """ + A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects. + + Args: + language (:obj:`str`, optional, defaults to :obj:`en`): + Language of the text to tokenize. + return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs POS tagging with spacy model. + return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs lemmatization with spacy model. + return_deps (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs dependency parsing with spacy model. + use_gpu (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, will load the Stanza model on GPU. + """ + + def __init__( + self, + language: str = "en", + return_pos_tags: bool = False, + return_lemmas: bool = False, + return_deps: bool = False, + use_gpu: bool = False, + ): + super().__init__() + if language not in SPACY_LANGUAGE_MAPPER: + raise ValueError( + f"`{language}` language not supported. The supported " + f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}." + ) + if use_gpu: + # load the model on GPU + # if the GPU is not available or not correctly configured, + # it will rise an error + spacy.require_gpu() + self.spacy = load_spacy( + SPACY_LANGUAGE_MAPPER[language], + return_pos_tags, + return_lemmas, + return_deps, + ) + + def __call__( + self, + texts: Union[str, List[str], List[List[str]]], + is_split_into_words: bool = False, + **kwargs, + ) -> Union[List[Word], List[List[Word]]]: + """ + Tokenize the input into single words using SpaCy models. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to tag. It can be a single string, a batch of string and pre-tokenized strings. + is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`List[List[Word]]`: The input text tokenized in single words. + + Example:: + + >>> from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer + + >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True) + >>> spacy_tokenizer("Mary sold the car to John.") + + """ + # check if input is batched or a single sample + is_batched = self.check_is_batched(texts, is_split_into_words) + + if is_batched: + tokenized = self.tokenize_batch(texts, is_split_into_words) + else: + tokenized = self.tokenize(texts, is_split_into_words) + + return tokenized + + def tokenize(self, text: Union[str, List[str]], is_split_into_words: bool) -> Doc: + if is_split_into_words: + if isinstance(text, str): + text = text.split(" ") + elif isinstance(text, list): + text = text + else: + raise ValueError( + f"text must be either `str` or `list`, found: `{type(text)}`" + ) + spaces = [True] * len(text) + return self.spacy(Doc(self.spacy.vocab, words=text, spaces=spaces)) + return self.spacy(text) + + def tokenize_batch( + self, texts: Union[List[str], List[List[str]]], is_split_into_words: bool + ) -> list[Any] | list[Doc]: + try: + if is_split_into_words: + if isinstance(texts[0], str): + texts = [text.split(" ") for text in texts] + elif isinstance(texts[0], list): + texts = texts + else: + raise ValueError( + f"text must be either `str` or `list`, found: `{type(texts[0])}`" + ) + spaces = [[True] * len(text) for text in texts] + texts = [ + Doc(self.spacy.vocab, words=text, spaces=space) + for text, space in zip(texts, spaces) + ] + return list(self.spacy.pipe(texts)) + except AttributeError: + # a WhitespaceSpacyTokenizer has no `pipe()` method, we use simple for loop + return [self.spacy(tokens) for tokens in texts] diff --git a/relik/inference/data/window/__init__.py b/relik/inference/data/window/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/data/window/__pycache__/__init__.cpython-310.pyc b/relik/inference/data/window/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da4f8eba49d42899dab267e80db28e356bae4d6c Binary files /dev/null and b/relik/inference/data/window/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/inference/data/window/__pycache__/manager.cpython-310.pyc b/relik/inference/data/window/__pycache__/manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff0f7d792c82e53378e86b44b838addedb303d49 Binary files /dev/null and b/relik/inference/data/window/__pycache__/manager.cpython-310.pyc differ diff --git a/relik/inference/data/window/manager.py b/relik/inference/data/window/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..54374142b831a47d7aa5f93809ccc6dee788e164 --- /dev/null +++ b/relik/inference/data/window/manager.py @@ -0,0 +1,431 @@ +import collections +import itertools +from typing import Dict, List, Optional, Set, Tuple + +from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter +from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter +from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer +from relik.reader.data.relik_reader_sample import RelikReaderSample + + +class WindowManager: + def __init__( + self, tokenizer: BaseTokenizer, splitter: BaseSentenceSplitter | None = None + ) -> None: + self.tokenizer = tokenizer + self.splitter = splitter or BlankSentenceSplitter() + + def create_windows( + self, + documents: str | List[str], + window_size: int | None = None, + stride: int | None = None, + max_length: int | None = None, + doc_id: str | int | None = None, + doc_topic: str | None = None, + is_split_into_words: bool = False, + mentions: List[List[List[int]]] = None, + ) -> Tuple[List[RelikReaderSample], List[RelikReaderSample]]: + """ + Create windows from a list of documents. + + Args: + documents (:obj:`str` or :obj:`List[str]`): + The document(s) to split in windows. + window_size (:obj:`int`): + The size of the window. + stride (:obj:`int`): + The stride between two windows. + max_length (:obj:`int`, `optional`): + The maximum length of a window. + doc_id (:obj:`str` or :obj:`int`, `optional`): + The id of the document(s). + doc_topic (:obj:`str`, `optional`): + The topic of the document(s). + is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether the input is already pre-tokenized (e.g., split into words). If :obj:`False`, the + input will first be tokenized using the tokenizer, then the tokens will be split into words. + mentions (:obj:`List[List[List[int]]]`, `optional`): + The mentions of the document(s). + + Returns: + :obj:`List[RelikReaderSample]`: The windows created from the documents. + """ + # normalize input + if isinstance(documents, str) or is_split_into_words: + documents = [documents] + + # batch tokenize + documents_tokens = self.tokenizer( + documents, is_split_into_words=is_split_into_words + ) + + # set splitter params + if hasattr(self.splitter, "window_size"): + self.splitter.window_size = window_size or self.splitter.window_size + if hasattr(self.splitter, "window_stride"): + self.splitter.window_stride = stride or self.splitter.window_stride + + windowed_documents, windowed_blank_documents = [], [] + + if mentions is not None: + assert len(documents) == len( + mentions + ), f"documents and mentions should have the same length, got {len(documents)} and {len(mentions)}" + doc_iter = zip(documents, documents_tokens, mentions) + else: + doc_iter = zip(documents, documents_tokens, itertools.repeat([])) + + for infered_doc_id, (document, document_tokens, document_mentions) in enumerate( + doc_iter + ): + if doc_topic is None: + doc_topic = document_tokens[0] if len(document_tokens) > 0 else "" + + if doc_id is None: + doc_id = infered_doc_id + + splitted_document = self.splitter(document_tokens, max_length=max_length) + + document_windows = [] + for window_id, window in enumerate(splitted_document): + window_text_start = window[0].idx + window_text_end = window[-1].idx + len(window[-1].text) + if isinstance(document, str): + text = document[window_text_start:window_text_end] + else: + # window_text_start = window[0].idx + # window_text_end = window[-1].i + text = " ".join([w.text for w in window]) + sample = RelikReaderSample( + doc_id=doc_id, + window_id=window_id, + text=text, + tokens=[w.text for w in window], + words=[w.text for w in window], + doc_topic=doc_topic, + offset=window_text_start, + spans=[ + [m[0], m[1]] for m in document_mentions + if window_text_end > m[0] >= window_text_start and window_text_end >= m[1] >= window_text_start + ], + token2char_start={str(i): w.idx for i, w in enumerate(window)}, + token2char_end={ + str(i): w.idx + len(w.text) for i, w in enumerate(window) + }, + char2token_start={ + str(w.idx): w.i for i, w in enumerate(window) + }, + char2token_end={ + str(w.idx + len(w.text)): w.i for i, w in enumerate(window) + }, + ) + if mentions is not None and len(sample.spans) == 0: + windowed_blank_documents.append(sample) + else: + document_windows.append(sample) + + windowed_documents.extend(document_windows) + if mentions is not None: + return windowed_documents, windowed_blank_documents + else: + return windowed_documents, windowed_blank_documents + + def merge_windows( + self, windows: List[RelikReaderSample] + ) -> List[RelikReaderSample]: + windows_by_doc_id = collections.defaultdict(list) + for window in windows: + windows_by_doc_id[window.doc_id].append(window) + + merged_window_by_doc = { + doc_id: self._merge_doc_windows(doc_windows) + for doc_id, doc_windows in windows_by_doc_id.items() + } + + return list(merged_window_by_doc.values()) + + def _merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample: + if len(windows) == 1: + return windows[0] + + if len(windows) > 0 and getattr(windows[0], "offset", None) is not None: + windows = sorted(windows, key=(lambda x: x.offset)) + + window_accumulator = windows[0] + + for next_window in windows[1:]: + window_accumulator = self._merge_window_pair( + window_accumulator, next_window + ) + + return window_accumulator + + @staticmethod + def _merge_tokens( + window1: RelikReaderSample, window2: RelikReaderSample + ) -> Tuple[list, dict, dict]: + w1_tokens = window1.tokens[1:-1] + w2_tokens = window2.tokens[1:-1] + + # find intersection if any + tokens_intersection = 0 + for k in reversed(range(1, len(w1_tokens))): + if w1_tokens[-k:] == w2_tokens[:k]: + tokens_intersection = k + break + + final_tokens = ( + [window1.tokens[0]] # CLS + + w1_tokens + + w2_tokens[tokens_intersection:] + + [window1.tokens[-1]] # SEP + ) + + w2_starting_offset = len(w1_tokens) - tokens_intersection + + def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict: + final_t2c = dict() + final_t2c.update(t2c1) + for t, c in t2c2.items(): + t = int(t) + if t < tokens_intersection: + continue + final_t2c[str(t + w2_starting_offset)] = c + return final_t2c + + return ( + final_tokens, + merge_char_mapping(window1.token2char_start, window2.token2char_start), + merge_char_mapping(window1.token2char_end, window2.token2char_end), + ) + + @staticmethod + def _merge_words( + window1: RelikReaderSample, window2: RelikReaderSample + ) -> Tuple[list, dict, dict]: + w1_words = window1.words + w2_words = window2.words + + # find intersection if any + words_intersection = 0 + for k in reversed(range(1, len(w1_words))): + if w1_words[-k:] == w2_words[:k]: + words_intersection = k + break + + final_words = w1_words + w2_words[words_intersection:] + + w2_starting_offset = len(w1_words) - words_intersection + + def merge_word_mapping(t2c1: dict, t2c2: dict) -> dict: + final_t2c = dict() + if t2c1 is None: + t2c1 = dict() + if t2c2 is None: + t2c2 = dict() + final_t2c.update(t2c1) + for t, c in t2c2.items(): + t = int(t) + if t < words_intersection: + continue + final_t2c[str(t + w2_starting_offset)] = c + return final_t2c + + return ( + final_words, + merge_word_mapping(window1.token2word_start, window2.token2word_start), + merge_word_mapping(window1.token2word_end, window2.token2word_end), + ) + + @staticmethod + def _merge_span_annotation( + span_annotation1: List[list], span_annotation2: List[list] + ) -> List[list]: + uniq_store = set() + final_span_annotation_store = [] + for span_annotation in itertools.chain(span_annotation1, span_annotation2): + span_annotation_id = tuple(span_annotation) + if span_annotation_id not in uniq_store: + uniq_store.add(span_annotation_id) + final_span_annotation_store.append(span_annotation) + return sorted(final_span_annotation_store, key=lambda x: x[0]) + + @staticmethod + def _merge_predictions( + window1: RelikReaderSample, window2: RelikReaderSample + ) -> Tuple[Set[Tuple[int, int, str]], dict]: + # a RelikReaderSample should have a filed called `predicted_spans` + # that stores the span-level predictions, or a filed called + # `predicted_triples` that stores the triple-level predictions + + # span predictions + merged_span_predictions: Set = set() + merged_span_probabilities = dict() + # triple predictions + merged_triplet_predictions: Set = set() + merged_triplet_probs: Dict = dict() + + if ( + getattr(window1, "predicted_spans", None) is not None + and getattr(window2, "predicted_spans", None) is not None + ): + merged_span_predictions = set(window1.predicted_spans).union( + set(window2.predicted_spans) + ) + merged_span_predictions = sorted(merged_span_predictions) + # probabilities + for span_prediction, predicted_probs in itertools.chain( + window1.probs_window_labels_chars.items() + if window1.probs_window_labels_chars is not None + else [], + window2.probs_window_labels_chars.items() + if window2.probs_window_labels_chars is not None + else [], + ): + if span_prediction not in merged_span_probabilities: + merged_span_probabilities[span_prediction] = predicted_probs + + if ( + getattr(window1, "predicted_triples", None) is not None + and getattr(window2, "predicted_triples", None) is not None + ): + # try to merge the triples predictions + # add offset to the second window + window1_triplets = [ + ( + merged_span_predictions.index(window1.predicted_spans[t[0]]), + t[1], + merged_span_predictions.index(window1.predicted_spans[t[2]]), + t[3] + ) + for t in window1.predicted_triples + ] + window2_triplets = [ + ( + merged_span_predictions.index(window2.predicted_spans[t[0]]), + t[1], + merged_span_predictions.index(window2.predicted_spans[t[2]]), + t[3] + ) + for t in window2.predicted_triples + ] + merged_triplet_predictions = set(window1_triplets).union( + set(window2_triplets) + ) + merged_triplet_predictions = sorted(merged_triplet_predictions) + # for now no triplet probs, we don't need them for the moment + + return ( + merged_span_predictions, + merged_span_probabilities, + merged_triplet_predictions, + merged_triplet_probs, + ) + + @staticmethod + def _merge_candidates(window1: RelikReaderSample, window2: RelikReaderSample): + candidates = [] + windows_candidates = [] + + # TODO: retro-compatibility + if getattr(window1, "candidates", None) is not None: + candidates = window1.candidates + if getattr(window2, "candidates", None) is not None: + candidates += window2.candidates + + # TODO: retro-compatibility + if getattr(window1, "windows_candidates", None) is not None: + windows_candidates = window1.windows_candidates + if getattr(window2, "windows_candidates", None) is not None: + windows_candidates += window2.windows_candidates + + # TODO: add programmatically + span_candidates = [] + if getattr(window1, "span_candidates", None) is not None: + span_candidates = window1.span_candidates + if getattr(window2, "span_candidates", None) is not None: + span_candidates += window2.span_candidates + + triplet_candidates = [] + if getattr(window1, "triplet_candidates", None) is not None: + triplet_candidates = window1.triplet_candidates + if getattr(window2, "triplet_candidates", None) is not None: + triplet_candidates += window2.triplet_candidates + + # make them unique + candidates = list(set(candidates)) + windows_candidates = list(set(windows_candidates)) + + span_candidates = list(set(span_candidates)) + triplet_candidates = list(set(triplet_candidates)) + + return candidates, windows_candidates, span_candidates, triplet_candidates + + def _merge_window_pair( + self, + window1: RelikReaderSample, + window2: RelikReaderSample, + ) -> RelikReaderSample: + merging_output = dict() + + if getattr(window1, "doc_id", None) is not None: + assert window1.doc_id == window2.doc_id + + if getattr(window1, "offset", None) is not None: + assert ( + window1.offset < window2.offset + ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})" + + merging_output["doc_id"] = window1.doc_id + merging_output["offset"] = window2.offset + + m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens( + window1, window2 + ) + + m_words, m_token2word_start, m_token2word_end = self._merge_words( + window1, window2 + ) + + ( + m_candidates, + m_windows_candidates, + m_span_candidates, + m_triplet_candidates, + ) = self._merge_candidates(window1, window2) + + window_labels = None + if getattr(window1, "window_labels", None) is not None: + window_labels = self._merge_span_annotation( + window1.window_labels, window2.window_labels + ) + + ( + predicted_spans, + predicted_spans_probs, + predicted_triples, + predicted_triples_probs, + ) = self._merge_predictions(window1, window2) + + merging_output.update( + dict( + tokens=m_tokens, + words=m_words, + token2char_start=m_token2char_start, + token2char_end=m_token2char_end, + token2word_start=m_token2word_start, + token2word_end=m_token2word_end, + window_labels=window_labels, + candidates=m_candidates, + span_candidates=m_span_candidates, + triplet_candidates=m_triplet_candidates, + windows_candidates=m_windows_candidates, + predicted_spans=predicted_spans, + predicted_spans_probs=predicted_spans_probs, + predicted_triples=predicted_triples, + predicted_triples_probs=predicted_triples_probs, + ) + ) + + return RelikReaderSample(**merging_output) diff --git a/relik/inference/gerbil.py b/relik/inference/gerbil.py new file mode 100644 index 0000000000000000000000000000000000000000..86b5d2fdb1d7e0a31a9c746b0e9f1b9a945675c8 --- /dev/null +++ b/relik/inference/gerbil.py @@ -0,0 +1,269 @@ +import argparse +import json +import logging +import os +from pathlib import Path +import re +import sys +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Iterator, List, Optional, Tuple +from urllib import parse + +from relik.inference.annotator import Relik +from relik.inference.data.objects import RelikOutput + +# sys.path += ['../'] +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) + + +logger = logging.getLogger(__name__) + + +class GerbilAlbyManager: + def __init__( + self, + annotator: Optional[Relik] = None, + response_logger_dir: Optional[str] = None, + ) -> None: + self.annotator = annotator + self.response_logger_dir = response_logger_dir + self.predictions_counter = 0 + self.labels_mapping = None + + def annotate(self, document: str): + relik_output: RelikOutput = self.annotator( + document, retriever_batch_size=2, reader_batch_size=1 + ) + annotations = [(ss, se, l) for ss, se, l, _ in relik_output.spans] + if self.labels_mapping is not None: + return [ + (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations + ] + return annotations + + def set_mapping_file(self, mapping_file_path: str): + with open(mapping_file_path) as f: + labels_mapping = json.load(f) + self.labels_mapping = {v: k for k, v in labels_mapping.items()} + + def write_response_bundle( + self, + document: str, + new_document: str, + annotations: list, + mapped_annotations: list, + ) -> None: + if self.response_logger_dir is None: + return + + if not os.path.isdir(self.response_logger_dir): + os.mkdir(self.response_logger_dir) + + with open( + f"{self.response_logger_dir}/{self.predictions_counter}.json", "w" + ) as f: + out_json_obj = dict( + document=document, + new_document=new_document, + annotations=annotations, + mapped_annotations=mapped_annotations, + ) + + out_json_obj["span_annotations"] = [ + (ss, se, document[ss:se], label) for (ss, se, label) in annotations + ] + + out_json_obj["span_mapped_annotations"] = [ + (ss, se, new_document[ss:se], label) + for (ss, se, label) in mapped_annotations + ] + + json.dump(out_json_obj, f, indent=2) + + self.predictions_counter += 1 + + +manager = GerbilAlbyManager() + + +def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]: + pattern_subs = { + "-LPR- ": " (", + "-RPR-": ")", + "\n\n": "\n", + "-LRB-": "(", + "-RRB-": ")", + '","': ",", + } + + document_acc = document + curr_offset = 0 + char2offset = [] + + matchings = re.finditer("({})".format("|".join(pattern_subs)), document) + for span_matching in sorted(matchings, key=lambda x: x.span()[0]): + span_start, span_end = span_matching.span() + span_start -= curr_offset + span_end -= curr_offset + + span_text = document_acc[span_start:span_end] + span_sub = pattern_subs[span_text] + document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:] + + offset = len(span_text) - len(span_sub) + curr_offset += offset + + char2offset.append((span_start + len(span_sub), curr_offset)) + + return document_acc, char2offset + + +def map_back_annotations( + annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]] +) -> Iterator[Tuple[int, int, str]]: + def map_char(char_idx: int) -> int: + current_offset = 0 + for offset_idx, offset_value in char_mapping: + if char_idx >= offset_idx: + current_offset = offset_value + else: + break + return char_idx + current_offset + + for ss, se, label in annotations: + yield map_char(ss), map_char(se), label + + +def annotate(document: str) -> List[Tuple[int, int, str]]: + new_document, mapping = preprocess_document(document) + logger.info("Mapping: " + str(mapping)) + logger.info("Document: " + str(document)) + annotations = [ + (cs, ce, label.replace(" ", "_")) + for cs, ce, label in manager.annotate(new_document) + ] + logger.info("New document: " + str(new_document)) + mapped_annotations = ( + list(map_back_annotations(annotations, mapping)) + if len(mapping) > 0 + else annotations + ) + + logger.info( + "Annotations: " + + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations]) + ) + + manager.write_response_bundle( + document, new_document, mapped_annotations, annotations + ) + + if not all( + [ + new_document[ss:se] == document[mss:mse] + for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) + ] + ): + diff_mappings = [ + (new_document[ss:se], document[mss:mse]) + for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) + ] + return None + assert all( + [ + document[mss:mse] == new_document[ss:se] + for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) + ] + ), (mapped_annotations, annotations) + + return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations] + + +class GetHandler(BaseHTTPRequestHandler): + def do_POST(self): + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length) + self.send_response(200) + self.end_headers() + doc_text = read_json(post_data) + # try: + response = annotate(doc_text) + + self.wfile.write(bytes(json.dumps(response), "utf-8")) + return + + +def read_json(post_data): + data = json.loads(post_data.decode("utf-8")) + # logger.info("received data:", data) + text = data["text"] + # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]] + return text + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--relik-model-name", required=True) + parser.add_argument("--responses-log-dir") + parser.add_argument("--log-file", default="experiments/logging.txt") + parser.add_argument("--mapping-file") + return parser.parse_args() + + +def main(): + args = parse_args() + + responses_log_dir = Path(args.responses_log_dir) + responses_log_dir.mkdir(parents=True, exist_ok=True) + + # init manager + manager.response_logger_dir = args.responses_log_dir + manager.annotator = Relik.from_pretrained( + args.relik_model_name, + device="cuda", + # document_index_device="cpu", + # document_index_precision="fp32", + # reader_device="cpu", + precision="fp16", # , reader_device="cpu", reader_precision="fp32" + dataset_kwargs={"use_nme": True} + ) + + # print("Debugging, not using you relik model but an hardcoded one.") + # manager.annotator = Relik( + # question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", + # document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", + # reader="relik/reader/models/relik-reader-deberta-base-new-data", + # window_size=32, + # window_stride=16, + # candidates_preprocessing_fn=(lambda x: x.split("")[0].strip()), + # ) + + if args.mapping_file is not None: + manager.set_mapping_file(args.mapping_file) + + # port = 6654 + port = 5555 + server = HTTPServer(("localhost", port), GetHandler) + logger.info(f"Starting server at http://localhost:{port}") + + # Create a file handler and set its level + file_handler = logging.FileHandler(args.log_file) + file_handler.setLevel(logging.DEBUG) + + # Create a log formatter and set it on the handler + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + # Add the file handler to the logger + logger.addHandler(file_handler) + + try: + server.serve_forever() + except KeyboardInterrupt: + exit(0) + + +if __name__ == "__main__": + main() diff --git a/relik/inference/serve/__init__.py b/relik/inference/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/serve/backend/__init__.py b/relik/inference/serve/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/serve/backend/fastapi.py b/relik/inference/serve/backend/fastapi.py new file mode 100644 index 0000000000000000000000000000000000000000..13409090eec1d74f994c5084d59109b3a6885494 --- /dev/null +++ b/relik/inference/serve/backend/fastapi.py @@ -0,0 +1,122 @@ +import logging +import os +from pathlib import Path +from typing import List, Union +import psutil + +import torch + +from relik.common.utils import is_package_available +from relik.inference.annotator import Relik + +if not is_package_available("fastapi"): + raise ImportError( + "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." + ) +from fastapi import FastAPI, HTTPException, APIRouter + + +from relik.common.log import get_logger +from relik.inference.serve.backend.utils import ( + RayParameterManager, + ServerParameterManager, +) + +logger = get_logger(__name__, level=logging.INFO) + +VERSION = {} # type: ignore +with open( + Path(__file__).parent.parent.parent.parent / "version.py", "r" +) as version_file: + exec(version_file.read(), VERSION) + +# Env variables for server +SERVER_MANAGER = ServerParameterManager() +RAY_MANAGER = RayParameterManager() + + +class RelikServer: + def __init__( + self, + relik_pretrained: str | None = None, + device: str = "cpu", + retriever_device: str | None = None, + document_index_device: str | None = None, + reader_device: str | None = None, + precision: str | int | torch.dtype = 32, + retriever_precision: str | int | torch.dtype | None = None, + document_index_precision: str | int | torch.dtype | None = None, + reader_precision: str | int | torch.dtype | None = None, + annotation_type: str = "char", + **kwargs, + ): + num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False)) + torch.set_num_threads(num_threads) + logger.info(f"Torch is running on {num_threads} threads.") + # parameters + logger.info(f"RELIK_PRETRAINED: {relik_pretrained}") + self.relik_pretrained = relik_pretrained + logger.info(f"DEVICE: {device}") + self.device = device + if retriever_device is not None: + logger.info(f"RETRIEVER_DEVICE: {retriever_device}") + self.retriever_device = retriever_device or device + if document_index_device is not None: + logger.info(f"INDEX_DEVICE: {document_index_device}") + self.document_index_device = document_index_device or retriever_device + if reader_device is not None: + logger.info(f"READER_DEVICE: {reader_device}") + self.reader_device = reader_device + logger.info(f"PRECISION: {precision}") + self.precision = precision + if retriever_precision is not None: + logger.info(f"RETRIEVER_PRECISION: {retriever_precision}") + self.retriever_precision = retriever_precision or precision + if document_index_precision is not None: + logger.info(f"INDEX_PRECISION: {document_index_precision}") + self.document_index_precision = document_index_precision or precision + if reader_precision is not None: + logger.info(f"READER_PRECISION: {reader_precision}") + self.reader_precision = reader_precision or precision + logger.info(f"ANNOTATION_TYPE: {annotation_type}") + self.annotation_type = annotation_type + + self.relik = Relik.from_pretrained( + self.relik_pretrained, + device=self.device, + retriever_device=self.retriever_device, + document_index_device=self.document_index_device, + reader_device=self.reader_device, + precision=self.precision, + retriever_precision=self.retriever_precision, + document_index_precision=self.document_index_precision, + reader_precision=self.reader_precision, + ) + + self.router = APIRouter() + self.router.add_api_route("/api/relik", self.relik_endpoint, methods=["POST"]) + + logger.info("RelikServer initialized.") + + # @serve.batch() + async def __call__(self, text: List[str]) -> List: + return self.relik(text, annotation_type=self.annotation_type) + + # @app.post("/api/relik") + async def relik_endpoint(self, text: Union[str, List[str]]): + try: + # get predictions for the retriever + return await self(text) + except Exception as e: + # log the entire stack trace + logger.exception(e) + raise HTTPException(status_code=500, detail=f"Server Error: {e}") + + +app = FastAPI( + title="ReLiK", + version=VERSION["VERSION"], + description="ReLiK REST API", +) +server = RelikServer(**vars(SERVER_MANAGER)) +app.include_router(server.router) diff --git a/relik/inference/serve/backend/ray.py b/relik/inference/serve/backend/ray.py new file mode 100644 index 0000000000000000000000000000000000000000..24758c2a586ecb3f2871412cf8753c2ffb672a5d --- /dev/null +++ b/relik/inference/serve/backend/ray.py @@ -0,0 +1,165 @@ +import logging +import os +from pathlib import Path +from typing import List, Union +import psutil + +import torch + +from relik.common.utils import is_package_available +from relik.inference.annotator import Relik + +if not is_package_available("fastapi"): + raise ImportError( + "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." + ) +from fastapi import FastAPI, HTTPException + +if not is_package_available("ray"): + raise ImportError( + "Ray is not installed. Please install Ray with `pip install relik[serve]`." + ) +from ray import serve + +from relik.common.log import get_logger +from relik.inference.serve.backend.utils import ( + RayParameterManager, + ServerParameterManager, +) + +logger = get_logger(__name__, level=logging.INFO) + +VERSION = {} # type: ignore +with open( + Path(__file__).parent.parent.parent.parent / "version.py", "r" +) as version_file: + exec(version_file.read(), VERSION) + +# Env variables for server +SERVER_MANAGER = ServerParameterManager() +RAY_MANAGER = RayParameterManager() + +app = FastAPI( + title="ReLiK", + version=VERSION["VERSION"], + description="ReLiK REST API", +) + + +@serve.deployment( + ray_actor_options={ + "num_gpus": RAY_MANAGER.num_gpus + if ( + SERVER_MANAGER.device == "cuda" + or SERVER_MANAGER.retriever_device == "cuda" + or SERVER_MANAGER.reader_device == "cuda" + ) + else 0 + }, + autoscaling_config={ + "min_replicas": RAY_MANAGER.min_replicas, + "max_replicas": RAY_MANAGER.max_replicas, + }, +) +@serve.ingress(app) +class RelikServer: + def __init__( + self, + relik_pretrained: str | None = None, + device: str = "cpu", + retriever_device: str | None = None, + document_index_device: str | None = None, + reader_device: str | None = None, + precision: str | int | torch.dtype = 32, + retriever_precision: str | int | torch.dtype | None = None, + document_index_precision: str | int | torch.dtype | None = None, + reader_precision: str | int | torch.dtype | None = None, + annotation_type: str = "char", + retriever_batch_size: int = 32, + reader_batch_size: int = 32, + relik_config_override: dict | None = None, + **kwargs, + ): + num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False)) + torch.set_num_threads(num_threads) + logger.info(f"Torch is running on {num_threads} threads.") + + # parameters + logger.info(f"RELIK_PRETRAINED: {relik_pretrained}") + self.relik_pretrained = relik_pretrained + + if relik_config_override is None: + relik_config_override = {} + logger.info(f"RELIK_CONFIG_OVERRIDE: {relik_config_override}") + self.relik_config_override = relik_config_override + + logger.info(f"DEVICE: {device}") + self.device = device + + if retriever_device is not None: + logger.info(f"RETRIEVER_DEVICE: {retriever_device}") + self.retriever_device = retriever_device or device + + if document_index_device is not None: + logger.info(f"INDEX_DEVICE: {document_index_device}") + self.document_index_device = document_index_device or retriever_device + + if reader_device is not None: + logger.info(f"READER_DEVICE: {reader_device}") + self.reader_device = reader_device + + logger.info(f"PRECISION: {precision}") + self.precision = precision + + if retriever_precision is not None: + logger.info(f"RETRIEVER_PRECISION: {retriever_precision}") + self.retriever_precision = retriever_precision or precision + + if document_index_precision is not None: + logger.info(f"INDEX_PRECISION: {document_index_precision}") + self.document_index_precision = document_index_precision or precision + + if reader_precision is not None: + logger.info(f"READER_PRECISION: {reader_precision}") + self.reader_precision = reader_precision or precision + + logger.info(f"ANNOTATION_TYPE: {annotation_type}") + self.annotation_type = annotation_type + + self.relik = Relik.from_pretrained( + self.relik_pretrained, + device=self.device, + retriever_device=self.retriever_device, + document_index_device=self.document_index_device, + reader_device=self.reader_device, + precision=self.precision, + retriever_precision=self.retriever_precision, + document_index_precision=self.document_index_precision, + reader_precision=self.reader_precision, + **self.relik_config_override, + ) + + self.retriever_batch_size = retriever_batch_size + self.reader_batch_size = reader_batch_size + + # @serve.batch() + async def handle_batch(self, text: List[str]) -> List: + return self.relik( + text, + annotation_type=self.annotation_type, + retriever_batch_size=self.retriever_batch_size, + reader_batch_size=self.reader_batch_size, + ) + + @app.post("/api/relik") + async def relik_endpoint(self, text: Union[str, List[str]]): + try: + # get predictions for the retriever + return await self.handle_batch(text) + except Exception as e: + # log the entire stack trace + logger.exception(e) + raise HTTPException(status_code=500, detail=f"Server Error: {e}") + + +server = RelikServer.bind(**vars(SERVER_MANAGER)) diff --git a/relik/inference/serve/backend/utils.py b/relik/inference/serve/backend/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f81da69f4938935fcdcc450d71555be89a3fde44 --- /dev/null +++ b/relik/inference/serve/backend/utils.py @@ -0,0 +1,38 @@ +import ast +import os +from dataclasses import dataclass + + +@dataclass +class ServerParameterManager: + relik_pretrained: str = os.environ.get("RELIK_PRETRAINED", None) + device: str = os.environ.get("DEVICE", "cpu") + retriever_device: str | None = os.environ.get("RETRIEVER_DEVICE", None) + document_index_device: str | None = os.environ.get("INDEX_DEVICE", None) + reader_device: str | None = os.environ.get("READER_DEVICE", None) + precision: int | str | None = os.environ.get("PRECISION", "fp32") + retriever_precision: int | str | None = os.environ.get("RETRIEVER_PRECISION", None) + document_index_precision: int | str | None = os.environ.get("INDEX_PRECISION", None) + reader_precision: int | str | None = os.environ.get("READER_PRECISION", None) + annotation_type: str = os.environ.get("ANNOTATION_TYPE", "char") + question_encoder: str = os.environ.get("QUESTION_ENCODER", None) + passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None) + document_index: str = os.environ.get("DOCUMENT_INDEX", None) + reader_encoder: str = os.environ.get("READER_ENCODER", None) + top_k: int = int(os.environ.get("TOP_K", 100)) + use_faiss: bool = os.environ.get("USE_FAISS", False) + retriever_batch_size: int = int(os.environ.get("RETRIEVER_BATCH_SIZE", 32)) + reader_batch_size: int = int(os.environ.get("READER_BATCH_SIZE", 32)) + window_size: int = int(os.environ.get("WINDOW_SIZE", 32)) + window_stride: int = int(os.environ.get("WINDOW_SIZE", 16)) + split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False) + # relik_config_override: dict = ast.literal_eval( + # os.environ.get("RELIK_CONFIG_OVERRIDE", None) + # ) + + +class RayParameterManager: + def __init__(self) -> None: + self.num_gpus = int(os.environ.get("NUM_GPUS", 1)) + self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1)) + self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1)) diff --git a/relik/inference/serve/frontend/__init__.py b/relik/inference/serve/frontend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/serve/frontend/relik_front.py b/relik/inference/serve/frontend/relik_front.py new file mode 100644 index 0000000000000000000000000000000000000000..e0edfd186b080aac9b0029fd2258a1cafe77c09c --- /dev/null +++ b/relik/inference/serve/frontend/relik_front.py @@ -0,0 +1,229 @@ +import os +from pathlib import Path + +import requests +import streamlit as st +from spacy import displacy +from streamlit_extras.badges import badge +from streamlit_extras.stylable_container import stylable_container + +RELIK = os.getenv("RELIK", "localhost:8000/api/entities") + +import random + + +def get_random_color(ents): + colors = {} + random_colors = generate_pastel_colors(len(ents)) + for ent in ents: + colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) + return colors + + +def floatrange(start, stop, steps): + if int(steps) == 1: + return [stop] + return [ + start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) + ] + + +def hsl_to_rgb(h, s, l): + def hue_2_rgb(v1, v2, v_h): + while v_h < 0.0: + v_h += 1.0 + while v_h > 1.0: + v_h -= 1.0 + if 6 * v_h < 1.0: + return v1 + (v2 - v1) * 6.0 * v_h + if 2 * v_h < 1.0: + return v2 + if 3 * v_h < 2.0: + return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 + return v1 + + # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1." + # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1." + + r, b, g = (l * 255,) * 3 + if s != 0.0: + if l < 0.5: + var_2 = l * (1.0 + s) + else: + var_2 = (l + s) - (s * l) + var_1 = 2.0 * l - var_2 + r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) + g = 255 * hue_2_rgb(var_1, var_2, h) + b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) + + return int(round(r)), int(round(g)), int(round(b)) + + +def generate_pastel_colors(n): + """Return different pastel colours. + + Input: + n (integer) : The number of colors to return + + Output: + A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) + + Example: + >>> print generate_pastel_colors(5) + ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] + """ + if n == 0: + return [] + + # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space) + start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue + saturation = 1.0 + lightness = 0.8 + # We take points around the chromatic circle (hue): + # (Note: we generate n+1 colors, then drop the last one ([:-1]) because + # it equals the first one (hue 0 = hue 1)) + return [ + "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) + for hue in floatrange(start_hue, start_hue + 1, n + 1) + ][:-1] + + +def set_sidebar(css): + white_link_wrapper = "{}" + with st.sidebar: + st.markdown(f"", unsafe_allow_html=True) + st.image( + "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg", + use_column_width=True, + ) + st.markdown("## ReLiK") + st.write( + f""" + - {white_link_wrapper.format("#", "  Paper")} + - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "  GitHub")} + - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "  Docker Hub")} + """, + unsafe_allow_html=True, + ) + st.markdown("## Sapienza NLP") + st.write( + f""" + - {white_link_wrapper.format("https://nlp.uniroma1.it", "  Webpage")} + - {white_link_wrapper.format("https://github.com/SapienzaNLP", "  GitHub")} + - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "  Twitter")} + - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "  LinkedIn")} + """, + unsafe_allow_html=True, + ) + + +def get_el_annotations(response): + # swap labels key with ents + response["ents"] = response.pop("labels") + label_in_text = set(l["label"] for l in response["ents"]) + options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} + return response, options + + +def set_intro(css): + # intro + st.markdown("# ReLik") + st.markdown( + "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget" + ) + # st.markdown( + # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API " + # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by " + # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), " + # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)." + # ) + badge(type="github", name="sapienzanlp/relik") + badge(type="pypi", name="relik") + + +def run_client(): + with open(Path(__file__).parent / "style.css") as f: + css = f.read() + + st.set_page_config( + page_title="ReLik", + page_icon="🦮", + layout="wide", + ) + set_sidebar(css) + set_intro(css) + + # text input + text = st.text_area( + "Enter Text Below:", + value="Obama went to Rome for a quick vacation.", + height=200, + max_chars=500, + ) + + with stylable_container( + key="annotate_button", + css_styles=""" + button { + background-color: #802433; + color: white; + border-radius: 25px; + } + """, + ): + submit = st.button("Annotate") + # submit = st.button("Run") + + # ReLik API call + if submit: + text = text.strip() + if text: + st.markdown("####") + st.markdown("#### Entity Linking") + with st.spinner(text="In progress"): + response = requests.post(RELIK, json=text) + if response.status_code != 200: + st.error("Error: {}".format(response.status_code)) + else: + response = response.json() + + # Entity Linking + # with stylable_container( + # key="container_with_border", + # css_styles=""" + # { + # border: 1px solid rgba(49, 51, 63, 0.2); + # border-radius: 0.5rem; + # padding: 0.5rem; + # padding-bottom: 2rem; + # } + # """, + # ): + # st.markdown("##") + dict_of_ents, options = get_el_annotations(response=response) + display = displacy.render( + dict_of_ents, manual=True, style="ent", options=options + ) + display = display.replace("\n", " ") + # wsd_display = re.sub( + # r"(wiki::\d+\w)", + # r"\g<1>".format( + # language.upper() + # ), + # wsd_display, + # ) + with st.container(): + st.write(display, unsafe_allow_html=True) + + st.markdown("####") + st.markdown("#### Relation Extraction") + + with st.container(): + st.write("Coming :)", unsafe_allow_html=True) + + else: + st.error("Please enter some text.") + + +if __name__ == "__main__": + run_client() diff --git a/relik/inference/serve/frontend/relik_re_front.py b/relik/inference/serve/frontend/relik_re_front.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc6176fd2fca9e370f0d947d0851ea602950d22 --- /dev/null +++ b/relik/inference/serve/frontend/relik_re_front.py @@ -0,0 +1,251 @@ +import os +from datetime import datetime as dt +from pathlib import Path + +import requests +import spacy +import streamlit as st +import streamlit.components.v1 as components +from pyvis.network import Network +from spacy import displacy +from spacy.tokens import Doc +from streamlit_extras.badges import badge +from streamlit_extras.stylable_container import stylable_container +from utils import get_random_color, visualize_parser + +from relik import Relik + +# RELIK = os.getenv("RELIK", "localhost:8000/api/relik") + +state_variables = {"has_run_free": False, "html_free": ""} + + +def init_state_variables(): + for k, v in state_variables.items(): + if k not in st.session_state: + st.session_state[k] = v + + +def free_reset_session(): + for k in state_variables: + del st.session_state[k] + + +def generate_graph(dict_ents, response, filename, options): + g = Network( + width="720px", + height="600px", + directed=True, + notebook=False, + bgcolor="#222222", + font_color="white", + ) + g.barnes_hut( + gravity=-3000, + central_gravity=0.3, + spring_length=50, + spring_strength=0.001, + damping=0.09, + overlap=0, + ) + for ent in dict_ents: + g.add_node( + dict_ents[ent][0], + label=dict_ents[ent][1], + color=options["colors"][dict_ents[ent][0]], + title=dict_ents[ent][0], + size=15, + labelHighlightBold=True, + ) + + for rel in response.triples: + g.add_edge( + dict_ents[(rel.subject.start, rel.subject.end)][0], + dict_ents[(rel.object.start, rel.object.end)][0], + label=rel.label, + title=rel.label, + ) + g.show(filename, notebook=False) + + +def set_sidebar(css): + white_link_wrapper = ( + "{}" + ) + with st.sidebar: + st.markdown(f"", unsafe_allow_html=True) + st.image( + "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg", + use_column_width=True, + ) + st.markdown("## ReLiK") + st.write( + f""" + - {white_link_wrapper.format("#", "  Paper")} + - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "  GitHub")} + - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "  Docker Hub")} + """, + unsafe_allow_html=True, + ) + st.markdown("## Sapienza NLP") + st.write( + f""" + - {white_link_wrapper.format("https://nlp.uniroma1.it", "  Webpage")} + - {white_link_wrapper.format("https://github.com/SapienzaNLP", "  GitHub")} + - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "  Twitter")} + - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "  LinkedIn")} + """, + unsafe_allow_html=True, + ) + + +def get_span_annotations(response): + el_link_wrapper = ( + "" + " " + "{}" + ) + tokens = response.tokens + labels = ["O"] * len(tokens) + dict_ents = {} + # make BIO labels + for idx, span in enumerate(response.spans): + labels[span.start] = ( + "B-" + span.label + str(idx) + if span.label == "NME" + else "B-" + el_link_wrapper.format(span.label.replace(" ", "_"), span.label) + ) + for i in range(span.start + 1, span.end): + labels[i] = ( + "I-" + span.label + str(idx) + if span.label == "NME" + else "I-" + + el_link_wrapper.format(span.label.replace(" ", "_"), span.label) + ) + dict_ents[(span.start, span.end)] = ( + span.label + str(idx), + " ".join(tokens[span.start : span.end]), + ) + unique_labels = set(w[2:] for w in labels if w != "O") + options = {"ents": unique_labels, "colors": get_random_color(unique_labels)} + return tokens, labels, options, dict_ents + + +@st.cache_resource() +def load_model(): + return Relik.from_pretrained("riccorl/relik-relation-extraction-nyt-small") + + +def set_intro(css): + # intro + st.markdown("# ReLik") + st.markdown( + "### Retrieve, Read and LinK: Fast and Accurate Entity Linking " + "and Relation Extraction on an Academic Budget" + ) + # st.markdown( + # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API " + # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal + # _Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), + # which will be presented at LREC 2022 by " + # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), " + # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), + # and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)." + # ) + badge(type="github", name="sapienzanlp/relik") + badge(type="pypi", name="relik") + + +def run_client(): + with open(Path(__file__).parent / "style.css") as f: + css = f.read() + + st.set_page_config( + page_title="ReLik", + page_icon="🦮", + layout="wide", + ) + set_sidebar(css) + set_intro(css) + + # text input + text = st.text_area( + "Enter Text Below:", + value="Michael Jordan was one of the best players in the NBA.", + height=200, + max_chars=1500, + ) + + with stylable_container( + key="annotate_button", + css_styles=""" + button { + background-color: #802433; + color: white; + border-radius: 25px; + } + """, + ): + submit = st.button("Annotate") + + if "relik_model" not in st.session_state.keys(): + st.session_state["relik_model"] = load_model() + relik_model = st.session_state["relik_model"] + init_state_variables() + # ReLik API call + + # spacy for span visualization + nlp = spacy.blank("xx") + + if submit: + text = text.strip() + if text: + st.session_state["filename"] = str(dt.now().timestamp() * 1000) + ".html" + + with st.spinner(text="In progress"): + response = relik_model(text, annotation_type="word", num_workers=0) + # response = requests.post(RELIK, json=text) + # if response.status_code != 200: + # st.error("Error: {}".format(response.status_code)) + # else: + # response = response.json() + + # EL + st.markdown("####") + st.markdown("#### Entities") + tokens, labels, options, dict_ents = get_span_annotations( + response=response + ) + doc = Doc(nlp.vocab, words=tokens, ents=labels) + display_el = displacy.render(doc, style="ent", options=options) + display_el = display_el.replace("\n", " ") + # heuristic, prevents split of annotation decorations + display_el = display_el.replace( + "border-radius: 0.35em;", + "border-radius: 0.35em; white-space: nowrap;", + ) + with st.container(): + st.write(display_el, unsafe_allow_html=True) + + # RE + generate_graph( + dict_ents, response, st.session_state["filename"], options + ) + HtmlFile = open(st.session_state["filename"], "r", encoding="utf-8") + source_code = HtmlFile.read() + st.session_state["html_free"] = source_code + os.remove(st.session_state["filename"]) + st.session_state["has_run_free"] = True + else: + st.error("Please enter some text.") + + if st.session_state["has_run_free"]: + st.markdown("#### Relations") + components.html(st.session_state["html_free"], width=720, height=600) + + +if __name__ == "__main__": + run_client() diff --git a/relik/inference/serve/frontend/style.css b/relik/inference/serve/frontend/style.css new file mode 100644 index 0000000000000000000000000000000000000000..31f0d182cfd9b2636d5db5cbd0e7a1339ed5d1c3 --- /dev/null +++ b/relik/inference/serve/frontend/style.css @@ -0,0 +1,33 @@ +/* Sidebar */ +.eczjsme11 { + background-color: #802433; +} + +.st-emotion-cache-10oheav h2 { + color: white; +} + +.st-emotion-cache-10oheav li { + color: white; +} + +/* Main */ +a:link { + text-decoration: none; + color: white; +} + +a:visited { + text-decoration: none; + color: white; +} + +a:hover { + text-decoration: none; + color: rgba(255, 255, 255, 0.871); +} + +a:active { + text-decoration: none; + color: white; +} \ No newline at end of file diff --git a/relik/inference/serve/frontend/utils.py b/relik/inference/serve/frontend/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..13781f3672e43eeca6df4e3a3c11ababd69e2b1b --- /dev/null +++ b/relik/inference/serve/frontend/utils.py @@ -0,0 +1,132 @@ +import base64 +import random +from typing import Dict, List, Optional, Union + +import spacy +import streamlit as st +from spacy import displacy + + +def get_html(html: str): + """Convert HTML so it can be rendered.""" + WRAPPER = """
{}
""" + # Newlines seem to mess with the rendering + html = html.replace("\n", " ") + return WRAPPER.format(html) + + +def get_svg(svg: str, style: str = "", wrap: bool = True): + """Convert an SVG to a base64-encoded image.""" + b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") + html = f'' + return get_html(html) if wrap else html + + +def visualize_parser( + doc: Union[spacy.tokens.Doc, List[Dict[str, str]]], + *, + title: Optional[str] = None, + key: Optional[str] = None, + manual: bool = False, + displacy_options: Optional[Dict] = None, +) -> None: + """Visualizer for dependency parses. + + doc (Doc, List): The document to visualize. + key (str): Key used for the streamlit component for selecting labels. + title (str): The title displayed at the top of the parser visualization. + manual (bool): Flag signifying whether the doc argument is a Doc object or a List of Dicts containing parse information. + displacy_options (Dict): Dictionary of options to be passed to the displacy render method for generating the HTML to be rendered. + See: https://spacy.io/api/top-level#options-dep + """ + if displacy_options is None: + displacy_options = dict() + if title: + st.header(title) + docs = [doc] + # add selected options to options provided by user + # `options` from `displacy_options` are overwritten by user provided + # options from the checkboxes + for sent in docs: + html = displacy.render( + sent, options=displacy_options, style="dep", manual=manual + ) + # Double newlines seem to mess with the rendering + html = html.replace("\n\n", "\n") + st.write(get_svg(html), unsafe_allow_html=True) + + +def get_random_color(ents): + colors = {} + random_colors = generate_pastel_colors(len(ents)) + for ent in ents: + colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) + return colors + + +def floatrange(start, stop, steps): + if int(steps) == 1: + return [stop] + return [ + start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) + ] + + +def hsl_to_rgb(h, s, l): + def hue_2_rgb(v1, v2, v_h): + while v_h < 0.0: + v_h += 1.0 + while v_h > 1.0: + v_h -= 1.0 + if 6 * v_h < 1.0: + return v1 + (v2 - v1) * 6.0 * v_h + if 2 * v_h < 1.0: + return v2 + if 3 * v_h < 2.0: + return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 + return v1 + + # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1." + # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1." + + r, b, g = (l * 255,) * 3 + if s != 0.0: + if l < 0.5: + var_2 = l * (1.0 + s) + else: + var_2 = (l + s) - (s * l) + var_1 = 2.0 * l - var_2 + r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) + g = 255 * hue_2_rgb(var_1, var_2, h) + b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) + + return int(round(r)), int(round(g)), int(round(b)) + + +def generate_pastel_colors(n): + """Return different pastel colours. + + Input: + n (integer) : The number of colors to return + + Output: + A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) + + Example: + >>> print generate_pastel_colors(5) + ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] + """ + if n == 0: + return [] + + # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space) + start_hue = 0.0 # 0=red 1/3=0.333=green 2/3=0.666=blue + saturation = 1.0 + lightness = 0.9 + # We take points around the chromatic circle (hue): + # (Note: we generate n+1 colors, then drop the last one ([:-1]) because + # it equals the first one (hue 0 = hue 1)) + return [ + "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) + for hue in floatrange(start_hue, start_hue + 1, n + 1) + ][:-1] diff --git a/relik/reader/__init__.py b/relik/reader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/__pycache__/__init__.cpython-310.pyc b/relik/reader/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..015f773b14b826b47166dffa2ef3854d79907d40 Binary files /dev/null and b/relik/reader/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/reader/__pycache__/relik_reader_core.cpython-310.pyc b/relik/reader/__pycache__/relik_reader_core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9646e708959fdcc62aa4b35d174978d37e7ab5dd Binary files /dev/null and b/relik/reader/__pycache__/relik_reader_core.cpython-310.pyc differ diff --git a/relik/reader/__pycache__/relik_reader_predictor.cpython-310.pyc b/relik/reader/__pycache__/relik_reader_predictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfcd1bc802a25a0296783704477eb12129e217b9 Binary files /dev/null and b/relik/reader/__pycache__/relik_reader_predictor.cpython-310.pyc differ diff --git a/relik/reader/conf/base.yaml b/relik/reader/conf/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..011070cfc7fb6e50ee102dc754af25fe20f7dee7 --- /dev/null +++ b/relik/reader/conf/base.yaml @@ -0,0 +1,14 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: relik-reader-deberta-base-retriever-relik-entity-linking-aida-wikipedia-det # -start-end-mask-0.001 # used to name the model in wandb and output dir +project_name: relik-reader # used to name the project in wandb +offline: false # if true, wandb will not be used + +defaults: + - _self_ + - training: base + - model: base + - data: base diff --git a/relik/reader/conf/config.yaml b/relik/reader/conf/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e743b167f57caac164ea8b82b5add9487ab4454 --- /dev/null +++ b/relik/reader/conf/config.yaml @@ -0,0 +1,14 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: relik-reader-deberta-base-retriever-relik-entity-linking-aida-wikipedia-twin-no-pere # -start-end-mask-0.001 # used to name the model in wandb and output dir +project_name: relik-reader # used to name the project in wandb +offline: false # if true, wandb will not be used + +defaults: + - _self_ + - training: base + - model: base + - data: base diff --git a/relik/reader/conf/data/base.yaml b/relik/reader/conf/data/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a596b053a3704b3a72bd7adfb655ba9a26dd132e --- /dev/null +++ b/relik/reader/conf/data/base.yaml @@ -0,0 +1,21 @@ +train_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/train_windowed_candidates.jsonl" +val_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/testa_windowed_candidates.jsonl" + +train_dataset: + _target_: "relik.reader.data.relik_reader_data.RelikDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: 0.5 + random_drop_gold_candidates: 0.05 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + special_symbols: null + +val_dataset: + _target_: "relik.reader.data.relik_reader_data.RelikDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + for_inference: True + special_symbols: null diff --git a/relik/reader/conf/data/bio.yaml b/relik/reader/conf/data/bio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a3ca9bacb5f54349cf08f8c11ff128ab8f3d54ba --- /dev/null +++ b/relik/reader/conf/data/bio.yaml @@ -0,0 +1,36 @@ +train_dataset_path: "data/reader/bio-rel/train.balanced.candidates.jsonl" +val_dataset_path: "data/reader/bio-rel/dev.en.relik.sent.re.el_people.candidates_windowed.jsonl" +test_dataset_path: "data/reader/bio-rel/test.en.relik.sent.re.el_people.candidates_windowed.jsonl" + +train_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: 1.0 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + max_length: 1024 + max_triplets: 8 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null + section_size: 1000000 + use_nme: False + sorting_fields: + - "predictable_candidates" +val_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: False + for_inference: True + use_nme: False + max_triplets: 8 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null diff --git a/relik/reader/conf/data/bio_orig.yaml b/relik/reader/conf/data/bio_orig.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98a8fa73132ad26c07890c13eefd742ff6d88020 --- /dev/null +++ b/relik/reader/conf/data/bio_orig.yaml @@ -0,0 +1,36 @@ +train_dataset_path: "data/reader/bio-rel/train.en.relik.sent.re.el_orig.candidates,jsonl" +val_dataset_path: "data/reader/bio-rel/dev.en.relik.sent.re.el_orig.candidates,jsonl" +test_dataset_path: "data/reader/bio-rel/test.en.relik.sent.re.el_orig.candidates,jsonl" + +train_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: 1.0 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + max_length: 1024 + max_triplets: 16 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null + section_size: 1000000 + use_nme: False + sorting_fields: + - "predictable_candidates" +val_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: False + for_inference: True + use_nme: False + max_triplets: 16 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null diff --git a/relik/reader/conf/data/crossre.yaml b/relik/reader/conf/data/crossre.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2d810dd6798b84f9acea0cd1be5d55d4d010c54 --- /dev/null +++ b/relik/reader/conf/data/crossre.yaml @@ -0,0 +1,36 @@ +train_dataset_path: "data/reader/crossre/train.swapped.candidates.jsonl" +val_dataset_path: "data/reader/crossre/dev.swapped.candidates.jsonl" +test_dataset_path: "data/reader/crossre/test.swapped.candidates.jsonl" + +train_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: 1.0 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + max_length: 1024 + max_triplets: 8 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null + section_size: 1000000 + use_nme: False + sorting_fields: + - "predictable_candidates" +val_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: False + for_inference: True + use_nme: False + max_triplets: 8 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null diff --git a/relik/reader/conf/data/large.yaml b/relik/reader/conf/data/large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81dddebe3ba892dd16e2aefb20dec2aaa18766c1 --- /dev/null +++ b/relik/reader/conf/data/large.yaml @@ -0,0 +1,21 @@ +train_dataset_path: "/home/carlos/amr-parsing-master/sentence-similarity/retriever/dataset/generative/single/reader/intervention_question_variations_predictions.jsonl" +val_dataset_path: "/home/carlos/amr-parsing-master/sentence-similarity/retriever/dataset/generative/single/reader/intervention_question_variations_test_predictions.jsonl" + +train_dataset: + _target_: "relik.reader.data.relik_reader_data.RelikDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: 0.5 + random_drop_gold_candidates: 0.05 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 2048 + special_symbols: null + +val_dataset: + _target_: "relik.reader.data.relik_reader_data.RelikDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + for_inference: True + special_symbols: null diff --git a/relik/reader/conf/data/nyt.yaml b/relik/reader/conf/data/nyt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..efc9d43b04dc2c4626a964f74e31bb95d1f82f2e --- /dev/null +++ b/relik/reader/conf/data/nyt.yaml @@ -0,0 +1,36 @@ +train_dataset_path: "data/reader/nyt/train.relik.candidates.jsonl" +val_dataset_path: "data/reader/nyt/valid.relik.candidates.jsonl" +test_dataset_path: "data/reader/nyt/test.relik.candidates.jsonl" + +train_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: 1.0 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + max_length: 1024 + max_triplets: 24 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null + section_size: 1000000 + use_nme: False + sorting_fields: + - "predictable_candidates" +val_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: False + for_inference: True + use_nme: False + max_triplets: 24 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null diff --git a/relik/reader/conf/data/nyt_bio.yaml b/relik/reader/conf/data/nyt_bio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b35317d52fac72623717f35657fa37c920971bd --- /dev/null +++ b/relik/reader/conf/data/nyt_bio.yaml @@ -0,0 +1,36 @@ +train_dataset_path: "data/reader/bio_nyt/train.relik.candidates_people.jsonl" +val_dataset_path: "data/reader/bio_nyt/valid.relik.candidates_people.jsonl" +test_dataset_path: "data/reader/bio_nyt/test.relik.candidates_people.jsonl" + +train_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: 1.0 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + max_length: 1024 + max_triplets: 8 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null + section_size: 1000000 + use_nme: False + sorting_fields: + - "predictable_candidates" +val_dataset: + _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: False + for_inference: True + use_nme: False + max_triplets: 8 + max_spans: + min_length: -1 + special_symbols: null + special_symbols_types: null diff --git a/relik/reader/conf/data/re.yaml b/relik/reader/conf/data/re.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17c18ee886021bc0157edb156020409fdd799fbc --- /dev/null +++ b/relik/reader/conf/data/re.yaml @@ -0,0 +1,54 @@ +train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl" +val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl" +test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl" + +relations_definitions: + /people/person/nationality: "nationality" + /sports/sports_team/location: "sports team location" + /location/country/administrative_divisions: "administrative divisions" + /business/company/major_shareholders: "shareholders" + /people/ethnicity/people: "ethnicity" + /people/ethnicity/geographic_distribution: "geographic distributi6on" + /business/company_shareholder/major_shareholder_of: "major shareholder" + /location/location/contains: "location" + /business/company/founders: "founders" + /business/person/company: "company" + /business/company/advisors: "advisor" + /people/deceased_person/place_of_death: "place of death" + /business/company/industry: "industry" + /people/person/ethnicity: "ethnic background" + /people/person/place_of_birth: "place of birth" + /location/administrative_division/country: "country of an administration division" + /people/person/place_lived: "place lived" + /sports/sports_team_location/teams: "sports team" + /people/person/children: "child" + /people/person/religion: "religion" + /location/neighborhood/neighborhood_of: "neighborhood" + /location/country/capital: "capital" + /business/company/place_founded: "company founded location" + /people/person/profession: "occupation" + +train_dataset: + _target_: "relik.reader.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: 1.0 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + min_length: -1 + special_symbols: null + relations_definitions: ${data.relations_definitions} + sorting_fields: + - "predictable_candidates" +val_dataset: + _target_: "relik.reader.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: False + for_inference: True + min_length: -1 + special_symbols: null + relations_definitions: ${data.relations_definitions} diff --git a/relik/reader/conf/data/small.yaml b/relik/reader/conf/data/small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..940f7149b2284b8979ae52329f08c51f6fe0f5b8 --- /dev/null +++ b/relik/reader/conf/data/small.yaml @@ -0,0 +1,22 @@ +train_dataset_path: "/home/carlos/amr-parsing-master/sentence-similarity/retriever/dataset/generative/single/reader/training/econie-training.100.jsonl" +val_dataset_path: "/home/carlos/amr-parsing-master/sentence-similarity/retriever/dataset/generative/single/reader/test/econie-test.100.jsonl" + +train_dataset: + _target_: "relik.reader.data.relik_reader_data.RelikDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: 0.5 + random_drop_gold_candidates: 0.05 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 2048 + special_symbols: null + prebatch: False + +val_dataset: + _target_: "relik.reader.data.relik_reader_data.RelikDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + for_inference: True + special_symbols: null diff --git a/relik/reader/conf/large.yaml b/relik/reader/conf/large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5996ca086ef95df5c6d2bae4dfef3c2e5c787143 --- /dev/null +++ b/relik/reader/conf/large.yaml @@ -0,0 +1,14 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: relik-reader-deberta-large-retriever-relik-entity-linking-aida-wikipedia # -start-end-mask-0.001 # used to name the model in wandb and output dir +project_name: relik-reader # used to name the project in wandb +offline: false # if true, wandb will not be used + +defaults: + - _self_ + - training: base + - model: base + - data: base diff --git a/relik/reader/conf/model/base.yaml b/relik/reader/conf/model/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3216a62834d08d0cf8fa5329f166ce8ab849c603 --- /dev/null +++ b/relik/reader/conf/model/base.yaml @@ -0,0 +1,15 @@ +model: + transformer_model: "microsoft/deberta-v3-base" + +optimizer: + lr: 0.0001 + warmup_steps: 5000 + total_steps: ${training.trainer.max_steps} + total_reset: 1 + weight_decay: 0.0 + lr_decay: 0.8 + no_decay_params: + - "bias" + - LayerNorm.weight + +entities_per_forward: 100 diff --git a/relik/reader/conf/model/bio.yaml b/relik/reader/conf/model/bio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..250267742ec99079e8f67f7d549ca0e194319a9d --- /dev/null +++ b/relik/reader/conf/model/bio.yaml @@ -0,0 +1,14 @@ +model: + transformer_model: "microsoft/deberta-v3-small" + +optimizer: + lr: 0.00005 + warmup_steps: 25000 + total_steps: ${training.trainer.max_steps} + weight_decay: 0.01 + no_decay_params: + - "bias" + - LayerNorm.weight + +relations_per_forward: 16 +entities_per_forward: diff --git a/relik/reader/conf/model/crossre.yaml b/relik/reader/conf/model/crossre.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7724cc713a5c6bea395d8e1ecb9290698fc5c292 --- /dev/null +++ b/relik/reader/conf/model/crossre.yaml @@ -0,0 +1,14 @@ +model: + transformer_model: "/root/relik-sapienzanlp/models/relik-reader-deberta-small-retriever-relik-e5-small-hierarchy-top8-balanced-2024-02-01-06-43-07-{val_f1:.2f}/relik-reader-deberta-small-retriever-relik-e5-small-hierarchy-top8-balanced-2024-02-01-06-43-07-{val_f1:.2f}/" + +optimizer: + lr: 0.00005 + warmup_steps: 250 + total_steps: ${training.trainer.max_steps} + weight_decay: 0.01 + no_decay_params: + - "bias" + - LayerNorm.weight + +relations_per_forward: 12 +entities_per_forward: diff --git a/relik/reader/conf/model/large.yaml b/relik/reader/conf/model/large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e4596a82eee9bc6a4b8ce72a8adc4473b020cd0d --- /dev/null +++ b/relik/reader/conf/model/large.yaml @@ -0,0 +1,15 @@ +model: + transformer_model: "microsoft/deberta-v3-large" + +optimizer: + lr: 0.0001 + warmup_steps: 5000 + total_steps: ${training.trainer.max_steps} + total_reset: 1 + weight_decay: 0.0 + lr_decay: 0.9 + no_decay_params: + - "bias" + - LayerNorm.weight + +entities_per_forward: 100 diff --git a/relik/reader/conf/model/nyt.yaml b/relik/reader/conf/model/nyt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d978848db0ccd095bbc0799010d6a28e5b690fe --- /dev/null +++ b/relik/reader/conf/model/nyt.yaml @@ -0,0 +1,22 @@ +model: + transformer_model: "microsoft/deberta-v3-large" + +optimizer: + lr: + - 0.0001 + - 0.00002 + warmup_steps: 500 + total_steps: ${training.trainer.max_steps} + total_reset: 1 + weight_decay: 0.01 + lr_decay: 0.9 + no_decay_params: + - "bias" + - LayerNorm.weight + other_lr_params: + - "re_subject_projector" + - "re_object_projector" + - "re_relation_projector" + - "re_classifier" +relations_per_forward: 24 +entities_per_forward: diff --git a/relik/reader/conf/model/nyt_bio.yaml b/relik/reader/conf/model/nyt_bio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c6fd777b3b7cace22c14b17cbd0733ea8ce7cfa --- /dev/null +++ b/relik/reader/conf/model/nyt_bio.yaml @@ -0,0 +1,14 @@ +model: + transformer_model: "/root/relik-sapienzanlp/models/relik-reader-deberta-small-retriever-relik-e5-small-hierarchy-top8-balanced-2024-02-01-06-43-07-{val_f1:.2f}/relik-reader-deberta-small-retriever-relik-e5-small-hierarchy-top8-balanced-2024-02-01-06-43-07-{val_f1:.2f}/" + +optimizer: + lr: 0.00005 + warmup_steps: 2500 + total_steps: ${training.trainer.max_steps} + weight_decay: 0.01 + no_decay_params: + - "bias" + - LayerNorm.weight + +relations_per_forward: 12 +entities_per_forward: diff --git a/relik/reader/conf/model/small.yaml b/relik/reader/conf/model/small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a4d943f9e980477c7ad8f21281b922e544e73e0 --- /dev/null +++ b/relik/reader/conf/model/small.yaml @@ -0,0 +1,15 @@ +model: + transformer_model: "/home/carlos/amr-parsing-master/sentence-similarity/relik-main/experiments/models/relik-small-best/17-47-48/wandb/run-20240430_174753-g4pz35kp/files/hf_model/hf_model" + +optimizer: + lr: 0.0001 + warmup_steps: 500 + total_steps: ${training.trainer.max_steps} + total_reset: 1 + weight_decay: 0.0 + lr_decay: 0.9 + no_decay_params: + - "bias" + - LayerNorm.weight + +entities_per_forward: 100 diff --git a/relik/reader/conf/small.yaml b/relik/reader/conf/small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bb59f0ad5effb9997b08714f10f915864c013a7 --- /dev/null +++ b/relik/reader/conf/small.yaml @@ -0,0 +1,14 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: relik-reader-deberta-small-io-10 # -start-end-mask-0.001 # used to name the model in wandb and output dir +project_name: relik-io # used to name the project in wandb +offline: false # if true, wandb will not be used + +defaults: + - _self_ + - training: small + - model: small + - data: small diff --git a/relik/reader/conf/training/base.yaml b/relik/reader/conf/training/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..985a981d734e76734626d81953ae8f8b04c31a7d --- /dev/null +++ b/relik/reader/conf/training/base.yaml @@ -0,0 +1,13 @@ +seed: 94 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 50000 + val_check_interval: 1.0 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 + accumulate_grad_batches: 4 diff --git a/relik/reader/conf/training/bio.yaml b/relik/reader/conf/training/bio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6d69ee6daeb0b5bd9cc10e194027dca8731862b --- /dev/null +++ b/relik/reader/conf/training/bio.yaml @@ -0,0 +1,15 @@ +seed: 15 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 150000 + val_check_interval: 5000 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 + +save_model_path: /root/relik-sapienzanlp/models/${model_name}-${now:%Y-%m-%d-%H-%M-%S}-{val_f1:.2f} +ckpt_path: \ No newline at end of file diff --git a/relik/reader/conf/training/crossre.yaml b/relik/reader/conf/training/crossre.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0128a536198301a9370cb7e90bea2fc12f898476 --- /dev/null +++ b/relik/reader/conf/training/crossre.yaml @@ -0,0 +1,15 @@ +seed: 15 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 2500 + check_val_every_n_epoch: 20 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 + +save_model_path: /root/relik-sapienzanlp/models/${model_name}-${now:%Y-%m-%d-%H-%M-%S} +ckpt_path: \ No newline at end of file diff --git a/relik/reader/conf/training/large.yaml b/relik/reader/conf/training/large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c387d2d581e763dc9f7a8014c61e6717be6ad4cf --- /dev/null +++ b/relik/reader/conf/training/large.yaml @@ -0,0 +1,13 @@ +seed: 94 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 50000 + val_check_interval: 1.0 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 + accumulate_grad_batches: 8 diff --git a/relik/reader/conf/training/nyt.yaml b/relik/reader/conf/training/nyt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cba11f5313f6c16bb6e94fb1c676f01a99ded21b --- /dev/null +++ b/relik/reader/conf/training/nyt.yaml @@ -0,0 +1,15 @@ +seed: 15 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 5000 + val_check_interval: 1.0 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 + accumulate_grad_batches: 16 + +ckpt_path: diff --git a/relik/reader/conf/training/nyt_bio.yaml b/relik/reader/conf/training/nyt_bio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a547c18a94adffb7f5bfc2cfe4c10eef71e2f37b --- /dev/null +++ b/relik/reader/conf/training/nyt_bio.yaml @@ -0,0 +1,15 @@ +seed: 15 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 25000 + val_check_interval: 1.0 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 + +save_model_path: /root/relik-sapienzanlp/models/${model_name}-${now:%Y-%m-%d-%H-%M-%S} +ckpt_path: \ No newline at end of file diff --git a/relik/reader/conf/training/re.yaml b/relik/reader/conf/training/re.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8701ae3fca48830649022644a743783a1016bd5b --- /dev/null +++ b/relik/reader/conf/training/re.yaml @@ -0,0 +1,12 @@ +seed: 15 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 100000 + val_check_interval: 1.0 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 \ No newline at end of file diff --git a/relik/reader/conf/training/small.yaml b/relik/reader/conf/training/small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a1e3ff8a4b32f7b031d56112b003ccc3949298a --- /dev/null +++ b/relik/reader/conf/training/small.yaml @@ -0,0 +1,13 @@ +seed: 94 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 20_000 + val_check_interval: 1.0 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 + accumulate_grad_batches: 8 diff --git a/relik/reader/data/__init__.py b/relik/reader/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/data/__pycache__/__init__.cpython-310.pyc b/relik/reader/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d3973a7387266803778751b549a8faf2b49c30a Binary files /dev/null and b/relik/reader/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/reader/data/__pycache__/patches.cpython-310.pyc b/relik/reader/data/__pycache__/patches.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..367e5ee51e421c0f6e9ad80e788f908dfa79e5aa Binary files /dev/null and b/relik/reader/data/__pycache__/patches.cpython-310.pyc differ diff --git a/relik/reader/data/__pycache__/relik_reader_data.cpython-310.pyc b/relik/reader/data/__pycache__/relik_reader_data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a73c5255eae4c48624f0fc2285162a605da7a30 Binary files /dev/null and b/relik/reader/data/__pycache__/relik_reader_data.cpython-310.pyc differ diff --git a/relik/reader/data/__pycache__/relik_reader_data_utils.cpython-310.pyc b/relik/reader/data/__pycache__/relik_reader_data_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76b0e11cd5488562fdf69b83b1cc99fdaa9fd936 Binary files /dev/null and b/relik/reader/data/__pycache__/relik_reader_data_utils.cpython-310.pyc differ diff --git a/relik/reader/data/__pycache__/relik_reader_sample.cpython-310.pyc b/relik/reader/data/__pycache__/relik_reader_sample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7e897311368722f6c70f4dfe58b5be4266a3094 Binary files /dev/null and b/relik/reader/data/__pycache__/relik_reader_sample.cpython-310.pyc differ diff --git a/relik/reader/data/patches.py b/relik/reader/data/patches.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d03dbdf08d0e205787ce2b8176c6bd47d2dfca --- /dev/null +++ b/relik/reader/data/patches.py @@ -0,0 +1,51 @@ +from typing import List + +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.utils.special_symbols import NME_SYMBOL + + +def merge_patches_predictions(sample) -> None: + sample._d["predicted_window_labels"] = dict() + predicted_window_labels = sample._d["predicted_window_labels"] + + sample._d["span_title_probabilities"] = dict() + span_title_probabilities = sample._d["span_title_probabilities"] + + span2title = dict() + for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): + # selecting span predictions + for predicted_title, predicted_spans in patch_info[ + "predicted_window_labels" + ].items(): + for pred_span in predicted_spans: + pred_span = tuple(pred_span) + curr_title = span2title.get(pred_span) + if curr_title is None or curr_title == NME_SYMBOL: + span2title[pred_span] = predicted_title + # else: + # print("Merging at patch level") + + # selecting span predictions probability + for predicted_span, titles_probabilities in patch_info[ + "span_title_probabilities" + ].items(): + if predicted_span not in span_title_probabilities: + span_title_probabilities[predicted_span] = titles_probabilities + + for span, title in span2title.items(): + if title not in predicted_window_labels: + predicted_window_labels[title] = list() + predicted_window_labels[title].append(span) + + +def remove_duplicate_samples( + samples: List[RelikReaderSample], +) -> List[RelikReaderSample]: + seen_sample = set() + samples_store = [] + for sample in samples: + sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}" + if sample_id not in seen_sample: + seen_sample.add(sample_id) + samples_store.append(sample) + return samples_store diff --git a/relik/reader/data/relik_reader_data.py b/relik/reader/data/relik_reader_data.py new file mode 100644 index 0000000000000000000000000000000000000000..86a65d5fc8dc1f9f0c79fd801155db2085cb4d81 --- /dev/null +++ b/relik/reader/data/relik_reader_data.py @@ -0,0 +1,991 @@ +import logging +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import numpy as np +import torch +from torch.utils.data import IterableDataset +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer + +from relik.reader.data.relik_reader_data_utils import ( + add_noise_to_value, + batchify, + chunks, + flatten, +) +from relik.reader.data.relik_reader_sample import ( + RelikReaderSample, + load_relik_reader_samples, +) +from relik.reader.utils.special_symbols import NME_SYMBOL + +logger = logging.getLogger(__name__) + + +def preprocess_sample( + relik_sample: RelikReaderSample, + tokenizer, + lowercase_policy: float, + add_topic: bool = False, +) -> None: + if len(relik_sample.tokens) == 0: + return + + if lowercase_policy > 0: + lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy + relik_sample.tokens = [ + t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens) + ] + + tokenization_out = tokenizer( + relik_sample.tokens, + return_offsets_mapping=True, + add_special_tokens=False, + ) + + window_tokens = tokenization_out.input_ids + window_tokens = flatten(window_tokens) + + offsets_mapping = [ + [ + ( + ss + relik_sample.token2char_start[str(i)], + se + relik_sample.token2char_start[str(i)], + ) + for ss, se in tokenization_out.offset_mapping[i] + ] + for i in range(len(relik_sample.tokens)) + ] + + offsets_mapping = flatten(offsets_mapping) + + assert len(offsets_mapping) == len(window_tokens) + + window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] + + topic_offset = 0 + if add_topic: + topic_tokens = tokenizer( + relik_sample.doc_topic, add_special_tokens=False + ).input_ids + topic_offset = len(topic_tokens) + relik_sample.topic_tokens = topic_offset + window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] + + token2char_start = { + str(i): s for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) + } + token2char_end = { + str(i): e for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) + } + token2word_start = { + str(i): int(relik_sample._d["char2token_start"][str(s)]) + for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) + if str(s) in relik_sample._d["char2token_start"] + } + token2word_end = { + str(i): int(relik_sample._d["char2token_end"][str(e)]) + for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) + if str(e) in relik_sample._d["char2token_end"] + } + relik_sample._d.update( + dict( + tokens=window_tokens, + token2char_start=token2char_start, + token2char_end=token2char_end, + token2word_start=token2word_start, + token2word_end=token2word_end, + ) + ) + + if "window_labels" in relik_sample._d: + relik_sample.window_labels = [ + (s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels + ] + + +class TokenizationOutput(NamedTuple): + input_ids: torch.Tensor + attention_mask: torch.Tensor + token_type_ids: torch.Tensor + prediction_mask: torch.Tensor + special_symbols_mask: torch.Tensor + + +class RelikDataset(IterableDataset): + def __init__( + self, + dataset_path: Optional[str], + materialize_samples: bool, + transformer_model: Union[str, PreTrainedTokenizer], + special_symbols: List[str], + shuffle_candidates: Optional[Union[bool, float]] = False, + for_inference: bool = False, + noise_param: float = 0.1, + sorting_fields: Optional[str] = None, + tokens_per_batch: int = 2048, + batch_size: int = None, + max_batch_size: int = 128, + section_size: int = 50_000, + prebatch: bool = True, + random_drop_gold_candidates: float = 0.0, + use_nme: bool = True, + max_subwords_per_candidate: bool = 22, + mask_by_instances: bool = False, + min_length: int = 5, + max_length: int = 2048, + model_max_length: int = 1000, + split_on_cand_overload: bool = True, + skip_empty_training_samples: bool = False, + drop_last: bool = False, + samples: Optional[Iterator[RelikReaderSample]] = None, + lowercase_policy: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.dataset_path = dataset_path + self.materialize_samples = materialize_samples + self.samples: Optional[List[RelikReaderSample]] = None + if self.materialize_samples: + self.samples = list() + + if isinstance(transformer_model, str): + self.tokenizer = self._build_tokenizer(transformer_model, special_symbols) + else: + self.tokenizer = transformer_model + self.special_symbols = special_symbols + self.shuffle_candidates = shuffle_candidates + self.for_inference = for_inference + self.noise_param = noise_param + self.batching_fields = ["input_ids"] + self.sorting_fields = ( + sorting_fields if sorting_fields is not None else self.batching_fields + ) + + self.tokens_per_batch = tokens_per_batch + self.batch_size = batch_size + self.max_batch_size = max_batch_size + self.section_size = section_size + self.prebatch = prebatch + + self.random_drop_gold_candidates = random_drop_gold_candidates + self.use_nme = use_nme + self.max_subwords_per_candidate = max_subwords_per_candidate + self.mask_by_instances = mask_by_instances + self.min_length = min_length + self.max_length = max_length + self.model_max_length = ( + model_max_length + if model_max_length < self.tokenizer.model_max_length + else self.tokenizer.model_max_length + ) + + # retrocompatibility workaround + self.transformer_model = ( + transformer_model + if isinstance(transformer_model, str) + else transformer_model.name_or_path + ) + self.split_on_cand_overload = split_on_cand_overload + self.skip_empty_training_samples = skip_empty_training_samples + self.drop_last = drop_last + self.lowercase_policy = lowercase_policy + self.samples = samples + + def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): + return AutoTokenizer.from_pretrained( + transformer_model, + additional_special_tokens=[ss for ss in special_symbols], + add_prefix_space=True, + ) + + @staticmethod + def get_special_symbols(num_entities: int) -> List[str]: + return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)] + + @property + def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: + fields_batchers = { + "input_ids": lambda x: batchify( + x, padding_value=self.tokenizer.pad_token_id + ), + "attention_mask": lambda x: batchify(x, padding_value=0), + "token_type_ids": lambda x: batchify(x, padding_value=0), + "prediction_mask": lambda x: batchify(x, padding_value=1), + "global_attention": lambda x: batchify(x, padding_value=0), + "token2word": None, + "sample": None, + "special_symbols_mask": lambda x: batchify(x, padding_value=False), + "start_labels": lambda x: batchify(x, padding_value=-100), + "end_labels": lambda x: batchify(x, padding_value=-100), + "predictable_candidates_symbols": None, + "predictable_candidates": None, + "patch_offset": None, + "optimus_labels": None, + } + + if ( + isinstance(self.transformer_model, str) + and "roberta" in self.transformer_model + ) or ( + isinstance(self.transformer_model, PreTrainedTokenizer) + and "roberta" in self.transformer_model.config.model_type + ): + del fields_batchers["token_type_ids"] + + return fields_batchers + + def _build_input_ids( + self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] + ) -> List[int]: + return ( + [self.tokenizer.cls_token_id] + + sentence_input_ids + + [self.tokenizer.sep_token_id] + + flatten(candidates_input_ids) + + [self.tokenizer.sep_token_id] + ) + + def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor: + special_symbols_mask = input_ids >= ( + len(self.tokenizer) - len(self.special_symbols) + ) + special_symbols_mask[0] = True + return special_symbols_mask + + def _build_tokenizer_essentials( + self, input_ids, original_sequence, sample + ) -> TokenizationOutput: + input_ids = torch.tensor(input_ids, dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + + total_sequence_len = len(input_ids) + predictable_sentence_len = len(original_sequence) + + # token type ids + token_type_ids = torch.cat( + [ + input_ids.new_zeros( + predictable_sentence_len + 2 + ), # original sentence bpes + CLS and SEP + input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2), + ] + ) + + # prediction mask -> boolean on tokens that are predictable + + prediction_mask = torch.tensor( + [1] + + ([0] * predictable_sentence_len) + + ([1] * (total_sequence_len - predictable_sentence_len - 1)) + ) + + # add topic tokens to the prediction mask so that they cannot be predicted + # or optimized during training + topic_tokens = getattr(sample, "topic_tokens", None) + if topic_tokens is not None: + prediction_mask[1 : 1 + topic_tokens] = 1 + + # If mask by instances is active the prediction mask is applied to everything + # that is not indicated as an instance in the training set. + if self.mask_by_instances: + char_start2token = { + cs: int(tok) for tok, cs in sample.token2char_start.items() + } + char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()} + instances_mask = torch.ones_like(prediction_mask) + for _, span_info in sample.instance_id2span_data.items(): + span_info = span_info[0] + token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS + token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS + instances_mask[token_start : token_end + 1] = 0 + + prediction_mask += instances_mask + prediction_mask[prediction_mask > 1] = 1 + + assert len(prediction_mask) == len(input_ids) + + # special symbols mask + special_symbols_mask = self._get_special_symbols_mask(input_ids) + + return TokenizationOutput( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + ) + + def _build_labels( + self, + sample, + tokenization_output: TokenizationOutput, + predictable_candidates: List[str], + ) -> Tuple[torch.Tensor, torch.Tensor]: + start_labels = [0] * len(tokenization_output.input_ids) + end_labels = [0] * len(tokenization_output.input_ids) + + char_start2token = {v: int(k) for k, v in sample.token2char_start.items()} + char_end2token = {v: int(k) for k, v in sample.token2char_end.items()} + for cs, ce, gold_candidate_title in sample.window_labels: + if gold_candidate_title not in predictable_candidates: + if self.use_nme: + gold_candidate_title = NME_SYMBOL + else: + continue + # +1 is to account for the CLS token + start_bpe = char_start2token[cs] + 1 + end_bpe = char_end2token[ce] + 1 + class_index = predictable_candidates.index(gold_candidate_title) + if ( + start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0 + ): # prevent from having entities that ends with the same label + start_labels[start_bpe] = class_index + 1 # +1 for the NONE class + end_labels[end_bpe] = class_index + 1 # +1 for the NONE class + else: + print( + "Found entity with the same last subword, it will not be included." + ) + print( + cs, + ce, + gold_candidate_title, + start_labels, + end_labels, + sample.doc_id, + ) + + ignored_labels_indices = tokenization_output.prediction_mask == 1 + + start_labels = torch.tensor(start_labels, dtype=torch.long) + start_labels[ignored_labels_indices] = -100 + + end_labels = torch.tensor(end_labels, dtype=torch.long) + end_labels[ignored_labels_indices] = -100 + + return start_labels, end_labels + + def produce_sample_bag( + self, sample, predictable_candidates: List[str], candidates_starting_offset: int + ) -> Optional[Tuple[dict, list, int]]: + # input sentence tokenization + input_subwords = sample.tokens[1:-1] # removing special tokens + candidates_symbols = self.special_symbols[candidates_starting_offset:] + + predictable_candidates = list(predictable_candidates) + original_predictable_candidates = list(predictable_candidates) + + # add NME as a possible candidate + if self.use_nme: + predictable_candidates.insert(0, NME_SYMBOL) + + # candidates encoding + candidates_symbols = candidates_symbols[: len(predictable_candidates)] + candidates_encoding_result = self.tokenizer.batch_encode_plus( + [ + "{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL + for cs, ct in zip(candidates_symbols, predictable_candidates) + ], + add_special_tokens=False, + ).input_ids + + if ( + self.max_subwords_per_candidate is not None + and self.max_subwords_per_candidate > 0 + ): + candidates_encoding_result = [ + cer[: self.max_subwords_per_candidate] + for cer in candidates_encoding_result + ] + + # drop candidates if the number of input tokens is too long for the model + if ( + sum(map(len, candidates_encoding_result)) + + len(input_subwords) + + 20 # + 20 special tokens + > self.model_max_length + ): + acceptable_tokens_from_candidates = ( + self.model_max_length - 20 - len(input_subwords) + ) + i = 0 + cum_len = 0 + while ( + cum_len + len(candidates_encoding_result[i]) + < acceptable_tokens_from_candidates + ): + cum_len += len(candidates_encoding_result[i]) + i += 1 + + candidates_encoding_result = candidates_encoding_result[:i] + candidates_symbols = candidates_symbols[:i] + predictable_candidates = predictable_candidates[:i] + + # final input_ids build + input_ids = self._build_input_ids( + sentence_input_ids=input_subwords, + candidates_input_ids=candidates_encoding_result, + ) + + # complete input building (e.g. attention / prediction mask) + tokenization_output = self._build_tokenizer_essentials( + input_ids, input_subwords, sample + ) + + output_dict = { + "input_ids": tokenization_output.input_ids, + "attention_mask": tokenization_output.attention_mask, + "token_type_ids": tokenization_output.token_type_ids, + "prediction_mask": tokenization_output.prediction_mask, + "special_symbols_mask": tokenization_output.special_symbols_mask, + "sample": sample, + "predictable_candidates_symbols": candidates_symbols, + "predictable_candidates": predictable_candidates, + } + + # labels creation + if sample.window_labels is not None: + start_labels, end_labels = self._build_labels( + sample, + tokenization_output, + predictable_candidates, + ) + output_dict.update(start_labels=start_labels, end_labels=end_labels) + + if ( + "roberta" in self.transformer_model + or "longformer" in self.transformer_model + ): + del output_dict["token_type_ids"] + + predictable_candidates_set = set(predictable_candidates) + remaining_candidates = [ + candidate + for candidate in original_predictable_candidates + if candidate not in predictable_candidates_set + ] + total_used_candidates = ( + candidates_starting_offset + + len(predictable_candidates) + - (1 if self.use_nme else 0) + ) + + if self.use_nme: + assert predictable_candidates[0] == NME_SYMBOL + + return output_dict, remaining_candidates, total_used_candidates + + def __iter__(self): + dataset_iterator = self.dataset_iterator_func() + i = None + for batch in self.materialize_batches(dataset_iterator): + if i is None: + i = 0 + i += batch["input_ids"].shape[0] + yield batch + if i is not None: + logger.debug(f"Dataset finished: {i} number of elements processed") + else: + logger.warning("Dataset empty") + + def iter_all(self): + dataset_iterator = self.dataset_iterator_func() + + current_dataset_elements = [] + + i = None + for i, dataset_elem in enumerate(dataset_iterator, start=1): + if ( + self.section_size is not None + and len(current_dataset_elements) == self.section_size + ): + for batch in self.materialize_batches(current_dataset_elements): + yield batch + current_dataset_elements = [] + + current_dataset_elements.append(dataset_elem) + + if i % 50_000 == 0: + logger.info(f"Processed: {i} number of elements") + + if len(current_dataset_elements) != 0: + for batch in self.materialize_batches(current_dataset_elements): + yield batch + + if i is not None: + logger.debug(f"Dataset finished: {i} number of elements processed") + else: + logger.warning("Dataset empty") + + def dataset_iterator_func(self): + skipped_instances = 0 + data_samples = ( + load_relik_reader_samples(self.dataset_path) + if self.samples is None + else self.samples + ) + for sample in data_samples: + preprocess_sample( + sample, self.tokenizer, lowercase_policy=self.lowercase_policy + ) + current_patch = 0 + sample_bag, used_candidates = None, None + # TODO: compatibility shit + sample.window_candidates = sample.span_candidates + + remaining_candidates = list(sample.window_candidates) + + if not self.for_inference: + # randomly drop gold candidates at training time + if ( + self.random_drop_gold_candidates > 0.0 + and np.random.uniform() < self.random_drop_gold_candidates + and len(set(ct for _, _, ct in sample.window_labels)) > 1 + ): + # selecting candidates to drop + np.random.shuffle(sample.window_labels) + n_dropped_candidates = np.random.randint( + 0, len(sample.window_labels) - 1 + ) + dropped_candidates = [ + label_elem[-1] + for label_elem in sample.window_labels[:n_dropped_candidates] + ] + dropped_candidates = set(dropped_candidates) + + # saving NMEs because they should not be dropped + if NME_SYMBOL in dropped_candidates: + dropped_candidates.remove(NME_SYMBOL) + + # sample update + sample.window_labels = [ + ( + (s, e, _l) + if _l not in dropped_candidates + else (s, e, NME_SYMBOL) + ) + for s, e, _l in sample.window_labels + ] + remaining_candidates = [ + wc + for wc in remaining_candidates + if wc not in dropped_candidates + ] + + # shuffle candidates + if ( + isinstance(self.shuffle_candidates, bool) + and self.shuffle_candidates + ) or ( + isinstance(self.shuffle_candidates, float) + and np.random.uniform() < self.shuffle_candidates + ): + np.random.shuffle(remaining_candidates) + + while len(remaining_candidates) != 0: + sample_bag = self.produce_sample_bag( + sample, + predictable_candidates=remaining_candidates, + candidates_starting_offset=( + used_candidates if used_candidates is not None else 0 + ), + ) + if sample_bag is not None: + sample_bag, remaining_candidates, used_candidates = sample_bag + if ( + self.for_inference + or not self.skip_empty_training_samples + or ( + ( + sample_bag.get("start_labels") is not None + and torch.any(sample_bag["start_labels"] > 1).item() + ) + or ( + sample_bag.get("optimus_labels") is not None + and len(sample_bag["optimus_labels"]) > 0 + ) + ) + ): + sample_bag["patch_offset"] = current_patch + current_patch += 1 + yield sample_bag + else: + skipped_instances += 1 + if skipped_instances % 1000 == 0 and skipped_instances != 0: + logger.info( + f"Skipped {skipped_instances} instances since they did not have any gold labels..." + ) + + # Just use the first fitting candidates if split on + # cand is not True + if not self.split_on_cand_overload: + break + + def preshuffle_elements(self, dataset_elements: List): + # This shuffling is done so that when using the sorting function, + # if it is deterministic given a collection and its order, we will + # make the whole operation not deterministic anymore. + # Basically, the aim is not to build every time the same batches. + if not self.for_inference: + dataset_elements = np.random.permutation(dataset_elements) + + def sorting_fn(elem): + return ( + add_noise_to_value( + sum(len(elem[k]) for k in self.sorting_fields), + noise_param=self.noise_param, + ) + if not self.for_inference + else sum(len(elem[k]) for k in self.sorting_fields) + ) + + dataset_elements = sorted(dataset_elements, key=sorting_fn) + + if self.for_inference: + return dataset_elements + + ds = list(chunks(dataset_elements, 64)) + np.random.shuffle(ds) + return flatten(ds) + + def materialize_batches( + self, dataset_elements: List[Dict[str, Any]] + ) -> Generator[Dict[str, Any], None, None]: + if self.prebatch: + dataset_elements = self.preshuffle_elements(dataset_elements) + + current_batch = [] + + # function that creates a batch from the 'current_batch' list + def output_batch() -> Dict[str, Any]: + assert ( + len( + set([len(elem["predictable_candidates"]) for elem in current_batch]) + ) + == 1 + ), " ".join( + map( + str, [len(elem["predictable_candidates"]) for elem in current_batch] + ) + ) + + batch_dict = dict() + + de_values_by_field = { + fn: [de[fn] for de in current_batch if fn in de] + for fn in self.fields_batcher + } + + # in case you provide fields batchers but in the batch + # there are no elements for that field + de_values_by_field = { + fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 + } + + assert len(set([len(v) for v in de_values_by_field.values()])) + + # todo: maybe we should report the user about possible + # fields filtering due to "None" instances + de_values_by_field = { + fn: fvs + for fn, fvs in de_values_by_field.items() + if all([fv is not None for fv in fvs]) + } + + for field_name, field_values in de_values_by_field.items(): + field_batch = ( + self.fields_batcher[field_name](field_values) + if self.fields_batcher[field_name] is not None + else field_values + ) + + batch_dict[field_name] = field_batch + + return batch_dict + + max_len_discards, min_len_discards = 0, 0 + + should_token_batch = self.batch_size is None + + curr_pred_elements = -1 + for de in dataset_elements: + if ( + should_token_batch + and self.max_batch_size != -1 + and len(current_batch) == self.max_batch_size + ) or (not should_token_batch and len(current_batch) == self.batch_size): + yield output_batch() + current_batch = [] + curr_pred_elements = -1 + + too_long_fields = [ + k + for k in de + if self.max_length != -1 + and torch.is_tensor(de[k]) + and len(de[k]) > self.max_length + ] + if len(too_long_fields) > 0: + max_len_discards += 1 + continue + + too_short_fields = [ + k + for k in de + if self.min_length != -1 + and torch.is_tensor(de[k]) + and len(de[k]) < self.min_length + ] + if len(too_short_fields) > 0: + min_len_discards += 1 + continue + + if should_token_batch: + de_len = sum(len(de[k]) for k in self.batching_fields) + + future_max_len = max( + de_len, + max( + [ + sum(len(bde[k]) for k in self.batching_fields) + for bde in current_batch + ], + default=0, + ), + ) + + future_tokens_per_batch = future_max_len * (len(current_batch) + 1) + + num_predictable_candidates = len(de["predictable_candidates"]) + + if len(current_batch) > 0 and ( + future_tokens_per_batch >= self.tokens_per_batch + or ( + num_predictable_candidates != curr_pred_elements + and curr_pred_elements != -1 + ) + ): + yield output_batch() + current_batch = [] + + current_batch.append(de) + curr_pred_elements = len(de["predictable_candidates"]) + + if len(current_batch) != 0 and not self.drop_last: + yield output_batch() + + if max_len_discards > 0: + if self.for_inference: + logger.warning( + f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " + f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" + f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " + f"sample length exceeds the maximum length supported by the current model." + ) + else: + logger.warning( + f"During iteration, {max_len_discards} elements were " + f"discarded since longer than max length {self.max_length}" + ) + + if min_len_discards > 0: + if self.for_inference: + logger.warning( + f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " + f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" + f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " + f"sample length is shorter than the minimum length supported by the current model." + ) + else: + logger.warning( + f"During iteration, {min_len_discards} elements were " + f"discarded since shorter than min length {self.min_length}" + ) + + @staticmethod + def convert_to_char_annotations( + sample: RelikReaderSample, + remove_nmes: bool = True, + ) -> RelikReaderSample: + """ + Converts the annotations to char annotations. + + Args: + sample (:obj:`RelikReaderSample`): + The sample to convert. + remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to remove the NMEs from the annotations. + Returns: + :obj:`RelikReaderSample`: The converted sample. + """ + char_annotations = set() + for ( + predicted_entity, + predicted_spans, + ) in sample.predicted_window_labels.items(): + if predicted_entity == NME_SYMBOL and remove_nmes: + continue + + for span_start, span_end in predicted_spans: + span_start = sample.token2char_start[str(span_start)] + span_end = sample.token2char_end[str(span_end)] + + char_annotations.add((span_start, span_end, predicted_entity)) + + char_probs_annotations = dict() + for ( + span_start, + span_end, + ), candidates_probs in sample.span_title_probabilities.items(): + span_start = sample.token2char_start[str(span_start)] + span_end = sample.token2char_end[str(span_end)] + # TODO: which one is kept if there are multiple candidates with same title? + # and where is the order? + char_probs_annotations[(span_start, span_end)] = { + title for title, _ in candidates_probs + } + + sample.predicted_window_labels_chars = char_annotations + sample.probs_window_labels_chars = char_probs_annotations + + # try-out for a new format + sample.predicted_spans = char_annotations + sample.predicted_spans_probabilities = char_probs_annotations + + return sample + + @staticmethod + def convert_to_word_annotations( + sample: RelikReaderSample, + remove_nmes: bool = True, + ) -> RelikReaderSample: + """ + Converts the annotations to tokens annotations. + + Args: + sample (:obj:`RelikReaderSample`): + The sample to convert. + remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to remove the NMEs from the annotations. + + Returns: + :obj:`RelikReaderSample`: The converted sample. + """ + word_annotations = set() + for ( + predicted_entity, + predicted_spans, + ) in sample.predicted_window_labels.items(): + if predicted_entity == NME_SYMBOL and remove_nmes: + continue + + for span_start, span_end in predicted_spans: + if str(span_start) not in sample.token2word_start: + # span_start is in the middle of a word + # retrieve the first token of the word + while str(span_start) not in sample.token2word_start: + span_start -= 1 + # skip + if span_start < 0: + break + if str(span_end) not in sample.token2word_end: + # span_end is in the middle of a word + # retrieve the last token of the word + while str(span_end) not in sample.token2word_end: + span_end += 1 + # skip + if span_end >= len(sample.tokens): + break + + if span_start < 0 or span_end >= len(sample.tokens): + continue + + span_start = sample.token2word_start[str(span_start)] + span_end = sample.token2word_end[str(span_end)] + + word_annotations.add((span_start, span_end + 1, predicted_entity)) + + word_probs_annotations = dict() + for ( + span_start, + span_end, + ), candidates_probs in sample.span_title_probabilities.items(): + for span_start, span_end in predicted_spans: + if str(span_start) not in sample.token2word_start: + # span_start is in the middle of a word + # retrieve the first token of the word + while str(span_start) not in sample.token2word_start: + span_start -= 1 + # skip + if span_start < 0: + break + if str(span_end) not in sample.token2word_end: + # span_end is in the middle of a word + # retrieve the last token of the word + while str(span_end) not in sample.token2word_end: + span_end += 1 + # skip + if span_end >= len(sample.tokens): + break + + if span_start < 0 or span_end >= len(sample.tokens): + continue + span_start = sample.token2word_start[str(span_start)] + span_end = sample.token2word_end[str(span_end)] + word_probs_annotations[(span_start, span_end + 1)] = { + title for title, _ in candidates_probs + } + + sample.predicted_window_labels_words = word_annotations + sample.probs_window_labels_words = word_probs_annotations + + # try-out for a new format + sample.predicted_spans = word_annotations + sample.predicted_spans_probabilities = word_probs_annotations + return sample + + @staticmethod + def merge_patches_predictions(sample) -> None: + sample._d["predicted_window_labels"] = dict() + predicted_window_labels = sample._d["predicted_window_labels"] + + sample._d["span_title_probabilities"] = dict() + span_title_probabilities = sample._d["span_title_probabilities"] + + span2title = dict() + for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): + # selecting span predictions + for predicted_title, predicted_spans in patch_info[ + "predicted_window_labels" + ].items(): + for pred_span in predicted_spans: + pred_span = tuple(pred_span) + curr_title = span2title.get(pred_span) + if curr_title is None or curr_title == NME_SYMBOL: + span2title[pred_span] = predicted_title + # else: + # print("Merging at patch level") + + # selecting span predictions probability + for predicted_span, titles_probabilities in patch_info[ + "span_title_probabilities" + ].items(): + if predicted_span not in span_title_probabilities: + span_title_probabilities[predicted_span] = titles_probabilities + + for span, title in span2title.items(): + if title not in predicted_window_labels: + predicted_window_labels[title] = list() + predicted_window_labels[title].append(span) diff --git a/relik/reader/data/relik_reader_data_utils.py b/relik/reader/data/relik_reader_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7446bee296d14653a35895bf9ec8071c87e5af --- /dev/null +++ b/relik/reader/data/relik_reader_data_utils.py @@ -0,0 +1,51 @@ +from typing import List + +import numpy as np +import torch + + +def flatten(lsts: List[list]) -> list: + acc_lst = list() + for lst in lsts: + acc_lst.extend(lst) + return acc_lst + + +def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor: + return torch.nn.utils.rnn.pad_sequence( + tensors, batch_first=True, padding_value=padding_value + ) + + +def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: + x = max([t.shape[0] for t in tensors]) + y = max([t.shape[1] for t in tensors]) + out_matrix = torch.zeros((len(tensors), x, y)) + out_matrix += padding_value + for i, tensor in enumerate(tensors): + out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor + return out_matrix + + +def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: + x = max([t.shape[0] for t in tensors]) + y = max([t.shape[1] for t in tensors]) + rest = tensors[0].shape[2] + out_matrix = torch.zeros((len(tensors), x, y, rest)) + out_matrix += padding_value + for i, tensor in enumerate(tensors): + out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor + return out_matrix + + +def chunks(lst: list, chunk_size: int) -> List[list]: + chunks_acc = list() + for i in range(0, len(lst), chunk_size): + chunks_acc.append(lst[i : i + chunk_size]) + return chunks_acc + + +def add_noise_to_value(value: int, noise_param: float): + noise_value = value * noise_param + noise = np.random.uniform(-noise_value, noise_value) + return max(1, value + noise) diff --git a/relik/reader/data/relik_reader_re_data.py b/relik/reader/data/relik_reader_re_data.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb564cb6949c097eb683925b554a8b141f58822 --- /dev/null +++ b/relik/reader/data/relik_reader_re_data.py @@ -0,0 +1,1155 @@ +import logging +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import numpy as np +import torch +import tqdm +from torch.utils.data import IterableDataset +from transformers import AutoTokenizer, PreTrainedTokenizer + +from relik.reader.data.relik_reader_data_utils import ( + add_noise_to_value, + batchify, + batchify_matrices, + batchify_tensor, + chunks, + flatten, +) +from relik.reader.data.relik_reader_sample import ( + RelikReaderSample, + load_relik_reader_samples, +) +from relik.reader.utils.special_symbols import NME_SYMBOL + +logger = logging.getLogger(__name__) + + +class TokenizationOutput(NamedTuple): + input_ids: torch.Tensor + attention_mask: torch.Tensor + token_type_ids: torch.Tensor + prediction_mask: torch.Tensor + special_symbols_mask: torch.Tensor + special_symbols_mask_entities: torch.Tensor + + +class RelikREDataset(IterableDataset): + def __init__( + self, + dataset_path: str, + materialize_samples: bool, + transformer_model: Union[str, PreTrainedTokenizer], + special_symbols: List[str], + shuffle_candidates: Optional[Union[bool, float]] = False, + flip_candidates: Optional[Union[bool, float]] = False, + for_inference: bool = False, + special_symbols_types=None, + noise_param: float = 0.1, + sorting_fields: Optional[str] = None, + tokens_per_batch: int = 2048, + batch_size: int = None, + max_batch_size: int = 128, + section_size: int = 500_000, + prebatch: bool = True, + add_gold_candidates: bool = True, + use_nme: bool = False, + min_length: int = -1, + max_length: int = 2048, + max_triplets: int = 50, + max_spans: int = 100, + model_max_length: int = 2048, + skip_empty_training_samples: bool = True, + drop_last: bool = False, + samples: Optional[Iterator[RelikReaderSample]] = None, + **kwargs, + ): + super().__init__(**kwargs) + # mutable default arguments + if special_symbols_types is None: + special_symbols_types = [] + + self.dataset_path = dataset_path + self.materialize_samples = materialize_samples + self.samples: Optional[List[RelikReaderSample]] = samples + if self.materialize_samples and self.samples is None: + self.samples = list() + + if isinstance(transformer_model, str): + self.tokenizer = self._build_tokenizer( + transformer_model, special_symbols + special_symbols_types + ) + else: + self.tokenizer = transformer_model + self.special_symbols = special_symbols + self.special_symbols_types = special_symbols_types + self.shuffle_candidates = shuffle_candidates + self.flip_candidates = flip_candidates + self.for_inference = for_inference + self.noise_param = noise_param + self.batching_fields = ["input_ids"] + self.sorting_fields = ( + sorting_fields if sorting_fields is not None else self.batching_fields + ) + self.add_gold_candidates = add_gold_candidates + self.use_nme = use_nme + self.min_length = min_length + self.max_length = max_length + self.model_max_length = ( + model_max_length + if model_max_length < self.tokenizer.model_max_length + else self.tokenizer.model_max_length + ) + self.transformer_model = transformer_model + self.skip_empty_training_samples = skip_empty_training_samples + self.drop_last = drop_last + + self.tokens_per_batch = tokens_per_batch + self.batch_size = batch_size + self.max_batch_size = max_batch_size + self.max_triplets = max_triplets + self.max_spans = max_spans + self.section_size = section_size + self.prebatch = prebatch + + def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): + return AutoTokenizer.from_pretrained( + transformer_model, + additional_special_tokens=[ss for ss in special_symbols], + add_prefix_space=True, + ) + + @staticmethod + def get_special_symbols_re(num_entities: int, use_nme: bool = False) -> List[str]: + if use_nme: + return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)] + else: + return [f"[R-{i}]" for i in range(num_entities)] + + @staticmethod + def get_special_symbols(num_entities: int) -> List[str]: + return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)] + + @property + def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: + fields_batchers = { + "input_ids": lambda x: batchify( + x, padding_value=self.tokenizer.pad_token_id + ), + "attention_mask": lambda x: batchify(x, padding_value=0), + "token_type_ids": lambda x: batchify(x, padding_value=0), + "prediction_mask": lambda x: batchify(x, padding_value=1), + "global_attention": lambda x: batchify(x, padding_value=0), + "token2word": None, + "sample": None, + "special_symbols_mask": lambda x: batchify(x, padding_value=False), + "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), + "start_labels": lambda x: batchify(x, padding_value=-100), + "end_labels": lambda x: batchify_matrices(x, padding_value=-100), + "disambiguation_labels": lambda x: batchify(x, padding_value=-100), + "relation_labels": lambda x: batchify_tensor(x, padding_value=-100), + "predictable_candidates": None, + } + if ( + isinstance(self.transformer_model, str) + and "roberta" in self.transformer_model + ) or ( + isinstance(self.transformer_model, PreTrainedTokenizer) + and "roberta" in self.transformer_model.config.model_type + ): + del fields_batchers["token_type_ids"] + + return fields_batchers + + def _build_input_ids( + self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] + ) -> List[int]: + return ( + [self.tokenizer.cls_token_id] + + sentence_input_ids + + [self.tokenizer.sep_token_id] + + flatten(candidates_input_ids) + + [self.tokenizer.sep_token_id] + ) + + def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]: + return ( + text + + [self.tokenizer.sep_token] + + flatten(candidates) + + [self.tokenizer.sep_token] + ) + + def _build_tokenizer_essentials( + self, input_ids, original_sequence, ents=0 + ) -> TokenizationOutput: + input_ids = torch.tensor(input_ids, dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + + if len(self.special_symbols_types) > 0: + # special symbols mask + special_symbols_mask = input_ids >= self.tokenizer.vocab_size + # select only the first N true values where N is len(entities_definitions) + special_symbols_mask_entities = special_symbols_mask.clone() + special_symbols_mask_entities[ + special_symbols_mask_entities.cumsum(0) > ents + ] = False + token_type_ids = (torch.cumsum(special_symbols_mask, dim=0) > 0).long() + special_symbols_mask = special_symbols_mask ^ special_symbols_mask_entities + else: + special_symbols_mask = input_ids >= self.tokenizer.vocab_size + special_symbols_mask_entities = special_symbols_mask.clone() + token_type_ids = (torch.cumsum(special_symbols_mask, dim=0) > 0).long() + + prediction_mask = token_type_ids.roll(shifts=-1, dims=0) + prediction_mask[-1] = 1 + prediction_mask[0] = 1 + + assert len(prediction_mask) == len(input_ids) + + return TokenizationOutput( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + special_symbols_mask_entities, + ) + + @staticmethod + def _subindex(lst, target_values, dims): + for i, sublist in enumerate(lst): + match = all(sublist[dim] == target_values[dim] for dim in dims) + if match: + return i + + def _build_labels( + self, + sample, + tokenization_output: TokenizationOutput, + ) -> Tuple[torch.Tensor, torch.Tensor]: + start_labels = [0] * len(tokenization_output.input_ids) + end_labels = [] + end_labels_tensor = [0] * len(tokenization_output.input_ids) + + sample.entities.sort(key=lambda x: (x[0], x[1])) + + prev_start_bpe = -1 + entities_untyped = list(set([(ce[0], ce[1]) for ce in sample.entities])) + entities_untyped.sort(key=lambda x: (x[0], x[1])) + if len(self.special_symbols_types) > 0: + sample.entities = [(ce[0], ce[1], ce[2]) for ce in sample.entities] + disambiguation_labels = torch.zeros( + len(entities_untyped), + len(sample.span_candidates) + len(sample.triplet_candidates), + ) + else: + sample.entities = [(ce[0], ce[1], "") for ce in sample.entities] + disambiguation_labels = torch.zeros( + len(entities_untyped), len(sample.triplet_candidates) + ) + ignored_labels_indices = tokenization_output.prediction_mask == 1 + offset = 0 + for idx, c_ent in enumerate(sample.entities): + while len(sample.word2token[c_ent[0]]) == 0: + c_ent = (c_ent[0] + 1, c_ent[1], c_ent[2]) + if len(sample.word2token) == c_ent[0]: + c_ent = None + break + if c_ent is None: + continue + while len(sample.word2token[c_ent[1] - 1]) == 0: + c_ent = (c_ent[0], c_ent[1] + 1, c_ent[2]) + if len(sample.word2token) == c_ent[1]: + c_ent = None + break + if c_ent is None: + continue + start_bpe = sample.word2token[c_ent[0]][0] + 1 + end_bpe = sample.word2token[c_ent[1] - 1][-1] + 1 + class_index = idx + start_labels[start_bpe] = class_index + 1 # +1 for the NONE class + if start_bpe != prev_start_bpe: + end_labels.append(end_labels_tensor.copy()) + end_labels[-1][:start_bpe] = [-100] * start_bpe + end_labels[-1][end_bpe] = class_index + 1 + elif end_labels[-1][end_bpe] == 0: + end_labels[-1][end_bpe] = class_index + 1 + else: + offset += 1 + prev_start_bpe = start_bpe + continue + if len(self.special_symbols_types) > 0: + if c_ent[2] in sample.span_candidates: + entity_type_idx = sample.span_candidates.index(c_ent[2]) + else: + entity_type_idx = 0 + disambiguation_labels[idx - offset, entity_type_idx] = 1 + prev_start_bpe = start_bpe + + start_labels = torch.tensor(start_labels, dtype=torch.long) + start_labels[ignored_labels_indices] = -100 + + end_labels = torch.tensor(end_labels, dtype=torch.long) + end_labels[ignored_labels_indices.repeat(len(end_labels), 1)] = -100 + + relation_labels = torch.zeros( + len(entities_untyped), len(entities_untyped), len(sample.triplet_candidates) + ) + + for re in sample.triplets: + if re["relation"]["name"] not in sample.triplet_candidates: + re_class_index = len(sample.triplet_candidates) - 1 + else: + re_class_index = sample.triplet_candidates.index(re["relation"]["name"]) + + subject_class_index = self._subindex( + entities_untyped, (re["subject"]["start"], re["subject"]["end"]), (0, 1) + ) + object_class_index = self._subindex( + entities_untyped, (re["object"]["start"], re["object"]["end"]), (0, 1) + ) + + relation_labels[subject_class_index, object_class_index, re_class_index] = 1 + + if len(self.special_symbols_types) > 0: + disambiguation_labels[ + subject_class_index, re_class_index + len(sample.span_candidates) + ] = 1 + disambiguation_labels[ + object_class_index, re_class_index + len(sample.span_candidates) + ] = 1 + else: + disambiguation_labels[subject_class_index, re_class_index] = 1 + disambiguation_labels[object_class_index, re_class_index] = 1 + return start_labels, end_labels, disambiguation_labels, relation_labels + + def __iter__(self): + dataset_iterator = self.dataset_iterator_func() + current_dataset_elements = [] + i = None + for i, dataset_elem in enumerate(dataset_iterator, start=1): + if ( + self.section_size is not None + and len(current_dataset_elements) == self.section_size + ): + for batch in self.materialize_batches(current_dataset_elements): + yield batch + current_dataset_elements = [] + current_dataset_elements.append(dataset_elem) + if i % 50_000 == 0: + logger.info(f"Processed: {i} number of elements") + if len(current_dataset_elements) != 0: + for batch in self.materialize_batches(current_dataset_elements): + yield batch + if i is not None: + logger.debug(f"Dataset finished: {i} number of elements processed") + else: + logger.warning("Dataset empty") + + def dataset_iterator_func(self): + data_samples = ( + load_relik_reader_samples(self.dataset_path) + if self.samples is None + or (isinstance(self.samples, list) and len(self.samples) == 0) + else self.samples + ) + if self.materialize_samples: + data_acc = [] + # take care of the tqdm nesting + # for sample in tqdm.tqdm(data_samples, desc="Reading dataset"): + for sample in data_samples: + if self.materialize_samples and sample.materialize is not None: + # tokenization_output = sample.materialize["tokenization_output"] + materialized = sample.materialize + del sample.materialize + yield { + "input_ids": materialized["tokenization_output"].input_ids, + "attention_mask": materialized[ + "tokenization_output" + ].attention_mask, + "token_type_ids": materialized[ + "tokenization_output" + ].token_type_ids, + "prediction_mask": materialized[ + "tokenization_output" + ].prediction_mask, + "special_symbols_mask": materialized[ + "tokenization_output" + ].special_symbols_mask, + "special_symbols_mask_entities": materialized[ + "tokenization_output" + ].special_symbols_mask_entities, + "sample": sample, + "start_labels": materialized["start_labels"], + "end_labels": materialized["end_labels"], + "disambiguation_labels": materialized["disambiguation_labels"], + "relation_labels": materialized["relation_labels"], + "predictable_candidates": materialized["candidates_symbols"], + } + sample.materialize = materialized + data_acc.append(sample) + continue + candidates_symbols = self.special_symbols + candidates_entities_symbols = self.special_symbols_types + + # sample.candidates = sample.candidates[: self.max_candidates] + + if len(self.special_symbols_types) > 0: + # sample.span_candidates = sample.span_candidates[ + # : self.max_ent_candidates + # ] + # add NME as a possible candidate + assert sample.span_candidates is not None + if self.use_nme: + sample.span_candidates.insert(0, NME_SYMBOL) + # sample.candidates.insert(0, NME_SYMBOL) + + sample.triplet_candidates = sample.triplet_candidates[ + : min(len(candidates_symbols), self.max_triplets) + ] + + if len(self.special_symbols_types) > 0: + sample.span_candidates = sample.span_candidates[ + : min(len(candidates_entities_symbols), self.max_spans) + ] + # training time sample mods + if not self.for_inference: + # check whether the sample has labels if not skip + if ( + sample.triplets is None or len(sample.triplets) == 0 + ) and self.skip_empty_training_samples: + logger.warning( + "Sample {} has no labels, skipping".format(sample.id) + ) + continue + + # add gold candidates if missing + if self.add_gold_candidates: + candidates_set = set(sample.triplet_candidates) + candidates_to_add = set() + for candidate_title in sample.triplets: + if candidate_title["relation"]["name"] not in candidates_set: + candidates_to_add.add(candidate_title["relation"]["name"]) + if len(candidates_to_add) > 0: + # replacing last candidates with the gold ones + # this is done in order to preserve the ordering + candidates_to_add = list(candidates_to_add) + added_gold_candidates = 0 + gold_candidates_titles_set = set( + set(ct["relation"]["name"] for ct in sample.triplets) + ) + for i in reversed(range(len(sample.triplet_candidates))): + if ( + sample.triplet_candidates[i] + not in gold_candidates_titles_set + and sample.triplet_candidates[i] != NME_SYMBOL + ): + sample.triplet_candidates[i] = candidates_to_add[ + added_gold_candidates + ] + added_gold_candidates += 1 + if len(candidates_to_add) == added_gold_candidates: + break + + candidates_still_to_add = ( + len(candidates_to_add) - added_gold_candidates + ) + while ( + len(sample.triplet_candidates) + <= min(len(candidates_symbols), self.max_triplets) + and candidates_still_to_add != 0 + ): + sample.triplet_candidates.append( + candidates_to_add[added_gold_candidates] + ) + added_gold_candidates += 1 + candidates_still_to_add -= 1 + + def shuffle_cands(shuffle_candidates, candidates): + if ( + isinstance(shuffle_candidates, bool) and shuffle_candidates + ) or ( + isinstance(shuffle_candidates, float) + and np.random.uniform() < shuffle_candidates + ): + np.random.shuffle(candidates) + if NME_SYMBOL in candidates: + candidates.remove(NME_SYMBOL) + candidates.insert(0, NME_SYMBOL) + return candidates + + def flip_cands(flip_candidates, candidates): + # flip candidates + if (isinstance(flip_candidates, bool) and flip_candidates) or ( + isinstance(flip_candidates, float) + and np.random.uniform() < flip_candidates + ): + for i in range(len(candidates) - 1): + if np.random.uniform() < 0.5: + candidates[i], candidates[i + 1] = ( + candidates[i + 1], + candidates[i], + ) + if NME_SYMBOL in candidates: + candidates.remove(NME_SYMBOL) + candidates.insert(0, NME_SYMBOL) + return candidates + + if self.shuffle_candidates: + sample.triplet_candidates = shuffle_cands( + self.shuffle_candidates, sample.triplet_candidates + ) + if len(self.special_symbols_types) > 0: + sample.span_candidates = shuffle_cands( + self.shuffle_candidates, sample.span_candidates + ) + elif self.flip_candidates: + sample.triplet_candidates = flip_cands( + self.flip_candidates, sample.triplet_candidates + ) + if len(self.special_symbols_types) > 0: + sample.span_candidates = flip_cands( + self.flip_candidates, sample.span_candidates + ) + + # candidates encoding + candidates_symbols = candidates_symbols[: len(sample.triplet_candidates)] + + candidates_encoding = [ + ["{} {}".format(cs, ct)] if ct != NME_SYMBOL else [NME_SYMBOL] + for cs, ct in zip(candidates_symbols, sample.triplet_candidates) + ] + if len(self.special_symbols_types) > 0: + candidates_entities_symbols = candidates_entities_symbols[ + : len(sample.span_candidates) + ] + candidates_types_encoding = [ + ["{} {}".format(cs, ct)] if ct != NME_SYMBOL else [NME_SYMBOL] + for cs, ct in zip( + candidates_entities_symbols, sample.span_candidates + ) + ] + candidates_encoding = ( + candidates_types_encoding + + [[self.tokenizer.sep_token]] + + candidates_encoding + ) + + pretoken_input = self._build_input(sample.words, candidates_encoding) + input_tokenized = self.tokenizer( + pretoken_input, + return_offsets_mapping=True, + add_special_tokens=False, + ) + + window_tokens = input_tokenized.input_ids + window_tokens = flatten(window_tokens) + + offsets_mapping = [ + [ + ( + ss + sample.token2char_start[str(i)], + se + sample.token2char_start[str(i)], + ) + for ss, se in input_tokenized.offset_mapping[i] + ] + for i in range(len(sample.words)) + ] + + offsets_mapping = flatten(offsets_mapping) + + token2char_start = {str(i): s for i, (s, _) in enumerate(offsets_mapping)} + token2char_end = {str(i): e for i, (_, e) in enumerate(offsets_mapping)} + token2word_start = { + str(i): int(sample._d["char2token_start"][str(s)]) + for i, (s, _) in enumerate(offsets_mapping) + if str(s) in sample._d["char2token_start"] + } + token2word_end = { + str(i): int(sample._d["char2token_end"][str(e)]) + for i, (_, e) in enumerate(offsets_mapping) + if str(e) in sample._d["char2token_end"] + } + # invert token2word_start and token2word_end + word2token_start = {str(v): int(k) for k, v in token2word_start.items()} + word2token_end = {str(v): int(k) for k, v in token2word_end.items()} + + sample._d.update( + dict( + tokens=window_tokens, + token2char_start=token2char_start, + token2char_end=token2char_end, + token2word_start=token2word_start, + token2word_end=token2word_end, + word2token_start=word2token_start, + word2token_end=word2token_end, + ) + ) + + input_subwords = flatten(input_tokenized["input_ids"][: len(sample.words)]) + offsets = input_tokenized["offset_mapping"][: len(sample.words)] + token2word = [] + word2token = {} + count = 0 + for i, offset in enumerate(offsets): + word2token[i] = [] + for token in offset: + token2word.append(i) + word2token[i].append(count) + count += 1 + + sample.token2word = token2word + sample.word2token = word2token + candidates_encoding_result = input_tokenized["input_ids"][ + len(sample.words) + 1 : -1 + ] + + i = 0 + cum_len = 0 + # drop candidates if the number of input tokens is too long for the model + if ( + sum(map(len, candidates_encoding_result)) + + len(input_subwords) + + 20 # + 20 special tokens + > self.model_max_length + ): + if self.for_inference: + acceptable_tokens_from_candidates = ( + self.model_max_length - 20 - len(input_subwords) + ) + + while ( + cum_len + len(candidates_encoding_result[i]) + < acceptable_tokens_from_candidates + ): + cum_len += len(candidates_encoding_result[i]) + i += 1 + + assert i > 0 + + candidates_encoding_result = candidates_encoding_result[:i] + if len(self.special_symbols_types) > 0: + candidates_symbols = candidates_symbols[ + : i - len(sample.span_candidates) + ] + sample.triplet_candidates = sample.triplet_candidates[ + : i - len(sample.span_candidates) + ] + else: + candidates_symbols = candidates_symbols[:i] + sample.triplet_candidates = sample.triplet_candidates[:i] + else: + gold_candidates_set = set( + [wl["relation"]["name"] for wl in sample.triplets] + ) + gold_candidates_indices = [ + i + for i, wc in enumerate(sample.triplet_candidates) + if wc in gold_candidates_set + ] + if len(self.special_symbols_types) > 0: + gold_candidates_indices = [ + i + len(sample.span_candidates) + for i in gold_candidates_indices + ] + # add entities indices + gold_candidates_indices = gold_candidates_indices + list( + range(len(sample.span_candidates)) + ) + necessary_taken_tokens = sum( + map( + len, + [ + candidates_encoding_result[i] + for i in gold_candidates_indices + ], + ) + ) + + acceptable_tokens_from_candidates = ( + self.model_max_length + - 20 + - len(input_subwords) + - necessary_taken_tokens + ) + if acceptable_tokens_from_candidates <= 0: + logger.warning( + "Sample {} has no candidates after truncation due to max length".format( + sample.id + ) + ) + continue + # assert acceptable_tokens_from_candidates > 0 + + i = 0 + cum_len = 0 + while ( + cum_len + len(candidates_encoding_result[i]) + < acceptable_tokens_from_candidates + ): + if i not in gold_candidates_indices: + cum_len += len(candidates_encoding_result[i]) + i += 1 + + new_indices = sorted( + list(set(list(range(i)) + gold_candidates_indices)) + ) + # np.random.shuffle(new_indices) + + candidates_encoding_result = [ + candidates_encoding_result[i] for i in new_indices + ] + if len(self.special_symbols_types) > 0: + sample.triplet_candidates = [ + sample.triplet_candidates[i - len(sample.span_candidates)] + for i in new_indices[len(sample.span_candidates) :] + ] + candidates_symbols = candidates_symbols[ + : i - len(sample.span_candidates) + ] + else: + candidates_symbols = [ + candidates_symbols[i] for i in new_indices + ] + sample.triplet_candidates = [ + sample.triplet_candidates[i] for i in new_indices + ] + if len(sample.triplet_candidates) == 0: + logger.warning( + "Sample {} has no candidates after truncation due to max length".format( + sample.sample_id + ) + ) + continue + + # final input_ids build + input_ids = self._build_input_ids( + sentence_input_ids=input_subwords, + candidates_input_ids=candidates_encoding_result, + ) + + # complete input building (e.g. attention / prediction mask) + tokenization_output = self._build_tokenizer_essentials( + input_ids, + input_subwords, + min(len(sample.span_candidates), len(self.special_symbols_types)) + if sample.span_candidates is not None + else 0, + ) + # labels creation + start_labels, end_labels, disambiguation_labels, relation_labels = ( + None, + None, + None, + None, + ) + if sample.entities is not None and len(sample.entities) > 0: + ( + start_labels, + end_labels, + disambiguation_labels, + relation_labels, + ) = self._build_labels( + sample, + tokenization_output, + ) + if self.materialize_samples: + sample.materialize = { + "tokenization_output": tokenization_output, + "start_labels": start_labels, + "end_labels": end_labels, + "disambiguation_labels": disambiguation_labels, + "relation_labels": relation_labels, + "candidates_symbols": candidates_symbols, + } + data_acc.append(sample) + yield { + "input_ids": tokenization_output.input_ids, + "attention_mask": tokenization_output.attention_mask, + "token_type_ids": tokenization_output.token_type_ids, + "prediction_mask": tokenization_output.prediction_mask, + "special_symbols_mask": tokenization_output.special_symbols_mask, + "special_symbols_mask_entities": tokenization_output.special_symbols_mask_entities, + "sample": sample, + "start_labels": start_labels, + "end_labels": end_labels, + "disambiguation_labels": disambiguation_labels, + "relation_labels": relation_labels, + "predictable_candidates": candidates_symbols, + } + if self.materialize_samples: + self.samples = data_acc + + def preshuffle_elements(self, dataset_elements: List): + # This shuffling is done so that when using the sorting function, + # if it is deterministic given a collection and its order, we will + # make the whole operation not deterministic anymore. + # Basically, the aim is not to build every time the same batches. + if not self.for_inference: + dataset_elements = np.random.permutation(dataset_elements) + + sorting_fn = ( + lambda elem: add_noise_to_value( + sum(len(elem[k]) for k in self.sorting_fields), + noise_param=self.noise_param, + ) + if not self.for_inference + else sum(len(elem[k]) for k in self.sorting_fields) + ) + + dataset_elements = sorted(dataset_elements, key=sorting_fn) + + if self.for_inference: + return dataset_elements + + ds = list(chunks(dataset_elements, 64)) # todo: modified + np.random.shuffle(ds) + return flatten(ds) + + def materialize_batches( + self, dataset_elements: List[Dict[str, Any]] + ) -> Generator[Dict[str, Any], None, None]: + if self.prebatch: + dataset_elements = self.preshuffle_elements(dataset_elements) + + current_batch = [] + + # function that creates a batch from the 'current_batch' list + def output_batch() -> Dict[str, Any]: + assert ( + len( + set([len(elem["predictable_candidates"]) for elem in current_batch]) + ) + == 1 + ), " ".join( + map( + str, [len(elem["predictable_candidates"]) for elem in current_batch] + ) + ) + + batch_dict = dict() + + de_values_by_field = { + fn: [de[fn] for de in current_batch if fn in de] + for fn in self.fields_batcher + } + + # in case you provide fields batchers but in the batch + # there are no elements for that field + de_values_by_field = { + fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 + } + + assert len(set([len(v) for v in de_values_by_field.values()])) + + # todo: maybe we should report the user about possible + # fields filtering due to "None" instances + de_values_by_field = { + fn: fvs + for fn, fvs in de_values_by_field.items() + if all([fv is not None for fv in fvs]) + } + + for field_name, field_values in de_values_by_field.items(): + field_batch = ( + self.fields_batcher[field_name](field_values) + if self.fields_batcher[field_name] is not None + else field_values + ) + + batch_dict[field_name] = field_batch + + return batch_dict + + max_len_discards, min_len_discards = 0, 0 + + should_token_batch = self.batch_size is None + + curr_pred_elements = -1 + for de in dataset_elements: + if ( + should_token_batch + and self.max_batch_size != -1 + and len(current_batch) == self.max_batch_size + ) or (not should_token_batch and len(current_batch) == self.batch_size): + yield output_batch() + current_batch = [] + curr_pred_elements = -1 + + # todo support max length (and min length) as dicts + + too_long_fields = [ + k + for k in de + if self.max_length != -1 + and torch.is_tensor(de[k]) + and len(de[k]) > self.max_length + ] + if len(too_long_fields) > 0: + max_len_discards += 1 + continue + + too_short_fields = [ + k + for k in de + if self.min_length != -1 + and torch.is_tensor(de[k]) + and len(de[k]) < self.min_length + ] + if len(too_short_fields) > 0: + min_len_discards += 1 + continue + + if should_token_batch: + de_len = sum(len(de[k]) for k in self.batching_fields) + + future_max_len = max( + de_len, + max( + [ + sum(len(bde[k]) for k in self.batching_fields) + for bde in current_batch + ], + default=0, + ), + ) + + future_tokens_per_batch = future_max_len * (len(current_batch) + 1) + + num_predictable_candidates = len(de["predictable_candidates"]) + + if len(current_batch) > 0 and ( + future_tokens_per_batch >= self.tokens_per_batch + or ( + num_predictable_candidates != curr_pred_elements + and curr_pred_elements != -1 + ) + ): + yield output_batch() + current_batch = [] + + current_batch.append(de) + curr_pred_elements = len(de["predictable_candidates"]) + + if len(current_batch) != 0 and not self.drop_last: + yield output_batch() + + if max_len_discards > 0: + if self.for_inference: + logger.warning( + f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " + f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" + f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " + f"sample length exceeds the maximum length supported by the current model." + ) + else: + logger.warning( + f"During iteration, {max_len_discards} elements were " + f"discarded since longer than max length {self.max_length}" + ) + + if min_len_discards > 0: + if self.for_inference: + logger.warning( + f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " + f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" + f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " + f"sample length is shorter than the minimum length supported by the current model." + ) + else: + logger.warning( + f"During iteration, {min_len_discards} elements were " + f"discarded since shorter than min length {self.min_length}" + ) + + @staticmethod + def _new_output_format(sample: RelikReaderSample) -> RelikReaderSample: + # try-out for a new format + + # set of span tuples (start, end, type) for each entity + predicted_spans = set() + for prediction in sample.predicted_entities: + predicted_spans.add( + ( + prediction[0], + prediction[1], + prediction[2], + ) + ) + + # sort the spans by start so that we can use the index of the span to get the entity + predicted_spans = sorted(predicted_spans, key=lambda x: x[0]) + predicted_triples = [] + # now search for the spans in each triplet + for prediction in sample.predicted_relations: + # get the index of the entity that has the same start and end + start_entity_index = [ + i + for i, p in enumerate(predicted_spans) + if p[:2] + == (prediction["subject"]["start"], prediction["subject"]["end"]) + ][0] + end_entity_index = [ + i + for i, p in enumerate(predicted_spans) + if p[:2] == (prediction["object"]["start"], prediction["object"]["end"]) + ][0] + + predicted_triples.append( + ( + start_entity_index, + prediction["relation"]["name"], + end_entity_index, + prediction["relation"]["probability"], + ) + ) + sample.predicted_spans = predicted_spans + sample.predicted_triples = predicted_triples + return sample + + @staticmethod + def _convert_annotations(sample: RelikReaderSample) -> RelikReaderSample: + triplets = [] + entities = [] + + for entity in sample.predicted_entities: + span_start = entity[0] - 1 + span_end = entity[1] - 1 + if str(span_start) not in sample.token2word_start: + # span_start is in the middle of a word + # retrieve the first token of the word + while str(span_start) not in sample.token2word_start: + span_start -= 1 + # skip + if span_start < 0: + break + if str(span_end) not in sample.token2word_end: + # span_end is in the middle of a word + # retrieve the last token of the word + while str(span_end) not in sample.token2word_end: + span_end += 1 + # skip + if span_end >= len(sample.tokens): + break + + if span_start < 0 or span_end >= len(sample.tokens): + continue + + entities.append( + ( + sample.token2word_start[str(span_start)], + sample.token2word_end[str(span_end)] + 1, + sample.span_candidates[entity[2]] + if sample.span_candidates and len(entity) > 2 + else "NME", + ) + ) + for predicted_triplet, predicted_triplet_probabilities in zip( + sample.predicted_relations, sample.predicted_relations_probabilities + ): + subject, object_, relation = predicted_triplet + subject = entities[subject] + object_ = entities[object_] + relation = sample.triplet_candidates[relation] + triplets.append( + { + "subject": { + "start": subject[0], + "end": subject[1], + "type": subject[2], + # "name": " ".join(sample.tokens[subject[0] : subject[1]]), + }, + "relation": { + "name": relation, + "probability": float(predicted_triplet_probabilities.round(2)), + }, + "object": { + "start": object_[0], + "end": object_[1], + "type": object_[2], + # "name": " ".join(sample.tokens[object_[0] : object_[1]]), + }, + } + ) + # convert to list since we need to modify the sample down the road + sample.predicted_entities = entities + sample.predicted_relations = triplets + del sample._d["predicted_relations_probabilities"] + + return sample + + @staticmethod + def convert_to_word_annotations(sample: RelikReaderSample) -> RelikReaderSample: + sample = RelikREDataset._convert_annotations(sample) + return RelikREDataset._new_output_format(sample) + + @staticmethod + def convert_to_char_annotations( + sample: RelikReaderSample, + remove_nmes: bool = True, + ) -> RelikReaderSample: + RelikREDataset._convert_annotations(sample) + if "token2char_start" in sample._d: + entities = [] + for entity in sample.predicted_entities: + entity = list(entity) + token_start = sample.word2token_start[str(entity[0])] + entity[0] = sample.token2char_start[str(token_start)] + token_end = sample.word2token_end[str(entity[1] - 1)] + entity[1] = sample.token2char_end[str(token_end)] + entities.append(entity) + sample.predicted_entities = entities + for triplet in sample.predicted_relations: + triplet["subject"]["start"] = sample.token2char_start[ + str(sample.word2token_start[str(triplet["subject"]["start"])]) + ] + triplet["subject"]["end"] = sample.token2char_end[ + str(sample.word2token_end[str(triplet["subject"]["end"] - 1)]) + ] + triplet["object"]["start"] = sample.token2char_start[ + str(sample.word2token_start[str(triplet["object"]["start"])]) + ] + triplet["object"]["end"] = sample.token2char_end[ + str(sample.word2token_end[str(triplet["object"]["end"] - 1)]) + ] + + sample = RelikREDataset._new_output_format(sample) + + return sample + + @staticmethod + def merge_patches_predictions(sample) -> None: + pass + + +def main(): + special_symbols = [NME_SYMBOL] + [f"R-{i}" for i in range(50)] + + relik_dataset = RelikREDataset( + "/home/huguetcabot/alby-re/alby/data/nyt-alby+/valid.jsonl", + materialize_samples=False, + transformer_model="microsoft/deberta-v3-base", + special_symbols=special_symbols, + shuffle_candidates=False, + flip_candidates=False, + for_inference=True, + ) + + for batch in relik_dataset: + print(batch) + exit(0) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/data/relik_reader_sample.py b/relik/reader/data/relik_reader_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..698181ef88ae28ba6a9c458644e936e68ca1c5d4 --- /dev/null +++ b/relik/reader/data/relik_reader_sample.py @@ -0,0 +1,62 @@ +import json +import numpy as np +from typing import Iterable + + +class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super(NpEncoder, self).default(obj) + + +class RelikReaderSample: + def __init__(self, **kwargs): + super().__setattr__("_d", {}) + self._d = kwargs + + def __getattribute__(self, item): + return super(RelikReaderSample, self).__getattribute__(item) + + def __getattr__(self, item): + if item.startswith("__") and item.endswith("__"): + # this is likely some python library-specific variable (such as __deepcopy__ for copy) + # better follow standard behavior here + raise AttributeError(item) + elif item in self._d: + return self._d[item] + else: + return None + + def __setattr__(self, key, value): + if key in self._d: + self._d[key] = value + else: + super().__setattr__(key, value) + + def to_jsons(self) -> str: + if "predicted_window_labels" in self._d: + new_obj = { + k: v + for k, v in self._d.items() + if k != "predicted_window_labels" and k != "span_title_probabilities" + } + new_obj["predicted_window_labels"] = [ + [ss, se, pred_title] + for (ss, se, pred_title) in self.predicted_window_labels_chars + ] + return new_obj + else: + return json.dumps(self._d, cls=NpEncoder) + + +def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]: + with open(path) as f: + for line in f: + jsonl_line = json.loads(line.strip()) + relik_reader_sample = RelikReaderSample(**jsonl_line) + yield relik_reader_sample diff --git a/relik/reader/lightning_modules/__init__.py b/relik/reader/lightning_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/lightning_modules/__pycache__/__init__.cpython-310.pyc b/relik/reader/lightning_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1828485daa43b4fccc9c240bf6c2a24ee0d5bcd Binary files /dev/null and b/relik/reader/lightning_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/reader/lightning_modules/__pycache__/relik_reader_pl_module.cpython-310.pyc b/relik/reader/lightning_modules/__pycache__/relik_reader_pl_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fe065c997f03af5e7e532803756f3666f703fb1 Binary files /dev/null and b/relik/reader/lightning_modules/__pycache__/relik_reader_pl_module.cpython-310.pyc differ diff --git a/relik/reader/lightning_modules/relik_reader_pl_module.py b/relik/reader/lightning_modules/relik_reader_pl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..141ad2e4c7feaa05c38feb760a976e637761dfe6 --- /dev/null +++ b/relik/reader/lightning_modules/relik_reader_pl_module.py @@ -0,0 +1,51 @@ +from typing import Any, Optional + +import lightning +from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler + +# from relik.reader.relik_reader_core import RelikReaderCoreModel +from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction + + +class RelikReaderPLModule(lightning.LightningModule): + def __init__( + self, + cfg: dict, + transformer_model: str, + additional_special_symbols: int, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + training: bool = False, + *args: Any, + **kwargs: Any + ): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.relik_reader_core_model = RelikReaderForSpanExtraction( + transformer_model, + additional_special_symbols, + num_layers, + activation, + linears_hidden_size, + use_last_k_layers, + training=training, + ) + self.optimizer_factory = None + + def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + relik_output = self.relik_reader_core_model(**batch) + self.log("train-loss", relik_output["loss"]) + return relik_output["loss"] + + def validation_step( + self, batch: dict, *args: Any, **kwargs: Any + ) -> Optional[STEP_OUTPUT]: + return + + def set_optimizer_factory(self, optimizer_factory) -> None: + self.optimizer_factory = optimizer_factory + + def configure_optimizers(self) -> OptimizerLRScheduler: + return self.optimizer_factory(self.relik_reader_core_model) diff --git a/relik/reader/lightning_modules/relik_reader_re_pl_module.py b/relik/reader/lightning_modules/relik_reader_re_pl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..b099287f780af26936e7dc0c91e2684e28f746ee --- /dev/null +++ b/relik/reader/lightning_modules/relik_reader_re_pl_module.py @@ -0,0 +1,61 @@ +from typing import Any, Optional + +import lightning +from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler + +from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction + + +class RelikReaderREPLModule(lightning.LightningModule): + def __init__( + self, + cfg: dict, + transformer_model: str, + additional_special_symbols: int, + additional_special_symbols_types: Optional[int] = 0, + entity_type_loss: bool = None, + add_entity_embedding: bool = None, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + training: bool = False, + *args: Any, + **kwargs: Any + ): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + self.relik_reader_re_model = RelikReaderForTripletExtraction( + transformer_model, + additional_special_symbols, + additional_special_symbols_types, + entity_type_loss, + add_entity_embedding, + num_layers, + activation, + linears_hidden_size, + use_last_k_layers, + training=training, + **kwargs, + ) + self.optimizer_factory = None + + def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + relik_output = self.relik_reader_re_model(**batch) + self.log("train-loss", relik_output["loss"]) + self.log("train-start_loss", relik_output["ned_start_loss"]) + self.log("train-end_loss", relik_output["ned_end_loss"]) + self.log("train-relation_loss", relik_output["re_loss"]) + return relik_output["loss"] + + def validation_step( + self, batch: dict, *args: Any, **kwargs: Any + ) -> Optional[STEP_OUTPUT]: + return + + def set_optimizer_factory(self, optimizer_factory) -> None: + self.optimizer_factory = optimizer_factory + + def configure_optimizers(self) -> OptimizerLRScheduler: + return self.optimizer_factory(self.relik_reader_re_model) diff --git a/relik/reader/pytorch_modules/__init__.py b/relik/reader/pytorch_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/pytorch_modules/__pycache__/__init__.cpython-310.pyc b/relik/reader/pytorch_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..049454eba01623c75e855ad21ef79ec927b84703 Binary files /dev/null and b/relik/reader/pytorch_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/__pycache__/base.cpython-310.pyc b/relik/reader/pytorch_modules/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53637486c68f78590fbd84ad7920a51db9bf2143 Binary files /dev/null and b/relik/reader/pytorch_modules/__pycache__/base.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/__pycache__/span.cpython-310.pyc b/relik/reader/pytorch_modules/__pycache__/span.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5564ab14d9253d122d22bd12740c80b952148e3a Binary files /dev/null and b/relik/reader/pytorch_modules/__pycache__/span.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/__pycache__/triplet.cpython-310.pyc b/relik/reader/pytorch_modules/__pycache__/triplet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46ca2057f2ff581abb6032dce9e14722e6738365 Binary files /dev/null and b/relik/reader/pytorch_modules/__pycache__/triplet.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/base.py b/relik/reader/pytorch_modules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6f91cb3142602f1be8ed28f4a789165289981ec4 --- /dev/null +++ b/relik/reader/pytorch_modules/base.py @@ -0,0 +1,270 @@ +import logging +import os +from pathlib import Path +from typing import Any, Dict, List + +import torch +import transformers as tr +from torch.utils.data import IterableDataset +from transformers import AutoConfig + +from relik.common.log import get_logger + +# from relik.common.torch_utils import load_ort_optimized_hf_model +from relik.common.utils import get_callable_from_string +from relik.inference.data.objects import AnnotationType +from relik.reader.pytorch_modules.hf.modeling_relik import ( + RelikReaderConfig, + RelikReaderSample, +) +from relik.retriever.pytorch_modules import PRECISION_MAP + +logger = get_logger(__name__, level=logging.INFO) + + +class RelikReaderBase(torch.nn.Module): + default_reader_class: str | None = None + default_data_class: str | None = None + + def __init__( + self, + transformer_model: str | tr.PreTrainedModel | None = None, + additional_special_symbols: int = 0, + num_layers: int | None = None, + activation: str = "gelu", + linears_hidden_size: int | None = 512, + use_last_k_layers: int = 1, + training: bool = False, + device: str | torch.device | None = None, + precision: int = 32, + tokenizer: str | tr.PreTrainedTokenizer | None = None, + dataset: IterableDataset | str | None = None, + default_reader_class: tr.PreTrainedModel | str | None = None, + **kwargs, + ) -> None: + super().__init__() + + self.default_reader_class = default_reader_class or self.default_reader_class + + if self.default_reader_class is None: + raise ValueError("You must specify a default reader class.") + + # get the callable for the default reader class + self.default_reader_class: tr.PreTrainedModel = get_callable_from_string( + self.default_reader_class + ) + + if isinstance(transformer_model, str): + config = AutoConfig.from_pretrained( + transformer_model, trust_remote_code=True + ) + if "relik-reader" in config.model_type: + transformer_model = self.default_reader_class.from_pretrained( + transformer_model, config=config, **kwargs + ) + else: + reader_config = RelikReaderConfig( + transformer_model=transformer_model, + additional_special_symbols=additional_special_symbols, + num_layers=num_layers, + activation=activation, + linears_hidden_size=linears_hidden_size, + use_last_k_layers=use_last_k_layers, + training=training, + **kwargs, + ) + transformer_model = self.default_reader_class(reader_config) + + self.relik_reader_model = transformer_model + + self.relik_reader_model_config = self.relik_reader_model.config + # self.name_or_path = self.relik_reader_model_config.name_or_path + self.name_or_path = self.relik_reader_model.config.transformer_model + + # get the tokenizer + self._tokenizer = tokenizer + + # and instantiate the dataset class + self.dataset: IterableDataset | None = dataset + + # move the model to the device + self.to(device or torch.device("cpu")) + + # set the precision + self.precision = precision + self.to(PRECISION_MAP[precision]) + + def forward(self, **kwargs) -> Dict[str, Any]: + return self.relik_reader_model(**kwargs) + + def _read(self, *args, **kwargs) -> Any: + raise NotImplementedError + + @torch.no_grad() + @torch.inference_mode() + def read( + self, + text: List[str] | List[List[str]] | None = None, + samples: List[RelikReaderSample] | None = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + candidates: List[List[str]] | None = None, + max_length: int = 1000, + max_batch_size: int = 128, + token_batch_size: int = 2048, + precision: int | str | None = None, + annotation_type: str | AnnotationType = AnnotationType.CHAR, + progress_bar: bool = False, + *args, + **kwargs, + ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: + """ + Reads the given text. + + Args: + text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`): + The text to read in tokens. If a list of list of tokens is provided, each + inner list is considered a sentence. + samples (:obj:`List[RelikReaderSample]`, `optional`): + The samples to read. If provided, `text` and `candidates` are ignored. + input_ids (:obj:`torch.Tensor`, `optional`): + The input ids of the text. + attention_mask (:obj:`torch.Tensor`, `optional`): + The attention mask of the text. + token_type_ids (:obj:`torch.Tensor`, `optional`): + The token type ids of the text. + prediction_mask (:obj:`torch.Tensor`, `optional`): + The prediction mask of the text. + special_symbols_mask (:obj:`torch.Tensor`, `optional`): + The special symbols mask of the text. + candidates (:obj:`List[List[str]]`, `optional`): + The candidates of the text. + max_length (:obj:`int`, `optional`, defaults to 1024): + The maximum length of the text. + max_batch_size (:obj:`int`, `optional`, defaults to 128): + The maximum batch size. + token_batch_size (:obj:`int`, `optional`): + The maximum number of tokens per batch. + precision (:obj:`int` or :obj:`str`, `optional`): + The precision to use. If not provided, the default is 32 bit. + annotation_type (`str` or `AnnotationType`, `optional`, defaults to `char`): + The type of annotation to return. If `char`, the spans will be in terms of + character offsets. If `word`, the spans will be in terms of word offsets. + progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to show a progress bar. + + Returns: + The predicted labels for each sample. + """ + if isinstance(annotation_type, str): + try: + annotation_type = AnnotationType(annotation_type) + except ValueError: + raise ValueError( + f"Annotation type `{annotation_type}` not recognized. " + f"Please choose one of {list(AnnotationType)}." + ) + + if text is None and input_ids is None and samples is None: + raise ValueError( + "Either `text` or `input_ids` or `samples` must be provided." + ) + if (input_ids is None and samples is None) and ( + text is None or candidates is None + ): + raise ValueError( + "`text` and `candidates` must be provided to return the predictions when " + "`input_ids` and `samples` is not provided." + ) + if text is not None and samples is None: + if len(text) != len(candidates): + raise ValueError("`text` and `candidates` must have the same length.") + if isinstance(text[0], str): # change to list of text + text = [text] + candidates = [candidates] + + samples = [ + RelikReaderSample(tokens=t, candidates=c) + for t, c in zip(text, candidates) + ] + + return self._read( + samples, + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + max_length, + max_batch_size, + token_batch_size, + precision or self.precision, + annotation_type, + progress_bar, + *args, + **kwargs, + ) + + @property + def device(self) -> torch.device: + """ + The device of the model. + """ + return next(self.parameters()).device + + @property + def tokenizer(self) -> tr.PreTrainedTokenizer: + """ + The tokenizer. + """ + if self._tokenizer: + return self._tokenizer + + self._tokenizer = tr.AutoTokenizer.from_pretrained( + self.relik_reader_model.config.name_or_path if self.relik_reader_model.config.name_or_path else self.relik_reader_model.config.transformer_model + ) + return self._tokenizer + + def save_pretrained( + self, + output_dir: str | os.PathLike, + model_name: str | None = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + """ + Saves the model to the given path. + + Args: + output_dir (`str` or :obj:`os.PathLike`): + The path to save the model to. + model_name (`str`, `optional`): + The name of the model. If not provided, the model will be saved as + `default_reader_class.__name__`. + push_to_hub (`bool`, `optional`, defaults to `False`): + Whether to push the model to the HuggingFace Hub. + **kwargs: + Additional keyword arguments to pass to the `save_pretrained` method + """ + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + model_name = model_name or output_dir.name + + logger.info(f"Saving reader to {output_dir / model_name}") + + # save the model + self.relik_reader_model.register_for_auto_class() + self.relik_reader_model.save_pretrained( + str(output_dir / model_name), push_to_hub=push_to_hub, **kwargs + ) + + if self.tokenizer: + logger.info("Saving also the tokenizer") + self.tokenizer.save_pretrained( + str(output_dir / model_name), push_to_hub=push_to_hub, **kwargs + ) diff --git a/relik/reader/pytorch_modules/hf/__init__.py b/relik/reader/pytorch_modules/hf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c158e6ab6dcd3ab43e60751218600fbb0a5ed5 --- /dev/null +++ b/relik/reader/pytorch_modules/hf/__init__.py @@ -0,0 +1,2 @@ +from .configuration_relik import RelikReaderConfig +from .modeling_relik import RelikReaderREModel diff --git a/relik/reader/pytorch_modules/hf/__pycache__/__init__.cpython-310.pyc b/relik/reader/pytorch_modules/hf/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d919462ae035e0b364e1653717a362b63aadfc1b Binary files /dev/null and b/relik/reader/pytorch_modules/hf/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/hf/__pycache__/configuration_relik.cpython-310.pyc b/relik/reader/pytorch_modules/hf/__pycache__/configuration_relik.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a50d40ea81191a6364947145b6c16a92f2eedf6 Binary files /dev/null and b/relik/reader/pytorch_modules/hf/__pycache__/configuration_relik.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/hf/__pycache__/modeling_relik.cpython-310.pyc b/relik/reader/pytorch_modules/hf/__pycache__/modeling_relik.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29aa1ab7abd806e1b1c7e84f197aa8c5d7a4e613 Binary files /dev/null and b/relik/reader/pytorch_modules/hf/__pycache__/modeling_relik.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/hf/configuration_relik.py b/relik/reader/pytorch_modules/hf/configuration_relik.py new file mode 100644 index 0000000000000000000000000000000000000000..2f5b4b4b9fb221b1b73529e25bb6f54fd9c93bbd --- /dev/null +++ b/relik/reader/pytorch_modules/hf/configuration_relik.py @@ -0,0 +1,44 @@ +from typing import Optional + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig + + +class RelikReaderConfig(PretrainedConfig): + model_type = "relik-reader" + + def __init__( + self, + transformer_model: str = "microsoft/deberta-v3-base", + additional_special_symbols: int = 101, + additional_special_symbols_types: Optional[int] = 0, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + entity_type_loss: bool = False, + add_entity_embedding: bool = None, + training: bool = False, + default_reader_class: Optional[str] = None, + **kwargs + ) -> None: + # TODO: add name_or_path to kwargs + self.transformer_model = transformer_model + self.additional_special_symbols = additional_special_symbols + self.additional_special_symbols_types = additional_special_symbols_types + self.num_layers = num_layers + self.activation = activation + self.linears_hidden_size = linears_hidden_size + self.use_last_k_layers = use_last_k_layers + self.entity_type_loss = entity_type_loss + self.add_entity_embedding = ( + True + if add_entity_embedding is None and entity_type_loss + else add_entity_embedding + ) + self.training = training + self.default_reader_class = default_reader_class + super().__init__(**kwargs) + + +AutoConfig.register("relik-reader", RelikReaderConfig) diff --git a/relik/reader/pytorch_modules/hf/modeling_relik.py b/relik/reader/pytorch_modules/hf/modeling_relik.py new file mode 100644 index 0000000000000000000000000000000000000000..8842471716547c1481bd93a1ce767527cd195a7f --- /dev/null +++ b/relik/reader/pytorch_modules/hf/modeling_relik.py @@ -0,0 +1,980 @@ +from typing import Any, Dict, Optional + +import torch +from transformers import AutoModel, PreTrainedModel +from transformers.activations import ClippedGELUActivation, GELUActivation +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PoolerEndLogits + +from .configuration_relik import RelikReaderConfig + +torch.set_float32_matmul_precision('medium') + +class RelikReaderSample: + def __init__(self, **kwargs): + super().__setattr__("_d", {}) + self._d = kwargs + + def __getattribute__(self, item): + return super(RelikReaderSample, self).__getattribute__(item) + + def __getattr__(self, item): + if item.startswith("__") and item.endswith("__"): + # this is likely some python library-specific variable (such as __deepcopy__ for copy) + # better follow standard behavior here + raise AttributeError(item) + elif item in self._d: + return self._d[item] + else: + return None + + def __setattr__(self, key, value): + if key in self._d: + self._d[key] = value + else: + super().__setattr__(key, value) + + +activation2functions = { + "relu": torch.nn.ReLU(), + "gelu": GELUActivation(), + "gelu_10": ClippedGELUActivation(-10, 10), +} + + +class PoolerEndLogitsBi(PoolerEndLogits): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.dense_1 = torch.nn.Linear(config.hidden_size, 2) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + if p_mask is not None: + p_mask = p_mask.unsqueeze(-1) + logits = super().forward( + hidden_states, + start_states, + start_positions, + p_mask, + ) + return logits + + +class RelikReaderSpanModel(PreTrainedModel): + config_class = RelikReaderConfig + + def __init__(self, config: RelikReaderConfig, *args, **kwargs): + super().__init__(config) + # Transformer model declaration + self.config = config + self.transformer_model = ( + AutoModel.from_pretrained(self.config.transformer_model) + if self.config.num_layers is None + else AutoModel.from_pretrained( + self.config.transformer_model, num_hidden_layers=self.config.num_layers + ) + ) + self.transformer_model.resize_token_embeddings( + self.transformer_model.config.vocab_size + + self.config.additional_special_symbols, + pad_to_multiple_of=8, + ) + + self.activation = self.config.activation + self.linears_hidden_size = self.config.linears_hidden_size + self.use_last_k_layers = self.config.use_last_k_layers + + # named entity detection layers + self.ned_start_classifier = self._get_projection_layer( + self.activation, last_hidden=2, layer_norm=False + ) + self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) + + # END entity disambiguation layer + self.ed_start_projector = self._get_projection_layer(self.activation) + self.ed_end_projector = self._get_projection_layer(self.activation) + + self.training = self.config.training + + # criterion + self.criterion = torch.nn.CrossEntropyLoss() + + def _get_projection_layer( + self, + activation: str, + last_hidden: Optional[int] = None, + input_hidden=None, + layer_norm: bool = True, + ) -> torch.nn.Sequential: + head_components = [ + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.transformer_model.config.hidden_size * self.use_last_k_layers + if input_hidden is None + else input_hidden, + self.linears_hidden_size, + ), + activation2functions[activation], + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.linears_hidden_size, + self.linears_hidden_size if last_hidden is None else last_hidden, + ), + ] + + if layer_norm: + head_components.append( + torch.nn.LayerNorm( + self.linears_hidden_size if last_hidden is None else last_hidden, + self.transformer_model.config.layer_norm_eps, + ) + ) + + return torch.nn.Sequential(*head_components) + + def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + mask = mask.unsqueeze(-1) + if next(self.parameters()).dtype == torch.float16: + logits = logits * (1 - mask) - 65500 * mask + else: + logits = logits * (1 - mask) - 1e30 * mask + return logits + + def _get_model_features( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor], + ): + model_input = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "output_hidden_states": self.use_last_k_layers > 1, + } + + if token_type_ids is not None: + model_input["token_type_ids"] = token_type_ids + + model_output = self.transformer_model(**model_input) + + if self.use_last_k_layers > 1: + model_features = torch.cat( + model_output[1][-self.use_last_k_layers :], dim=-1 + ) + else: + model_features = model_output[0] + + return model_features + + def compute_ned_end_logits( + self, + start_predictions, + start_labels, + model_features, + prediction_mask, + batch_size, + ) -> Optional[torch.Tensor]: + # todo: maybe when constraining on the spans, + # we should not use a prediction_mask for the end tokens. + # at least we should not during training imo + start_positions = start_labels if self.training else start_predictions + start_positions_indices = ( + torch.arange(start_positions.size(1), device=start_positions.device) + .unsqueeze(0) + .expand(batch_size, -1)[start_positions > 0] + ).to(start_positions.device) + + if len(start_positions_indices) > 0: + expanded_features = model_features.repeat_interleave( + torch.sum(start_positions > 0, dim=-1), dim=0 + ) + expanded_prediction_mask = prediction_mask.repeat_interleave( + torch.sum(start_positions > 0, dim=-1), dim=0 + ) + end_logits = self.ned_end_classifier( + hidden_states=expanded_features, + start_positions=start_positions_indices, + p_mask=expanded_prediction_mask, + ) + + return end_logits + + return None + + def compute_classification_logits( + self, + model_features, + special_symbols_mask, + prediction_mask, + batch_size, + start_positions=None, + end_positions=None, + ) -> torch.Tensor: + if start_positions is None or end_positions is None: + start_positions = torch.zeros_like(prediction_mask) + end_positions = torch.zeros_like(prediction_mask) + + model_start_features = self.ed_start_projector(model_features) + model_end_features = self.ed_end_projector(model_features) + model_end_features[start_positions > 0] = model_end_features[end_positions > 0] + + model_ed_features = torch.cat( + [model_start_features, model_end_features], dim=-1 + ) + + # computing ed features + classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() + special_symbols_representation = model_ed_features[special_symbols_mask].view( + batch_size, classes_representations, -1 + ) + + logits = torch.bmm( + model_ed_features, + torch.permute(special_symbols_representation, (0, 2, 1)), + ) + + logits = self._mask_logits(logits, prediction_mask) + + return logits + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + start_labels: Optional[torch.Tensor] = None, + end_labels: Optional[torch.Tensor] = None, + use_predefined_spans: bool = False, + *args, + **kwargs, + ) -> Dict[str, Any]: + batch_size, seq_len = input_ids.shape + + model_features = self._get_model_features( + input_ids, attention_mask, token_type_ids + ) + + ned_start_labels = None + + # named entity detection if required + if use_predefined_spans: # no need to compute spans + ned_start_logits, ned_start_probabilities, ned_start_predictions = ( + None, + None, + torch.clone(start_labels) + if start_labels is not None + else torch.zeros_like(input_ids), + ) + ned_end_logits, ned_end_probabilities, ned_end_predictions = ( + None, + None, + torch.clone(end_labels) + if end_labels is not None + else torch.zeros_like(input_ids), + ) + + ned_start_predictions[ned_start_predictions > 0] = 1 + ned_end_predictions[ned_end_predictions > 0] = 1 + + else: # compute spans + # start boundary prediction + ned_start_logits = self.ned_start_classifier(model_features) + ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) + ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) + ned_start_predictions = ned_start_probabilities.argmax(dim=-1) + + # end boundary prediction + ned_start_labels = ( + torch.zeros_like(start_labels) if start_labels is not None else None + ) + + if ned_start_labels is not None: + ned_start_labels[start_labels == -100] = -100 + ned_start_labels[start_labels > 0] = 1 + + ned_end_logits = self.compute_ned_end_logits( + ned_start_predictions, + ned_start_labels, + model_features, + prediction_mask, + batch_size, + ) + + if ned_end_logits is not None: + ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) + ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) + else: + ned_end_logits, ned_end_probabilities = None, None + ned_end_predictions = ned_start_predictions.new_zeros(batch_size) + + # flattening end predictions + # (flattening can happen only if the + # end boundaries were not predicted using the gold labels) + if not self.training and ned_end_logits is not None: + flattened_end_predictions = torch.zeros_like(ned_start_predictions) + + row_indices, start_positions = torch.where(ned_start_predictions > 0) + ned_end_predictions[ned_end_predictions torch.cat((end_spans_repeated[:1], cummax_values[:-1]))) + end_spans_repeated[0] = True + + ned_start_predictions[row_indices[~end_spans_repeated], start_positions[~end_spans_repeated]] = 0 + + row_indices, start_positions, ned_end_predictions = row_indices[end_spans_repeated], start_positions[end_spans_repeated], ned_end_predictions[end_spans_repeated] + + flattened_end_predictions[row_indices, ned_end_predictions] = 1 + + total_start_predictions, total_end_predictions = ned_start_predictions.sum(), flattened_end_predictions.sum() + + assert ( + total_start_predictions == 0 + or total_start_predictions == total_end_predictions + ), ( + f"Total number of start predictions = {total_start_predictions}. " + f"Total number of end predictions = {total_end_predictions}" + ) + ned_end_predictions = flattened_end_predictions + else: + ned_end_predictions = torch.zeros_like(ned_start_predictions) + + start_position, end_position = ( + (start_labels, end_labels) + if self.training + else (ned_start_predictions, ned_end_predictions) + ) + + # Entity disambiguation + ed_logits = self.compute_classification_logits( + model_features, + special_symbols_mask, + prediction_mask, + batch_size, + start_position, + end_position, + ) + ed_probabilities = torch.softmax(ed_logits, dim=-1) + ed_predictions = torch.argmax(ed_probabilities, dim=-1) + + # output build + output_dict = dict( + batch_size=batch_size, + ned_start_logits=ned_start_logits, + ned_start_probabilities=ned_start_probabilities, + ned_start_predictions=ned_start_predictions, + ned_end_logits=ned_end_logits, + ned_end_probabilities=ned_end_probabilities, + ned_end_predictions=ned_end_predictions, + ed_logits=ed_logits, + ed_probabilities=ed_probabilities, + ed_predictions=ed_predictions, + ) + + # compute loss if labels + if start_labels is not None and end_labels is not None and self.training: + # named entity detection loss + + # start + if ned_start_logits is not None: + ned_start_loss = self.criterion( + ned_start_logits.view(-1, ned_start_logits.shape[-1]), + ned_start_labels.view(-1), + ) + else: + ned_start_loss = 0 + + # end + if ned_end_logits is not None: + ned_end_labels = torch.zeros_like(end_labels) + ned_end_labels[end_labels == -100] = -100 + ned_end_labels[end_labels > 0] = 1 + + ned_end_loss = self.criterion( + ned_end_logits, + ( + torch.arange( + ned_end_labels.size(1), device=ned_end_labels.device + ) + .unsqueeze(0) + .expand(batch_size, -1)[ned_end_labels > 0] + ).to(ned_end_labels.device), + ) + + else: + ned_end_loss = 0 + + # entity disambiguation loss + start_labels[ned_start_labels != 1] = -100 + ed_labels = torch.clone(start_labels) + ed_labels[end_labels > 0] = end_labels[end_labels > 0] + ed_loss = self.criterion( + ed_logits.view(-1, ed_logits.shape[-1]), + ed_labels.view(-1), + ) + + output_dict["ned_start_loss"] = ned_start_loss + output_dict["ned_end_loss"] = ned_end_loss + output_dict["ed_loss"] = ed_loss + + output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss + + return output_dict + + +class RelikReaderREModel(PreTrainedModel): + config_class = RelikReaderConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config) + # Transformer model declaration + # self.transformer_model_name = transformer_model + self.config = config + self.transformer_model = ( + AutoModel.from_pretrained(config.transformer_model) + if config.num_layers is None + else AutoModel.from_pretrained( + config.transformer_model, num_hidden_layers=config.num_layers + ) + ) + self.transformer_model.resize_token_embeddings( + self.transformer_model.config.vocab_size + + config.additional_special_symbols + + config.additional_special_symbols_types, + pad_to_multiple_of=8, + ) + + # named entity detection layers + self.ned_start_classifier = self._get_projection_layer( + config.activation, last_hidden=2, layer_norm=False + ) + + self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config) + + self.relation_disambiguation_loss = ( + config.relation_disambiguation_loss + if hasattr(config, "relation_disambiguation_loss") + else False + ) + + if self.config.entity_type_loss and self.config.add_entity_embedding: + input_hidden_ents = 3 * self.transformer_model.config.hidden_size + else: + input_hidden_ents = 2 * self.transformer_model.config.hidden_size + + self.re_subject_projector = self._get_projection_layer( + config.activation, input_hidden=input_hidden_ents + ) + self.re_object_projector = self._get_projection_layer( + config.activation, input_hidden=input_hidden_ents + ) + self.re_relation_projector = self._get_projection_layer(config.activation) + + if self.config.entity_type_loss or self.relation_disambiguation_loss: + self.re_entities_projector = self._get_projection_layer( + config.activation, + input_hidden=2 * self.transformer_model.config.hidden_size, + ) + self.re_definition_projector = self._get_projection_layer( + config.activation, + ) + + self.re_classifier = self._get_projection_layer( + config.activation, + input_hidden=config.linears_hidden_size, + last_hidden=2, + layer_norm=False, + ) + + self.training = config.training + + # criterion + self.criterion = torch.nn.CrossEntropyLoss() + self.criterion_type = torch.nn.BCEWithLogitsLoss() + + def _get_projection_layer( + self, + activation: str, + last_hidden: Optional[int] = None, + input_hidden=None, + layer_norm: bool = True, + ) -> torch.nn.Sequential: + head_components = [ + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.transformer_model.config.hidden_size + * self.config.use_last_k_layers + if input_hidden is None + else input_hidden, + self.config.linears_hidden_size, + ), + activation2functions[activation], + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.config.linears_hidden_size, + self.config.linears_hidden_size if last_hidden is None else last_hidden, + ), + ] + + if layer_norm: + head_components.append( + torch.nn.LayerNorm( + self.config.linears_hidden_size + if last_hidden is None + else last_hidden, + self.transformer_model.config.layer_norm_eps, + ) + ) + + return torch.nn.Sequential(*head_components) + + def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + mask = mask.unsqueeze(-1) + if next(self.parameters()).dtype == torch.float16: + logits = logits * (1 - mask) - 65500 * mask + else: + logits = logits * (1 - mask) - 1e30 * mask + return logits + + def _get_model_features( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor], + ): + model_input = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "output_hidden_states": self.config.use_last_k_layers > 1, + } + + if token_type_ids is not None: + model_input["token_type_ids"] = token_type_ids + + model_output = self.transformer_model(**model_input) + + if self.config.use_last_k_layers > 1: + model_features = torch.cat( + model_output[1][-self.config.use_last_k_layers :], dim=-1 + ) + else: + model_features = model_output[0] + + return model_features + + def compute_ned_end_logits( + self, + start_predictions, + start_labels, + model_features, + prediction_mask, + batch_size, + mask_preceding: bool = False, + ) -> Optional[torch.Tensor]: + # todo: maybe when constraining on the spans, + # we should not use a prediction_mask for the end tokens. + # at least we should not during training imo + start_positions = start_labels if self.training else start_predictions + start_positions_indices = ( + torch.arange(start_positions.size(1), device=start_positions.device) + .unsqueeze(0) + .expand(batch_size, -1)[start_positions > 0] + ).to(start_positions.device) + + if len(start_positions_indices) > 0: + expanded_features = model_features.repeat_interleave( + torch.sum(start_positions > 0, dim=-1), dim=0 + ) + expanded_prediction_mask = prediction_mask.repeat_interleave( + torch.sum(start_positions > 0, dim=-1), dim=0 + ) + if mask_preceding: + expanded_prediction_mask[ + torch.arange( + expanded_prediction_mask.shape[1], + device=expanded_prediction_mask.device, + ) + < start_positions_indices.unsqueeze(1) + ] = 1 + end_logits = self.ned_end_classifier( + hidden_states=expanded_features, + start_positions=start_positions_indices, + p_mask=expanded_prediction_mask, + ) + + return end_logits + + return None + + def compute_relation_logits( + self, + model_entity_features, + special_symbols_features, + ) -> torch.Tensor: + model_subject_features = self.re_subject_projector(model_entity_features) + model_object_features = self.re_object_projector(model_entity_features) + special_symbols_start_representation = self.re_relation_projector( + special_symbols_features + ) + re_logits = torch.einsum( + "bse,bde,bfe->bsdfe", + model_subject_features, + model_object_features, + special_symbols_start_representation, + ) + re_logits = self.re_classifier(re_logits) + + return re_logits + + def compute_entity_logits( + self, + model_entity_features, + special_symbols_features, + ) -> torch.Tensor: + model_ed_features = self.re_entities_projector(model_entity_features) + special_symbols_ed_representation = self.re_definition_projector( + special_symbols_features + ) + + logits = torch.bmm( + model_ed_features, + torch.permute(special_symbols_ed_representation, (0, 2, 1)), + ) + logits = self._mask_logits( + logits, (model_entity_features == -100).all(2).long() + ) + return logits + + def compute_loss(self, logits, labels, mask=None): + logits = logits.reshape(-1, logits.shape[-1]) + labels = labels.reshape(-1).long() + if mask is not None: + return self.criterion(logits[mask], labels[mask]) + return self.criterion(logits, labels) + + def compute_ned_type_loss( + self, + disambiguation_labels, + re_ned_entities_logits, + ned_type_logits, + re_entities_logits, + entity_types, + mask, + ): + if self.config.entity_type_loss and self.relation_disambiguation_loss: + return self.criterion_type( + re_ned_entities_logits[disambiguation_labels != -100], + disambiguation_labels[disambiguation_labels != -100], + ) + if self.config.entity_type_loss: + return self.criterion_type( + ned_type_logits[mask], + disambiguation_labels[:, :, :entity_types][mask], + ) + + if self.relation_disambiguation_loss: + return self.criterion_type( + re_entities_logits[disambiguation_labels != -100], + disambiguation_labels[disambiguation_labels != -100], + ) + return 0 + + def compute_relation_loss(self, relation_labels, re_logits): + return self.compute_loss( + re_logits, relation_labels, relation_labels.view(-1) != -100 + ) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + special_symbols_mask_entities: Optional[torch.Tensor] = None, + start_labels: Optional[torch.Tensor] = None, + end_labels: Optional[torch.Tensor] = None, + disambiguation_labels: Optional[torch.Tensor] = None, + relation_labels: Optional[torch.Tensor] = None, + relation_threshold: float = 0.5, + is_validation: bool = False, + is_prediction: bool = False, + use_predefined_spans: bool = False, + *args, + **kwargs, + ) -> Dict[str, Any]: + batch_size = input_ids.shape[0] + + model_features = self._get_model_features( + input_ids, attention_mask, token_type_ids + ) + + # named entity detection + if use_predefined_spans: + ned_start_logits, ned_start_probabilities, ned_start_predictions = ( + None, + None, + torch.zeros_like(start_labels), + ) + ned_end_logits, ned_end_probabilities, ned_end_predictions = ( + None, + None, + torch.zeros_like(end_labels), + ) + + ned_start_predictions[start_labels > 0] = 1 + ned_end_predictions[end_labels > 0] = 1 + ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)] + ned_start_labels = start_labels + ned_start_labels[start_labels > 0] = 1 + else: + # start boundary prediction + ned_start_logits = self.ned_start_classifier(model_features) + if is_validation or is_prediction: + ned_start_logits = self._mask_logits( + ned_start_logits, prediction_mask + ) # why? + ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) + ned_start_predictions = ned_start_probabilities.argmax(dim=-1) + + # end boundary prediction + ned_start_labels = ( + torch.zeros_like(start_labels) if start_labels is not None else None + ) + + # start_labels contain entity id at their position, we just need 1 for start of entity + if ned_start_labels is not None: + ned_start_labels[start_labels == -100] = -100 + ned_start_labels[start_labels > 0] = 1 + + # compute end logits only if there are any start predictions. + # For each start prediction, n end predictions are made + ned_end_logits = self.compute_ned_end_logits( + ned_start_predictions, + ned_start_labels, + model_features, + prediction_mask, + batch_size, + True, + ) + + if ned_end_logits is not None: + # For each start prediction, n end predictions are made based on + # binary classification ie. argmax at each position. + ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) + ned_end_predictions = ned_end_probabilities.argmax(dim=-1) + else: + ned_end_logits, ned_end_probabilities = None, None + ned_end_predictions = torch.zeros_like(ned_start_predictions) + + if is_prediction or is_validation: + end_preds_count = ned_end_predictions.sum(1) + # If there are no end predictions for a start prediction, remove the start prediction + if (end_preds_count == 0).any() and (ned_start_predictions > 0).any(): + ned_start_predictions[ned_start_predictions == 1] = ( + end_preds_count != 0 + ).long() + ned_end_predictions = ned_end_predictions[end_preds_count != 0] + + if end_labels is not None: + end_labels = end_labels[~(end_labels == -100).all(2)] + + start_position, end_position = ( + (start_labels, end_labels) + if (not is_prediction and not is_validation) + else (ned_start_predictions, ned_end_predictions) + ) + + start_counts = (start_position > 0).sum(1) + if (start_counts > 0).any(): + ned_end_predictions = ned_end_predictions.split(start_counts.tolist()) + # limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0 + # if is_validation or is_prediction: + # ned_start_predictions[ned_start_predictions == 1] = start_counts + # We can only predict relations if we have start and end predictions + if (end_position > 0).sum() > 0: + ends_count = (end_position > 0).sum(1) + model_subject_features = torch.cat( + [ + torch.repeat_interleave( + model_features[start_position > 0], ends_count, dim=0 + ), # start position features + torch.repeat_interleave(model_features, start_counts, dim=0)[ + end_position > 0 + ], # end position features + ], + dim=-1, + ) + ents_count = torch.nn.utils.rnn.pad_sequence( + torch.split(ends_count, start_counts.tolist()), + batch_first=True, + padding_value=0, + ).sum(1) + model_subject_features = torch.nn.utils.rnn.pad_sequence( + torch.split(model_subject_features, ents_count.tolist()), + batch_first=True, + padding_value=-100, + ) + + # if is_validation or is_prediction: + # model_subject_features = model_subject_features[:, :30, :] + + # entity disambiguation. Here relation_disambiguation_loss would only be useful to + # reduce the number of candidate relations for the next step, but currently unused. + if self.config.entity_type_loss or self.relation_disambiguation_loss: + (re_ned_entities_logits) = self.compute_entity_logits( + model_subject_features, + model_features[ + special_symbols_mask | special_symbols_mask_entities + ].view(batch_size, -1, model_features.shape[-1]), + ) + entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item() + ned_type_logits = re_ned_entities_logits[:, :, :entity_types] + re_entities_logits = re_ned_entities_logits[:, :, entity_types:] + + if self.config.entity_type_loss: + ned_type_probabilities = torch.sigmoid(ned_type_logits) + ned_type_predictions = ned_type_probabilities.argmax(dim=-1) + + if self.config.add_entity_embedding: + special_symbols_representation = model_features[ + special_symbols_mask_entities + ].view(batch_size, entity_types, -1) + + entities_representation = torch.einsum( + "bsp,bpe->bse", + ned_type_probabilities, + special_symbols_representation, + ) + model_subject_features = torch.cat( + [model_subject_features, entities_representation], dim=-1 + ) + re_entities_probabilities = torch.sigmoid(re_entities_logits) + re_entities_predictions = re_entities_probabilities.round() + else: + ( + ned_type_logits, + ned_type_probabilities, + re_entities_logits, + re_entities_probabilities, + ) = (None, None, None, None) + ned_type_predictions, re_entities_predictions = ( + torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), + torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), + ) + + # Compute relation logits + re_logits = self.compute_relation_logits( + model_subject_features, + model_features[special_symbols_mask].view( + batch_size, -1, model_features.shape[-1] + ), + ) + + re_probabilities = torch.softmax(re_logits, dim=-1) + # we set a thresshold instead of argmax in cause it needs to be tweaked + re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold + # re_predictions = re_probabilities.argmax(dim=-1) + re_probabilities = re_probabilities[:, :, :, :, 1] + # re_logits, re_probabilities, re_predictions = ( + # torch.zeros( + # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + # ).to(input_ids.device), + # torch.zeros( + # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + # ).to(input_ids.device), + # torch.zeros( + # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + # ).to(input_ids.device), + # ) + + else: + ( + ned_type_logits, + ned_type_probabilities, + re_entities_logits, + re_entities_probabilities, + ) = (None, None, None, None) + ned_type_predictions, re_entities_predictions = ( + torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), + torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), + ) + re_logits, re_probabilities, re_predictions = ( + torch.zeros( + [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + ).to(input_ids.device), + torch.zeros( + [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + ).to(input_ids.device), + torch.zeros( + [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + ).to(input_ids.device), + ) + + # output build + output_dict = dict( + batch_size=batch_size, + ned_start_logits=ned_start_logits, + ned_start_probabilities=ned_start_probabilities, + ned_start_predictions=ned_start_predictions, + ned_end_logits=ned_end_logits, + ned_end_probabilities=ned_end_probabilities, + ned_end_predictions=ned_end_predictions, + ned_type_logits=ned_type_logits, + ned_type_probabilities=ned_type_probabilities, + ned_type_predictions=ned_type_predictions, + re_entities_logits=re_entities_logits, + re_entities_probabilities=re_entities_probabilities, + re_entities_predictions=re_entities_predictions, + re_logits=re_logits, + re_probabilities=re_probabilities, + re_predictions=re_predictions, + ) + + if ( + start_labels is not None + and end_labels is not None + and relation_labels is not None + and is_prediction is False + ): + ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels) + end_labels[end_labels > 0] = 1 + ned_end_loss = self.compute_loss(ned_end_logits, end_labels) + if self.config.entity_type_loss or self.relation_disambiguation_loss: + ned_type_loss = self.compute_ned_type_loss( + disambiguation_labels, + re_ned_entities_logits, + ned_type_logits, + re_entities_logits, + entity_types, + (model_subject_features != -100).all(2), + ) + relation_loss = self.compute_relation_loss(relation_labels, re_logits) + # compute loss. We can skip the relation loss if we are in the first epochs (optional) + if self.config.entity_type_loss or self.relation_disambiguation_loss: + output_dict["loss"] = ( + ned_start_loss + ned_end_loss + relation_loss + ned_type_loss + ) / 4 + output_dict["ned_type_loss"] = ned_type_loss + else: + # output_dict["loss"] = ((1 / 4) * (ned_start_loss + ned_end_loss)) + ( + # (1 / 2) * relation_loss + # ) + output_dict["loss"] = ((1 / 16) * (ned_start_loss + ned_end_loss)) + ( + (7 / 8) * relation_loss + ) + + output_dict["ned_start_loss"] = ned_start_loss + output_dict["ned_end_loss"] = ned_end_loss + output_dict["re_loss"] = relation_loss + + return output_dict diff --git a/relik/reader/pytorch_modules/optim/__init__.py b/relik/reader/pytorch_modules/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..369091133267cfa05240306fbfe5ea3b537d5d9c --- /dev/null +++ b/relik/reader/pytorch_modules/optim/__init__.py @@ -0,0 +1,6 @@ +from relik.reader.pytorch_modules.optim.adamw_with_warmup import ( + AdamWWithWarmupOptimizer, +) +from relik.reader.pytorch_modules.optim.layer_wise_lr_decay import ( + LayerWiseLRDecayOptimizer, +) diff --git a/relik/reader/pytorch_modules/optim/__pycache__/__init__.cpython-310.pyc b/relik/reader/pytorch_modules/optim/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1a2120fdce7683957262bb9aca7a1d5b996aff3 Binary files /dev/null and b/relik/reader/pytorch_modules/optim/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/optim/__pycache__/adamw_with_warmup.cpython-310.pyc b/relik/reader/pytorch_modules/optim/__pycache__/adamw_with_warmup.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8b664b3a2ec9bcf31fab1713ed3549eb8e57bfe Binary files /dev/null and b/relik/reader/pytorch_modules/optim/__pycache__/adamw_with_warmup.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/optim/__pycache__/layer_wise_lr_decay.cpython-310.pyc b/relik/reader/pytorch_modules/optim/__pycache__/layer_wise_lr_decay.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4671056032df114088455e7be7a053565d7c75a6 Binary files /dev/null and b/relik/reader/pytorch_modules/optim/__pycache__/layer_wise_lr_decay.cpython-310.pyc differ diff --git a/relik/reader/pytorch_modules/optim/adamw_with_warmup.py b/relik/reader/pytorch_modules/optim/adamw_with_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..dfaecc4ca3d1c366f25962db4d0024a5b986fd50 --- /dev/null +++ b/relik/reader/pytorch_modules/optim/adamw_with_warmup.py @@ -0,0 +1,66 @@ +from typing import List + +import torch +import transformers +from torch.optim import AdamW + + +class AdamWWithWarmupOptimizer: + def __init__( + self, + lr: float, + warmup_steps: int, + total_steps: int, + weight_decay: float, + no_decay_params: List[str], + ): + self.lr = lr + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.weight_decay = weight_decay + self.no_decay_params = no_decay_params + + def group_params(self, module: torch.nn.Module) -> list: + if self.no_decay_params is not None: + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in module.named_parameters() + if not any(nd in n for nd in self.no_decay_params) + ], + "weight_decay": self.weight_decay, + }, + { + "params": [ + p + for n, p in module.named_parameters() + if any(nd in n for nd in self.no_decay_params) + ], + "weight_decay": 0.0, + }, + ] + + else: + optimizer_grouped_parameters = [ + {"params": module.parameters(), "weight_decay": self.weight_decay} + ] + + return optimizer_grouped_parameters + + def __call__(self, module: torch.nn.Module): + optimizer_grouped_parameters = self.group_params(module) + optimizer = AdamW( + optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay + ) + scheduler = transformers.get_linear_schedule_with_warmup( + optimizer, self.warmup_steps, self.total_steps + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + }, + } diff --git a/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py b/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..d179096153f356196a921c50083c96b3dcd5f246 --- /dev/null +++ b/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py @@ -0,0 +1,104 @@ +import collections +from typing import List + +import torch +import transformers +from torch.optim import AdamW + + +class LayerWiseLRDecayOptimizer: + def __init__( + self, + lr: float, + warmup_steps: int, + total_steps: int, + weight_decay: float, + lr_decay: float, + no_decay_params: List[str], + total_reset: int, + ): + self.lr = lr + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.weight_decay = weight_decay + self.lr_decay = lr_decay + self.no_decay_params = no_decay_params + self.total_reset = total_reset + + def group_layers(self, module) -> dict: + grouped_layers = collections.defaultdict(list) + module_named_parameters = list(module.named_parameters()) + for ln, lp in module_named_parameters: + if "embeddings" in ln: + grouped_layers["embeddings"].append((ln, lp)) + elif "encoder.layer" in ln: + layer_num = ln.split("transformer_model.encoder.layer.")[-1] + layer_num = layer_num[0 : layer_num.index(".")] + grouped_layers[layer_num].append((ln, lp)) + else: + grouped_layers["head"].append((ln, lp)) + + depth = len(grouped_layers) - 1 + final_dict = dict() + for key, value in grouped_layers.items(): + if key == "head": + final_dict[0] = value + elif key == "embeddings": + final_dict[depth] = value + else: + # -1 because layer number starts from zero + final_dict[depth - int(key) - 1] = value + + assert len(module_named_parameters) == sum( + len(v) for _, v in final_dict.items() + ) + + return final_dict + + def group_params(self, module) -> list: + optimizer_grouped_params = [] + for inverse_depth, layer in self.group_layers(module).items(): + layer_lr = self.lr * (self.lr_decay**inverse_depth) + layer_wd_params = { + "params": [ + lp + for ln, lp in layer + if not any(nd in ln for nd in self.no_decay_params) + ], + "weight_decay": self.weight_decay, + "lr": layer_lr, + } + layer_no_wd_params = { + "params": [ + lp + for ln, lp in layer + if any(nd in ln for nd in self.no_decay_params) + ], + "weight_decay": 0, + "lr": layer_lr, + } + + if len(layer_wd_params) != 0: + optimizer_grouped_params.append(layer_wd_params) + if len(layer_no_wd_params) != 0: + optimizer_grouped_params.append(layer_no_wd_params) + + return optimizer_grouped_params + + def __call__(self, module: torch.nn.Module): + optimizer_grouped_parameters = self.group_params(module) + optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) + scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer, + self.warmup_steps, + self.total_steps, + num_cycles=self.total_reset, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + }, + } diff --git a/relik/reader/pytorch_modules/span.py b/relik/reader/pytorch_modules/span.py new file mode 100644 index 0000000000000000000000000000000000000000..fb30c66fcd5c1a698bebd90acccea58e7fe920ba --- /dev/null +++ b/relik/reader/pytorch_modules/span.py @@ -0,0 +1,370 @@ +import collections +import logging +from typing import Any, Dict, Iterator, List + +import torch +import transformers as tr +from lightning_fabric.utilities import move_data_to_device +from torch.utils.data import DataLoader, IterableDataset +from tqdm import tqdm + +from relik.common.torch_utils import get_autocast_context +from relik.common.log import get_logger +from relik.common.utils import get_callable_from_string +from relik.inference.data.objects import AnnotationType +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.pytorch_modules.base import RelikReaderBase + +logger = get_logger(__name__, level=logging.INFO) + + +class RelikReaderForSpanExtraction(RelikReaderBase): + """ + A class for the RelikReader model for span extraction. + + Args: + transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): + The transformer model to use. If `None`, the default model is used. + additional_special_symbols (:obj:`int`, `optional`, defaults to 0): + The number of additional special symbols to add to the tokenizer. + num_layers (:obj:`int`, `optional`): + The number of layers to use. If `None`, all layers are used. + activation (:obj:`str`, `optional`, defaults to "gelu"): + The activation function to use. + linears_hidden_size (:obj:`int`, `optional`, defaults to 512): + The hidden size of the linears. + use_last_k_layers (:obj:`int`, `optional`, defaults to 1): + The number of last layers to use. + training (:obj:`bool`, `optional`, defaults to False): + Whether the model is in training mode. + device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`): + The device to use. If `None`, the default device is used. + tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`): + The tokenizer to use. If `None`, the default tokenizer is used. + dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`): + The dataset to use. If `None`, the default dataset is used. + dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`): + The keyword arguments to pass to the dataset class. + default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): + The default reader class to use. If `None`, the default reader class is used. + **kwargs: + Keyword arguments. + """ + + default_reader_class: str = ( + "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel" + ) + default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset" + + def __init__( + self, + transformer_model: str | tr.PreTrainedModel | None = None, + additional_special_symbols: int = 0, + num_layers: int | None = None, + activation: str = "gelu", + linears_hidden_size: int | None = 512, + use_last_k_layers: int = 1, + training: bool = False, + device: str | torch.device | None = None, + tokenizer: str | tr.PreTrainedTokenizer | None = None, + dataset: IterableDataset | str | None = None, + dataset_kwargs: Dict[str, Any] | None = None, + default_reader_class: tr.PreTrainedModel | str | None = None, + **kwargs, + ): + super().__init__( + transformer_model=transformer_model, + additional_special_symbols=additional_special_symbols, + num_layers=num_layers, + activation=activation, + linears_hidden_size=linears_hidden_size, + use_last_k_layers=use_last_k_layers, + training=training, + device=device, + tokenizer=tokenizer, + dataset=dataset, + default_reader_class=default_reader_class, + **kwargs, + ) + # and instantiate the dataset class + self.dataset = dataset + if self.dataset is None: + self.default_data_class = get_callable_from_string(self.default_data_class) + default_data_kwargs = dict( + dataset_path=None, + materialize_samples=False, + transformer_model=self.tokenizer, + special_symbols=self.default_data_class.get_special_symbols( + self.relik_reader_model.config.additional_special_symbols + ), + for_inference=True, + use_nme=kwargs.get("use_nme", True), + ) + # merge the default data kwargs with the ones passed to the model + default_data_kwargs.update(dataset_kwargs or {}) + self.dataset = self.default_data_class(**default_data_kwargs) + + @torch.no_grad() + @torch.inference_mode() + def _read( + self, + samples: List[RelikReaderSample] | None = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + max_length: int = 1000, + max_batch_size: int = 128, + token_batch_size: int = 2048, + precision: str = 32, + annotation_type: AnnotationType = AnnotationType.CHAR, + progress_bar: bool = False, + *args: object, + **kwargs: object, + ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: + """ + A wrapper around the forward method that returns the predicted labels for each sample. + + Args: + samples (:obj:`List[RelikReaderSample]`, `optional`): + The samples to read. If provided, `text` and `candidates` are ignored. + input_ids (:obj:`torch.Tensor`, `optional`): + The input ids of the text. If `samples` is provided, this is ignored. + attention_mask (:obj:`torch.Tensor`, `optional`): + The attention mask of the text. If `samples` is provided, this is ignored. + token_type_ids (:obj:`torch.Tensor`, `optional`): + The token type ids of the text. If `samples` is provided, this is ignored. + prediction_mask (:obj:`torch.Tensor`, `optional`): + The prediction mask of the text. If `samples` is provided, this is ignored. + special_symbols_mask (:obj:`torch.Tensor`, `optional`): + The special symbols mask of the text. If `samples` is provided, this is ignored. + max_length (:obj:`int`, `optional`, defaults to 1000): + The maximum length of the text. + max_batch_size (:obj:`int`, `optional`, defaults to 128): + The maximum batch size. + token_batch_size (:obj:`int`, `optional`): + The token batch size. + progress_bar (:obj:`bool`, `optional`, defaults to False): + Whether to show a progress bar. + precision (:obj:`str`, `optional`, defaults to 32): + The precision to use for the model. + annotation_type (`AnnotationType`, `optional`, defaults to `AnnotationType.CHAR`): + The type of annotation to return. If `char`, the spans will be in terms of + character offsets. If `word`, the spans will be in terms of word offsets. + *args: + Positional arguments. + **kwargs: + Keyword arguments. + + Returns: + :obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`: + The predicted labels for each sample. + """ + + precision = precision or self.precision + if samples is not None: + + def _read_iterator(): + def samples_it(): + for i, sample in enumerate(samples): + assert sample._mixin_prediction_position is None + sample._mixin_prediction_position = i + if sample.spans is not None and len(sample.spans) > 0: + sample.window_labels = [[s[0], s[1], ""] for s in sample.spans] + yield sample + + next_prediction_position = 0 + position2predicted_sample = {} + + # instantiate dataset + if self.dataset is None: + raise ValueError( + "You need to pass a dataset to the model in order to predict" + ) + self.dataset.samples = samples_it() + self.dataset.model_max_length = max_length + self.dataset.tokens_per_batch = token_batch_size + self.dataset.max_batch_size = max_batch_size + + # instantiate dataloader + iterator = DataLoader( + self.dataset, batch_size=None, num_workers=0, shuffle=False + ) + if progress_bar: + iterator = tqdm(iterator, desc="Predicting with RelikReader") + + with get_autocast_context(self.device, precision): + for batch in iterator: + batch = move_data_to_device(batch, self.device) + batch.update(kwargs) + batch_out = self._batch_predict(**batch) + + for sample in batch_out: + if ( + sample._mixin_prediction_position + >= next_prediction_position + ): + position2predicted_sample[ + sample._mixin_prediction_position + ] = sample + + # yield + while next_prediction_position in position2predicted_sample: + yield position2predicted_sample[next_prediction_position] + del position2predicted_sample[next_prediction_position] + next_prediction_position += 1 + + outputs = list(_read_iterator()) + for sample in outputs: + self.dataset.merge_patches_predictions(sample) + if annotation_type == AnnotationType.CHAR: + self.dataset.convert_to_char_annotations(sample) + elif annotation_type == AnnotationType.WORD: + self.dataset.convert_to_word_annotations(sample) + else: + raise ValueError( + f"Annotation type {annotation_type} not recognized. " + f"Please choose one of {list(AnnotationType)}." + ) + + else: + outputs = list( + self._batch_predict( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + *args, + **kwargs, + ) + ) + return outputs + + def _batch_predict( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + sample: List[RelikReaderSample] | None = None, + top_k: int = 5, # the amount of top-k most probable entities to predict + *args, + **kwargs, + ) -> Iterator[RelikReaderSample]: + """ + A wrapper around the forward method that returns the predicted labels for each sample. + It also adds the predicted labels to the samples. + + Args: + input_ids (:obj:`torch.Tensor`): + The input ids of the text. + attention_mask (:obj:`torch.Tensor`): + The attention mask of the text. + token_type_ids (:obj:`torch.Tensor`, `optional`): + The token type ids of the text. + prediction_mask (:obj:`torch.Tensor`, `optional`): + The prediction mask of the text. + special_symbols_mask (:obj:`torch.Tensor`, `optional`): + The special symbols mask of the text. + sample (:obj:`List[RelikReaderSample]`, `optional`): + The samples to read. If provided, `text` and `candidates` are ignored. + top_k (:obj:`int`, `optional`, defaults to 5): + The amount of top-k most probable entities to predict. + *args: + Positional arguments. + **kwargs: + Keyword arguments. + + Returns: + The predicted labels for each sample. + """ + forward_output = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + prediction_mask=prediction_mask, + special_symbols_mask=special_symbols_mask, + *args, + **kwargs, + ) + + ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() + ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() + ed_predictions = forward_output["ed_predictions"].cpu().numpy() + ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() + + batch_predictable_candidates = kwargs["predictable_candidates"] + patch_offset = kwargs["patch_offset"] + for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( + sample, + ned_start_predictions, + ned_end_predictions, + ed_predictions, + ed_probabilities, + batch_predictable_candidates, + patch_offset, + ): + ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] + ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] + + final_class2predicted_spans = collections.defaultdict(list) + spans2predicted_probabilities = dict() + for start_token_index, end_token_index in zip( + ne_start_indices, ne_end_indices + ): + # predicted candidate + token_class = edp[start_token_index + 1] - 1 + predicted_candidate_title = pred_cands[token_class] + final_class2predicted_spans[predicted_candidate_title].append( + [start_token_index, end_token_index] + ) + + # candidates probabilities + classes_probabilities = edpr[start_token_index + 1] + classes_probabilities_best_indices = classes_probabilities.argsort()[ + ::-1 + ] + titles_2_probs = [] + top_k = ( + min( + top_k, + len(classes_probabilities_best_indices), + ) + if top_k != -1 + else len(classes_probabilities_best_indices) + ) + for i in range(top_k): + titles_2_probs.append( + ( + pred_cands[classes_probabilities_best_indices[i] - 1], + classes_probabilities[ + classes_probabilities_best_indices[i] + ].item(), + ) + ) + spans2predicted_probabilities[ + (start_token_index, end_token_index) + ] = titles_2_probs + + if "patches" not in ts._d: + ts._d["patches"] = dict() + + ts._d["patches"][po] = dict() + sample_patch = ts._d["patches"][po] + + sample_patch["predicted_window_labels"] = final_class2predicted_spans + sample_patch["span_title_probabilities"] = spans2predicted_probabilities + + # additional info + sample_patch["predictable_candidates"] = pred_cands + + # try-out for a new format + sample_patch["predicted_spans"] = final_class2predicted_spans + sample_patch[ + "predicted_spans_probabilities" + ] = spans2predicted_probabilities + + yield ts diff --git a/relik/reader/pytorch_modules/triplet.py b/relik/reader/pytorch_modules/triplet.py new file mode 100644 index 0000000000000000000000000000000000000000..959b9c1b1952d2bbdff7354fa040209448f50962 --- /dev/null +++ b/relik/reader/pytorch_modules/triplet.py @@ -0,0 +1,396 @@ +import contextlib +import logging +from typing import Any, Dict, Iterator, List + +import numpy as np +import torch +import transformers as tr +from lightning_fabric.utilities import move_data_to_device +from torch.utils.data import DataLoader, IterableDataset +from tqdm import tqdm + +from relik.common.log import get_logger +from relik.common.torch_utils import get_autocast_context +from relik.common.utils import get_callable_from_string +from relik.inference.data.objects import AnnotationType +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.pytorch_modules.base import RelikReaderBase +from relik.retriever.pytorch_modules import PRECISION_MAP + +logger = get_logger(__name__, level=logging.INFO) + + +class RelikReaderForTripletExtraction(RelikReaderBase): + """ + A class for the RelikReader model for triplet extraction. + + Args: + transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): + The transformer model to use. If `None`, the default model is used. + additional_special_symbols (:obj:`int`, `optional`, defaults to 0): + The number of additional special symbols to add to the tokenizer. + num_layers (:obj:`int`, `optional`): + The number of layers to use. If `None`, all layers are used. + activation (:obj:`str`, `optional`, defaults to "gelu"): + The activation function to use. + linears_hidden_size (:obj:`int`, `optional`, defaults to 512): + The hidden size of the linears. + use_last_k_layers (:obj:`int`, `optional`, defaults to 1): + The number of last layers to use. + training (:obj:`bool`, `optional`, defaults to False): + Whether the model is in training mode. + device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`): + The device to use. If `None`, the default device is used. + tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`): + The tokenizer to use. If `None`, the default tokenizer is used. + dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`): + The dataset to use. If `None`, the default dataset is used. + dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`): + The keyword arguments to pass to the dataset class. + default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): + The default reader class to use. If `None`, the default reader class is used. + **kwargs: + Keyword arguments. + """ + + default_reader_class: str = ( + "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderREModel" + ) + default_data_class: str = "relik.reader.data.relik_reader_re_data.RelikREDataset" + + def __init__( + self, + transformer_model: str | tr.PreTrainedModel | None = None, + additional_special_symbols: int = 0, + additional_special_symbols_types: int = 0, + entity_type_loss: bool | None = None, + add_entity_embedding: bool | None = None, + num_layers: int | None = None, + activation: str = "gelu", + linears_hidden_size: int | None = 512, + use_last_k_layers: int = 1, + training: bool = False, + device: str | torch.device | None = None, + tokenizer: str | tr.PreTrainedTokenizer | None = None, + dataset: IterableDataset | str | None = None, + dataset_kwargs: Dict[str, Any] | None = None, + default_reader_class: tr.PreTrainedModel | str | None = None, + **kwargs, + ): + super().__init__( + transformer_model=transformer_model, + additional_special_symbols=additional_special_symbols, + additional_special_symbols_types=additional_special_symbols_types, + entity_type_loss=entity_type_loss, + add_entity_embedding=add_entity_embedding, + num_layers=num_layers, + activation=activation, + linears_hidden_size=linears_hidden_size, + use_last_k_layers=use_last_k_layers, + training=training, + device=device, + tokenizer=tokenizer, + dataset=dataset, + default_reader_class=default_reader_class, + **kwargs, + ) + # and instantiate the dataset class + self.dataset = dataset + if self.dataset is None and training is False: + self.default_data_class = get_callable_from_string(self.default_data_class) + default_data_kwargs = dict( + dataset_path=None, + materialize_samples=False, + transformer_model=self.tokenizer, + special_symbols=self.default_data_class.get_special_symbols_re( + self.relik_reader_model.config.additional_special_symbols, + use_nme=kwargs.get("use_nme_re", False), + ), + special_symbols_types=self.default_data_class.get_special_symbols( + self.relik_reader_model.config.additional_special_symbols_types - 1 + ) + if self.relik_reader_model.config.additional_special_symbols_types > 0 + else [], + for_inference=True, + use_nme=kwargs.get("use_nme", False), + ) + # merge the default data kwargs with the ones passed to the model + default_data_kwargs.update(dataset_kwargs or {}) + self.dataset = self.default_data_class(**default_data_kwargs) + + @torch.no_grad() + @torch.inference_mode() + def _read( + self, + samples: List[RelikReaderSample] | None = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + max_length: int = 2048, + max_batch_size: int = 128, + token_batch_size: int = 2048, + precision: str = 32, + annotation_type: AnnotationType = AnnotationType.CHAR, + progress_bar: bool = False, + *args: object, + **kwargs: object, + ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: + """ + A wrapper around the forward method that returns the predicted labels for each sample. + + Args: + samples (:obj:`List[RelikReaderSample]`, `optional`): + The samples to read. If provided, `text` and `candidates` are ignored. + input_ids (:obj:`torch.Tensor`, `optional`): + The input ids of the text. If `samples` is provided, this is ignored. + attention_mask (:obj:`torch.Tensor`, `optional`): + The attention mask of the text. If `samples` is provided, this is ignored. + token_type_ids (:obj:`torch.Tensor`, `optional`): + The token type ids of the text. If `samples` is provided, this is ignored. + prediction_mask (:obj:`torch.Tensor`, `optional`): + The prediction mask of the text. If `samples` is provided, this is ignored. + special_symbols_mask (:obj:`torch.Tensor`, `optional`): + The special symbols mask of the text. If `samples` is provided, this is ignored. + max_length (:obj:`int`, `optional`, defaults to 1000): + The maximum length of the text. + max_batch_size (:obj:`int`, `optional`, defaults to 128): + The maximum batch size. + token_batch_size (:obj:`int`, `optional`): + The token batch size. + progress_bar (:obj:`bool`, `optional`, defaults to False): + Whether to show a progress bar. + precision (:obj:`str`, `optional`, defaults to 32): + The precision to use for the model. + annotation_type (`AnnotationType`, `optional`, defaults to `AnnotationType.CHAR`): + The type of annotation to return. If `char`, the spans will be in terms of + character offsets. If `word`, the spans will be in terms of word offsets. + *args: + Positional arguments. + **kwargs: + Keyword arguments. + + Returns: + :obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`: + The predicted labels for each sample. + """ + + precision = precision or self.precision + if samples is not None: + + def _read_iterator(): + def samples_it(): + for i, sample in enumerate(samples): + assert sample._mixin_prediction_position is None + sample._mixin_prediction_position = i + if sample.spans is not None and len(sample.spans) > 0: + entities = [] + offset_span = sample.char2token_start[str(sample.offset)] + for span_start, span_end in sample.spans: + if str(span_start) not in sample.char2token_start: + # span_start is in the middle of a word + # retrieve the first token of the word + while str(span_start) not in sample.char2token_start: + span_start -= 1 + # skip + if span_start < 0: + break + if str(span_end) not in sample.char2token_end: + # span_end is in the middle of a word + # retrieve the last token of the word + while str(span_end) not in sample.char2token_end: + span_end += 1 + # skip + if span_end >= int(list(sample.char2token_end.keys())[-1]): + break + + if span_start < 0 or span_end > int(list(sample.char2token_end.keys())[-1]): + continue + entities.append([sample.char2token_start[str(span_start)]-offset_span, sample.char2token_end[str(span_end)]+1-offset_span, ""]) + sample.entities = entities + yield sample + + next_prediction_position = 0 + position2predicted_sample = {} + + # instantiate dataset + if self.dataset is None: + raise ValueError( + "You need to pass a dataset to the model in order to predict" + ) + self.dataset.samples = samples_it() + self.dataset.model_max_length = max_length + self.dataset.tokens_per_batch = token_batch_size + self.dataset.max_batch_size = max_batch_size + + # instantiate dataloader + iterator = DataLoader( + self.dataset, batch_size=None, num_workers=0, shuffle=False + ) + if progress_bar: + iterator = tqdm(iterator, desc="Predicting with RelikReader") + + with get_autocast_context(self.device, precision): + for batch in iterator: + batch = move_data_to_device(batch, self.device) + batch.update(kwargs) + batch_out = self._batch_predict(**batch) + + for sample in batch_out: + if ( + sample._mixin_prediction_position + >= next_prediction_position + ): + position2predicted_sample[ + sample._mixin_prediction_position + ] = sample + + # yield + while next_prediction_position in position2predicted_sample: + yield position2predicted_sample[next_prediction_position] + del position2predicted_sample[next_prediction_position] + next_prediction_position += 1 + + outputs = list(_read_iterator()) + for sample in outputs: + self.dataset.merge_patches_predictions(sample) + if annotation_type == AnnotationType.CHAR: + self.dataset.convert_to_char_annotations(sample) + elif annotation_type == AnnotationType.WORD: + self.dataset.convert_to_word_annotations(sample) + else: + raise ValueError( + f"Annotation type {annotation_type} not recognized. " + f"Please choose one of {list(AnnotationType)}." + ) + + else: + outputs = list( + self._batch_predict( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + *args, + **kwargs, + ) + ) + return outputs + + def _batch_predict( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + special_symbols_mask_entities: torch.Tensor | None = None, + sample: List[RelikReaderSample] | None = None, + *args, + **kwargs, + ) -> Iterator[RelikReaderSample]: + """ + A wrapper around the forward method that returns the predicted labels for each sample. + It also adds the predicted labels to the samples. + + Args: + input_ids (:obj:`torch.Tensor`): + The input ids of the text. + attention_mask (:obj:`torch.Tensor`): + The attention mask of the text. + token_type_ids (:obj:`torch.Tensor`, `optional`): + The token type ids of the text. + prediction_mask (:obj:`torch.Tensor`, `optional`): + The prediction mask of the text. + special_symbols_mask (:obj:`torch.Tensor`, `optional`): + The special symbols mask of the text. + sample (:obj:`List[RelikReaderSample]`, `optional`): + The samples to read. If provided, `text` and `candidates` are ignored. + top_k (:obj:`int`, `optional`, defaults to 5): + The amount of top-k most probable entities to predict. + *args: + Positional arguments. + **kwargs: + Keyword arguments. + + Returns: + The predicted labels for each sample. + """ + forward_output = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + prediction_mask=prediction_mask, + special_symbols_mask=special_symbols_mask, + special_symbols_mask_entities=special_symbols_mask_entities, + is_prediction=True, + *args, + **kwargs, + ) + + ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() + ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy() + ed_predictions = forward_output["re_entities_predictions"].cpu().numpy() + ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy() + re_predictions = forward_output["re_predictions"].cpu().numpy() + re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy() + + for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip( + sample, + ned_start_predictions, + ned_end_predictions, + re_predictions, + re_probabilities, + ed_predictions, + ned_type_predictions, + ): + ne_end = ne_end.cpu().numpy() + entities = [] + if self.relik_reader_model.config.entity_type_loss: + starts = np.argwhere(ne_st) + i = 0 + for start, end in zip(starts, ne_end): + ends = np.argwhere(end) + for e in ends: + entities.append([start[0], e[0], ne_et[i]]) + i += 1 + # if i == len(ne_et): + # break + # if i == len(ne_et): + # break + else: + starts = np.argwhere(ne_st) + for start, end in zip(starts, ne_end): + ends = np.argwhere(end) + for e in ends: + entities.append([start[0], e[0]]) + + edp = edp[: len(entities)] + re_pred = re_pred[: len(entities), : len(entities)] + re_prob = re_prob[: len(entities), : len(entities)] + possible_re = np.argwhere(re_pred) + predicted_triplets = [] + predicted_triplets_prob = [] + for i, j, r in possible_re: + if self.relik_reader_model.relation_disambiguation_loss: + if not ( + i != j + and edp[i, r] == 1 + and edp[j, r] == 1 + and edp[i, 0] == 0 + and edp[j, 0] == 0 + ): + continue + predicted_triplets.append([i, j, r]) + predicted_triplets_prob.append(re_prob[i, j, r]) + + ts._d["predicted_relations"] = predicted_triplets + ts._d["predicted_entities"] = entities + ts._d["predicted_relations_probabilities"] = predicted_triplets_prob + + # try-out for a new format + ts._d["predicted_triples"] = predicted_triplets + + yield ts diff --git a/relik/reader/relik_reader_core.py b/relik/reader/relik_reader_core.py new file mode 100644 index 0000000000000000000000000000000000000000..4e2c3bfe759c75a34c4c829077bda2414c8d4be7 --- /dev/null +++ b/relik/reader/relik_reader_core.py @@ -0,0 +1,497 @@ +import collections +from typing import Any, Dict, Iterator, List, Optional + +import torch +from transformers import AutoModel +from transformers.activations import ClippedGELUActivation, GELUActivation +from transformers.modeling_utils import PoolerEndLogits + +from relik.reader.data.relik_reader_sample import RelikReaderSample + +activation2functions = { + "relu": torch.nn.ReLU(), + "gelu": GELUActivation(), + "gelu_10": ClippedGELUActivation(-10, 10), +} + + +class RelikReaderCoreModel(torch.nn.Module): + def __init__( + self, + transformer_model: str, + additional_special_symbols: int, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + training: bool = False, + ) -> None: + super().__init__() + + # Transformer model declaration + self.transformer_model_name = transformer_model + self.transformer_model = ( + AutoModel.from_pretrained(transformer_model) + if num_layers is None + else AutoModel.from_pretrained( + transformer_model, num_hidden_layers=num_layers + ) + ) + # self.transformer_model.resize_token_embeddings( + # self.transformer_model.config.vocab_size + additional_special_symbols + # ) + + self.activation = activation + self.linears_hidden_size = linears_hidden_size + self.use_last_k_layers = use_last_k_layers + + # named entity detection layers + self.ned_start_classifier = self._get_projection_layer( + self.activation, last_hidden=2, layer_norm=False + ) + self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) + + # END entity disambiguation layer + self.ed_start_projector = self._get_projection_layer(self.activation) + self.ed_end_projector = self._get_projection_layer(self.activation) + + self.training = training + + # criterion + self.criterion = torch.nn.CrossEntropyLoss() + + def _get_projection_layer( + self, + activation: str, + last_hidden: Optional[int] = None, + input_hidden=None, + layer_norm: bool = True, + ) -> torch.nn.Sequential: + head_components = [ + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.transformer_model.config.hidden_size * self.use_last_k_layers + if input_hidden is None + else input_hidden, + self.linears_hidden_size, + ), + activation2functions[activation], + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.linears_hidden_size, + self.linears_hidden_size if last_hidden is None else last_hidden, + ), + ] + + if layer_norm: + head_components.append( + torch.nn.LayerNorm( + self.linears_hidden_size if last_hidden is None else last_hidden, + self.transformer_model.config.layer_norm_eps, + ) + ) + + return torch.nn.Sequential(*head_components) + + def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + mask = mask.unsqueeze(-1) + if next(self.parameters()).dtype == torch.float16: + logits = logits * (1 - mask) - 65500 * mask + else: + logits = logits * (1 - mask) - 1e30 * mask + return logits + + def _get_model_features( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor], + ): + model_input = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "output_hidden_states": self.use_last_k_layers > 1, + } + + if token_type_ids is not None: + model_input["token_type_ids"] = token_type_ids + + model_output = self.transformer_model(**model_input) + + if self.use_last_k_layers > 1: + model_features = torch.cat( + model_output[1][-self.use_last_k_layers :], dim=-1 + ) + else: + model_features = model_output[0] + + return model_features + + def compute_ned_end_logits( + self, + start_predictions, + start_labels, + model_features, + prediction_mask, + batch_size, + ) -> Optional[torch.Tensor]: + # todo: maybe when constraining on the spans, + # we should not use a prediction_mask for the end tokens. + # at least we should not during training imo + start_positions = start_labels if self.training else start_predictions + start_positions_indices = ( + torch.arange(start_positions.size(1), device=start_positions.device) + .unsqueeze(0) + .expand(batch_size, -1)[start_positions > 0] + ).to(start_positions.device) + + if len(start_positions_indices) > 0: + expanded_features = torch.cat( + [ + model_features[i].unsqueeze(0).expand(x, -1, -1) + for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) + if x > 0 + ], + dim=0, + ).to(start_positions_indices.device) + + expanded_prediction_mask = torch.cat( + [ + prediction_mask[i].unsqueeze(0).expand(x, -1) + for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) + if x > 0 + ], + dim=0, + ).to(expanded_features.device) + + end_logits = self.ned_end_classifier( + hidden_states=expanded_features, + start_positions=start_positions_indices, + p_mask=expanded_prediction_mask, + ) + + return end_logits + + return None + + def compute_classification_logits( + self, + model_features, + special_symbols_mask, + prediction_mask, + batch_size, + start_positions=None, + end_positions=None, + ) -> torch.Tensor: + if start_positions is None or end_positions is None: + start_positions = torch.zeros_like(prediction_mask) + end_positions = torch.zeros_like(prediction_mask) + + model_start_features = self.ed_start_projector(model_features) + model_end_features = self.ed_end_projector(model_features) + model_end_features[start_positions > 0] = model_end_features[end_positions > 0] + + model_ed_features = torch.cat( + [model_start_features, model_end_features], dim=-1 + ) + + # computing ed features + classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() + special_symbols_representation = model_ed_features[special_symbols_mask].view( + batch_size, classes_representations, -1 + ) + + logits = torch.bmm( + model_ed_features, + torch.permute(special_symbols_representation, (0, 2, 1)), + ) + + logits = self._mask_logits(logits, prediction_mask) + + return logits + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + start_labels: Optional[torch.Tensor] = None, + end_labels: Optional[torch.Tensor] = None, + use_predefined_spans: bool = False, + *args, + **kwargs, + ) -> Dict[str, Any]: + batch_size, seq_len = input_ids.shape + + model_features = self._get_model_features( + input_ids, attention_mask, token_type_ids + ) + + # named entity detection if required + if use_predefined_spans: # no need to compute spans + ned_start_logits, ned_start_probabilities, ned_start_predictions = ( + None, + None, + torch.clone(start_labels) + if start_labels is not None + else torch.zeros_like(input_ids), + ) + ned_end_logits, ned_end_probabilities, ned_end_predictions = ( + None, + None, + torch.clone(end_labels) + if end_labels is not None + else torch.zeros_like(input_ids), + ) + + ned_start_predictions[ned_start_predictions > 0] = 1 + ned_end_predictions[ned_end_predictions > 0] = 1 + + else: # compute spans + # start boundary prediction + ned_start_logits = self.ned_start_classifier(model_features) + ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) + ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) + ned_start_predictions = ned_start_probabilities.argmax(dim=-1) + + # end boundary prediction + ned_start_labels = ( + torch.zeros_like(start_labels) if start_labels is not None else None + ) + + if ned_start_labels is not None: + ned_start_labels[start_labels == -100] = -100 + ned_start_labels[start_labels > 0] = 1 + + ned_end_logits = self.compute_ned_end_logits( + ned_start_predictions, + ned_start_labels, + model_features, + prediction_mask, + batch_size, + ) + + if ned_end_logits is not None: + ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) + ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) + else: + ned_end_logits, ned_end_probabilities = None, None + ned_end_predictions = ned_start_predictions.new_zeros(batch_size) + + # flattening end predictions + # (flattening can happen only if the + # end boundaries were not predicted using the gold labels) + if not self.training: + flattened_end_predictions = torch.clone(ned_start_predictions) + flattened_end_predictions[flattened_end_predictions > 0] = 0 + + batch_start_predictions = list() + for elem_idx in range(batch_size): + batch_start_predictions.append( + torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() + ) + + # check that the total number of start predictions + # is equal to the end predictions + total_start_predictions = sum(map(len, batch_start_predictions)) + total_end_predictions = len(ned_end_predictions) + assert ( + total_start_predictions == 0 + or total_start_predictions == total_end_predictions + ), ( + f"Total number of start predictions = {total_start_predictions}. " + f"Total number of end predictions = {total_end_predictions}" + ) + + curr_end_pred_num = 0 + for elem_idx, bsp in enumerate(batch_start_predictions): + for sp in bsp: + ep = ned_end_predictions[curr_end_pred_num].item() + if ep < sp: + ep = sp + + # if we already set this span throw it (no overlap) + if flattened_end_predictions[elem_idx, ep] == 1: + ned_start_predictions[elem_idx, sp] = 0 + else: + flattened_end_predictions[elem_idx, ep] = 1 + + curr_end_pred_num += 1 + + ned_end_predictions = flattened_end_predictions + + start_position, end_position = ( + (start_labels, end_labels) + if self.training + else (ned_start_predictions, ned_end_predictions) + ) + + # Entity disambiguation + ed_logits = self.compute_classification_logits( + model_features, + special_symbols_mask, + prediction_mask, + batch_size, + start_position, + end_position, + ) + ed_probabilities = torch.softmax(ed_logits, dim=-1) + ed_predictions = torch.argmax(ed_probabilities, dim=-1) + + # output build + output_dict = dict( + batch_size=batch_size, + ned_start_logits=ned_start_logits, + ned_start_probabilities=ned_start_probabilities, + ned_start_predictions=ned_start_predictions, + ned_end_logits=ned_end_logits, + ned_end_probabilities=ned_end_probabilities, + ned_end_predictions=ned_end_predictions, + ed_logits=ed_logits, + ed_probabilities=ed_probabilities, + ed_predictions=ed_predictions, + ) + + # compute loss if labels + if start_labels is not None and end_labels is not None and self.training: + # named entity detection loss + + # start + if ned_start_logits is not None: + ned_start_loss = self.criterion( + ned_start_logits.view(-1, ned_start_logits.shape[-1]), + ned_start_labels.view(-1), + ) + else: + ned_start_loss = 0 + + # end + if ned_end_logits is not None: + ned_end_labels = torch.zeros_like(end_labels) + ned_end_labels[end_labels == -100] = -100 + ned_end_labels[end_labels > 0] = 1 + + ned_end_loss = self.criterion( + ned_end_logits, + ( + torch.arange( + ned_end_labels.size(1), device=ned_end_labels.device + ) + .unsqueeze(0) + .expand(batch_size, -1)[ned_end_labels > 0] + ).to(ned_end_labels.device), + ) + + else: + ned_end_loss = 0 + + # entity disambiguation loss + start_labels[ned_start_labels != 1] = -100 + ed_labels = torch.clone(start_labels) + ed_labels[end_labels > 0] = end_labels[end_labels > 0] + ed_loss = self.criterion( + ed_logits.view(-1, ed_logits.shape[-1]), + ed_labels.view(-1), + ) + + output_dict["ned_start_loss"] = ned_start_loss + output_dict["ned_end_loss"] = ned_end_loss + output_dict["ed_loss"] = ed_loss + + output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss + + return output_dict + + def batch_predict( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + sample: Optional[List[RelikReaderSample]] = None, + top_k: int = 5, # the amount of top-k most probable entities to predict + *args, + **kwargs, + ) -> Iterator[RelikReaderSample]: + forward_output = self.forward( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + ) + + ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() + ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() + ed_predictions = forward_output["ed_predictions"].cpu().numpy() + ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() + + batch_predictable_candidates = kwargs["predictable_candidates"] + patch_offset = kwargs["patch_offset"] + for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( + sample, + ned_start_predictions, + ned_end_predictions, + ed_predictions, + ed_probabilities, + batch_predictable_candidates, + patch_offset, + ): + ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] + ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] + + final_class2predicted_spans = collections.defaultdict(list) + spans2predicted_probabilities = dict() + for start_token_index, end_token_index in zip( + ne_start_indices, ne_end_indices + ): + # predicted candidate + token_class = edp[start_token_index + 1] - 1 + predicted_candidate_title = pred_cands[token_class] + final_class2predicted_spans[predicted_candidate_title].append( + [start_token_index, end_token_index] + ) + + # candidates probabilities + classes_probabilities = edpr[start_token_index + 1] + classes_probabilities_best_indices = classes_probabilities.argsort()[ + ::-1 + ] + titles_2_probs = [] + top_k = ( + min( + top_k, + len(classes_probabilities_best_indices), + ) + if top_k != -1 + else len(classes_probabilities_best_indices) + ) + for i in range(top_k): + titles_2_probs.append( + ( + pred_cands[classes_probabilities_best_indices[i] - 1], + classes_probabilities[ + classes_probabilities_best_indices[i] + ].item(), + ) + ) + spans2predicted_probabilities[ + (start_token_index, end_token_index) + ] = titles_2_probs + + if "patches" not in ts._d: + ts._d["patches"] = dict() + + ts._d["patches"][po] = dict() + sample_patch = ts._d["patches"][po] + + sample_patch["predicted_window_labels"] = final_class2predicted_spans + sample_patch["span_title_probabilities"] = spans2predicted_probabilities + + # additional info + sample_patch["predictable_candidates"] = pred_cands + + yield ts diff --git a/relik/reader/relik_reader_predictor.py b/relik/reader/relik_reader_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..456cba867c59d65bce5b4ce183a351b8b09f314a --- /dev/null +++ b/relik/reader/relik_reader_predictor.py @@ -0,0 +1,177 @@ +import logging +from typing import Iterable, Iterator, List, Optional + +import hydra +import torch +from lightning.pytorch.utilities import move_data_to_device +from torch.utils.data import DataLoader, IterableDataset +from tqdm import tqdm + +from relik.reader.data.patches import merge_patches_predictions +from relik.reader.data.relik_reader_sample import ( + RelikReaderSample, + load_relik_reader_samples, +) +from relik.reader.relik_reader_core import RelikReaderCoreModel +from relik.reader.utils.special_symbols import NME_SYMBOL + +logger = logging.getLogger(__name__) + + +def convert_tokens_to_char_annotations( + sample: RelikReaderSample, remove_nmes: bool = False +): + char_annotations = set() + + for ( + predicted_entity, + predicted_spans, + ) in sample.predicted_window_labels.items(): + if predicted_entity == NME_SYMBOL and remove_nmes: + continue + + for span_start, span_end in predicted_spans: + span_start = sample.token2char_start[str(span_start)] + span_end = sample.token2char_end[str(span_end)] + + char_annotations.add((span_start, span_end, predicted_entity)) + + char_probs_annotations = dict() + for ( + span_start, + span_end, + ), candidates_probs in sample.span_title_probabilities.items(): + span_start = sample.token2char_start[str(span_start)] + span_end = sample.token2char_end[str(span_end)] + char_probs_annotations[(span_start, span_end)] = { + title for title, _ in candidates_probs + } + + sample.predicted_window_labels_chars = char_annotations + sample.probs_window_labels_chars = char_probs_annotations + + +class RelikReaderPredictor: + def __init__( + self, + relik_reader_core: RelikReaderCoreModel, + dataset_conf: Optional[dict] = None, + predict_nmes: bool = False, + dataloader: Optional[DataLoader] = None, + ) -> None: + self.relik_reader_core = relik_reader_core + self.dataset_conf = dataset_conf + self.predict_nmes = predict_nmes + self.dataloader: DataLoader | None = dataloader + + if self.dataset_conf is not None and self.dataset is None: + # instantiate dataset + self.dataset = hydra.utils.instantiate( + dataset_conf, + dataset_path=None, + samples=None, + ) + + def predict( + self, + path: Optional[str], + samples: Optional[Iterable[RelikReaderSample]], + dataset_conf: Optional[dict], + token_batch_size: int = 1024, + progress_bar: bool = False, + **kwargs, + ) -> List[RelikReaderSample]: + annotated_samples = list( + self._predict(path, samples, dataset_conf, token_batch_size, progress_bar) + ) + for sample in annotated_samples: + merge_patches_predictions(sample) + convert_tokens_to_char_annotations( + sample, remove_nmes=not self.predict_nmes + ) + return annotated_samples + + def _predict( + self, + path: Optional[str], + samples: Optional[Iterable[RelikReaderSample]], + dataset_conf: dict, + token_batch_size: int = 1024, + progress_bar: bool = False, + **kwargs, + ) -> Iterator[RelikReaderSample]: + assert ( + path is not None or samples is not None + ), "Either predict on a path or on an iterable of samples" + + next_prediction_position = 0 + position2predicted_sample = {} + + if self.dataloader is not None: + iterator = self.dataloader + for i, sample in enumerate(self.dataloader.dataset.samples): + sample._mixin_prediction_position = i + else: + samples = load_relik_reader_samples(path) if samples is None else samples + + # setup infrastructure to re-yield in order + def samples_it(): + for i, sample in enumerate(samples): + assert sample._mixin_prediction_position is None + sample._mixin_prediction_position = i + yield sample + + # instantiate dataset + if getattr(self, "dataset", None) is not None: + dataset = self.dataset + dataset.samples = samples_it() + dataset.tokens_per_batch = token_batch_size + else: + dataset = hydra.utils.instantiate( + dataset_conf, + dataset_path=None, + samples=samples_it(), + tokens_per_batch=token_batch_size, + ) + + # instantiate dataloader + iterator = DataLoader( + dataset, batch_size=None, num_workers=0, shuffle=False + ) + if progress_bar: + iterator = tqdm(iterator, desc="Predicting") + + model_device = next(self.relik_reader_core.parameters()).device + + with torch.inference_mode(): + for batch in iterator: + # do batch predict + with torch.autocast( + "cpu" if model_device == torch.device("cpu") else "cuda" + ): + batch = move_data_to_device(batch, model_device) + batch_out = self.relik_reader_core._batch_predict(**batch) + # update prediction position position + for sample in batch_out: + if sample._mixin_prediction_position >= next_prediction_position: + position2predicted_sample[ + sample._mixin_prediction_position + ] = sample + + # yield + while next_prediction_position in position2predicted_sample: + yield position2predicted_sample[next_prediction_position] + del position2predicted_sample[next_prediction_position] + next_prediction_position += 1 + + if len(position2predicted_sample) > 0: + logger.warning( + "It seems samples have been discarded in your dataset. " + "This means that you WON'T have a prediction for each input sample. " + "Prediction order will also be partially disrupted" + ) + for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]): + yield v + + if progress_bar: + iterator.close() diff --git a/relik/reader/trainer/__init__.py b/relik/reader/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/trainer/predict-reader.py b/relik/reader/trainer/predict-reader.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc9078495016e124fd745427f461a938b00d574 --- /dev/null +++ b/relik/reader/trainer/predict-reader.py @@ -0,0 +1,58 @@ +from relik.retriever import GoldenRetriever + +from relik.retriever.indexers.inmemory import InMemoryDocumentIndex +from relik.retriever.indexers.document import DocumentStore +from relik.retriever import GoldenRetriever +from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction +from relik.reader.utils.strong_matching_eval import StrongMatching +from relik.reader.data.relik_reader_data import RelikDataset + +from relik.inference.annotator import Relik +from relik.inference.data.objects import ( + AnnotationType, + RelikOutput, + Span, + TaskType, + Triples, +) + +def load_model(): + + retriever = GoldenRetriever( + question_encoder="/home/carlos/amr-parsing-master/sentence-similarity/retriever/wandb/wandb/latest-run/files/retriever/question_encoder", + document_index=InMemoryDocumentIndex( + documents=DocumentStore.from_file( + "/home/carlos/amr-parsing-master/sentence-similarity/retriever/wandb/wandb/latest-run/files/retriever/document_index/documents.jsonl" + ), + metadata_fields=["definition"], + separator=' ', + device="cuda" + ), + devide="cuda" + + ) + retriever.index() + + reader = RelikReaderForSpanExtraction("/home/carlos/amr-parsing-master/sentence-similarity/relik-main/experiments/relik-reader-deberta-small-io/2024-04-26/12-56-49/wandb/run-20240426_125654-vfznbu4r/files/hf_model/hf_model", + dataset_kwargs={"use_nme": True}) + + relik = Relik(reader=reader, retriever=retriever, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu") + + relik() + + val_dataset: RelikDataset = hydra.utils.instantiate( + cfg.data.val_dataset, + dataset_path=to_absolute_path(cfg.data.val_dataset_path), + ) + + predicted_samples = relik.predict( + dataset_path, token_batch_size=token_batch_size + ) + + eval_dict = StrongMatching()(predicted_samples) + pprint(eval_dict) + + if output_path is not None: + with open(output_path, "w") as f: + for sample in predicted_samples: + f.write(sample.to_jsons() + "\n") diff --git a/relik/reader/trainer/predict.py b/relik/reader/trainer/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..667aaa4907075a32a78df7707c686fec59c43818 --- /dev/null +++ b/relik/reader/trainer/predict.py @@ -0,0 +1,70 @@ +import argparse +from pprint import pprint +from typing import Optional + +from relik.reader.relik_reader_predictor import RelikReaderPredictor +from relik.reader.utils.strong_matching_eval import StrongMatching +from relik.reader.relik_reader_core import RelikReaderCoreModel +from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction +import hydra +from omegaconf import DictConfig +from relik.reader.data.relik_reader_sample import load_relik_reader_samples +import json +# @hydra.main(config_path="config.yaml", config_name="") # Specify your config path and name here +def predict( + model_path: str, + dataset_path: str, + token_batch_size: int, + is_eval: bool, + output_path: Optional[str], +) -> None: + relik_reader = RelikReaderForSpanExtraction(model_path,training=False, device="cuda") + samples = list(load_relik_reader_samples(dataset_path)) + predicted_samples = relik_reader.read( + samples=samples, progress_bar=True + ) + if True: + eval_dict = StrongMatching()(predicted_samples) + pprint(eval_dict) + if output_path is not None: + with open(output_path, "w") as f: + gold_text = "" + for sample in predicted_samples: + text = sample.to_jsons() + # json.dump(text, f) + # f.write("\n") + gold_text += str(text["window_labels"]) + "\t" + str(text["predicted_window_labels"]) + "\n" + + f.write(gold_text) + + + +def parse_arg() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", + required=True, + ) + parser.add_argument("--dataset-path", "-i", required=True) + parser.add_argument("--is-eval", action="store_true") + parser.add_argument( + "--output-path", + "-o", + ) + parser.add_argument("--token-batch-size", default=4096) + return parser.parse_args() + + +def main(): + args = parse_arg() + predict( + args.model_path, + args.dataset_path, + token_batch_size=args.token_batch_size, + is_eval=args.is_eval, + output_path=args.output_path, + ) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/trainer/predict_re.py b/relik/reader/trainer/predict_re.py new file mode 100644 index 0000000000000000000000000000000000000000..0944ca59cbcd4751da932e3d3d44cad9a8e2a4e2 --- /dev/null +++ b/relik/reader/trainer/predict_re.py @@ -0,0 +1,86 @@ +import argparse + +import torch + +from relik.reader.data.relik_reader_sample import load_relik_reader_samples +from relik.reader.pytorch_modules.hf.modeling_relik import ( + RelikReaderConfig, + RelikReaderREModel, +) +from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction +from relik.reader.utils.relation_matching_eval import StrongMatching + + +def eval(model_path, data_path, is_eval, output_path=None): + if model_path.endswith(".ckpt"): + # if it is a lightning checkpoint we load the model state dict and the tokenizer from the config + model_dict = torch.load(model_path) + + additional_special_symbols = model_dict["hyper_parameters"][ + "additional_special_symbols" + ] + from transformers import AutoTokenizer + + from relik.reader.utils.special_symbols import get_special_symbols_re + + special_symbols = get_special_symbols_re(additional_special_symbols - 1) + tokenizer = AutoTokenizer.from_pretrained( + model_dict["hyper_parameters"]["transformer_model"], + additional_special_tokens=special_symbols, + add_prefix_space=True, + ) + config_model = RelikReaderConfig( + model_dict["hyper_parameters"]["transformer_model"], + len(special_symbols), + training=False, + ) + model = RelikReaderREModel(config_model) + model_dict["state_dict"] = { + k.replace("relik_reader_re_model.", ""): v + for k, v in model_dict["state_dict"].items() + } + model.load_state_dict(model_dict["state_dict"], strict=False) + reader = RelikReaderForTripletExtraction( + model, training=False, device="cuda", tokenizer=tokenizer + ) + else: + # if it is a huggingface model we load the model directly. Note that it could even be a string from the hub + model = RelikReaderREModel.from_pretrained(model_path) + reader = RelikReaderForTripletExtraction( + model, training=False, device="cuda" + ) # , dataset_kwargs={"use_nme": True}) if we want to use NME + + samples = list(load_relik_reader_samples(data_path)) + + predicted_samples = reader.read(samples=samples, progress_bar=True) + if is_eval: + strong_matching_metric = StrongMatching() + predicted_samples = list(predicted_samples) + for k, v in strong_matching_metric(predicted_samples).items(): + print(f"test_{k}", v) + if output_path is not None: + with open(output_path, "w") as f: + for sample in predicted_samples: + f.write(sample.to_jsons() + "\n") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", + type=str, + default="/root/relik/experiments/relik_reader_re_small", + ) + parser.add_argument( + "--data_path", + type=str, + default="/root/relik/data/re/test.jsonl", + ) + parser.add_argument("--is-eval", action="store_true") + parser.add_argument("--output_path", type=str, default=None) + args = parser.parse_args() + eval(args.model_path, args.data_path, args.is_eval, args.output_path) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/trainer/train.py b/relik/reader/trainer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..cac94ba0f2d4db02b0860c020858e1c36738f94e --- /dev/null +++ b/relik/reader/trainer/train.py @@ -0,0 +1,118 @@ +from pathlib import Path +from pprint import pprint +import hydra +import lightning +from hydra.utils import to_absolute_path +from lightning import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers.wandb import WandbLogger +from omegaconf import DictConfig, OmegaConf, open_dict +import torch +from torch.utils.data import DataLoader + +from relik.reader.data.relik_reader_data import RelikDataset +from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule +from relik.reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer +from relik.reader.utils.special_symbols import get_special_symbols +from relik.reader.utils.strong_matching_eval import ELStrongMatchingCallback +from relik.reader.utils.shuffle_train_callback import ShuffleTrainCallback + +@hydra.main(config_path="../conf", config_name="config") +def train(cfg: DictConfig) -> None: + + lightning.seed_everything(cfg.training.seed) + # check if deterministic algorithms are available + # torch.use_deterministic_algorithms(True, warn_only=True) + + # log the configuration + pprint(OmegaConf.to_container(cfg, resolve=True)) + + special_symbols = get_special_symbols(cfg.model.entities_per_forward) + + # model declaration + model = RelikReaderPLModule( + cfg=OmegaConf.to_container(cfg), + transformer_model=cfg.model.model.transformer_model, + additional_special_symbols=len(special_symbols), + training=True, + ) + + # optimizer declaration + opt_conf = cfg.model.optimizer + electra_optimizer_factory = LayerWiseLRDecayOptimizer( + lr=opt_conf.lr, + warmup_steps=opt_conf.warmup_steps, + total_steps=opt_conf.total_steps, + total_reset=opt_conf.total_reset, + no_decay_params=opt_conf.no_decay_params, + weight_decay=opt_conf.weight_decay, + lr_decay=opt_conf.lr_decay, + ) + + model.set_optimizer_factory(electra_optimizer_factory) + + # datasets declaration + train_dataset: RelikDataset = hydra.utils.instantiate( + cfg.data.train_dataset, + dataset_path=to_absolute_path(cfg.data.train_dataset_path), + special_symbols=special_symbols, + ) + + # update of validation dataset config with special_symbols since they + # are required even from the EvaluationCallback dataset_config + with open_dict(cfg): + cfg.data.val_dataset.special_symbols = special_symbols + + val_dataset: RelikDataset = hydra.utils.instantiate( + cfg.data.val_dataset, + dataset_path=to_absolute_path(cfg.data.val_dataset_path), + ) + + # callbacks declaration + callbacks = [ + ELStrongMatchingCallback( + to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset + ), + ModelCheckpoint( + "model", + filename="{epoch}-{val_core_f1:.2f}", + monitor="val_core_f1", + mode="max", + ), + LearningRateMonitor(), + ShuffleTrainCallback(), + ] + + wandb_logger = WandbLogger( + cfg.model_name, project=cfg.project_name, offline=cfg.offline + ) + + # trainer declaration + trainer: Trainer = hydra.utils.instantiate( + cfg.training.trainer, + callbacks=callbacks, + logger=wandb_logger, + ) + + # Trainer fit + trainer.fit( + model=model, + train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=1), + val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), + ) + + # if cfg.training.save_model_path: + experiment_path = Path(wandb_logger.experiment.dir) + model = RelikReaderPLModule.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) + model.relik_reader_core_model._tokenizer = train_dataset.tokenizer + model.relik_reader_core_model.save_pretrained(experiment_path / "hf_model") + + +def main(): + train() + + +if __name__ == "__main__": + main() diff --git a/relik/reader/trainer/train_cie.py b/relik/reader/trainer/train_cie.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3baef577b42c4e416346f573cad6c2e6249a05 --- /dev/null +++ b/relik/reader/trainer/train_cie.py @@ -0,0 +1,135 @@ +import hydra +import lightning +from hydra.utils import to_absolute_path +from lightning import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers.wandb import WandbLogger +from omegaconf import DictConfig, OmegaConf, open_dict +from torch.utils.data import DataLoader + +from relik.reader.data.relik_reader_re_data import RelikREDataset +from relik.reader.lightning_modules.relik_reader_re_pl_module import ( + RelikReaderREPLModule, +) +from relik.reader.pytorch_modules.optim import ( + AdamWWithWarmupOptimizer, + LayerWiseLRDecayOptimizer, +) +from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback +from relik.reader.utils.special_symbols import ( + get_special_symbols, + get_special_symbols_re, +) + + +@hydra.main(config_path="../conf", config_name="config_cie") +def train(cfg: DictConfig) -> None: + lightning.seed_everything(cfg.training.seed) + + special_symbols = get_special_symbols_re(cfg.model.relations_per_forward) + special_symbols_types = get_special_symbols(cfg.model.entities_per_forward) + # datasets declaration + train_dataset: RelikREDataset = hydra.utils.instantiate( + cfg.data.train_dataset, + dataset_path=to_absolute_path(cfg.data.train_dataset_path), + special_symbols=special_symbols, + special_symbols_types=special_symbols_types, + ) + + # update of validation dataset config with special_symbols since they + # are required even from the EvaluationCallback dataset_config + with open_dict(cfg): + cfg.data.val_dataset.special_symbols = special_symbols + cfg.data.val_dataset.special_symbols_types = special_symbols_types + + val_dataset: RelikREDataset = hydra.utils.instantiate( + cfg.data.val_dataset, + dataset_path=to_absolute_path(cfg.data.val_dataset_path), + ) + + if val_dataset.materialize_samples: + list(val_dataset.dataset_iterator_func()) + + # model declaration + model = RelikReaderREPLModule( + cfg=OmegaConf.to_container(cfg), + transformer_model=cfg.model.model.transformer_model, + additional_special_symbols=len(special_symbols), + additional_special_symbols_types=len(special_symbols_types), + entity_type_loss=True, + add_entity_embedding=True, + training=True, + ) + model.relik_reader_re_model._tokenizer = train_dataset.tokenizer + # optimizer declaration + opt_conf = cfg.model.optimizer + + if "total_reset" not in opt_conf: + optimizer_factory = AdamWWithWarmupOptimizer( + lr=opt_conf.lr, + warmup_steps=opt_conf.warmup_steps, + total_steps=opt_conf.total_steps, + no_decay_params=opt_conf.no_decay_params, + weight_decay=opt_conf.weight_decay, + ) + else: + optimizer_factory = LayerWiseLRDecayOptimizer( + lr=opt_conf.lr, + warmup_steps=opt_conf.warmup_steps, + total_steps=opt_conf.total_steps, + total_reset=opt_conf.total_reset, + no_decay_params=opt_conf.no_decay_params, + weight_decay=opt_conf.weight_decay, + lr_decay=opt_conf.lr_decay, + ) + + model.set_optimizer_factory(optimizer_factory) + + # callbacks declaration + callbacks = [ + REStrongMatchingCallback( + to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset + ), + ModelCheckpoint( + "model", + filename="{epoch}-{val_f1:.2f}", + monitor="val_f1", + mode="max", + ), + LearningRateMonitor(), + ] + + wandb_logger = WandbLogger( + cfg.model_name, project=cfg.project_name + ) # , offline=True) + + # trainer declaration + trainer: Trainer = hydra.utils.instantiate( + cfg.training.trainer, + callbacks=callbacks, + logger=wandb_logger, + ) + + # Trainer fit + trainer.fit( + model=model, + train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), + val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), + ckpt_path=cfg.training.ckpt_path if cfg.training.ckpt_path else None, + ) + + # Load best checkpoint + if cfg.training.save_model_path: + model = RelikReaderREPLModule.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) + model.relik_reader_re_model._tokenizer = train_dataset.tokenizer + model.relik_reader_re_model.save_pretrained(cfg.training.save_model_path) + + +def main(): + train() + + +if __name__ == "__main__": + main() diff --git a/relik/reader/trainer/train_re.py b/relik/reader/trainer/train_re.py new file mode 100644 index 0000000000000000000000000000000000000000..8b79735bc8ef7fb28cb25d7b773278aac27914ea --- /dev/null +++ b/relik/reader/trainer/train_re.py @@ -0,0 +1,123 @@ +import hydra +import lightning +from hydra.utils import to_absolute_path +from lightning import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers.wandb import WandbLogger +from omegaconf import DictConfig, OmegaConf, open_dict +from torch.utils.data import DataLoader + +from relik.reader.data.relik_reader_re_data import RelikREDataset +from relik.reader.lightning_modules.relik_reader_re_pl_module import ( + RelikReaderREPLModule, +) +from relik.reader.pytorch_modules.optim import ( + AdamWWithWarmupOptimizer, + LayerWiseLRDecayOptimizer, +) +from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback +from relik.reader.utils.special_symbols import get_special_symbols_re + + +@hydra.main(config_path="../conf", config_name="config") +def train(cfg: DictConfig) -> None: + lightning.seed_everything(cfg.training.seed) + + special_symbols = get_special_symbols_re(cfg.model.relations_per_forward) + + # datasets declaration + train_dataset: RelikREDataset = hydra.utils.instantiate( + cfg.data.train_dataset, + dataset_path=to_absolute_path(cfg.data.train_dataset_path), + special_symbols=special_symbols, + ) + + # update of validation dataset config with special_symbols since they + # are required even from the EvaluationCallback dataset_config + with open_dict(cfg): + cfg.data.val_dataset.special_symbols = special_symbols + + val_dataset: RelikREDataset = hydra.utils.instantiate( + cfg.data.val_dataset, + dataset_path=to_absolute_path(cfg.data.val_dataset_path), + ) + + if val_dataset.materialize_samples: + list(val_dataset.dataset_iterator_func()) + # model declaration + model = RelikReaderREPLModule( + cfg=OmegaConf.to_container(cfg), + transformer_model=cfg.model.model.transformer_model, + additional_special_symbols=len(special_symbols), + training=True, + ) + model.relik_reader_re_model._tokenizer = train_dataset.tokenizer + # optimizer declaration + opt_conf = cfg.model.optimizer + + if "total_reset" not in opt_conf: + optimizer_factory = AdamWWithWarmupOptimizer( + lr=opt_conf.lr, + warmup_steps=opt_conf.warmup_steps, + total_steps=opt_conf.total_steps, + no_decay_params=opt_conf.no_decay_params, + weight_decay=opt_conf.weight_decay, + ) + else: + optimizer_factory = LayerWiseLRDecayOptimizer( + lr=opt_conf.lr, + warmup_steps=opt_conf.warmup_steps, + total_steps=opt_conf.total_steps, + total_reset=opt_conf.total_reset, + no_decay_params=opt_conf.no_decay_params, + weight_decay=opt_conf.weight_decay, + lr_decay=opt_conf.lr_decay, + ) + + model.set_optimizer_factory(optimizer_factory) + # callbacks declaration + callbacks = [ + REStrongMatchingCallback( + to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset + ), + ModelCheckpoint( + "model", + filename="{epoch}-{val_f1:.2f}", + monitor="val_f1", + mode="max", + ), + LearningRateMonitor(), + ] + + wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name) + + # trainer declaration + trainer: Trainer = hydra.utils.instantiate( + cfg.training.trainer, + callbacks=callbacks, + logger=wandb_logger, + ) + + # Trainer fit + trainer.fit( + model=model, + train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), + val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), + ckpt_path=cfg.training.ckpt_path if cfg.training.ckpt_path else None, + ) + + # Load best checkpoint + if cfg.training.save_model_path: + model = RelikReaderREPLModule.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) + model.relik_reader_re_model._tokenizer = train_dataset.tokenizer + model.relik_reader_re_model.save_pretrained(cfg.training.save_model_path) + + +def main(): + train() + + +if __name__ == "__main__": + main() diff --git a/relik/reader/utils/__init__.py b/relik/reader/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/utils/__pycache__/__init__.cpython-310.pyc b/relik/reader/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2571eb0b43c544667b711f84e117ccf5eca20d19 Binary files /dev/null and b/relik/reader/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/reader/utils/__pycache__/metrics.cpython-310.pyc b/relik/reader/utils/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b63cc1d4c41fc0e70b81c1685b0112507f75c0 Binary files /dev/null and b/relik/reader/utils/__pycache__/metrics.cpython-310.pyc differ diff --git a/relik/reader/utils/__pycache__/shuffle_train_callback.cpython-310.pyc b/relik/reader/utils/__pycache__/shuffle_train_callback.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..924bb8060cb1c6293230ec9928aff022fdb7ad54 Binary files /dev/null and b/relik/reader/utils/__pycache__/shuffle_train_callback.cpython-310.pyc differ diff --git a/relik/reader/utils/__pycache__/special_symbols.cpython-310.pyc b/relik/reader/utils/__pycache__/special_symbols.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c035f66a594ab0801325feb5d3ea2654fb94f297 Binary files /dev/null and b/relik/reader/utils/__pycache__/special_symbols.cpython-310.pyc differ diff --git a/relik/reader/utils/__pycache__/strong_matching_eval.cpython-310.pyc b/relik/reader/utils/__pycache__/strong_matching_eval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e025b842124f447a4b7342c089c6bec4795aba8e Binary files /dev/null and b/relik/reader/utils/__pycache__/strong_matching_eval.cpython-310.pyc differ diff --git a/relik/reader/utils/metrics.py b/relik/reader/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..fa17bf5d23cc888d6da0c6f40cf7bd3c20d77a66 --- /dev/null +++ b/relik/reader/utils/metrics.py @@ -0,0 +1,18 @@ +def safe_divide(num: float, den: float) -> float: + if den == 0: + return 0 + else: + return num / den + + +def f1_measure(precision: float, recall: float) -> float: + if precision == 0 or recall == 0: + return 0.0 + return safe_divide(2 * precision * recall, (precision + recall)) + + +def compute_metrics(total_correct, total_preds, total_gold): + precision = safe_divide(total_correct, total_preds) + recall = safe_divide(total_correct, total_gold) + f1 = f1_measure(precision, recall) + return precision, recall, f1 diff --git a/relik/reader/utils/relation_matching_eval.py b/relik/reader/utils/relation_matching_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9ac00ee91c5f199d61e910dfe17fdbf9652f49 --- /dev/null +++ b/relik/reader/utils/relation_matching_eval.py @@ -0,0 +1,361 @@ +from typing import Dict, List + +from collections import defaultdict + +from lightning.pytorch.callbacks import Callback + +from relik.reader.data.relik_reader_re_data import RelikREDataset +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.relik_reader_predictor import RelikReaderPredictor +from relik.reader.utils.metrics import compute_metrics + + +class StrongMatching: + def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: + # accumulators + correct_predictions, total_predictions, total_gold = ( + 0, + 0, + 0, + ) + correct_predictions_strict, total_predictions_strict = ( + 0, + 0, + ) + correct_predictions_bound, total_predictions_bound = ( + 0, + 0, + ) + correct_span_predictions, total_span_predictions, total_gold_spans = ( + 0, + 0, + 0, + ) + ( + correct_span_in_triplets_predictions, + total_span_in_triplets_predictions, + total_gold_spans_in_triplets, + ) = ( + 0, + 0, + 0, + ) + + # collect data from samples + for sample in predicted_samples: + if sample.triplets is None: + sample.triplets = [] + + if sample.span_candidates: + predicted_annotations_strict = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + triplet["subject"]["type"], + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + triplet["object"]["type"], + ) + for triplet in sample.predicted_relations + ] + ) + gold_annotations_strict = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + triplet["subject"]["type"], + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + triplet["object"]["type"], + ) + for triplet in sample.triplets + ] + ) + predicted_spans_strict = set((ss, se, st) for (ss, se, st) in sample.predicted_entities) + gold_spans_strict = set(sample.entities) + predicted_spans_in_triplets = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + triplet["subject"]["type"], + ) + for triplet in sample.predicted_relations + ] + + [ + ( + triplet["object"]["start"], + triplet["object"]["end"], + triplet["object"]["type"], + ) + for triplet in sample.predicted_relations + ] + ) + gold_spans_in_triplets = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + triplet["subject"]["type"], + ) + for triplet in sample.triplets + ] + + [ + ( + triplet["object"]["start"], + triplet["object"]["end"], + triplet["object"]["type"], + ) + for triplet in sample.triplets + ] + ) + # strict + correct_span_predictions += len( + predicted_spans_strict.intersection(gold_spans_strict) + ) + total_span_predictions += len(predicted_spans_strict) + + correct_span_in_triplets_predictions += len( + predicted_spans_in_triplets.intersection(gold_spans_in_triplets) + ) + total_span_in_triplets_predictions += len(predicted_spans_in_triplets) + total_gold_spans_in_triplets += len(gold_spans_in_triplets) + + correct_predictions_strict += len( + predicted_annotations_strict.intersection(gold_annotations_strict) + ) + total_predictions_strict += len(predicted_annotations_strict) + + predicted_annotations = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + -1, + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + -1, + ) + for triplet in sample.predicted_relations + ] + ) + gold_annotations = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + -1, + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + -1, + ) + for triplet in sample.triplets + ] + ) + predicted_spans = set( + [(ss, se) for (ss, se, _) in sample.predicted_entities] + ) + gold_spans = set([(ss, se) for (ss, se, _) in sample.entities]) + total_gold_spans += len(gold_spans) + + correct_predictions_bound += len(predicted_spans.intersection(gold_spans)) + total_predictions_bound += len(predicted_spans) + + total_predictions += len(predicted_annotations) + total_gold += len(gold_annotations) + # correct relation extraction + correct_predictions += len( + predicted_annotations.intersection(gold_annotations) + ) + + span_precision, span_recall, span_f1 = compute_metrics( + correct_span_predictions, total_span_predictions, total_gold_spans + ) + bound_precision, bound_recall, bound_f1 = compute_metrics( + correct_predictions_bound, total_predictions_bound, total_gold_spans + ) + + precision, recall, f1 = compute_metrics( + correct_predictions, total_predictions, total_gold + ) + + if sample.span_candidates: + precision_strict, recall_strict, f1_strict = compute_metrics( + correct_predictions_strict, total_predictions_strict, total_gold + ) + ( + span_in_triplet_precisiion, + span_in_triplet_recall, + span_in_triplet_f1, + ) = compute_metrics( + correct_span_in_triplets_predictions, + total_span_in_triplets_predictions, + total_gold_spans_in_triplets, + ) + return { + "span-precision-strict": span_precision, + "span-recall-strict": span_recall, + "span-f1-strict": span_f1, + "span-precision": bound_precision, + "span-recall": bound_recall, + "span-f1": bound_f1, + "span-in-triplet-precision": span_in_triplet_precisiion, + "span-in-triplet-recall": span_in_triplet_recall, + "span-in-triplet-f1": span_in_triplet_f1, + "precision": precision, + "recall": recall, + "f1": f1, + "precision-strict": precision_strict, + "recall-strict": recall_strict, + "f1-strict": f1_strict, + } + else: + return { + "span-precision": bound_precision, + "span-recall": bound_recall, + "span-f1": bound_f1, + "precision": precision, + "recall": recall, + "f1": f1, + } + +class StrongMatchingPerRelation: + def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: + correct_predictions, total_predictions, total_gold = ( + defaultdict(int), + defaultdict(int), + defaultdict(int), + ) + correct_predictions_strict, total_predictions_strict = ( + defaultdict(int), + defaultdict(int), + ) + # collect data from samples + for sample in predicted_samples: + if sample.triplets is None: + sample.triplets = [] + + if sample.span_candidates: + gold_annotations_strict = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + triplet["subject"]["type"], + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + triplet["object"]["type"], + ) + for triplet in sample.triplets + ] + ) + # compute correct preds per triplet["relation"]["name"] + for triplet in sample.predicted_relations: + predicted_annotations_strict = ( + triplet["subject"]["start"], + triplet["subject"]["end"], + triplet["subject"]["type"], + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + triplet["object"]["type"], + ) + if predicted_annotations_strict in gold_annotations_strict: + correct_predictions_strict[triplet["relation"]["name"]] += 1 + total_predictions_strict[triplet["relation"]["name"]] += 1 + gold_annotations = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + -1, + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + -1, + ) + for triplet in sample.triplets + ] + ) + for triplet in sample.predicted_relations: + predicted_annotations = ( + triplet["subject"]["start"], + triplet["subject"]["end"], + -1, + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + -1, + ) + if predicted_annotations in gold_annotations: + correct_predictions[triplet["relation"]["name"]] += 1 + total_predictions[triplet["relation"]["name"]] += 1 + for triplet in sample.triplets: + total_gold[triplet["relation"]["name"]] += 1 + metrics = {} + metrics_non_zero = 0 + for relation in total_gold.keys(): + precision, recall, f1 = compute_metrics( + correct_predictions[relation], + total_predictions[relation], + total_gold[relation], + ) + metrics[f"{relation}-precision"] = precision + metrics[f"{relation}-recall"] = recall + metrics[f"{relation}-f1"] = f1 + precision_strict, recall_strict, f1_strict = compute_metrics( + correct_predictions_strict[relation], + total_predictions_strict[relation], + total_gold[relation], + ) + metrics[f"{relation}-precision-strict"] = precision_strict + metrics[f"{relation}-recall-strict"] = recall_strict + metrics[f"{relation}-f1-strict"] = f1_strict + if metrics[f"{relation}-f1-strict"] > 0: + metrics_non_zero += 1 + # print in a readable way + print( + f"{relation} precision: {precision:.4f} recall: {recall:.4f} f1: {f1:.4f} precision_strict: {precision_strict:.4f} recall_strict: {recall_strict:.4f} f1_strict: {f1_strict:.4f} support: {total_gold[relation]}" + ) + print(f"metrics_non_zero: {metrics_non_zero}") + return metrics + +class REStrongMatchingCallback(Callback): + def __init__(self, dataset_path: str, dataset_conf, log_metric: str = "val_") -> None: + super().__init__() + self.dataset_path = dataset_path + self.dataset_conf = dataset_conf + self.strong_matching_metric = StrongMatching() + self.log_metric = log_metric + + def on_validation_epoch_start(self, trainer, pl_module) -> None: + dataloader = trainer.val_dataloaders + if ( + self.dataset_path == dataloader.dataset.dataset_path + and dataloader.dataset.samples is not None + and len(dataloader.dataset.samples) > 0 + ): + relik_reader_predictor = RelikReaderPredictor( + pl_module.relik_reader_re_model, dataloader=trainer.val_dataloaders + ) + else: + relik_reader_predictor = RelikReaderPredictor( + pl_module.relik_reader_re_model + ) + predicted_samples = relik_reader_predictor._predict( + self.dataset_path, + None, + self.dataset_conf, + ) + predicted_samples = list(predicted_samples) + for sample in predicted_samples: + RelikREDataset._convert_annotations(sample) + for k, v in self.strong_matching_metric(predicted_samples).items(): + pl_module.log(f"{self.log_metric}{k}", v) diff --git a/relik/reader/utils/save_load_utilities.py b/relik/reader/utils/save_load_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..1e635650c1f69c0e223d268f97ec9d6e0677742c --- /dev/null +++ b/relik/reader/utils/save_load_utilities.py @@ -0,0 +1,76 @@ +import argparse +import os +from typing import Tuple + +import omegaconf +import torch + +from relik.common.utils import from_cache +from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule +from relik.reader.relik_reader_core import RelikReaderCoreModel + +CKPT_FILE_NAME = "model.ckpt" +CONFIG_FILE_NAME = "cfg.yaml" + + +def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + else: + print(f"{output_dir} already exists, aborting operation") + exit(1) + + relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint( + pl_module_ckpt_path + ) + torch.save( + relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}" + ) + with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f: + omegaconf.OmegaConf.save( + omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f + ) + + +def load_model_and_conf( + model_dir_path: str, +) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]: + # TODO: quick workaround to load the model from HF hub + model_dir = from_cache( + model_dir_path, + filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME], + cache_dir=None, + force_download=False, + ) + + ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}" + model = torch.load(ckpt_path, map_location=torch.device("cpu")) + + model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}" + model_conf = omegaconf.OmegaConf.load(model_cfg_path) + return model, model_conf + + +def parse_arg() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + help="Path to the pytorch lightning ckpt you want to convert.", + required=True, + ) + parser.add_argument( + "--output-dir", + "-o", + help="The output dir to store the bare models and the config.", + required=True, + ) + return parser.parse_args() + + +def main(): + args = parse_arg() + convert_pl_module(args.ckpt, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/utils/shuffle_train_callback.py b/relik/reader/utils/shuffle_train_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..601b55f32fbef281da0070c615ba973e0d1322e5 --- /dev/null +++ b/relik/reader/utils/shuffle_train_callback.py @@ -0,0 +1,22 @@ +from typing import Dict, List + +from collections import defaultdict + +from lightning.pytorch.callbacks import Callback + +from relik.common.log import get_logger + +import os + +logger = get_logger() + +class ShuffleTrainCallback(Callback): + def __init__(self, shuffle_every: int = 1, data_path: str = None): + self.shuffle_every = shuffle_every + self.data_path = data_path + + def on_train_epoch_end(self, trainer, pl_module): + if (trainer.current_epoch + 1) % self.shuffle_every == 0: + logger.info("Shuffling train dataset") + os.system(f"shuf {self.data_path} > {self.data_path}.shuf") + os.system(f"mv {self.data_path}.shuf {self.data_path}") \ No newline at end of file diff --git a/relik/reader/utils/special_symbols.py b/relik/reader/utils/special_symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..b36c5145af5a0ef80aae175698cdb53e4c7da3ff --- /dev/null +++ b/relik/reader/utils/special_symbols.py @@ -0,0 +1,14 @@ +from typing import List + +NME_SYMBOL = "--NME--" + + +def get_special_symbols(num_entities: int) -> List[str]: + return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)] + + +def get_special_symbols_re(num_entities: int, use_nme: bool = False) -> List[str]: + if use_nme: + return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)] + else: + return [f"[R-{i}]" for i in range(num_entities)] diff --git a/relik/reader/utils/strong_matching_eval.py b/relik/reader/utils/strong_matching_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7eed42767fcd8dc97a6fcae60325ef21801d5e --- /dev/null +++ b/relik/reader/utils/strong_matching_eval.py @@ -0,0 +1,178 @@ +from typing import Dict, List + +from lightning.pytorch.callbacks import Callback + +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.relik_reader_predictor import RelikReaderPredictor +from relik.reader.utils.metrics import f1_measure, safe_divide +from relik.reader.utils.special_symbols import NME_SYMBOL + + +class StrongMatching: + def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: + # accumulators + correct_predictions = 0 + correct_predicted_entities = 0 + correct_predictions_at_k = 0 + total_predictions = 0 + total_gold = 0 + total_entities_predictions = 0 + total_entities_gold = 0 + correct_span_predictions = 0 + miss_due_to_candidates = 0 + + # prediction index stats + avg_correct_predicted_index = [] + avg_wrong_predicted_index = [] + less_index_predictions = [] + + # collect data from samples + for sample in predicted_samples: + predicted_annotations = sample.predicted_window_labels_chars + predicted_annotations_probabilities = sample.probs_window_labels_chars + gold_annotations = { + (ss, se, entity) + for ss, se, entity in sample.window_labels + if entity != NME_SYMBOL + } + + total_predictions += len(predicted_annotations) + total_gold += len(gold_annotations) + + gold_entities = { + entity + for ss, se, entity in sample.window_labels + if entity != NME_SYMBOL + } + total_entities_gold += len(gold_entities) + + pred_entities = { + entity + for ss, se, entity in predicted_annotations + if entity != NME_SYMBOL + } + total_entities_predictions += len(pred_entities) + + # correct named entity detection + predicted_spans = {(s, e) for s, e, _ in predicted_annotations} + gold_spans = {(s, e) for s, e, _ in gold_annotations} + correct_span_predictions += len(predicted_spans.intersection(gold_spans)) + + # correct entity linking + correct_predictions += len( + predicted_annotations.intersection(gold_annotations) + ) + + for ss, se, ge in gold_annotations.difference(predicted_annotations): + if ge not in sample.window_candidates: + miss_due_to_candidates += 1 + if ge in predicted_annotations_probabilities.get((ss, se), set()): + correct_predictions_at_k += 1 + + # correct entity disambiguation + correct_predicted_entities += len( + pred_entities.intersection(gold_entities) + ) + + # indices metrics + predicted_spans_index = { + (ss, se): ent for ss, se, ent in predicted_annotations + } + gold_spans_index = {(ss, se): ent for ss, se, ent in gold_annotations} + + for pred_span, pred_ent in predicted_spans_index.items(): + gold_ent = gold_spans_index.get(pred_span) + + if pred_span not in gold_spans_index: + continue + + # missing candidate + if gold_ent not in sample.window_candidates: + continue + + gold_idx = sample.window_candidates.index(gold_ent) + if gold_idx is None: + continue + pred_idx = sample.window_candidates.index(pred_ent) + + if gold_ent != pred_ent: + avg_wrong_predicted_index.append(pred_idx) + + if gold_idx is not None: + if pred_idx > gold_idx: + less_index_predictions.append(0) + else: + less_index_predictions.append(1) + + else: + avg_correct_predicted_index.append(pred_idx) + + # compute NED metrics + span_precision = safe_divide(correct_span_predictions, total_predictions) + span_recall = safe_divide(correct_span_predictions, total_gold) + span_f1 = f1_measure(span_precision, span_recall) + + # compute EL metrics + precision = safe_divide(correct_predictions, total_predictions) + recall = safe_divide(correct_predictions, total_gold) + recall_at_k = safe_divide( + (correct_predictions + correct_predictions_at_k), total_gold + ) + f1 = f1_measure(precision, recall) + + + # comput ED metrics + precision_entities = safe_divide(correct_predicted_entities, total_entities_predictions) + recall_entities = safe_divide(correct_predicted_entities, total_entities_gold) + span_entities_f1 = f1_measure(precision_entities, recall_entities) + + + wrong_for_candidates = safe_divide(miss_due_to_candidates, total_gold) + + out_dict = { + "span_precision": span_precision, + "span_recall": span_recall, + "span_f1": span_f1, + "entities_precision": precision_entities, + "entities_recall": recall_entities, + "entities_f1": span_entities_f1, + "core_precision": precision, + "core_recall": recall, + "core_recall-at-k": recall_at_k, + "core_f1": round(f1, 4), + "wrong-for-candidates": wrong_for_candidates, + "index_errors_avg-index": safe_divide( + sum(avg_wrong_predicted_index), len(avg_wrong_predicted_index) + ), + "index_correct_avg-index": safe_divide( + sum(avg_correct_predicted_index), len(avg_correct_predicted_index) + ), + "index_avg-index": safe_divide( + sum(avg_correct_predicted_index + avg_wrong_predicted_index), + len(avg_correct_predicted_index + avg_wrong_predicted_index), + ), + "index_percentage-favoured-smaller-idx": safe_divide( + sum(less_index_predictions), len(less_index_predictions) + ), + } + + return {k: round(v, 5) for k, v in out_dict.items()} + + +class ELStrongMatchingCallback(Callback): + def __init__(self, dataset_path: str, dataset_conf) -> None: + super().__init__() + self.dataset_path = dataset_path + self.dataset_conf = dataset_conf + self.strong_matching_metric = StrongMatching() + + def on_validation_epoch_start(self, trainer, pl_module) -> None: + relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_core_model) + predicted_samples = relik_reader_predictor.predict( + self.dataset_path, + samples=None, + dataset_conf=self.dataset_conf, + ) + predicted_samples = list(predicted_samples) + for k, v in self.strong_matching_metric(predicted_samples).items(): + pl_module.log(f"val_{k}", v) diff --git a/relik/retriever/__init__.py b/relik/retriever/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..42a3df6b991b0af65ec5974fc4faa381b8e555b7 --- /dev/null +++ b/relik/retriever/__init__.py @@ -0,0 +1 @@ +from relik.retriever.pytorch_modules.model import GoldenRetriever diff --git a/relik/retriever/__pycache__/__init__.cpython-310.pyc b/relik/retriever/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b8b04e2ee1bc520f07c84516a44b066a814243c Binary files /dev/null and b/relik/retriever/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/retriever/callbacks/__init__.py b/relik/retriever/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/callbacks/base.py b/relik/retriever/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3c9720d7f59a64e11512a898382221e3b134b815 --- /dev/null +++ b/relik/retriever/callbacks/base.py @@ -0,0 +1,168 @@ +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import hydra +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage +from omegaconf import DictConfig +from torch.utils.data import DataLoader, Dataset + +from relik.common.log import get_logger +from relik.retriever.data.base.datasets import BaseDataset + +logger = get_logger(__name__) + + +STAGES_COMPATIBILITY_MAP = { + "train": RunningStage.TRAINING, + "val": RunningStage.VALIDATING, + "test": RunningStage.TESTING, +} + +DEFAULT_STAGES = { + RunningStage.VALIDATING, + RunningStage.TESTING, + RunningStage.SANITY_CHECKING, + RunningStage.PREDICTING, +} + + +class PredictionCallback(pl.Callback): + def __init__( + self, + batch_size: int = 32, + stages: Optional[Set[Union[str, RunningStage]]] = None, + other_callbacks: Optional[ + Union[List[DictConfig], List["NLPTemplateCallback"]] + ] = None, + datasets: Optional[Union[DictConfig, BaseDataset]] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + *args, + **kwargs, + ): + super().__init__() + # parameters + self.batch_size = batch_size + self.datasets = datasets + self.dataloaders = dataloaders + + # callback initialization + if stages is None: + stages = DEFAULT_STAGES + + # compatibily stuff + stages = {STAGES_COMPATIBILITY_MAP.get(stage, stage) for stage in stages} + self.stages = [RunningStage(stage) for stage in stages] + self.other_callbacks = other_callbacks or [] + for i, callback in enumerate(self.other_callbacks): + if isinstance(callback, DictConfig): + self.other_callbacks[i] = hydra.utils.instantiate( + callback, _recursive_=False + ) + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + *args, + **kwargs, + ) -> Any: + # it should return the predictions + raise NotImplementedError + + def on_validation_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ): + predictions = self(trainer, pl_module) + for callback in self.other_callbacks: + callback( + trainer=trainer, + pl_module=pl_module, + callback=self, + predictions=predictions, + ) + + def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + predictions = self(trainer, pl_module) + for callback in self.other_callbacks: + callback( + trainer=trainer, + pl_module=pl_module, + callback=self, + predictions=predictions, + ) + + @staticmethod + def _get_datasets_and_dataloaders( + dataset: Optional[Union[Dataset, DictConfig]], + dataloader: Optional[DataLoader], + trainer: pl.Trainer, + dataloader_kwargs: Optional[Dict[str, Any]] = None, + collate_fn: Optional[Callable] = None, + collate_fn_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[List[Dataset], List[DataLoader]]: + """ + Get the datasets and dataloaders from the datamodule or from the dataset provided. + + Args: + dataset (`Optional[Union[Dataset, DictConfig]]`): + The dataset to use. If `None`, the datamodule is used. + dataloader (`Optional[DataLoader]`): + The dataloader to use. If `None`, the datamodule is used. + trainer (`pl.Trainer`): + The trainer that contains the datamodule. + dataloader_kwargs (`Optional[Dict[str, Any]]`): + The kwargs to pass to the dataloader. + collate_fn (`Optional[Callable]`): + The collate function to use. + collate_fn_kwargs (`Optional[Dict[str, Any]]`): + The kwargs to pass to the collate function. + + Returns: + `Tuple[List[Dataset], List[DataLoader]]`: The datasets and dataloaders. + """ + # if a dataset is provided, use it + if dataset is not None: + dataloader_kwargs = dataloader_kwargs or {} + # get dataset + if isinstance(dataset, DictConfig): + dataset = hydra.utils.instantiate(dataset, _recursive_=False) + datasets = [dataset] if not isinstance(dataset, list) else dataset + if dataloader is not None: + dataloaders = ( + [dataloader] if isinstance(dataloader, DataLoader) else dataloader + ) + else: + collate_fn = collate_fn or partial( + datasets[0].collate_fn, **collate_fn_kwargs + ) + dataloader_kwargs["collate_fn"] = collate_fn + dataloaders = [DataLoader(datasets[0], **dataloader_kwargs)] + else: + # get the dataloaders and datasets from the datamodule + datasets = ( + trainer.datamodule.test_datasets + if trainer.state.stage == RunningStage.TESTING + else trainer.datamodule.val_datasets + ) + dataloaders = ( + trainer.test_dataloaders + if trainer.state.stage == RunningStage.TESTING + else trainer.val_dataloaders + ) + return datasets, dataloaders + + +class NLPTemplateCallback: + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + callback: PredictionCallback, + predictions: Dict[str, Any], + *args, + **kwargs, + ) -> Any: + raise NotImplementedError diff --git a/relik/retriever/callbacks/evaluation_callbacks.py b/relik/retriever/callbacks/evaluation_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..868d013730029b5c4a4112fd4900ad38cf55806c --- /dev/null +++ b/relik/retriever/callbacks/evaluation_callbacks.py @@ -0,0 +1,275 @@ +import logging +from typing import Dict, List, Optional + +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage +from sklearn.metrics import label_ranking_average_precision_score + +from relik.common.log import get_logger +from relik.retriever.callbacks.base import DEFAULT_STAGES, NLPTemplateCallback + +logger = get_logger(__name__, level=logging.INFO) + + +class RecallAtKEvaluationCallback(NLPTemplateCallback): + """ + Computes the recall at k for the predictions. Recall at k is computed as the number of + correct predictions in the top k predictions divided by the total number of correct + predictions. + + Args: + k (`int`): + The number of predictions to consider. + prefix (`str`, `optional`): + The prefix to add to the metrics. + verbose (`bool`, `optional`, defaults to `False`): + Whether to log the metrics. + prog_bar (`bool`, `optional`, defaults to `True`): + Whether to log the metrics to the progress bar. + """ + + def __init__( + self, + k: int = 100, + prefix: Optional[str] = None, + verbose: bool = False, + prog_bar: bool = True, + *args, + **kwargs, + ): + super().__init__() + self.k = k + self.prefix = prefix + self.verbose = verbose + self.prog_bar = prog_bar + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Dict, + *args, + **kwargs, + ) -> dict: + """ + Computes the recall at k for the predictions. + + Args: + trainer (:obj:`lightning.trainer.trainer.Trainer`): + The trainer object. + pl_module (:obj:`lightning.core.lightning.LightningModule`): + The lightning module. + predictions (:obj:`Dict`): + The predictions. + + Returns: + :obj:`Dict`: The computed metrics. + """ + if self.verbose: + logger.info(f"Computing recall@{self.k}") + + # metrics to return + metrics = {} + + stage = trainer.state.stage + if stage not in DEFAULT_STAGES: + raise ValueError( + f"Stage {stage} not supported, only `validate` and `test` are supported." + ) + + for dataloader_idx, samples in predictions.items(): + hits, total = 0, 0 + for sample in samples: + # compute the recall at k + # cut the predictions to the first k elements + predictions = sample["predictions"][: self.k] + hits += len(set(predictions) & set(sample["gold"])) + total += len(set(sample["gold"])) + + # compute the mean recall at k + recall_at_k = hits / total + metrics[f"recall@{self.k}_{dataloader_idx}"] = recall_at_k + metrics[f"recall@{self.k}"] = sum(metrics.values()) / len(metrics) + + if self.prefix is not None: + metrics = {f"{self.prefix}_{k}": v for k, v in metrics.items()} + else: + metrics = {f"{stage.value}_{k}": v for k, v in metrics.items()} + pl_module.log_dict( + metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar + ) + + if self.verbose: + logger.info( + f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}" + ) + + return metrics + + +class AvgRankingEvaluationCallback(NLPTemplateCallback): + """ + Computes the average ranking of the gold label in the predictions. Average ranking is + computed as the average of the rank of the gold label in the predictions. + + Args: + k (`int`): + The number of predictions to consider. + prefix (`str`, `optional`): + The prefix to add to the metrics. + stages (`List[str]`, `optional`): + The stages to compute the metrics on. Defaults to `["validate", "test"]`. + verbose (`bool`, `optional`, defaults to `False`): + Whether to log the metrics. + """ + + def __init__( + self, + k: int, + prefix: Optional[str] = None, + stages: Optional[List[str]] = None, + verbose: bool = True, + *args, + **kwargs, + ): + super().__init__() + self.k = k + self.prefix = prefix + self.verbose = verbose + self.stages = ( + [RunningStage(stage) for stage in stages] if stages else DEFAULT_STAGES + ) + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Dict, + *args, + **kwargs, + ) -> dict: + """ + Computes the average ranking of the gold label in the predictions. + + Args: + trainer (:obj:`lightning.trainer.trainer.Trainer`): + The trainer object. + pl_module (:obj:`lightning.core.lightning.LightningModule`): + The lightning module. + predictions (:obj:`Dict`): + The predictions. + + Returns: + :obj:`Dict`: The computed metrics. + """ + if not predictions: + logger.warning("No predictions to compute the AVG Ranking metrics.") + return {} + + if self.verbose: + logger.info(f"Computing AVG Ranking@{self.k}") + + # metrics to return + metrics = {} + + stage = trainer.state.stage + if stage not in self.stages: + raise ValueError( + f"Stage `{stage}` not supported, only `validate` and `test` are supported." + ) + + for dataloader_idx, samples in predictions.items(): + rankings = [] + for sample in samples: + window_candidates = sample["predictions"][: self.k] + window_labels = sample["gold"] + for wl in window_labels: + if wl in window_candidates: + rankings.append(window_candidates.index(wl) + 1) + + avg_ranking = sum(rankings) / len(rankings) if len(rankings) > 0 else 0 + metrics[f"avg_ranking@{self.k}_{dataloader_idx}"] = avg_ranking + if len(metrics) == 0: + metrics[f"avg_ranking@{self.k}"] = 0 + else: + metrics[f"avg_ranking@{self.k}"] = sum(metrics.values()) / len(metrics) + + prefix = self.prefix or stage.value + metrics = { + f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32) + for k, v in metrics.items() + } + pl_module.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=False) + + if self.verbose: + logger.info( + f"AVG Ranking@{self.k} on {prefix}: {metrics[f'{prefix}_avg_ranking@{self.k}']}" + ) + + return metrics + + +class LRAPEvaluationCallback(NLPTemplateCallback): + def __init__( + self, + k: int = 100, + prefix: Optional[str] = None, + verbose: bool = False, + prog_bar: bool = True, + *args, + **kwargs, + ): + super().__init__() + self.k = k + self.prefix = prefix + self.verbose = verbose + self.prog_bar = prog_bar + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Dict, + *args, + **kwargs, + ) -> dict: + if self.verbose: + logger.info(f"Computing recall@{self.k}") + + # metrics to return + metrics = {} + + stage = trainer.state.stage + if stage not in DEFAULT_STAGES: + raise ValueError( + f"Stage {stage} not supported, only `validate` and `test` are supported." + ) + + for dataloader_idx, samples in predictions.items(): + scores = [sample["scores"][: self.k] for sample in samples] + golds = [sample["gold"] for sample in samples] + + # compute the mean recall at k + lrap = label_ranking_average_precision_score(golds, scores) + metrics[f"lrap@{self.k}_{dataloader_idx}"] = lrap + metrics[f"lrap@{self.k}"] = sum(metrics.values()) / len(metrics) + + prefix = self.prefix or stage.value + metrics = { + f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32) + for k, v in metrics.items() + } + pl_module.log_dict( + metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar + ) + + if self.verbose: + logger.info( + f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}" + ) + + return metrics diff --git a/relik/retriever/callbacks/prediction_callbacks.py b/relik/retriever/callbacks/prediction_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f84f5eeeedd80e10d98da32e236aa1a44e40f0 --- /dev/null +++ b/relik/retriever/callbacks/prediction_callbacks.py @@ -0,0 +1,200 @@ +import logging +import time +from pathlib import Path +from typing import List, Optional, Set + +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage +from omegaconf import DictConfig +from torch.utils.data import DataLoader +from tqdm import tqdm + +from relik.common.log import get_logger +from relik.retriever.callbacks.base import NLPTemplateCallback, PredictionCallback +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.data.datasets import GoldenRetrieverDataset +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.pytorch_modules.model import GoldenRetriever + +logger = get_logger(__name__, level=logging.INFO) + + +class GoldenRetrieverPredictionCallback(PredictionCallback): + def __init__( + self, + k: int | None = None, + batch_size: int = 32, + num_workers: int = 8, + document_index: BaseDocumentIndex | None = None, + precision: str | int = 32, + force_reindex: bool = True, + retriever_dir: Optional[Path] = None, + stages: Set[str | RunningStage] | None = None, + other_callbacks: List[DictConfig] | List[NLPTemplateCallback] | None = None, + dataset: DictConfig | BaseDataset | None = None, + dataloader: DataLoader | None = None, + *args, + **kwargs, + ): + super().__init__(batch_size, stages, other_callbacks, dataset, dataloader) + self.k = k + self.num_workers = num_workers + self.document_index = document_index + self.precision = precision + self.force_reindex = force_reindex + self.retriever_dir = retriever_dir + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + datasets: DictConfig + | BaseDataset + | List[DictConfig] + | List[BaseDataset] + | None = None, + dataloaders: DataLoader | List[DataLoader] | None = None, + *args, + **kwargs, + ) -> dict: + stage = trainer.state.stage + logger.info(f"Computing predictions for stage {stage.value}") + if stage not in self.stages: + raise ValueError( + f"Stage `{stage}` not supported, only {self.stages} are supported" + ) + + self.datasets, self.dataloaders = self._get_datasets_and_dataloaders( + datasets, + dataloaders, + trainer, + dataloader_kwargs=dict( + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + shuffle=False, + ), + ) + + # set the model to eval mode + pl_module.eval() + # get the retriever + retriever: GoldenRetriever = pl_module.model + + # here we will store the samples with predictions for each dataloader + dataloader_predictions = {} + # compute the passage embeddings index for each dataloader + for dataloader_idx, dataloader in enumerate(self.dataloaders): + current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx] + logger.info( + f"Computing passage embeddings for dataset {current_dataset.name}" + ) + + tokenizer = current_dataset.tokenizer + + def collate_fn(x): + return ModelInputs( + tokenizer( + x, + truncation=True, + padding=True, + max_length=current_dataset.max_passage_length, + return_tensors="pt", + ) + ) + + # check if we need to reindex the passages and + # also if we need to load the retriever from disk + if (self.retriever_dir is not None and trainer.current_epoch == 0) or ( + self.retriever_dir is not None and stage == RunningStage.TESTING + ): + force_reindex = False + else: + force_reindex = self.force_reindex + + if ( + not force_reindex + and self.retriever_dir is not None + and stage == RunningStage.TESTING + ): + retriever = retriever.from_pretrained(self.retriever_dir) + + # you never know :) + retriever.eval() + + retriever.index( + batch_size=self.batch_size, + num_workers=self.num_workers, + max_length=current_dataset.max_passage_length, + collate_fn=collate_fn, + precision=self.precision, + compute_on_cpu=False, + force_reindex=force_reindex, + ) + + # now compute the question embeddings and compute the top-k accuracy + predictions = [] + start = time.time() + for batch in tqdm( + dataloader, + desc=f"Computing predictions for dataset {current_dataset.name}", + ): + batch = batch.to(pl_module.device) + # get the top-k indices + retriever_output = retriever.retrieve( + **batch.questions, k=self.k, precision=self.precision + ) + # compute recall at k + for batch_idx, retrieved_samples in enumerate(retriever_output): + # get the positive passages + gold_passages = batch["positives"][batch_idx] + # get the index of the gold passages in the retrieved passages + gold_passage_indices = [] + for passage in gold_passages: + try: + gold_passage_indices.append( + retriever.get_index_from_passage(passage) + ) + except ValueError: + logger.warning( + f"Passage `{passage}` not found in the index. " + "We will skip it, but the results might not reflect the " + "actual performance." + ) + pass + retrieved_indices = [r.document.id for r in retrieved_samples if r] + retrieved_passages = [ + retriever.get_passage_from_index(i) for i in retrieved_indices + ] + retrieved_scores = [r.score for r in retrieved_samples] + # correct predictions are the passages that are in the top-k and are gold + correct_indices = set(gold_passage_indices) & set(retrieved_indices) + # wrong predictions are the passages that are in the top-k and are not gold + wrong_indices = set(retrieved_indices) - set(gold_passage_indices) + # add the predictions to the list + prediction_output = dict( + sample_idx=batch.sample_idx[batch_idx], + gold=gold_passages, + predictions=retrieved_passages, + scores=retrieved_scores, + correct=[ + retriever.get_passage_from_index(i) for i in correct_indices + ], + wrong=[ + retriever.get_passage_from_index(i) for i in wrong_indices + ], + ) + predictions.append(prediction_output) + end = time.time() + logger.info(f"Time to retrieve: {str(end - start)}") + + dataloader_predictions[dataloader_idx] = predictions + + # if pl_module_original_device != pl_module.device: + # pl_module.to(pl_module_original_device) + + # return the predictions + return dataloader_predictions diff --git a/relik/retriever/callbacks/training_callbacks.py b/relik/retriever/callbacks/training_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..955ad3e0eec4202bed1f78a8987ea7effd5ae8cb --- /dev/null +++ b/relik/retriever/callbacks/training_callbacks.py @@ -0,0 +1,213 @@ +import logging +import random +import time +from copy import deepcopy +from pathlib import Path +from typing import List, Optional, Sequence, Set, Union + +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage +from omegaconf import DictConfig +from torch.utils.data import DataLoader +from tqdm import tqdm + +from relik.common.log import get_logger +from relik.retriever.callbacks.prediction_callbacks import ( + GoldenRetrieverPredictionCallback, +) +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.data.utils import HardNegativesManager + +logger = get_logger(__name__, level=logging.INFO) + + +class NegativeAugmentationCallback(GoldenRetrieverPredictionCallback): + """ + Callback that computes the predictions of a retriever model on a dataset and computes the + negative examples for the training set. + + Args: + k (:obj:`int`, `optional`, defaults to 100): + The number of top-k retrieved passages to + consider for the evaluation. + batch_size (:obj:`int`, `optional`, defaults to 32): + The batch size to use for the evaluation. + num_workers (:obj:`int`, `optional`, defaults to 0): + The number of workers to use for the evaluation. + force_reindex (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to force the reindexing of the dataset. + retriever_dir (:obj:`Path`, `optional`): + The path to the retriever directory. If not specified, the retriever will be + initialized from scratch. + stages (:obj:`Set[str]`, `optional`): + The stages to run the callback on. If not specified, the callback will be run on + train, validation and test. + other_callbacks (:obj:`List[DictConfig]`, `optional`): + A list of other callbacks to run on the same stages. + dataset (:obj:`Union[DictConfig, BaseDataset]`, `optional`): + The dataset to use for the evaluation. If not specified, the dataset will be + initialized from scratch. + metrics_to_monitor (:obj:`List[str]`, `optional`): + The metrics to monitor for the evaluation. + threshold (:obj:`float`, `optional`, defaults to 0.8): + The threshold to consider. If the recall score of the retriever is above the + threshold, the negative examples will be added to the training set. + max_negatives (:obj:`int`, `optional`, defaults to 5): + The maximum number of negative examples to add to the training set. + add_with_probability (:obj:`float`, `optional`, defaults to 1.0): + The probability with which to add the negative examples to the training set. + refresh_every_n_epochs (:obj:`int`, `optional`, defaults to 1): + The number of epochs after which to refresh the index. + """ + + def __init__( + self, + k: int = 100, + batch_size: int = 32, + num_workers: int = 0, + force_reindex: bool = False, + retriever_dir: Optional[Path] = None, + stages: Sequence[Union[str, RunningStage]] = None, + other_callbacks: Optional[List[DictConfig]] = None, + dataset: Optional[Union[DictConfig, BaseDataset]] = None, + metrics_to_monitor: List[str] = None, + threshold: float = 0.8, + max_negatives: int = 5, + add_with_probability: float = 1.0, + refresh_every_n_epochs: int = 1, + *args, + **kwargs, + ): + super().__init__( + k=k, + batch_size=batch_size, + num_workers=num_workers, + force_reindex=force_reindex, + retriever_dir=retriever_dir, + stages=stages, + other_callbacks=other_callbacks, + dataset=dataset, + *args, + **kwargs, + ) + if metrics_to_monitor is None: + metrics_to_monitor = ["val_loss"] + self.metrics_to_monitor = metrics_to_monitor + self.threshold = threshold + self.max_negatives = max_negatives + self.add_with_probability = add_with_probability + self.refresh_every_n_epochs = refresh_every_n_epochs + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + *args, + **kwargs, + ) -> dict: + """ + Computes the predictions of a retriever model on a dataset and computes the negative + examples for the training set. + + Args: + trainer (:obj:`pl.Trainer`): + The trainer object. + pl_module (:obj:`pl.LightningModule`): + The lightning module. + + Returns: + A dictionary containing the negative examples. + """ + stage = trainer.state.stage + if stage not in self.stages: + return {} + + if self.metrics_to_monitor not in trainer.logged_metrics: + logger.warning( + f"Metric `{self.metrics_to_monitor}` not found in trainer.logged_metrics. " + f"Available metrics: {trainer.logged_metrics.keys()}" + ) + return {} + + if trainer.logged_metrics[self.metrics_to_monitor] < self.threshold: + return {} + + if trainer.current_epoch % self.refresh_every_n_epochs != 0: + return {} + + # if all( + # [ + # trainer.logged_metrics.get(metric) is None + # for metric in self.metrics_to_monitor + # ] + # ): + # raise ValueError( + # f"No metric from {self.metrics_to_monitor} not found in trainer.logged_metrics" + # f"Available metrics: {trainer.logged_metrics.keys()}" + # ) + + # if all( + # [ + # trainer.logged_metrics.get(metric) < self.threshold + # for metric in self.metrics_to_monitor + # if trainer.logged_metrics.get(metric) is not None + # ] + # ): + # return {} + + if trainer.current_epoch % self.refresh_every_n_epochs != 0: + return {} + + logger.info( + f"At least one metric from {self.metrics_to_monitor} is above threshold " + f"{self.threshold}. Computing hard negatives." + ) + + # make a copy of the dataset to avoid modifying the original one + trainer.datamodule.train_dataset.hn_manager = None + dataset_copy = deepcopy(trainer.datamodule.train_dataset) + predictions = super().__call__( + trainer, + pl_module, + datasets=dataset_copy, + dataloaders=DataLoader( + dataset_copy.to_torch_dataset(), + shuffle=False, + batch_size=None, + num_workers=self.num_workers, + pin_memory=True, + collate_fn=lambda x: x, + ), + *args, + **kwargs, + ) + logger.info(f"Computing hard negatives for epoch {trainer.current_epoch}") + # predictions is a dict with the dataloader index as key and the predictions as value + # since we only have one dataloader, we can get the predictions directly + predictions = list(predictions.values())[0] + # store the predictions in a dictionary for faster access based on the sample index + hard_negatives_list = {} + for prediction in tqdm(predictions, desc="Collecting hard negatives"): + if random.random() < 1 - self.add_with_probability: + continue + top_k_passages = prediction["predictions"] + gold_passages = prediction["gold"] + # get the ids of the max_negatives wrong passages with the highest similarity + wrong_passages = [ + passage_id + for passage_id in top_k_passages + if passage_id not in gold_passages + ][: self.max_negatives] + hard_negatives_list[prediction["sample_idx"]] = wrong_passages + + trainer.datamodule.train_dataset.hn_manager = HardNegativesManager( + tokenizer=trainer.datamodule.train_dataset.tokenizer, + max_length=trainer.datamodule.train_dataset.max_passage_length, + data=hard_negatives_list, + ) + + # normalize predictions as in the original GoldenRetrieverPredictionCallback + predictions = {0: predictions} + return predictions diff --git a/relik/retriever/callbacks/utils_callbacks.py b/relik/retriever/callbacks/utils_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef85033ebc8aa44250cf251fa99ed5736d803fd --- /dev/null +++ b/relik/retriever/callbacks/utils_callbacks.py @@ -0,0 +1,292 @@ +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage + +from relik.common.log import get_logger +from relik.retriever.callbacks.base import NLPTemplateCallback, PredictionCallback +from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel + +logger = get_logger(__name__, level=logging.INFO) + + +class SavePredictionsCallback(NLPTemplateCallback): + def __init__( + self, + saving_dir: str | os.PathLike | None = None, + verbose: bool = False, + *args, + **kwargs, + ): + super().__init__() + self.saving_dir = saving_dir + self.verbose = verbose + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Dict, + callback: PredictionCallback, + *args, + **kwargs, + ) -> dict: + # write the predictions to a file inside the experiment folder + if self.saving_dir is None and trainer.logger is None: + logger.info( + "You need to specify an output directory (`saving_dir`) or a logger to save the predictions.\n" + "Skipping saving predictions." + ) + return {} + datasets = callback.datasets + for dataloader_idx, dataloader_predictions in predictions.items(): + # save to file + if self.saving_dir is not None: + prediction_folder = Path(self.saving_dir) + else: + try: + prediction_folder = ( + Path(trainer.logger.experiment.dir) / "predictions" + ) + except Exception: + logger.info( + "You need to specify an output directory (`saving_dir`) or a logger to save the predictions.\n" + "Skipping saving predictions." + ) + return {} + prediction_folder.mkdir(exist_ok=True) + predictions_path = ( + prediction_folder + / f"{datasets[dataloader_idx].name}_{dataloader_idx}.json" + ) + if self.verbose: + logger.info(f"Saving predictions to {predictions_path}") + with open(predictions_path, "w") as f: + for prediction in dataloader_predictions: + for k, v in prediction.items(): + if isinstance(v, set): + prediction[k] = list(v) + f.write(json.dumps(prediction) + "\n") + + +class ResetModelCallback(pl.Callback): + def __init__( + self, + question_encoder: str, + passage_encoder: str | None = None, + verbose: bool = True, + ) -> None: + super().__init__() + self.question_encoder = question_encoder + self.passage_encoder = passage_encoder + self.verbose = verbose + + def on_train_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs + ) -> None: + if trainer.current_epoch == 0: + if self.verbose: + logger.info("Current epoch is 0, skipping resetting model") + return + + if self.verbose: + logger.info("Resetting model, optimizer and lr scheduler") + # reload model from scratch + previous_device = pl_module.device + trainer.model.model.question_encoder = GoldenRetrieverModel.from_pretrained( + self.question_encoder + ) + trainer.model.model.question_encoder.to(previous_device) + if self.passage_encoder is not None: + trainer.model.model.passage_encoder = GoldenRetrieverModel.from_pretrained( + self.passage_encoder + ) + trainer.model.model.passage_encoder.to(previous_device) + + trainer.strategy.setup_optimizers(trainer) + + +class FreeUpIndexerVRAMCallback(pl.Callback): + def __call__( + self, + pl_module: pl.LightningModule, + *args, + **kwargs, + ) -> Any: + logger.info("Freeing up GPU memory") + + # remove the index from the GPU memory + # remove the embeddings from the GPU memory first + if pl_module.model.document_index is not None: + if pl_module.model.document_index.embeddings is not None: + try: + pl_module.model.document_index.embeddings.cpu() + except Exception: + logger.warning( + "Could not move embeddings to CPU. Skipping freeing up VRAM." + ) + pass + pl_module.model.document_index.embeddings = None + + import gc + + gc.collect() + torch.cuda.empty_cache() + + def on_train_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs + ) -> None: + return self(pl_module) + + def on_test_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs + ) -> None: + return self(pl_module) + + +class ShuffleTrainDatasetCallback(pl.Callback): + def __init__(self, seed: int = 42, verbose: bool = True) -> None: + super().__init__() + self.seed = seed + self.verbose = verbose + self.previous_epoch = -1 + + def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + if trainer.current_epoch != self.previous_epoch: + logger.info(f"Shuffling train dataset at epoch {trainer.current_epoch}") + + # logger.info(f"Shuffling train dataset at epoch {trainer.current_epoch}") + if trainer.current_epoch != self.previous_epoch: + trainer.datamodule.train_dataset.shuffle_data( + seed=self.seed + trainer.current_epoch + 1 + ) + self.previous_epoch = trainer.current_epoch + + +class PrefetchTrainDatasetCallback(pl.Callback): + def __init__(self, verbose: bool = True) -> None: + super().__init__() + self.verbose = verbose + # self.previous_epoch = -1 + + def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): + if trainer.datamodule.train_dataset.prefetch_batches: + if self.verbose: + # if trainer.current_epoch != self.previous_epoch: + logger.info( + f"Prefetching train dataset at epoch {trainer.current_epoch}" + ) + # if trainer.current_epoch != self.previous_epoch: + trainer.datamodule.train_dataset.prefetch() + self.previous_epoch = trainer.current_epoch + + +class SubsampleTrainDatasetCallback(pl.Callback): + def __init__(self, seed: int = 43, verbose: bool = True) -> None: + super().__init__() + self.seed = seed + self.verbose = verbose + + def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + logger.info(f"Subsampling train dataset at epoch {trainer.current_epoch}") + trainer.datamodule.train_dataset.random_subsample( + seed=self.seed + trainer.current_epoch + 1 + ) + + +class SaveRetrieverCallback(pl.Callback): + def __init__( + self, + saving_dir: str | os.PathLike | None = None, + verbose: bool = True, + *args, + **kwargs, + ): + super().__init__() + self.saving_dir = saving_dir + self.verbose = verbose + self.free_up_indexer_callback = FreeUpIndexerVRAMCallback() + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + *args, + **kwargs, + ): + if self.saving_dir is None and trainer.logger is None: + logger.info( + "You need to specify an output directory (`saving_dir`) or a logger to save the retriever.\n" + "Skipping saving retriever." + ) + return + if self.saving_dir is not None: + retriever_folder = Path(self.saving_dir) + else: + try: + retriever_folder = Path(trainer.logger.experiment.dir) / "retriever" + except Exception: + logger.info( + "You need to specify an output directory (`saving_dir`) or a logger to save the " + "retriever.\nSkipping saving retriever." + ) + return + retriever_folder.mkdir(exist_ok=True, parents=True) + if self.verbose: + logger.info(f"Saving retriever to {retriever_folder}") + pl_module.model.save_pretrained(retriever_folder) + + def on_save_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: Dict[str, Any], + ): + self(trainer, pl_module) + # self.free_up_indexer_callback(pl_module) + + +class SampleNegativesDatasetCallback(pl.Callback): + def __init__(self, seed: int = 42, verbose: bool = True) -> None: + super().__init__() + self.seed = seed + self.verbose = verbose + + def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + f"Sampling negatives for train dataset at epoch {trainer.current_epoch}" + trainer.datamodule.train_dataset.sample_dataset_negatives( + seed=self.seed + trainer.current_epoch + ) + + +class SubsampleDataCallback(pl.Callback): + def __init__(self, seed: int = 42, verbose: bool = True) -> None: + super().__init__() + self.seed = seed + self.verbose = verbose + + def on_validation_epoch_start(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + f"Subsampling data for train dataset at epoch {trainer.current_epoch}" + if trainer.state.stage == RunningStage.SANITY_CHECKING: + return + trainer.datamodule.train_dataset.subsample_data( + seed=self.seed + trainer.current_epoch + ) + + def on_fit_start(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + f"Subsampling data for train dataset at epoch {trainer.current_epoch}" + trainer.datamodule.train_dataset.subsample_data( + seed=self.seed + trainer.current_epoch + ) diff --git a/relik/retriever/common/__init__.py b/relik/retriever/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/common/__pycache__/__init__.cpython-310.pyc b/relik/retriever/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d1d9646ec83254feeb7845a1018f9ad94f0f277 Binary files /dev/null and b/relik/retriever/common/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/retriever/common/__pycache__/model_inputs.cpython-310.pyc b/relik/retriever/common/__pycache__/model_inputs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11e718ad95476e5784ea560cd14007f11224ad67 Binary files /dev/null and b/relik/retriever/common/__pycache__/model_inputs.cpython-310.pyc differ diff --git a/relik/retriever/common/model_inputs.py b/relik/retriever/common/model_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..f03f9124cee0d2c9c3b6bbe217f9bf857eca7660 --- /dev/null +++ b/relik/retriever/common/model_inputs.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from collections import UserDict +from typing import Any, Union + +import torch +from lightning.fabric.utilities import move_data_to_device + +from relik.common.log import get_logger + +logger = get_logger(__name__) + + +class ModelInputs(UserDict): + """Model input dictionary wrapper.""" + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError(f"`ModelInputs` has no attribute `{item}`") + + def __getitem__(self, item: str) -> Any: + return self.data[item] + + def __getstate__(self): + return {"data": self.data} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + def keys(self): + """A set-like object providing a view on D's keys.""" + return self.data.keys() + + def values(self): + """An object providing a view on D's values.""" + return self.data.values() + + def items(self): + """A set-like object providing a view on D's items.""" + return self.data.items() + + def to(self, device: Union[str, torch.device]) -> ModelInputs: + """ + Send all tensors values to device. + Args: + device (`str` or `torch.device`): The device to put the tensors on. + Returns: + :class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs` + after modification. + """ + self.data = move_data_to_device(self.data, device) + return self diff --git a/relik/retriever/common/sampler.py b/relik/retriever/common/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..024c57b23da6db71dd76929226b005f75b9e98f5 --- /dev/null +++ b/relik/retriever/common/sampler.py @@ -0,0 +1,108 @@ +import math + +from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler + + +def identity(x): + return x + + +class SortedSampler(Sampler): + """ + Samples elements sequentially, always in the same order. + + Args: + data (`obj`: `Iterable`): + Iterable data. + sort_key (`obj`: `Callable`): + Specifies a function of one argument that is used to + extract a numerical comparison key from each list element. + + Example: + >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + """ + + def __init__(self, data, sort_key=identity): + super().__init__(data) + self.data = data + self.sort_key = sort_key + zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] + zip_ = sorted(zip_, key=lambda r: r[1]) + self.sorted_indexes = [item[0] for item in zip_] + + def __iter__(self): + return iter(self.sorted_indexes) + + def __len__(self): + return len(self.data) + + +class BucketBatchSampler(BatchSampler): + """ + `BucketBatchSampler` toggles between `sampler` batches and sorted batches. + Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between + random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice + versa. + Background: + ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like + ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar + size length to reduce the padding required for each batch while maintaining some noise + through bucketing. + **AllenNLP Implementation:** + https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py + **torchtext Implementation:** + https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 + + Args: + sampler (`obj`: `torch.data.utils.sampler.Sampler): + batch_size (`int`): + Size of mini-batch. + drop_last (`bool`, optional, defaults to `False`): + If `True` the sampler will drop the last batch if its size would be less than `batch_size`. + sort_key (`obj`: `Callable`, optional, defaults to `identity`): + Callable to specify a comparison key for sorting. + bucket_size_multiplier (`int`, optional, defaults to `100`): + Buckets are of size `batch_size * bucket_size_multiplier`. + Example: + >>> from torchnlp.random import set_seed + >>> set_seed(123) + >>> + >>> from torch.utils.data.sampler import SequentialSampler + >>> sampler = SequentialSampler(list(range(10))) + >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) + [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] + >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + """ + + def __init__( + self, + sampler, + batch_size, + drop_last: bool = False, + sort_key=identity, + bucket_size_multiplier=100, + ): + super().__init__(sampler, batch_size, drop_last) + self.sort_key = sort_key + _bucket_size = batch_size * bucket_size_multiplier + if hasattr(sampler, "__len__"): + _bucket_size = min(_bucket_size, len(sampler)) + self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) + + def __iter__(self): + for bucket in self.bucket_sampler: + sorted_sampler = SortedSampler(bucket, self.sort_key) + for batch in SubsetRandomSampler( + list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last)) + ): + yield [bucket[i] for i in batch] + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return math.ceil(len(self.sampler) / self.batch_size) diff --git a/relik/retriever/conf/data/aida_dataset.yaml b/relik/retriever/conf/data/aida_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22fcc6458a2dc757b569baef37846751dd3c1c7a --- /dev/null +++ b/relik/retriever/conf/data/aida_dataset.yaml @@ -0,0 +1,47 @@ +shared_params: + passages_path: null + max_passage_length: 64 + passage_batch_size: 64 + question_batch_size: 64 + use_topics: False + +datamodule: + _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule + datasets: + train: + _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset + name: "train" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + subsample_strategy: null + subsample_portion: 0.1 + shuffle: True + use_topics: ${data.shared_params.use_topics} + + val: + - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset + name: "val" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + use_topics: ${data.shared_params.use_topics} + + test: + - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset + name: "test" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + use_topics: ${data.shared_params.use_topics} + + num_workers: + train: 4 + val: 4 + test: 4 diff --git a/relik/retriever/conf/data/dataset.yaml b/relik/retriever/conf/data/dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6040616d1f92182c2d002c065a7b89dc240f96b3 --- /dev/null +++ b/relik/retriever/conf/data/dataset.yaml @@ -0,0 +1,43 @@ +shared_params: + passages_path: null + max_passage_length: 64 + passage_batch_size: 64 + question_batch_size: 64 + +datamodule: + _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule + datasets: + train: + _target_: relik.retriever.data.datasets.InBatchNegativesDataset + name: "train" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + subsample_strategy: null + subsample_portion: 0.1 + shuffle: True + + val: + - _target_: relik.retriever.data.datasets.InBatchNegativesDataset + name: "val" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + + test: + - _target_: relik.retriever.data.datasets.InBatchNegativesDataset + name: "test" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + + num_workers: + train: 0 + val: 0 + test: 0 diff --git a/relik/retriever/conf/finetune_iterable_in_batch.yaml b/relik/retriever/conf/finetune_iterable_in_batch.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8bec07ecfca4cc0a8d99d4971a426f2397e9dc62 --- /dev/null +++ b/relik/retriever/conf/finetune_iterable_in_batch.yaml @@ -0,0 +1,117 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: ${model.language_model} # used to name the model in wandb +project_name: relik-retriever # used to name the project in wandb + +defaults: + - _self_ + - model: golden_retriever + - index: inmemory + - loss: nce_loss + - optimizer: radamw + - scheduler: linear_scheduler + - data: dataset_v2 # iterable_in_batch_negatives #dataset_v2 + - logging: wandb_logging + - override hydra/job_logging: colorlog + - override hydra/hydra_logging: colorlog + +train: + # reproducibility + seed: 42 + set_determinism_the_old_way: False + # torch parameters + float32_matmul_precision: "medium" + # if true, only test the model + only_test: False + # if provided, initialize the model with the weights from the checkpoint + pretrain_ckpt_path: null + # if provided, start training from the checkpoint + checkpoint_path: null + + # task specific parameter + top_k: 100 + + # pl_trainer + pl_trainer: + _target_: lightning.Trainer + accelerator: gpu + devices: 1 + num_nodes: 1 + strategy: auto + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps + check_val_every_n_epoch: 1 + max_epochs: 0 + max_steps: 25_000 + deterministic: True + fast_dev_run: False + precision: 16 + reload_dataloaders_every_n_epochs: 1 + + early_stopping_callback: + # null + _target_: lightning.callbacks.EarlyStopping + monitor: validate_recall@${train.top_k} + mode: max + patience: 3 + + model_checkpoint_callback: + _target_: lightning.callbacks.ModelCheckpoint + monitor: validate_recall@${train.top_k} + mode: max + verbose: True + save_top_k: 1 + save_last: False + filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" + auto_insert_metric_name: False + + callbacks: + prediction_callback: + _target_: relik.retriever.callbacks.training_callbacks.GoldenRetrieverPredictionCallback + k: ${train.top_k} + batch_size: 64 + precision: 16 + index_precision: 16 + other_callbacks: + - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback + k: ${train.top_k} + verbose: True + - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback + k: 50 + verbose: True + prog_bar: False + - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback + k: ${train.top_k} + verbose: True + - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback + + hard_negatives_callback: + _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback + k: ${train.top_k} + batch_size: 64 + precision: 16 + index_precision: 16 + stages: [validate] #[validate, sanity_check] + metrics_to_monitor: + validate_recall@${train.top_k} + # - sanity_check_recall@${train.top_k} + threshold: 0.0 + max_negatives: 20 + add_with_probability: 1.0 + refresh_every_n_epochs: 1 + other_callbacks: + - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback + k: ${train.top_k} + verbose: True + prefix: "train" + + utils_callbacks: + - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback + - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback + # - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback + # question_encoder: ${model.pl_module.model.question_encoder} + # passage_encoder: ${model.pl_module.model.passage_encoder} diff --git a/relik/retriever/conf/index/inmemory.yaml b/relik/retriever/conf/index/inmemory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a2247aa52ac0a6857e2f4d036162227c505acb6 --- /dev/null +++ b/relik/retriever/conf/index/inmemory.yaml @@ -0,0 +1,6 @@ +_target_: relik.retriever.indexers.inmemory.InMemoryDocumentIndex.from_file +documents: ${data.shared_params.passages_path} +metadata_fields: ["definition"] +separator: "" +device: cuda +precision: 16 diff --git a/relik/retriever/conf/logging/wandb_logging.yaml b/relik/retriever/conf/logging/wandb_logging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1908d7e09789ca0b8e4973ec7f3ca5d47d460af3 --- /dev/null +++ b/relik/retriever/conf/logging/wandb_logging.yaml @@ -0,0 +1,16 @@ +# don't forget loggers.login() for the first usage. + +log: True # set to False to avoid the logging + +wandb_arg: + _target_: lightning.loggers.WandbLogger + name: ${model_name} + project: ${project_name} + save_dir: ./ + log_model: True + mode: "online" + entity: null + +watch: + log: "all" + log_freq: 100 diff --git a/relik/retriever/conf/loss/nce_loss.yaml b/relik/retriever/conf/loss/nce_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe9246b88027fd0d499b9bc3b4beaa7937b23586 --- /dev/null +++ b/relik/retriever/conf/loss/nce_loss.yaml @@ -0,0 +1 @@ +_target_: relik.retriever.pythorch_modules.losses.MultiLabelNCELoss diff --git a/relik/retriever/conf/loss/nll_loss.yaml b/relik/retriever/conf/loss/nll_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e0a5010025a4a6e9da382e5408af4a58cda6185 --- /dev/null +++ b/relik/retriever/conf/loss/nll_loss.yaml @@ -0,0 +1 @@ +_target_: torch.nn.NLLLoss diff --git a/relik/retriever/conf/optimizer/adamw.yaml b/relik/retriever/conf/optimizer/adamw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ff0f84e15ebd6c60e3e6e411c30e88fb910fdb29 --- /dev/null +++ b/relik/retriever/conf/optimizer/adamw.yaml @@ -0,0 +1,4 @@ +_target_: torch.optim.AdamW +lr: 1e-5 +weight_decay: 0.01 +fused: False diff --git a/relik/retriever/conf/optimizer/radam.yaml b/relik/retriever/conf/optimizer/radam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5d2a4ecf468327bda98fd205bf483b6828cf653 --- /dev/null +++ b/relik/retriever/conf/optimizer/radam.yaml @@ -0,0 +1,3 @@ +_target_: torch.optim.RAdam +lr: 1e-5 +weight_decay: 0 diff --git a/relik/retriever/conf/optimizer/radamw.yaml b/relik/retriever/conf/optimizer/radamw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f1fc8c4696bf793180366a64baf107829cf7752 --- /dev/null +++ b/relik/retriever/conf/optimizer/radamw.yaml @@ -0,0 +1,3 @@ +_target_: relik.retriever.pytorch_modules.optim.RAdamW +lr: 1e-5 +weight_decay: 0.01 diff --git a/relik/retriever/conf/pretrain_iterable_in_batch.yaml b/relik/retriever/conf/pretrain_iterable_in_batch.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8331042fbdad34820bc5d48c0afe88c238d812f --- /dev/null +++ b/relik/retriever/conf/pretrain_iterable_in_batch.yaml @@ -0,0 +1,114 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: ${model.language_model} # used to name the model in wandb +project_name: relik-retriever # used to name the project in wandb + +defaults: + - _self_ + - model: golden_retriever + - index: inmemory + - loss: nce_loss + - optimizer: radamw + - scheduler: linear_scheduler + - data: dataset_v2 + - logging: wandb_logging + - override hydra/job_logging: colorlog + - override hydra/hydra_logging: colorlog + +train: + # reproducibility + seed: 42 + set_determinism_the_old_way: False + # torch parameters + float32_matmul_precision: "medium" + # if true, only test the model + only_test: False + # if provided, initialize the model with the weights from the checkpoint + pretrain_ckpt_path: null + # if provided, start training from the checkpoint + checkpoint_path: null + + # task specific parameter + top_k: 100 + + # pl_trainer + pl_trainer: + _target_: lightning.Trainer + accelerator: gpu + devices: 1 + num_nodes: 1 + strategy: auto + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps + check_val_every_n_epoch: 1 + max_epochs: 0 + max_steps: 220_000 + deterministic: True + fast_dev_run: False + precision: 16 + reload_dataloaders_every_n_epochs: 1 + + early_stopping_callback: + null + # _target_: lightning.callbacks.EarlyStopping + # monitor: validate_recall@${train.top_k} + # mode: max + # patience: 15 + + model_checkpoint_callback: + _target_: lightning.callbacks.ModelCheckpoint + monitor: validate_recall@${train.top_k} + mode: max + verbose: True + save_top_k: 1 + save_last: True + filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" + auto_insert_metric_name: False + + callbacks: + prediction_callback: + _target_: relik.retriever.callbacks.training_callbacks.GoldenRetrieverPredictionCallback + k: ${train.top_k} + batch_size: 128 + precision: 16 + index_precision: 16 + other_callbacks: + - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback + k: ${train.top_k} + verbose: True + - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback + k: 50 + verbose: True + prog_bar: False + - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback + k: ${train.top_k} + verbose: True + - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback + + hard_negatives_callback: + _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback + k: ${train.top_k} + batch_size: 128 + precision: 16 + index_precision: 16 + stages: [validate] #[validate, sanity_check] + metrics_to_monitor: + validate_recall@${train.top_k} + # - sanity_check_recall@${train.top_k} + threshold: 0.0 + max_negatives: 15 + add_with_probability: 0.2 + refresh_every_n_epochs: 1 + other_callbacks: + - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback + k: ${train.top_k} + verbose: True + prefix: "train" + + utils_callbacks: + - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback + - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback diff --git a/relik/retriever/conf/scheduler/linear_scheduler.yaml b/relik/retriever/conf/scheduler/linear_scheduler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1896bff5d01ee2543639e4e379674d60682a0f6 --- /dev/null +++ b/relik/retriever/conf/scheduler/linear_scheduler.yaml @@ -0,0 +1,3 @@ +_target_: transformers.get_linear_schedule_with_warmup +num_warmup_steps: 0 +num_training_steps: ${train.pl_trainer.max_steps} diff --git a/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml b/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml new file mode 100644 index 0000000000000000000000000000000000000000..417857489486469032b7e8b19d509a1e45da043c --- /dev/null +++ b/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml @@ -0,0 +1,3 @@ +_target_: transformers.get_linear_schedule_with_warmup +num_warmup_steps: 5_000 +num_training_steps: ${train.pl_trainer.max_steps} diff --git a/relik/retriever/conf/scheduler/none.yaml b/relik/retriever/conf/scheduler/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec747fa47ddb81e9bf2d282011ed32aa4c59f932 --- /dev/null +++ b/relik/retriever/conf/scheduler/none.yaml @@ -0,0 +1 @@ +null \ No newline at end of file diff --git a/relik/retriever/data/__init__.py b/relik/retriever/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/data/__pycache__/__init__.cpython-310.pyc b/relik/retriever/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..431b36c44365c6e490871d8b1c55b24607d2034f Binary files /dev/null and b/relik/retriever/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/retriever/data/__pycache__/labels.cpython-310.pyc b/relik/retriever/data/__pycache__/labels.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbdcf329939e734ed868193ba21a3e671e3d1ce8 Binary files /dev/null and b/relik/retriever/data/__pycache__/labels.cpython-310.pyc differ diff --git a/relik/retriever/data/base/__init__.py b/relik/retriever/data/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/data/base/__pycache__/__init__.cpython-310.pyc b/relik/retriever/data/base/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..030b84b9ef2ba23f6e68987c23f0a9ad229238ce Binary files /dev/null and b/relik/retriever/data/base/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/retriever/data/base/__pycache__/datasets.cpython-310.pyc b/relik/retriever/data/base/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a443536338286bd73f1f7a64690530fc22036d4 Binary files /dev/null and b/relik/retriever/data/base/__pycache__/datasets.cpython-310.pyc differ diff --git a/relik/retriever/data/base/datasets.py b/relik/retriever/data/base/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..9a83731a22b8bfe968804776e3346b5671bc7833 --- /dev/null +++ b/relik/retriever/data/base/datasets.py @@ -0,0 +1,89 @@ +import os +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import torch +from torch.utils.data import Dataset, IterableDataset + +from relik.common.log import get_logger + +logger = get_logger(__name__) + + +class BaseDataset(Dataset): + def __init__( + self, + name: str, + path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None, + data: Any = None, + **kwargs, + ): + super().__init__() + self.name = name + if path is None and data is None: + raise ValueError("Either `path` or `data` must be provided") + self.path = path + self.project_folder = Path(__file__).parent.parent.parent + self.data = data + + def __len__(self) -> int: + return len(self.data) + + def __getitem__( + self, index + ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + return self.data[index] + + def __repr__(self) -> str: + return f"Dataset({self.name=}, {self.path=})" + + def load( + self, + paths: Union[str, os.PathLike, List[str], List[os.PathLike]], + *args, + **kwargs, + ) -> Any: + # load data from single or multiple paths in one single dataset + raise NotImplementedError + + @staticmethod + def collate_fn(batch: Any, *args, **kwargs) -> Any: + raise NotImplementedError + + +class IterableBaseDataset(IterableDataset): + def __init__( + self, + name: str, + path: Optional[Union[str, Path, List[str], List[Path]]] = None, + data: Any = None, + *args, + **kwargs, + ): + super().__init__() + self.name = name + if path is None and data is None: + raise ValueError("Either `path` or `data` must be provided") + self.path = path + self.project_folder = Path(__file__).parent.parent.parent + self.data = data + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + for sample in self.data: + yield sample + + def __repr__(self) -> str: + return f"Dataset({self.name=}, {self.path=})" + + def load( + self, + paths: Union[str, os.PathLike, List[str], List[os.PathLike]], + *args, + **kwargs, + ) -> Any: + # load data from single or multiple paths in one single dataset + raise NotImplementedError + + @staticmethod + def collate_fn(batch: Any, *args, **kwargs) -> Any: + raise NotImplementedError diff --git a/relik/retriever/data/datasets.py b/relik/retriever/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb6315217e7a30617acef7180645245ed6e3331 --- /dev/null +++ b/relik/retriever/data/datasets.py @@ -0,0 +1,720 @@ +import os +from copy import deepcopy +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import datasets +import psutil +import torch +import transformers as tr +from datasets import load_dataset +from torch.utils.data import Dataset +from tqdm import tqdm + +from relik.common.log import get_logger +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset, IterableBaseDataset +from relik.retriever.data.utils import HardNegativesManager + +logger = get_logger(__name__) + + +class SubsampleStrategyEnum(Enum): + NONE = "none" + RANDOM = "random" + IN_ORDER = "in_order" + + +class GoldenRetrieverDataset: + def __init__( + self, + name: str, + path: Union[str, os.PathLike, List[str], List[os.PathLike]] = None, + data: Any = None, + tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, + # passages: Union[str, os.PathLike, List[str]] = None, + passage_batch_size: int = 32, + question_batch_size: int = 32, + max_positives: int = -1, + max_negatives: int = 0, + max_hard_negatives: int = 0, + max_question_length: int = 256, + max_passage_length: int = 64, + shuffle: bool = False, + subsample_strategy: Optional[str] = SubsampleStrategyEnum.NONE, + subsample_portion: float = 0.1, + num_proc: Optional[int] = None, + load_from_cache_file: bool = True, + keep_in_memory: bool = False, + prefetch: bool = True, + load_fn_kwargs: Optional[Dict[str, Any]] = None, + batch_fn_kwargs: Optional[Dict[str, Any]] = None, + collate_fn_kwargs: Optional[Dict[str, Any]] = None, + ): + if path is None and data is None: + raise ValueError("Either `path` or `data` must be provided") + + if tokenizer is None: + raise ValueError("A tokenizer must be provided") + + # dataset parameters + self.name = name + self.path = Path(path) or path + if path is not None and not isinstance(self.path, Sequence): + self.path = [self.path] + # self.project_folder = Path(__file__).parent.parent.parent + self.data = data + + # hyper-parameters + self.passage_batch_size = passage_batch_size + self.question_batch_size = question_batch_size + self.max_positives = max_positives + self.max_negatives = max_negatives + self.max_hard_negatives = max_hard_negatives + self.max_question_length = max_question_length + self.max_passage_length = max_passage_length + self.shuffle = shuffle + self.num_proc = num_proc + self.load_from_cache_file = load_from_cache_file + self.keep_in_memory = keep_in_memory + self.prefetch = prefetch + + self.tokenizer = tokenizer + if isinstance(self.tokenizer, str): + self.tokenizer = tr.AutoTokenizer.from_pretrained(self.tokenizer) + + self.padding_ops = { + "input_ids": partial( + self.pad_sequence, + value=self.tokenizer.pad_token_id, + ), + "attention_mask": partial(self.pad_sequence, value=0), + "token_type_ids": partial( + self.pad_sequence, + value=self.tokenizer.pad_token_type_id, + ), + } + + # check if subsample strategy is valid + if subsample_strategy is not None: + # subsample_strategy can be a string or a SubsampleStrategy + if isinstance(subsample_strategy, str): + try: + subsample_strategy = SubsampleStrategyEnum(subsample_strategy) + except ValueError: + raise ValueError( + f"Subsample strategy `{subsample_strategy}` is not valid. " + f"Valid strategies are: {SubsampleStrategyEnum.__members__}" + ) + if not isinstance(subsample_strategy, SubsampleStrategyEnum): + raise ValueError( + f"Subsample strategy `{subsample_strategy}` is not valid. " + f"Valid strategies are: {SubsampleStrategyEnum.__members__}" + ) + self.subsample_strategy = subsample_strategy + self.subsample_portion = subsample_portion + + # load the dataset + if data is None: + self.data: Dataset = self.load( + self.path, + tokenizer=self.tokenizer, + load_from_cache_file=load_from_cache_file, + load_fn_kwargs=load_fn_kwargs, + num_proc=num_proc, + shuffle=shuffle, + keep_in_memory=keep_in_memory, + max_positives=max_positives, + max_negatives=max_negatives, + max_hard_negatives=max_hard_negatives, + max_question_length=max_question_length, + max_passage_length=max_passage_length, + ) + else: + self.data: Dataset = data + + self.hn_manager: Optional[HardNegativesManager] = None + + # keep track of how many times the dataset has been iterated over + self.number_of_complete_iterations = 0 + + def __repr__(self) -> str: + return f"GoldenRetrieverDataset({self.name=}, {self.path=})" + + def __len__(self) -> int: + raise NotImplementedError + + def __getitem__( + self, index + ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + raise NotImplementedError + + def to_torch_dataset(self, *args, **kwargs) -> torch.utils.data.Dataset: + raise NotImplementedError + + def load( + self, + paths: Union[str, os.PathLike, List[str], List[os.PathLike]], + tokenizer: tr.PreTrainedTokenizer = None, + load_fn_kwargs: Dict = None, + load_from_cache_file: bool = True, + num_proc: Optional[int] = None, + shuffle: bool = False, + keep_in_memory: bool = True, + max_positives: int = -1, + max_negatives: int = -1, + max_hard_negatives: int = -1, + max_passages: int = -1, + max_question_length: int = 256, + max_passage_length: int = 64, + *args, + **kwargs, + ) -> Any: + # if isinstance(paths, Sequence): + # paths = [self.project_folder / path for path in paths] + # else: + # paths = [self.project_folder / paths] + + # read the data and put it in a placeholder list + for path in paths: + if not path.exists(): + raise ValueError(f"{path} does not exist") + + fn_kwargs = dict( + tokenizer=tokenizer, + max_positives=max_positives, + max_negatives=max_negatives, + max_hard_negatives=max_hard_negatives, + max_passages=max_passages, + max_question_length=max_question_length, + max_passage_length=max_passage_length, + ) + if load_fn_kwargs is not None: + fn_kwargs.update(load_fn_kwargs) + + if num_proc is None: + num_proc = psutil.cpu_count(logical=False) + + # The data is a list of dictionaries, each dictionary is a sample + # Each sample has the following keys: + # - "question": the question + # - "answers": a list of answers + # - "positive_ctxs": a list of positive passages + # - "negative_ctxs": a list of negative passages + # - "hard_negative_ctxs": a list of hard negative passages + # use the huggingface dataset library to load the data, by default it will load the + # data in a dict with the key being "train". + logger.info(f"Loading data for dataset {self.name}") + data = load_dataset( + "json", + data_files=[str(p) for p in paths], # datasets needs str paths and not Path + split="train", + streaming=False, # TODO maybe we can make streaming work + keep_in_memory=keep_in_memory, + ) + # add id if not present + if isinstance(data, datasets.Dataset): + data = data.add_column("sample_idx", range(len(data))) + else: + data = data.map( + lambda x, idx: x.update({"sample_idx": idx}), with_indices=True + ) + + map_kwargs = dict( + function=self.load_fn, + fn_kwargs=fn_kwargs, + ) + if isinstance(data, datasets.Dataset): + map_kwargs.update( + dict( + load_from_cache_file=load_from_cache_file, + keep_in_memory=keep_in_memory, + num_proc=num_proc, + desc="Loading data", + ) + ) + # preprocess the data + data = data.map(**map_kwargs) + + # shuffle the data + if shuffle: + data.shuffle(seed=42) + + return data + + @staticmethod + def create_batches( + data: Dataset, + batch_fn: Callable, + batch_fn_kwargs: Optional[Dict[str, Any]] = None, + prefetch: bool = True, + *args, + **kwargs, + ) -> Union[Iterable, List]: + if not prefetch: + # if we are streaming, we don't need to create batches right now + # we will create them on the fly when we need them + batched_data = ( + batch + for batch in batch_fn( + data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) + ) + ) + else: + batched_data = [ + batch + for batch in tqdm( + batch_fn( + data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) + ), + desc="Creating batches", + ) + ] + return batched_data + + @staticmethod + def collate_batches( + batched_data: Union[Iterable, List], + collate_fn: Callable, + collate_fn_kwargs: Optional[Dict[str, Any]] = None, + prefetch: bool = True, + *args, + **kwargs, + ) -> Union[Iterable, List]: + if not prefetch: + collated_data = ( + collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) + for batch in batched_data + ) + else: + collated_data = [ + collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) + for batch in tqdm(batched_data, desc="Collating batches") + ] + return collated_data + + @staticmethod + def load_fn(sample: Dict, *args, **kwargs) -> Dict: + raise NotImplementedError + + @staticmethod + def batch_fn(data: Dataset, *args, **kwargs) -> Any: + raise NotImplementedError + + @staticmethod + def collate_fn(batch: Any, *args, **kwargs) -> Any: + raise NotImplementedError + + @staticmethod + def pad_sequence( + sequence: Union[List, torch.Tensor], + length: int, + value: Any = None, + pad_to_left: bool = False, + ) -> Union[List, torch.Tensor]: + """ + Pad the input to the specified length with the given value. + + Args: + sequence (:obj:`List`, :obj:`torch.Tensor`): + Element to pad, it can be either a :obj:`List` or a :obj:`torch.Tensor`. + length (:obj:`int`, :obj:`str`, optional, defaults to :obj:`subtoken`): + Length after pad. + value (:obj:`Any`, optional): + Value to use as padding. + pad_to_left (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, pads to the left, right otherwise. + + Returns: + :obj:`List`, :obj:`torch.Tensor`: The padded sequence. + + """ + padding = [value] * abs(length - len(sequence)) + if isinstance(sequence, torch.Tensor): + if len(sequence.shape) > 1: + raise ValueError( + f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`" + ) + padding = torch.as_tensor(padding) + if pad_to_left: + if isinstance(sequence, torch.Tensor): + return torch.cat((padding, sequence), -1) + return padding + sequence + if isinstance(sequence, torch.Tensor): + return torch.cat((sequence, padding), -1) + return sequence + padding + + def convert_to_batch( + self, samples: Any, *args, **kwargs + ) -> Dict[str, torch.Tensor]: + """ + Convert the list of samples to a batch. + + Args: + samples (:obj:`List`): + List of samples to convert to a batch. + + Returns: + :obj:`Dict[str, torch.Tensor]`: The batch. + """ + # invert questions from list of dict to dict of list + samples = {k: [d[k] for d in samples] for k in samples[0]} + # get max length of questions + max_len = max(len(x) for x in samples["input_ids"]) + # pad the questions + for key in samples: + if key in self.padding_ops: + samples[key] = torch.as_tensor( + [self.padding_ops[key](b, max_len) for b in samples[key]] + ) + return samples + + def shuffle_data(self, seed: int = 42): + self.data = self.data.shuffle(seed=seed) + + +class InBatchNegativesDataset(GoldenRetrieverDataset): + def __len__(self) -> int: + if isinstance(self.data, datasets.Dataset): + return len(self.data) + + def __getitem__( + self, index + ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + return self.data[index] + + def to_torch_dataset(self) -> torch.utils.data.Dataset: + shuffle_this_time = self.shuffle + + if ( + self.subsample_strategy + and self.subsample_strategy != SubsampleStrategyEnum.NONE + ): + number_of_samples = int(len(self.data) * self.subsample_portion) + if self.subsample_strategy == SubsampleStrategyEnum.RANDOM: + logger.info( + f"Random subsampling {number_of_samples} samples from {len(self.data)}" + ) + data = ( + deepcopy(self.data) + .shuffle(seed=42 + self.number_of_complete_iterations) + .select(range(0, number_of_samples)) + ) + elif self.subsample_strategy == SubsampleStrategyEnum.IN_ORDER: + # number_of_samples = int(len(self.data) * self.subsample_portion) + already_selected = ( + number_of_samples * self.number_of_complete_iterations + ) + logger.info( + f"Subsampling {number_of_samples} samples out of {len(self.data)}" + ) + to_select = min(already_selected + number_of_samples, len(self.data)) + logger.info( + f"Portion of data selected: {already_selected} " f"to {to_select}" + ) + data = deepcopy(self.data).select(range(already_selected, to_select)) + + # don't shuffle the data if we are subsampling, and we have still not completed + # one full iteration over the dataset + if self.number_of_complete_iterations > 0: + shuffle_this_time = False + + # reset the number of complete iterations + if to_select >= len(self.data): + # reset the number of complete iterations, + # we have completed one full iteration over the dataset + # the value is -1 because we want to start from 0 at the next iteration + self.number_of_complete_iterations = -1 + else: + raise ValueError( + f"Subsample strategy `{self.subsample_strategy}` is not valid. " + f"Valid strategies are: {SubsampleStrategyEnum.__members__}" + ) + + else: + data = data = self.data + + # do we need to shuffle the data? + if self.shuffle and shuffle_this_time: + logger.info("Shuffling the data") + data = data.shuffle(seed=42 + self.number_of_complete_iterations) + + batch_fn_kwargs = { + "passage_batch_size": self.passage_batch_size, + "question_batch_size": self.question_batch_size, + "hard_negatives_manager": self.hn_manager, + } + batched_data = self.create_batches( + data, + batch_fn=self.batch_fn, + batch_fn_kwargs=batch_fn_kwargs, + prefetch=self.prefetch, + ) + + batched_data = self.collate_batches( + batched_data, self.collate_fn, prefetch=self.prefetch + ) + + # increment the number of complete iterations + self.number_of_complete_iterations += 1 + + if self.prefetch: + return BaseDataset(name=self.name, data=batched_data) + else: + return IterableBaseDataset(name=self.name, data=batched_data) + + @staticmethod + def load_fn( + sample: Dict, + tokenizer: tr.PreTrainedTokenizer, + max_positives: int, + max_negatives: int, + max_hard_negatives: int, + max_passages: int = -1, + max_question_length: int = 256, + max_passage_length: int = 128, + *args, + **kwargs, + ) -> Dict: + # remove duplicates and limit the number of passages + positives = list(set([p["text"] for p in sample["positive_ctxs"]])) + if max_positives != -1: + positives = positives[:max_positives] + negatives = list(set([n["text"] for n in sample["negative_ctxs"]])) + if max_negatives != -1: + negatives = negatives[:max_negatives] + hard_negatives = list(set([h["text"] for h in sample["hard_negative_ctxs"]])) + if max_hard_negatives != -1: + hard_negatives = hard_negatives[:max_hard_negatives] + + question = tokenizer( + sample["question"], max_length=max_question_length, truncation=True + ) + + passage = positives + negatives + hard_negatives + if max_passages != -1: + passage = passage[:max_passages] + + passage = tokenizer(passage, max_length=max_passage_length, truncation=True) + + # invert the passage data structure from a dict of lists to a list of dicts + passage = [dict(zip(passage, t)) for t in zip(*passage.values())] + + output = dict( + question=question, + passage=passage, + positives=positives, + positive_pssgs=passage[: len(positives)], + ) + return output + + @staticmethod + def batch_fn( + data: Dataset, + passage_batch_size: int, + question_batch_size: int, + hard_negatives_manager: Optional[HardNegativesManager] = None, + *args, + **kwargs, + ) -> Dict[str, List[Dict[str, Any]]]: + def split_batch( + batch: Union[Dict[str, Any], ModelInputs], question_batch_size: int + ) -> List[ModelInputs]: + """ + Split a batch into multiple batches of size `question_batch_size` while keeping + the same number of passages. + """ + + def split_fn(x): + return [ + x[i : i + question_batch_size] + for i in range(0, len(x), question_batch_size) + ] + + # split the sample_idx + sample_idx = split_fn(batch["sample_idx"]) + # split the questions + questions = split_fn(batch["questions"]) + # split the positives + positives = split_fn(batch["positives"]) + # split the positives_pssgs + positives_pssgs = split_fn(batch["positives_pssgs"]) + + # collect the new batches + batches = [] + for i in range(len(questions)): + batches.append( + ModelInputs( + dict( + sample_idx=sample_idx[i], + questions=questions[i], + passages=batch["passages"], + positives=positives[i], + positives_pssgs=positives_pssgs[i], + ) + ) + ) + return batches + + batch = [] + passages_in_batch = {} + + for sample in data: + if len(passages_in_batch) >= passage_batch_size: + # create the batch dict + batch_dict = ModelInputs( + dict( + sample_idx=[s["sample_idx"] for s in batch], + questions=[s["question"] for s in batch], + passages=list(passages_in_batch.values()), + positives_pssgs=[s["positive_pssgs"] for s in batch], + positives=[s["positives"] for s in batch], + ) + ) + # split the batch if needed + if len(batch) > question_batch_size: + for splited_batch in split_batch(batch_dict, question_batch_size): + yield splited_batch + else: + yield batch_dict + + # reset batch + batch = [] + passages_in_batch = {} + + batch.append(sample) + # yes it's a bit ugly but it works :) + # count the number of passages in the batch and stop if we reach the limit + # we use a set to avoid counting the same passage twice + # we use a tuple because set doesn't support lists + # we use input_ids as discriminator + passages_in_batch.update( + {tuple(passage["input_ids"]): passage for passage in sample["passage"]} + ) + # check for hard negatives and add with a probability of 0.1 + if hard_negatives_manager is not None: + if sample["sample_idx"] in hard_negatives_manager: + passages_in_batch.update( + { + tuple(passage["input_ids"]): passage + for passage in hard_negatives_manager.get( + sample["sample_idx"] + ) + } + ) + + # left over + if len(batch) > 0: + # create the batch dict + batch_dict = ModelInputs( + dict( + sample_idx=[s["sample_idx"] for s in batch], + questions=[s["question"] for s in batch], + passages=list(passages_in_batch.values()), + positives_pssgs=[s["positive_pssgs"] for s in batch], + positives=[s["positives"] for s in batch], + ) + ) + # split the batch if needed + if len(batch) > question_batch_size: + for splited_batch in split_batch(batch_dict, question_batch_size): + yield splited_batch + else: + yield batch_dict + + def collate_fn(self, batch: Any, *args, **kwargs) -> Any: + # convert questions and passages to a batch + questions = self.convert_to_batch(batch.questions) + passages = self.convert_to_batch(batch.passages) + + # build an index to map the position of the passage in the batch + passage_index = {tuple(c["input_ids"]): i for i, c in enumerate(batch.passages)} + + # now we can create the labels + labels = torch.zeros( + questions["input_ids"].shape[0], passages["input_ids"].shape[0] + ) + # iterate over the questions and set the labels to 1 if the passage is positive + for sample_idx in range(len(questions["input_ids"])): + for pssg in batch["positives_pssgs"][sample_idx]: + # get the index of the positive passage + index = passage_index[tuple(pssg["input_ids"])] + # set the label to 1 + labels[sample_idx, index] = 1 + + model_inputs = ModelInputs( + { + "questions": questions, + "passages": passages, + "labels": labels, + "positives": batch["positives"], + "sample_idx": batch["sample_idx"], + } + ) + return model_inputs + + +class AidaInBatchNegativesDataset(InBatchNegativesDataset): + def __init__(self, use_topics: bool = False, *args, **kwargs): + if "load_fn_kwargs" not in kwargs: + kwargs["load_fn_kwargs"] = {} + kwargs["load_fn_kwargs"]["use_topics"] = use_topics + super().__init__(*args, **kwargs) + + @staticmethod + def load_fn( + sample: Dict, + tokenizer: tr.PreTrainedTokenizer, + max_positives: int, + max_negatives: int, + max_hard_negatives: int, + max_passages: int = -1, + max_question_length: int = 256, + max_passage_length: int = 128, + use_topics: bool = False, + *args, + **kwargs, + ) -> Dict: + # remove duplicates and limit the number of passages + positives = list(set([p["text"] for p in sample["positive_ctxs"]])) + if max_positives != -1: + positives = positives[:max_positives] + negatives = list(set([n["text"] for n in sample["negative_ctxs"]])) + if max_negatives != -1: + negatives = negatives[:max_negatives] + hard_negatives = list(set([h["text"] for h in sample["hard_negative_ctxs"]])) + if max_hard_negatives != -1: + hard_negatives = hard_negatives[:max_hard_negatives] + + question = sample["question"] + + if "doc_topic" in sample and use_topics: + question = tokenizer( + question, + sample["doc_topic"], + max_length=max_question_length, + truncation=True, + ) + else: + question = tokenizer( + question, max_length=max_question_length, truncation=True + ) + + passage = positives + negatives + hard_negatives + if max_passages != -1: + passage = passage[:max_passages] + + passage = tokenizer(passage, max_length=max_passage_length, truncation=True) + + # invert the passage data structure from a dict of lists to a list of dicts + passage = [dict(zip(passage, t)) for t in zip(*passage.values())] + + output = dict( + question=question, + passage=passage, + positives=positives, + positive_pssgs=passage[: len(positives)], + ) + return output diff --git a/relik/retriever/data/labels.py b/relik/retriever/data/labels.py new file mode 100644 index 0000000000000000000000000000000000000000..be7ee2a1f42aa254ea4aeb87eda3eb3f1a43a616 --- /dev/null +++ b/relik/retriever/data/labels.py @@ -0,0 +1,185 @@ +import json +from pathlib import Path +from typing import Dict, List, Set, Union + + +class Labels: + """ + Class that contains the labels for a model. + + Args: + _labels_to_index (:obj:`Dict[str, Dict[str, int]]`): + A dictionary from :obj:`str` to :obj:`int`. + _index_to_labels (:obj:`Dict[str, Dict[int, str]]`): + A dictionary from :obj:`int` to :obj:`str`. + """ + + def __init__( + self, + _labels_to_index: Dict[str, Dict[str, int]] = None, + _index_to_labels: Dict[str, Dict[int, str]] = None, + **kwargs, + ): + self._labels_to_index = _labels_to_index or {"labels": {}} + self._index_to_labels = _index_to_labels or {"labels": {}} + # if _labels_to_index is not empty and _index_to_labels is not provided + # to the constructor, build the inverted label dictionary + if not _index_to_labels and _labels_to_index: + for namespace in self._labels_to_index: + self._index_to_labels[namespace] = { + v: k for k, v in self._labels_to_index[namespace].items() + } + + def get_index_from_label(self, label: str, namespace: str = "labels") -> int: + """ + Returns the index of a literal label. + + Args: + label (:obj:`str`): + The string representation of the label. + namespace (:obj:`str`, optional, defaults to ``labels``): + The namespace where the label belongs, e.g. ``roles`` for a SRL task. + + Returns: + :obj:`int`: The index of the label. + """ + if namespace not in self._labels_to_index: + raise ValueError( + f"Provided namespace `{namespace}` is not in the label dictionary." + ) + + if label not in self._labels_to_index[namespace]: + raise ValueError(f"Provided label {label} is not in the label dictionary.") + + return self._labels_to_index[namespace][label] + + def get_label_from_index(self, index: int, namespace: str = "labels") -> str: + """ + Returns the string representation of the label index. + + Args: + index (:obj:`int`): + The index of the label. + namespace (:obj:`str`, optional, defaults to ``labels``): + The namespace where the label belongs, e.g. ``roles`` for a SRL task. + + Returns: + :obj:`str`: The string representation of the label. + """ + if namespace not in self._index_to_labels: + raise ValueError( + f"Provided namespace `{namespace}` is not in the label dictionary." + ) + + if index not in self._index_to_labels[namespace]: + raise ValueError( + f"Provided label `{index}` is not in the label dictionary." + ) + + return self._index_to_labels[namespace][index] + + def add_labels( + self, + labels: Union[str, List[str], Set[str], Dict[str, int]], + namespace: str = "labels", + ) -> List[int]: + """ + Adds the labels in input in the label dictionary. + + Args: + labels (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`): + The labels (single label, list of labels or set of labels) to add to the dictionary. + namespace (:obj:`str`, optional, defaults to ``labels``): + Namespace where the labels belongs. + + Returns: + :obj:`List[int]`: The index of the labels just inserted. + """ + if isinstance(labels, dict): + self._labels_to_index[namespace] = labels + self._index_to_labels[namespace] = { + v: k for k, v in self._labels_to_index[namespace].items() + } + # normalize input + if isinstance(labels, (str, list)): + labels = set(labels) + # if new namespace, add to the dictionaries + if namespace not in self._labels_to_index: + self._labels_to_index[namespace] = {} + self._index_to_labels[namespace] = {} + # returns the new indices + return [self._add_label(label, namespace) for label in labels] + + def _add_label(self, label: str, namespace: str = "labels") -> int: + """ + Adds the label in input in the label dictionary. + + Args: + label (:obj:`str`): + The label to add to the dictionary. + namespace (:obj:`str`, optional, defaults to ``labels``): + Namespace where the label belongs. + + Returns: + :obj:`List[int]`: The index of the label just inserted. + """ + if label not in self._labels_to_index[namespace]: + index = len(self._labels_to_index[namespace]) + self._labels_to_index[namespace][label] = index + self._index_to_labels[namespace][index] = label + return index + else: + return self._labels_to_index[namespace][label] + + def get_labels(self, namespace: str = "labels") -> Dict[str, int]: + """ + Returns all the labels that belongs to the input namespace. + + Args: + namespace (:obj:`str`, optional, defaults to ``labels``): + Labels namespace to retrieve. + + Returns: + :obj:`Dict[str, int]`: The label dictionary, from ``str`` to ``int``. + """ + if namespace not in self._labels_to_index: + raise ValueError( + f"Provided namespace `{namespace}` is not in the label dictionary." + ) + return self._labels_to_index[namespace] + + def get_label_size(self, namespace: str = "labels") -> int: + """ + Returns the number of the labels in the namespace dictionary. + + Args: + namespace (:obj:`str`, optional, defaults to ``labels``): + Labels namespace to retrieve. + + Returns: + :obj:`int`: Number of labels. + """ + if namespace not in self._labels_to_index: + raise ValueError( + f"Provided namespace `{namespace}` is not in the label dictionary." + ) + return len(self._labels_to_index[namespace]) + + def get_namespaces(self) -> List[str]: + """ + Returns all the namespaces in the label dictionary. + + Returns: + :obj:`List[str]`: The namespaces in the label dictionary. + """ + return list(self._labels_to_index.keys()) + + @classmethod + def from_file(cls, file_path: Union[str, Path, dict], **kwargs): + with open(file_path, "r") as f: + labels_to_index = json.load(f) + return cls(labels_to_index, **kwargs) + + def save(self, file_path: Union[str, Path, dict], indent: int = 2, **kwargs): + with open(file_path, "w") as f: + json.dump(self._labels_to_index, f, indent=indent) diff --git a/relik/retriever/data/utils.py b/relik/retriever/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..928dfb833919ce2e30c9e90fed539c46f4bd3dec --- /dev/null +++ b/relik/retriever/data/utils.py @@ -0,0 +1,176 @@ +import json +import os +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Union + +import numpy as np +import transformers as tr +from tqdm import tqdm + + +class HardNegativesManager: + def __init__( + self, + tokenizer: tr.PreTrainedTokenizer, + data: Union[List[Dict], os.PathLike, Dict[int, List]] = None, + max_length: int = 64, + batch_size: int = 1000, + lazy: bool = False, + ) -> None: + self._db: dict = None + self.tokenizer = tokenizer + + if data is None: + self._db = {} + else: + if isinstance(data, Dict): + self._db = data + elif isinstance(data, os.PathLike): + with open(data) as f: + self._db = json.load(f) + else: + raise ValueError( + f"Data type {type(data)} not supported, only Dict and os.PathLike are supported." + ) + # add the tokenizer to the class for future use + self.tokenizer = tokenizer + + # invert the db to have a passage -> sample_idx mapping + self._passage_db = defaultdict(set) + for sample_idx, passages in self._db.items(): + for passage in passages: + self._passage_db[passage].add(sample_idx) + + self._passage_hard_negatives = {} + if not lazy: + # create a dictionary of passage -> hard_negative mapping + batch_size = min(batch_size, len(self._passage_db)) + unique_passages = list(self._passage_db.keys()) + for i in tqdm( + range(0, len(unique_passages), batch_size), + desc="Tokenizing Hard Negatives", + ): + batch = unique_passages[i : i + batch_size] + tokenized_passages = self.tokenizer( + batch, + max_length=max_length, + truncation=True, + ) + for i, passage in enumerate(batch): + self._passage_hard_negatives[passage] = { + k: tokenized_passages[k][i] for k in tokenized_passages.keys() + } + + def __len__(self) -> int: + return len(self._db) + + def __getitem__(self, idx: int) -> Dict: + return self._db[idx] + + def __iter__(self): + for sample in self._db: + yield sample + + def __contains__(self, idx: int) -> bool: + return idx in self._db + + def get(self, idx: int) -> List[str]: + """Get the hard negatives for a given sample index.""" + if idx not in self._db: + raise ValueError(f"Sample index {idx} not in the database.") + + passages = self._db[idx] + + output = [] + for passage in passages: + if passage not in self._passage_hard_negatives: + self._passage_hard_negatives[passage] = self._tokenize(passage) + output.append(self._passage_hard_negatives[passage]) + + return output + + def _tokenize(self, passage: str) -> Dict: + return self.tokenizer(passage, max_length=self.max_length, truncation=True) + + +class NegativeSampler: + def __init__( + self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None + ): + if not isinstance(probabilities, np.ndarray): + probabilities = np.array(probabilities) + + if probabilities is None: + # probabilities should sum to 1 + probabilities = np.random.random(num_elements) + probabilities /= np.sum(probabilities) + self.probabilities = probabilities + + def __call__( + self, + sample_size: int, + num_samples: int = 1, + probabilities: np.array = None, + exclude: List[int] = None, + ) -> np.array: + """ + Fast sampling of `sample_size` elements from `num_elements` elements. + The sampling is done by randomly shifting the probabilities and then + finding the smallest of the negative numbers. This is much faster than + sampling from a multinomial distribution. + + Args: + sample_size (`int`): + number of elements to sample + num_samples (`int`, optional): + number of samples to draw. Defaults to 1. + probabilities (`np.array`, optional): + probabilities of each element. Defaults to None. + exclude (`List[int]`, optional): + indices of elements to exclude. Defaults to None. + + Returns: + `np.array`: array of sampled indices + """ + if probabilities is None: + probabilities = self.probabilities + + if exclude is not None: + probabilities[exclude] = 0 + # re-normalize? + # probabilities /= np.sum(probabilities) + + # replicate probabilities as many times as `num_samples` + replicated_probabilities = np.tile(probabilities, (num_samples, 1)) + # get random shifting numbers & scale them correctly + random_shifts = np.random.random(replicated_probabilities.shape) + random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis] + # shift by numbers & find largest (by finding the smallest of the negative) + shifted_probabilities = random_shifts - replicated_probabilities + sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[ + :, :sample_size + ] + return sampled_indices + + +def batch_generator(samples: Iterable[Any], batch_size: int) -> Iterable[Any]: + """ + Generate batches from samples. + + Args: + samples (`Iterable[Any]`): Iterable of samples. + batch_size (`int`): Batch size. + + Returns: + `Iterable[Any]`: Iterable of batches. + """ + batch = [] + for sample in samples: + batch.append(sample) + if len(batch) == batch_size: + yield batch + batch = [] + + # leftover batch + if len(batch) > 0: + yield batch diff --git a/relik/retriever/indexers/__init__.py b/relik/retriever/indexers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/indexers/__pycache__/__init__.cpython-310.pyc b/relik/retriever/indexers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..636e12463df2a956bf362a910d640f800d3d6abc Binary files /dev/null and b/relik/retriever/indexers/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/retriever/indexers/__pycache__/base.cpython-310.pyc b/relik/retriever/indexers/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfabf2c2b23197f4736a0337f2c6adec15b7c026 Binary files /dev/null and b/relik/retriever/indexers/__pycache__/base.cpython-310.pyc differ diff --git a/relik/retriever/indexers/__pycache__/document.cpython-310.pyc b/relik/retriever/indexers/__pycache__/document.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3056ace43e427731e7d67a120c3c5b672c84ad8 Binary files /dev/null and b/relik/retriever/indexers/__pycache__/document.cpython-310.pyc differ diff --git a/relik/retriever/indexers/__pycache__/inmemory.cpython-310.pyc b/relik/retriever/indexers/__pycache__/inmemory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09cfe5fd507e622e1a905761b89b16fb1bbfbc87 Binary files /dev/null and b/relik/retriever/indexers/__pycache__/inmemory.cpython-310.pyc differ diff --git a/relik/retriever/indexers/base.py b/relik/retriever/indexers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..804760fb4a370265ffee7d47a2cc391dc8d36e84 --- /dev/null +++ b/relik/retriever/indexers/base.py @@ -0,0 +1,478 @@ +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import hydra +import numpy +import torch +from omegaconf import OmegaConf +from pprintpp import pformat + +from relik.common.log import get_logger +from relik.common.upload import upload +from relik.common.utils import ( + from_cache, + is_str_a_path, + relative_to_absolute_path, + to_config, +) +from relik.retriever.indexers.document import Document, DocumentStore + +logger = get_logger(__name__) + + +@dataclass +class IndexerOutput: + indices: Union[torch.Tensor, numpy.ndarray] + distances: Union[torch.Tensor, numpy.ndarray] + + +class BaseDocumentIndex: + """ + Base class for document indexes. + + Args: + documents (:obj:`str`, :obj:`List[str]`, :obj:`os.PathLike`, :obj:`List[os.PathLike]`, :obj:`DocumentStore`, `optional`): + The documents to index. If `None`, an empty document store will be created. Defaults to `None`. + embeddings (:obj:`torch.Tensor`, `optional`): + The embeddings of the documents. If `None`, the documents will not be indexed. Defaults to `None`. + name_or_path (:obj:`str`, :obj:`os.PathLike`, `optional`): + The name or directory of the retriever. + """ + + CONFIG_NAME = "config.yaml" + DOCUMENTS_FILE_NAME = "documents.jsonl" + EMBEDDINGS_FILE_NAME = "embeddings.pt" + + def __init__( + self, + documents: str + | List[str] + | os.PathLike + | List[os.PathLike] + | DocumentStore + | None = None, + embeddings: torch.Tensor | None = None, + metadata_fields: List[str] | None = None, + separator: str | None = None, + name_or_path: str | os.PathLike | None = None, + device: str = "cpu", + ) -> None: + if metadata_fields is None: + metadata_fields = [] + + self.metadata_fields = metadata_fields + self.separator = separator + + self.document_path: List[str | os.PathLike] = [] + + if documents is not None: + if isinstance(documents, DocumentStore): + self.documents = documents + else: + documents_are_paths = False + + # normalize the documents to list if not already + if not isinstance(documents, list): + documents = [documents] + + # now check if the documents are a list of paths (either str or os.PathLike) + if isinstance(documents[0], str) or isinstance( + documents[0], os.PathLike + ): + # check if the str is a path + documents_are_paths = is_str_a_path(documents[0]) + + # if the documents are a list of paths, then we load them + if documents_are_paths: + logger.info("Loading documents from paths") + _documents = [] + for doc in documents: + with open(relative_to_absolute_path(doc)) as f: + self.document_path.append(doc) + _documents += [ + Document.from_dict(json.loads(line)) + for line in f.readlines() + ] + # remove duplicates + documents = _documents + + self.documents = DocumentStore(documents) + else: + self.documents = DocumentStore() + + self.embeddings = embeddings + self.name_or_path = name_or_path + + # store the device in case embeddings are not provided + self.device_in_init = device + + def __iter__(self): + # make this class iterable + for i in range(len(self)): + yield self[i] + + def __len__(self): + return len(self.documents) + + def __getitem__(self, index): + return self.get_passage_from_index(index) + + def to( + self, device_or_precision: str | torch.device | torch.dtype + ) -> "BaseDocumentIndex": + """ + Move the retriever to the specified device or precision. + + Args: + device_or_precision (`str` | `torch.device` | `torch.dtype`): + The device or precision to move the retriever to. + + Returns: + `BaseDocumentIndex`: The retriever. + """ + if self.embeddings is not None: + if isinstance(device_or_precision, torch.dtype) and self.device != "cpu": + # if the device is a dtype, then we need to move the embeddings to cpu + # first before converting to the dtype to avoid OOM + previous_device = self.embeddings.device + self.embeddings = self.embeddings.cpu() + self.embeddings = self.embeddings.to(device_or_precision) + self.embeddings = self.embeddings.to(previous_device) + else: + if isinstance(device_or_precision, torch.device): + self.embeddings = self.embeddings.to(device_or_precision) + else: + if device_or_precision != self.embeddings.dtype and self.device != "cpu": + self.embeddings = self.embeddings.to(device_or_precision) + # self.embeddings = self.embeddings.to(device_or_precision) + return self + + @property + def device(self): + return ( + self.embeddings.device + if self.embeddings is not None + else self.device_in_init + ) + + @property + def config(self) -> Dict[str, Any]: + """ + The configuration of the document index. + + Returns: + `Dict[str, Any]`: The configuration of the retriever. + """ + + config = { + "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}", + "metadata_fields": self.metadata_fields, + "separator": self.separator, + "name_or_path": self.name_or_path, + } + if len(self.document_path) > 0: + config["documents"] = self.document_path + return config + + def index( + self, + retriever, + *args, + **kwargs, + ) -> "BaseDocumentIndex": + raise NotImplementedError + + def search(self, query: Any, k: int = 1, *args, **kwargs) -> List: + raise NotImplementedError + + def get_document_from_passage(self, passage: str) -> Document | None: + """ + Get the document label from the passage. + + Args: + passage (`str`): + The document to get the label for. + + Returns: + `str`: The document label. + """ + # get the text from the document + if self.separator: + text = passage.split(self.separator)[0] + else: + text = passage + return self.documents.get_document_from_text(text) + + def get_index_from_passage(self, passage: str) -> int: + """ + Get the index of the passage. + + Args: + passage (`str`): + The document to get the index for. + + Returns: + `int`: The index of the document. + """ + # get the text from the document + doc = self.get_document_from_passage(passage) + if doc is None: + raise ValueError(f"Document `{passage}` not found.") + return doc.id + + def get_document_from_index(self, index: int) -> Document | None: + """ + Get the document from the index. + + Args: + index (`int`): + The index of the document. + + Returns: + `str`: The document. + """ + return self.documents.get_document_from_id(index) + + def get_passage_from_index(self, index: int) -> str: + """ + Get the document from the index. + + Args: + index (`int`): + The index of the document. + + Returns: + `str`: The document. + """ + document = self.get_document_from_index(index) + # build the passage using the metadata fields + passage = document.text + for field in self.metadata_fields: + passage += f"{self.separator}{document.metadata[field]}" + return passage + + def get_passage_from_document(self, document: Document) -> str: + passage = document.text + for field in self.metadata_fields: + passage += f"{self.separator}{document.metadata[field]}" + return passage + + def get_embeddings_from_index(self, index: int) -> torch.Tensor: + """ + Get the document vector from the index. + + Args: + index (`int`): + The index of the document. + + Returns: + `torch.Tensor`: The document vector. + """ + if self.embeddings is None: + raise ValueError( + "The documents must be indexed before they can be retrieved." + ) + if index >= self.embeddings.shape[0]: + raise ValueError( + f"The index {index} is out of bounds. The maximum index is {len(self.embeddings) - 1}." + ) + return self.embeddings[index] + + def get_embeddings_from_passage(self, document: str) -> torch.Tensor: + """ + Get the document vector from the document label. + + Args: + document (`str`): + The document to get the vector for. + + Returns: + `torch.Tensor`: The document vector. + """ + if self.embeddings is None: + raise ValueError( + "The documents must be indexed before they can be retrieved." + ) + return self.get_embeddings_from_index(self.get_index_from_passage(document)) + + def get_embeddings_from_document(self, document: str) -> torch.Tensor: + """ + Get the document vector from the document label. + + Args: + document (`str`): + The document to get the vector for. + + Returns: + `torch.Tensor`: The document vector. + """ + if self.embeddings is None: + raise ValueError( + "The documents must be indexed before they can be retrieved." + ) + return self.get_embeddings_from_index(self.get_index_from_document(document)) + + def get_passages(self, documents: DocumentStore | None = None) -> List[str]: + """ + Get the passages from the document store. + + Returns: + `List[str]`: The passages. + """ + documents = documents or self.documents + # construct the passages from the documents + # return [self.get_passage_from_index(i) for i in range(len(documents))] + return [self.get_passage_from_document(doc) for doc in documents] + + def save_pretrained( + self, + output_dir: Union[str, os.PathLike], + config: Optional[Dict[str, Any]] = None, + config_file_name: str | None = None, + document_file_name: str | None = None, + embedding_file_name: str | None = None, + push_to_hub: bool = False, + model_id: str | None = None, + **kwargs, + ): + """ + Save the retriever to a directory. + + Args: + output_dir (`str`): + The directory to save the retriever to. + config (`Optional[Dict[str, Any]]`, `optional`): + The configuration to save. If `None`, the current configuration of the retriever will be + saved. Defaults to `None`. + config_file_name (`str | None`, `optional`): + The name of the configuration file. Defaults to `config.yaml`. + document_file_name (`str | None`, `optional`): + The name of the document file. Defaults to `documents.json`. + embedding_file_name (`str | None`, `optional`): + The name of the embedding file. Defaults to `embeddings.pt`. + push_to_hub (`bool`, `optional`): + Whether to push the saved retriever to the hub. Defaults to `False`. + model_id (`str | None`, `optional`): + The id of the model to push to the hub. If `None`, the name of the output + directory will be used. Defaults to `None`. + **kwargs: + Additional keyword arguments to pass to `upload`. + """ + if config is None: + # create a default config + config = self.config + + config_file_name = config_file_name or self.CONFIG_NAME + document_file_name = document_file_name or self.DOCUMENTS_FILE_NAME + embedding_file_name = embedding_file_name or self.EMBEDDINGS_FILE_NAME + + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving retriever to {output_dir}") + logger.info(f"Saving config to {output_dir / config_file_name}") + # pretty print the config + OmegaConf.save(config, output_dir / config_file_name) + logger.info(pformat(config)) + + # save the current state of the retriever + embedding_path = output_dir / embedding_file_name + logger.info(f"Saving retriever state to {output_dir / embedding_path}") + torch.save(self.embeddings, embedding_path) + + # save the passage index + documents_path = output_dir / document_file_name + logger.info(f"Saving passage index to {documents_path}") + self.documents.save(documents_path) + + logger.info("Saving document index to disk done.") + + if push_to_hub: + # push to hub + logger.info("Pushing to hub") + model_id = model_id or output_dir.name + upload(output_dir, model_id, **kwargs) + + @classmethod + def from_pretrained( + cls, + name_or_dir: Union[str, os.PathLike], + device: str = "cpu", + precision: str | None = None, + config_file_name: str | None = None, + document_file_name: str | None = None, + embedding_file_name: str | None = None, + *args, + **kwargs, + ) -> "BaseDocumentIndex": + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + + config_file_name = config_file_name or cls.CONFIG_NAME + document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME + embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME + + model_dir = from_cache( + name_or_dir, + filenames=[config_file_name, document_file_name, embedding_file_name], + cache_dir=cache_dir, + force_download=force_download, + ) + + config_path = model_dir / config_file_name + if not config_path.exists(): + raise FileNotFoundError( + f"Model configuration file not found at {config_path}." + ) + + config = OmegaConf.load(config_path) + # override the config with the kwargs + # if config_kwargs is not None: + config = OmegaConf.merge(config, OmegaConf.create(kwargs)) + logger.info("Loading Index from config:") + logger.info(pformat(OmegaConf.to_container(config))) + + # load the documents + documents_path = model_dir / document_file_name + + if not documents_path.exists(): + raise ValueError(f"Document file `{documents_path}` does not exist.") + logger.info(f"Loading documents from {documents_path}") + documents = DocumentStore.from_file(documents_path) + # TODO: probably is better to do the opposite and iterate over the config + # check for each possible attribute ind DocumentStore + for attr in dir(documents): + if attr.startswith("__"): + continue + if attr not in config: + continue + # set the attribute + setattr(documents, attr, config[attr]) + + # load the passage embeddings + embedding_path = model_dir / embedding_file_name + # run some checks + embeddings = None + if embedding_path.exists(): + logger.info(f"Loading embeddings from {embedding_path}") + embeddings = torch.load(embedding_path, map_location="cpu") + else: + logger.warning(f"Embedding file `{embedding_path}` does not exist.") + + document_index = hydra.utils.instantiate( + config, + documents=documents, + embeddings=embeddings, + device=device, + precision=precision, + name_or_dir=name_or_dir, + _convert_="partial", + *args, + **kwargs, + ) + + return document_index diff --git a/relik/retriever/indexers/document.py b/relik/retriever/indexers/document.py new file mode 100644 index 0000000000000000000000000000000000000000..476ad8c1868456785bf7dd451389b3dab82c4b66 --- /dev/null +++ b/relik/retriever/indexers/document.py @@ -0,0 +1,337 @@ +import csv +import json +import pickle +from pathlib import Path +from typing import Dict, List, Union + +from relik.common.log import get_logger + +logger = get_logger(__name__) + + +class Document: + def __init__( + self, + text: str, + id: int | None = None, + metadata: Dict | None = None, + **kwargs, + ): + self.text = text + # if id is not provided, we use the hash of the text + self.id = id if id is not None else hash(text) + # if metadata is not provided, we use an empty dictionary + self.metadata = metadata or {} + + def __str__(self): + return f"{self.id}:{self.text}" + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + if isinstance(other, Document): + return self.id == other.id + elif isinstance(other, int): + return self.id == other + elif isinstance(other, str): + return self.text == other + else: + raise ValueError( + f"Document must be compared with a Document, an int or a str, got `{type(other)}`" + ) + + def to_dict(self): + return {"text": self.text, "id": self.id, "metadata": self.metadata} + + @classmethod + def from_dict(cls, d: Dict): + return cls(**d) + + @classmethod + def from_file(cls, file_path: Union[str, Path], **kwargs): + with open(file_path, "r") as f: + d = json.load(f) + return cls.from_dict(d) + + def save(self, file_path: Union[str, Path], **kwargs): + with open(file_path, "w") as f: + json.dump(self.to_dict(), f, indent=2) + + +class DocumentStore: + """ + A document store is a collection of documents. + + Args: + documents (:obj:`List[Document]`): + The documents to store. + """ + + def __init__(self, documents: List[Document] = None) -> None: + if documents is None: + documents = [] + # if self.ingore_case: + # documents = [doc.lower() for doc in documents] + self._documents = documents + # build an index for the documents + self._documents_index = {doc.id: doc for doc in self._documents} + # build a reverse index for the documents + self._documents_reverse_index = {doc.text: doc for doc in self._documents} + + def __len__(self): + return len(self._documents) + + def __getitem__(self, index): + return self._documents[index] + + def __iter__(self): + return iter(self._documents) + + def __contains__(self, item): + if isinstance(item, int): + return item in self._documents_index + elif isinstance(item, str): + return item in self._documents_reverse_index + elif isinstance(item, Document): + return item.id in self._documents_index + # return item in self._documents_index + + def __str__(self): + return f"DocumentStore with {len(self)} documents" + + def __repr__(self): + return self.__str__() + + def get_document_from_id(self, id: int) -> Document | None: + """ + Retrieve a document by its ID. + + Args: + id (`int`): + The ID of the document to retrieve. + + Returns: + Optional[Document]: The document with the given ID, or None if it does not exist. + """ + if id not in self._documents_index: + logger.warning(f"Document with id `{id}` does not exist, skipping") + return self._documents_index.get(id, None) + + def get_document_from_text(self, text: str) -> Document | None: + """ + Retrieve the document by its text. + + Args: + text (`str`): + The text of the document to retrieve. + + Returns: + Optional[Document]: The document with the given text, or None if it does not exist. + """ + if text not in self._documents_reverse_index: + logger.warning(f"Document with text `{text}` does not exist, skipping") + return self._documents_reverse_index.get(text, None) + + def add_documents(self, documents: List[Document] | List[Dict]) -> List[Document]: + """ + Add a list of documents to the document store. + + Args: + documents (`List[Document]`): + The documents to add. + + Returns: + List[Document]: The documents just added. + """ + return [ + self.add_document(doc) + if isinstance(doc, Document) + else self.add_document(Document.from_dict(doc)) + for doc in documents + ] + + def add_document( + self, + text: str | Document, + id: int | None = None, + metadata: Dict | None = None, + ) -> Document: + """ + Add a document to the document store. + + Args: + text (`str`): + The text of the document to add. + id (`int`, optional, defaults to None): + The ID of the document to add. + metadata (`Dict`, optional, defaults to None): + The metadata of the document to add. + + Returns: + Document: The document just added. + """ + if isinstance(text, str): + if id is None: + # get the len of the documents and add 1 + id = len(self._documents) # + 1 + text = Document(text, id, metadata) + + if text in self: + logger.warning(f"Document {text} already exists, skipping") + return self._documents_index[text.id] + + self._documents.append(text) + self._documents_index[text.id] = text + self._documents_reverse_index[text.text] = text + return text + # if id in self._documents_index: + # logger.warning(f"Document with id `{id}` already exists, skipping") + # return self._documents_index[id] + # if text_or_document in self._documents_reverse_index: + # logger.warning(f"Document with text `{text_or_document}` already exists, skipping") + # return self._documents_reverse_index[text_or_document] + # self._documents.append(Document(text_or_document, id, metadata)) + # self._documents_index[id] = self._documents[-1] + # self._documents_reverse_index[text_or_document] = self._documents[-1] + # return self._documents_index[id] + + def delete_document(self, document: int | str | Document) -> bool: + """ + Delete a document from the document store. + + Args: + document (`int`, `str` or `Document`): + The document to delete. + + Returns: + bool: True if the document has been deleted, False otherwise. + """ + if isinstance(document, int): + return self.delete_by_id(document) + elif isinstance(document, str): + return self.delete_by_text(document) + elif isinstance(document, Document): + return self.delete_by_document(document) + else: + raise ValueError( + f"Document must be an int, a str or a Document, got `{type(document)}`" + ) + + def delete_by_id(self, id: int) -> bool: + """ + Delete a document by its ID. + + Args: + id (`int`): + The ID of the document to delete. + + Returns: + bool: True if the document has been deleted, False otherwise. + """ + if id not in self._documents_index: + logger.warning(f"Document with id `{id}` does not exist, skipping") + return False + del self._documents_reverse_index[self._documents_index[id]] + del self._documents_index[id] + return True + + def delete_by_text(self, text: str) -> bool: + """ + Delete a document by its text. + + Args: + text (`str`): + The text of the document to delete. + + Returns: + bool: True if the document has been deleted, False otherwise. + """ + if text not in self._documents_reverse_index: + logger.warning(f"Document with text `{text}` does not exist, skipping") + return False + del self._documents_reverse_index[text] + del self._documents_index[self._documents_index[text]] + return True + + def delete_by_document(self, document: Document) -> bool: + """ + Delete a document by its text. + + Args: + document (:obj:`Document`): + The document to delete. + + Returns: + bool: True if the document has been deleted, False otherwise. + """ + if document.id not in self._documents_index: + logger.warning(f"Document {document} does not exist, skipping") + return False + del self._documents[self._documents.index(document)] + del self._documents_index[document.id] + del self._documents_reverse_index[self._documents_index[document.id]] + + def to_dict(self): + return [doc.to_dict() for doc in self._documents] + + @classmethod + def from_dict(cls, d): + return cls([Document.from_dict(doc) for doc in d]) + + @classmethod + def from_file(cls, file_path: Union[str, Path], **kwargs): + with open(file_path, "r") as f: + # load a json lines file + d = [Document.from_dict(json.loads(line)) for line in f] + return cls(d) + + @classmethod + def from_pickle(cls, file_path: Union[str, Path], **kwargs): + with open(file_path, "rb") as handle: + d = pickle.load(handle) + return cls(d) + + @classmethod + def from_tsv( + cls, + file_path: Union[str, Path], + ingore_case: bool = False, + delimiter: str = "\t", + **kwargs, + ): + d = [] + # load a tsv/csv file and take the header into account + # the header must be `id\ttext\t[list of metadata keys]` + with open(file_path, "r", encoding="utf8") as f: + csv_reader = csv.reader(f, delimiter=delimiter, **kwargs) + header = next(csv_reader) + id, text, *metadata_keys = header + for i, row in enumerate(csv_reader): + # check if id can be casted to int + # if not, we add it to the metadata and use `i` as id + try: + s_id = int(row[header.index(id)]) + row_metadata_keys = metadata_keys + except ValueError: + row_metadata_keys = [id] + metadata_keys + s_id = i + + d.append( + Document( + text=row[header.index(text)].strip().lower() + if ingore_case + else row[header.index(text)].strip(), + id=s_id, # row[header.index(id)], + metadata={ + key: row[header.index(key)] for key in row_metadata_keys + }, + ) + ) + return cls(d) + + def save(self, file_path: Union[str, Path], **kwargs): + with open(file_path, "w") as f: + for doc in self._documents: + # save as json lines + f.write(json.dumps(doc.to_dict()) + "\n") diff --git a/relik/retriever/indexers/faiss.py b/relik/retriever/indexers/faiss.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8495c307f41a66fe69aaabf80d93ca0c704eba --- /dev/null +++ b/relik/retriever/indexers/faiss.py @@ -0,0 +1,457 @@ +import contextlib +import logging +import os +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy +import psutil +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from relik.common.log import get_logger +from relik.common.utils import is_package_available +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.indexers.document import Document, DocumentStore +from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample +from relik.retriever.pytorch_modules.model import GoldenRetriever + +if is_package_available("faiss"): + import faiss + import faiss.contrib.torch_utils + +logger = get_logger(__name__, level=logging.INFO) + + +@dataclass +class FaissOutput: + indices: Union[torch.Tensor, numpy.ndarray] + distances: Union[torch.Tensor, numpy.ndarray] + + +class FaissDocumentIndex(BaseDocumentIndex): + DOCUMENTS_FILE_NAME = "documents.json" + EMBEDDINGS_FILE_NAME = "embeddings.pt" + INDEX_FILE_NAME = "index.faiss" + + def __init__( + self, + documents: str + | List[str] + | os.PathLike + | List[os.PathLike] + | DocumentStore + | None = None, + embeddings: torch.Tensor | numpy.ndarray | None = None, + metadata_fields: List[str] | None = None, + separator: str = "", + name_or_path: str | os.PathLike | None = None, + device: str = "cpu", + index=None, + index_type: str = "Flat", + nprobe: int = 1, + metric: int = faiss.METRIC_INNER_PRODUCT, + normalize: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__( + documents, embeddings, metadata_fields, separator, name_or_path, device + ) + + if embeddings is not None and documents is not None: + logger.info("Both documents and embeddings are provided.") + if len(documents) != embeddings.shape[0]: + raise ValueError( + "The number of documents and embeddings must be the same." + ) + + faiss.omp_set_num_threads(psutil.cpu_count(logical=False)) + + # params + self.index_type = index_type + self.metric = metric + self.normalize = normalize + + if index is not None: + self.embeddings = index + if self.device == "cuda": + # use a single GPU + faiss_resource = faiss.StandardGpuResources() + self.embeddings = faiss.index_cpu_to_gpu( + faiss_resource, 0, self.embeddings + ) + else: + if embeddings is not None: + # build the faiss index + logger.info("Building the index from the embeddings.") + self.embeddings = self._build_faiss_index( + embeddings=embeddings, + index_type=index_type, + nprobe=nprobe, + normalize=normalize, + metric=metric, + ) + + def to( + self, device_or_precision: str | torch.device | torch.dtype + ) -> "BaseDocumentIndex": + """ + Move the retriever to the specified device or precision. + + Args: + device_or_precision (`str` | `torch.device` | `torch.dtype`): + The device or precision to move the retriever to. + + Returns: + `BaseDocumentIndex`: The retriever. + """ + if isinstance(device_or_precision, torch.dtype): + # raise ValueError( + # "FaissDocumentIndex does not support precision conversion." + # ) + logger.warning( + "FaissDocumentIndex does not support precision conversion. Ignoring." + ) + if device_or_precision == "cuda" and self.device == "cpu": + # use a single GPU + faiss_resource = faiss.StandardGpuResources() + self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings) + elif device_or_precision == "cpu" and self.device == "cuda": + # move faiss index to CPU + self.embeddings = faiss.index_gpu_to_cpu(self.embeddings) + else: + logger.warning( + f"Provided device `{device_or_precision}` is the same as the current device `{self.device}`." + ) + return self + + @property + def device(self): + # check if faiss index is on GPU + if faiss.get_num_gpus() > 0: + return "cuda" + return "cpu" + + def _build_faiss_index( + self, + embeddings: Optional[Union[torch.Tensor, numpy.ndarray]], + index_type: str, + nprobe: int, + normalize: bool, + metric: int, + ): + # build the faiss index + self.normalize = ( + normalize + and metric == faiss.METRIC_INNER_PRODUCT + and not isinstance(embeddings, torch.Tensor) + ) + if self.normalize: + index_type = f"L2norm,{index_type}" + faiss_vector_size = embeddings.shape[1] + # if self.device == "cpu": + # index_type = index_type.replace("x,", "x_HNSW32,") + # nlist = math.ceil(math.sqrt(faiss_vector_size)) * 4 + # # nlist = 8 + # index_type = index_type.replace( + # "x", str(nlist) + # ) + # print("Current nlist:", nlist) + self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric) + + # convert to GPU + if self.device == "cuda": + # use a single GPU + faiss_resource = faiss.StandardGpuResources() + self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings) + else: + # move to CPU if embeddings is a torch.Tensor + embeddings = ( + embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings + ) + + # convert to float32 if embeddings is a torch.Tensor and is float16 + if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16: + embeddings = embeddings.float() + + logger.info("Training the index.") + self.embeddings.train(embeddings) + + logger.info("Adding the embeddings to the index.") + self.embeddings.add(embeddings) + + self.embeddings.nprobe = nprobe + + # save parameters for saving/loading + self.index_type = index_type + self.metric = metric + + # clear the embeddings to free up memory + embeddings = None + + return self.embeddings + + @torch.no_grad() + @torch.inference_mode() + def index( + self, + retriever: GoldenRetriever, + documents: Optional[List[Document]] = None, + batch_size: int = 32, + num_workers: int = 4, + max_length: Optional[int] = None, + collate_fn: Optional[Callable] = None, + encoder_precision: Optional[Union[str, int]] = None, + compute_on_cpu: bool = False, + force_reindex: bool = False, + *args, + **kwargs, + ) -> "FaissDocumentIndex": + """ + Index the documents using the encoder. + + Args: + retriever (:obj:`torch.nn.Module`): + The encoder to be used for indexing. + documents (:obj:`List[Document]`, `optional`, defaults to None): + The documents to be indexed. + batch_size (:obj:`int`, `optional`, defaults to 32): + The batch size to be used for indexing. + num_workers (:obj:`int`, `optional`, defaults to 4): + The number of workers to be used for indexing. + max_length (:obj:`int`, `optional`, defaults to None): + The maximum length of the input to the encoder. + collate_fn (:obj:`Callable`, `optional`, defaults to None): + The collate function to be used for batching. + encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): + The precision to be used for the encoder. + compute_on_cpu (:obj:`bool`, `optional`, defaults to False): + Whether to compute the embeddings on CPU. + force_reindex (:obj:`bool`, `optional`, defaults to False): + Whether to force reindexing. + + Returns: + :obj:`InMemoryIndexer`: The indexer object. + """ + + if self.embeddings is not None and not force_reindex: + logger.log( + "Embeddings are already present and `force_reindex` is `False`. Skipping indexing." + ) + if documents is None: + return self + + # release the memory + if collate_fn is None: + tokenizer = retriever.passage_tokenizer + + def collate_fn(x): + return ModelInputs( + tokenizer( + x, + padding=True, + return_tensors="pt", + truncation=True, + max_length=max_length or tokenizer.model_max_length, + ) + ) + + if force_reindex: + if documents is not None: + self.documents.add_document(documents) + data = [k for k in self.get_passages()] + + else: + if documents is not None: + data = [k for k in self.get_passages(DocumentStore(documents))] + else: + return self + + dataloader = DataLoader( + BaseDataset(name="passage", data=data), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, + ) + + encoder = retriever.passage_encoder + + # Create empty lists to store the passage embeddings and passage index + passage_embeddings: List[torch.Tensor] = [] + + encoder_device = "cpu" if compute_on_cpu else self.device + + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(encoder_device).split(":")[0] + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_pssg_mngr = ( + contextlib.nullcontext() + if device_type_for_autocast == "cpu" + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=PRECISION_MAP[encoder_precision], + ) + ) + ) + with autocast_pssg_mngr: + # Iterate through each batch in the dataloader + for batch in tqdm(dataloader, desc="Indexing"): + # Move the batch to the device + batch: ModelInputs = batch.to(encoder_device) + # Compute the passage embeddings + passage_outs = encoder(**batch) + # Append the passage embeddings to the list + if self.device == "cpu": + passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) + else: + passage_embeddings.extend([c for c in passage_outs]) + + # move the passage embeddings to the CPU if not already done + passage_embeddings = [c.detach().cpu() for c in passage_embeddings] + # stack it + passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) + # convert to float32 for faiss + passage_embeddings.to(PRECISION_MAP["float32"]) + + # index the embeddings + self.embeddings = self._build_faiss_index( + embeddings=passage_embeddings, + index_type=self.index_type, + normalize=self.normalize, + metric=self.metric, + ) + # free up memory from the unused variable + del passage_embeddings + + return self + + @torch.no_grad() + @torch.inference_mode() + def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: + k = min(k, self.embeddings.ntotal) + + if self.normalize: + faiss.normalize_L2(query) + if isinstance(query, torch.Tensor) and self.device == "cpu": + query = query.detach().cpu() + # Retrieve the indices of the top k passage embeddings + retriever_out = self.embeddings.search(query, k) + + # get int values (second element of the tuple) + batch_top_k: List[List[int]] = retriever_out[1].detach().cpu().tolist() + # get float values (first element of the tuple) + batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist() + # Retrieve the passages corresponding to the indices + batch_docs = [ + [self.documents.get_document_from_id(i) for i in indices if i != -1] + for indices in batch_top_k + ] + # build the output object + # build the output object + batch_retrieved_samples = [ + [ + RetrievedSample(document=doc, score=score) + for doc, score in zip(docs, scores) + ] + for docs, scores in zip(batch_docs, batch_scores) + ] + return batch_retrieved_samples + + # def save(self, saving_dir: Union[str, os.PathLike]): + # """ + # Save the indexer to the disk. + + # Args: + # saving_dir (:obj:`Union[str, os.PathLike]`): + # The directory where the indexer will be saved. + # """ + # saving_dir = Path(saving_dir) + # # save the passage embeddings + # index_path = saving_dir / self.INDEX_FILE_NAME + # logger.info(f"Saving passage embeddings to {index_path}") + # faiss.write_index(self.embeddings, str(index_path)) + # # save the passage index + # documents_path = saving_dir / self.DOCUMENTS_FILE_NAME + # logger.info(f"Saving passage index to {documents_path}") + # self.documents.save(documents_path) + + # @classmethod + # def load( + # cls, + # loading_dir: Union[str, os.PathLike], + # device: str = "cpu", + # document_file_name: Optional[str] = None, + # embedding_file_name: Optional[str] = None, + # index_file_name: Optional[str] = None, + # **kwargs, + # ) -> "FaissDocumentIndex": + # loading_dir = Path(loading_dir) + + # document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME + # embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME + # index_file_name = index_file_name or cls.INDEX_FILE_NAME + + # # load the documents + # documents_path = loading_dir / document_file_name + + # if not documents_path.exists(): + # raise ValueError(f"Document file `{documents_path}` does not exist.") + # logger.info(f"Loading documents from {documents_path}") + # documents = Labels.from_file(documents_path) + + # index = None + # embeddings = None + # # try to load the index directly + # index_path = loading_dir / index_file_name + # if not index_path.exists(): + # # try to load the embeddings + # embedding_path = loading_dir / embedding_file_name + # # run some checks + # if embedding_path.exists(): + # logger.info(f"Loading embeddings from {embedding_path}") + # embeddings = torch.load(embedding_path, map_location="cpu") + # logger.warning( + # f"Index file `{index_path}` and embedding file `{embedding_path}` do not exist." + # ) + # else: + # logger.info(f"Loading index from {index_path}") + # index = faiss.read_index(str(embedding_path)) + + # return cls( + # documents=documents, + # embeddings=embeddings, + # index=index, + # device=device, + # **kwargs, + # ) + + def get_embeddings_from_index( + self, index: int + ) -> Union[torch.Tensor, numpy.ndarray]: + """ + Get the document vector from the index. + + Args: + index (`int`): + The index of the document. + + Returns: + `torch.Tensor`: The document vector. + """ + if self.embeddings is None: + raise ValueError( + "The documents must be indexed before they can be retrieved." + ) + if index >= self.embeddings.ntotal: + raise ValueError( + f"The index {index} is out of bounds. The maximum index is {self.embeddings.ntotal}." + ) + return self.embeddings.reconstruct(index) diff --git a/relik/retriever/indexers/inmemory.py b/relik/retriever/indexers/inmemory.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5951fd6252248bd0563d84d96958d871e8ac12 --- /dev/null +++ b/relik/retriever/indexers/inmemory.py @@ -0,0 +1,311 @@ +import contextlib +import logging +import os +import tempfile +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +import transformers as tr + +from relik.common.log import get_logger +from relik.common.torch_utils import get_autocast_context +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.indexers.document import Document, DocumentStore +from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample + + +# check if ORT is available +# if is_package_available("onnxruntime"): + +logger = get_logger(__name__, level=logging.INFO) + + +class MatrixMultiplicationModule(torch.nn.Module): + def __init__(self, embeddings): + super().__init__() + self.embeddings = torch.nn.Parameter(embeddings, requires_grad=False) + + def forward(self, query): + return torch.matmul(query, self.embeddings.T) + + +class InMemoryDocumentIndex(BaseDocumentIndex): + DOCUMENTS_FILE_NAME = "documents.jsonl" + EMBEDDINGS_FILE_NAME = "embeddings.pt" + + def __init__( + self, + documents: str + | List[str] + | os.PathLike + | List[os.PathLike] + | DocumentStore + | None = None, + embeddings: torch.Tensor | None = None, + metadata_fields: List[str] | None = None, + separator: str | None = None, + name_or_path: str | os.PathLike | None = None, + device: str = "cpu", + precision: str | int | torch.dtype = 32, + *args, + **kwargs, + ) -> None: + """ + An in-memory indexer based on PyTorch. + + Args: + documents (:obj:`Union[List[str]]`): + The documents to be indexed. + embeddings (:obj:`Optional[torch.Tensor]`, `optional`, defaults to :obj:`None`): + The embeddings of the documents. + device (:obj:`str`, `optional`, defaults to "cpu"): + The device to be used for storing the embeddings. + """ + + super().__init__( + documents, embeddings, metadata_fields, separator, name_or_path, device + ) + + if embeddings is not None and documents is not None: + logger.info("Both documents and embeddings are provided.") + if len(documents) != embeddings.shape[0]: + raise ValueError( + "The number of documents and embeddings must be the same." + ) + + # # embeddings of the documents + # self.embeddings = embeddings + # does this do anything? + del embeddings + # convert the embeddings to the desired precision + if precision is not None: + if self.embeddings is not None and device == "cpu": + if PRECISION_MAP[precision] == PRECISION_MAP[16]: + logger.info( + f"Precision `{precision}` is not supported on CPU. " + f"Using `{PRECISION_MAP[32]}` instead." + ) + precision = 32 + + if ( + self.embeddings is not None + and self.embeddings.dtype != PRECISION_MAP[precision] + ): + logger.info( + f"Index vectors are of type {self.embeddings.dtype}. " + f"Converting to {PRECISION_MAP[precision]}." + ) + self.embeddings = self.embeddings.to(PRECISION_MAP[precision]) + else: + # TODO: a bit redundant, fix this eventually + if ( + device == "cpu" + and self.embeddings is not None + and self.embeddings.dtype != torch.float32 + ): + logger.info( + f"Index vectors are of type {self.embeddings.dtype}. " + f"Converting to {PRECISION_MAP[32]}." + ) + self.embeddings = self.embeddings.to(PRECISION_MAP[32]) + # move the embeddings to the desired device + if self.embeddings is not None and not self.embeddings.device == device: + self.embeddings = self.embeddings.to(device) + + # TODO: check interactions with the embeddings + # self.mm = MatrixMultiplicationModule(embeddings=self.embeddings) + # self.mm.eval() + + # precision to be used for the embeddings + self.precision = precision + + @torch.no_grad() + @torch.inference_mode() + def index( + self, + retriever, + documents: Optional[List[Document]] = None, + batch_size: int = 32, + num_workers: int = 4, + max_length: int | None = None, + collate_fn: Optional[Callable] = None, + encoder_precision: Optional[Union[str, int]] = None, + compute_on_cpu: bool = False, + force_reindex: bool = False, + ) -> "InMemoryDocumentIndex": + """ + Index the documents using the encoder. + + Args: + retriever (:obj:`torch.nn.Module`): + The encoder to be used for indexing. + documents (:obj:`List[Document]`, `optional`, defaults to :obj:`None`): + The documents to be indexed. If not provided, the documents provided at the initialization will be used. + batch_size (:obj:`int`, `optional`, defaults to 32): + The batch size to be used for indexing. + num_workers (:obj:`int`, `optional`, defaults to 4): + The number of workers to be used for indexing. + max_length (:obj:`int`, `optional`, defaults to None): + The maximum length of the input to the encoder. + collate_fn (:obj:`Callable`, `optional`, defaults to None): + The collate function to be used for batching. + encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): + The precision to be used for the encoder. + compute_on_cpu (:obj:`bool`, `optional`, defaults to False): + Whether to compute the embeddings on CPU. + force_reindex (:obj:`bool`, `optional`, defaults to False): + Whether to force reindexing. + + Returns: + :obj:`InMemoryIndexer`: The indexer object. + """ + + if documents is None and self.documents is None: + raise ValueError("Documents must be provided.") + + if self.embeddings is not None and not force_reindex and documents is None: + logger.info( + "Embeddings are already present and `force_reindex` is `False`. Skipping indexing." + ) + return self + + if force_reindex: + if documents is not None: + self.documents.add_documents(documents) + data = [k for k in self.get_passages()] + + else: + if documents is not None: + data = [k for k in self.get_passages(DocumentStore(documents))] + # add the documents to the actual document store + self.documents.add_documents(documents) + else: + if self.embeddings is None: + data = [k for k in self.get_passages()] + + if collate_fn is None: + tokenizer = retriever.passage_tokenizer + + def collate_fn(x): + return ModelInputs( + tokenizer( + x, + padding=True, + return_tensors="pt", + truncation=True, + max_length=max_length or tokenizer.model_max_length, + ) + ) + + dataloader = DataLoader( + BaseDataset(name="passage", data=data), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, + ) + + encoder = retriever.passage_encoder + + # Create empty lists to store the passage embeddings and passage index + passage_embeddings: List[torch.Tensor] = [] + + encoder_device = "cpu" if compute_on_cpu else encoder.device + + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(encoder_device).split(":")[0] + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_pssg_mngr = ( + contextlib.nullcontext() + if device_type_for_autocast == "cpu" + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=PRECISION_MAP[encoder_precision], + ) + ) + ) + with autocast_pssg_mngr: + # Iterate through each batch in the dataloader + for batch in tqdm(dataloader, desc="Indexing"): + # Move the batch to the device + batch: ModelInputs = batch.to(encoder_device) + # Compute the passage embeddings + passage_outs = encoder(**batch).pooler_output + # Append the passage embeddings to the list + if self.device == "cpu": + passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) + else: + passage_embeddings.extend([c for c in passage_outs]) + + # move the passage embeddings to the CPU if not already done + # the move to cpu and then to gpu is needed to avoid OOM when using mixed precision + if not self.device == "cpu": # this if is to avoid unnecessary moves + passage_embeddings = [c.detach().cpu() for c in passage_embeddings] + # stack it + passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) + # move the passage embeddings to the gpu if needed + if not self.device == "cpu": + passage_embeddings = passage_embeddings.to(PRECISION_MAP[self.precision]) + passage_embeddings = passage_embeddings.to(self.device) + self.embeddings = passage_embeddings + # update the matrix multiplication module + # self.mm = MatrixMultiplicationModule(embeddings=self.embeddings) + + # free up memory from the unused variable + del passage_embeddings + + return self + + @torch.no_grad() + @torch.inference_mode() + def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: + """ + Search the documents using the query. + + Args: + query (:obj:`torch.Tensor`): + The query to be used for searching. + k (:obj:`int`, `optional`, defaults to 1): + The number of documents to be retrieved. + + Returns: + :obj:`List[RetrievedSample]`: The retrieved documents. + """ + + with get_autocast_context(self.device, self.embeddings.dtype): + # move query to the same device as embeddings + query = query.to(self.embeddings.device) + if query.dtype != self.embeddings.dtype: + query = query.to(self.embeddings.dtype) + similarity = torch.matmul(query, self.embeddings.T) + # similarity = self.mm(query) + # Retrieve the indices of the top k passage embeddings + retriever_out: torch.return_types.topk = torch.topk( + similarity, k=min(k, similarity.shape[-1]), dim=1 + ) + + # get int values + batch_top_k: List[List[int]] = retriever_out.indices.detach().cpu().tolist() + # get float values + batch_scores: List[List[float]] = retriever_out.values.detach().cpu().tolist() + # Retrieve the passages corresponding to the indices + batch_docs = [ + [self.documents.get_document_from_id(i) for i in indices] + for indices in batch_top_k + ] + # build the output object + batch_retrieved_samples = [ + [ + RetrievedSample(document=doc, score=score) + for doc, score in zip(docs, scores) + ] + for docs, scores in zip(batch_docs, batch_scores) + ] + return batch_retrieved_samples diff --git a/relik/retriever/indexers/voyager.py b/relik/retriever/indexers/voyager.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d274fe76800a82cadd8f88ce20e2cc17c3a97f --- /dev/null +++ b/relik/retriever/indexers/voyager.py @@ -0,0 +1,271 @@ +import contextlib +import logging +import os +from typing import Callable, List, Optional, Union + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from voyager import Index, Space + +from relik.common.log import get_logger +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.data.labels import Labels +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample + +logger = get_logger(__name__, level=logging.INFO) + + +class VoyagerDocumentIndex(BaseDocumentIndex): + DOCUMENTS_FILE_NAME = "documents.json" + EMBEDDINGS_FILE_NAME = "embeddings.pt" + + def __init__( + self, + documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None, + embeddings: Optional[torch.Tensor] = None, + device: str = "cpu", + precision: Optional[str] = None, + name_or_path: Optional[Union[str, os.PathLike]] = None, + *args, + **kwargs, + ) -> None: + """ + An in-memory indexer. + + Args: + documents (:obj:`Union[List[str], PassageManager]`): + The documents to be indexed. + embeddings (:obj:`Optional[torch.Tensor]`, `optional`, defaults to :obj:`None`): + The embeddings of the documents. + device (:obj:`str`, `optional`, defaults to "cpu"): + The device to be used for storing the embeddings. + """ + + super().__init__(documents, embeddings, name_or_path) + + if embeddings is not None and documents is not None: + logger.info("Both documents and embeddings are provided.") + if documents.get_label_size() != embeddings.shape[0]: + raise ValueError( + "The number of documents and embeddings must be the same." + ) + + self.embeddings = Index( + Space.InnerProduct, + num_dimensions=embeddings.shape[1], + ef_construction=2000, + M=2048, + ) + self.embeddings.add_items(embeddings.numpy()) + + # embeddings of the documents + # self.embeddings = embeddings + # does this do anything? + del embeddings + # convert the embeddings to the desired precision + # if precision is not None: + # if ( + # self.embeddings is not None + # and self.embeddings.dtype != PRECISION_MAP[precision] + # ): + # logger.info( + # f"Index vectors are of type {self.embeddings.dtype}. " + # f"Converting to {PRECISION_MAP[precision]}." + # ) + # self.embeddings = self.embeddings.to(PRECISION_MAP[precision]) + # else: + # if ( + # device == "cpu" + # and self.embeddings is not None + # and self.embeddings.dtype != torch.float32 + # ): + # logger.info( + # "Index vectors are of type {}. Converting to float32.".format( + # self.embeddings.dtype + # ) + # ) + # self.embeddings = self.embeddings.to(PRECISION_MAP[32]) + # # move the embeddings to the desired device + # if self.embeddings is not None and not self.embeddings.device == device: + # self.embeddings = self.embeddings.to(device) + + # device to store the embeddings + self.device = device + # precision to be used for the embeddings + self.precision = precision + + @torch.no_grad() + @torch.inference_mode() + def index( + self, + retriever, + documents: Optional[List[str]] = None, + batch_size: int = 32, + num_workers: int = 4, + max_length: Optional[int] = None, + collate_fn: Optional[Callable] = None, + encoder_precision: Optional[Union[str, int]] = None, + compute_on_cpu: bool = False, + force_reindex: bool = False, + add_to_existing_index: bool = False, + ) -> "VoyagerDocumentIndex": + """ + Index the documents using the encoder. + + Args: + retriever (:obj:`torch.nn.Module`): + The encoder to be used for indexing. + documents (:obj:`List[str]`, `optional`, defaults to :obj:`None`): + The documents to be indexed. + batch_size (:obj:`int`, `optional`, defaults to 32): + The batch size to be used for indexing. + num_workers (:obj:`int`, `optional`, defaults to 4): + The number of workers to be used for indexing. + max_length (:obj:`int`, `optional`, defaults to None): + The maximum length of the input to the encoder. + collate_fn (:obj:`Callable`, `optional`, defaults to None): + The collate function to be used for batching. + encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): + The precision to be used for the encoder. + compute_on_cpu (:obj:`bool`, `optional`, defaults to False): + Whether to compute the embeddings on CPU. + force_reindex (:obj:`bool`, `optional`, defaults to False): + Whether to force reindexing. + add_to_existing_index (:obj:`bool`, `optional`, defaults to False): + Whether to add the new documents to the existing index. + + Returns: + :obj:`InMemoryIndexer`: The indexer object. + """ + + if documents is None and self.documents is None: + raise ValueError("Documents must be provided.") + + if self.embeddings is not None and not force_reindex: + logger.info( + "Embeddings are already present and `force_reindex` is `False`. Skipping indexing." + ) + if documents is None: + return self + + if collate_fn is None: + tokenizer = retriever.passage_tokenizer + + def collate_fn(x): + return ModelInputs( + tokenizer( + x, + padding=True, + return_tensors="pt", + truncation=True, + max_length=max_length or tokenizer.model_max_length, + ) + ) + + if force_reindex: + if documents is not None: + self.documents.add_labels(documents) + data = [k for k in self.documents.get_labels()] + + else: + if documents is not None: + data = [k for k in Labels(documents).get_labels()] + else: + return self + + # if force_reindex: + # data = [k for k in self.documents.get_labels()] + + dataloader = DataLoader( + BaseDataset(name="passage", data=data), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, + ) + + encoder = retriever.passage_encoder + + # Create empty lists to store the passage embeddings and passage index + passage_embeddings: List[torch.Tensor] = [] + + encoder_device = "cpu" if compute_on_cpu else self.device + + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(encoder_device).split(":")[0] + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_pssg_mngr = ( + contextlib.nullcontext() + if device_type_for_autocast == "cpu" + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=PRECISION_MAP[encoder_precision], + ) + ) + ) + with autocast_pssg_mngr: + # Iterate through each batch in the dataloader + for batch in tqdm(dataloader, desc="Indexing"): + # Move the batch to the device + batch: ModelInputs = batch.to(encoder_device) + # Compute the passage embeddings + passage_outs = encoder(**batch).pooler_output + # Append the passage embeddings to the list + if self.device == "cpu": + passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) + else: + passage_embeddings.extend([c for c in passage_outs]) + + # move the passage embeddings to the CPU if not already done + # the move to cpu and then to gpu is needed to avoid OOM when using mixed precision + if not self.device == "cpu": # this if is to avoid unnecessary moves + passage_embeddings = [c.detach().cpu() for c in passage_embeddings] + # stack it + passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) + # move the passage embeddings to the gpu if needed + if not self.device == "cpu": + passage_embeddings = passage_embeddings.to(PRECISION_MAP[self.precision]) + passage_embeddings = passage_embeddings.to(self.device) + self.embeddings = passage_embeddings + + # free up memory from the unused variable + del passage_embeddings + + return self + + @torch.no_grad() + @torch.inference_mode() + def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: + # k = min(k, self.embeddings.ntotal) + + if isinstance(query, torch.Tensor) and self.device == "cpu": + query = query.detach().cpu().numpy() + # Retrieve the indices of the top k passage embeddings + retriever_out = self.embeddings.query(query, k) + + # get int values (second element of the tuple) + batch_top_k: List[List[int]] = retriever_out[0].tolist() + # get float values (first element of the tuple) + batch_scores: List[List[float]] = retriever_out[1].tolist() + # Retrieve the passages corresponding to the indices + batch_passages = [ + [self.documents.get_label_from_index(i) for i in indices if i != -1] + for indices in batch_top_k + ] + # build the output object + batch_retrieved_samples = [ + [ + RetrievedSample(label=passage, index=index, score=score) + for passage, index, score in zip(passages, indices, scores) + ] + for passages, indices, scores in zip( + batch_passages, batch_top_k, batch_scores + ) + ] + return batch_retrieved_samples diff --git a/relik/retriever/lightning_modules/__init__.py b/relik/retriever/lightning_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/lightning_modules/pl_data_modules.py b/relik/retriever/lightning_modules/pl_data_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..d7898d7c64a3843121c0004b586a485d081acab6 --- /dev/null +++ b/relik/retriever/lightning_modules/pl_data_modules.py @@ -0,0 +1,121 @@ +from typing import Any, List, Optional, Sequence, Union + +import hydra +import lightning as pl +import torch +from lightning.pytorch.utilities.types import EVAL_DATALOADERS +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from relik.common.log import get_logger +from relik.retriever.data.datasets import GoldenRetrieverDataset + +logger = get_logger(__name__) + + +class GoldenRetrieverPLDataModule(pl.LightningDataModule): + def __init__( + self, + train_dataset: Optional[GoldenRetrieverDataset] = None, + val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, + test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, + num_workers: Optional[Union[DictConfig, int]] = None, + datasets: Optional[DictConfig] = None, + *args, + **kwargs, + ): + super().__init__() + self.datasets = datasets + if num_workers is None: + num_workers = 0 + if isinstance(num_workers, int): + num_workers = DictConfig( + {"train": num_workers, "val": num_workers, "test": num_workers} + ) + self.num_workers = num_workers + # data + self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset + self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets + self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets + + def prepare_data(self, *args, **kwargs): + """ + Method for preparing the data before the training. This method is called only once. + It is used to download the data, tokenize the data, etc. + """ + pass + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + # usually there is only one dataset for train + # if you need more train loader, you can follow + # the same logic as val and test datasets + if self.train_dataset is None: + self.train_dataset = hydra.utils.instantiate(self.datasets.train) + self.val_datasets = [ + hydra.utils.instantiate(dataset_cfg) + for dataset_cfg in self.datasets.val + ] + if stage == "test": + if self.test_datasets is None: + self.test_datasets = [ + hydra.utils.instantiate(dataset_cfg) + for dataset_cfg in self.datasets.test + ] + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + torch_dataset = self.train_dataset.to_torch_dataset() + return DataLoader( + # self.train_dataset.to_torch_dataset(), + torch_dataset, + shuffle=False, + batch_size=None, + num_workers=self.num_workers.train, + pin_memory=True, + collate_fn=lambda x: x, + ) + + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + dataloaders = [] + for dataset in self.val_datasets: + torch_dataset = dataset.to_torch_dataset() + dataloaders.append( + DataLoader( + torch_dataset, + shuffle=False, + batch_size=None, + num_workers=self.num_workers.val, + pin_memory=True, + collate_fn=lambda x: x, + ) + ) + return dataloaders + + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + dataloaders = [] + for dataset in self.test_datasets: + torch_dataset = dataset.to_torch_dataset() + dataloaders.append( + DataLoader( + torch_dataset, + shuffle=False, + batch_size=None, + num_workers=self.num_workers.test, + pin_memory=True, + collate_fn=lambda x: x, + ) + ) + return dataloaders + + def predict_dataloader(self) -> EVAL_DATALOADERS: + raise NotImplementedError + + def transfer_batch_to_device( + self, batch: Any, device: torch.device, dataloader_idx: int + ) -> Any: + return super().transfer_batch_to_device(batch, device, dataloader_idx) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, " + ) diff --git a/relik/retriever/lightning_modules/pl_modules.py b/relik/retriever/lightning_modules/pl_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef6ace6eb8ac95d92e16b1e28dc50d2492b33f4 --- /dev/null +++ b/relik/retriever/lightning_modules/pl_modules.py @@ -0,0 +1,123 @@ +from typing import Any, Union + +import hydra +import lightning as pl +import torch +from omegaconf import DictConfig + +from relik.retriever.common.model_inputs import ModelInputs + + +class GoldenRetrieverPLModule(pl.LightningModule): + def __init__( + self, + model: Union[torch.nn.Module, DictConfig], + optimizer: Union[torch.optim.Optimizer, DictConfig], + lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, DictConfig] = None, + *args, + **kwargs, + ) -> None: + super().__init__() + self.save_hyperparameters(ignore=["model"]) + if isinstance(model, DictConfig): + self.model = hydra.utils.instantiate(model) + else: + self.model = model + + self.optimizer_config = optimizer + self.lr_scheduler_config = lr_scheduler + + def forward(self, **kwargs) -> dict: + """ + Method for the forward pass. + 'training_step', 'validation_step' and 'test_step' should call + this method in order to compute the output predictions and the loss. + + Returns: + output_dict: forward output containing the predictions (output logits ecc...) and the loss if any. + + """ + return self.model(**kwargs) + + def training_step(self, batch: ModelInputs, batch_idx: int) -> torch.Tensor: + forward_output = self.forward(**batch, return_loss=True) + self.log( + "loss", + forward_output["loss"], + batch_size=batch["questions"]["input_ids"].size(0), + prog_bar=True, + ) + return forward_output["loss"] + + def validation_step(self, batch: ModelInputs, batch_idx: int) -> None: + forward_output = self.forward(**batch, return_loss=True) + self.log( + "val_loss", + forward_output["loss"], + batch_size=batch["questions"]["input_ids"].size(0), + ) + + def test_step(self, batch: ModelInputs, batch_idx: int) -> Any: + forward_output = self.forward(**batch, return_loss=True) + self.log( + "test_loss", + forward_output["loss"], + batch_size=batch["questions"]["input_ids"].size(0), + ) + + def configure_optimizers(self): + if isinstance(self.optimizer_config, DictConfig): + param_optimizer = list(self.named_parameters()) + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in param_optimizer if "layer_norm_layer" in n + ], + "weight_decay": self.hparams.optimizer.weight_decay, + "lr": 1e-4, + }, + { + "params": [ + p + for n, p in param_optimizer + if all(nd not in n for nd in no_decay) + and "layer_norm_layer" not in n + ], + "weight_decay": self.hparams.optimizer.weight_decay, + }, + { + "params": [ + p + for n, p in param_optimizer + if "layer_norm_layer" not in n + and any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + optimizer = hydra.utils.instantiate( + self.optimizer_config, + # params=self.parameters(), + params=optimizer_grouped_parameters, + _convert_="partial", + ) + else: + optimizer = self.optimizer_config + + if self.lr_scheduler_config is None: + return optimizer + + if isinstance(self.lr_scheduler_config, DictConfig): + lr_scheduler = hydra.utils.instantiate( + self.lr_scheduler_config, optimizer=optimizer + ) + else: + lr_scheduler = self.lr_scheduler_config + + lr_scheduler_config = { + "scheduler": lr_scheduler, + "interval": "step", + "frequency": 1, + } + return [optimizer], [lr_scheduler_config] diff --git a/relik/retriever/pytorch_modules/__init__.py b/relik/retriever/pytorch_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b3b1f6777281b978f8a0fdc361925f38278623 --- /dev/null +++ b/relik/retriever/pytorch_modules/__init__.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass + +import torch + +from relik.retriever.indexers.document import Document + +PRECISION_MAP = { + None: torch.float32, + 32: torch.float32, + 16: torch.float16, + torch.float32: torch.float32, + torch.float16: torch.float16, + torch.bfloat16: torch.bfloat16, + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float": torch.float32, + "half": torch.float16, + "32": torch.float32, + "16": torch.float16, + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +@dataclass +class RetrievedSample: + """ + Dataclass for the output of the GoldenRetriever model. + """ + + score: float + document: Document diff --git a/relik/retriever/pytorch_modules/__pycache__/__init__.cpython-310.pyc b/relik/retriever/pytorch_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45d743a98e26db7bba31e3d740f5095960971173 Binary files /dev/null and b/relik/retriever/pytorch_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/relik/retriever/pytorch_modules/__pycache__/hf.cpython-310.pyc b/relik/retriever/pytorch_modules/__pycache__/hf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a72140aed0545f971f9c9b8a6fdf252fbfd6c2b Binary files /dev/null and b/relik/retriever/pytorch_modules/__pycache__/hf.cpython-310.pyc differ diff --git a/relik/retriever/pytorch_modules/__pycache__/model.cpython-310.pyc b/relik/retriever/pytorch_modules/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64c94dbd661e7c1ab526029e20871fe80d8bee6a Binary files /dev/null and b/relik/retriever/pytorch_modules/__pycache__/model.cpython-310.pyc differ diff --git a/relik/retriever/pytorch_modules/hf.py b/relik/retriever/pytorch_modules/hf.py new file mode 100644 index 0000000000000000000000000000000000000000..b5868d67ce5ed97a1e66c6d1d3a606350e75267a --- /dev/null +++ b/relik/retriever/pytorch_modules/hf.py @@ -0,0 +1,88 @@ +from typing import Tuple, Union + +import torch +from transformers import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions +from transformers.models.bert.modeling_bert import BertModel + + +class GoldenRetrieverConfig(PretrainedConfig): + model_type = "bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class GoldenRetrieverModel(BertModel): + config_class = GoldenRetrieverConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config) + self.layer_norm_layer = torch.nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + def forward( + self, **kwargs + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + attention_mask = kwargs.get("attention_mask", None) + model_outputs = super().forward(**kwargs) + if attention_mask is None: + pooler_output = model_outputs.pooler_output + else: + token_embeddings = model_outputs.last_hidden_state + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + pooler_output = torch.sum( + token_embeddings * input_mask_expanded, 1 + ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + pooler_output = self.layer_norm_layer(pooler_output) + + if not kwargs.get("return_dict", True): + return (model_outputs[0], pooler_output) + model_outputs[2:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=model_outputs.last_hidden_state, + pooler_output=pooler_output, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + cross_attentions=model_outputs.cross_attentions, + ) diff --git a/relik/retriever/pytorch_modules/loss.py b/relik/retriever/pytorch_modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..643d3a486ca73ca38486094553b357f5e7c28adb --- /dev/null +++ b/relik/retriever/pytorch_modules/loss.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch +from torch.nn.modules.loss import _WeightedLoss + + +class MultiLabelNCELoss(_WeightedLoss): + __constants__ = ["reduction"] + + def __init__( + self, + weight: Optional[torch.Tensor] = None, + size_average=None, + reduction: Optional[str] = "mean", + ) -> None: + super(MultiLabelNCELoss, self).__init__(weight, size_average, None, reduction) + + def forward( + self, input: torch.Tensor, target: torch.Tensor, ignore_index: int = -100 + ) -> torch.Tensor: + gold_scores = input.masked_fill(~(target.bool()), 0) + gold_scores_sum = gold_scores.sum(-1) # B x C + neg_logits = input.masked_fill(target.bool(), float("-inf")) # B x C x L + neg_log_sum_exp = torch.logsumexp(neg_logits, -1, keepdim=True) # B x C x 1 + norm_term = ( + torch.logaddexp(input, neg_log_sum_exp) + .masked_fill(~(target.bool()), 0) + .sum(-1) + ) + gold_log_probs = gold_scores_sum - norm_term + loss = -gold_log_probs.sum() + if self.reduction == "mean": + loss /= input.size(0) + return loss diff --git a/relik/retriever/pytorch_modules/model.py b/relik/retriever/pytorch_modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b55dea14caadc727b92df59d89f7f0fe5553c9a0 --- /dev/null +++ b/relik/retriever/pytorch_modules/model.py @@ -0,0 +1,645 @@ +import logging +import os +import platform +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +import transformers as tr +from torch.utils.data import DataLoader +from tqdm import tqdm + +from relik.common.log import get_logger +from relik.common.torch_utils import ( + get_autocast_context, +) # , # load_ort_optimized_hf_model +from relik.common.utils import is_package_available, to_config +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.data.labels import Labels +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.indexers.document import Document +from relik.retriever.indexers.inmemory import InMemoryDocumentIndex +from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample +from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel + +# check if ORT is available +if is_package_available("onnxruntime"): + from optimum.onnxruntime import ORTModel + +logger = get_logger(__name__, level=logging.INFO) + + +@dataclass +class GoldenRetrieverOutput(tr.file_utils.ModelOutput): + """Class for model's outputs.""" + + logits: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None + question_encodings: Optional[torch.FloatTensor] = None + passages_encodings: Optional[torch.FloatTensor] = None + + +class GoldenRetriever(torch.nn.Module): + def __init__( + self, + question_encoder: Union[str, tr.PreTrainedModel], + loss_type: Optional[torch.nn.Module] = None, + passage_encoder: Optional[Union[str, tr.PreTrainedModel]] = None, + document_index: Optional[Union[str, BaseDocumentIndex]] = None, + question_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, + passage_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, + device: Optional[Union[str, torch.device]] = "cpu", + precision: Optional[Union[str, int]] = None, + index_precision: Optional[Union[str, int]] = None, + index_device: Optional[Union[str, torch.device]] = None, + *args, + **kwargs, + ): + super().__init__() + + self.passage_encoder_is_question_encoder = False + # question encoder model + if isinstance(question_encoder, str): + question_encoder = GoldenRetrieverModel.from_pretrained( + question_encoder, **kwargs + ) + self.question_encoder = question_encoder + if passage_encoder is None: + # if no passage encoder is provided, + # share the weights of the question encoder + passage_encoder = question_encoder + # keep track of the fact that the passage encoder is the same as the question encoder + self.passage_encoder_is_question_encoder = True + if isinstance(passage_encoder, str): + passage_encoder = GoldenRetrieverModel.from_pretrained( + passage_encoder, **kwargs + ) + # passage encoder model + self.passage_encoder = passage_encoder + + # loss function + self.loss_type = loss_type + + # indexer stuff + index_device = index_device or device + index_precision = index_precision or precision + if document_index is None: + # if no indexer is provided, create a new one + document_index = InMemoryDocumentIndex( + device=index_device, precision=index_precision, **kwargs + ) + if isinstance(document_index, str): + document_index = BaseDocumentIndex.from_pretrained( + document_index, device=index_device, precision=index_precision, **kwargs + ) + self.document_index = document_index + + # lazy load the tokenizer for inference + self._question_tokenizer = question_tokenizer + self._passage_tokenizer = passage_tokenizer + + # move the model to the device + self.to(device or torch.device("cpu")) + + # set the precision + self.precision = precision + + def forward( + self, + questions: Optional[Dict[str, torch.Tensor]] = None, + passages: Optional[Dict[str, torch.Tensor]] = None, + labels: Optional[torch.Tensor] = None, + question_encodings: Optional[torch.Tensor] = None, + passages_encodings: Optional[torch.Tensor] = None, + passages_per_question: Optional[List[int]] = None, + return_loss: bool = False, + return_encodings: bool = False, + *args, + **kwargs, + ) -> GoldenRetrieverOutput: + """ + Forward pass of the model. + + Args: + questions (`Dict[str, torch.Tensor]`): + The questions to encode. + passages (`Dict[str, torch.Tensor]`): + The passages to encode. + labels (`torch.Tensor`): + The labels of the sentences. + return_loss (`bool`): + Whether to compute the predictions. + question_encodings (`torch.Tensor`): + The encodings of the questions. + passages_encodings (`torch.Tensor`): + The encodings of the passages. + passages_per_question (`List[int]`): + The number of passages per question. + return_loss (`bool`): + Whether to compute the loss. + return_encodings (`bool`): + Whether to return the encodings. + + Returns: + obj:`torch.Tensor`: The outputs of the model. + """ + if questions is None and question_encodings is None: + raise ValueError( + "Either `questions` or `question_encodings` must be provided" + ) + if passages is None and passages_encodings is None: + raise ValueError( + "Either `passages` or `passages_encodings` must be provided" + ) + + if question_encodings is None: + question_encodings = self.question_encoder(**questions).pooler_output + if passages_encodings is None: + passages_encodings = self.passage_encoder(**passages).pooler_output + + if passages_per_question is not None: + # multiply each question encoding with a passages_per_question encodings + concatenated_passages = torch.stack( + torch.split(passages_encodings, passages_per_question) + ).transpose(1, 2) + if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss): + # normalize the encodings for cosine similarity + concatenated_passages = F.normalize(concatenated_passages, p=2, dim=2) + question_encodings = F.normalize(question_encodings, p=2, dim=1) + logits = torch.bmm( + question_encodings.unsqueeze(1), concatenated_passages + ).view(question_encodings.shape[0], -1) + else: + if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss): + # normalize the encodings for cosine similarity + question_encodings = F.normalize(question_encodings, p=2, dim=1) + passages_encodings = F.normalize(passages_encodings, p=2, dim=1) + + logits = torch.matmul(question_encodings, passages_encodings.T) + + output = dict(logits=logits) + + if return_loss and labels is not None: + if self.loss_type is None: + raise ValueError( + "If `return_loss` is set to `True`, `loss_type` must be provided" + ) + if isinstance(self.loss_type, torch.nn.NLLLoss): + labels = labels.argmax(dim=1) + logits = F.log_softmax(logits, dim=1) + if len(question_encodings.size()) > 1: + logits = logits.view(question_encodings.size(0), -1) + + output["loss"] = self.loss_type(logits, labels) + + if return_encodings: + output["question_encodings"] = question_encodings + output["passages_encodings"] = passages_encodings + + return GoldenRetrieverOutput(**output) + + @torch.no_grad() + @torch.inference_mode() + def index( + self, + batch_size: int = 32, + num_workers: int = 4, + max_length: int | None = None, + collate_fn: Optional[Callable] = None, + force_reindex: bool = False, + compute_on_cpu: bool = False, + precision: Optional[Union[str, int]] = None, + *args, + **kwargs, + ): + """ + Index the passages for later retrieval. + + Args: + batch_size (`int`): + The batch size to use for the indexing. + num_workers (`int`): + The number of workers to use for the indexing. + max_length (`int | None`): + The maximum length of the passages. + collate_fn (`Callable`): + The collate function to use for the indexing. + force_reindex (`bool`): + Whether to force reindexing even if the passages are already indexed. + compute_on_cpu (`bool`): + Whether to move the index to the CPU after the indexing. + precision (`Optional[Union[str, int]]`): + The precision to use for the model. + """ + if self.document_index is None: + raise ValueError( + "The retriever must be initialized with an indexer to index " + "the passages within the retriever." + ) + # TODO: add kwargs + return self.document_index.index( + retriever=self, + batch_size=batch_size, + num_workers=num_workers, + max_length=max_length, + collate_fn=collate_fn, + encoder_precision=precision or self.precision, + compute_on_cpu=compute_on_cpu, + force_reindex=force_reindex, + *args, + **kwargs, + ) + + @torch.no_grad() + @torch.inference_mode() + def retrieve( + self, + text: Optional[Union[str, List[str]]] = None, + text_pair: Optional[Union[str, List[str]]] = None, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + k: int | None = None, + max_length: int | None = None, + precision: Optional[Union[str, int]] = None, + collate_fn: Optional[Callable] = None, + batch_size: int | None = None, + num_workers: int = 4, + progress_bar: bool = False, + **kwargs, + ) -> List[List[RetrievedSample]]: + """ + Retrieve the passages for the questions. + + Args: + text (`Optional[Union[str, List[str]]]`): + The questions to retrieve the passages for. + text_pair (`Optional[Union[str, List[str]]]`): + The questions to retrieve the passages for. + input_ids (`torch.Tensor`): + The input ids of the questions. + attention_mask (`torch.Tensor`): + The attention mask of the questions. + token_type_ids (`torch.Tensor`): + The token type ids of the questions. + k (`int`): + The number of top passages to retrieve. + max_length (`int | None`): + The maximum length of the questions. + precision (`Optional[Union[str, int]]`): + The precision to use for the model. + collate_fn (`Callable`): + The collate function to use for the retrieval. + batch_size (`int`): + The batch size to use for the retrieval. + num_workers (`int`): + The number of workers to use for the retrieval. + progress_bar (`bool`): + Whether to show a progress bar. + + Returns: + `List[List[RetrievedSample]]`: The retrieved passages and their indices. + """ + if self.document_index is None: + raise ValueError( + "The indexer must be indexed before it can be used within the retriever." + ) + if text is None and input_ids is None: + raise ValueError( + "Either `text` or `input_ids` must be provided to retrieve the passages." + ) + + if text is not None: + if isinstance(text, str): + text = [text] + if text_pair is not None: + if isinstance(text_pair, str): + text_pair = [text_pair] + else: + text_pair = [None] * len(text) + + if collate_fn is None: + tokenizer = self.question_tokenizer + collate_fn = partial( + self.default_collate_fn, max_length=max_length, tokenizer=tokenizer + ) + + dataloader = DataLoader( + BaseDataset(name="questions", data=list(zip(text, text_pair))), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, + ) + else: + model_inputs = ModelInputs(dict(input_ids=input_ids)) + if attention_mask is not None: + model_inputs["attention_mask"] = attention_mask + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids + + dataloader = [model_inputs] + + if progress_bar: + dataloader = tqdm(dataloader, desc="Retrieving passages") + + retrieved = [] + try: + with get_autocast_context(self.device, precision): + for batch in dataloader: + batch = batch.to(self.device) + question_encodings = self.question_encoder(**batch).pooler_output + retrieved += self.document_index.search(question_encodings, k) + except AttributeError as e: + # apparently num_workers > 0 gives some issue on MacOS as of now + if "mac" in platform.platform().lower(): + raise ValueError( + "DataLoader with num_workers > 0 is not supported on MacOS. " + "Please set num_workers=0 or try to run on a different machine." + ) from e + else: + raise e + + if progress_bar: + dataloader.close() + + return retrieved + + @staticmethod + def default_collate_fn( + x: tuple, tokenizer: tr.PreTrainedTokenizer, max_length: int | None = None + ) -> ModelInputs: + # get text and text pair + # TODO: check if only retriever is used + _text = [sample[0] for sample in x] + _text_pair = [sample[1] for sample in x] + _text_pair = None if any([t is None for t in _text_pair]) else _text_pair + return ModelInputs( + tokenizer( + _text, + text_pair=_text_pair, + padding=True, + return_tensors="pt", + truncation=True, + max_length=max_length or tokenizer.model_max_length, + ) + ) + + def get_document_from_index(self, index: int) -> Document: + """ + Get the document from its ID. + + Args: + id (`int`): + The ID of the document. + + Returns: + `str`: The document. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_document_from_index(index) + + def get_document_from_passage(self, passage: str) -> Document: + """ + Get the document from its text. + + Args: + passage (`str`): + The passage of the document. + + Returns: + `str`: The document. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_document_from_passage(passage) + + def get_index_from_passage(self, passage: str) -> int: + """ + Get the index of the passage. + + Args: + passage (`str`): + The passage to get the index for. + + Returns: + `int`: The index of the passage. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_index_from_passage(passage) + + def get_passage_from_index(self, index: int) -> str: + """ + Get the passage from the index. + + Args: + index (`int`): + The index of the passage. + + Returns: + `str`: The passage. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_passage_from_index(index) + + def get_vector_from_index(self, index: int) -> torch.Tensor: + """ + Get the passage vector from the index. + + Args: + index (`int`): + The index of the passage. + + Returns: + `torch.Tensor`: The passage vector. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_embeddings_from_index(index) + + def get_vector_from_passage(self, passage: str) -> torch.Tensor: + """ + Get the passage vector from the passage. + + Args: + passage (`str`): + The passage. + + Returns: + `torch.Tensor`: The passage vector. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_embeddings_from_passage(passage) + + @property + def passage_embeddings(self) -> torch.Tensor: + """ + The passage embeddings. + """ + return self._passage_embeddings + + @property + def passage_index(self) -> Labels: + """ + The passage index. + """ + return self._passage_index + + @property + def device(self) -> torch.device: + """ + The device of the model. + """ + return next(self.parameters()).device + + @property + def question_tokenizer(self) -> tr.PreTrainedTokenizer: + """ + The question tokenizer. + """ + if self._question_tokenizer: + return self._question_tokenizer + + if ( + self.question_encoder.config.name_or_path + == self.question_encoder.config.name_or_path + ): + if not self._question_tokenizer: + self._question_tokenizer = tr.AutoTokenizer.from_pretrained( + self.question_encoder.config.name_or_path + ) + self._passage_tokenizer = self._question_tokenizer + return self._question_tokenizer + + if not self._question_tokenizer: + self._question_tokenizer = tr.AutoTokenizer.from_pretrained( + self.question_encoder.config.name_or_path + ) + return self._question_tokenizer + + @property + def passage_tokenizer(self) -> tr.PreTrainedTokenizer: + """ + The passage tokenizer. + """ + if self._passage_tokenizer: + return self._passage_tokenizer + + if ( + self.question_encoder.config.name_or_path + == self.passage_encoder.config.name_or_path + ): + if not self._question_tokenizer: + self._question_tokenizer = tr.AutoTokenizer.from_pretrained( + self.question_encoder.config.name_or_path + ) + self._passage_tokenizer = self._question_tokenizer + return self._passage_tokenizer + + if not self._passage_tokenizer: + self._passage_tokenizer = tr.AutoTokenizer.from_pretrained( + self.passage_encoder.config.name_or_path + ) + return self._passage_tokenizer + + def save_pretrained( + self, + output_dir: Union[str, os.PathLike], + question_encoder_name: str | None = None, + passage_encoder_name: str | None = None, + document_index_name: str | None = None, + push_to_hub: bool = False, + **kwargs, + ): + """ + Save the retriever to a directory. + + Args: + output_dir (`str`): + The directory to save the retriever to. + question_encoder_name (`str | None`): + The name of the question encoder. + passage_encoder_name (`str | None`): + The name of the passage encoder. + document_index_name (`str | None`): + The name of the document index. + push_to_hub (`bool`): + Whether to push the model to the hub. + """ + + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving retriever to {output_dir}") + + question_encoder_name = question_encoder_name or "question_encoder" + passage_encoder_name = passage_encoder_name or "passage_encoder" + document_index_name = document_index_name or "document_index" + + logger.info( + f"Saving question encoder state to {output_dir / question_encoder_name}" + ) + # self.question_encoder.config._name_or_path = question_encoder_name + self.question_encoder.register_for_auto_class() + self.question_encoder.save_pretrained( + str(output_dir / question_encoder_name), push_to_hub=push_to_hub, **kwargs + ) + self.question_tokenizer.save_pretrained( + str(output_dir / question_encoder_name), push_to_hub=push_to_hub, **kwargs + ) + if not self.passage_encoder_is_question_encoder: + logger.info( + f"Saving passage encoder state to {output_dir / passage_encoder_name}" + ) + # self.passage_encoder.config._name_or_path = passage_encoder_name + self.passage_encoder.register_for_auto_class() + self.passage_encoder.save_pretrained( + str(output_dir / passage_encoder_name), + push_to_hub=push_to_hub, + **kwargs, + ) + self.passage_tokenizer.save_pretrained( + output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs + ) + + if self.document_index is not None: + # save the indexer + self.document_index.save_pretrained( + str(output_dir / document_index_name), push_to_hub=push_to_hub, **kwargs + ) + + logger.info("Saving retriever to disk done.") + + @classmethod + def to_config(cls, *args, **kwargs): + config = { + "_target_": f"{cls.__class__.__module__}.{cls.__class__.__name__}", + "question_encoder": cls.question_encoder.config.name_or_path, + "passage_encoder": cls.passage_encoder.config.name_or_path + if not cls.passage_encoder_is_question_encoder + else None, + "document_index": to_config(cls.document_index), + } + return config diff --git a/relik/retriever/pytorch_modules/optim.py b/relik/retriever/pytorch_modules/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..3fa634d1b577ac65b7480564c2a9199582a563a3 --- /dev/null +++ b/relik/retriever/pytorch_modules/optim.py @@ -0,0 +1,111 @@ +import math + +import torch +from torch.optim import Optimizer + + +class RAdamW(Optimizer): + r"""Implements RAdamW algorithm. + + RAdam from `On the Variance of the Adaptive Learning Rate and Beyond + `_ + + * `Adam: A Method for Stochastic Optimization + `_ + * `Decoupled Weight Decay Regularization + `_ + * `On the Convergence of Adam and Beyond + `_ + * `On the Variance of the Adaptive Learning Rate and Beyond + `_ + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(RAdamW, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + eps = group["eps"] + lr = group["lr"] + if "rho_inf" not in group: + group["rho_inf"] = 2 / (1 - beta2) - 1 + rho_inf = group["rho_inf"] + + state["step"] += 1 + t = state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + rho_t = rho_inf - ((2 * t * (beta2**t)) / (1 - beta2**t)) + + # Perform stepweight decay + p.data.mul_(1 - lr * group["weight_decay"]) + + if rho_t >= 5: + var = exp_avg_sq.sqrt().add_(eps) + r = math.sqrt( + (1 - beta2**t) + * ((rho_t - 4) * (rho_t - 2) * rho_inf) + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) + + p.data.addcdiv_(exp_avg, var, value=-lr * r / (1 - beta1**t)) + else: + p.data.add_(exp_avg, alpha=-lr / (1 - beta1**t)) + + return loss diff --git a/relik/retriever/pytorch_modules/scheduler.py b/relik/retriever/pytorch_modules/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..5edd2433612b27c1a91fe189e57c4e2d41c462b2 --- /dev/null +++ b/relik/retriever/pytorch_modules/scheduler.py @@ -0,0 +1,54 @@ +import torch +from torch.optim.lr_scheduler import LRScheduler + + +class LinearSchedulerWithWarmup(LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + num_warmup_steps: int, + num_training_steps: int, + last_epoch: int = -1, + verbose: bool = False, + **kwargs, + ): + self.num_warmup_steps = num_warmup_steps + self.num_training_steps = num_training_steps + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + def scheduler_fn(current_step): + if current_step < self.num_warmup_steps: + return current_step / max(1, self.num_warmup_steps) + return max( + 0.0, + float(self.num_training_steps - current_step) + / float(max(1, self.num_training_steps - self.num_warmup_steps)), + ) + + return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] + + +class LinearScheduler(LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + num_training_steps: int, + last_epoch: int = -1, + verbose: bool = False, + **kwargs, + ): + self.num_training_steps = num_training_steps + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + def scheduler_fn(current_step): + # if current_step < self.num_warmup_steps: + # return current_step / max(1, self.num_warmup_steps) + return max( + 0.0, + float(self.num_training_steps - current_step) + / float(max(1, self.num_training_steps)), + ) + + return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] diff --git a/relik/retriever/trainer/__init__.py b/relik/retriever/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b18bf79091418217ae2bb782c3796dfa8b5b56 --- /dev/null +++ b/relik/retriever/trainer/__init__.py @@ -0,0 +1 @@ +from relik.retriever.trainer.train import RetrieverTrainer diff --git a/relik/retriever/trainer/train.py b/relik/retriever/trainer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c51545c30097ce8833fcd4cb91c09a6670b37644 --- /dev/null +++ b/relik/retriever/trainer/train.py @@ -0,0 +1,1018 @@ +from copy import deepcopy +import os +from pathlib import Path +from typing import List, Literal, Optional, Union + +import hydra +import lightning as pl +import omegaconf +import torch +from lightning import Trainer +from lightning.pytorch.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, + ModelSummary, +) +from lightning.pytorch.loggers import WandbLogger +from omegaconf import OmegaConf +from pprintpp import pformat + +from relik.common.log import get_logger +from relik.retriever.callbacks.base import NLPTemplateCallback +from relik.retriever.callbacks.evaluation_callbacks import ( + AvgRankingEvaluationCallback, + RecallAtKEvaluationCallback, +) +from relik.retriever.callbacks.prediction_callbacks import ( + GoldenRetrieverPredictionCallback, +) +from relik.retriever.callbacks.training_callbacks import NegativeAugmentationCallback +from relik.retriever.callbacks.utils_callbacks import ( + FreeUpIndexerVRAMCallback, + SavePredictionsCallback, + SaveRetrieverCallback, +) +from relik.retriever.data.datasets import GoldenRetrieverDataset +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.lightning_modules.pl_data_modules import ( + GoldenRetrieverPLDataModule, +) +from relik.retriever.lightning_modules.pl_modules import GoldenRetrieverPLModule +from relik.retriever.pytorch_modules.loss import MultiLabelNCELoss +from relik.retriever.pytorch_modules.model import GoldenRetriever +from relik.retriever.pytorch_modules.optim import RAdamW +from relik.retriever.pytorch_modules.scheduler import LinearScheduler + +logger = get_logger(__name__) + + +class RetrieverTrainer: + def __init__( + self, + retriever: GoldenRetriever, + train_dataset: GoldenRetrieverDataset | None = None, + val_dataset: GoldenRetrieverDataset + | list[GoldenRetrieverDataset] + | None = None, + test_dataset: GoldenRetrieverDataset + | list[GoldenRetrieverDataset] + | None = None, + num_workers: int = 4, + optimizer: torch.optim.Optimizer = RAdamW, + lr: float = 1e-5, + weight_decay: float = 0.01, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler = LinearScheduler, + num_warmup_steps: int = 0, + loss: torch.nn.Module = MultiLabelNCELoss, + callbacks: list | None = None, + accelerator: str = "auto", + devices: int = 1, + num_nodes: int = 1, + strategy: str = "auto", + accumulate_grad_batches: int = 1, + gradient_clip_val: float = 1.0, + val_check_interval: float = 1.0, + check_val_every_n_epoch: int = 1, + max_steps: int | None = None, + max_epochs: int | None = None, + deterministic: bool = True, + fast_dev_run: bool = False, + precision: int | str = 16, + reload_dataloaders_every_n_epochs: int = 1, + resume_from_checkpoint_path: str | os.PathLike | None = None, + trainer_kwargs: dict | None = None, + # eval parameters + metric_to_monitor: str = "validate_recall@{top_k}", + monitor_mode: str = "max", + top_k: int | List[int] = 100, + # early stopping parameters + early_stopping: bool = True, + early_stopping_patience: int = 10, + early_stopping_kwargs: dict | None = None, + # wandb logger parameters + log_to_wandb: bool = True, + wandb_entity: str | None = None, + wandb_experiment_name: str | None = None, + wandb_project_name: str = "golden-retriever", + wandb_save_dir: str | os.PathLike = "./", # TODO: i don't like this default + wandb_log_model: bool = True, + wandb_online_mode: bool = False, + wandb_watch: str = "all", + wandb_kwargs: dict | None = None, + # checkpoint parameters + model_checkpointing: bool = True, + checkpoint_dir: str | os.PathLike | None = None, + checkpoint_filename: str | os.PathLike | None = None, + save_top_k: int = 1, + save_last: bool = False, + checkpoint_kwargs: dict | None = None, + # prediction callback parameters + prediction_batch_size: int = 128, + # hard negatives callback parameters + max_hard_negatives_to_mine: int = 15, + hard_negatives_threshold: float = 0.0, + metrics_to_monitor_for_hard_negatives: str | None = None, + mine_hard_negatives_with_probability: float = 1.0, + # other parameters + seed: int = 42, + float32_matmul_precision: str = "medium", + **kwargs, + ): + # put all the parameters in the class + self.retriever = retriever + # datasets + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.test_dataset = test_dataset + self.num_workers = num_workers + # trainer parameters + self.optimizer = optimizer + self.lr = lr + self.weight_decay = weight_decay + self.lr_scheduler = lr_scheduler + self.num_warmup_steps = num_warmup_steps + self.loss = loss + self.callbacks = callbacks + self.accelerator = accelerator + self.devices = devices + self.num_nodes = num_nodes + self.strategy = strategy + self.accumulate_grad_batches = accumulate_grad_batches + self.gradient_clip_val = gradient_clip_val + self.val_check_interval = val_check_interval + self.check_val_every_n_epoch = check_val_every_n_epoch + self.max_steps = max_steps + self.max_epochs = max_epochs + self.deterministic = deterministic + self.fast_dev_run = fast_dev_run + self.precision = precision + self.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs + self.resume_from_checkpoint_path = resume_from_checkpoint_path + self.trainer_kwargs = trainer_kwargs or {} + # eval parameters + self.metric_to_monitor = metric_to_monitor + self.monitor_mode = monitor_mode + self.top_k = top_k + # early stopping parameters + self.early_stopping = early_stopping + self.early_stopping_patience = early_stopping_patience + self.early_stopping_kwargs = early_stopping_kwargs + # wandb logger parameters + self.log_to_wandb = log_to_wandb + self.wandb_entity = wandb_entity + self.wandb_experiment_name = wandb_experiment_name + self.wandb_project_name = wandb_project_name + self.wandb_save_dir = wandb_save_dir + self.wandb_log_model = wandb_log_model + self.wandb_online_mode = wandb_online_mode + self.wandb_watch = wandb_watch + self.wandb_kwargs = wandb_kwargs + # checkpoint parameters + self.model_checkpointing = model_checkpointing + self.checkpoint_dir = checkpoint_dir + self.checkpoint_filename = checkpoint_filename + self.save_top_k = save_top_k + self.save_last = save_last + self.checkpoint_kwargs = checkpoint_kwargs + # prediction callback parameters + self.prediction_batch_size = prediction_batch_size + # hard negatives callback parameters + self.max_hard_negatives_to_mine = max_hard_negatives_to_mine + self.hard_negatives_threshold = hard_negatives_threshold + self.metrics_to_monitor_for_hard_negatives = ( + metrics_to_monitor_for_hard_negatives + ) + self.mine_hard_negatives_with_probability = mine_hard_negatives_with_probability + # other parameters + self.seed = seed + self.float32_matmul_precision = float32_matmul_precision + + if self.max_epochs is None and self.max_steps is None: + raise ValueError( + "Either `max_epochs` or `max_steps` should be specified in the trainer configuration" + ) + + if self.max_epochs is not None and self.max_steps is not None: + logger.info( + "Both `max_epochs` and `max_steps` are specified in the trainer configuration. " + "Will use `max_epochs` for the number of training steps" + ) + self.max_steps = None + + # reproducibility + pl.seed_everything(self.seed) + # set the precision of matmul operations + torch.set_float32_matmul_precision(self.float32_matmul_precision) + + # lightning data module declaration + self.lightning_datamodule = self.configure_lightning_datamodule() + + if self.max_epochs is not None: + logger.info(f"Number of training epochs: {self.max_epochs}") + self.max_steps = ( + len(self.lightning_datamodule.train_dataloader()) * self.max_epochs + ) + + # optimizer declaration + self.optimizer, self.lr_scheduler = self.configure_optimizers() + + # lightning module declaration + self.lightning_module = self.configure_lightning_module() + + # logger and experiment declaration + # update self.wandb_kwargs + wandb_args = dict( + entity=self.wandb_entity, + project=self.wandb_project_name, + name=self.wandb_experiment_name, + save_dir=self.wandb_save_dir, + log_model=self.wandb_log_model, + offline=not self.wandb_online_mode, + watch=self.wandb_watch, + lightning_module=self.lightning_module, + ) + if self.wandb_kwargs is not None: + wandb_args.update(self.wandb_kwargs) + self.wandb_kwargs = wandb_args + self.wandb_logger: Optional[WandbLogger] = None + self.experiment_path: Optional[Path] = None + + # setup metrics to monitor for a bunch of callbacks + if isinstance(self.top_k, int): + self.top_k = [self.top_k] + # save the target top_k + self.target_top_k = self.top_k[0] + self.metric_to_monitor = self.metric_to_monitor.format(top_k=self.target_top_k) + + # explicitly configure some callbacks that will be needed not only by the + # pl.Trainer but also in this class + # model checkpoint callback + if self.save_last: + logger.warning( + "We will override the `save_last` of `ModelCheckpoint` to `False`. " + "Instead, we will use a separate `ModelCheckpoint` callback to save the last checkpoint" + ) + checkpoint_kwargs = dict( + monitor=self.metric_to_monitor, + mode=self.monitor_mode, + verbose=True, + save_top_k=self.save_top_k, + filename=self.checkpoint_filename, + dirpath=self.checkpoint_dir, + auto_insert_metric_name=False, + ) + if self.checkpoint_kwargs is not None: + checkpoint_kwargs.update(self.checkpoint_kwargs) + self.checkpoint_kwargs = checkpoint_kwargs + self.model_checkpoint_callback: ModelCheckpoint | None = None + self.checkpoint_path: str | os.PathLike | None = None + # last checkpoint callback + self.latest_model_checkpoint_callback: ModelCheckpoint | None = None + self.last_checkpoint_kwargs: dict | None = None + if self.save_last: + last_checkpoint_kwargs = deepcopy(self.checkpoint_kwargs) + last_checkpoint_kwargs["save_top_k"] = 1 + last_checkpoint_kwargs["filename"] = "last-{epoch}-{step}" + last_checkpoint_kwargs["monitor"] = "step" + last_checkpoint_kwargs["mode"] = "max" + self.last_checkpoint_kwargs = last_checkpoint_kwargs + + # early stopping callback + early_stopping_kwargs = dict( + monitor=self.metric_to_monitor, + mode=self.monitor_mode, + patience=self.early_stopping_patience, + ) + if self.early_stopping_kwargs is not None: + early_stopping_kwargs.update(self.early_stopping_kwargs) + self.early_stopping_kwargs = early_stopping_kwargs + self.early_stopping_callback: EarlyStopping | None = None + + # other callbacks declaration + self.callbacks_store: List[pl.Callback] = [] # self.configure_callbacks() + # add default callbacks + self.callbacks_store += [ + ModelSummary(max_depth=2), + LearningRateMonitor(logging_interval="step"), + ] + + # lazy trainer declaration + self.trainer: pl.Trainer | None = None + + def configure_lightning_datamodule(self, *args, **kwargs): + # lightning data module declaration + if self.val_dataset is not None and isinstance( + self.val_dataset, GoldenRetrieverDataset + ): + self.val_dataset = [self.val_dataset] + if self.test_dataset is not None and isinstance( + self.test_dataset, GoldenRetrieverDataset + ): + self.test_dataset = [self.test_dataset] + + self.lightning_datamodule = GoldenRetrieverPLDataModule( + train_dataset=self.train_dataset, + val_datasets=self.val_dataset, + test_datasets=self.test_dataset, + num_workers=self.num_workers, + *args, + **kwargs, + ) + return self.lightning_datamodule + + def configure_lightning_module(self, *args, **kwargs): + # add loss object to the retriever + if self.retriever.loss_type is None: + self.retriever.loss_type = self.loss() + + # lightning module declaration + self.lightning_module = GoldenRetrieverPLModule( + model=self.retriever, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + *args, + **kwargs, + ) + + return self.lightning_module + + def configure_optimizers(self, *args, **kwargs): + # check if it is the class or the instance + if isinstance(self.optimizer, type): + param_optimizer = list(self.retriever.named_parameters()) + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in param_optimizer if "layer_norm_layer" in n + ], + "weight_decay": self.weight_decay, + "lr": 1e-4, + }, + { + "params": [ + p + for n, p in param_optimizer + if all(nd not in n for nd in no_decay) + and "layer_norm_layer" not in n + ], + "weight_decay": self.weight_decay, + }, + { + "params": [ + p + for n, p in param_optimizer + if "layer_norm_layer" not in n + and any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + self.optimizer = self.optimizer( + # params=self.retriever.parameters(), + params=optimizer_grouped_parameters, + lr=self.lr, + # weight_decay=self.weight_decay, + ) + else: + self.optimizer = self.optimizer + + # LR Scheduler declaration + # check if it is the class, the instance or a function + if self.lr_scheduler is not None: + if isinstance(self.lr_scheduler, type): + self.lr_scheduler = self.lr_scheduler( + optimizer=self.optimizer, + num_warmup_steps=self.num_warmup_steps, + num_training_steps=self.max_steps, + ) + + return self.optimizer, self.lr_scheduler + + @staticmethod + def configure_logger( + name: str, + save_dir: str | os.PathLike, + offline: bool, + entity: str, + project: str, + log_model: Literal["all"] | bool, + watch: str | None = None, + lightning_module: torch.nn.Module | None = None, + *args, + **kwargs, + ) -> WandbLogger: + """ + Configure the wandb logger + + Args: + name (`str`): + The name of the experiment + save_dir (`str`, `os.PathLike`): + The directory where to save the experiment + offline (`bool`): + Whether to run wandb offline + entity (`str`): + The wandb entity + project (`str`): + The wandb project name + log_model (`Literal["all"]`, `bool`): + Whether to log the model to wandb + watch (`str`, optional, defaults to `None`): + The mode to watch the model + lightning_module (`torch.nn.Module`, optional, defaults to `None`): + The lightning module to watch + *args: + Additional args + **kwargs: + Additional kwargs + + Returns: + `lightning.loggers.WandbLogger`: + The wandb logger + """ + wandb_logger = WandbLogger( + name=name, + save_dir=save_dir, + offline=offline, + project=project, + log_model=log_model and not offline, + entity=entity, + *args, + **kwargs, + ) + if watch is not None and lightning_module is not None: + watch_kwargs = dict(model=lightning_module) + if watch is not None: + watch_kwargs["log"] = watch + wandb_logger.watch(**watch_kwargs) + return wandb_logger + + @staticmethod + def configure_early_stopping( + monitor: str, + mode: str, + patience: int = 3, + *args, + **kwargs, + ) -> EarlyStopping: + logger.info(f"Enabling EarlyStopping callback with patience: {patience}") + early_stopping_callback = EarlyStopping( + monitor=monitor, + mode=mode, + patience=patience, + *args, + **kwargs, + ) + return early_stopping_callback + + def configure_model_checkpoint( + self, + monitor: str, + mode: str, + verbose: bool = True, + save_top_k: int = 1, + save_last: bool = False, + filename: str | os.PathLike | None = None, + dirpath: str | os.PathLike | None = None, + auto_insert_metric_name: bool = False, + *args, + **kwargs, + ) -> ModelCheckpoint: + logger.info("Enabling Model Checkpointing") + if dirpath is None: + dirpath = ( + self.experiment_path / "checkpoints" if self.experiment_path else None + ) + if filename is None: + filename = ( + "checkpoint-" + monitor + "_{" + monitor + ":.4f}-epoch_{epoch:02d}" + ) + self.checkpoint_path = dirpath / filename if dirpath is not None else None + logger.info(f"Checkpoint directory: {dirpath}") + logger.info(f"Checkpoint filename: {filename}") + + kwargs = dict( + monitor=monitor, + mode=mode, + verbose=verbose, + save_top_k=save_top_k, + save_last=save_last, + filename=filename, + dirpath=dirpath, + auto_insert_metric_name=auto_insert_metric_name, + *args, + **kwargs, + ) + + # update the kwargs + # TODO: this is bad + # kwargs.update( + # dirpath=self.checkpoint_dir, + # filename=self.checkpoint_filename, + # ) + # modelcheckpoint_kwargs = dict( + # dirpath=self.checkpoint_dir, + # filename=self.checkpoint_filename, + # ) + # modelcheckpoint_kwargs.update(kwargs) + self.model_checkpoint_callback = ModelCheckpoint(**kwargs) + return self.model_checkpoint_callback + + def configure_hard_negatives_callback(self): + metrics_to_monitor = ( + self.metrics_to_monitor_for_hard_negatives or self.metric_to_monitor + ) + hard_negatives_callback = NegativeAugmentationCallback( + k=self.target_top_k, + batch_size=self.prediction_batch_size, + precision=self.precision, + stages=["validate"], + metrics_to_monitor=metrics_to_monitor, + threshold=self.hard_negatives_threshold, + max_negatives=self.max_hard_negatives_to_mine, + add_with_probability=self.mine_hard_negatives_with_probability, + refresh_every_n_epochs=1, + ) + return hard_negatives_callback + + def training_callbacks(self): + if self.model_checkpointing: + self.model_checkpoint_callback = self.configure_model_checkpoint( + **self.checkpoint_kwargs + ) + self.callbacks_store.append(self.model_checkpoint_callback) + if self.save_last: + self.latest_model_checkpoint_callback = self.configure_model_checkpoint( + **self.last_checkpoint_kwargs + ) + self.callbacks_store.append(self.latest_model_checkpoint_callback) + + self.callbacks_store.append(SaveRetrieverCallback()) + if self.early_stopping: + self.early_stopping_callback = self.configure_early_stopping( + **self.early_stopping_kwargs + ) + return self.callbacks_store + + def configure_metrics_callbacks( + self, save_predictions: bool = False + ) -> List[NLPTemplateCallback]: + """ + Configure the metrics callbacks for the trainer. This method is called + by the `eval_callbacks` method, and it is used to configure the callbacks + that will be used to evaluate the model during training. + + Args: + save_predictions (`bool`, optional, defaults to `False`): + Whether to save the predictions to disk or not + + Returns: + `List[NLPTemplateCallback]`: + The list of callbacks to use for evaluation + """ + # prediction callback + metrics_callbacks: List[NLPTemplateCallback] = [ + RecallAtKEvaluationCallback(k, verbose=True) for k in self.top_k + ] + metrics_callbacks += [ + AvgRankingEvaluationCallback(k, verbose=True) for k in self.top_k + ] + if save_predictions: + metrics_callbacks.append(SavePredictionsCallback()) + return metrics_callbacks + + def configure_prediction_callbacks( + self, + batch_size: int = 64, + precision: int | str = 32, + k: int | None = None, + force_reindex: bool = True, + metrics_callbacks: list[NLPTemplateCallback] | None = None, + *args, + **kwargs, + ): + if k is None: + # we need the largest k for the prediction callback + # get the max top_k for the prediction callback + k = sorted(self.top_k, reverse=True)[0] + if metrics_callbacks is None: + metrics_callbacks = self.configure_metrics_callbacks() + + prediction_callback = GoldenRetrieverPredictionCallback( + batch_size=batch_size, + precision=precision, + k=k, + force_reindex=force_reindex, + other_callbacks=metrics_callbacks, + *args, + **kwargs, + ) + return prediction_callback + + def train(self, *args, **kwargs): + """ + Train the model + + Args: + *args: + Additional args + **kwargs: + Additional kwargs + + Returns: + `None` + """ + if self.log_to_wandb: + logger.info("Instantiating Wandb Logger") + # log the args to wandb + # logger.info(pformat(self.wandb_kwargs)) + self.wandb_logger = self.configure_logger(**self.wandb_kwargs) + self.experiment_path = Path(self.wandb_logger.experiment.dir) + + # set-up training specific callbacks + self.callbacks_store = self.training_callbacks() + # add the evaluation callbacks + self.callbacks_store.append( + self.configure_prediction_callbacks( + batch_size=self.prediction_batch_size, + precision=self.precision, + ) + ) + # add the hard negatives callback after the evaluation callback + if self.max_hard_negatives_to_mine > 0: + self.callbacks_store.append(self.configure_hard_negatives_callback()) + + self.callbacks_store.append(FreeUpIndexerVRAMCallback()) + + if self.trainer is None: + logger.info("Instantiating the Trainer") + self.trainer = pl.Trainer( + accelerator=self.accelerator, + devices=self.devices, + num_nodes=self.num_nodes, + strategy=self.strategy, + accumulate_grad_batches=self.accumulate_grad_batches, + max_epochs=self.max_epochs, + max_steps=self.max_steps, + gradient_clip_val=self.gradient_clip_val, + val_check_interval=self.val_check_interval, + check_val_every_n_epoch=self.check_val_every_n_epoch, + deterministic=self.deterministic, + fast_dev_run=self.fast_dev_run, + precision=self.precision, + reload_dataloaders_every_n_epochs=self.reload_dataloaders_every_n_epochs, + callbacks=self.callbacks_store, + logger=self.wandb_logger, + **self.trainer_kwargs, + ) + + # # save this class as config to file + # if self.experiment_path is not None: + # logger.info("Saving the configuration to file") + # self.experiment_path.mkdir(parents=True, exist_ok=True) + # OmegaConf.save( + # OmegaConf.create(to_config(self)), + # self.experiment_path / "trainer_config.yaml", + # ) + self.trainer.fit( + self.lightning_module, + datamodule=self.lightning_datamodule, + ckpt_path=self.resume_from_checkpoint_path, + ) + + def test( + self, + lightning_module: GoldenRetrieverPLModule | None = None, + checkpoint_path: str | os.PathLike | None = None, + lightning_datamodule: GoldenRetrieverPLDataModule | None = None, + force_reindex: bool = False, + *args, + **kwargs, + ): + """ + Test the model + + Args: + lightning_module (`GoldenRetrieverPLModule`, optional, defaults to `None`): + The lightning module to test + checkpoint_path (`str`, `os.PathLike`, optional, defaults to `None`): + The path to the checkpoint to load + lightning_datamodule (`GoldenRetrieverPLDataModule`, optional, defaults to `None`): + The lightning data module to use for testing + *args: + Additional args + **kwargs: + Additional kwargs + + Returns: + `None` + """ + if self.test_dataset is None: + logger.warning("No test dataset provided. Skipping testing.") + return + + if self.trainer is None: + self.trainer = pl.Trainer( + accelerator=self.accelerator, + devices=self.devices, + num_nodes=self.num_nodes, + strategy=self.strategy, + deterministic=self.deterministic, + fast_dev_run=self.fast_dev_run, + precision=self.precision, + callbacks=[ + self.configure_prediction_callbacks( + batch_size=self.prediction_batch_size, + precision=self.precision, + force_reindex=force_reindex, + ) + ], + **self.trainer_kwargs, + ) + if lightning_module is not None: + best_lightning_module = lightning_module + else: + try: + if self.fast_dev_run: + best_lightning_module = self.lightning_module + else: + # load best model for testing + if checkpoint_path is not None: + best_model_path = checkpoint_path + elif self.checkpoint_path is not None: + best_model_path = self.checkpoint_path + elif self.model_checkpoint_callback: + best_model_path = self.model_checkpoint_callback.best_model_path + else: + raise ValueError( + "Either `checkpoint_path` or `model_checkpoint_callback` should " + "be provided to the trainer" + ) + logger.info(f"Loading best model from {best_model_path}") + + best_lightning_module = ( + GoldenRetrieverPLModule.load_from_checkpoint(best_model_path) + ) + except Exception as e: + logger.info(f"Failed to load the model from checkpoint: {e}") + logger.info("Using last model instead") + best_lightning_module = self.lightning_module + + lightning_datamodule = lightning_datamodule or self.lightning_datamodule + # module test + self.trainer.test(best_lightning_module, datamodule=lightning_datamodule) + + +def train(conf: omegaconf.DictConfig) -> None: + logger.info("Starting training with config:") + logger.info(pformat(OmegaConf.to_container(conf))) + + logger.info("Instantiating the Retriever") + retriever: GoldenRetriever = hydra.utils.instantiate( + conf.retriever, _recursive_=False + ) + + logger.info("Instantiating datasets") + train_dataset: GoldenRetrieverDataset = hydra.utils.instantiate( + conf.data.train_dataset, _recursive_=False + ) + val_dataset: GoldenRetrieverDataset = hydra.utils.instantiate( + conf.data.val_dataset, _recursive_=False + ) + test_dataset: GoldenRetrieverDataset = hydra.utils.instantiate( + conf.data.test_dataset, _recursive_=False + ) + + logger.info("Loading the document index") + document_index: BaseDocumentIndex = hydra.utils.instantiate( + conf.data.document_index, _recursive_=False + ) + retriever.document_index = document_index + + logger.info("Instantiating the Trainer") + trainer: Trainer = hydra.utils.instantiate( + conf.train, + retriever=retriever, + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, + _recursive_=False, + ) + + logger.info("Starting training") + trainer.train() + + logger.info("Starting testing") + trainer.test() + + logger.info("Training and testing completed") + + +@hydra.main(config_path="../../conf", config_name="default", version_base="1.3") +def main(conf: omegaconf.DictConfig): + train(conf) + + +def train_hydra(conf: omegaconf.DictConfig) -> None: + # reproducibility + pl.seed_everything(conf.train.seed) + torch.set_float32_matmul_precision(conf.train.float32_matmul_precision) + + logger.info(f"Starting training for [bold cyan]{conf.model_name}[/bold cyan] model") + if conf.train.pl_trainer.fast_dev_run: + logger.info( + f"Debug mode {conf.train.pl_trainer.fast_dev_run}. Forcing debugger configuration" + ) + # Debuggers don't like GPUs nor multiprocessing + # conf.train.pl_trainer.accelerator = "cpu" + conf.train.pl_trainer.devices = 1 + conf.train.pl_trainer.strategy = "auto" + conf.train.pl_trainer.precision = 32 + if "num_workers" in conf.data.datamodule: + conf.data.datamodule.num_workers = { + k: 0 for k in conf.data.datamodule.num_workers + } + # Switch wandb to offline mode to prevent online logging + conf.logging.log = None + # remove model checkpoint callback + conf.train.model_checkpoint_callback = None + + if "print_config" in conf and conf.print_config: + # pprint(OmegaConf.to_container(conf), console=logger, expand_all=True) + logger.info(pformat(OmegaConf.to_container(conf))) + + # data module declaration + logger.info("Instantiating the Data Module") + pl_data_module: GoldenRetrieverPLDataModule = hydra.utils.instantiate( + conf.data.datamodule, _recursive_=False + ) + # force setup to get labels initialized for the model + pl_data_module.prepare_data() + # main module declaration + pl_module: Optional[GoldenRetrieverPLModule] = None + + if not conf.train.only_test: + pl_data_module.setup("fit") + + # count the number of training steps + if ( + "max_epochs" in conf.train.pl_trainer + and conf.train.pl_trainer.max_epochs > 0 + ): + num_training_steps = ( + len(pl_data_module.train_dataloader()) + * conf.train.pl_trainer.max_epochs + ) + if "max_steps" in conf.train.pl_trainer: + logger.info( + "Both `max_epochs` and `max_steps` are specified in the trainer configuration. " + "Will use `max_epochs` for the number of training steps" + ) + conf.train.pl_trainer.max_steps = None + elif ( + "max_steps" in conf.train.pl_trainer and conf.train.pl_trainer.max_steps > 0 + ): + num_training_steps = conf.train.pl_trainer.max_steps + conf.train.pl_trainer.max_epochs = None + else: + raise ValueError( + "Either `max_epochs` or `max_steps` should be specified in the trainer configuration" + ) + logger.info(f"Expected number of training steps: {num_training_steps}") + + if "lr_scheduler" in conf.model.pl_module and conf.model.pl_module.lr_scheduler: + # set the number of warmup steps as x% of the total number of training steps + if conf.model.pl_module.lr_scheduler.num_warmup_steps is None: + if ( + "warmup_steps_ratio" in conf.model.pl_module + and conf.model.pl_module.warmup_steps_ratio is not None + ): + conf.model.pl_module.lr_scheduler.num_warmup_steps = int( + conf.model.pl_module.lr_scheduler.num_training_steps + * conf.model.pl_module.warmup_steps_ratio + ) + else: + conf.model.pl_module.lr_scheduler.num_warmup_steps = 0 + logger.info( + f"Number of warmup steps: {conf.model.pl_module.lr_scheduler.num_warmup_steps}" + ) + + logger.info("Instantiating the Model") + pl_module: GoldenRetrieverPLModule = hydra.utils.instantiate( + conf.model.pl_module, _recursive_=False + ) + if ( + "pretrain_ckpt_path" in conf.train + and conf.train.pretrain_ckpt_path is not None + ): + logger.info( + f"Loading pretrained checkpoint from {conf.train.pretrain_ckpt_path}" + ) + pl_module.load_state_dict( + torch.load(conf.train.pretrain_ckpt_path)["state_dict"], strict=False + ) + + if "compile" in conf.model.pl_module and conf.model.pl_module.compile: + try: + pl_module = torch.compile(pl_module, backend="inductor") + except Exception: + logger.info( + "Failed to compile the model, you may need to install PyTorch 2.0" + ) + + # callbacks declaration + callbacks_store = [ModelSummary(max_depth=2)] + + experiment_logger: Optional[WandbLogger] = None + experiment_path: Optional[Path] = None + if conf.logging.log: + logger.info("Instantiating Wandb Logger") + experiment_logger = hydra.utils.instantiate(conf.logging.wandb_arg) + if pl_module is not None: + # it may happen that the model is not instantiated if we are only testing + # in that case, we don't need to watch the model + experiment_logger.watch(pl_module, **conf.logging.watch) + experiment_path = Path(experiment_logger.experiment.dir) + # Store the YaML config separately into the wandb dir + yaml_conf: str = OmegaConf.to_yaml(cfg=conf) + (experiment_path / "hparams.yaml").write_text(yaml_conf) + # Add a Learning Rate Monitor callback to log the learning rate + callbacks_store.append(LearningRateMonitor(logging_interval="step")) + + early_stopping_callback: Optional[EarlyStopping] = None + if conf.train.early_stopping_callback is not None: + early_stopping_callback = hydra.utils.instantiate( + conf.train.early_stopping_callback + ) + callbacks_store.append(early_stopping_callback) + + model_checkpoint_callback: Optional[ModelCheckpoint] = None + if conf.train.model_checkpoint_callback is not None: + model_checkpoint_callback = hydra.utils.instantiate( + conf.train.model_checkpoint_callback, + dirpath=experiment_path / "checkpoints" if experiment_path else None, + ) + callbacks_store.append(model_checkpoint_callback) + + if "callbacks" in conf.train and conf.train.callbacks is not None: + for _, callback in conf.train.callbacks.items(): + # callback can be a list of callbacks or a single callback + if isinstance(callback, omegaconf.listconfig.ListConfig): + for cb in callback: + if cb is not None: + callbacks_store.append( + hydra.utils.instantiate(cb, _recursive_=False) + ) + else: + if callback is not None: + callbacks_store.append(hydra.utils.instantiate(callback)) + + # trainer + logger.info("Instantiating the Trainer") + trainer: Trainer = hydra.utils.instantiate( + conf.train.pl_trainer, callbacks=callbacks_store, logger=experiment_logger + ) + + if not conf.train.only_test: + # module fit + trainer.fit(pl_module, datamodule=pl_data_module) + + if conf.train.pl_trainer.fast_dev_run: + best_pl_module = pl_module + else: + # load best model for testing + if conf.train.checkpoint_path: + best_model_path = conf.evaluation.checkpoint_path + elif model_checkpoint_callback: + best_model_path = model_checkpoint_callback.best_model_path + else: + raise ValueError( + "Either `checkpoint_path` or `model_checkpoint_callback` should " + "be specified in the evaluation configuration" + ) + logger.info(f"Loading best model from {best_model_path}") + + try: + best_pl_module = GoldenRetrieverPLModule.load_from_checkpoint( + best_model_path + ) + except Exception as e: + logger.info(f"Failed to load the model from checkpoint: {e}") + logger.info("Using last model instead") + best_pl_module = pl_module + if "compile" in conf.model.pl_module and conf.model.pl_module.compile: + try: + best_pl_module = torch.compile(best_pl_module, backend="inductor") + except Exception: + logger.info( + "Failed to compile the model, you may need to install PyTorch 2.0" + ) + + # module test + trainer.test(best_pl_module, datamodule=pl_data_module) + + +if __name__ == "__main__": + main() diff --git a/relik/version.py b/relik/version.py new file mode 100644 index 0000000000000000000000000000000000000000..bed137800c980e0e82d7c8ccdf474053baed630f --- /dev/null +++ b/relik/version.py @@ -0,0 +1,13 @@ +import os + +_MAJOR = "0" +_MINOR = "1" +# On main and in a nightly release the patch should be one ahead of the last +# released build. +_PATCH = "0" +# This is mainly for nightly builds which have the suffix ".dev$DATE". See +# https://semver.org/#is-v123-a-semantic-version for the semantics. +_SUFFIX = os.environ.get("RELIK_VERSION_SUFFIX", "") + +VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) +VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)