diff --git a/relik/__init__.py b/relik/__init__.py deleted file mode 100644 index 42a3df6b991b0af65ec5974fc4faa381b8e555b7..0000000000000000000000000000000000000000 --- a/relik/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from relik.retriever.pytorch_modules.model import GoldenRetriever diff --git a/relik/common/__init__.py b/relik/common/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/common/log.py b/relik/common/log.py deleted file mode 100644 index b91e1fa7bfc22b759e0da4d69563315e31ce0e60..0000000000000000000000000000000000000000 --- a/relik/common/log.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -import sys -import threading -from typing import Optional - -from rich import get_console - -_lock = threading.Lock() -_default_handler: Optional[logging.Handler] = None - -_default_log_level = logging.WARNING - -# fancy logger -_console = get_console() - - -def _get_library_name() -> str: - return __name__.split(".")[0] - - -def _get_library_root_logger() -> logging.Logger: - return logging.getLogger(_get_library_name()) - - -def _configure_library_root_logger() -> None: - global _default_handler - - with _lock: - if _default_handler: - # This library has already configured the library root logger. - return - _default_handler = logging.StreamHandler() # Set sys.stderr as stream. - _default_handler.flush = sys.stderr.flush - - # Apply our default configuration to the library root logger. - library_root_logger = _get_library_root_logger() - library_root_logger.addHandler(_default_handler) - library_root_logger.setLevel(_default_log_level) - library_root_logger.propagate = False - - -def _reset_library_root_logger() -> None: - global _default_handler - - with _lock: - if not _default_handler: - return - - library_root_logger = _get_library_root_logger() - library_root_logger.removeHandler(_default_handler) - library_root_logger.setLevel(logging.NOTSET) - _default_handler = None - - -def set_log_level(level: int, logger: logging.Logger = None) -> None: - """ - Set the log level. - Args: - level (:obj:`int`): - Logging level. - logger (:obj:`logging.Logger`): - Logger to set the log level. - """ - if not logger: - _configure_library_root_logger() - logger = _get_library_root_logger() - logger.setLevel(level) - - -def get_logger( - name: Optional[str] = None, - level: Optional[int] = None, - formatter: Optional[str] = None, -) -> logging.Logger: - """ - Return a logger with the specified name. - """ - - if name is None: - name = _get_library_name() - - _configure_library_root_logger() - - if level is not None: - set_log_level(level) - - if formatter is None: - formatter = logging.Formatter( - "%(asctime)s - %(levelname)s - %(name)s - %(message)s" - ) - _default_handler.setFormatter(formatter) - - return logging.getLogger(name) - - -def get_console_logger(): - return _console diff --git a/relik/common/upload.py b/relik/common/upload.py deleted file mode 100644 index b2cad77bd95f43992af3144baf296560a496556b..0000000000000000000000000000000000000000 --- a/relik/common/upload.py +++ /dev/null @@ -1,128 +0,0 @@ -import argparse -import json -import logging -import os -import tempfile -import zipfile -from datetime import datetime -from pathlib import Path -from typing import Optional, Union - -import huggingface_hub - -from relik.common.log import get_logger -from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5 - -logger = get_logger(level=logging.DEBUG) - - -def create_info_file(tmpdir: Path): - logger.debug("Computing md5 of model.zip") - md5 = get_md5(tmpdir / "model.zip") - date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT) - - logger.debug("Dumping info.json file") - with (tmpdir / "info.json").open("w") as f: - json.dump(dict(md5=md5, upload_date=date), f, indent=2) - - -def zip_run( - dir_path: Union[str, os.PathLike], - tmpdir: Union[str, os.PathLike], - zip_name: str = "model.zip", -) -> Path: - logger.debug(f"zipping {dir_path} to {tmpdir}") - # creates a zip version of the provided dir_path - run_dir = Path(dir_path) - zip_path = tmpdir / zip_name - - with zipfile.ZipFile(zip_path, "w") as zip_file: - # fully zip the run directory maintaining its structure - for file in run_dir.rglob("*.*"): - if file.is_dir(): - continue - - zip_file.write(file, arcname=file.relative_to(run_dir)) - - return zip_path - - -def upload( - model_dir: Union[str, os.PathLike], - model_name: str, - organization: Optional[str] = None, - repo_name: Optional[str] = None, - commit: Optional[str] = None, - archive: bool = False, -): - token = huggingface_hub.HfFolder.get_token() - if token is None: - print( - "No HuggingFace token found. You need to execute `huggingface-cli login` first!" - ) - return - - repo_id = repo_name or model_name - if organization is not None: - repo_id = f"{organization}/{repo_id}" - with tempfile.TemporaryDirectory() as tmpdir: - api = huggingface_hub.HfApi() - repo_url = api.create_repo( - token=token, - repo_id=repo_id, - exist_ok=True, - ) - repo = huggingface_hub.Repository( - str(tmpdir), clone_from=repo_url, use_auth_token=token - ) - - tmp_path = Path(tmpdir) - if archive: - # otherwise we zip the model_dir - logger.debug(f"Zipping {model_dir} to {tmp_path}") - zip_run(model_dir, tmp_path) - create_info_file(tmp_path) - else: - # if the user wants to upload a transformers model, we don't need to zip it - # we just need to copy the files to the tmpdir - logger.debug(f"Copying {model_dir} to {tmpdir}") - os.system(f"cp -r {model_dir}/* {tmpdir}") - - # this method automatically puts large files (>10MB) into git lfs - repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp") - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument( - "model_dir", help="The directory of the model you want to upload" - ) - parser.add_argument("model_name", help="The model you want to upload") - parser.add_argument( - "--organization", - help="the name of the organization where you want to upload the model", - ) - parser.add_argument( - "--repo_name", - help="Optional name to use when uploading to the HuggingFace repository", - ) - parser.add_argument( - "--commit", help="Commit message to use when pushing to the HuggingFace Hub" - ) - parser.add_argument( - "--archive", - action="store_true", - help=""" - Whether to compress the model directory before uploading it. - If True, the model directory will be zipped and the zip file will be uploaded. - If False, the model directory will be uploaded as is.""", - ) - return parser.parse_args() - - -def main(): - upload(**vars(parse_args())) - - -if __name__ == "__main__": - main() diff --git a/relik/common/utils.py b/relik/common/utils.py deleted file mode 100644 index 6e49b88feca9a8e6cc517b583fd07c924d57ed8b..0000000000000000000000000000000000000000 --- a/relik/common/utils.py +++ /dev/null @@ -1,609 +0,0 @@ -import importlib.util -import json -import logging -import os -import shutil -import tarfile -import tempfile -from functools import partial -from hashlib import sha256 -from pathlib import Path -from typing import Any, BinaryIO, Dict, List, Optional, Union -from urllib.parse import urlparse -from zipfile import ZipFile, is_zipfile - -import huggingface_hub -import requests -import tqdm -from filelock import FileLock -from transformers.utils.hub import cached_file as hf_cached_file - -from relik.common.log import get_logger - -# name constants -WEIGHTS_NAME = "weights.pt" -ONNX_WEIGHTS_NAME = "weights.onnx" -CONFIG_NAME = "config.yaml" -LABELS_NAME = "labels.json" - -# SAPIENZANLP_USER_NAME = "sapienzanlp" -SAPIENZANLP_USER_NAME = "riccorl" -SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}" -SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = ( - f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip" -) -# path constants -SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp") -SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S" - - -logger = get_logger(__name__) - - -def sapienzanlp_model_urls(model_id: str) -> str: - """ - Returns the URL for a possible SapienzaNLP valid model. - - Args: - model_id (:obj:`str`): - A SapienzaNLP model id. - - Returns: - :obj:`str`: The url for the model id. - """ - # check if there is already the namespace of the user - if "/" in model_id: - return model_id - return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id) - - -def is_package_available(package_name: str) -> bool: - """ - Check if a package is available. - - Args: - package_name (`str`): The name of the package to check. - """ - return importlib.util.find_spec(package_name) is not None - - -def load_json(path: Union[str, Path]) -> Any: - """ - Load a json file provided in input. - - Args: - path (`Union[str, Path]`): The path to the json file to load. - - Returns: - `Any`: The loaded json file. - """ - with open(path, encoding="utf8") as f: - return json.load(f) - - -def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None): - """ - Dump input to json file. - - Args: - document (`Any`): The document to dump. - path (`Union[str, Path]`): The path to dump the document to. - indent (`Optional[int]`): The indent to use for the json file. - - """ - with open(path, "w", encoding="utf8") as outfile: - json.dump(document, outfile, indent=indent) - - -def get_md5(path: Path): - """ - Get the MD5 value of a path. - """ - import hashlib - - with path.open("rb") as fin: - data = fin.read() - return hashlib.md5(data).hexdigest() - - -def file_exists(path: Union[str, os.PathLike]) -> bool: - """ - Check if the file at :obj:`path` exists. - - Args: - path (:obj:`str`, :obj:`os.PathLike`): - Path to check. - - Returns: - :obj:`bool`: :obj:`True` if the file exists. - """ - return Path(path).exists() - - -def dir_exists(path: Union[str, os.PathLike]) -> bool: - """ - Check if the directory at :obj:`path` exists. - - Args: - path (:obj:`str`, :obj:`os.PathLike`): - Path to check. - - Returns: - :obj:`bool`: :obj:`True` if the directory exists. - """ - return Path(path).is_dir() - - -def is_remote_url(url_or_filename: Union[str, Path]): - """ - Returns :obj:`True` if the input path is an url. - - Args: - url_or_filename (:obj:`str`, :obj:`Path`): - path to check. - - Returns: - :obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise. - - """ - if isinstance(url_or_filename, Path): - url_or_filename = str(url_or_filename) - parsed = urlparse(url_or_filename) - return parsed.scheme in ("http", "https") - - -def url_to_filename(resource: str, etag: str = None) -> str: - """ - Convert a `resource` into a hashed filename in a repeatable way. - If `etag` is specified, append its hash to the resources's, delimited - by a period. - """ - resource_bytes = resource.encode("utf-8") - resource_hash = sha256(resource_bytes) - filename = resource_hash.hexdigest() - - if etag: - etag_bytes = etag.encode("utf-8") - etag_hash = sha256(etag_bytes) - filename += "." + etag_hash.hexdigest() - - return filename - - -def download_resource( - url: str, - temp_file: BinaryIO, - headers=None, -): - """ - Download remote file. - """ - - if headers is None: - headers = {} - - r = requests.get(url, stream=True, headers=headers) - r.raise_for_status() - content_length = r.headers.get("Content-Length") - total = int(content_length) if content_length is not None else None - progress = tqdm( - unit="B", - unit_scale=True, - total=total, - desc="Downloading", - disable=logger.level in [logging.NOTSET], - ) - for chunk in r.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - temp_file.write(chunk) - progress.close() - - -def download_and_cache( - url: Union[str, Path], - cache_dir: Union[str, Path] = None, - force_download: bool = False, -): - if cache_dir is None: - cache_dir = SAPIENZANLP_CACHE_DIR - if isinstance(url, Path): - url = str(url) - - # check if cache dir exists - Path(cache_dir).mkdir(parents=True, exist_ok=True) - - # check if file is private - headers = {} - try: - r = requests.head(url, allow_redirects=False, timeout=10) - r.raise_for_status() - except requests.exceptions.HTTPError: - if r.status_code == 401: - hf_token = huggingface_hub.HfFolder.get_token() - if hf_token is None: - raise ValueError( - "You need to login to HuggingFace to download this model " - "(use the `huggingface-cli login` command)" - ) - headers["Authorization"] = f"Bearer {hf_token}" - - etag = None - try: - r = requests.head(url, allow_redirects=True, timeout=10, headers=headers) - r.raise_for_status() - etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") - # We favor a custom header indicating the etag of the linked resource, and - # we fallback to the regular etag header. - # If we don't have any of those, raise an error. - if etag is None: - raise OSError( - "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." - ) - # In case of a redirect, - # save an extra redirect on the request.get call, - # and ensure we download the exact atomic version even if it changed - # between the HEAD and the GET (unlikely, but hey). - if 300 <= r.status_code <= 399: - url = r.headers["Location"] - except (requests.exceptions.SSLError, requests.exceptions.ProxyError): - # Actually raise for those subclasses of ConnectionError - raise - except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): - # Otherwise, our Internet connection is down. - # etag is None - pass - - # get filename from the url - filename = url_to_filename(url, etag) - # get cache path to put the file - cache_path = cache_dir / filename - - # the file is already here, return it - if file_exists(cache_path) and not force_download: - logger.info( - f"{url} found in cache, set `force_download=True` to force the download" - ) - return cache_path - - cache_path = str(cache_path) - # Prevent parallel downloads of the same file with a lock. - lock_path = cache_path + ".lock" - with FileLock(lock_path): - # If the download just completed while the lock was activated. - if file_exists(cache_path) and not force_download: - # Even if returning early like here, the lock will be released. - return cache_path - - temp_file_manager = partial( - tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False - ) - - # Download to temporary file, then copy to cache dir once finished. - # Otherwise, you get corrupt cache entries if the download gets interrupted. - with temp_file_manager() as temp_file: - logger.info( - f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}" - ) - download_resource(url, temp_file, headers) - - logger.info(f"storing {url} in cache at {cache_path}") - os.replace(temp_file.name, cache_path) - - # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it. - umask = os.umask(0o666) - os.umask(umask) - os.chmod(cache_path, 0o666 & ~umask) - - logger.info(f"creating metadata file for {cache_path}") - meta = {"url": url} # , "etag": etag} - meta_path = cache_path + ".json" - with open(meta_path, "w") as meta_file: - json.dump(meta, meta_file) - - return cache_path - - -def download_from_hf( - path_or_repo_id: Union[str, Path], - filenames: Optional[List[str]], - cache_dir: Union[str, Path] = None, - force_download: bool = False, - resume_download: bool = False, - proxies: Optional[Dict[str, str]] = None, - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - local_files_only: bool = False, - subfolder: str = "", -): - if isinstance(path_or_repo_id, Path): - path_or_repo_id = str(path_or_repo_id) - - downloaded_paths = [] - for filename in filenames: - downloaded_path = hf_cached_file( - path_or_repo_id, - filename, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - use_auth_token=use_auth_token, - revision=revision, - local_files_only=local_files_only, - subfolder=subfolder, - ) - downloaded_paths.append(downloaded_path) - - # we want the folder where the files are downloaded - # the best guess is the parent folder of the first file - probably_the_folder = Path(downloaded_paths[0]).parent - return probably_the_folder - - -def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str: - """ - Resolve a model name or directory to a model archive name or directory. - - Args: - model_name_or_dir (:obj:`str` or :obj:`os.PathLike`): - A model name or directory. - - Returns: - :obj:`str`: The model archive name or directory. - """ - if is_remote_url(model_name_or_dir): - # if model_name_or_dir is a URL - # download it and try to load - model_archive = model_name_or_dir - elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file(): - # if model_name_or_dir is a local directory or - # an archive file try to load it - model_archive = model_name_or_dir - else: - # probably model_name_or_dir is a sapienzanlp model id - # guess the url and try to download - model_name_or_dir_ = model_name_or_dir - # raise ValueError(f"Providing a model id is not supported yet.") - model_archive = sapienzanlp_model_urls(model_name_or_dir_) - - return model_archive - - -def from_cache( - url_or_filename: Union[str, Path], - cache_dir: Union[str, Path] = None, - force_download: bool = False, - resume_download: bool = False, - proxies: Optional[Dict[str, str]] = None, - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - local_files_only: bool = False, - subfolder: str = "", - filenames: Optional[List[str]] = None, -) -> Path: - """ - Given something that could be either a local path or a URL (or a SapienzaNLP model id), - determine which one and return a path to the corresponding file. - - Args: - url_or_filename (:obj:`str` or :obj:`Path`): - A path to a local file or a URL (or a SapienzaNLP model id). - cache_dir (:obj:`str` or :obj:`Path`, `optional`): - Path to a directory in which a downloaded file will be cached. - force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to re-download the file even if it already exists. - resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to delete incompletely received files. Attempts to resume the download if such a file - exists. - proxies (:obj:`Dict[str, str]`, `optional`): - A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - use_auth_token (:obj:`Union[bool, str]`, `optional`): - Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from - :obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token. - revision (:obj:`str`, `optional`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any - identifier allowed by git. - local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to raise an error if the file to be downloaded is local. - subfolder (:obj:`str`, `optional`): - In case the relevant file is in a subfolder of the URL, specify it here. - filenames (:obj:`List[str]`, `optional`): - List of filenames to look for in the directory structure. - - Returns: - :obj:`Path`: Path to the cached file. - """ - - url_or_filename = model_name_or_path_resolver(url_or_filename) - - if cache_dir is None: - cache_dir = SAPIENZANLP_CACHE_DIR - - if file_exists(url_or_filename): - logger.info(f"{url_or_filename} is a local path or file") - output_path = url_or_filename - elif is_remote_url(url_or_filename): - # URL, so get it from the cache (downloading if necessary) - output_path = download_and_cache( - url_or_filename, - cache_dir=cache_dir, - force_download=force_download, - ) - else: - if filenames is None: - filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME] - output_path = download_from_hf( - url_or_filename, - filenames, - cache_dir, - force_download, - resume_download, - proxies, - use_auth_token, - revision, - local_files_only, - subfolder, - ) - - # if is_hf_hub_url(url_or_filename): - # HuggingFace Hub - # output_path = hf_hub_download_url(url_or_filename) - # elif is_remote_url(url_or_filename): - # # URL, so get it from the cache (downloading if necessary) - # output_path = download_and_cache( - # url_or_filename, - # cache_dir=cache_dir, - # force_download=force_download, - # ) - # elif file_exists(url_or_filename): - # logger.info(f"{url_or_filename} is a local path or file") - # # File, and it exists. - # output_path = url_or_filename - # elif urlparse(url_or_filename).scheme == "": - # # File, but it doesn't exist. - # raise EnvironmentError(f"file {url_or_filename} not found") - # else: - # # Something unknown - # raise ValueError( - # f"unable to parse {url_or_filename} as a URL or as a local path" - # ) - - if dir_exists(output_path) or ( - not is_zipfile(output_path) and not tarfile.is_tarfile(output_path) - ): - return Path(output_path) - - # Path where we extract compressed archives - # for now it will extract it in the same folder - # maybe implement extraction in the sapienzanlp folder - # when using local archive path? - logger.info("Extracting compressed archive") - output_dir, output_file = os.path.split(output_path) - output_extract_dir_name = output_file.replace(".", "-") + "-extracted" - output_path_extracted = os.path.join(output_dir, output_extract_dir_name) - - # already extracted, do not extract - if ( - os.path.isdir(output_path_extracted) - and os.listdir(output_path_extracted) - and not force_download - ): - return Path(output_path_extracted) - - # Prevent parallel extractions - lock_path = output_path + ".lock" - with FileLock(lock_path): - shutil.rmtree(output_path_extracted, ignore_errors=True) - os.makedirs(output_path_extracted) - if is_zipfile(output_path): - with ZipFile(output_path, "r") as zip_file: - zip_file.extractall(output_path_extracted) - zip_file.close() - elif tarfile.is_tarfile(output_path): - tar_file = tarfile.open(output_path) - tar_file.extractall(output_path_extracted) - tar_file.close() - else: - raise EnvironmentError( - f"Archive format of {output_path} could not be identified" - ) - - # remove lock file, is it safe? - os.remove(lock_path) - - return Path(output_path_extracted) - - -def is_str_a_path(maybe_path: str) -> bool: - """ - Check if a string is a path. - - Args: - maybe_path (`str`): The string to check. - - Returns: - `bool`: `True` if the string is a path, `False` otherwise. - """ - # first check if it is a path - if Path(maybe_path).exists(): - return True - # check if it is a relative path - if Path(os.path.join(os.getcwd(), maybe_path)).exists(): - return True - # otherwise it is not a path - return False - - -def relative_to_absolute_path(path: str) -> os.PathLike: - """ - Convert a relative path to an absolute path. - - Args: - path (`str`): The relative path to convert. - - Returns: - `os.PathLike`: The absolute path. - """ - if not is_str_a_path(path): - raise ValueError(f"{path} is not a path") - if Path(path).exists(): - return Path(path).absolute() - if Path(os.path.join(os.getcwd(), path)).exists(): - return Path(os.path.join(os.getcwd(), path)).absolute() - raise ValueError(f"{path} is not a path") - - -def to_config(object_to_save: Any) -> Dict[str, Any]: - """ - Convert an object to a dictionary. - - Returns: - `Dict[str, Any]`: The dictionary representation of the object. - """ - - def obj_to_dict(obj): - match obj: - case dict(): - data = {} - for k, v in obj.items(): - data[k] = obj_to_dict(v) - return data - - case list() | tuple(): - return [obj_to_dict(x) for x in obj] - - case object(__dict__=_): - data = { - "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}", - } - for k, v in obj.__dict__.items(): - if not k.startswith("_"): - data[k] = obj_to_dict(v) - return data - - case _: - return obj - - return obj_to_dict(object_to_save) - - -def get_callable_from_string(callable_fn: str) -> Any: - """ - Get a callable from a string. - - Args: - callable_fn (`str`): - The string representation of the callable. - - Returns: - `Any`: The callable. - """ - # separate the function name from the module name - module_name, function_name = callable_fn.rsplit(".", 1) - # import the module - module = importlib.import_module(module_name) - # get the function - return getattr(module, function_name) diff --git a/relik/inference/__init__.py b/relik/inference/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/inference/annotator.py b/relik/inference/annotator.py deleted file mode 100644 index 4d079830c42ac95c1bf0fda1bd86c62ff6e94aa6..0000000000000000000000000000000000000000 --- a/relik/inference/annotator.py +++ /dev/null @@ -1,428 +0,0 @@ -import os -from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union - -import hydra -from omegaconf import OmegaConf -from relik.retriever.indexers.faiss import FaissDocumentIndex -from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel -from rich.pretty import pprint - -from relik.common.log import get_console_logger, get_logger -from relik.common.upload import upload -from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string -from relik.inference.data.objects import EntitySpan, RelikOutput -from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer -from relik.inference.data.window.manager import WindowManager -from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction -from relik.reader.relik_reader import RelikReader -from relik.retriever.data.utils import batch_generator -from relik.retriever.indexers.base import BaseDocumentIndex -from relik.retriever.pytorch_modules.model import GoldenRetriever - -logger = get_logger(__name__) -console_logger = get_console_logger() - - -class Relik: - """ - Relik main class. It is a wrapper around a retriever and a reader. - - Args: - retriever (`Optional[GoldenRetriever]`, `optional`): - The retriever to use. If `None`, a retriever will be instantiated from the - provided `question_encoder`, `passage_encoder` and `document_index`. - Defaults to `None`. - question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): - The question encoder to use. If `retriever` is `None`, a retriever will be - instantiated from this parameter. Defaults to `None`. - passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): - The passage encoder to use. If `retriever` is `None`, a retriever will be - instantiated from this parameter. Defaults to `None`. - document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`): - The document index to use. If `retriever` is `None`, a retriever will be - instantiated from this parameter. Defaults to `None`. - reader (`Optional[Union[str, RelikReader]]`, `optional`): - The reader to use. If `None`, a reader will be instantiated from the - provided `reader`. Defaults to `None`. - retriever_device (`str`, `optional`, defaults to `cpu`): - The device to use for the retriever. - - """ - - def __init__( - self, - retriever: GoldenRetriever | None = None, - question_encoder: str | GoldenRetrieverModel | None = None, - passage_encoder: str | GoldenRetrieverModel | None = None, - document_index: str | BaseDocumentIndex | None = None, - reader: str | RelikReader | None = None, - device: str = "cpu", - retriever_device: str | None = None, - document_index_device: str | None = None, - reader_device: str | None = None, - precision: int = 32, - retriever_precision: int | None = None, - document_index_precision: int | None = None, - reader_precision: int | None = None, - reader_kwargs: dict | None = None, - retriever_kwargs: dict | None = None, - candidates_preprocessing_fn: str | Callable | None = None, - top_k: int | None = None, - window_size: int | None = None, - window_stride: int | None = None, - **kwargs, - ) -> None: - # retriever - retriever_device = retriever_device or device - document_index_device = document_index_device or device - retriever_precision = retriever_precision or precision - document_index_precision = document_index_precision or precision - if retriever is None and question_encoder is None: - raise ValueError( - "Either `retriever` or `question_encoder` must be provided" - ) - if retriever is None: - self.retriever_kwargs = dict( - question_encoder=question_encoder, - passage_encoder=passage_encoder, - document_index=document_index, - device=retriever_device, - precision=retriever_precision, - index_device=document_index_device, - index_precision=document_index_precision, - ) - # overwrite default_retriever_kwargs with retriever_kwargs - self.retriever_kwargs.update(retriever_kwargs or {}) - retriever = GoldenRetriever(**self.retriever_kwargs) - retriever.training = False - retriever.eval() - self.retriever = retriever - - # reader - self.reader_device = reader_device or device - self.reader_precision = reader_precision or precision - self.reader_kwargs = reader_kwargs - if isinstance(reader, str): - reader_kwargs = reader_kwargs or {} - reader = RelikReaderForSpanExtraction(reader, **reader_kwargs) - self.reader = reader - - # windowization stuff - self.tokenizer = SpacyTokenizer(language="en") - self.window_manager: WindowManager | None = None - - # candidates preprocessing - # TODO: maybe move this logic somewhere else - candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x) - if isinstance(candidates_preprocessing_fn, str): - candidates_preprocessing_fn = get_callable_from_string( - candidates_preprocessing_fn - ) - self.candidates_preprocessing_fn = candidates_preprocessing_fn - - # inference params - self.top_k = top_k - self.window_size = window_size - self.window_stride = window_stride - - def __call__( - self, - text: Union[str, list], - top_k: Optional[int] = None, - window_size: Optional[int] = None, - window_stride: Optional[int] = None, - retriever_batch_size: Optional[int] = 32, - reader_batch_size: Optional[int] = 32, - return_also_windows: bool = False, - **kwargs, - ) -> Union[RelikOutput, list[RelikOutput]]: - """ - Annotate a text with entities. - - Args: - text (`str` or `list`): - The text to annotate. If a list is provided, each element of the list - will be annotated separately. - top_k (`int`, `optional`, defaults to `None`): - The number of candidates to retrieve for each window. - window_size (`int`, `optional`, defaults to `None`): - The size of the window. If `None`, the whole text will be annotated. - window_stride (`int`, `optional`, defaults to `None`): - The stride of the window. If `None`, there will be no overlap between windows. - retriever_batch_size (`int`, `optional`, defaults to `None`): - The batch size to use for the retriever. The whole input is the batch for the retriever. - reader_batch_size (`int`, `optional`, defaults to `None`): - The batch size to use for the reader. The whole input is the batch for the reader. - return_also_windows (`bool`, `optional`, defaults to `False`): - Whether to return the windows in the output. - **kwargs: - Additional keyword arguments to pass to the retriever and the reader. - - Returns: - `RelikOutput` or `list[RelikOutput]`: - The annotated text. If a list was provided as input, a list of - `RelikOutput` objects will be returned. - """ - if top_k is None: - top_k = self.top_k or 100 - if window_size is None: - window_size = self.window_size - if window_stride is None: - window_stride = self.window_stride - - if isinstance(text, str): - text = [text] - - if window_size is not None: - if self.window_manager is None: - self.window_manager = WindowManager(self.tokenizer) - - if window_size == "sentence": - # todo: implement sentence windowizer - raise NotImplementedError("Sentence windowizer not implemented yet") - - # if window_size < window_stride: - # raise ValueError( - # f"Window size ({window_size}) must be greater than window stride ({window_stride})" - # ) - - # window generator - windows = [ - window - for doc_id, t in enumerate(text) - for window in self.window_manager.create_windows( - t, - window_size=window_size, - stride=window_stride, - doc_id=doc_id, - ) - ] - - # retrieve candidates first - windows_candidates = [] - # TODO: Move batching inside retriever - for batch in batch_generator(windows, batch_size=retriever_batch_size): - retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k) - windows_candidates.extend( - [[p.label for p in predictions] for predictions in retriever_out] - ) - - # add passage to the windows - for window, candidates in zip(windows, windows_candidates): - window.window_candidates = [ - self.candidates_preprocessing_fn(c) for c in candidates - ] - - windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size) - windows = self.window_manager.merge_windows(windows) - - # transform predictions into RelikOutput objects - output = [] - for w in windows: - sample_output = RelikOutput( - text=text[w.doc_id], - labels=sorted( - [ - EntitySpan( - start=ss, end=se, label=sl, text=text[w.doc_id][ss:se] - ) - for ss, se, sl in w.predicted_window_labels_chars - ], - key=lambda x: x.start, - ), - ) - output.append(sample_output) - - if return_also_windows: - for i, sample_output in enumerate(output): - sample_output.windows = [w for w in windows if w.doc_id == i] - - # if only one text was provided, return a single RelikOutput object - if len(output) == 1: - return output[0] - - return output - - @classmethod - def from_pretrained( - cls, - model_name_or_dir: Union[str, os.PathLike], - config_kwargs: Optional[Dict] = None, - config_file_name: str = CONFIG_NAME, - *args, - **kwargs, - ) -> "Relik": - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - - model_dir = from_cache( - model_name_or_dir, - filenames=[config_file_name], - cache_dir=cache_dir, - force_download=force_download, - ) - - config_path = model_dir / config_file_name - if not config_path.exists(): - raise FileNotFoundError( - f"Model configuration file not found at {config_path}." - ) - - # overwrite config with config_kwargs - config = OmegaConf.load(config_path) - if config_kwargs is not None: - # TODO: check merging behavior - config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) - # do we want to print the config? I like it - pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) - - # load relik from config - relik = hydra.utils.instantiate(config, *args, **kwargs) - - return relik - - def save_pretrained( - self, - output_dir: Union[str, os.PathLike], - config: Optional[Dict[str, Any]] = None, - config_file_name: Optional[str] = None, - save_weights: bool = False, - push_to_hub: bool = False, - model_id: Optional[str] = None, - organization: Optional[str] = None, - repo_name: Optional[str] = None, - **kwargs, - ): - """ - Save the configuration of Relik to the specified directory as a YAML file. - - Args: - output_dir (`str`): - The directory to save the configuration file to. - config (`Optional[Dict[str, Any]]`, `optional`): - The configuration to save. If `None`, the current configuration will be - saved. Defaults to `None`. - config_file_name (`Optional[str]`, `optional`): - The name of the configuration file. Defaults to `config.yaml`. - save_weights (`bool`, `optional`): - Whether to save the weights of the model. Defaults to `False`. - push_to_hub (`bool`, `optional`): - Whether to push the saved model to the hub. Defaults to `False`. - model_id (`Optional[str]`, `optional`): - The id of the model to push to the hub. If `None`, the name of the - directory will be used. Defaults to `None`. - organization (`Optional[str]`, `optional`): - The organization to push the model to. Defaults to `None`. - repo_name (`Optional[str]`, `optional`): - The name of the repository to push the model to. Defaults to `None`. - **kwargs: - Additional keyword arguments to pass to `OmegaConf.save`. - """ - if config is None: - # create a default config - config = { - "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}" - } - if self.retriever is not None: - if self.retriever.question_encoder is not None: - config[ - "question_encoder" - ] = self.retriever.question_encoder.name_or_path - if self.retriever.passage_encoder is not None: - config[ - "passage_encoder" - ] = self.retriever.passage_encoder.name_or_path - if self.retriever.document_index is not None: - config["document_index"] = self.retriever.document_index.name_or_dir - if self.reader is not None: - config["reader"] = self.reader.model_path - - config["retriever_kwargs"] = self.retriever_kwargs - config["reader_kwargs"] = self.reader_kwargs - # expand the fn as to be able to save it and load it later - config[ - "candidates_preprocessing_fn" - ] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}" - - # these are model-specific and should be saved - config["top_k"] = self.top_k - config["window_size"] = self.window_size - config["window_stride"] = self.window_stride - - config_file_name = config_file_name or CONFIG_NAME - - # create the output directory - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Saving relik config to {output_dir / config_file_name}") - # pretty print the config - pprint(config, console=console_logger, expand_all=True) - OmegaConf.save(config, output_dir / config_file_name) - - if save_weights: - model_id = model_id or output_dir.name - retriever_model_id = model_id + "-retriever" - # save weights - logger.info(f"Saving retriever to {output_dir / retriever_model_id}") - self.retriever.save_pretrained( - output_dir / retriever_model_id, - question_encoder_name=retriever_model_id + "-question-encoder", - passage_encoder_name=retriever_model_id + "-passage-encoder", - document_index_name=retriever_model_id + "-index", - push_to_hub=push_to_hub, - organization=organization, - repo_name=repo_name, - **kwargs, - ) - reader_model_id = model_id + "-reader" - logger.info(f"Saving reader to {output_dir / reader_model_id}") - self.reader.save_pretrained( - output_dir / reader_model_id, - push_to_hub=push_to_hub, - organization=organization, - repo_name=repo_name, - **kwargs, - ) - - if push_to_hub: - # push to hub - logger.info(f"Pushing to hub") - model_id = model_id or output_dir.name - upload(output_dir, model_id, organization=organization, repo_name=repo_name) - - -def main(): - from pprint import pprint - - document_index = FaissDocumentIndex.from_pretrained( - "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index", - config_kwargs={"_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex", "index_type": "IVFx,Flat"}, - ) - - relik = Relik( - question_encoder="/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder", - document_index=document_index, - reader="/root/relik-spaces/models/relik-reader-aida-deberta-small", - device="cuda", - precision=16, - top_k=100, - window_size=32, - window_stride=16, - candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing", - ) - - input_text = """ - Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore. - The 92-year-old billionaire did not disclose the trust to the government in July 2015. - Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty. - Ecclestone had been due to go on trial next month. - """ - - preds = relik(input_text) - pprint(preds) - - -if __name__ == "__main__": - main() diff --git a/relik/inference/data/__init__.py b/relik/inference/data/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/inference/data/objects.py b/relik/inference/data/objects.py deleted file mode 100644 index 4b11e9641380b9e13d60de427827a73b70cbb9c1..0000000000000000000000000000000000000000 --- a/relik/inference/data/objects.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import List, NamedTuple, Optional - -from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample - - -@dataclass -class Word: - """ - A word representation that includes text, index in the sentence, POS tag, lemma, - dependency relation, and similar information. - - # Parameters - text : `str`, optional - The text representation. - index : `int`, optional - The word offset in the sentence. - lemma : `str`, optional - The lemma of this word. - pos : `str`, optional - The coarse-grained part of speech of this word. - dep : `str`, optional - The dependency relation for this word. - - input_id : `int`, optional - Integer representation of the word, used to pass it to a model. - token_type_id : `int`, optional - Token type id used by some transformers. - attention_mask: `int`, optional - Attention mask used by transformers, indicates to the model which tokens should - be attended to, and which should not. - """ - - text: str - index: int - start_char: Optional[int] = None - end_char: Optional[int] = None - # preprocessing fields - lemma: Optional[str] = None - pos: Optional[str] = None - dep: Optional[str] = None - head: Optional[int] = None - - def __str__(self): - return self.text - - def __repr__(self): - return self.__str__() - - -class EntitySpan(NamedTuple): - start: int - end: int - label: str - text: str - - -@dataclass -class RelikOutput: - text: str - labels: List[EntitySpan] - windows: Optional[List[RelikReaderSample]] = None diff --git a/relik/inference/data/tokenizers/__init__.py b/relik/inference/data/tokenizers/__init__.py deleted file mode 100644 index ad70314e8e0ccc18b946ff1317f6415c1892747a..0000000000000000000000000000000000000000 --- a/relik/inference/data/tokenizers/__init__.py +++ /dev/null @@ -1,89 +0,0 @@ -SPACY_LANGUAGE_MAPPER = { - "ca": "ca_core_news_sm", - "da": "da_core_news_sm", - "de": "de_core_news_sm", - "el": "el_core_news_sm", - "en": "en_core_web_sm", - "es": "es_core_news_sm", - "fr": "fr_core_news_sm", - "it": "it_core_news_sm", - "ja": "ja_core_news_sm", - "lt": "lt_core_news_sm", - "mk": "mk_core_news_sm", - "nb": "nb_core_news_sm", - "nl": "nl_core_news_sm", - "pl": "pl_core_news_sm", - "pt": "pt_core_news_sm", - "ro": "ro_core_news_sm", - "ru": "ru_core_news_sm", - "xx": "xx_sent_ud_sm", - "zh": "zh_core_web_sm", - "ca_core_news_sm": "ca_core_news_sm", - "ca_core_news_md": "ca_core_news_md", - "ca_core_news_lg": "ca_core_news_lg", - "ca_core_news_trf": "ca_core_news_trf", - "da_core_news_sm": "da_core_news_sm", - "da_core_news_md": "da_core_news_md", - "da_core_news_lg": "da_core_news_lg", - "da_core_news_trf": "da_core_news_trf", - "de_core_news_sm": "de_core_news_sm", - "de_core_news_md": "de_core_news_md", - "de_core_news_lg": "de_core_news_lg", - "de_dep_news_trf": "de_dep_news_trf", - "el_core_news_sm": "el_core_news_sm", - "el_core_news_md": "el_core_news_md", - "el_core_news_lg": "el_core_news_lg", - "en_core_web_sm": "en_core_web_sm", - "en_core_web_md": "en_core_web_md", - "en_core_web_lg": "en_core_web_lg", - "en_core_web_trf": "en_core_web_trf", - "es_core_news_sm": "es_core_news_sm", - "es_core_news_md": "es_core_news_md", - "es_core_news_lg": "es_core_news_lg", - "es_dep_news_trf": "es_dep_news_trf", - "fr_core_news_sm": "fr_core_news_sm", - "fr_core_news_md": "fr_core_news_md", - "fr_core_news_lg": "fr_core_news_lg", - "fr_dep_news_trf": "fr_dep_news_trf", - "it_core_news_sm": "it_core_news_sm", - "it_core_news_md": "it_core_news_md", - "it_core_news_lg": "it_core_news_lg", - "ja_core_news_sm": "ja_core_news_sm", - "ja_core_news_md": "ja_core_news_md", - "ja_core_news_lg": "ja_core_news_lg", - "ja_dep_news_trf": "ja_dep_news_trf", - "lt_core_news_sm": "lt_core_news_sm", - "lt_core_news_md": "lt_core_news_md", - "lt_core_news_lg": "lt_core_news_lg", - "mk_core_news_sm": "mk_core_news_sm", - "mk_core_news_md": "mk_core_news_md", - "mk_core_news_lg": "mk_core_news_lg", - "nb_core_news_sm": "nb_core_news_sm", - "nb_core_news_md": "nb_core_news_md", - "nb_core_news_lg": "nb_core_news_lg", - "nl_core_news_sm": "nl_core_news_sm", - "nl_core_news_md": "nl_core_news_md", - "nl_core_news_lg": "nl_core_news_lg", - "pl_core_news_sm": "pl_core_news_sm", - "pl_core_news_md": "pl_core_news_md", - "pl_core_news_lg": "pl_core_news_lg", - "pt_core_news_sm": "pt_core_news_sm", - "pt_core_news_md": "pt_core_news_md", - "pt_core_news_lg": "pt_core_news_lg", - "ro_core_news_sm": "ro_core_news_sm", - "ro_core_news_md": "ro_core_news_md", - "ro_core_news_lg": "ro_core_news_lg", - "ru_core_news_sm": "ru_core_news_sm", - "ru_core_news_md": "ru_core_news_md", - "ru_core_news_lg": "ru_core_news_lg", - "xx_ent_wiki_sm": "xx_ent_wiki_sm", - "xx_sent_ud_sm": "xx_sent_ud_sm", - "zh_core_web_sm": "zh_core_web_sm", - "zh_core_web_md": "zh_core_web_md", - "zh_core_web_lg": "zh_core_web_lg", - "zh_core_web_trf": "zh_core_web_trf", -} - -from relik.inference.data.tokenizers.regex_tokenizer import RegexTokenizer -from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer -from relik.inference.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer diff --git a/relik/inference/data/tokenizers/base_tokenizer.py b/relik/inference/data/tokenizers/base_tokenizer.py deleted file mode 100644 index 1fed161b3eca085656e85d44cb9a64739f3d1e4c..0000000000000000000000000000000000000000 --- a/relik/inference/data/tokenizers/base_tokenizer.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import List, Union - -from relik.inference.data.objects import Word - - -class BaseTokenizer: - """ - A :obj:`Tokenizer` splits strings of text into single words, optionally adds - pos tags and perform lemmatization. - """ - - def __call__( - self, - texts: Union[str, List[str], List[List[str]]], - is_split_into_words: bool = False, - **kwargs - ) -> List[List[Word]]: - """ - Tokenize the input into single words. - - Args: - texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): - Text to tag. It can be a single string, a batch of string and pre-tokenized strings. - is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True` and the input is a string, the input is split on spaces. - - Returns: - :obj:`List[List[Word]]`: The input text tokenized in single words. - """ - raise NotImplementedError - - def tokenize(self, text: str) -> List[Word]: - """ - Implements splitting words into tokens. - - Args: - text (:obj:`str`): - Text to tokenize. - - Returns: - :obj:`List[Word]`: The input text tokenized in single words. - - """ - raise NotImplementedError - - def tokenize_batch(self, texts: List[str]) -> List[List[Word]]: - """ - Implements batch splitting words into tokens. - - Args: - texts (:obj:`List[str]`): - Batch of text to tokenize. - - Returns: - :obj:`List[List[Word]]`: The input batch tokenized in single words. - - """ - return [self.tokenize(text) for text in texts] - - @staticmethod - def check_is_batched( - texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool - ): - """ - Check if input is batched or a single sample. - - Args: - texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): - Text to check. - is_split_into_words (:obj:`bool`): - If :obj:`True` and the input is a string, the input is split on spaces. - - Returns: - :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise. - """ - return bool( - (not is_split_into_words and isinstance(texts, (list, tuple))) - or ( - is_split_into_words - and isinstance(texts, (list, tuple)) - and texts - and isinstance(texts[0], (list, tuple)) - ) - ) diff --git a/relik/inference/data/tokenizers/regex_tokenizer.py b/relik/inference/data/tokenizers/regex_tokenizer.py deleted file mode 100644 index ebe8656afb891a8318a7030375427e190d1dc383..0000000000000000000000000000000000000000 --- a/relik/inference/data/tokenizers/regex_tokenizer.py +++ /dev/null @@ -1,73 +0,0 @@ -import re -from typing import List, Union - -from overrides import overrides - -from relik.inference.data.objects import Word -from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer - - -class RegexTokenizer(BaseTokenizer): - """ - A :obj:`Tokenizer` that splits the text based on a simple regex. - """ - - def __init__(self): - super(RegexTokenizer, self).__init__() - # regex for splitting on spaces and punctuation and new lines - # self._regex = re.compile(r"\S+|[\[\](),.!?;:\"]|\\n") - self._regex = re.compile( - r"\w+|\$[\d\.]+|\S+", re.UNICODE | re.MULTILINE | re.DOTALL - ) - - def __call__( - self, - texts: Union[str, List[str], List[List[str]]], - is_split_into_words: bool = False, - **kwargs, - ) -> List[List[Word]]: - """ - Tokenize the input into single words by splitting using a simple regex. - - Args: - texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): - Text to tag. It can be a single string, a batch of string and pre-tokenized strings. - is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True` and the input is a string, the input is split on spaces. - - Returns: - :obj:`List[List[Word]]`: The input text tokenized in single words. - - Example:: - - >>> from relik.retriever.serve.tokenizers.regex_tokenizer import RegexTokenizer - - >>> regex_tokenizer = RegexTokenizer() - >>> regex_tokenizer("Mary sold the car to John.") - - """ - # check if input is batched or a single sample - is_batched = self.check_is_batched(texts, is_split_into_words) - - if is_batched: - tokenized = self.tokenize_batch(texts) - else: - tokenized = self.tokenize(texts) - - return tokenized - - @overrides - def tokenize(self, text: Union[str, List[str]]) -> List[Word]: - if not isinstance(text, (str, list)): - raise ValueError( - f"text must be either `str` or `list`, found: `{type(text)}`" - ) - - if isinstance(text, list): - text = " ".join(text) - return [ - Word(t[0], i, start_char=t[1], end_char=t[2]) - for i, t in enumerate( - (m.group(0), m.start(), m.end()) for m in self._regex.finditer(text) - ) - ] diff --git a/relik/inference/data/tokenizers/spacy_tokenizer.py b/relik/inference/data/tokenizers/spacy_tokenizer.py deleted file mode 100644 index b949216ed5cf152ae4a7722c4a6be3f883481db2..0000000000000000000000000000000000000000 --- a/relik/inference/data/tokenizers/spacy_tokenizer.py +++ /dev/null @@ -1,228 +0,0 @@ -import logging -from typing import Dict, List, Tuple, Union - -import spacy - -# from ipa.common.utils import load_spacy -from overrides import overrides -from spacy.cli.download import download as spacy_download -from spacy.tokens import Doc - -from relik.common.log import get_logger -from relik.inference.data.objects import Word -from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER -from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer - -logger = get_logger(level=logging.DEBUG) - -# Spacy and Stanza stuff - -LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {} - - -def load_spacy( - language: str, - pos_tags: bool = False, - lemma: bool = False, - parse: bool = False, - split_on_spaces: bool = False, -) -> spacy.Language: - """ - Download and load spacy model. - - Args: - language (:obj:`str`, defaults to :obj:`en`): - Language of the text to tokenize. - pos_tags (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, performs POS tagging with spacy model. - lemma (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, performs lemmatization with spacy model. - parse (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, performs dependency parsing with spacy model. - split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, will split by spaces without performing tokenization. - - Returns: - :obj:`spacy.Language`: The spacy model loaded. - """ - exclude = ["vectors", "textcat", "ner"] - if not pos_tags: - exclude.append("tagger") - if not lemma: - exclude.append("lemmatizer") - if not parse: - exclude.append("parser") - - # check if the model is already loaded - # if so, there is no need to reload it - spacy_params = (language, pos_tags, lemma, parse, split_on_spaces) - if spacy_params not in LOADED_SPACY_MODELS: - try: - spacy_tagger = spacy.load(language, exclude=exclude) - except OSError: - logger.warning( - "Spacy model '%s' not found. Downloading and installing.", language - ) - spacy_download(language) - spacy_tagger = spacy.load(language, exclude=exclude) - - # if everything is disabled, return only the tokenizer - # for faster tokenization - # TODO: is it really faster? - # if len(exclude) >= 6: - # spacy_tagger = spacy_tagger.tokenizer - LOADED_SPACY_MODELS[spacy_params] = spacy_tagger - - return LOADED_SPACY_MODELS[spacy_params] - - -class SpacyTokenizer(BaseTokenizer): - """ - A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects. - - Args: - language (:obj:`str`, optional, defaults to :obj:`en`): - Language of the text to tokenize. - return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, performs POS tagging with spacy model. - return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, performs lemmatization with spacy model. - return_deps (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, performs dependency parsing with spacy model. - split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, will split by spaces without performing tokenization. - use_gpu (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, will load the Stanza model on GPU. - """ - - def __init__( - self, - language: str = "en", - return_pos_tags: bool = False, - return_lemmas: bool = False, - return_deps: bool = False, - split_on_spaces: bool = False, - use_gpu: bool = False, - ): - super(SpacyTokenizer, self).__init__() - if language not in SPACY_LANGUAGE_MAPPER: - raise ValueError( - f"`{language}` language not supported. The supported " - f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}." - ) - if use_gpu: - # load the model on GPU - # if the GPU is not available or not correctly configured, - # it will rise an error - spacy.require_gpu() - self.spacy = load_spacy( - SPACY_LANGUAGE_MAPPER[language], - return_pos_tags, - return_lemmas, - return_deps, - split_on_spaces, - ) - self.split_on_spaces = split_on_spaces - - def __call__( - self, - texts: Union[str, List[str], List[List[str]]], - is_split_into_words: bool = False, - **kwargs, - ) -> Union[List[Word], List[List[Word]]]: - """ - Tokenize the input into single words using SpaCy models. - - Args: - texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): - Text to tag. It can be a single string, a batch of string and pre-tokenized strings. - is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True` and the input is a string, the input is split on spaces. - - Returns: - :obj:`List[List[Word]]`: The input text tokenized in single words. - - Example:: - - >>> from ipa import SpacyTokenizer - - >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True) - >>> spacy_tokenizer("Mary sold the car to John.") - - """ - # check if input is batched or a single sample - is_batched = self.check_is_batched(texts, is_split_into_words) - if is_batched: - tokenized = self.tokenize_batch(texts) - else: - tokenized = self.tokenize(texts) - return tokenized - - @overrides - def tokenize(self, text: Union[str, List[str]]) -> List[Word]: - if self.split_on_spaces: - if isinstance(text, str): - text = text.split(" ") - spaces = [True] * len(text) - text = Doc(self.spacy.vocab, words=text, spaces=spaces) - return self._clean_tokens(self.spacy(text)) - - @overrides - def tokenize_batch( - self, texts: Union[List[str], List[List[str]]] - ) -> List[List[Word]]: - if self.split_on_spaces: - if isinstance(texts[0], str): - texts = [text.split(" ") for text in texts] - spaces = [[True] * len(text) for text in texts] - texts = [ - Doc(self.spacy.vocab, words=text, spaces=space) - for text, space in zip(texts, spaces) - ] - return [self._clean_tokens(tokens) for tokens in self.spacy.pipe(texts)] - - @staticmethod - def _clean_tokens(tokens: Doc) -> List[Word]: - """ - Converts spaCy tokens to :obj:`Word`. - - Args: - tokens (:obj:`spacy.tokens.Doc`): - Tokens from SpaCy model. - - Returns: - :obj:`List[Word]`: The SpaCy model output converted into :obj:`Word` objects. - """ - words = [ - Word( - token.text, - token.i, - token.idx, - token.idx + len(token), - token.lemma_, - token.pos_, - token.dep_, - token.head.i, - ) - for token in tokens - ] - return words - - -class WhitespaceSpacyTokenizer: - """Simple white space tokenizer for SpaCy.""" - - def __init__(self, vocab): - self.vocab = vocab - - def __call__(self, text): - if isinstance(text, str): - words = text.split(" ") - elif isinstance(text, list): - words = text - else: - raise ValueError( - f"text must be either `str` or `list`, found: `{type(text)}`" - ) - spaces = [True] * len(words) - return Doc(self.vocab, words=words, spaces=spaces) diff --git a/relik/inference/data/tokenizers/whitespace_tokenizer.py b/relik/inference/data/tokenizers/whitespace_tokenizer.py deleted file mode 100644 index 537ab6fe21eb4f9378d96d7cebfbc8cb12c36104..0000000000000000000000000000000000000000 --- a/relik/inference/data/tokenizers/whitespace_tokenizer.py +++ /dev/null @@ -1,70 +0,0 @@ -import re -from typing import List, Union - -from overrides import overrides - -from relik.inference.data.objects import Word -from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer - - -class WhitespaceTokenizer(BaseTokenizer): - """ - A :obj:`Tokenizer` that splits the text on spaces. - """ - - def __init__(self): - super(WhitespaceTokenizer, self).__init__() - self.whitespace_regex = re.compile(r"\S+") - - def __call__( - self, - texts: Union[str, List[str], List[List[str]]], - is_split_into_words: bool = False, - **kwargs, - ) -> List[List[Word]]: - """ - Tokenize the input into single words by splitting on spaces. - - Args: - texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): - Text to tag. It can be a single string, a batch of string and pre-tokenized strings. - is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True` and the input is a string, the input is split on spaces. - - Returns: - :obj:`List[List[Word]]`: The input text tokenized in single words. - - Example:: - - >>> from nlp_preprocessing_wrappers import WhitespaceTokenizer - - >>> whitespace_tokenizer = WhitespaceTokenizer() - >>> whitespace_tokenizer("Mary sold the car to John .") - - """ - # check if input is batched or a single sample - is_batched = self.check_is_batched(texts, is_split_into_words) - - if is_batched: - tokenized = self.tokenize_batch(texts) - else: - tokenized = self.tokenize(texts) - - return tokenized - - @overrides - def tokenize(self, text: Union[str, List[str]]) -> List[Word]: - if not isinstance(text, (str, list)): - raise ValueError( - f"text must be either `str` or `list`, found: `{type(text)}`" - ) - - if isinstance(text, list): - text = " ".join(text) - return [ - Word(t[0], i, start_char=t[1], end_char=t[2]) - for i, t in enumerate( - (m.group(0), m.start(), m.end()) - for m in self.whitespace_regex.finditer(text) - ) - ] diff --git a/relik/inference/data/window/__init__.py b/relik/inference/data/window/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/inference/data/window/manager.py b/relik/inference/data/window/manager.py deleted file mode 100644 index 420609b1827f13bb332780554e3e20421908f6e9..0000000000000000000000000000000000000000 --- a/relik/inference/data/window/manager.py +++ /dev/null @@ -1,262 +0,0 @@ -import collections -import itertools -from dataclasses import dataclass -from typing import List, Optional, Set, Tuple - -from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer -from relik.reader.data.relik_reader_sample import RelikReaderSample - - -@dataclass -class Window: - doc_id: int - window_id: int - text: str - tokens: List[str] - doc_topic: Optional[str] - offset: int - token2char_start: dict - token2char_end: dict - window_candidates: Optional[List[str]] = None - - -class WindowManager: - def __init__(self, tokenizer: BaseTokenizer) -> None: - self.tokenizer = tokenizer - - def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]: - tokenized_document = self.tokenizer(document) - tokens = [] - tokens_char_mapping = [] - for token in tokenized_document: - tokens.append(token.text) - tokens_char_mapping.append((token.start_char, token.end_char)) - return tokens, tokens_char_mapping - - def create_windows( - self, - document: str, - window_size: int, - stride: int, - doc_id: int = 0, - doc_topic: str = None, - ) -> List[RelikReaderSample]: - document_tokens, tokens_char_mapping = self.tokenize(document) - if doc_topic is None: - doc_topic = document_tokens[0] if len(document_tokens) > 0 else "" - document_windows = [] - if len(document_tokens) <= window_size: - text = document - # relik_reader_sample = RelikReaderSample() - document_windows.append( - # Window( - RelikReaderSample( - doc_id=doc_id, - window_id=0, - text=text, - tokens=document_tokens, - doc_topic=doc_topic, - offset=0, - token2char_start={ - str(i): tokens_char_mapping[i][0] - for i in range(len(document_tokens)) - }, - token2char_end={ - str(i): tokens_char_mapping[i][1] - for i in range(len(document_tokens)) - }, - ) - ) - else: - for window_id, i in enumerate(range(0, len(document_tokens), stride)): - # if the last stride is smaller than the window size, then we can - # include more tokens form the previous window. - if i != 0 and i + window_size > len(document_tokens): - overflowing_tokens = i + window_size - len(document_tokens) - if overflowing_tokens >= stride: - break - i -= overflowing_tokens - - involved_token_indices = list( - range(i, min(i + window_size, len(document_tokens) - 1)) - ) - window_tokens = [document_tokens[j] for j in involved_token_indices] - window_text_start = tokens_char_mapping[involved_token_indices[0]][0] - window_text_end = tokens_char_mapping[involved_token_indices[-1]][1] - text = document[window_text_start:window_text_end] - document_windows.append( - # Window( - RelikReaderSample( - # dict( - doc_id=doc_id, - window_id=window_id, - text=text, - tokens=window_tokens, - doc_topic=doc_topic, - offset=window_text_start, - token2char_start={ - str(i): tokens_char_mapping[ti][0] - for i, ti in enumerate(involved_token_indices) - }, - token2char_end={ - str(i): tokens_char_mapping[ti][1] - for i, ti in enumerate(involved_token_indices) - }, - # ) - ) - ) - return document_windows - - def merge_windows( - self, windows: List[RelikReaderSample] - ) -> List[RelikReaderSample]: - windows_by_doc_id = collections.defaultdict(list) - for window in windows: - windows_by_doc_id[window.doc_id].append(window) - - merged_window_by_doc = { - doc_id: self.merge_doc_windows(doc_windows) - for doc_id, doc_windows in windows_by_doc_id.items() - } - - return list(merged_window_by_doc.values()) - - def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample: - if len(windows) == 1: - return windows[0] - - if len(windows) > 0 and getattr(windows[0], "offset", None) is not None: - windows = sorted(windows, key=(lambda x: x.offset)) - - window_accumulator = windows[0] - - for next_window in windows[1:]: - window_accumulator = self._merge_window_pair( - window_accumulator, next_window - ) - - return window_accumulator - - def _merge_tokens( - self, window1: RelikReaderSample, window2: RelikReaderSample - ) -> Tuple[list, dict, dict]: - w1_tokens = window1.tokens[1:-1] - w2_tokens = window2.tokens[1:-1] - - # find intersection - tokens_intersection = None - for k in reversed(range(1, len(w1_tokens))): - if w1_tokens[-k:] == w2_tokens[:k]: - tokens_intersection = k - break - assert tokens_intersection is not None, ( - f"{window1.doc_id} - {window1.sent_id} - {window1.offset}" - + f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n" - + f"w1 tokens: {w1_tokens}\n" - + f"w2 tokens: {w2_tokens}\n" - ) - - final_tokens = ( - [window1.tokens[0]] # CLS - + w1_tokens - + w2_tokens[tokens_intersection:] - + [window1.tokens[-1]] # SEP - ) - - w2_starting_offset = len(w1_tokens) - tokens_intersection - - def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict: - final_t2c = dict() - final_t2c.update(t2c1) - for t, c in t2c2.items(): - t = int(t) - if t < tokens_intersection: - continue - final_t2c[str(t + w2_starting_offset)] = c - return final_t2c - - return ( - final_tokens, - merge_char_mapping(window1.token2char_start, window2.token2char_start), - merge_char_mapping(window1.token2char_end, window2.token2char_end), - ) - - def _merge_span_annotation( - self, span_annotation1: List[list], span_annotation2: List[list] - ) -> List[list]: - uniq_store = set() - final_span_annotation_store = [] - for span_annotation in itertools.chain(span_annotation1, span_annotation2): - span_annotation_id = tuple(span_annotation) - if span_annotation_id not in uniq_store: - uniq_store.add(span_annotation_id) - final_span_annotation_store.append(span_annotation) - return sorted(final_span_annotation_store, key=lambda x: x[0]) - - def _merge_predictions( - self, - window1: RelikReaderSample, - window2: RelikReaderSample, - ) -> Tuple[Set[Tuple[int, int, str]], dict]: - merged_predictions = window1.predicted_window_labels_chars.union( - window2.predicted_window_labels_chars - ) - - span_title_probabilities = dict() - # probabilities - for span_prediction, predicted_probs in itertools.chain( - window1.probs_window_labels_chars.items(), - window2.probs_window_labels_chars.items(), - ): - if span_prediction not in span_title_probabilities: - span_title_probabilities[span_prediction] = predicted_probs - - return merged_predictions, span_title_probabilities - - def _merge_window_pair( - self, - window1: RelikReaderSample, - window2: RelikReaderSample, - ) -> RelikReaderSample: - merging_output = dict() - - if getattr(window1, "doc_id", None) is not None: - assert window1.doc_id == window2.doc_id - - if getattr(window1, "offset", None) is not None: - assert ( - window1.offset < window2.offset - ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})" - - merging_output["doc_id"] = window1.doc_id - merging_output["offset"] = window2.offset - - m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens( - window1, window2 - ) - - window_labels = None - if getattr(window1, "window_labels", None) is not None: - window_labels = self._merge_span_annotation( - window1.window_labels, window2.window_labels - ) - ( - predicted_window_labels_chars, - probs_window_labels_chars, - ) = self._merge_predictions( - window1, - window2, - ) - - merging_output.update( - dict( - tokens=m_tokens, - token2char_start=m_token2char_start, - token2char_end=m_token2char_end, - window_labels=window_labels, - predicted_window_labels_chars=predicted_window_labels_chars, - probs_window_labels_chars=probs_window_labels_chars, - ) - ) - - return RelikReaderSample(**merging_output) diff --git a/relik/inference/gerbil.py b/relik/inference/gerbil.py deleted file mode 100644 index d4c3f17cacea1d5472de99d1a974ad098585fc20..0000000000000000000000000000000000000000 --- a/relik/inference/gerbil.py +++ /dev/null @@ -1,254 +0,0 @@ -import argparse -import json -import os -import re -import sys -from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Iterator, List, Optional, Tuple - -from relik.inference.annotator import Relik -from relik.inference.data.objects import RelikOutput - -# sys.path += ['../'] -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) - - -import logging - -logger = logging.getLogger(__name__) - - -class GerbilAlbyManager: - def __init__( - self, - annotator: Optional[Relik] = None, - response_logger_dir: Optional[str] = None, - ) -> None: - self.annotator = annotator - self.response_logger_dir = response_logger_dir - self.predictions_counter = 0 - self.labels_mapping = None - - def annotate(self, document: str): - relik_output: RelikOutput = self.annotator(document) - annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels] - if self.labels_mapping is not None: - return [ - (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations - ] - return annotations - - def set_mapping_file(self, mapping_file_path: str): - with open(mapping_file_path) as f: - labels_mapping = json.load(f) - self.labels_mapping = {v: k for k, v in labels_mapping.items()} - - def write_response_bundle( - self, - document: str, - new_document: str, - annotations: list, - mapped_annotations: list, - ) -> None: - if self.response_logger_dir is None: - return - - if not os.path.isdir(self.response_logger_dir): - os.mkdir(self.response_logger_dir) - - with open( - f"{self.response_logger_dir}/{self.predictions_counter}.json", "w" - ) as f: - out_json_obj = dict( - document=document, - new_document=new_document, - annotations=annotations, - mapped_annotations=mapped_annotations, - ) - - out_json_obj["span_annotations"] = [ - (ss, se, document[ss:se], label) for (ss, se, label) in annotations - ] - - out_json_obj["span_mapped_annotations"] = [ - (ss, se, new_document[ss:se], label) - for (ss, se, label) in mapped_annotations - ] - - json.dump(out_json_obj, f, indent=2) - - self.predictions_counter += 1 - - -manager = GerbilAlbyManager() - - -def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]: - pattern_subs = { - "-LPR- ": " (", - "-RPR-": ")", - "\n\n": "\n", - "-LRB-": "(", - "-RRB-": ")", - '","': ",", - } - - document_acc = document - curr_offset = 0 - char2offset = [] - - matchings = re.finditer("({})".format("|".join(pattern_subs)), document) - for span_matching in sorted(matchings, key=lambda x: x.span()[0]): - span_start, span_end = span_matching.span() - span_start -= curr_offset - span_end -= curr_offset - - span_text = document_acc[span_start:span_end] - span_sub = pattern_subs[span_text] - document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:] - - offset = len(span_text) - len(span_sub) - curr_offset += offset - - char2offset.append((span_start + len(span_sub), curr_offset)) - - return document_acc, char2offset - - -def map_back_annotations( - annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]] -) -> Iterator[Tuple[int, int, str]]: - def map_char(char_idx: int) -> int: - current_offset = 0 - for offset_idx, offset_value in char_mapping: - if char_idx >= offset_idx: - current_offset = offset_value - else: - break - return char_idx + current_offset - - for ss, se, label in annotations: - yield map_char(ss), map_char(se), label - - -def annotate(document: str) -> List[Tuple[int, int, str]]: - new_document, mapping = preprocess_document(document) - logger.info("Mapping: " + str(mapping)) - logger.info("Document: " + str(document)) - annotations = [ - (cs, ce, label.replace(" ", "_")) - for cs, ce, label in manager.annotate(new_document) - ] - logger.info("New document: " + str(new_document)) - mapped_annotations = ( - list(map_back_annotations(annotations, mapping)) - if len(mapping) > 0 - else annotations - ) - - logger.info( - "Annotations: " - + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations]) - ) - - manager.write_response_bundle( - document, new_document, mapped_annotations, annotations - ) - - if not all( - [ - new_document[ss:se] == document[mss:mse] - for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) - ] - ): - diff_mappings = [ - (new_document[ss:se], document[mss:mse]) - for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) - ] - return None - assert all( - [ - document[mss:mse] == new_document[ss:se] - for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) - ] - ), (mapped_annotations, annotations) - - return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations] - - -class GetHandler(BaseHTTPRequestHandler): - def do_POST(self): - content_length = int(self.headers["Content-Length"]) - post_data = self.rfile.read(content_length) - self.send_response(200) - self.end_headers() - doc_text = read_json(post_data) - # try: - response = annotate(doc_text) - - self.wfile.write(bytes(json.dumps(response), "utf-8")) - return - - -def read_json(post_data): - data = json.loads(post_data.decode("utf-8")) - # logger.info("received data:", data) - text = data["text"] - # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]] - return text - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument("--relik-model-name", required=True) - parser.add_argument("--responses-log-dir") - parser.add_argument("--log-file", default="logs/logging.txt") - parser.add_argument("--mapping-file") - return parser.parse_args() - - -def main(): - args = parse_args() - - # init manager - manager.response_logger_dir = args.responses_log_dir - # manager.annotator = Relik.from_pretrained(args.relik_model_name) - - print("Debugging, not using you relik model but an hardcoded one.") - manager.annotator = Relik( - question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", - document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", - reader="relik/reader/models/relik-reader-deberta-base-new-data", - window_size=32, - window_stride=16, - candidates_preprocessing_fn=(lambda x: x.split("")[0].strip()), - ) - - if args.mapping_file is not None: - manager.set_mapping_file(args.mapping_file) - - port = 6654 - server = HTTPServer(("localhost", port), GetHandler) - logger.info(f"Starting server at http://localhost:{port}") - - # Create a file handler and set its level - file_handler = logging.FileHandler(args.log_file) - file_handler.setLevel(logging.DEBUG) - - # Create a log formatter and set it on the handler - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - file_handler.setFormatter(formatter) - - # Add the file handler to the logger - logger.addHandler(file_handler) - - try: - server.serve_forever() - except KeyboardInterrupt: - exit(0) - - -if __name__ == "__main__": - main() diff --git a/relik/inference/preprocessing.py b/relik/inference/preprocessing.py deleted file mode 100644 index 2476fe47ea64d907a8c32c31082253c45b48720c..0000000000000000000000000000000000000000 --- a/relik/inference/preprocessing.py +++ /dev/null @@ -1,4 +0,0 @@ -def wikipedia_title_and_openings_preprocessing( - wikipedia_title_and_openings: str, sepator: str = " " -): - return wikipedia_title_and_openings.split(sepator, 1)[0] diff --git a/relik/inference/serve/__init__.py b/relik/inference/serve/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/inference/serve/backend/__init__.py b/relik/inference/serve/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/inference/serve/backend/relik.py b/relik/inference/serve/backend/relik.py deleted file mode 100644 index 038e2ef78afbccb0758162996e35cd8dc858d453..0000000000000000000000000000000000000000 --- a/relik/inference/serve/backend/relik.py +++ /dev/null @@ -1,210 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Optional, Union - -from relik.common.utils import is_package_available -from relik.inference.annotator import Relik - -if not is_package_available("fastapi"): - raise ImportError( - "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." - ) -from fastapi import FastAPI, HTTPException - -if not is_package_available("ray"): - raise ImportError( - "Ray is not installed. Please install Ray with `pip install relik[serve]`." - ) -from ray import serve - -from relik.common.log import get_logger -from relik.inference.serve.backend.utils import ( - RayParameterManager, - ServerParameterManager, -) -from relik.retriever.data.utils import batch_generator - -logger = get_logger(__name__, level=logging.INFO) - -VERSION = {} # type: ignore -with open( - Path(__file__).parent.parent.parent.parent / "version.py", "r" -) as version_file: - exec(version_file.read(), VERSION) - -# Env variables for server -SERVER_MANAGER = ServerParameterManager() -RAY_MANAGER = RayParameterManager() - -app = FastAPI( - title="ReLiK", - version=VERSION["VERSION"], - description="ReLiK REST API", -) - - -@serve.deployment( - ray_actor_options={ - "num_gpus": RAY_MANAGER.num_gpus - if ( - SERVER_MANAGER.retriver_device == "cuda" - or SERVER_MANAGER.reader_device == "cuda" - ) - else 0 - }, - autoscaling_config={ - "min_replicas": RAY_MANAGER.min_replicas, - "max_replicas": RAY_MANAGER.max_replicas, - }, -) -@serve.ingress(app) -class RelikServer: - def __init__( - self, - question_encoder: str, - document_index: str, - passage_encoder: Optional[str] = None, - reader_encoder: Optional[str] = None, - top_k: int = 100, - retriver_device: str = "cpu", - reader_device: str = "cpu", - index_device: Optional[str] = None, - precision: int = 32, - index_precision: Optional[int] = None, - use_faiss: bool = False, - window_batch_size: int = 32, - window_size: int = 32, - window_stride: int = 16, - split_on_spaces: bool = False, - ): - # parameters - self.question_encoder = question_encoder - self.passage_encoder = passage_encoder - self.reader_encoder = reader_encoder - self.document_index = document_index - self.top_k = top_k - self.retriver_device = retriver_device - self.index_device = index_device or retriver_device - self.reader_device = reader_device - self.precision = precision - self.index_precision = index_precision or precision - self.use_faiss = use_faiss - self.window_batch_size = window_batch_size - self.window_size = window_size - self.window_stride = window_stride - self.split_on_spaces = split_on_spaces - - # log stuff for debugging - logger.info("Initializing RelikServer with parameters:") - logger.info(f"QUESTION_ENCODER: {self.question_encoder}") - logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}") - logger.info(f"READER_ENCODER: {self.reader_encoder}") - logger.info(f"DOCUMENT_INDEX: {self.document_index}") - logger.info(f"TOP_K: {self.top_k}") - logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}") - logger.info(f"READER_DEVICE: {self.reader_device}") - logger.info(f"INDEX_DEVICE: {self.index_device}") - logger.info(f"PRECISION: {self.precision}") - logger.info(f"INDEX_PRECISION: {self.index_precision}") - logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}") - logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}") - - self.relik = Relik( - question_encoder=self.question_encoder, - passage_encoder=self.passage_encoder, - document_index=self.document_index, - reader=self.reader_encoder, - retriever_device=self.retriver_device, - document_index_device=self.index_device, - reader_device=self.reader_device, - retriever_precision=self.precision, - document_index_precision=self.index_precision, - reader_precision=self.precision, - ) - - # @serve.batch() - async def handle_batch(self, documents: List[str]) -> List: - return self.relik( - documents, - top_k=self.top_k, - window_size=self.window_size, - window_stride=self.window_stride, - batch_size=self.window_batch_size, - ) - - @app.post("/api/entities") - async def entities_endpoint( - self, - documents: Union[str, List[str]], - ): - try: - # normalize input - if isinstance(documents, str): - documents = [documents] - if document_topics is not None: - if isinstance(document_topics, str): - document_topics = [document_topics] - assert len(documents) == len(document_topics) - # get predictions for the retriever - return await self.handle_batch(documents, document_topics) - except Exception as e: - # log the entire stack trace - logger.exception(e) - raise HTTPException(status_code=500, detail=f"Server Error: {e}") - - @app.post("/api/gerbil") - async def gerbil_endpoint(self, documents: Union[str, List[str]]): - try: - # normalize input - if isinstance(documents, str): - documents = [documents] - - # output list - windows_passages = [] - # split documents into windows - document_windows = [ - window - for doc_id, document in enumerate(documents) - for window in self.window_manager( - self.tokenizer, - document, - window_size=self.window_size, - stride=self.window_stride, - doc_id=doc_id, - ) - ] - - # get text and topic from document windows and create new list - model_inputs = [ - (window.text, window.doc_topic) for window in document_windows - ] - - # batch generator - for batch in batch_generator( - model_inputs, batch_size=self.window_batch_size - ): - text, text_pair = zip(*batch) - batch_predictions = await self.handle_batch_retriever(text, text_pair) - windows_passages.extend( - [ - [p.label for p in predictions] - for predictions in batch_predictions - ] - ) - - # add passage to document windows - for window, passages in zip(document_windows, windows_passages): - # clean up passages (remove everything after first tag if present) - passages = [c.split(" ", 1)[0] for c in passages] - window.window_candidates = passages - - # return document windows - return document_windows - - except Exception as e: - # log the entire stack trace - logger.exception(e) - raise HTTPException(status_code=500, detail=f"Server Error: {e}") - - -server = RelikServer.bind(**vars(SERVER_MANAGER)) diff --git a/relik/inference/serve/backend/retriever.py b/relik/inference/serve/backend/retriever.py deleted file mode 100644 index e796893e76b83377a5f8b2c7afdccce21756dcbd..0000000000000000000000000000000000000000 --- a/relik/inference/serve/backend/retriever.py +++ /dev/null @@ -1,206 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Optional, Union - -from relik.common.utils import is_package_available - -if not is_package_available("fastapi"): - raise ImportError( - "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." - ) -from fastapi import FastAPI, HTTPException - -if not is_package_available("ray"): - raise ImportError( - "Ray is not installed. Please install Ray with `pip install relik[serve]`." - ) -from ray import serve - -from relik.common.log import get_logger -from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer -from relik.inference.data.window.manager import WindowManager -from relik.inference.serve.backend.utils import ( - RayParameterManager, - ServerParameterManager, -) -from relik.retriever.data.utils import batch_generator -from relik.retriever.pytorch_modules import GoldenRetriever - -logger = get_logger(__name__, level=logging.INFO) - -VERSION = {} # type: ignore -with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file: - exec(version_file.read(), VERSION) - -# Env variables for server -SERVER_MANAGER = ServerParameterManager() -RAY_MANAGER = RayParameterManager() - -app = FastAPI( - title="Golden Retriever", - version=VERSION["VERSION"], - description="Golden Retriever REST API", -) - - -@serve.deployment( - ray_actor_options={ - "num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0 - }, - autoscaling_config={ - "min_replicas": RAY_MANAGER.min_replicas, - "max_replicas": RAY_MANAGER.max_replicas, - }, -) -@serve.ingress(app) -class GoldenRetrieverServer: - def __init__( - self, - question_encoder: str, - document_index: str, - passage_encoder: Optional[str] = None, - top_k: int = 100, - device: str = "cpu", - index_device: Optional[str] = None, - precision: int = 32, - index_precision: Optional[int] = None, - use_faiss: bool = False, - window_batch_size: int = 32, - window_size: int = 32, - window_stride: int = 16, - split_on_spaces: bool = False, - ): - # parameters - self.question_encoder = question_encoder - self.passage_encoder = passage_encoder - self.document_index = document_index - self.top_k = top_k - self.device = device - self.index_device = index_device or device - self.precision = precision - self.index_precision = index_precision or precision - self.use_faiss = use_faiss - self.window_batch_size = window_batch_size - self.window_size = window_size - self.window_stride = window_stride - self.split_on_spaces = split_on_spaces - - # log stuff for debugging - logger.info("Initializing GoldenRetrieverServer with parameters:") - logger.info(f"QUESTION_ENCODER: {self.question_encoder}") - logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}") - logger.info(f"DOCUMENT_INDEX: {self.document_index}") - logger.info(f"TOP_K: {self.top_k}") - logger.info(f"DEVICE: {self.device}") - logger.info(f"INDEX_DEVICE: {self.index_device}") - logger.info(f"PRECISION: {self.precision}") - logger.info(f"INDEX_PRECISION: {self.index_precision}") - logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}") - logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}") - - self.retriever = GoldenRetriever( - question_encoder=self.question_encoder, - passage_encoder=self.passage_encoder, - document_index=self.document_index, - device=self.device, - index_device=self.index_device, - index_precision=self.index_precision, - ) - self.retriever.eval() - - if self.split_on_spaces: - logger.info("Using WhitespaceTokenizer") - self.tokenizer = WhitespaceTokenizer() - # logger.info("Using RegexTokenizer") - # self.tokenizer = RegexTokenizer() - else: - logger.info("Using SpacyTokenizer") - self.tokenizer = SpacyTokenizer(language="en") - - self.window_manager = WindowManager(tokenizer=self.tokenizer) - - # @serve.batch() - async def handle_batch( - self, documents: List[str], document_topics: List[str] - ) -> List: - return self.retriever.retrieve( - documents, text_pair=document_topics, k=self.top_k, precision=self.precision - ) - - @app.post("/api/retrieve") - async def retrieve_endpoint( - self, - documents: Union[str, List[str]], - document_topics: Optional[Union[str, List[str]]] = None, - ): - try: - # normalize input - if isinstance(documents, str): - documents = [documents] - if document_topics is not None: - if isinstance(document_topics, str): - document_topics = [document_topics] - assert len(documents) == len(document_topics) - # get predictions - return await self.handle_batch(documents, document_topics) - except Exception as e: - # log the entire stack trace - logger.exception(e) - raise HTTPException(status_code=500, detail=f"Server Error: {e}") - - @app.post("/api/gerbil") - async def gerbil_endpoint(self, documents: Union[str, List[str]]): - try: - # normalize input - if isinstance(documents, str): - documents = [documents] - - # output list - windows_passages = [] - # split documents into windows - document_windows = [ - window - for doc_id, document in enumerate(documents) - for window in self.window_manager( - self.tokenizer, - document, - window_size=self.window_size, - stride=self.window_stride, - doc_id=doc_id, - ) - ] - - # get text and topic from document windows and create new list - model_inputs = [ - (window.text, window.doc_topic) for window in document_windows - ] - - # batch generator - for batch in batch_generator( - model_inputs, batch_size=self.window_batch_size - ): - text, text_pair = zip(*batch) - batch_predictions = await self.handle_batch(text, text_pair) - windows_passages.extend( - [ - [p.label for p in predictions] - for predictions in batch_predictions - ] - ) - - # add passage to document windows - for window, passages in zip(document_windows, windows_passages): - # clean up passages (remove everything after first tag if present) - passages = [c.split(" ", 1)[0] for c in passages] - window.window_candidates = passages - - # return document windows - return document_windows - - except Exception as e: - # log the entire stack trace - logger.exception(e) - raise HTTPException(status_code=500, detail=f"Server Error: {e}") - - -server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER)) diff --git a/relik/inference/serve/backend/utils.py b/relik/inference/serve/backend/utils.py deleted file mode 100644 index bdf869c1ece0e260355526ee5fcc2f00da7ef887..0000000000000000000000000000000000000000 --- a/relik/inference/serve/backend/utils.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Union - - -@dataclass -class ServerParameterManager: - retriver_device: str = os.environ.get("RETRIEVER_DEVICE", "cpu") - reader_device: str = os.environ.get("READER_DEVICE", "cpu") - index_device: str = os.environ.get("INDEX_DEVICE", retriver_device) - precision: Union[str, int] = os.environ.get("PRECISION", "fp32") - index_precision: Union[str, int] = os.environ.get("INDEX_PRECISION", precision) - question_encoder: str = os.environ.get("QUESTION_ENCODER", None) - passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None) - document_index: str = os.environ.get("DOCUMENT_INDEX", None) - reader_encoder: str = os.environ.get("READER_ENCODER", None) - top_k: int = int(os.environ.get("TOP_K", 100)) - use_faiss: bool = os.environ.get("USE_FAISS", False) - window_batch_size: int = int(os.environ.get("WINDOW_BATCH_SIZE", 32)) - window_size: int = int(os.environ.get("WINDOW_SIZE", 32)) - window_stride: int = int(os.environ.get("WINDOW_SIZE", 16)) - split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False) - - -class RayParameterManager: - def __init__(self) -> None: - self.num_gpus = int(os.environ.get("NUM_GPUS", 1)) - self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1)) - self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1)) diff --git a/relik/inference/serve/frontend/__init__.py b/relik/inference/serve/frontend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/inference/serve/frontend/relik.py b/relik/inference/serve/frontend/relik.py deleted file mode 100644 index 5dd8bb4eb2d8c6b056c61bda359959010e688635..0000000000000000000000000000000000000000 --- a/relik/inference/serve/frontend/relik.py +++ /dev/null @@ -1,231 +0,0 @@ -import os -import re -import time -from pathlib import Path - -import requests -import streamlit as st -from spacy import displacy -from streamlit_extras.badges import badge -from streamlit_extras.stylable_container import stylable_container - -RELIK = os.getenv("RELIK", "localhost:8000/api/entities") - -import random - - -def get_random_color(ents): - colors = {} - random_colors = generate_pastel_colors(len(ents)) - for ent in ents: - colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) - return colors - - -def floatrange(start, stop, steps): - if int(steps) == 1: - return [stop] - return [ - start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) - ] - - -def hsl_to_rgb(h, s, l): - def hue_2_rgb(v1, v2, v_h): - while v_h < 0.0: - v_h += 1.0 - while v_h > 1.0: - v_h -= 1.0 - if 6 * v_h < 1.0: - return v1 + (v2 - v1) * 6.0 * v_h - if 2 * v_h < 1.0: - return v2 - if 3 * v_h < 2.0: - return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 - return v1 - - # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1." - # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1." - - r, b, g = (l * 255,) * 3 - if s != 0.0: - if l < 0.5: - var_2 = l * (1.0 + s) - else: - var_2 = (l + s) - (s * l) - var_1 = 2.0 * l - var_2 - r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) - g = 255 * hue_2_rgb(var_1, var_2, h) - b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) - - return int(round(r)), int(round(g)), int(round(b)) - - -def generate_pastel_colors(n): - """Return different pastel colours. - - Input: - n (integer) : The number of colors to return - - Output: - A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) - - Example: - >>> print generate_pastel_colors(5) - ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] - """ - if n == 0: - return [] - - # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space) - start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue - saturation = 1.0 - lightness = 0.8 - # We take points around the chromatic circle (hue): - # (Note: we generate n+1 colors, then drop the last one ([:-1]) because - # it equals the first one (hue 0 = hue 1)) - return [ - "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) - for hue in floatrange(start_hue, start_hue + 1, n + 1) - ][:-1] - - -def set_sidebar(css): - white_link_wrapper = "{}" - with st.sidebar: - st.markdown(f"", unsafe_allow_html=True) - st.image( - "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg", - use_column_width=True, - ) - st.markdown("## ReLiK") - st.write( - f""" - - {white_link_wrapper.format("#", "  Paper")} - - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "  GitHub")} - - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "  Docker Hub")} - """, - unsafe_allow_html=True, - ) - st.markdown("## Sapienza NLP") - st.write( - f""" - - {white_link_wrapper.format("https://nlp.uniroma1.it", "  Webpage")} - - {white_link_wrapper.format("https://github.com/SapienzaNLP", "  GitHub")} - - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "  Twitter")} - - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "  LinkedIn")} - """, - unsafe_allow_html=True, - ) - - -def get_el_annotations(response): - # swap labels key with ents - response["ents"] = response.pop("labels") - label_in_text = set(l["label"] for l in response["ents"]) - options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} - return response, options - - -def set_intro(css): - # intro - st.markdown("# ReLik") - st.markdown( - "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget" - ) - # st.markdown( - # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API " - # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by " - # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), " - # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)." - # ) - badge(type="github", name="sapienzanlp/relik") - badge(type="pypi", name="relik") - - -def run_client(): - with open(Path(__file__).parent / "style.css") as f: - css = f.read() - - st.set_page_config( - page_title="ReLik", - page_icon="🦮", - layout="wide", - ) - set_sidebar(css) - set_intro(css) - - # text input - text = st.text_area( - "Enter Text Below:", - value="Obama went to Rome for a quick vacation.", - height=200, - max_chars=500, - ) - - with stylable_container( - key="annotate_button", - css_styles=""" - button { - background-color: #802433; - color: white; - border-radius: 25px; - } - """, - ): - submit = st.button("Annotate") - # submit = st.button("Run") - - # ReLik API call - if submit: - text = text.strip() - if text: - st.markdown("####") - st.markdown("#### Entity Linking") - with st.spinner(text="In progress"): - response = requests.post(RELIK, json=text) - if response.status_code != 200: - st.error("Error: {}".format(response.status_code)) - else: - response = response.json() - - # Entity Linking - # with stylable_container( - # key="container_with_border", - # css_styles=""" - # { - # border: 1px solid rgba(49, 51, 63, 0.2); - # border-radius: 0.5rem; - # padding: 0.5rem; - # padding-bottom: 2rem; - # } - # """, - # ): - # st.markdown("##") - dict_of_ents, options = get_el_annotations(response=response) - display = displacy.render( - dict_of_ents, manual=True, style="ent", options=options - ) - display = display.replace("\n", " ") - # wsd_display = re.sub( - # r"(wiki::\d+\w)", - # r"\g<1>".format( - # language.upper() - # ), - # wsd_display, - # ) - with st.container(): - st.write(display, unsafe_allow_html=True) - - st.markdown("####") - st.markdown("#### Relation Extraction") - - with st.container(): - st.write("Coming :)", unsafe_allow_html=True) - - else: - st.error("Please enter some text.") - - -if __name__ == "__main__": - run_client() diff --git a/relik/inference/serve/frontend/style.css b/relik/inference/serve/frontend/style.css deleted file mode 100644 index 31f0d182cfd9b2636d5db5cbd0e7a1339ed5d1c3..0000000000000000000000000000000000000000 --- a/relik/inference/serve/frontend/style.css +++ /dev/null @@ -1,33 +0,0 @@ -/* Sidebar */ -.eczjsme11 { - background-color: #802433; -} - -.st-emotion-cache-10oheav h2 { - color: white; -} - -.st-emotion-cache-10oheav li { - color: white; -} - -/* Main */ -a:link { - text-decoration: none; - color: white; -} - -a:visited { - text-decoration: none; - color: white; -} - -a:hover { - text-decoration: none; - color: rgba(255, 255, 255, 0.871); -} - -a:active { - text-decoration: none; - color: white; -} \ No newline at end of file diff --git a/relik/reader/__init__.py b/relik/reader/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/reader/conf/config.yaml b/relik/reader/conf/config.yaml deleted file mode 100644 index 05b4e524d060aed56b930d1f578424b986792975..0000000000000000000000000000000000000000 --- a/relik/reader/conf/config.yaml +++ /dev/null @@ -1,14 +0,0 @@ -# Required to make the "experiments" dir the default one for the output of the models -hydra: - run: - dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - -model_name: relik-reader-deberta-base # used to name the model in wandb and output dir -project_name: relik-reader # used to name the project in wandb - - -defaults: - - _self_ - - training: base - - model: base - - data: base diff --git a/relik/reader/conf/data/base.yaml b/relik/reader/conf/data/base.yaml deleted file mode 100644 index 1964d8750bed1a162a9ec15d20be1708ccad9914..0000000000000000000000000000000000000000 --- a/relik/reader/conf/data/base.yaml +++ /dev/null @@ -1,21 +0,0 @@ -train_dataset_path: "relik/reader/data/train.jsonl" -val_dataset_path: "relik/reader/data/testa.jsonl" - -train_dataset: - _target_: "relik.reader.relik_reader_data.RelikDataset" - transformer_model: "${model.model.transformer_model}" - materialize_samples: False - shuffle_candidates: 0.5 - random_drop_gold_candidates: 0.05 - noise_param: 0.0 - for_inference: False - tokens_per_batch: 4096 - special_symbols: null - -val_dataset: - _target_: "relik.reader.relik_reader_data.RelikDataset" - transformer_model: "${model.model.transformer_model}" - materialize_samples: False - shuffle_candidates: False - for_inference: True - special_symbols: null diff --git a/relik/reader/conf/data/re.yaml b/relik/reader/conf/data/re.yaml deleted file mode 100644 index 17c18ee886021bc0157edb156020409fdd799fbc..0000000000000000000000000000000000000000 --- a/relik/reader/conf/data/re.yaml +++ /dev/null @@ -1,54 +0,0 @@ -train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl" -val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl" -test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl" - -relations_definitions: - /people/person/nationality: "nationality" - /sports/sports_team/location: "sports team location" - /location/country/administrative_divisions: "administrative divisions" - /business/company/major_shareholders: "shareholders" - /people/ethnicity/people: "ethnicity" - /people/ethnicity/geographic_distribution: "geographic distributi6on" - /business/company_shareholder/major_shareholder_of: "major shareholder" - /location/location/contains: "location" - /business/company/founders: "founders" - /business/person/company: "company" - /business/company/advisors: "advisor" - /people/deceased_person/place_of_death: "place of death" - /business/company/industry: "industry" - /people/person/ethnicity: "ethnic background" - /people/person/place_of_birth: "place of birth" - /location/administrative_division/country: "country of an administration division" - /people/person/place_lived: "place lived" - /sports/sports_team_location/teams: "sports team" - /people/person/children: "child" - /people/person/religion: "religion" - /location/neighborhood/neighborhood_of: "neighborhood" - /location/country/capital: "capital" - /business/company/place_founded: "company founded location" - /people/person/profession: "occupation" - -train_dataset: - _target_: "relik.reader.relik_reader_re_data.RelikREDataset" - transformer_model: "${model.model.transformer_model}" - materialize_samples: False - shuffle_candidates: False - flip_candidates: 1.0 - noise_param: 0.0 - for_inference: False - tokens_per_batch: 4096 - min_length: -1 - special_symbols: null - relations_definitions: ${data.relations_definitions} - sorting_fields: - - "predictable_candidates" -val_dataset: - _target_: "relik.reader.relik_reader_re_data.RelikREDataset" - transformer_model: "${model.model.transformer_model}" - materialize_samples: False - shuffle_candidates: False - flip_candidates: False - for_inference: True - min_length: -1 - special_symbols: null - relations_definitions: ${data.relations_definitions} diff --git a/relik/reader/conf/training/base.yaml b/relik/reader/conf/training/base.yaml deleted file mode 100644 index 8e366a96408bd0f8ff5184849e53d19bb477af38..0000000000000000000000000000000000000000 --- a/relik/reader/conf/training/base.yaml +++ /dev/null @@ -1,12 +0,0 @@ -seed: 94 - -trainer: - _target_: lightning.Trainer - devices: - - 0 - precision: "16-mixed" - max_steps: 50000 - val_check_interval: 1.0 - num_sanity_val_steps: 0 - limit_val_batches: 1 - gradient_clip_val: 1.0 diff --git a/relik/reader/conf/training/re.yaml b/relik/reader/conf/training/re.yaml deleted file mode 100644 index 8701ae3fca48830649022644a743783a1016bd5b..0000000000000000000000000000000000000000 --- a/relik/reader/conf/training/re.yaml +++ /dev/null @@ -1,12 +0,0 @@ -seed: 15 - -trainer: - _target_: lightning.Trainer - devices: - - 0 - precision: "16-mixed" - max_steps: 100000 - val_check_interval: 1.0 - num_sanity_val_steps: 0 - limit_val_batches: 1 - gradient_clip_val: 1.0 \ No newline at end of file diff --git a/relik/reader/data/__init__.py b/relik/reader/data/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/reader/data/patches.py b/relik/reader/data/patches.py deleted file mode 100644 index b0d03dbdf08d0e205787ce2b8176c6bd47d2dfca..0000000000000000000000000000000000000000 --- a/relik/reader/data/patches.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import List - -from relik.reader.data.relik_reader_sample import RelikReaderSample -from relik.reader.utils.special_symbols import NME_SYMBOL - - -def merge_patches_predictions(sample) -> None: - sample._d["predicted_window_labels"] = dict() - predicted_window_labels = sample._d["predicted_window_labels"] - - sample._d["span_title_probabilities"] = dict() - span_title_probabilities = sample._d["span_title_probabilities"] - - span2title = dict() - for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): - # selecting span predictions - for predicted_title, predicted_spans in patch_info[ - "predicted_window_labels" - ].items(): - for pred_span in predicted_spans: - pred_span = tuple(pred_span) - curr_title = span2title.get(pred_span) - if curr_title is None or curr_title == NME_SYMBOL: - span2title[pred_span] = predicted_title - # else: - # print("Merging at patch level") - - # selecting span predictions probability - for predicted_span, titles_probabilities in patch_info[ - "span_title_probabilities" - ].items(): - if predicted_span not in span_title_probabilities: - span_title_probabilities[predicted_span] = titles_probabilities - - for span, title in span2title.items(): - if title not in predicted_window_labels: - predicted_window_labels[title] = list() - predicted_window_labels[title].append(span) - - -def remove_duplicate_samples( - samples: List[RelikReaderSample], -) -> List[RelikReaderSample]: - seen_sample = set() - samples_store = [] - for sample in samples: - sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}" - if sample_id not in seen_sample: - seen_sample.add(sample_id) - samples_store.append(sample) - return samples_store diff --git a/relik/reader/data/relik_reader_data.py b/relik/reader/data/relik_reader_data.py deleted file mode 100644 index 3c65646f99d37cdcf03ab7005c83eb0069da168c..0000000000000000000000000000000000000000 --- a/relik/reader/data/relik_reader_data.py +++ /dev/null @@ -1,965 +0,0 @@ -import logging -from typing import ( - Any, - Callable, - Dict, - Generator, - Iterable, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union, -) - -import numpy as np -import torch -from torch.utils.data import IterableDataset -from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizer - -from relik.reader.data.relik_reader_data_utils import ( - add_noise_to_value, - batchify, - chunks, - flatten, -) -from relik.reader.data.relik_reader_sample import ( - RelikReaderSample, - load_relik_reader_samples, -) -from relik.reader.utils.special_symbols import NME_SYMBOL - -logger = logging.getLogger(__name__) - - -def preprocess_dataset( - input_dataset: Iterable[dict], - transformer_model: str, - add_topic: bool, -) -> Iterable[dict]: - tokenizer = AutoTokenizer.from_pretrained(transformer_model) - for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"): - if len(dataset_elem["tokens"]) == 0: - print( - f"Dataset element with doc id: {dataset_elem['doc_id']}", - f"and offset {dataset_elem['offset']} does not contain any token", - "Skipping it", - ) - continue - - new_dataset_elem = dict( - doc_id=dataset_elem["doc_id"], - offset=dataset_elem["offset"], - ) - - tokenization_out = tokenizer( - dataset_elem["tokens"], - return_offsets_mapping=True, - add_special_tokens=False, - ) - - window_tokens = tokenization_out.input_ids - window_tokens = flatten(window_tokens) - - offsets_mapping = [ - [ - ( - ss + dataset_elem["token2char_start"][str(i)], - se + dataset_elem["token2char_start"][str(i)], - ) - for ss, se in tokenization_out.offset_mapping[i] - ] - for i in range(len(dataset_elem["tokens"])) - ] - - offsets_mapping = flatten(offsets_mapping) - - assert len(offsets_mapping) == len(window_tokens) - - window_tokens = ( - [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] - ) - - topic_offset = 0 - if add_topic: - topic_tokens = tokenizer( - dataset_elem["doc_topic"], add_special_tokens=False - ).input_ids - topic_offset = len(topic_tokens) - new_dataset_elem["topic_tokens"] = topic_offset - window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] - - new_dataset_elem.update( - dict( - tokens=window_tokens, - token2char_start={ - str(i): s - for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) - }, - token2char_end={ - str(i): e - for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) - }, - window_candidates=dataset_elem["window_candidates"], - window_candidates_scores=dataset_elem.get("window_candidates_scores"), - ) - ) - - if "window_labels" in dataset_elem: - window_labels = [ - (s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"] - ] - - new_dataset_elem["window_labels"] = window_labels - - if not all( - [ - s in new_dataset_elem["token2char_start"].values() - for s, _, _ in new_dataset_elem["window_labels"] - ] - ): - print( - "Mismatching token start char mapping with labels", - new_dataset_elem["token2char_start"], - new_dataset_elem["window_labels"], - dataset_elem["tokens"], - ) - continue - - if not all( - [ - e in new_dataset_elem["token2char_end"].values() - for _, e, _ in new_dataset_elem["window_labels"] - ] - ): - print( - "Mismatching token end char mapping with labels", - new_dataset_elem["token2char_end"], - new_dataset_elem["window_labels"], - dataset_elem["tokens"], - ) - continue - - yield new_dataset_elem - - -def preprocess_sample( - relik_sample: RelikReaderSample, - tokenizer, - lowercase_policy: float, - add_topic: bool = False, -) -> None: - if len(relik_sample.tokens) == 0: - return - - if lowercase_policy > 0: - lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy - relik_sample.tokens = [ - t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens) - ] - - tokenization_out = tokenizer( - relik_sample.tokens, - return_offsets_mapping=True, - add_special_tokens=False, - ) - - window_tokens = tokenization_out.input_ids - window_tokens = flatten(window_tokens) - - offsets_mapping = [ - [ - ( - ss + relik_sample.token2char_start[str(i)], - se + relik_sample.token2char_start[str(i)], - ) - for ss, se in tokenization_out.offset_mapping[i] - ] - for i in range(len(relik_sample.tokens)) - ] - - offsets_mapping = flatten(offsets_mapping) - - assert len(offsets_mapping) == len(window_tokens) - - window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] - - topic_offset = 0 - if add_topic: - topic_tokens = tokenizer( - relik_sample.doc_topic, add_special_tokens=False - ).input_ids - topic_offset = len(topic_tokens) - relik_sample.topic_tokens = topic_offset - window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] - - relik_sample._d.update( - dict( - tokens=window_tokens, - token2char_start={ - str(i): s - for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) - }, - token2char_end={ - str(i): e - for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) - }, - ) - ) - - if "window_labels" in relik_sample._d: - relik_sample.window_labels = [ - (s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels - ] - - -class TokenizationOutput(NamedTuple): - input_ids: torch.Tensor - attention_mask: torch.Tensor - token_type_ids: torch.Tensor - prediction_mask: torch.Tensor - special_symbols_mask: torch.Tensor - - -class RelikDataset(IterableDataset): - def __init__( - self, - dataset_path: Optional[str], - materialize_samples: bool, - transformer_model: Union[str, PreTrainedTokenizer], - special_symbols: List[str], - shuffle_candidates: Optional[Union[bool, float]] = False, - for_inference: bool = False, - noise_param: float = 0.1, - sorting_fields: Optional[str] = None, - tokens_per_batch: int = 2048, - batch_size: int = None, - max_batch_size: int = 128, - section_size: int = 50_000, - prebatch: bool = True, - random_drop_gold_candidates: float = 0.0, - use_nme: bool = True, - max_subwords_per_candidate: bool = 22, - mask_by_instances: bool = False, - min_length: int = 5, - max_length: int = 2048, - model_max_length: int = 1000, - split_on_cand_overload: bool = True, - skip_empty_training_samples: bool = False, - drop_last: bool = False, - samples: Optional[Iterator[RelikReaderSample]] = None, - lowercase_policy: float = 0.0, - **kwargs, - ): - super().__init__(**kwargs) - self.dataset_path = dataset_path - self.materialize_samples = materialize_samples - self.samples: Optional[List[RelikReaderSample]] = None - if self.materialize_samples: - self.samples = list() - - if isinstance(transformer_model, str): - self.tokenizer = self._build_tokenizer(transformer_model, special_symbols) - else: - self.tokenizer = transformer_model - self.special_symbols = special_symbols - self.shuffle_candidates = shuffle_candidates - self.for_inference = for_inference - self.noise_param = noise_param - self.batching_fields = ["input_ids"] - self.sorting_fields = ( - sorting_fields if sorting_fields is not None else self.batching_fields - ) - - self.tokens_per_batch = tokens_per_batch - self.batch_size = batch_size - self.max_batch_size = max_batch_size - self.section_size = section_size - self.prebatch = prebatch - - self.random_drop_gold_candidates = random_drop_gold_candidates - self.use_nme = use_nme - self.max_subwords_per_candidate = max_subwords_per_candidate - self.mask_by_instances = mask_by_instances - self.min_length = min_length - self.max_length = max_length - self.model_max_length = ( - model_max_length - if model_max_length < self.tokenizer.model_max_length - else self.tokenizer.model_max_length - ) - - # retrocompatibility workaround - self.transformer_model = ( - transformer_model - if isinstance(transformer_model, str) - else transformer_model.name_or_path - ) - self.split_on_cand_overload = split_on_cand_overload - self.skip_empty_training_samples = skip_empty_training_samples - self.drop_last = drop_last - self.lowercase_policy = lowercase_policy - self.samples = samples - - def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): - return AutoTokenizer.from_pretrained( - transformer_model, - additional_special_tokens=[ss for ss in special_symbols], - add_prefix_space=True, - ) - - @property - def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: - fields_batchers = { - "input_ids": lambda x: batchify( - x, padding_value=self.tokenizer.pad_token_id - ), - "attention_mask": lambda x: batchify(x, padding_value=0), - "token_type_ids": lambda x: batchify(x, padding_value=0), - "prediction_mask": lambda x: batchify(x, padding_value=1), - "global_attention": lambda x: batchify(x, padding_value=0), - "token2word": None, - "sample": None, - "special_symbols_mask": lambda x: batchify(x, padding_value=False), - "start_labels": lambda x: batchify(x, padding_value=-100), - "end_labels": lambda x: batchify(x, padding_value=-100), - "predictable_candidates_symbols": None, - "predictable_candidates": None, - "patch_offset": None, - "optimus_labels": None, - } - - if "roberta" in self.transformer_model: - del fields_batchers["token_type_ids"] - - return fields_batchers - - def _build_input_ids( - self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] - ) -> List[int]: - return ( - [self.tokenizer.cls_token_id] - + sentence_input_ids - + [self.tokenizer.sep_token_id] - + flatten(candidates_input_ids) - + [self.tokenizer.sep_token_id] - ) - - def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor: - special_symbols_mask = input_ids >= ( - len(self.tokenizer) - len(self.special_symbols) - ) - special_symbols_mask[0] = True - return special_symbols_mask - - def _build_tokenizer_essentials( - self, input_ids, original_sequence, sample - ) -> TokenizationOutput: - input_ids = torch.tensor(input_ids, dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - - total_sequence_len = len(input_ids) - predictable_sentence_len = len(original_sequence) - - # token type ids - token_type_ids = torch.cat( - [ - input_ids.new_zeros( - predictable_sentence_len + 2 - ), # original sentence bpes + CLS and SEP - input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2), - ] - ) - - # prediction mask -> boolean on tokens that are predictable - - prediction_mask = torch.tensor( - [1] - + ([0] * predictable_sentence_len) - + ([1] * (total_sequence_len - predictable_sentence_len - 1)) - ) - - # add topic tokens to the prediction mask so that they cannot be predicted - # or optimized during training - topic_tokens = getattr(sample, "topic_tokens", None) - if topic_tokens is not None: - prediction_mask[1 : 1 + topic_tokens] = 1 - - # If mask by instances is active the prediction mask is applied to everything - # that is not indicated as an instance in the training set. - if self.mask_by_instances: - char_start2token = { - cs: int(tok) for tok, cs in sample.token2char_start.items() - } - char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()} - instances_mask = torch.ones_like(prediction_mask) - for _, span_info in sample.instance_id2span_data.items(): - span_info = span_info[0] - token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS - token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS - instances_mask[token_start : token_end + 1] = 0 - - prediction_mask += instances_mask - prediction_mask[prediction_mask > 1] = 1 - - assert len(prediction_mask) == len(input_ids) - - # special symbols mask - special_symbols_mask = self._get_special_symbols_mask(input_ids) - - return TokenizationOutput( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - ) - - def _build_labels( - self, - sample, - tokenization_output: TokenizationOutput, - predictable_candidates: List[str], - ) -> Tuple[torch.Tensor, torch.Tensor]: - start_labels = [0] * len(tokenization_output.input_ids) - end_labels = [0] * len(tokenization_output.input_ids) - - char_start2token = {v: int(k) for k, v in sample.token2char_start.items()} - char_end2token = {v: int(k) for k, v in sample.token2char_end.items()} - for cs, ce, gold_candidate_title in sample.window_labels: - if gold_candidate_title not in predictable_candidates: - if self.use_nme: - gold_candidate_title = NME_SYMBOL - else: - continue - # +1 is to account for the CLS token - start_bpe = char_start2token[cs] + 1 - end_bpe = char_end2token[ce] + 1 - class_index = predictable_candidates.index(gold_candidate_title) - if ( - start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0 - ): # prevent from having entities that ends with the same label - start_labels[start_bpe] = class_index + 1 # +1 for the NONE class - end_labels[end_bpe] = class_index + 1 # +1 for the NONE class - else: - print( - "Found entity with the same last subword, it will not be included." - ) - print( - cs, - ce, - gold_candidate_title, - start_labels, - end_labels, - sample.doc_id, - ) - - ignored_labels_indices = tokenization_output.prediction_mask == 1 - - start_labels = torch.tensor(start_labels, dtype=torch.long) - start_labels[ignored_labels_indices] = -100 - - end_labels = torch.tensor(end_labels, dtype=torch.long) - end_labels[ignored_labels_indices] = -100 - - return start_labels, end_labels - - def produce_sample_bag( - self, sample, predictable_candidates: List[str], candidates_starting_offset: int - ) -> Optional[Tuple[dict, list, int]]: - # input sentence tokenization - input_subwords = sample.tokens[1:-1] # removing special tokens - candidates_symbols = self.special_symbols[candidates_starting_offset:] - - predictable_candidates = list(predictable_candidates) - original_predictable_candidates = list(predictable_candidates) - - # add NME as a possible candidate - if self.use_nme: - predictable_candidates.insert(0, NME_SYMBOL) - - # candidates encoding - candidates_symbols = candidates_symbols[: len(predictable_candidates)] - candidates_encoding_result = self.tokenizer.batch_encode_plus( - [ - "{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL - for cs, ct in zip(candidates_symbols, predictable_candidates) - ], - add_special_tokens=False, - ).input_ids - - if ( - self.max_subwords_per_candidate is not None - and self.max_subwords_per_candidate > 0 - ): - candidates_encoding_result = [ - cer[: self.max_subwords_per_candidate] - for cer in candidates_encoding_result - ] - - # drop candidates if the number of input tokens is too long for the model - if ( - sum(map(len, candidates_encoding_result)) - + len(input_subwords) - + 20 # + 20 special tokens - > self.model_max_length - ): - acceptable_tokens_from_candidates = ( - self.model_max_length - 20 - len(input_subwords) - ) - i = 0 - cum_len = 0 - while ( - cum_len + len(candidates_encoding_result[i]) - < acceptable_tokens_from_candidates - ): - cum_len += len(candidates_encoding_result[i]) - i += 1 - - candidates_encoding_result = candidates_encoding_result[:i] - candidates_symbols = candidates_symbols[:i] - predictable_candidates = predictable_candidates[:i] - - # final input_ids build - input_ids = self._build_input_ids( - sentence_input_ids=input_subwords, - candidates_input_ids=candidates_encoding_result, - ) - - # complete input building (e.g. attention / prediction mask) - tokenization_output = self._build_tokenizer_essentials( - input_ids, input_subwords, sample - ) - - output_dict = { - "input_ids": tokenization_output.input_ids, - "attention_mask": tokenization_output.attention_mask, - "token_type_ids": tokenization_output.token_type_ids, - "prediction_mask": tokenization_output.prediction_mask, - "special_symbols_mask": tokenization_output.special_symbols_mask, - "sample": sample, - "predictable_candidates_symbols": candidates_symbols, - "predictable_candidates": predictable_candidates, - } - - # labels creation - if sample.window_labels is not None: - start_labels, end_labels = self._build_labels( - sample, - tokenization_output, - predictable_candidates, - ) - output_dict.update(start_labels=start_labels, end_labels=end_labels) - - if ( - "roberta" in self.transformer_model - or "longformer" in self.transformer_model - ): - del output_dict["token_type_ids"] - - predictable_candidates_set = set(predictable_candidates) - remaining_candidates = [ - candidate - for candidate in original_predictable_candidates - if candidate not in predictable_candidates_set - ] - total_used_candidates = ( - candidates_starting_offset - + len(predictable_candidates) - - (1 if self.use_nme else 0) - ) - - if self.use_nme: - assert predictable_candidates[0] == NME_SYMBOL - - return output_dict, remaining_candidates, total_used_candidates - - def __iter__(self): - dataset_iterator = self.dataset_iterator_func() - - current_dataset_elements = [] - - i = None - for i, dataset_elem in enumerate(dataset_iterator, start=1): - if ( - self.section_size is not None - and len(current_dataset_elements) == self.section_size - ): - for batch in self.materialize_batches(current_dataset_elements): - yield batch - current_dataset_elements = [] - - current_dataset_elements.append(dataset_elem) - - if i % 50_000 == 0: - logger.info(f"Processed: {i} number of elements") - - if len(current_dataset_elements) != 0: - for batch in self.materialize_batches(current_dataset_elements): - yield batch - - if i is not None: - logger.info(f"Dataset finished: {i} number of elements processed") - else: - logger.warning("Dataset empty") - - def dataset_iterator_func(self): - skipped_instances = 0 - data_samples = ( - load_relik_reader_samples(self.dataset_path) - if self.samples is None - else self.samples - ) - for sample in data_samples: - preprocess_sample( - sample, self.tokenizer, lowercase_policy=self.lowercase_policy - ) - current_patch = 0 - sample_bag, used_candidates = None, None - remaining_candidates = list(sample.window_candidates) - - if not self.for_inference: - # randomly drop gold candidates at training time - if ( - self.random_drop_gold_candidates > 0.0 - and np.random.uniform() < self.random_drop_gold_candidates - and len(set(ct for _, _, ct in sample.window_labels)) > 1 - ): - # selecting candidates to drop - np.random.shuffle(sample.window_labels) - n_dropped_candidates = np.random.randint( - 0, len(sample.window_labels) - 1 - ) - dropped_candidates = [ - label_elem[-1] - for label_elem in sample.window_labels[:n_dropped_candidates] - ] - dropped_candidates = set(dropped_candidates) - - # saving NMEs because they should not be dropped - if NME_SYMBOL in dropped_candidates: - dropped_candidates.remove(NME_SYMBOL) - - # sample update - sample.window_labels = [ - (s, e, _l) - if _l not in dropped_candidates - else (s, e, NME_SYMBOL) - for s, e, _l in sample.window_labels - ] - remaining_candidates = [ - wc - for wc in remaining_candidates - if wc not in dropped_candidates - ] - - # shuffle candidates - if ( - isinstance(self.shuffle_candidates, bool) - and self.shuffle_candidates - ) or ( - isinstance(self.shuffle_candidates, float) - and np.random.uniform() < self.shuffle_candidates - ): - np.random.shuffle(remaining_candidates) - - while len(remaining_candidates) != 0: - sample_bag = self.produce_sample_bag( - sample, - predictable_candidates=remaining_candidates, - candidates_starting_offset=used_candidates - if used_candidates is not None - else 0, - ) - if sample_bag is not None: - sample_bag, remaining_candidates, used_candidates = sample_bag - if ( - self.for_inference - or not self.skip_empty_training_samples - or ( - ( - sample_bag.get("start_labels") is not None - and torch.any(sample_bag["start_labels"] > 1).item() - ) - or ( - sample_bag.get("optimus_labels") is not None - and len(sample_bag["optimus_labels"]) > 0 - ) - ) - ): - sample_bag["patch_offset"] = current_patch - current_patch += 1 - yield sample_bag - else: - skipped_instances += 1 - if skipped_instances % 1000 == 0 and skipped_instances != 0: - logger.info( - f"Skipped {skipped_instances} instances since they did not have any gold labels..." - ) - - # Just use the first fitting candidates if split on - # cand is not True - if not self.split_on_cand_overload: - break - - def preshuffle_elements(self, dataset_elements: List): - # This shuffling is done so that when using the sorting function, - # if it is deterministic given a collection and its order, we will - # make the whole operation not deterministic anymore. - # Basically, the aim is not to build every time the same batches. - if not self.for_inference: - dataset_elements = np.random.permutation(dataset_elements) - - sorting_fn = ( - lambda elem: add_noise_to_value( - sum(len(elem[k]) for k in self.sorting_fields), - noise_param=self.noise_param, - ) - if not self.for_inference - else sum(len(elem[k]) for k in self.sorting_fields) - ) - - dataset_elements = sorted(dataset_elements, key=sorting_fn) - - if self.for_inference: - return dataset_elements - - ds = list(chunks(dataset_elements, 64)) - np.random.shuffle(ds) - return flatten(ds) - - def materialize_batches( - self, dataset_elements: List[Dict[str, Any]] - ) -> Generator[Dict[str, Any], None, None]: - if self.prebatch: - dataset_elements = self.preshuffle_elements(dataset_elements) - - current_batch = [] - - # function that creates a batch from the 'current_batch' list - def output_batch() -> Dict[str, Any]: - assert ( - len( - set([len(elem["predictable_candidates"]) for elem in current_batch]) - ) - == 1 - ), " ".join( - map( - str, [len(elem["predictable_candidates"]) for elem in current_batch] - ) - ) - - batch_dict = dict() - - de_values_by_field = { - fn: [de[fn] for de in current_batch if fn in de] - for fn in self.fields_batcher - } - - # in case you provide fields batchers but in the batch - # there are no elements for that field - de_values_by_field = { - fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 - } - - assert len(set([len(v) for v in de_values_by_field.values()])) - - # todo: maybe we should report the user about possible - # fields filtering due to "None" instances - de_values_by_field = { - fn: fvs - for fn, fvs in de_values_by_field.items() - if all([fv is not None for fv in fvs]) - } - - for field_name, field_values in de_values_by_field.items(): - field_batch = ( - self.fields_batcher[field_name](field_values) - if self.fields_batcher[field_name] is not None - else field_values - ) - - batch_dict[field_name] = field_batch - - return batch_dict - - max_len_discards, min_len_discards = 0, 0 - - should_token_batch = self.batch_size is None - - curr_pred_elements = -1 - for de in dataset_elements: - if ( - should_token_batch - and self.max_batch_size != -1 - and len(current_batch) == self.max_batch_size - ) or (not should_token_batch and len(current_batch) == self.batch_size): - yield output_batch() - current_batch = [] - curr_pred_elements = -1 - - too_long_fields = [ - k - for k in de - if self.max_length != -1 - and torch.is_tensor(de[k]) - and len(de[k]) > self.max_length - ] - if len(too_long_fields) > 0: - max_len_discards += 1 - continue - - too_short_fields = [ - k - for k in de - if self.min_length != -1 - and torch.is_tensor(de[k]) - and len(de[k]) < self.min_length - ] - if len(too_short_fields) > 0: - min_len_discards += 1 - continue - - if should_token_batch: - de_len = sum(len(de[k]) for k in self.batching_fields) - - future_max_len = max( - de_len, - max( - [ - sum(len(bde[k]) for k in self.batching_fields) - for bde in current_batch - ], - default=0, - ), - ) - - future_tokens_per_batch = future_max_len * (len(current_batch) + 1) - - num_predictable_candidates = len(de["predictable_candidates"]) - - if len(current_batch) > 0 and ( - future_tokens_per_batch >= self.tokens_per_batch - or ( - num_predictable_candidates != curr_pred_elements - and curr_pred_elements != -1 - ) - ): - yield output_batch() - current_batch = [] - - current_batch.append(de) - curr_pred_elements = len(de["predictable_candidates"]) - - if len(current_batch) != 0 and not self.drop_last: - yield output_batch() - - if max_len_discards > 0: - if self.for_inference: - logger.warning( - f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " - f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" - f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " - f"sample length exceeds the maximum length supported by the current model." - ) - else: - logger.warning( - f"During iteration, {max_len_discards} elements were " - f"discarded since longer than max length {self.max_length}" - ) - - if min_len_discards > 0: - if self.for_inference: - logger.warning( - f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " - f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" - f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " - f"sample length is shorter than the minimum length supported by the current model." - ) - else: - logger.warning( - f"During iteration, {min_len_discards} elements were " - f"discarded since shorter than min length {self.min_length}" - ) - - @staticmethod - def convert_tokens_to_char_annotations( - sample: RelikReaderSample, - remove_nmes: bool = True, - ) -> RelikReaderSample: - """ - Converts the token annotations to char annotations. - - Args: - sample (:obj:`RelikReaderSample`): - The sample to convert. - remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`): - Whether to remove the NMEs from the annotations. - Returns: - :obj:`RelikReaderSample`: The converted sample. - """ - char_annotations = set() - for ( - predicted_entity, - predicted_spans, - ) in sample.predicted_window_labels.items(): - if predicted_entity == NME_SYMBOL and remove_nmes: - continue - - for span_start, span_end in predicted_spans: - span_start = sample.token2char_start[str(span_start)] - span_end = sample.token2char_end[str(span_end)] - - char_annotations.add((span_start, span_end, predicted_entity)) - - char_probs_annotations = dict() - for ( - span_start, - span_end, - ), candidates_probs in sample.span_title_probabilities.items(): - span_start = sample.token2char_start[str(span_start)] - span_end = sample.token2char_end[str(span_end)] - char_probs_annotations[(span_start, span_end)] = { - title for title, _ in candidates_probs - } - - sample.predicted_window_labels_chars = char_annotations - sample.probs_window_labels_chars = char_probs_annotations - - return sample - - @staticmethod - def merge_patches_predictions(sample) -> None: - sample._d["predicted_window_labels"] = dict() - predicted_window_labels = sample._d["predicted_window_labels"] - - sample._d["span_title_probabilities"] = dict() - span_title_probabilities = sample._d["span_title_probabilities"] - - span2title = dict() - for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): - # selecting span predictions - for predicted_title, predicted_spans in patch_info[ - "predicted_window_labels" - ].items(): - for pred_span in predicted_spans: - pred_span = tuple(pred_span) - curr_title = span2title.get(pred_span) - if curr_title is None or curr_title == NME_SYMBOL: - span2title[pred_span] = predicted_title - # else: - # print("Merging at patch level") - - # selecting span predictions probability - for predicted_span, titles_probabilities in patch_info[ - "span_title_probabilities" - ].items(): - if predicted_span not in span_title_probabilities: - span_title_probabilities[predicted_span] = titles_probabilities - - for span, title in span2title.items(): - if title not in predicted_window_labels: - predicted_window_labels[title] = list() - predicted_window_labels[title].append(span) diff --git a/relik/reader/data/relik_reader_data_utils.py b/relik/reader/data/relik_reader_data_utils.py deleted file mode 100644 index 3c7446bee296d14653a35895bf9ec8071c87e5af..0000000000000000000000000000000000000000 --- a/relik/reader/data/relik_reader_data_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import List - -import numpy as np -import torch - - -def flatten(lsts: List[list]) -> list: - acc_lst = list() - for lst in lsts: - acc_lst.extend(lst) - return acc_lst - - -def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor: - return torch.nn.utils.rnn.pad_sequence( - tensors, batch_first=True, padding_value=padding_value - ) - - -def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: - x = max([t.shape[0] for t in tensors]) - y = max([t.shape[1] for t in tensors]) - out_matrix = torch.zeros((len(tensors), x, y)) - out_matrix += padding_value - for i, tensor in enumerate(tensors): - out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor - return out_matrix - - -def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: - x = max([t.shape[0] for t in tensors]) - y = max([t.shape[1] for t in tensors]) - rest = tensors[0].shape[2] - out_matrix = torch.zeros((len(tensors), x, y, rest)) - out_matrix += padding_value - for i, tensor in enumerate(tensors): - out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor - return out_matrix - - -def chunks(lst: list, chunk_size: int) -> List[list]: - chunks_acc = list() - for i in range(0, len(lst), chunk_size): - chunks_acc.append(lst[i : i + chunk_size]) - return chunks_acc - - -def add_noise_to_value(value: int, noise_param: float): - noise_value = value * noise_param - noise = np.random.uniform(-noise_value, noise_value) - return max(1, value + noise) diff --git a/relik/reader/data/relik_reader_sample.py b/relik/reader/data/relik_reader_sample.py deleted file mode 100644 index 3d7570411fbb939f99d73d1cc3318b21552bc7c2..0000000000000000000000000000000000000000 --- a/relik/reader/data/relik_reader_sample.py +++ /dev/null @@ -1,49 +0,0 @@ -import json -from typing import Iterable - - -class RelikReaderSample: - def __init__(self, **kwargs): - super().__setattr__("_d", {}) - self._d = kwargs - - def __getattribute__(self, item): - return super(RelikReaderSample, self).__getattribute__(item) - - def __getattr__(self, item): - if item.startswith("__") and item.endswith("__"): - # this is likely some python library-specific variable (such as __deepcopy__ for copy) - # better follow standard behavior here - raise AttributeError(item) - elif item in self._d: - return self._d[item] - else: - return None - - def __setattr__(self, key, value): - if key in self._d: - self._d[key] = value - else: - super().__setattr__(key, value) - - def to_jsons(self) -> str: - if "predicted_window_labels" in self._d: - new_obj = { - k: v - for k, v in self._d.items() - if k != "predicted_window_labels" and k != "span_title_probabilities" - } - new_obj["predicted_window_labels"] = [ - [ss, se, pred_title] - for (ss, se), pred_title in self.predicted_window_labels_chars - ] - else: - return json.dumps(self._d) - - -def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]: - with open(path) as f: - for line in f: - jsonl_line = json.loads(line.strip()) - relik_reader_sample = RelikReaderSample(**jsonl_line) - yield relik_reader_sample diff --git a/relik/reader/lightning_modules/__init__.py b/relik/reader/lightning_modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/reader/lightning_modules/relik_reader_pl_module.py b/relik/reader/lightning_modules/relik_reader_pl_module.py deleted file mode 100644 index 4e66e87b6360fe9b4a72477fd5a7fe6295b53ae9..0000000000000000000000000000000000000000 --- a/relik/reader/lightning_modules/relik_reader_pl_module.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Any, Optional - -import lightning -from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler - -from relik.reader.relik_reader_core import RelikReaderCoreModel - - -class RelikReaderPLModule(lightning.LightningModule): - def __init__( - self, - cfg: dict, - transformer_model: str, - additional_special_symbols: int, - num_layers: Optional[int] = None, - activation: str = "gelu", - linears_hidden_size: Optional[int] = 512, - use_last_k_layers: int = 1, - training: bool = False, - *args: Any, - **kwargs: Any - ): - super().__init__(*args, **kwargs) - self.save_hyperparameters() - self.relik_reader_core_model = RelikReaderCoreModel( - transformer_model, - additional_special_symbols, - num_layers, - activation, - linears_hidden_size, - use_last_k_layers, - training=training, - ) - self.optimizer_factory = None - - def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - relik_output = self.relik_reader_core_model(**batch) - self.log("train-loss", relik_output["loss"]) - return relik_output["loss"] - - def validation_step( - self, batch: dict, *args: Any, **kwargs: Any - ) -> Optional[STEP_OUTPUT]: - return - - def set_optimizer_factory(self, optimizer_factory) -> None: - self.optimizer_factory = optimizer_factory - - def configure_optimizers(self) -> OptimizerLRScheduler: - return self.optimizer_factory(self.relik_reader_core_model) diff --git a/relik/reader/lightning_modules/relik_reader_re_pl_module.py b/relik/reader/lightning_modules/relik_reader_re_pl_module.py deleted file mode 100644 index ad1d2084f10d700d68b30e68dc29cbace1450f9b..0000000000000000000000000000000000000000 --- a/relik/reader/lightning_modules/relik_reader_re_pl_module.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Any, Optional - -import lightning -from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler - -from relik.reader.relik_reader_re import RelikReaderForTripletExtraction - - -class RelikReaderREPLModule(lightning.LightningModule): - def __init__( - self, - cfg: dict, - transformer_model: str, - additional_special_symbols: int, - num_layers: Optional[int] = None, - activation: str = "gelu", - linears_hidden_size: Optional[int] = 512, - use_last_k_layers: int = 1, - training: bool = False, - *args: Any, - **kwargs: Any - ): - super().__init__(*args, **kwargs) - self.save_hyperparameters() - - self.relik_reader_re_model = RelikReaderForTripletExtraction( - transformer_model, - additional_special_symbols, - num_layers, - activation, - linears_hidden_size, - use_last_k_layers, - training=training, - ) - self.optimizer_factory = None - - def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - relik_output = self.relik_reader_re_model(**batch) - self.log("train-loss", relik_output["loss"]) - self.log("train-start_loss", relik_output["ned_start_loss"]) - self.log("train-end_loss", relik_output["ned_end_loss"]) - self.log("train-relation_loss", relik_output["re_loss"]) - return relik_output["loss"] - - def validation_step( - self, batch: dict, *args: Any, **kwargs: Any - ) -> Optional[STEP_OUTPUT]: - return - - def set_optimizer_factory(self, optimizer_factory) -> None: - self.optimizer_factory = optimizer_factory - - def configure_optimizers(self) -> OptimizerLRScheduler: - return self.optimizer_factory(self.relik_reader_re_model) diff --git a/relik/reader/pytorch_modules/__init__.py b/relik/reader/pytorch_modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/reader/pytorch_modules/base.py b/relik/reader/pytorch_modules/base.py deleted file mode 100644 index 75db716d53d6dbdb9e9f95b63dfa1d8619769bbf..0000000000000000000000000000000000000000 --- a/relik/reader/pytorch_modules/base.py +++ /dev/null @@ -1,248 +0,0 @@ -import logging -import os -from pathlib import Path -from typing import Any, Dict, List - -import torch -import transformers as tr -from torch.utils.data import IterableDataset -from transformers import AutoConfig - -from relik.common.log import get_console_logger, get_logger -from relik.common.utils import get_callable_from_string -from relik.reader.pytorch_modules.hf.modeling_relik import ( - RelikReaderConfig, - RelikReaderSample, -) - -console_logger = get_console_logger() -logger = get_logger(__name__, level=logging.INFO) - - -class RelikReaderBase(torch.nn.Module): - default_reader_class: str | None = None - default_data_class: str | None = None - - def __init__( - self, - transformer_model: str | tr.PreTrainedModel | None = None, - additional_special_symbols: int = 0, - num_layers: int | None = None, - activation: str = "gelu", - linears_hidden_size: int | None = 512, - use_last_k_layers: int = 1, - training: bool = False, - device: str | torch.device | None = None, - precision: int = 32, - tokenizer: str | tr.PreTrainedTokenizer | None = None, - dataset: IterableDataset | str | None = None, - default_reader_class: tr.PreTrainedModel | str | None = None, - **kwargs, - ) -> None: - super().__init__() - - self.default_reader_class = default_reader_class or self.default_reader_class - - if self.default_reader_class is None: - raise ValueError("You must specify a default reader class.") - - # get the callable for the default reader class - self.default_reader_class: tr.PreTrainedModel = get_callable_from_string( - self.default_reader_class - ) - - if isinstance(transformer_model, str): - config = AutoConfig.from_pretrained( - transformer_model, trust_remote_code=True - ) - if "relik-reader" in config.model_type: - transformer_model = self.default_reader_class.from_pretrained( - transformer_model, **kwargs - ) - else: - reader_config = RelikReaderConfig( - transformer_model=transformer_model, - additional_special_symbols=additional_special_symbols, - num_layers=num_layers, - activation=activation, - linears_hidden_size=linears_hidden_size, - use_last_k_layers=use_last_k_layers, - training=training, - ) - transformer_model = self.default_reader_class(reader_config) - - self.relik_reader_model = transformer_model - self.relik_reader_model_config = self.relik_reader_model.config - - # get the tokenizer - self._tokenizer = tokenizer - - # and instantiate the dataset class - self.dataset: IterableDataset | None = dataset - - # move the model to the device - self.to(device or torch.device("cpu")) - - # set the precision - self.precision = precision - - def forward(self, **kwargs) -> Dict[str, Any]: - return self.relik_reader_model(**kwargs) - - def _read(self, *args, **kwargs) -> Any: - raise NotImplementedError - - @torch.no_grad() - @torch.inference_mode() - def read( - self, - text: List[str] | List[List[str]] | None = None, - samples: List[RelikReaderSample] | None = None, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - token_type_ids: torch.Tensor | None = None, - prediction_mask: torch.Tensor | None = None, - special_symbols_mask: torch.Tensor | None = None, - candidates: List[List[str]] | None = None, - max_length: int = 1000, - max_batch_size: int = 128, - token_batch_size: int = 2048, - precision: int | str | None = None, - progress_bar: bool = False, - *args, - **kwargs, - ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: - """ - Reads the given text. - - Args: - text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`): - The text to read in tokens. If a list of list of tokens is provided, each - inner list is considered a sentence. - samples (:obj:`List[RelikReaderSample]`, `optional`): - The samples to read. If provided, `text` and `candidates` are ignored. - input_ids (:obj:`torch.Tensor`, `optional`): - The input ids of the text. - attention_mask (:obj:`torch.Tensor`, `optional`): - The attention mask of the text. - token_type_ids (:obj:`torch.Tensor`, `optional`): - The token type ids of the text. - prediction_mask (:obj:`torch.Tensor`, `optional`): - The prediction mask of the text. - special_symbols_mask (:obj:`torch.Tensor`, `optional`): - The special symbols mask of the text. - candidates (:obj:`List[List[str]]`, `optional`): - The candidates of the text. - max_length (:obj:`int`, `optional`, defaults to 1024): - The maximum length of the text. - max_batch_size (:obj:`int`, `optional`, defaults to 128): - The maximum batch size. - token_batch_size (:obj:`int`, `optional`): - The maximum number of tokens per batch. - precision (:obj:`int` or :obj:`str`, `optional`): - The precision to use. If not provided, the default is 32 bit. - progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to show a progress bar. - - Returns: - The predicted labels for each sample. - """ - if text is None and input_ids is None and samples is None: - raise ValueError( - "Either `text` or `input_ids` or `samples` must be provided." - ) - if (input_ids is None and samples is None) and ( - text is None or candidates is None - ): - raise ValueError( - "`text` and `candidates` must be provided to return the predictions when " - "`input_ids` and `samples` is not provided." - ) - if text is not None and samples is None: - if len(text) != len(candidates): - raise ValueError("`text` and `candidates` must have the same length.") - if isinstance(text[0], str): # change to list of text - text = [text] - candidates = [candidates] - - samples = [ - RelikReaderSample(tokens=t, candidates=c) - for t, c in zip(text, candidates) - ] - - return self._read( - samples, - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - max_length, - max_batch_size, - token_batch_size, - precision or self.precision, - progress_bar, - *args, - **kwargs, - ) - - @property - def device(self) -> torch.device: - """ - The device of the model. - """ - return next(self.parameters()).device - - @property - def tokenizer(self) -> tr.PreTrainedTokenizer: - """ - The tokenizer. - """ - if self._tokenizer: - return self._tokenizer - - self._tokenizer = tr.AutoTokenizer.from_pretrained( - self.relik_reader_model.config.name_or_path - ) - return self._tokenizer - - def save_pretrained( - self, - output_dir: str | os.PathLike, - model_name: str | None = None, - push_to_hub: bool = False, - **kwargs, - ) -> None: - """ - Saves the model to the given path. - - Args: - output_dir (`str` or :obj:`os.PathLike`): - The path to save the model to. - model_name (`str`, `optional`): - The name of the model. If not provided, the model will be saved as - `default_reader_class.__name__`. - push_to_hub (`bool`, `optional`, defaults to `False`): - Whether to push the model to the HuggingFace Hub. - **kwargs: - Additional keyword arguments to pass to the `save_pretrained` method - """ - # create the output directory - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - model_name = model_name or self.default_reader_class.__name__ - - logger.info(f"Saving reader to {output_dir / model_name}") - - # save the model - self.relik_reader_model.register_for_auto_class() - self.relik_reader_model.save_pretrained( - output_dir / model_name, push_to_hub=push_to_hub, **kwargs - ) - - if self.tokenizer: - logger.info("Saving also the tokenizer") - self.tokenizer.save_pretrained( - output_dir / model_name, push_to_hub=push_to_hub, **kwargs - ) diff --git a/relik/reader/pytorch_modules/hf/__init__.py b/relik/reader/pytorch_modules/hf/__init__.py deleted file mode 100644 index c9c158e6ab6dcd3ab43e60751218600fbb0a5ed5..0000000000000000000000000000000000000000 --- a/relik/reader/pytorch_modules/hf/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .configuration_relik import RelikReaderConfig -from .modeling_relik import RelikReaderREModel diff --git a/relik/reader/pytorch_modules/hf/configuration_relik.py b/relik/reader/pytorch_modules/hf/configuration_relik.py deleted file mode 100644 index 6683823926b4b09a5ad169ef4e0f5b92061d774e..0000000000000000000000000000000000000000 --- a/relik/reader/pytorch_modules/hf/configuration_relik.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -from transformers import AutoConfig -from transformers.configuration_utils import PretrainedConfig - - -class RelikReaderConfig(PretrainedConfig): - model_type = "relik-reader" - - def __init__( - self, - transformer_model: str = "microsoft/deberta-v3-base", - additional_special_symbols: int = 101, - num_layers: Optional[int] = None, - activation: str = "gelu", - linears_hidden_size: Optional[int] = 512, - use_last_k_layers: int = 1, - training: bool = False, - default_reader_class: Optional[str] = None, - **kwargs - ) -> None: - self.transformer_model = transformer_model - self.additional_special_symbols = additional_special_symbols - self.num_layers = num_layers - self.activation = activation - self.linears_hidden_size = linears_hidden_size - self.use_last_k_layers = use_last_k_layers - self.training = training - self.default_reader_class = default_reader_class - super().__init__(**kwargs) - - -AutoConfig.register("relik-reader", RelikReaderConfig) diff --git a/relik/reader/pytorch_modules/hf/modeling_relik.py b/relik/reader/pytorch_modules/hf/modeling_relik.py deleted file mode 100644 index f79fc14e0cabe9f830187467578ff3f65351c9a2..0000000000000000000000000000000000000000 --- a/relik/reader/pytorch_modules/hf/modeling_relik.py +++ /dev/null @@ -1,981 +0,0 @@ -from typing import Any, Dict, Optional - -import torch -from transformers import AutoModel, PreTrainedModel -from transformers.activations import ClippedGELUActivation, GELUActivation -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_utils import PoolerEndLogits - -from .configuration_relik import RelikReaderConfig - - -class RelikReaderSample: - def __init__(self, **kwargs): - super().__setattr__("_d", {}) - self._d = kwargs - - def __getattribute__(self, item): - return super(RelikReaderSample, self).__getattribute__(item) - - def __getattr__(self, item): - if item.startswith("__") and item.endswith("__"): - # this is likely some python library-specific variable (such as __deepcopy__ for copy) - # better follow standard behavior here - raise AttributeError(item) - elif item in self._d: - return self._d[item] - else: - return None - - def __setattr__(self, key, value): - if key in self._d: - self._d[key] = value - else: - super().__setattr__(key, value) - - -activation2functions = { - "relu": torch.nn.ReLU(), - "gelu": GELUActivation(), - "gelu_10": ClippedGELUActivation(-10, 10), -} - - -class PoolerEndLogitsBi(PoolerEndLogits): - def __init__(self, config: PretrainedConfig): - super().__init__(config) - self.dense_1 = torch.nn.Linear(config.hidden_size, 2) - - def forward( - self, - hidden_states: torch.FloatTensor, - start_states: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - p_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - if p_mask is not None: - p_mask = p_mask.unsqueeze(-1) - logits = super().forward( - hidden_states, - start_states, - start_positions, - p_mask, - ) - return logits - - -class RelikReaderSpanModel(PreTrainedModel): - config_class = RelikReaderConfig - - def __init__(self, config: RelikReaderConfig, *args, **kwargs): - super().__init__(config) - # Transformer model declaration - self.config = config - self.transformer_model = ( - AutoModel.from_pretrained(self.config.transformer_model) - if self.config.num_layers is None - else AutoModel.from_pretrained( - self.config.transformer_model, num_hidden_layers=self.config.num_layers - ) - ) - self.transformer_model.resize_token_embeddings( - self.transformer_model.config.vocab_size - + self.config.additional_special_symbols - ) - - self.activation = self.config.activation - self.linears_hidden_size = self.config.linears_hidden_size - self.use_last_k_layers = self.config.use_last_k_layers - - # named entity detection layers - self.ned_start_classifier = self._get_projection_layer( - self.activation, last_hidden=2, layer_norm=False - ) - self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) - - # END entity disambiguation layer - self.ed_start_projector = self._get_projection_layer(self.activation) - self.ed_end_projector = self._get_projection_layer(self.activation) - - self.training = self.config.training - - # criterion - self.criterion = torch.nn.CrossEntropyLoss() - - def _get_projection_layer( - self, - activation: str, - last_hidden: Optional[int] = None, - input_hidden=None, - layer_norm: bool = True, - ) -> torch.nn.Sequential: - head_components = [ - torch.nn.Dropout(0.1), - torch.nn.Linear( - self.transformer_model.config.hidden_size * self.use_last_k_layers - if input_hidden is None - else input_hidden, - self.linears_hidden_size, - ), - activation2functions[activation], - torch.nn.Dropout(0.1), - torch.nn.Linear( - self.linears_hidden_size, - self.linears_hidden_size if last_hidden is None else last_hidden, - ), - ] - - if layer_norm: - head_components.append( - torch.nn.LayerNorm( - self.linears_hidden_size if last_hidden is None else last_hidden, - self.transformer_model.config.layer_norm_eps, - ) - ) - - return torch.nn.Sequential(*head_components) - - def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - mask = mask.unsqueeze(-1) - if next(self.parameters()).dtype == torch.float16: - logits = logits * (1 - mask) - 65500 * mask - else: - logits = logits * (1 - mask) - 1e30 * mask - return logits - - def _get_model_features( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: Optional[torch.Tensor], - ): - model_input = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "output_hidden_states": self.use_last_k_layers > 1, - } - - if token_type_ids is not None: - model_input["token_type_ids"] = token_type_ids - - model_output = self.transformer_model(**model_input) - - if self.use_last_k_layers > 1: - model_features = torch.cat( - model_output[1][-self.use_last_k_layers :], dim=-1 - ) - else: - model_features = model_output[0] - - return model_features - - def compute_ned_end_logits( - self, - start_predictions, - start_labels, - model_features, - prediction_mask, - batch_size, - ) -> Optional[torch.Tensor]: - # todo: maybe when constraining on the spans, - # we should not use a prediction_mask for the end tokens. - # at least we should not during training imo - start_positions = start_labels if self.training else start_predictions - start_positions_indices = ( - torch.arange(start_positions.size(1), device=start_positions.device) - .unsqueeze(0) - .expand(batch_size, -1)[start_positions > 0] - ).to(start_positions.device) - - if len(start_positions_indices) > 0: - expanded_features = torch.cat( - [ - model_features[i].unsqueeze(0).expand(x, -1, -1) - for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) - if x > 0 - ], - dim=0, - ).to(start_positions_indices.device) - - expanded_prediction_mask = torch.cat( - [ - prediction_mask[i].unsqueeze(0).expand(x, -1) - for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) - if x > 0 - ], - dim=0, - ).to(expanded_features.device) - - end_logits = self.ned_end_classifier( - hidden_states=expanded_features, - start_positions=start_positions_indices, - p_mask=expanded_prediction_mask, - ) - - return end_logits - - return None - - def compute_classification_logits( - self, - model_features, - special_symbols_mask, - prediction_mask, - batch_size, - start_positions=None, - end_positions=None, - ) -> torch.Tensor: - if start_positions is None or end_positions is None: - start_positions = torch.zeros_like(prediction_mask) - end_positions = torch.zeros_like(prediction_mask) - - model_start_features = self.ed_start_projector(model_features) - model_end_features = self.ed_end_projector(model_features) - model_end_features[start_positions > 0] = model_end_features[end_positions > 0] - - model_ed_features = torch.cat( - [model_start_features, model_end_features], dim=-1 - ) - - # computing ed features - classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() - special_symbols_representation = model_ed_features[special_symbols_mask].view( - batch_size, classes_representations, -1 - ) - - logits = torch.bmm( - model_ed_features, - torch.permute(special_symbols_representation, (0, 2, 1)), - ) - - logits = self._mask_logits(logits, prediction_mask) - - return logits - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, - prediction_mask: Optional[torch.Tensor] = None, - special_symbols_mask: Optional[torch.Tensor] = None, - start_labels: Optional[torch.Tensor] = None, - end_labels: Optional[torch.Tensor] = None, - use_predefined_spans: bool = False, - *args, - **kwargs, - ) -> Dict[str, Any]: - batch_size, seq_len = input_ids.shape - - model_features = self._get_model_features( - input_ids, attention_mask, token_type_ids - ) - - ned_start_labels = None - - # named entity detection if required - if use_predefined_spans: # no need to compute spans - ned_start_logits, ned_start_probabilities, ned_start_predictions = ( - None, - None, - torch.clone(start_labels) - if start_labels is not None - else torch.zeros_like(input_ids), - ) - ned_end_logits, ned_end_probabilities, ned_end_predictions = ( - None, - None, - torch.clone(end_labels) - if end_labels is not None - else torch.zeros_like(input_ids), - ) - - ned_start_predictions[ned_start_predictions > 0] = 1 - ned_end_predictions[ned_end_predictions > 0] = 1 - - else: # compute spans - # start boundary prediction - ned_start_logits = self.ned_start_classifier(model_features) - ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) - ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) - ned_start_predictions = ned_start_probabilities.argmax(dim=-1) - - # end boundary prediction - ned_start_labels = ( - torch.zeros_like(start_labels) if start_labels is not None else None - ) - - if ned_start_labels is not None: - ned_start_labels[start_labels == -100] = -100 - ned_start_labels[start_labels > 0] = 1 - - ned_end_logits = self.compute_ned_end_logits( - ned_start_predictions, - ned_start_labels, - model_features, - prediction_mask, - batch_size, - ) - - if ned_end_logits is not None: - ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) - ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) - else: - ned_end_logits, ned_end_probabilities = None, None - ned_end_predictions = ned_start_predictions.new_zeros(batch_size) - - # flattening end predictions - # (flattening can happen only if the - # end boundaries were not predicted using the gold labels) - if not self.training: - flattened_end_predictions = torch.clone(ned_start_predictions) - flattened_end_predictions[flattened_end_predictions > 0] = 0 - - batch_start_predictions = list() - for elem_idx in range(batch_size): - batch_start_predictions.append( - torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() - ) - - # check that the total number of start predictions - # is equal to the end predictions - total_start_predictions = sum(map(len, batch_start_predictions)) - total_end_predictions = len(ned_end_predictions) - assert ( - total_start_predictions == 0 - or total_start_predictions == total_end_predictions - ), ( - f"Total number of start predictions = {total_start_predictions}. " - f"Total number of end predictions = {total_end_predictions}" - ) - - curr_end_pred_num = 0 - for elem_idx, bsp in enumerate(batch_start_predictions): - for sp in bsp: - ep = ned_end_predictions[curr_end_pred_num].item() - if ep < sp: - ep = sp - - # if we already set this span throw it (no overlap) - if flattened_end_predictions[elem_idx, ep] == 1: - ned_start_predictions[elem_idx, sp] = 0 - else: - flattened_end_predictions[elem_idx, ep] = 1 - - curr_end_pred_num += 1 - - ned_end_predictions = flattened_end_predictions - - start_position, end_position = ( - (start_labels, end_labels) - if self.training - else (ned_start_predictions, ned_end_predictions) - ) - - # Entity disambiguation - ed_logits = self.compute_classification_logits( - model_features, - special_symbols_mask, - prediction_mask, - batch_size, - start_position, - end_position, - ) - ed_probabilities = torch.softmax(ed_logits, dim=-1) - ed_predictions = torch.argmax(ed_probabilities, dim=-1) - - # output build - output_dict = dict( - batch_size=batch_size, - ned_start_logits=ned_start_logits, - ned_start_probabilities=ned_start_probabilities, - ned_start_predictions=ned_start_predictions, - ned_end_logits=ned_end_logits, - ned_end_probabilities=ned_end_probabilities, - ned_end_predictions=ned_end_predictions, - ed_logits=ed_logits, - ed_probabilities=ed_probabilities, - ed_predictions=ed_predictions, - ) - - # compute loss if labels - if start_labels is not None and end_labels is not None and self.training: - # named entity detection loss - - # start - if ned_start_logits is not None: - ned_start_loss = self.criterion( - ned_start_logits.view(-1, ned_start_logits.shape[-1]), - ned_start_labels.view(-1), - ) - else: - ned_start_loss = 0 - - # end - if ned_end_logits is not None: - ned_end_labels = torch.zeros_like(end_labels) - ned_end_labels[end_labels == -100] = -100 - ned_end_labels[end_labels > 0] = 1 - - ned_end_loss = self.criterion( - ned_end_logits, - ( - torch.arange( - ned_end_labels.size(1), device=ned_end_labels.device - ) - .unsqueeze(0) - .expand(batch_size, -1)[ned_end_labels > 0] - ).to(ned_end_labels.device), - ) - - else: - ned_end_loss = 0 - - # entity disambiguation loss - start_labels[ned_start_labels != 1] = -100 - ed_labels = torch.clone(start_labels) - ed_labels[end_labels > 0] = end_labels[end_labels > 0] - ed_loss = self.criterion( - ed_logits.view(-1, ed_logits.shape[-1]), - ed_labels.view(-1), - ) - - output_dict["ned_start_loss"] = ned_start_loss - output_dict["ned_end_loss"] = ned_end_loss - output_dict["ed_loss"] = ed_loss - - output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss - - return output_dict - - -class RelikReaderREModel(PreTrainedModel): - config_class = RelikReaderConfig - - def __init__(self, config, *args, **kwargs): - super().__init__(config) - # Transformer model declaration - # self.transformer_model_name = transformer_model - self.config = config - self.transformer_model = ( - AutoModel.from_pretrained(config.transformer_model) - if config.num_layers is None - else AutoModel.from_pretrained( - config.transformer_model, num_hidden_layers=config.num_layers - ) - ) - self.transformer_model.resize_token_embeddings( - self.transformer_model.config.vocab_size + config.additional_special_symbols - ) - - # named entity detection layers - self.ned_start_classifier = self._get_projection_layer( - config.activation, last_hidden=2, layer_norm=False - ) - - self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config) - - self.entity_type_loss = ( - config.entity_type_loss if hasattr(config, "entity_type_loss") else False - ) - self.relation_disambiguation_loss = ( - config.relation_disambiguation_loss - if hasattr(config, "relation_disambiguation_loss") - else False - ) - - input_hidden_ents = 2 * self.transformer_model.config.hidden_size - - self.re_subject_projector = self._get_projection_layer( - config.activation, input_hidden=input_hidden_ents - ) - self.re_object_projector = self._get_projection_layer( - config.activation, input_hidden=input_hidden_ents - ) - self.re_relation_projector = self._get_projection_layer(config.activation) - - if self.entity_type_loss or self.relation_disambiguation_loss: - self.re_entities_projector = self._get_projection_layer( - config.activation, - input_hidden=2 * self.transformer_model.config.hidden_size, - ) - self.re_definition_projector = self._get_projection_layer( - config.activation, - ) - - self.re_classifier = self._get_projection_layer( - config.activation, - input_hidden=config.linears_hidden_size, - last_hidden=2, - layer_norm=False, - ) - - if self.entity_type_loss or self.relation_disambiguation_loss: - self.re_ed_classifier = self._get_projection_layer( - config.activation, - input_hidden=config.linears_hidden_size, - last_hidden=2, - layer_norm=False, - ) - - self.training = config.training - - # criterion - self.criterion = torch.nn.CrossEntropyLoss() - - def _get_projection_layer( - self, - activation: str, - last_hidden: Optional[int] = None, - input_hidden=None, - layer_norm: bool = True, - ) -> torch.nn.Sequential: - head_components = [ - torch.nn.Dropout(0.1), - torch.nn.Linear( - self.transformer_model.config.hidden_size - * self.config.use_last_k_layers - if input_hidden is None - else input_hidden, - self.config.linears_hidden_size, - ), - activation2functions[activation], - torch.nn.Dropout(0.1), - torch.nn.Linear( - self.config.linears_hidden_size, - self.config.linears_hidden_size if last_hidden is None else last_hidden, - ), - ] - - if layer_norm: - head_components.append( - torch.nn.LayerNorm( - self.config.linears_hidden_size - if last_hidden is None - else last_hidden, - self.transformer_model.config.layer_norm_eps, - ) - ) - - return torch.nn.Sequential(*head_components) - - def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - mask = mask.unsqueeze(-1) - if next(self.parameters()).dtype == torch.float16: - logits = logits * (1 - mask) - 65500 * mask - else: - logits = logits * (1 - mask) - 1e30 * mask - return logits - - def _get_model_features( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: Optional[torch.Tensor], - ): - model_input = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "output_hidden_states": self.config.use_last_k_layers > 1, - } - - if token_type_ids is not None: - model_input["token_type_ids"] = token_type_ids - - model_output = self.transformer_model(**model_input) - - if self.config.use_last_k_layers > 1: - model_features = torch.cat( - model_output[1][-self.config.use_last_k_layers :], dim=-1 - ) - else: - model_features = model_output[0] - - return model_features - - def compute_ned_end_logits( - self, - start_predictions, - start_labels, - model_features, - prediction_mask, - batch_size, - ) -> Optional[torch.Tensor]: - # todo: maybe when constraining on the spans, - # we should not use a prediction_mask for the end tokens. - # at least we should not during training imo - start_positions = start_labels if self.training else start_predictions - start_positions_indices = ( - torch.arange(start_positions.size(1), device=start_positions.device) - .unsqueeze(0) - .expand(batch_size, -1)[start_positions > 0] - ).to(start_positions.device) - - if len(start_positions_indices) > 0: - expanded_features = torch.cat( - [ - model_features[i].unsqueeze(0).expand(x, -1, -1) - for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) - if x > 0 - ], - dim=0, - ).to(start_positions_indices.device) - - expanded_prediction_mask = torch.cat( - [ - prediction_mask[i].unsqueeze(0).expand(x, -1) - for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) - if x > 0 - ], - dim=0, - ).to(expanded_features.device) - - # mask all tokens before start_positions_indices ie, mask all tokens with - # indices < start_positions_indices with 1, ie. [range(x) for x in start_positions_indices] - expanded_prediction_mask = torch.stack( - [ - torch.cat( - [ - torch.ones(x, device=expanded_features.device), - expanded_prediction_mask[i, x:], - ] - ) - for i, x in enumerate(start_positions_indices) - if x > 0 - ], - dim=0, - ).to(expanded_features.device) - - end_logits = self.ned_end_classifier( - hidden_states=expanded_features, - start_positions=start_positions_indices, - p_mask=expanded_prediction_mask, - ) - - return end_logits - - return None - - def compute_relation_logits( - self, - model_entity_features, - special_symbols_features, - ) -> torch.Tensor: - model_subject_features = self.re_subject_projector(model_entity_features) - model_object_features = self.re_object_projector(model_entity_features) - special_symbols_start_representation = self.re_relation_projector( - special_symbols_features - ) - re_logits = torch.einsum( - "bse,bde,bfe->bsdfe", - model_subject_features, - model_object_features, - special_symbols_start_representation, - ) - re_logits = self.re_classifier(re_logits) - - return re_logits - - def compute_entity_logits( - self, - model_entity_features, - special_symbols_features, - ) -> torch.Tensor: - model_ed_features = self.re_entities_projector(model_entity_features) - special_symbols_ed_representation = self.re_definition_projector( - special_symbols_features - ) - logits = torch.einsum( - "bce,bde->bcde", - model_ed_features, - special_symbols_ed_representation, - ) - logits = self.re_ed_classifier(logits) - start_logits = self._mask_logits( - logits, - (model_entity_features == -100) - .all(2) - .long() - .unsqueeze(2) - .repeat(1, 1, torch.sum(model_entity_features, dim=1)[0].item()), - ) - - return logits - - def compute_loss(self, logits, labels, mask=None): - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1).long() - if mask is not None: - return self.criterion(logits[mask], labels[mask]) - return self.criterion(logits, labels) - - def compute_ned_end_loss(self, ned_end_logits, end_labels): - if ned_end_logits is None: - return 0 - ned_end_labels = torch.zeros_like(end_labels) - ned_end_labels[end_labels == -100] = -100 - ned_end_labels[end_labels > 0] = 1 - return self.compute_loss(ned_end_logits, ned_end_labels) - - def compute_ned_type_loss( - self, - disambiguation_labels, - re_ned_entities_logits, - ned_type_logits, - re_entities_logits, - entity_types, - ): - if self.entity_type_loss and self.relation_disambiguation_loss: - return self.compute_loss(disambiguation_labels, re_ned_entities_logits) - if self.entity_type_loss: - return self.compute_loss( - disambiguation_labels[:, :, :entity_types], ned_type_logits - ) - if self.relation_disambiguation_loss: - return self.compute_loss(disambiguation_labels, re_entities_logits) - return 0 - - def compute_relation_loss(self, relation_labels, re_logits): - return self.compute_loss( - re_logits, relation_labels, relation_labels.view(-1) != -100 - ) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor, - prediction_mask: Optional[torch.Tensor] = None, - special_symbols_mask: Optional[torch.Tensor] = None, - special_symbols_mask_entities: Optional[torch.Tensor] = None, - start_labels: Optional[torch.Tensor] = None, - end_labels: Optional[torch.Tensor] = None, - disambiguation_labels: Optional[torch.Tensor] = None, - relation_labels: Optional[torch.Tensor] = None, - is_validation: bool = False, - is_prediction: bool = False, - *args, - **kwargs, - ) -> Dict[str, Any]: - batch_size = input_ids.shape[0] - - model_features = self._get_model_features( - input_ids, attention_mask, token_type_ids - ) - - # named entity detection - if is_prediction and start_labels is not None: - ned_start_logits, ned_start_probabilities, ned_start_predictions = ( - None, - None, - torch.zeros_like(start_labels), - ) - ned_end_logits, ned_end_probabilities, ned_end_predictions = ( - None, - None, - torch.zeros_like(end_labels), - ) - - ned_start_predictions[start_labels > 0] = 1 - ned_end_predictions[end_labels > 0] = 1 - ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)] - else: - # start boundary prediction - ned_start_logits = self.ned_start_classifier(model_features) - ned_start_logits = self._mask_logits( - ned_start_logits, prediction_mask - ) # why? - ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) - ned_start_predictions = ned_start_probabilities.argmax(dim=-1) - - # end boundary prediction - ned_start_labels = ( - torch.zeros_like(start_labels) if start_labels is not None else None - ) - - # start_labels contain entity id at their position, we just need 1 for start of entity - if ned_start_labels is not None: - ned_start_labels[start_labels > 0] = 1 - - # compute end logits only if there are any start predictions. - # For each start prediction, n end predictions are made - ned_end_logits = self.compute_ned_end_logits( - ned_start_predictions, - ned_start_labels, - model_features, - prediction_mask, - batch_size, - ) - # For each start prediction, n end predictions are made based on - # binary classification ie. argmax at each position. - ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) - ned_end_predictions = ned_end_probabilities.argmax(dim=-1) - if is_prediction or is_validation: - end_preds_count = ned_end_predictions.sum(1) - # If there are no end predictions for a start prediction, remove the start prediction - ned_start_predictions[ned_start_predictions == 1] = ( - end_preds_count != 0 - ).long() - ned_end_predictions = ned_end_predictions[end_preds_count != 0] - - if end_labels is not None: - end_labels = end_labels[~(end_labels == -100).all(2)] - - start_position, end_position = ( - (start_labels, end_labels) - if (not is_prediction and not is_validation) - else (ned_start_predictions, ned_end_predictions) - ) - - start_counts = (start_position > 0).sum(1) - ned_end_predictions = ned_end_predictions.split(start_counts.tolist()) - - # We can only predict relations if we have start and end predictions - if (end_position > 0).sum() > 0: - ends_count = (end_position > 0).sum(1) - model_subject_features = torch.cat( - [ - torch.repeat_interleave( - model_features[start_position > 0], ends_count, dim=0 - ), # start position features - torch.repeat_interleave(model_features, start_counts, dim=0)[ - end_position > 0 - ], # end position features - ], - dim=-1, - ) - ents_count = torch.nn.utils.rnn.pad_sequence( - torch.split(ends_count, start_counts.tolist()), - batch_first=True, - padding_value=0, - ).sum(1) - model_subject_features = torch.nn.utils.rnn.pad_sequence( - torch.split(model_subject_features, ents_count.tolist()), - batch_first=True, - padding_value=-100, - ) - - if is_validation or is_prediction: - model_subject_features = model_subject_features[:, :30, :] - - # entity disambiguation. Here relation_disambiguation_loss would only be useful to - # reduce the number of candidate relations for the next step, but currently unused. - if self.entity_type_loss or self.relation_disambiguation_loss: - (re_ned_entities_logits) = self.compute_entity_logits( - model_subject_features, - model_features[ - special_symbols_mask | special_symbols_mask_entities - ].view(batch_size, -1, model_features.shape[-1]), - ) - entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item() - ned_type_logits = re_ned_entities_logits[:, :, :entity_types] - re_entities_logits = re_ned_entities_logits[:, :, entity_types:] - - if self.entity_type_loss: - ned_type_probabilities = torch.softmax(ned_type_logits, dim=-1) - ned_type_predictions = ned_type_probabilities.argmax(dim=-1) - ned_type_predictions = ned_type_predictions.argmax(dim=-1) - - re_entities_probabilities = torch.softmax(re_entities_logits, dim=-1) - re_entities_predictions = re_entities_probabilities.argmax(dim=-1) - else: - ( - ned_type_logits, - ned_type_probabilities, - re_entities_logits, - re_entities_probabilities, - ) = (None, None, None, None) - ned_type_predictions, re_entities_predictions = ( - torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), - torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), - ) - - # Compute relation logits - re_logits = self.compute_relation_logits( - model_subject_features, - model_features[special_symbols_mask].view( - batch_size, -1, model_features.shape[-1] - ), - ) - - re_probabilities = torch.softmax(re_logits, dim=-1) - # we set a thresshold instead of argmax in cause it needs to be tweaked - re_predictions = re_probabilities[:, :, :, :, 1] > 0.5 - # re_predictions = re_probabilities.argmax(dim=-1) - re_probabilities = re_probabilities[:, :, :, :, 1] - - else: - ( - ned_type_logits, - ned_type_probabilities, - re_entities_logits, - re_entities_probabilities, - ) = (None, None, None, None) - ned_type_predictions, re_entities_predictions = ( - torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), - torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), - ) - re_logits, re_probabilities, re_predictions = ( - torch.zeros( - [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long - ).to(input_ids.device), - torch.zeros( - [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long - ).to(input_ids.device), - torch.zeros( - [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long - ).to(input_ids.device), - ) - - # output build - output_dict = dict( - batch_size=batch_size, - ned_start_logits=ned_start_logits, - ned_start_probabilities=ned_start_probabilities, - ned_start_predictions=ned_start_predictions, - ned_end_logits=ned_end_logits, - ned_end_probabilities=ned_end_probabilities, - ned_end_predictions=ned_end_predictions, - ned_type_logits=ned_type_logits, - ned_type_probabilities=ned_type_probabilities, - ned_type_predictions=ned_type_predictions, - re_entities_logits=re_entities_logits, - re_entities_probabilities=re_entities_probabilities, - re_entities_predictions=re_entities_predictions, - re_logits=re_logits, - re_probabilities=re_probabilities, - re_predictions=re_predictions, - ) - - if ( - start_labels is not None - and end_labels is not None - and relation_labels is not None - ): - ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels) - ned_end_loss = self.compute_ned_end_loss(ned_end_logits, end_labels) - if self.entity_type_loss or self.relation_disambiguation_loss: - ned_type_loss = self.compute_ned_type_loss( - disambiguation_labels, - re_ned_entities_logits, - ned_type_logits, - re_entities_logits, - entity_types, - ) - relation_loss = self.compute_relation_loss(relation_labels, re_logits) - # compute loss. We can skip the relation loss if we are in the first epochs (optional) - if self.entity_type_loss or self.relation_disambiguation_loss: - output_dict["loss"] = ( - ned_start_loss + ned_end_loss + relation_loss + ned_type_loss - ) / 4 - output_dict["ned_type_loss"] = ned_type_loss - else: - output_dict["loss"] = ( - ned_start_loss + ned_end_loss + relation_loss - ) / 3 - - output_dict["ned_start_loss"] = ned_start_loss - output_dict["ned_end_loss"] = ned_end_loss - output_dict["re_loss"] = relation_loss - - return output_dict diff --git a/relik/reader/pytorch_modules/optim/__init__.py b/relik/reader/pytorch_modules/optim/__init__.py deleted file mode 100644 index 369091133267cfa05240306fbfe5ea3b537d5d9c..0000000000000000000000000000000000000000 --- a/relik/reader/pytorch_modules/optim/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from relik.reader.pytorch_modules.optim.adamw_with_warmup import ( - AdamWWithWarmupOptimizer, -) -from relik.reader.pytorch_modules.optim.layer_wise_lr_decay import ( - LayerWiseLRDecayOptimizer, -) diff --git a/relik/reader/pytorch_modules/optim/adamw_with_warmup.py b/relik/reader/pytorch_modules/optim/adamw_with_warmup.py deleted file mode 100644 index dfaecc4ca3d1c366f25962db4d0024a5b986fd50..0000000000000000000000000000000000000000 --- a/relik/reader/pytorch_modules/optim/adamw_with_warmup.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import List - -import torch -import transformers -from torch.optim import AdamW - - -class AdamWWithWarmupOptimizer: - def __init__( - self, - lr: float, - warmup_steps: int, - total_steps: int, - weight_decay: float, - no_decay_params: List[str], - ): - self.lr = lr - self.warmup_steps = warmup_steps - self.total_steps = total_steps - self.weight_decay = weight_decay - self.no_decay_params = no_decay_params - - def group_params(self, module: torch.nn.Module) -> list: - if self.no_decay_params is not None: - optimizer_grouped_parameters = [ - { - "params": [ - p - for n, p in module.named_parameters() - if not any(nd in n for nd in self.no_decay_params) - ], - "weight_decay": self.weight_decay, - }, - { - "params": [ - p - for n, p in module.named_parameters() - if any(nd in n for nd in self.no_decay_params) - ], - "weight_decay": 0.0, - }, - ] - - else: - optimizer_grouped_parameters = [ - {"params": module.parameters(), "weight_decay": self.weight_decay} - ] - - return optimizer_grouped_parameters - - def __call__(self, module: torch.nn.Module): - optimizer_grouped_parameters = self.group_params(module) - optimizer = AdamW( - optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay - ) - scheduler = transformers.get_linear_schedule_with_warmup( - optimizer, self.warmup_steps, self.total_steps - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "interval": "step", - "frequency": 1, - }, - } diff --git a/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py b/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py deleted file mode 100644 index d179096153f356196a921c50083c96b3dcd5f246..0000000000000000000000000000000000000000 --- a/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py +++ /dev/null @@ -1,104 +0,0 @@ -import collections -from typing import List - -import torch -import transformers -from torch.optim import AdamW - - -class LayerWiseLRDecayOptimizer: - def __init__( - self, - lr: float, - warmup_steps: int, - total_steps: int, - weight_decay: float, - lr_decay: float, - no_decay_params: List[str], - total_reset: int, - ): - self.lr = lr - self.warmup_steps = warmup_steps - self.total_steps = total_steps - self.weight_decay = weight_decay - self.lr_decay = lr_decay - self.no_decay_params = no_decay_params - self.total_reset = total_reset - - def group_layers(self, module) -> dict: - grouped_layers = collections.defaultdict(list) - module_named_parameters = list(module.named_parameters()) - for ln, lp in module_named_parameters: - if "embeddings" in ln: - grouped_layers["embeddings"].append((ln, lp)) - elif "encoder.layer" in ln: - layer_num = ln.split("transformer_model.encoder.layer.")[-1] - layer_num = layer_num[0 : layer_num.index(".")] - grouped_layers[layer_num].append((ln, lp)) - else: - grouped_layers["head"].append((ln, lp)) - - depth = len(grouped_layers) - 1 - final_dict = dict() - for key, value in grouped_layers.items(): - if key == "head": - final_dict[0] = value - elif key == "embeddings": - final_dict[depth] = value - else: - # -1 because layer number starts from zero - final_dict[depth - int(key) - 1] = value - - assert len(module_named_parameters) == sum( - len(v) for _, v in final_dict.items() - ) - - return final_dict - - def group_params(self, module) -> list: - optimizer_grouped_params = [] - for inverse_depth, layer in self.group_layers(module).items(): - layer_lr = self.lr * (self.lr_decay**inverse_depth) - layer_wd_params = { - "params": [ - lp - for ln, lp in layer - if not any(nd in ln for nd in self.no_decay_params) - ], - "weight_decay": self.weight_decay, - "lr": layer_lr, - } - layer_no_wd_params = { - "params": [ - lp - for ln, lp in layer - if any(nd in ln for nd in self.no_decay_params) - ], - "weight_decay": 0, - "lr": layer_lr, - } - - if len(layer_wd_params) != 0: - optimizer_grouped_params.append(layer_wd_params) - if len(layer_no_wd_params) != 0: - optimizer_grouped_params.append(layer_no_wd_params) - - return optimizer_grouped_params - - def __call__(self, module: torch.nn.Module): - optimizer_grouped_parameters = self.group_params(module) - optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) - scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer, - self.warmup_steps, - self.total_steps, - num_cycles=self.total_reset, - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "interval": "step", - "frequency": 1, - }, - } diff --git a/relik/reader/pytorch_modules/span.py b/relik/reader/pytorch_modules/span.py deleted file mode 100644 index 349e42cafc1dfbc583adc46e7c8cf63d1d3752d8..0000000000000000000000000000000000000000 --- a/relik/reader/pytorch_modules/span.py +++ /dev/null @@ -1,367 +0,0 @@ -import collections -import contextlib -import logging -from typing import Any, Dict, Iterator, List - -import torch -import transformers as tr -from lightning_fabric.utilities import move_data_to_device -from torch.utils.data import DataLoader, IterableDataset -from tqdm import tqdm - -from relik.common.log import get_console_logger, get_logger -from relik.common.utils import get_callable_from_string -from relik.reader.data.relik_reader_sample import RelikReaderSample -from relik.reader.pytorch_modules.base import RelikReaderBase -from relik.reader.utils.special_symbols import get_special_symbols -from relik.retriever.pytorch_modules import PRECISION_MAP - -console_logger = get_console_logger() -logger = get_logger(__name__, level=logging.INFO) - - -class RelikReaderForSpanExtraction(RelikReaderBase): - """ - A class for the RelikReader model for span extraction. - - Args: - transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): - The transformer model to use. If `None`, the default model is used. - additional_special_symbols (:obj:`int`, `optional`, defaults to 0): - The number of additional special symbols to add to the tokenizer. - num_layers (:obj:`int`, `optional`): - The number of layers to use. If `None`, all layers are used. - activation (:obj:`str`, `optional`, defaults to "gelu"): - The activation function to use. - linears_hidden_size (:obj:`int`, `optional`, defaults to 512): - The hidden size of the linears. - use_last_k_layers (:obj:`int`, `optional`, defaults to 1): - The number of last layers to use. - training (:obj:`bool`, `optional`, defaults to False): - Whether the model is in training mode. - device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`): - The device to use. If `None`, the default device is used. - tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`): - The tokenizer to use. If `None`, the default tokenizer is used. - dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`): - The dataset to use. If `None`, the default dataset is used. - dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`): - The keyword arguments to pass to the dataset class. - default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): - The default reader class to use. If `None`, the default reader class is used. - **kwargs: - Keyword arguments. - """ - - default_reader_class: str = ( - "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel" - ) - default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset" - - def __init__( - self, - transformer_model: str | tr.PreTrainedModel | None = None, - additional_special_symbols: int = 0, - num_layers: int | None = None, - activation: str = "gelu", - linears_hidden_size: int | None = 512, - use_last_k_layers: int = 1, - training: bool = False, - device: str | torch.device | None = None, - tokenizer: str | tr.PreTrainedTokenizer | None = None, - dataset: IterableDataset | str | None = None, - dataset_kwargs: Dict[str, Any] | None = None, - default_reader_class: tr.PreTrainedModel | str | None = None, - **kwargs, - ): - super().__init__( - transformer_model=transformer_model, - additional_special_symbols=additional_special_symbols, - num_layers=num_layers, - activation=activation, - linears_hidden_size=linears_hidden_size, - use_last_k_layers=use_last_k_layers, - training=training, - device=device, - tokenizer=tokenizer, - dataset=dataset, - default_reader_class=default_reader_class, - **kwargs, - ) - # and instantiate the dataset class - self.dataset = dataset - if self.dataset is None: - default_data_kwargs = dict( - dataset_path=None, - materialize_samples=False, - transformer_model=self.tokenizer, - special_symbols=get_special_symbols( - self.relik_reader_model.config.additional_special_symbols - ), - for_inference=True, - ) - # merge the default data kwargs with the ones passed to the model - default_data_kwargs.update(dataset_kwargs or {}) - self.dataset = get_callable_from_string(self.default_data_class)( - **default_data_kwargs - ) - - @torch.no_grad() - @torch.inference_mode() - def _read( - self, - samples: List[RelikReaderSample] | None = None, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - token_type_ids: torch.Tensor | None = None, - prediction_mask: torch.Tensor | None = None, - special_symbols_mask: torch.Tensor | None = None, - max_length: int = 1000, - max_batch_size: int = 128, - token_batch_size: int = 2048, - precision: str = 32, - annotation_type: str = "char", - progress_bar: bool = False, - *args: object, - **kwargs: object, - ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: - """ - A wrapper around the forward method that returns the predicted labels for each sample. - - Args: - samples (:obj:`List[RelikReaderSample]`, `optional`): - The samples to read. If provided, `text` and `candidates` are ignored. - input_ids (:obj:`torch.Tensor`, `optional`): - The input ids of the text. If `samples` is provided, this is ignored. - attention_mask (:obj:`torch.Tensor`, `optional`): - The attention mask of the text. If `samples` is provided, this is ignored. - token_type_ids (:obj:`torch.Tensor`, `optional`): - The token type ids of the text. If `samples` is provided, this is ignored. - prediction_mask (:obj:`torch.Tensor`, `optional`): - The prediction mask of the text. If `samples` is provided, this is ignored. - special_symbols_mask (:obj:`torch.Tensor`, `optional`): - The special symbols mask of the text. If `samples` is provided, this is ignored. - max_length (:obj:`int`, `optional`, defaults to 1000): - The maximum length of the text. - max_batch_size (:obj:`int`, `optional`, defaults to 128): - The maximum batch size. - token_batch_size (:obj:`int`, `optional`): - The token batch size. - progress_bar (:obj:`bool`, `optional`, defaults to False): - Whether to show a progress bar. - precision (:obj:`str`, `optional`, defaults to 32): - The precision to use for the model. - annotation_type (:obj:`str`, `optional`, defaults to "char"): - The annotation type to use. It can be either "char", "token" or "word". - *args: - Positional arguments. - **kwargs: - Keyword arguments. - - Returns: - :obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`: - The predicted labels for each sample. - """ - - precision = precision or self.precision - if samples is not None: - - def _read_iterator(): - def samples_it(): - for i, sample in enumerate(samples): - assert sample._mixin_prediction_position is None - sample._mixin_prediction_position = i - yield sample - - next_prediction_position = 0 - position2predicted_sample = {} - - # instantiate dataset - if self.dataset is None: - raise ValueError( - "You need to pass a dataset to the model in order to predict" - ) - self.dataset.samples = samples_it() - self.dataset.model_max_length = max_length - self.dataset.tokens_per_batch = token_batch_size - self.dataset.max_batch_size = max_batch_size - - # instantiate dataloader - iterator = DataLoader( - self.dataset, batch_size=None, num_workers=0, shuffle=False - ) - if progress_bar: - iterator = tqdm(iterator, desc="Predicting with RelikReader") - - # fucking autocast only wants pure strings like 'cpu' or 'cuda' - # we need to convert the model device to that - device_type_for_autocast = str(self.device).split(":")[0] - # autocast doesn't work with CPU and stuff different from bfloat16 - autocast_mngr = ( - contextlib.nullcontext() - if device_type_for_autocast == "cpu" - else ( - torch.autocast( - device_type=device_type_for_autocast, - dtype=PRECISION_MAP[precision], - ) - ) - ) - - with autocast_mngr: - for batch in iterator: - batch = move_data_to_device(batch, self.device) - batch_out = self._batch_predict(**batch) - - for sample in batch_out: - if ( - sample._mixin_prediction_position - >= next_prediction_position - ): - position2predicted_sample[ - sample._mixin_prediction_position - ] = sample - - # yield - while next_prediction_position in position2predicted_sample: - yield position2predicted_sample[next_prediction_position] - del position2predicted_sample[next_prediction_position] - next_prediction_position += 1 - - outputs = list(_read_iterator()) - for sample in outputs: - self.dataset.merge_patches_predictions(sample) - self.dataset.convert_tokens_to_char_annotations(sample) - - else: - outputs = list( - self._batch_predict( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - *args, - **kwargs, - ) - ) - return outputs - - def _batch_predict( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor | None = None, - prediction_mask: torch.Tensor | None = None, - special_symbols_mask: torch.Tensor | None = None, - sample: List[RelikReaderSample] | None = None, - top_k: int = 5, # the amount of top-k most probable entities to predict - *args, - **kwargs, - ) -> Iterator[RelikReaderSample]: - """ - A wrapper around the forward method that returns the predicted labels for each sample. - It also adds the predicted labels to the samples. - - Args: - input_ids (:obj:`torch.Tensor`): - The input ids of the text. - attention_mask (:obj:`torch.Tensor`): - The attention mask of the text. - token_type_ids (:obj:`torch.Tensor`, `optional`): - The token type ids of the text. - prediction_mask (:obj:`torch.Tensor`, `optional`): - The prediction mask of the text. - special_symbols_mask (:obj:`torch.Tensor`, `optional`): - The special symbols mask of the text. - sample (:obj:`List[RelikReaderSample]`, `optional`): - The samples to read. If provided, `text` and `candidates` are ignored. - top_k (:obj:`int`, `optional`, defaults to 5): - The amount of top-k most probable entities to predict. - *args: - Positional arguments. - **kwargs: - Keyword arguments. - - Returns: - The predicted labels for each sample. - """ - forward_output = self.forward( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - prediction_mask=prediction_mask, - special_symbols_mask=special_symbols_mask, - ) - - ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() - ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() - ed_predictions = forward_output["ed_predictions"].cpu().numpy() - ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() - - batch_predictable_candidates = kwargs["predictable_candidates"] - patch_offset = kwargs["patch_offset"] - for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( - sample, - ned_start_predictions, - ned_end_predictions, - ed_predictions, - ed_probabilities, - batch_predictable_candidates, - patch_offset, - ): - ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] - ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] - - final_class2predicted_spans = collections.defaultdict(list) - spans2predicted_probabilities = dict() - for start_token_index, end_token_index in zip( - ne_start_indices, ne_end_indices - ): - # predicted candidate - token_class = edp[start_token_index + 1] - 1 - predicted_candidate_title = pred_cands[token_class] - final_class2predicted_spans[predicted_candidate_title].append( - [start_token_index, end_token_index] - ) - - # candidates probabilities - classes_probabilities = edpr[start_token_index + 1] - classes_probabilities_best_indices = classes_probabilities.argsort()[ - ::-1 - ] - titles_2_probs = [] - top_k = ( - min( - top_k, - len(classes_probabilities_best_indices), - ) - if top_k != -1 - else len(classes_probabilities_best_indices) - ) - for i in range(top_k): - titles_2_probs.append( - ( - pred_cands[classes_probabilities_best_indices[i] - 1], - classes_probabilities[ - classes_probabilities_best_indices[i] - ].item(), - ) - ) - spans2predicted_probabilities[ - (start_token_index, end_token_index) - ] = titles_2_probs - - if "patches" not in ts._d: - ts._d["patches"] = dict() - - ts._d["patches"][po] = dict() - sample_patch = ts._d["patches"][po] - - sample_patch["predicted_window_labels"] = final_class2predicted_spans - sample_patch["span_title_probabilities"] = spans2predicted_probabilities - - # additional info - sample_patch["predictable_candidates"] = pred_cands - - yield ts diff --git a/relik/reader/relik_reader.py b/relik/reader/relik_reader.py deleted file mode 100644 index 5acd9e8c4774593c4a61245ecf92a5559ab438f2..0000000000000000000000000000000000000000 --- a/relik/reader/relik_reader.py +++ /dev/null @@ -1,629 +0,0 @@ -import collections -import logging -from pathlib import Path -from typing import Any, Callable, Dict, Iterator, List, Union - -import torch -import transformers as tr -from tqdm import tqdm -from transformers import AutoConfig - -from relik.common.log import get_console_logger, get_logger -from relik.reader.data.relik_reader_data_utils import batchify, flatten -from relik.reader.data.relik_reader_sample import RelikReaderSample -from relik.reader.pytorch_modules.hf.modeling_relik import ( - RelikReaderConfig, - RelikReaderSpanModel, -) -from relik.reader.relik_reader_predictor import RelikReaderPredictor -from relik.reader.utils.save_load_utilities import load_model_and_conf -from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols - -console_logger = get_console_logger() -logger = get_logger(__name__, level=logging.INFO) - - -class RelikReaderForSpanExtraction(torch.nn.Module): - def __init__( - self, - transformer_model: str | tr.PreTrainedModel | None = None, - additional_special_symbols: int = 0, - num_layers: int | None = None, - activation: str = "gelu", - linears_hidden_size: int | None = 512, - use_last_k_layers: int = 1, - training: bool = False, - device: str | torch.device | None = None, - tokenizer: str | tr.PreTrainedTokenizer | None = None, - **kwargs, - ) -> None: - super().__init__() - - if isinstance(transformer_model, str): - config = AutoConfig.from_pretrained( - transformer_model, trust_remote_code=True - ) - if "relik-reader" in config.model_type: - transformer_model = RelikReaderSpanModel.from_pretrained( - transformer_model, **kwargs - ) - else: - reader_config = RelikReaderConfig( - transformer_model=transformer_model, - additional_special_symbols=additional_special_symbols, - num_layers=num_layers, - activation=activation, - linears_hidden_size=linears_hidden_size, - use_last_k_layers=use_last_k_layers, - training=training, - ) - transformer_model = RelikReaderSpanModel(reader_config) - - self.relik_reader_model = transformer_model - - self._tokenizer = tokenizer - - # move the model to the device - self.to(device or torch.device("cpu")) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor, - prediction_mask: torch.Tensor | None = None, - special_symbols_mask: torch.Tensor | None = None, - special_symbols_mask_entities: torch.Tensor | None = None, - start_labels: torch.Tensor | None = None, - end_labels: torch.Tensor | None = None, - disambiguation_labels: torch.Tensor | None = None, - relation_labels: torch.Tensor | None = None, - is_validation: bool = False, - is_prediction: bool = False, - *args, - **kwargs, - ) -> Dict[str, Any]: - return self.relik_reader_model( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - special_symbols_mask_entities, - start_labels, - end_labels, - disambiguation_labels, - relation_labels, - is_validation, - is_prediction, - *args, - **kwargs, - ) - - def batch_predict( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor | None = None, - prediction_mask: torch.Tensor | None = None, - special_symbols_mask: torch.Tensor | None = None, - sample: List[RelikReaderSample] | None = None, - top_k: int = 5, # the amount of top-k most probable entities to predict - *args, - **kwargs, - ) -> Iterator[RelikReaderSample]: - """ - - - Args: - input_ids: - attention_mask: - token_type_ids: - prediction_mask: - special_symbols_mask: - sample: - top_k: - *args: - **kwargs: - - Returns: - - """ - forward_output = self.forward( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - ) - - ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() - ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() - ed_predictions = forward_output["ed_predictions"].cpu().numpy() - ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() - - batch_predictable_candidates = kwargs["predictable_candidates"] - patch_offset = kwargs["patch_offset"] - for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( - sample, - ned_start_predictions, - ned_end_predictions, - ed_predictions, - ed_probabilities, - batch_predictable_candidates, - patch_offset, - ): - ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] - ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] - - final_class2predicted_spans = collections.defaultdict(list) - spans2predicted_probabilities = dict() - for start_token_index, end_token_index in zip( - ne_start_indices, ne_end_indices - ): - # predicted candidate - token_class = edp[start_token_index + 1] - 1 - predicted_candidate_title = pred_cands[token_class] - final_class2predicted_spans[predicted_candidate_title].append( - [start_token_index, end_token_index] - ) - - # candidates probabilities - classes_probabilities = edpr[start_token_index + 1] - classes_probabilities_best_indices = classes_probabilities.argsort()[ - ::-1 - ] - titles_2_probs = [] - top_k = ( - min( - top_k, - len(classes_probabilities_best_indices), - ) - if top_k != -1 - else len(classes_probabilities_best_indices) - ) - for i in range(top_k): - titles_2_probs.append( - ( - pred_cands[classes_probabilities_best_indices[i] - 1], - classes_probabilities[ - classes_probabilities_best_indices[i] - ].item(), - ) - ) - spans2predicted_probabilities[ - (start_token_index, end_token_index) - ] = titles_2_probs - - if "patches" not in ts._d: - ts._d["patches"] = dict() - - ts._d["patches"][po] = dict() - sample_patch = ts._d["patches"][po] - - sample_patch["predicted_window_labels"] = final_class2predicted_spans - sample_patch["span_title_probabilities"] = spans2predicted_probabilities - - # additional info - sample_patch["predictable_candidates"] = pred_cands - - yield ts - - def _build_input(self, text: List[str], candidates: List[List[str]]) -> list[str]: - candidates_symbols = get_special_symbols(len(candidates)) - candidates = [ - [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] - for cs, ct in zip(candidates_symbols, candidates) - ] - return ( - [self.tokenizer.cls_token] - + text - + [self.tokenizer.sep_token] - + flatten(candidates) - + [self.tokenizer.sep_token] - ) - - @staticmethod - def _compute_offsets(offsets_mapping): - offsets_mapping = offsets_mapping.numpy() - token2word = [] - word2token = {} - count = 0 - for i, offset in enumerate(offsets_mapping): - if offset[0] == 0: - token2word.append(i - count) - word2token[i - count] = [i] - else: - token2word.append(token2word[-1]) - word2token[token2word[-1]].append(i) - count += 1 - return token2word, word2token - - @staticmethod - def _convert_tokens_to_word_annotations(sample: RelikReaderSample): - triplets = [] - entities = [] - for entity in sample.predicted_entities: - if sample.entity_candidates: - entities.append( - ( - sample.token2word[entity[0] - 1], - sample.token2word[entity[1] - 1] + 1, - sample.entity_candidates[entity[2]], - ) - ) - else: - entities.append( - ( - sample.token2word[entity[0] - 1], - sample.token2word[entity[1] - 1] + 1, - -1, - ) - ) - for predicted_triplet, predicted_triplet_probabilities in zip( - sample.predicted_relations, sample.predicted_relations_probabilities - ): - subject, object_, relation = predicted_triplet - subject = entities[subject] - object_ = entities[object_] - relation = sample.candidates[relation] - triplets.append( - { - "subject": { - "start": subject[0], - "end": subject[1], - "type": subject[2], - "name": " ".join(sample.tokens[subject[0] : subject[1]]), - }, - "relation": { - "name": relation, - "probability": float(predicted_triplet_probabilities.round(2)), - }, - "object": { - "start": object_[0], - "end": object_[1], - "type": object_[2], - "name": " ".join(sample.tokens[object_[0] : object_[1]]), - }, - } - ) - sample.predicted_entities = entities - sample.predicted_relations = triplets - sample.predicted_relations_probabilities = None - - @torch.no_grad() - @torch.inference_mode() - def read( - self, - text: List[str] | List[List[str]] | None = None, - samples: List[RelikReaderSample] | None = None, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - token_type_ids: torch.Tensor | None = None, - prediction_mask: torch.Tensor | None = None, - special_symbols_mask: torch.Tensor | None = None, - special_symbols_mask_entities: torch.Tensor | None = None, - candidates: List[List[str]] | None = None, - max_length: int | None = 1024, - max_batch_size: int | None = 64, - token_batch_size: int | None = None, - progress_bar: bool = False, - *args, - **kwargs, - ) -> List[List[RelikReaderSample]]: - """ - Reads the given text. - Args: - text: The text to read in tokens. - samples: - input_ids: The input ids of the text. - attention_mask: The attention mask of the text. - token_type_ids: The token type ids of the text. - prediction_mask: The prediction mask of the text. - special_symbols_mask: The special symbols mask of the text. - special_symbols_mask_entities: The special symbols mask entities of the text. - candidates: The candidates of the text. - max_length: The maximum length of the text. - max_batch_size: The maximum batch size. - token_batch_size: The maximum number of tokens per batch. - progress_bar: - Returns: - The predicted labels for each sample. - """ - if text is None and input_ids is None and samples is None: - raise ValueError( - "Either `text` or `input_ids` or `samples` must be provided." - ) - if (input_ids is None and samples is None) and ( - text is None or candidates is None - ): - raise ValueError( - "`text` and `candidates` must be provided to return the predictions when " - "`input_ids` and `samples` is not provided." - ) - if text is not None and samples is None: - if len(text) != len(candidates): - raise ValueError("`text` and `candidates` must have the same length.") - if isinstance(text[0], str): # change to list of text - text = [text] - candidates = [candidates] - - samples = [ - RelikReaderSample(tokens=t, candidates=c) - for t, c in zip(text, candidates) - ] - - if samples is not None: - # function that creates a batch from the 'current_batch' list - def output_batch() -> Dict[str, Any]: - assert ( - len( - set( - [ - len(elem["predictable_candidates"]) - for elem in current_batch - ] - ) - ) - == 1 - ), " ".join( - map( - str, - [len(elem["predictable_candidates"]) for elem in current_batch], - ) - ) - - batch_dict = dict() - - de_values_by_field = { - fn: [de[fn] for de in current_batch if fn in de] - for fn in self.fields_batcher - } - - # in case you provide fields batchers but in the batch - # there are no elements for that field - de_values_by_field = { - fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 - } - - assert len(set([len(v) for v in de_values_by_field.values()])) - - # todo: maybe we should report the user about possible - # fields filtering due to "None" instances - de_values_by_field = { - fn: fvs - for fn, fvs in de_values_by_field.items() - if all([fv is not None for fv in fvs]) - } - - for field_name, field_values in de_values_by_field.items(): - field_batch = ( - self.fields_batcher[field_name]([fv[0] for fv in field_values]) - if self.fields_batcher[field_name] is not None - else field_values - ) - - batch_dict[field_name] = field_batch - - batch_dict = { - k: v.to(self.device) if isinstance(v, torch.Tensor) else v - for k, v in batch_dict.items() - } - return batch_dict - - current_batch = [] - predictions = [] - current_cand_len = -1 - - for sample in tqdm(samples, disable=not progress_bar): - sample.candidates = [NME_SYMBOL] + sample.candidates - inputs_text = self._build_input(sample.tokens, sample.candidates) - model_inputs = self.tokenizer( - inputs_text, - is_split_into_words=True, - add_special_tokens=False, - padding=False, - truncation=True, - max_length=max_length or self.tokenizer.model_max_length, - return_offsets_mapping=True, - return_tensors="pt", - ) - model_inputs["special_symbols_mask"] = ( - model_inputs["input_ids"] > self.tokenizer.vocab_size - ) - # prediction mask is 0 until the first special symbol - model_inputs["token_type_ids"] = ( - torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 - ).long() - # shift prediction_mask to the left - model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( - shifts=-1, dims=1 - ) - model_inputs["prediction_mask"][:, -1] = 1 - model_inputs["prediction_mask"][:, 0] = 1 - - assert ( - len(model_inputs["special_symbols_mask"]) - == len(model_inputs["prediction_mask"]) - == len(model_inputs["input_ids"]) - ) - - model_inputs["sample"] = sample - - # compute cand_len using special_symbols_mask - model_inputs["predictable_candidates"] = sample.candidates[ - : model_inputs["special_symbols_mask"].sum().item() - ] - # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]]) - offsets = model_inputs.pop("offset_mapping") - offsets = offsets[model_inputs["prediction_mask"] == 0] - sample.token2word, sample.word2token = self._compute_offsets(offsets) - future_max_len = max( - len(model_inputs["input_ids"]), - max([len(b["input_ids"]) for b in current_batch], default=0), - ) - future_tokens_per_batch = future_max_len * (len(current_batch) + 1) - - if len(current_batch) > 0 and ( - ( - len(model_inputs["predictable_candidates"]) != current_cand_len - and current_cand_len != -1 - ) - or ( - isinstance(token_batch_size, int) - and future_tokens_per_batch >= token_batch_size - ) - or len(current_batch) == max_batch_size - ): - batch_inputs = output_batch() - current_batch = [] - predictions.extend(list(self.batch_predict(**batch_inputs))) - current_cand_len = len(model_inputs["predictable_candidates"]) - current_batch.append(model_inputs) - - if current_batch: - batch_inputs = output_batch() - predictions.extend(list(self.batch_predict(**batch_inputs))) - else: - predictions = list( - self.batch_predict( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - special_symbols_mask_entities, - *args, - **kwargs, - ) - ) - return predictions - - @property - def device(self) -> torch.device: - """ - The device of the model. - """ - return next(self.parameters()).device - - @property - def tokenizer(self) -> tr.PreTrainedTokenizer: - """ - The tokenizer. - """ - if self._tokenizer: - return self._tokenizer - - self._tokenizer = tr.AutoTokenizer.from_pretrained( - self.relik_reader_model.config.name_or_path - ) - return self._tokenizer - - @property - def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: - fields_batchers = { - "input_ids": lambda x: batchify( - x, padding_value=self.tokenizer.pad_token_id - ), - "attention_mask": lambda x: batchify(x, padding_value=0), - "token_type_ids": lambda x: batchify(x, padding_value=0), - "prediction_mask": lambda x: batchify(x, padding_value=1), - "global_attention": lambda x: batchify(x, padding_value=0), - "token2word": None, - "sample": None, - "special_symbols_mask": lambda x: batchify(x, padding_value=False), - "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), - } - if "roberta" in self.relik_reader_model.config.model_type: - del fields_batchers["token_type_ids"] - - return fields_batchers - - def save_pretrained( - self, - output_dir: str, - model_name: str | None = None, - push_to_hub: bool = False, - **kwargs, - ) -> None: - """ - Saves the model to the given path. - Args: - output_dir: The path to save the model to. - model_name: The name of the model. - push_to_hub: Whether to push the model to the hub. - """ - # create the output directory - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - model_name = model_name or "relik-reader-for-span-extraction" - - logger.info(f"Saving reader to {output_dir / model_name}") - - # save the model - self.relik_reader_model.register_for_auto_class() - self.relik_reader_model.save_pretrained( - output_dir / model_name, push_to_hub=push_to_hub, **kwargs - ) - - logger.info("Saving reader to disk done.") - - if self.tokenizer: - self.tokenizer.save_pretrained( - output_dir / model_name, push_to_hub=push_to_hub, **kwargs - ) - logger.info("Saving tokenizer to disk done.") - - -class RelikReader: - def __init__(self, model_path: str, predict_nmes: bool = False): - model, model_conf = load_model_and_conf(model_path) - model.training = False - model.eval() - - val_dataset_conf = model_conf.data.val_dataset - val_dataset_conf.special_symbols = get_special_symbols( - model_conf.model.entities_per_forward - ) - val_dataset_conf.transformer_model = model_conf.model.model.transformer_model - - self.predictor = RelikReaderPredictor( - model, - dataset_conf=model_conf.data.val_dataset, - predict_nmes=predict_nmes, - ) - self.model_path = model_path - - def link_entities( - self, - dataset_path_or_samples: str | Iterator[RelikReaderSample], - token_batch_size: int = 2048, - progress_bar: bool = False, - ) -> List[RelikReaderSample]: - data_input = ( - (dataset_path_or_samples, None) - if isinstance(dataset_path_or_samples, str) - else (None, dataset_path_or_samples) - ) - return self.predictor.predict( - *data_input, - dataset_conf=None, - token_batch_size=token_batch_size, - progress_bar=progress_bar, - ) - - # def save_pretrained(self, path: Union[str, Path]): - # self.predictor.save(path) - - -def main(): - rr = RelikReader("riccorl/relik-reader-aida-deberta-small-old", predict_nmes=True) - predictions = rr.link_entities( - "/Users/ric/Documents/PhD/Projects/relik/data/reader/aida/testa.jsonl" - ) - print(predictions) - - -if __name__ == "__main__": - main() diff --git a/relik/reader/relik_reader_core.py b/relik/reader/relik_reader_core.py deleted file mode 100644 index 1d62c5f13b3c1f7e7ba02209d2c88813d4f960ac..0000000000000000000000000000000000000000 --- a/relik/reader/relik_reader_core.py +++ /dev/null @@ -1,497 +0,0 @@ -import collections -from typing import Any, Dict, Iterator, List, Optional - -import torch -from transformers import AutoModel -from transformers.activations import ClippedGELUActivation, GELUActivation -from transformers.modeling_utils import PoolerEndLogits - -from relik.reader.data.relik_reader_sample import RelikReaderSample - -activation2functions = { - "relu": torch.nn.ReLU(), - "gelu": GELUActivation(), - "gelu_10": ClippedGELUActivation(-10, 10), -} - - -class RelikReaderCoreModel(torch.nn.Module): - def __init__( - self, - transformer_model: str, - additional_special_symbols: int, - num_layers: Optional[int] = None, - activation: str = "gelu", - linears_hidden_size: Optional[int] = 512, - use_last_k_layers: int = 1, - training: bool = False, - ) -> None: - super().__init__() - - # Transformer model declaration - self.transformer_model_name = transformer_model - self.transformer_model = ( - AutoModel.from_pretrained(transformer_model) - if num_layers is None - else AutoModel.from_pretrained( - transformer_model, num_hidden_layers=num_layers - ) - ) - self.transformer_model.resize_token_embeddings( - self.transformer_model.config.vocab_size + additional_special_symbols - ) - - self.activation = activation - self.linears_hidden_size = linears_hidden_size - self.use_last_k_layers = use_last_k_layers - - # named entity detection layers - self.ned_start_classifier = self._get_projection_layer( - self.activation, last_hidden=2, layer_norm=False - ) - self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) - - # END entity disambiguation layer - self.ed_start_projector = self._get_projection_layer(self.activation) - self.ed_end_projector = self._get_projection_layer(self.activation) - - self.training = training - - # criterion - self.criterion = torch.nn.CrossEntropyLoss() - - def _get_projection_layer( - self, - activation: str, - last_hidden: Optional[int] = None, - input_hidden=None, - layer_norm: bool = True, - ) -> torch.nn.Sequential: - head_components = [ - torch.nn.Dropout(0.1), - torch.nn.Linear( - self.transformer_model.config.hidden_size * self.use_last_k_layers - if input_hidden is None - else input_hidden, - self.linears_hidden_size, - ), - activation2functions[activation], - torch.nn.Dropout(0.1), - torch.nn.Linear( - self.linears_hidden_size, - self.linears_hidden_size if last_hidden is None else last_hidden, - ), - ] - - if layer_norm: - head_components.append( - torch.nn.LayerNorm( - self.linears_hidden_size if last_hidden is None else last_hidden, - self.transformer_model.config.layer_norm_eps, - ) - ) - - return torch.nn.Sequential(*head_components) - - def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - mask = mask.unsqueeze(-1) - if next(self.parameters()).dtype == torch.float16: - logits = logits * (1 - mask) - 65500 * mask - else: - logits = logits * (1 - mask) - 1e30 * mask - return logits - - def _get_model_features( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: Optional[torch.Tensor], - ): - model_input = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "output_hidden_states": self.use_last_k_layers > 1, - } - - if token_type_ids is not None: - model_input["token_type_ids"] = token_type_ids - - model_output = self.transformer_model(**model_input) - - if self.use_last_k_layers > 1: - model_features = torch.cat( - model_output[1][-self.use_last_k_layers :], dim=-1 - ) - else: - model_features = model_output[0] - - return model_features - - def compute_ned_end_logits( - self, - start_predictions, - start_labels, - model_features, - prediction_mask, - batch_size, - ) -> Optional[torch.Tensor]: - # todo: maybe when constraining on the spans, - # we should not use a prediction_mask for the end tokens. - # at least we should not during training imo - start_positions = start_labels if self.training else start_predictions - start_positions_indices = ( - torch.arange(start_positions.size(1), device=start_positions.device) - .unsqueeze(0) - .expand(batch_size, -1)[start_positions > 0] - ).to(start_positions.device) - - if len(start_positions_indices) > 0: - expanded_features = torch.cat( - [ - model_features[i].unsqueeze(0).expand(x, -1, -1) - for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) - if x > 0 - ], - dim=0, - ).to(start_positions_indices.device) - - expanded_prediction_mask = torch.cat( - [ - prediction_mask[i].unsqueeze(0).expand(x, -1) - for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) - if x > 0 - ], - dim=0, - ).to(expanded_features.device) - - end_logits = self.ned_end_classifier( - hidden_states=expanded_features, - start_positions=start_positions_indices, - p_mask=expanded_prediction_mask, - ) - - return end_logits - - return None - - def compute_classification_logits( - self, - model_features, - special_symbols_mask, - prediction_mask, - batch_size, - start_positions=None, - end_positions=None, - ) -> torch.Tensor: - if start_positions is None or end_positions is None: - start_positions = torch.zeros_like(prediction_mask) - end_positions = torch.zeros_like(prediction_mask) - - model_start_features = self.ed_start_projector(model_features) - model_end_features = self.ed_end_projector(model_features) - model_end_features[start_positions > 0] = model_end_features[end_positions > 0] - - model_ed_features = torch.cat( - [model_start_features, model_end_features], dim=-1 - ) - - # computing ed features - classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() - special_symbols_representation = model_ed_features[special_symbols_mask].view( - batch_size, classes_representations, -1 - ) - - logits = torch.bmm( - model_ed_features, - torch.permute(special_symbols_representation, (0, 2, 1)), - ) - - logits = self._mask_logits(logits, prediction_mask) - - return logits - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, - prediction_mask: Optional[torch.Tensor] = None, - special_symbols_mask: Optional[torch.Tensor] = None, - start_labels: Optional[torch.Tensor] = None, - end_labels: Optional[torch.Tensor] = None, - use_predefined_spans: bool = False, - *args, - **kwargs, - ) -> Dict[str, Any]: - batch_size, seq_len = input_ids.shape - - model_features = self._get_model_features( - input_ids, attention_mask, token_type_ids - ) - - # named entity detection if required - if use_predefined_spans: # no need to compute spans - ned_start_logits, ned_start_probabilities, ned_start_predictions = ( - None, - None, - torch.clone(start_labels) - if start_labels is not None - else torch.zeros_like(input_ids), - ) - ned_end_logits, ned_end_probabilities, ned_end_predictions = ( - None, - None, - torch.clone(end_labels) - if end_labels is not None - else torch.zeros_like(input_ids), - ) - - ned_start_predictions[ned_start_predictions > 0] = 1 - ned_end_predictions[ned_end_predictions > 0] = 1 - - else: # compute spans - # start boundary prediction - ned_start_logits = self.ned_start_classifier(model_features) - ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) - ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) - ned_start_predictions = ned_start_probabilities.argmax(dim=-1) - - # end boundary prediction - ned_start_labels = ( - torch.zeros_like(start_labels) if start_labels is not None else None - ) - - if ned_start_labels is not None: - ned_start_labels[start_labels == -100] = -100 - ned_start_labels[start_labels > 0] = 1 - - ned_end_logits = self.compute_ned_end_logits( - ned_start_predictions, - ned_start_labels, - model_features, - prediction_mask, - batch_size, - ) - - if ned_end_logits is not None: - ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) - ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) - else: - ned_end_logits, ned_end_probabilities = None, None - ned_end_predictions = ned_start_predictions.new_zeros(batch_size) - - # flattening end predictions - # (flattening can happen only if the - # end boundaries were not predicted using the gold labels) - if not self.training: - flattened_end_predictions = torch.clone(ned_start_predictions) - flattened_end_predictions[flattened_end_predictions > 0] = 0 - - batch_start_predictions = list() - for elem_idx in range(batch_size): - batch_start_predictions.append( - torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() - ) - - # check that the total number of start predictions - # is equal to the end predictions - total_start_predictions = sum(map(len, batch_start_predictions)) - total_end_predictions = len(ned_end_predictions) - assert ( - total_start_predictions == 0 - or total_start_predictions == total_end_predictions - ), ( - f"Total number of start predictions = {total_start_predictions}. " - f"Total number of end predictions = {total_end_predictions}" - ) - - curr_end_pred_num = 0 - for elem_idx, bsp in enumerate(batch_start_predictions): - for sp in bsp: - ep = ned_end_predictions[curr_end_pred_num].item() - if ep < sp: - ep = sp - - # if we already set this span throw it (no overlap) - if flattened_end_predictions[elem_idx, ep] == 1: - ned_start_predictions[elem_idx, sp] = 0 - else: - flattened_end_predictions[elem_idx, ep] = 1 - - curr_end_pred_num += 1 - - ned_end_predictions = flattened_end_predictions - - start_position, end_position = ( - (start_labels, end_labels) - if self.training - else (ned_start_predictions, ned_end_predictions) - ) - - # Entity disambiguation - ed_logits = self.compute_classification_logits( - model_features, - special_symbols_mask, - prediction_mask, - batch_size, - start_position, - end_position, - ) - ed_probabilities = torch.softmax(ed_logits, dim=-1) - ed_predictions = torch.argmax(ed_probabilities, dim=-1) - - # output build - output_dict = dict( - batch_size=batch_size, - ned_start_logits=ned_start_logits, - ned_start_probabilities=ned_start_probabilities, - ned_start_predictions=ned_start_predictions, - ned_end_logits=ned_end_logits, - ned_end_probabilities=ned_end_probabilities, - ned_end_predictions=ned_end_predictions, - ed_logits=ed_logits, - ed_probabilities=ed_probabilities, - ed_predictions=ed_predictions, - ) - - # compute loss if labels - if start_labels is not None and end_labels is not None and self.training: - # named entity detection loss - - # start - if ned_start_logits is not None: - ned_start_loss = self.criterion( - ned_start_logits.view(-1, ned_start_logits.shape[-1]), - ned_start_labels.view(-1), - ) - else: - ned_start_loss = 0 - - # end - if ned_end_logits is not None: - ned_end_labels = torch.zeros_like(end_labels) - ned_end_labels[end_labels == -100] = -100 - ned_end_labels[end_labels > 0] = 1 - - ned_end_loss = self.criterion( - ned_end_logits, - ( - torch.arange( - ned_end_labels.size(1), device=ned_end_labels.device - ) - .unsqueeze(0) - .expand(batch_size, -1)[ned_end_labels > 0] - ).to(ned_end_labels.device), - ) - - else: - ned_end_loss = 0 - - # entity disambiguation loss - start_labels[ned_start_labels != 1] = -100 - ed_labels = torch.clone(start_labels) - ed_labels[end_labels > 0] = end_labels[end_labels > 0] - ed_loss = self.criterion( - ed_logits.view(-1, ed_logits.shape[-1]), - ed_labels.view(-1), - ) - - output_dict["ned_start_loss"] = ned_start_loss - output_dict["ned_end_loss"] = ned_end_loss - output_dict["ed_loss"] = ed_loss - - output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss - - return output_dict - - def batch_predict( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, - prediction_mask: Optional[torch.Tensor] = None, - special_symbols_mask: Optional[torch.Tensor] = None, - sample: Optional[List[RelikReaderSample]] = None, - top_k: int = 5, # the amount of top-k most probable entities to predict - *args, - **kwargs, - ) -> Iterator[RelikReaderSample]: - forward_output = self.forward( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - ) - - ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() - ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() - ed_predictions = forward_output["ed_predictions"].cpu().numpy() - ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() - - batch_predictable_candidates = kwargs["predictable_candidates"] - patch_offset = kwargs["patch_offset"] - for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( - sample, - ned_start_predictions, - ned_end_predictions, - ed_predictions, - ed_probabilities, - batch_predictable_candidates, - patch_offset, - ): - ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] - ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] - - final_class2predicted_spans = collections.defaultdict(list) - spans2predicted_probabilities = dict() - for start_token_index, end_token_index in zip( - ne_start_indices, ne_end_indices - ): - # predicted candidate - token_class = edp[start_token_index + 1] - 1 - predicted_candidate_title = pred_cands[token_class] - final_class2predicted_spans[predicted_candidate_title].append( - [start_token_index, end_token_index] - ) - - # candidates probabilities - classes_probabilities = edpr[start_token_index + 1] - classes_probabilities_best_indices = classes_probabilities.argsort()[ - ::-1 - ] - titles_2_probs = [] - top_k = ( - min( - top_k, - len(classes_probabilities_best_indices), - ) - if top_k != -1 - else len(classes_probabilities_best_indices) - ) - for i in range(top_k): - titles_2_probs.append( - ( - pred_cands[classes_probabilities_best_indices[i] - 1], - classes_probabilities[ - classes_probabilities_best_indices[i] - ].item(), - ) - ) - spans2predicted_probabilities[ - (start_token_index, end_token_index) - ] = titles_2_probs - - if "patches" not in ts._d: - ts._d["patches"] = dict() - - ts._d["patches"][po] = dict() - sample_patch = ts._d["patches"][po] - - sample_patch["predicted_window_labels"] = final_class2predicted_spans - sample_patch["span_title_probabilities"] = spans2predicted_probabilities - - # additional info - sample_patch["predictable_candidates"] = pred_cands - - yield ts diff --git a/relik/reader/relik_reader_predictor.py b/relik/reader/relik_reader_predictor.py deleted file mode 100644 index e5635d477d67febce3c937ef1945900f004bb269..0000000000000000000000000000000000000000 --- a/relik/reader/relik_reader_predictor.py +++ /dev/null @@ -1,168 +0,0 @@ -import logging -from typing import Iterable, Iterator, List, Optional - -import hydra -import torch -from lightning.pytorch.utilities import move_data_to_device -from torch.utils.data import DataLoader -from tqdm import tqdm - -from relik.reader.data.patches import merge_patches_predictions -from relik.reader.data.relik_reader_sample import ( - RelikReaderSample, - load_relik_reader_samples, -) -from relik.reader.relik_reader_core import RelikReaderCoreModel -from relik.reader.utils.special_symbols import NME_SYMBOL - -logger = logging.getLogger(__name__) - - -def convert_tokens_to_char_annotations( - sample: RelikReaderSample, remove_nmes: bool = False -): - char_annotations = set() - - for ( - predicted_entity, - predicted_spans, - ) in sample.predicted_window_labels.items(): - if predicted_entity == NME_SYMBOL and remove_nmes: - continue - - for span_start, span_end in predicted_spans: - span_start = sample.token2char_start[str(span_start)] - span_end = sample.token2char_end[str(span_end)] - - char_annotations.add((span_start, span_end, predicted_entity)) - - char_probs_annotations = dict() - for ( - span_start, - span_end, - ), candidates_probs in sample.span_title_probabilities.items(): - span_start = sample.token2char_start[str(span_start)] - span_end = sample.token2char_end[str(span_end)] - char_probs_annotations[(span_start, span_end)] = { - title for title, _ in candidates_probs - } - - sample.predicted_window_labels_chars = char_annotations - sample.probs_window_labels_chars = char_probs_annotations - - -class RelikReaderPredictor: - def __init__( - self, - relik_reader_core: RelikReaderCoreModel, - dataset_conf: Optional[dict] = None, - predict_nmes: bool = False, - ) -> None: - self.relik_reader_core = relik_reader_core - self.dataset_conf = dataset_conf - self.predict_nmes = predict_nmes - - if self.dataset_conf is not None: - # instantiate dataset - self.dataset = hydra.utils.instantiate( - dataset_conf, - dataset_path=None, - samples=None, - ) - - def predict( - self, - path: Optional[str], - samples: Optional[Iterable[RelikReaderSample]], - dataset_conf: Optional[dict], - token_batch_size: int = 1024, - progress_bar: bool = False, - **kwargs, - ) -> List[RelikReaderSample]: - annotated_samples = list( - self._predict(path, samples, dataset_conf, token_batch_size, progress_bar) - ) - for sample in annotated_samples: - merge_patches_predictions(sample) - convert_tokens_to_char_annotations( - sample, remove_nmes=not self.predict_nmes - ) - return annotated_samples - - def _predict( - self, - path: Optional[str], - samples: Optional[Iterable[RelikReaderSample]], - dataset_conf: dict, - token_batch_size: int = 1024, - progress_bar: bool = False, - **kwargs, - ) -> Iterator[RelikReaderSample]: - assert ( - path is not None or samples is not None - ), "Either predict on a path or on an iterable of samples" - - samples = load_relik_reader_samples(path) if samples is None else samples - - # setup infrastructure to re-yield in order - def samples_it(): - for i, sample in enumerate(samples): - assert sample._mixin_prediction_position is None - sample._mixin_prediction_position = i - yield sample - - next_prediction_position = 0 - position2predicted_sample = {} - - # instantiate dataset - if getattr(self, "dataset", None) is not None: - dataset = self.dataset - dataset.samples = samples_it() - dataset.tokens_per_batch = token_batch_size - else: - dataset = hydra.utils.instantiate( - dataset_conf, - dataset_path=None, - samples=samples_it(), - tokens_per_batch=token_batch_size, - ) - - # instantiate dataloader - iterator = DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False) - if progress_bar: - iterator = tqdm(iterator, desc="Predicting") - - model_device = next(self.relik_reader_core.parameters()).device - - with torch.inference_mode(): - for batch in iterator: - # do batch predict - with torch.autocast( - "cpu" if model_device == torch.device("cpu") else "cuda" - ): - batch = move_data_to_device(batch, model_device) - batch_out = self.relik_reader_core.batch_predict(**batch) - # update prediction position position - for sample in batch_out: - if sample._mixin_prediction_position >= next_prediction_position: - position2predicted_sample[ - sample._mixin_prediction_position - ] = sample - - # yield - while next_prediction_position in position2predicted_sample: - yield position2predicted_sample[next_prediction_position] - del position2predicted_sample[next_prediction_position] - next_prediction_position += 1 - - if len(position2predicted_sample) > 0: - logger.warning( - "It seems samples have been discarded in your dataset. " - "This means that you WON'T have a prediction for each input sample. " - "Prediction order will also be partially disrupted" - ) - for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]): - yield v - - if progress_bar: - iterator.close() diff --git a/relik/reader/relik_reader_re.py b/relik/reader/relik_reader_re.py deleted file mode 100644 index d1efaa87110863901c522b277d6e594989b21997..0000000000000000000000000000000000000000 --- a/relik/reader/relik_reader_re.py +++ /dev/null @@ -1,556 +0,0 @@ -import logging -from pathlib import Path -from typing import Any, Callable, Dict, Iterator, List, Optional, Union - -import numpy as np -import torch -import transformers as tr -from reader.data.relik_reader_data_utils import batchify, flatten -from reader.data.relik_reader_sample import RelikReaderSample -from reader.pytorch_modules.hf.modeling_relik import ( - RelikReaderConfig, - RelikReaderREModel, -) -from tqdm import tqdm -from transformers import AutoConfig - -from relik.common.log import get_console_logger, get_logger -from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols_re - -console_logger = get_console_logger() -logger = get_logger(__name__, level=logging.INFO) - - -class RelikReaderForTripletExtraction(torch.nn.Module): - def __init__( - self, - transformer_model: Optional[Union[str, tr.PreTrainedModel]] = None, - additional_special_symbols: Optional[int] = 0, - num_layers: Optional[int] = None, - activation: str = "gelu", - linears_hidden_size: Optional[int] = 512, - use_last_k_layers: int = 1, - training: bool = False, - device: Optional[Union[str, torch.device]] = None, - tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, - **kwargs, - ) -> None: - super().__init__() - - if isinstance(transformer_model, str): - config = AutoConfig.from_pretrained( - transformer_model, trust_remote_code=True - ) - if "relik_reader" in config.model_type: - transformer_model = RelikReaderREModel.from_pretrained( - transformer_model, **kwargs - ) - else: - reader_config = RelikReaderConfig( - transformer_model=transformer_model, - additional_special_symbols=additional_special_symbols, - num_layers=num_layers, - activation=activation, - linears_hidden_size=linears_hidden_size, - use_last_k_layers=use_last_k_layers, - training=training, - ) - transformer_model = RelikReaderREModel(reader_config) - - self.relik_reader_re_model = transformer_model - - self._tokenizer = tokenizer - - # move the model to the device - self.to(device or torch.device("cpu")) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor, - prediction_mask: Optional[torch.Tensor] = None, - special_symbols_mask: Optional[torch.Tensor] = None, - special_symbols_mask_entities: Optional[torch.Tensor] = None, - start_labels: Optional[torch.Tensor] = None, - end_labels: Optional[torch.Tensor] = None, - disambiguation_labels: Optional[torch.Tensor] = None, - relation_labels: Optional[torch.Tensor] = None, - is_validation: bool = False, - is_prediction: bool = False, - *args, - **kwargs, - ) -> Dict[str, Any]: - return self.relik_reader_re_model( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - special_symbols_mask_entities, - start_labels, - end_labels, - disambiguation_labels, - relation_labels, - is_validation, - is_prediction, - *args, - **kwargs, - ) - - def batch_predict( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, - prediction_mask: Optional[torch.Tensor] = None, - special_symbols_mask: Optional[torch.Tensor] = None, - special_symbols_mask_entities: Optional[torch.Tensor] = None, - sample: Optional[List[RelikReaderSample]] = None, - *args, - **kwargs, - ) -> Iterator[RelikReaderSample]: - """ - Predicts the labels for a batch of samples. - Args: - input_ids: The input ids of the batch. - attention_mask: The attention mask of the batch. - token_type_ids: The token type ids of the batch. - prediction_mask: The prediction mask of the batch. - special_symbols_mask: The special symbols mask of the batch. - special_symbols_mask_entities: The special symbols mask entities of the batch. - sample: The samples of the batch. - Returns: - The predicted labels for each sample. - """ - forward_output = self.forward( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - special_symbols_mask_entities, - is_prediction=True, - ) - ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() - ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy() - ed_predictions = forward_output["re_entities_predictions"].cpu().numpy() - ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy() - re_predictions = forward_output["re_predictions"].cpu().numpy() - re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy() - if sample is None: - sample = [RelikReaderSample() for _ in range(len(input_ids))] - for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip( - sample, - ned_start_predictions, - ned_end_predictions, - re_predictions, - re_probabilities, - ed_predictions, - ned_type_predictions, - ): - ne_end = ne_end.cpu().numpy() - entities = [] - if self.relik_reader_re_model.entity_type_loss: - starts = np.argwhere(ne_st) - i = 0 - for start, end in zip(starts, ne_end): - ends = np.argwhere(end) - for e in ends: - entities.append([start[0], e[0], ne_et[i]]) - i += 1 - else: - starts = np.argwhere(ne_st) - for start, end in zip(starts, ne_end): - ends = np.argwhere(end) - for e in ends: - entities.append([start[0], e[0]]) - - edp = edp[: len(entities)] - re_pred = re_pred[: len(entities), : len(entities)] - re_prob = re_prob[: len(entities), : len(entities)] - possible_re = np.argwhere(re_pred) - predicted_triplets = [] - predicted_triplets_prob = [] - for i, j, r in possible_re: - if self.relik_reader_re_model.relation_disambiguation_loss: - if not ( - i != j - and edp[i, r] == 1 - and edp[j, r] == 1 - and edp[i, 0] == 0 - and edp[j, 0] == 0 - ): - continue - predicted_triplets.append([i, j, r]) - predicted_triplets_prob.append(re_prob[i, j, r]) - - ts._d["predicted_relations"] = predicted_triplets - ts._d["predicted_entities"] = entities - ts._d["predicted_relations_probabilities"] = predicted_triplets_prob - if ts.token2word: - self._convert_tokens_to_word_annotations(ts) - yield ts - - def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]: - candidates_symbols = get_special_symbols_re(len(candidates)) - candidates = [ - [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] - for cs, ct in zip(candidates_symbols, candidates) - ] - return ( - [self.tokenizer.cls_token] - + text - + [self.tokenizer.sep_token] - + flatten(candidates) - + [self.tokenizer.sep_token] - ) - - @staticmethod - def _compute_offsets(offsets_mapping): - offsets_mapping = offsets_mapping.numpy() - token2word = [] - word2token = {} - count = 0 - for i, offset in enumerate(offsets_mapping): - if offset[0] == 0: - token2word.append(i - count) - word2token[i - count] = [i] - else: - token2word.append(token2word[-1]) - word2token[token2word[-1]].append(i) - count += 1 - return token2word, word2token - - @staticmethod - def _convert_tokens_to_word_annotations(sample: RelikReaderSample): - triplets = [] - entities = [] - for entity in sample.predicted_entities: - if sample.entity_candidates: - entities.append( - ( - sample.token2word[entity[0] - 1], - sample.token2word[entity[1] - 1] + 1, - sample.entity_candidates[entity[2]], - ) - ) - else: - entities.append( - ( - sample.token2word[entity[0] - 1], - sample.token2word[entity[1] - 1] + 1, - -1, - ) - ) - for predicted_triplet, predicted_triplet_probabilities in zip( - sample.predicted_relations, sample.predicted_relations_probabilities - ): - subject, object_, relation = predicted_triplet - subject = entities[subject] - object_ = entities[object_] - relation = sample.candidates[relation] - triplets.append( - { - "subject": { - "start": subject[0], - "end": subject[1], - "type": subject[2], - "name": " ".join(sample.tokens[subject[0] : subject[1]]), - }, - "relation": { - "name": relation, - "probability": float(predicted_triplet_probabilities.round(2)), - }, - "object": { - "start": object_[0], - "end": object_[1], - "type": object_[2], - "name": " ".join(sample.tokens[object_[0] : object_[1]]), - }, - } - ) - sample.predicted_entities = entities - sample.predicted_relations = triplets - sample.predicted_relations_probabilities = None - - @torch.no_grad() - @torch.inference_mode() - def read( - self, - text: Optional[Union[List[str], List[List[str]]]] = None, - samples: Optional[List[RelikReaderSample]] = None, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - prediction_mask: Optional[torch.Tensor] = None, - special_symbols_mask: Optional[torch.Tensor] = None, - special_symbols_mask_entities: Optional[torch.Tensor] = None, - candidates: Optional[List[List[str]]] = None, - max_length: Optional[int] = 1024, - max_batch_size: Optional[int] = 64, - token_batch_size: Optional[int] = None, - progress_bar: bool = False, - *args, - **kwargs, - ) -> List[List[RelikReaderSample]]: - """ - Reads the given text. - Args: - text: The text to read in tokens. - input_ids: The input ids of the text. - attention_mask: The attention mask of the text. - token_type_ids: The token type ids of the text. - prediction_mask: The prediction mask of the text. - special_symbols_mask: The special symbols mask of the text. - special_symbols_mask_entities: The special symbols mask entities of the text. - candidates: The candidates of the text. - max_length: The maximum length of the text. - max_batch_size: The maximum batch size. - token_batch_size: The maximum number of tokens per batch. - Returns: - The predicted labels for each sample. - """ - if text is None and input_ids is None and samples is None: - raise ValueError( - "Either `text` or `input_ids` or `samples` must be provided." - ) - if (input_ids is None and samples is None) and ( - text is None or candidates is None - ): - raise ValueError( - "`text` and `candidates` must be provided to return the predictions when `input_ids` and `samples` is not provided." - ) - if text is not None and samples is None: - if len(text) != len(candidates): - raise ValueError("`text` and `candidates` must have the same length.") - if isinstance(text[0], str): # change to list of text - text = [text] - candidates = [candidates] - - samples = [ - RelikReaderSample(tokens=t, candidates=c) - for t, c in zip(text, candidates) - ] - - if samples is not None: - # function that creates a batch from the 'current_batch' list - def output_batch() -> Dict[str, Any]: - assert ( - len( - set( - [ - len(elem["predictable_candidates"]) - for elem in current_batch - ] - ) - ) - == 1 - ), " ".join( - map( - str, - [len(elem["predictable_candidates"]) for elem in current_batch], - ) - ) - - batch_dict = dict() - - de_values_by_field = { - fn: [de[fn] for de in current_batch if fn in de] - for fn in self.fields_batcher - } - - # in case you provide fields batchers but in the batch - # there are no elements for that field - de_values_by_field = { - fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 - } - - assert len(set([len(v) for v in de_values_by_field.values()])) - - # todo: maybe we should report the user about possible - # fields filtering due to "None" instances - de_values_by_field = { - fn: fvs - for fn, fvs in de_values_by_field.items() - if all([fv is not None for fv in fvs]) - } - - for field_name, field_values in de_values_by_field.items(): - field_batch = ( - self.fields_batcher[field_name]([fv[0] for fv in field_values]) - if self.fields_batcher[field_name] is not None - else field_values - ) - - batch_dict[field_name] = field_batch - - batch_dict = { - k: v.to(self.device) if isinstance(v, torch.Tensor) else v - for k, v in batch_dict.items() - } - return batch_dict - - current_batch = [] - predictions = [] - current_cand_len = -1 - - for sample in tqdm(samples, disable=not progress_bar): - sample.candidates = [NME_SYMBOL] + sample.candidates - inputs_text = self._build_input(sample.tokens, sample.candidates) - model_inputs = self.tokenizer( - inputs_text, - is_split_into_words=True, - add_special_tokens=False, - padding=False, - truncation=True, - max_length=max_length or self.tokenizer.model_max_length, - return_offsets_mapping=True, - return_tensors="pt", - ) - model_inputs["special_symbols_mask"] = ( - model_inputs["input_ids"] > self.tokenizer.vocab_size - ) - # prediction mask is 0 until the first special symbol - model_inputs["token_type_ids"] = ( - torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 - ).long() - # shift prediction_mask to the left - model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( - shifts=-1, dims=1 - ) - model_inputs["prediction_mask"][:, -1] = 1 - model_inputs["prediction_mask"][:, 0] = 1 - - assert ( - len(model_inputs["special_symbols_mask"]) - == len(model_inputs["prediction_mask"]) - == len(model_inputs["input_ids"]) - ) - - model_inputs["sample"] = sample - - # compute cand_len using special_symbols_mask - model_inputs["predictable_candidates"] = sample.candidates[ - : model_inputs["special_symbols_mask"].sum().item() - ] - # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]]) - offsets = model_inputs.pop("offset_mapping") - offsets = offsets[model_inputs["prediction_mask"] == 0] - sample.token2word, sample.word2token = self._compute_offsets(offsets) - future_max_len = max( - len(model_inputs["input_ids"]), - max([len(b["input_ids"]) for b in current_batch], default=0), - ) - future_tokens_per_batch = future_max_len * (len(current_batch) + 1) - - if len(current_batch) > 0 and ( - ( - len(model_inputs["predictable_candidates"]) != current_cand_len - and current_cand_len != -1 - ) - or ( - isinstance(token_batch_size, int) - and future_tokens_per_batch >= token_batch_size - ) - or len(current_batch) == max_batch_size - ): - batch_inputs = output_batch() - current_batch = [] - predictions.extend(list(self.batch_predict(**batch_inputs))) - current_cand_len = len(model_inputs["predictable_candidates"]) - current_batch.append(model_inputs) - - if current_batch: - batch_inputs = output_batch() - predictions.extend(list(self.batch_predict(**batch_inputs))) - else: - predictions = list( - self.batch_predict( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - special_symbols_mask_entities, - *args, - **kwargs, - ) - ) - return predictions - - @property - def device(self) -> torch.device: - """ - The device of the model. - """ - return next(self.parameters()).device - - @property - def tokenizer(self) -> tr.PreTrainedTokenizer: - """ - The tokenizer. - """ - if self._tokenizer: - return self._tokenizer - - self._tokenizer = tr.AutoTokenizer.from_pretrained( - self.relik_reader_re_model.config.name_or_path - ) - return self._tokenizer - - @property - def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: - fields_batchers = { - "input_ids": lambda x: batchify( - x, padding_value=self.tokenizer.pad_token_id - ), - "attention_mask": lambda x: batchify(x, padding_value=0), - "token_type_ids": lambda x: batchify(x, padding_value=0), - "prediction_mask": lambda x: batchify(x, padding_value=1), - "global_attention": lambda x: batchify(x, padding_value=0), - "token2word": None, - "sample": None, - "special_symbols_mask": lambda x: batchify(x, padding_value=False), - "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), - } - if "roberta" in self.relik_reader_re_model.config.model_type: - del fields_batchers["token_type_ids"] - - return fields_batchers - - def save_pretrained( - self, - output_dir: str, - model_name: Optional[str] = None, - push_to_hub: bool = False, - **kwargs, - ) -> None: - """ - Saves the model to the given path. - Args: - output_dir: The path to save the model to. - model_name: The name of the model. - push_to_hub: Whether to push the model to the hub. - """ - # create the output directory - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - model_name = model_name or "relik_reader_for_triplet_extraction" - - logger.info(f"Saving reader to {output_dir / model_name}") - - # save the model - self.relik_reader_re_model.register_for_auto_class() - self.relik_reader_re_model.save_pretrained( - output_dir / model_name, push_to_hub=push_to_hub, **kwargs - ) - - logger.info("Saving reader to disk done.") - - if self.tokenizer: - self.tokenizer.save_pretrained( - output_dir / model_name, push_to_hub=push_to_hub, **kwargs - ) - logger.info("Saving tokenizer to disk done.") diff --git a/relik/reader/relik_reader_re_data.py b/relik/reader/relik_reader_re_data.py deleted file mode 100644 index 2bc43b78eeba4e63447d8bf333dacf2a993716fd..0000000000000000000000000000000000000000 --- a/relik/reader/relik_reader_re_data.py +++ /dev/null @@ -1,849 +0,0 @@ -import logging -from typing import ( - Any, - Callable, - Dict, - Generator, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union, -) - -import numpy as np -import torch -from reader.data.relik_reader_data_utils import ( - add_noise_to_value, - batchify, - batchify_matrices, - batchify_tensor, - chunks, - flatten, -) -from reader.data.relik_reader_sample import RelikReaderSample, load_relik_reader_samples -from torch.utils.data import IterableDataset -from transformers import AutoTokenizer - -from relik.reader.utils.special_symbols import NME_SYMBOL - -logger = logging.getLogger(__name__) - - -class TokenizationOutput(NamedTuple): - input_ids: torch.Tensor - attention_mask: torch.Tensor - token_type_ids: torch.Tensor - prediction_mask: torch.Tensor - special_symbols_mask: torch.Tensor - special_symbols_mask_entities: torch.Tensor - - -class RelikREDataset(IterableDataset): - def __init__( - self, - dataset_path: str, - materialize_samples: bool, - transformer_model: str, - special_symbols: List[str], - shuffle_candidates: Optional[Union[bool, float]], - flip_candidates: Optional[Union[bool, float]], - relations_definitions: Union[str, Dict[str, str]], - for_inference: bool, - entities_definitions: Optional[Union[str, Dict[str, str]]] = None, - special_symbols_entities: Optional[List[str]] = None, - noise_param: float = 0.1, - sorting_fields: Optional[str] = None, - tokens_per_batch: int = 2048, - batch_size: int = None, - max_batch_size: int = 128, - section_size: int = 50_000, - prebatch: bool = True, - max_candidates: int = 0, - add_gold_candidates: bool = True, - use_nme: bool = True, - min_length: int = 5, - max_length: int = 2048, - model_max_length: int = 1000, - skip_empty_training_samples: bool = True, - drop_last: bool = False, - samples: Optional[Iterator[RelikReaderSample]] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.dataset_path = dataset_path - self.materialize_samples = materialize_samples - self.samples: Optional[List[RelikReaderSample]] = None - if self.materialize_samples: - self.samples = list() - - self.tokenizer = self._build_tokenizer(transformer_model, special_symbols) - self.special_symbols = special_symbols - self.special_symbols_entities = special_symbols_entities - self.shuffle_candidates = shuffle_candidates - self.flip_candidates = flip_candidates - self.for_inference = for_inference - self.noise_param = noise_param - self.batching_fields = ["input_ids"] - self.sorting_fields = ( - sorting_fields if sorting_fields is not None else self.batching_fields - ) - - # open relations definitions file if needed - if type(relations_definitions) == str: - relations_definitions = { - line.split("\t")[0]: line.split("\t")[1] - for line in open(relations_definitions) - } - self.max_candidates = max_candidates - self.relations_definitions = relations_definitions - self.entities_definitions = entities_definitions - - self.add_gold_candidates = add_gold_candidates - self.use_nme = use_nme - self.min_length = min_length - self.max_length = max_length - self.model_max_length = ( - model_max_length - if model_max_length < self.tokenizer.model_max_length - else self.tokenizer.model_max_length - ) - self.transformer_model = transformer_model - self.skip_empty_training_samples = skip_empty_training_samples - self.drop_last = drop_last - self.samples = samples - - self.tokens_per_batch = tokens_per_batch - self.batch_size = batch_size - self.max_batch_size = max_batch_size - self.section_size = section_size - self.prebatch = prebatch - - def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): - return AutoTokenizer.from_pretrained( - transformer_model, - additional_special_tokens=[ss for ss in special_symbols], - add_prefix_space=True, - ) - - @property - def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: - fields_batchers = { - "input_ids": lambda x: batchify( - x, padding_value=self.tokenizer.pad_token_id - ), - "attention_mask": lambda x: batchify(x, padding_value=0), - "token_type_ids": lambda x: batchify(x, padding_value=0), - "prediction_mask": lambda x: batchify(x, padding_value=1), - "global_attention": lambda x: batchify(x, padding_value=0), - "token2word": None, - "sample": None, - "special_symbols_mask": lambda x: batchify(x, padding_value=False), - "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), - "start_labels": lambda x: batchify(x, padding_value=-100), - "end_labels": lambda x: batchify_matrices(x, padding_value=-100), - "disambiguation_labels": lambda x: batchify(x, padding_value=-100), - "relation_labels": lambda x: batchify_tensor(x, padding_value=-100), - "predictable_candidates": None, - } - if "roberta" in self.transformer_model: - del fields_batchers["token_type_ids"] - - return fields_batchers - - def _build_input_ids( - self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] - ) -> List[int]: - return ( - [self.tokenizer.cls_token_id] - + sentence_input_ids - + [self.tokenizer.sep_token_id] - + flatten(candidates_input_ids) - + [self.tokenizer.sep_token_id] - ) - - def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor: - special_symbols_mask = input_ids >= ( - len(self.tokenizer) - - len(self.special_symbols + self.special_symbols_entities) - ) - special_symbols_mask[0] = True - return special_symbols_mask - - def _build_tokenizer_essentials( - self, input_ids, original_sequence - ) -> TokenizationOutput: - input_ids = torch.tensor(input_ids, dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - - total_sequence_len = len(input_ids) - predictable_sentence_len = len(original_sequence) - - # token type ids - token_type_ids = torch.cat( - [ - input_ids.new_zeros( - predictable_sentence_len + 2 - ), # original sentence bpes + CLS and SEP - input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2), - ] - ) - - # prediction mask -> boolean on tokens that are predictable - - prediction_mask = torch.tensor( - [1] - + ([0] * predictable_sentence_len) - + ([1] * (total_sequence_len - predictable_sentence_len - 1)) - ) - - assert len(prediction_mask) == len(input_ids) - - # special symbols mask - special_symbols_mask = input_ids >= ( - len(self.tokenizer) - - len(self.special_symbols) # + self.special_symbols_entities) - ) - if self.entities_definitions is not None: - # select only the first N true values where N is len(entities_definitions) - special_symbols_mask_entities = special_symbols_mask.clone() - special_symbols_mask_entities[ - special_symbols_mask_entities.cumsum(0) > len(self.entities_definitions) - ] = False - special_symbols_mask = special_symbols_mask ^ special_symbols_mask_entities - else: - special_symbols_mask_entities = special_symbols_mask.clone() - - return TokenizationOutput( - input_ids, - attention_mask, - token_type_ids, - prediction_mask, - special_symbols_mask, - special_symbols_mask_entities, - ) - - def _build_labels( - self, - sample, - tokenization_output: TokenizationOutput, - ) -> Tuple[torch.Tensor, torch.Tensor]: - start_labels = [0] * len(tokenization_output.input_ids) - end_labels = [] - - sample.entities.sort(key=lambda x: (x[0], x[1])) - - prev_start_bpe = -1 - num_repeat_start = 0 - if self.entities_definitions: - sample.entities = [(ce[0], ce[1], ce[2]) for ce in sample.entities] - sample.entity_candidates = list(self.entities_definitions.keys()) - disambiguation_labels = torch.zeros( - len(sample.entities), - len(sample.entity_candidates) + len(sample.candidates), - ) - else: - sample.entities = [(ce[0], ce[1], "") for ce in sample.entities] - disambiguation_labels = torch.zeros( - len(sample.entities), len(sample.candidates) - ) - ignored_labels_indices = tokenization_output.prediction_mask == 1 - for idx, c_ent in enumerate(sample.entities): - start_bpe = sample.word2token[c_ent[0]][0] + 1 - end_bpe = sample.word2token[c_ent[1] - 1][-1] + 1 - class_index = idx - start_labels[start_bpe] = class_index + 1 # +1 for the NONE class - if start_bpe != prev_start_bpe: - end_labels.append([0] * len(tokenization_output.input_ids)) - # end_labels[-1][:start_bpe] = [-100] * start_bpe - end_labels[-1][end_bpe] = class_index + 1 - else: - end_labels[-1][end_bpe] = class_index + 1 - num_repeat_start += 1 - if self.entities_definitions: - entity_type_idx = sample.entity_candidates.index(c_ent[2]) - disambiguation_labels[idx, entity_type_idx] = 1 - prev_start_bpe = start_bpe - - start_labels = torch.tensor(start_labels, dtype=torch.long) - start_labels[ignored_labels_indices] = -100 - - end_labels = torch.tensor(end_labels, dtype=torch.long) - end_labels[ignored_labels_indices.repeat(len(end_labels), 1)] = -100 - - relation_labels = torch.zeros( - len(sample.entities), len(sample.entities), len(sample.candidates) - ) - - # sample.relations = [] - for re in sample.triplets: - if re["relation"]["name"] not in sample.candidates: - re_class_index = len(sample.candidates) - 1 - else: - re_class_index = sample.candidates.index( - re["relation"]["name"] - ) # should remove this +1 - if self.entities_definitions: - subject_class_index = sample.entities.index( - ( - re["subject"]["start"], - re["subject"]["end"], - re["subject"]["type"], - ) - ) - object_class_index = sample.entities.index( - (re["object"]["start"], re["object"]["end"], re["object"]["type"]) - ) - else: - subject_class_index = sample.entities.index( - (re["subject"]["start"], re["subject"]["end"], "") - ) - object_class_index = sample.entities.index( - (re["object"]["start"], re["object"]["end"], "") - ) - - relation_labels[subject_class_index, object_class_index, re_class_index] = 1 - - if self.entities_definitions: - disambiguation_labels[ - subject_class_index, re_class_index + len(sample.entity_candidates) - ] = 1 - disambiguation_labels[ - object_class_index, re_class_index + len(sample.entity_candidates) - ] = 1 - # sample.relations.append([re['subject']['start'], re['subject']['end'], re['subject']['type'], re['relation']['name'], re['object']['start'], re['object']['end'], re['object']['type']]) - else: - disambiguation_labels[subject_class_index, re_class_index] = 1 - disambiguation_labels[object_class_index, re_class_index] = 1 - # sample.relations.append([re['subject']['start'], re['subject']['end'], "", re['relation']['name'], re['object']['start'], re['object']['end'], ""]) - return start_labels, end_labels, disambiguation_labels, relation_labels - - def __iter__(self): - dataset_iterator = self.dataset_iterator_func() - current_dataset_elements = [] - i = None - for i, dataset_elem in enumerate(dataset_iterator, start=1): - if ( - self.section_size is not None - and len(current_dataset_elements) == self.section_size - ): - for batch in self.materialize_batches(current_dataset_elements): - yield batch - current_dataset_elements = [] - current_dataset_elements.append(dataset_elem) - if i % 50_000 == 0: - logger.info(f"Processed: {i} number of elements") - if len(current_dataset_elements) != 0: - for batch in self.materialize_batches(current_dataset_elements): - yield batch - if i is not None: - logger.info(f"Dataset finished: {i} number of elements processed") - else: - logger.warning("Dataset empty") - - def dataset_iterator_func(self): - data_samples = ( - load_relik_reader_samples(self.dataset_path) - if self.samples is None - else self.samples - ) - for sample in data_samples: - # input sentence tokenization - input_tokenized = self.tokenizer( - sample.tokens, - return_offsets_mapping=True, - add_special_tokens=False, - is_split_into_words=True, - ) - input_subwords = input_tokenized["input_ids"] - offsets = input_tokenized["offset_mapping"] - token2word = [] - word2token = {} - count = 0 - for i, offset in enumerate(offsets): - if offset[0] == 0: - token2word.append(i - count) - word2token[i - count] = [i] - else: - token2word.append(token2word[-1]) - word2token[token2word[-1]].append(i) - count += 1 - sample.token2word = token2word - sample.word2token = word2token - # input_subwords = sample.tokens[1:-1] # removing special tokens - candidates_symbols = self.special_symbols - - if self.max_candidates > 0: - # truncate candidates - sample.candidates = sample.candidates[: self.max_candidates] - - # add NME as a possible candidate - if self.use_nme: - sample.candidates.insert(0, NME_SYMBOL) - - # training time sample mods - if not self.for_inference: - # check whether the sample has labels if not skip - if ( - sample.triplets is None or len(sample.triplets) == 0 - ) and self.skip_empty_training_samples: - logger.warning( - "Sample {} has no labels, skipping".format(sample.sample_id) - ) - continue - - # add gold candidates if missing - if self.add_gold_candidates: - candidates_set = set(sample.candidates) - candidates_to_add = [] - for candidate_title in sample.triplets: - if candidate_title["relation"]["name"] not in candidates_set: - candidates_to_add.append( - candidate_title["relation"]["name"] - ) - if len(candidates_to_add) > 0: - # replacing last candidates with the gold ones - # this is done in order to preserve the ordering - added_gold_candidates = 0 - gold_candidates_titles_set = set( - set(ct["relation"]["name"] for ct in sample.triplets) - ) - for i in reversed(range(len(sample.candidates))): - if ( - sample.candidates[i] not in gold_candidates_titles_set - and sample.candidates[i] != NME_SYMBOL - ): - sample.candidates[i] = candidates_to_add[ - added_gold_candidates - ] - added_gold_candidates += 1 - if len(candidates_to_add) == added_gold_candidates: - break - - candidates_still_to_add = ( - len(candidates_to_add) - added_gold_candidates - ) - while ( - len(sample.candidates) <= len(candidates_symbols) - and candidates_still_to_add != 0 - ): - sample.candidates.append( - candidates_to_add[added_gold_candidates] - ) - added_gold_candidates += 1 - candidates_still_to_add -= 1 - - # shuffle candidates - if ( - isinstance(self.shuffle_candidates, bool) - and self.shuffle_candidates - ) or ( - isinstance(self.shuffle_candidates, float) - and np.random.uniform() < self.shuffle_candidates - ): - np.random.shuffle(sample.candidates) - if NME_SYMBOL in sample.candidates: - sample.candidates.remove(NME_SYMBOL) - sample.candidates.insert(0, NME_SYMBOL) - - # flip candidates - if ( - isinstance(self.flip_candidates, bool) and self.flip_candidates - ) or ( - isinstance(self.flip_candidates, float) - and np.random.uniform() < self.flip_candidates - ): - for i in range(len(sample.candidates) - 1): - if np.random.uniform() < 0.5: - sample.candidates[i], sample.candidates[i + 1] = ( - sample.candidates[i + 1], - sample.candidates[i], - ) - if NME_SYMBOL in sample.candidates: - sample.candidates.remove(NME_SYMBOL) - sample.candidates.insert(0, NME_SYMBOL) - - # candidates encoding - candidates_symbols = candidates_symbols[: len(sample.candidates)] - relations_defs = [ - "{} {}".format(cs, self.relations_definitions[ct]) - if ct != NME_SYMBOL - else NME_SYMBOL - for cs, ct in zip(candidates_symbols, sample.candidates) - ] - if self.entities_definitions is not None: - candidates_entities_symbols = list(self.special_symbols_entities) - candidates_entities_symbols = candidates_entities_symbols[ - : len(self.entities_definitions) - ] - entity_defs = [ - "{} {}".format(cs, self.entities_definitions[ct]) - for cs, ct in zip( - candidates_entities_symbols, self.entities_definitions.keys() - ) - ] - relations_defs = ( - entity_defs + [self.tokenizer.sep_token] + relations_defs - ) - - candidates_encoding_result = self.tokenizer.batch_encode_plus( - relations_defs, - add_special_tokens=False, - ).input_ids - - # drop candidates if the number of input tokens is too long for the model - if ( - sum(map(len, candidates_encoding_result)) - + len(input_subwords) - + 20 # + 20 special tokens - > self.model_max_length - ): - if self.for_inference: - acceptable_tokens_from_candidates = ( - self.model_max_length - 20 - len(input_subwords) - ) - while ( - cum_len + len(candidates_encoding_result[i]) - < acceptable_tokens_from_candidates - ): - cum_len += len(candidates_encoding_result[i]) - i += 1 - - candidates_encoding_result = candidates_encoding_result[:i] - if self.entities_definitions is not None: - candidates_symbols = candidates_symbols[ - : i - len(self.entities_definitions) - ] - sample.candidates = sample.candidates[ - : i - len(self.entities_definitions) - ] - else: - candidates_symbols = candidates_symbols[:i] - sample.candidates = sample.candidates[:i] - - else: - gold_candidates_set = set( - [wl["relation"]["name"] for wl in sample.triplets] - ) - gold_candidates_indices = [ - i - for i, wc in enumerate(sample.candidates) - if wc in gold_candidates_set - ] - if self.entities_definitions is not None: - gold_candidates_indices = [ - i + len(self.entities_definitions) - for i in gold_candidates_indices - ] - # add entities indices - gold_candidates_indices = gold_candidates_indices + list( - range(len(self.entities_definitions)) - ) - necessary_taken_tokens = sum( - map( - len, - [ - candidates_encoding_result[i] - for i in gold_candidates_indices - ], - ) - ) - - acceptable_tokens_from_candidates = ( - self.model_max_length - - 20 - - len(input_subwords) - - necessary_taken_tokens - ) - - assert acceptable_tokens_from_candidates > 0 - - i = 0 - cum_len = 0 - while ( - cum_len + len(candidates_encoding_result[i]) - < acceptable_tokens_from_candidates - ): - if i not in gold_candidates_indices: - cum_len += len(candidates_encoding_result[i]) - i += 1 - - new_indices = sorted( - list(set(list(range(i)) + gold_candidates_indices)) - ) - np.random.shuffle(new_indices) - - candidates_encoding_result = [ - candidates_encoding_result[i] for i in new_indices - ] - if self.entities_definitions is not None: - sample.candidates = [ - sample.candidates[i - len(self.entities_definitions)] - for i in new_indices - ] - candidates_symbols = candidates_symbols[ - : i - len(self.entities_definitions) - ] - else: - candidates_symbols = [ - candidates_symbols[i] for i in new_indices - ] - sample.window_candidates = [ - sample.window_candidates[i] for i in new_indices - ] - if len(sample.candidates) == 0: - logger.warning( - "Sample {} has no candidates after truncation due to max length".format( - sample.sample_id - ) - ) - continue - - # final input_ids build - input_ids = self._build_input_ids( - sentence_input_ids=input_subwords, - candidates_input_ids=candidates_encoding_result, - ) - - # complete input building (e.g. attention / prediction mask) - tokenization_output = self._build_tokenizer_essentials( - input_ids, input_subwords - ) - - # labels creation - start_labels, end_labels, disambiguation_labels, relation_labels = ( - None, - None, - None, - None, - ) - if sample.entities is not None and len(sample.entities) > 0: - ( - start_labels, - end_labels, - disambiguation_labels, - relation_labels, - ) = self._build_labels( - sample, - tokenization_output, - ) - - yield { - "input_ids": tokenization_output.input_ids, - "attention_mask": tokenization_output.attention_mask, - "token_type_ids": tokenization_output.token_type_ids, - "prediction_mask": tokenization_output.prediction_mask, - "special_symbols_mask": tokenization_output.special_symbols_mask, - "special_symbols_mask_entities": tokenization_output.special_symbols_mask_entities, - "sample": sample, - "start_labels": start_labels, - "end_labels": end_labels, - "disambiguation_labels": disambiguation_labels, - "relation_labels": relation_labels, - "predictable_candidates": candidates_symbols, - } - - def preshuffle_elements(self, dataset_elements: List): - # This shuffling is done so that when using the sorting function, - # if it is deterministic given a collection and its order, we will - # make the whole operation not deterministic anymore. - # Basically, the aim is not to build every time the same batches. - if not self.for_inference: - dataset_elements = np.random.permutation(dataset_elements) - - sorting_fn = ( - lambda elem: add_noise_to_value( - sum(len(elem[k]) for k in self.sorting_fields), - noise_param=self.noise_param, - ) - if not self.for_inference - else sum(len(elem[k]) for k in self.sorting_fields) - ) - - dataset_elements = sorted(dataset_elements, key=sorting_fn) - - if self.for_inference: - return dataset_elements - - ds = list(chunks(dataset_elements, 64)) # todo: modified - np.random.shuffle(ds) - return flatten(ds) - - def materialize_batches( - self, dataset_elements: List[Dict[str, Any]] - ) -> Generator[Dict[str, Any], None, None]: - if self.prebatch: - dataset_elements = self.preshuffle_elements(dataset_elements) - - current_batch = [] - - # function that creates a batch from the 'current_batch' list - def output_batch() -> Dict[str, Any]: - assert ( - len( - set([len(elem["predictable_candidates"]) for elem in current_batch]) - ) - == 1 - ), " ".join( - map( - str, [len(elem["predictable_candidates"]) for elem in current_batch] - ) - ) - - batch_dict = dict() - - de_values_by_field = { - fn: [de[fn] for de in current_batch if fn in de] - for fn in self.fields_batcher - } - - # in case you provide fields batchers but in the batch - # there are no elements for that field - de_values_by_field = { - fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 - } - - assert len(set([len(v) for v in de_values_by_field.values()])) - - # todo: maybe we should report the user about possible - # fields filtering due to "None" instances - de_values_by_field = { - fn: fvs - for fn, fvs in de_values_by_field.items() - if all([fv is not None for fv in fvs]) - } - - for field_name, field_values in de_values_by_field.items(): - field_batch = ( - self.fields_batcher[field_name](field_values) - if self.fields_batcher[field_name] is not None - else field_values - ) - - batch_dict[field_name] = field_batch - - return batch_dict - - max_len_discards, min_len_discards = 0, 0 - - should_token_batch = self.batch_size is None - - curr_pred_elements = -1 - for de in dataset_elements: - if ( - should_token_batch - and self.max_batch_size != -1 - and len(current_batch) == self.max_batch_size - ) or (not should_token_batch and len(current_batch) == self.batch_size): - yield output_batch() - current_batch = [] - curr_pred_elements = -1 - - # todo support max length (and min length) as dicts - - too_long_fields = [ - k - for k in de - if self.max_length != -1 - and torch.is_tensor(de[k]) - and len(de[k]) > self.max_length - ] - if len(too_long_fields) > 0: - max_len_discards += 1 - continue - - too_short_fields = [ - k - for k in de - if self.min_length != -1 - and torch.is_tensor(de[k]) - and len(de[k]) < self.min_length - ] - if len(too_short_fields) > 0: - min_len_discards += 1 - continue - - if should_token_batch: - de_len = sum(len(de[k]) for k in self.batching_fields) - - future_max_len = max( - de_len, - max( - [ - sum(len(bde[k]) for k in self.batching_fields) - for bde in current_batch - ], - default=0, - ), - ) - - future_tokens_per_batch = future_max_len * (len(current_batch) + 1) - - num_predictable_candidates = len(de["predictable_candidates"]) - - if len(current_batch) > 0 and ( - future_tokens_per_batch >= self.tokens_per_batch - or ( - num_predictable_candidates != curr_pred_elements - and curr_pred_elements != -1 - ) - ): - yield output_batch() - current_batch = [] - - current_batch.append(de) - curr_pred_elements = len(de["predictable_candidates"]) - - if len(current_batch) != 0 and not self.drop_last: - yield output_batch() - - if max_len_discards > 0: - if self.for_inference: - logger.warning( - f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " - f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" - f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " - f"sample length exceeds the maximum length supported by the current model." - ) - else: - logger.warning( - f"During iteration, {max_len_discards} elements were " - f"discarded since longer than max length {self.max_length}" - ) - - if min_len_discards > 0: - if self.for_inference: - logger.warning( - f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " - f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" - f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " - f"sample length is shorter than the minimum length supported by the current model." - ) - else: - logger.warning( - f"During iteration, {min_len_discards} elements were " - f"discarded since shorter than min length {self.min_length}" - ) - - -def main(): - special_symbols = [NME_SYMBOL] + [f"R-{i}" for i in range(50)] - - relik_dataset = RelikREDataset( - "/home/huguetcabot/alby-re/alby/data/nyt-alby+/valid.jsonl", - materialize_samples=False, - transformer_model="microsoft/deberta-v3-base", - special_symbols=special_symbols, - shuffle_candidates=False, - flip_candidates=False, - for_inference=True, - ) - - for batch in relik_dataset: - print(batch) - exit(0) - - -if __name__ == "__main__": - main() diff --git a/relik/reader/trainer/__init__.py b/relik/reader/trainer/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/reader/trainer/predict.py b/relik/reader/trainer/predict.py deleted file mode 100644 index 3801bef958f9a092f8d094d2e99fe476fd4caed9..0000000000000000000000000000000000000000 --- a/relik/reader/trainer/predict.py +++ /dev/null @@ -1,57 +0,0 @@ -import argparse -from pprint import pprint -from typing import Optional - -from relik.reader.relik_reader import RelikReader -from relik.reader.utils.strong_matching_eval import StrongMatching - - -def predict( - model_path: str, - dataset_path: str, - token_batch_size: int, - is_eval: bool, - output_path: Optional[str], -) -> None: - relik_reader = RelikReader(model_path) - predicted_samples = relik_reader.link_entities( - dataset_path, token_batch_size=token_batch_size - ) - if is_eval: - eval_dict = StrongMatching()(predicted_samples) - pprint(eval_dict) - if output_path is not None: - with open(output_path, "w") as f: - for sample in predicted_samples: - f.write(sample.to_jsons() + "\n") - - -def parse_arg() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument( - "--model-path", - required=True, - ) - parser.add_argument("--dataset-path", "-i", required=True) - parser.add_argument("--is-eval", action="store_true") - parser.add_argument( - "--output-path", - "-o", - ) - parser.add_argument("--token-batch-size", default=4096) - return parser.parse_args() - - -def main(): - args = parse_arg() - predict( - args.model_path, - args.dataset_path, - token_batch_size=args.token_batch_size, - is_eval=args.is_eval, - output_path=args.output_path, - ) - - -if __name__ == "__main__": - main() diff --git a/relik/reader/trainer/predict_re.py b/relik/reader/trainer/predict_re.py deleted file mode 100644 index 2f21b4b15f857ec561e14f2e8c9fb50ea4f00637..0000000000000000000000000000000000000000 --- a/relik/reader/trainer/predict_re.py +++ /dev/null @@ -1,125 +0,0 @@ -import argparse - -import torch -from reader.data.relik_reader_sample import load_relik_reader_samples - -from relik.reader.pytorch_modules.hf.modeling_relik import ( - RelikReaderConfig, - RelikReaderREModel, -) -from relik.reader.relik_reader_re import RelikReaderForTripletExtraction -from relik.reader.utils.relation_matching_eval import StrongMatching - -dict_nyt = { - "/people/person/nationality": "nationality", - "/sports/sports_team/location": "sports team location", - "/location/country/administrative_divisions": "administrative divisions", - "/business/company/major_shareholders": "shareholders", - "/people/ethnicity/people": "ethnicity", - "/people/ethnicity/geographic_distribution": "geographic distributi6on", - "/business/company_shareholder/major_shareholder_of": "major shareholder", - "/location/location/contains": "location", - "/business/company/founders": "founders", - "/business/person/company": "company", - "/business/company/advisors": "advisor", - "/people/deceased_person/place_of_death": "place of death", - "/business/company/industry": "industry", - "/people/person/ethnicity": "ethnic background", - "/people/person/place_of_birth": "place of birth", - "/location/administrative_division/country": "country of an administration division", - "/people/person/place_lived": "place lived", - "/sports/sports_team_location/teams": "sports team", - "/people/person/children": "child", - "/people/person/religion": "religion", - "/location/neighborhood/neighborhood_of": "neighborhood", - "/location/country/capital": "capital", - "/business/company/place_founded": "company founded location", - "/people/person/profession": "occupation", -} - - -def eval(model_path, data_path, is_eval, output_path=None): - if model_path.endswith(".ckpt"): - # if it is a lightning checkpoint we load the model state dict and the tokenizer from the config - model_dict = torch.load(model_path) - - additional_special_symbols = model_dict["hyper_parameters"][ - "additional_special_symbols" - ] - from transformers import AutoTokenizer - - from relik.reader.utils.special_symbols import get_special_symbols_re - - special_symbols = get_special_symbols_re(additional_special_symbols - 1) - tokenizer = AutoTokenizer.from_pretrained( - model_dict["hyper_parameters"]["transformer_model"], - additional_special_tokens=special_symbols, - add_prefix_space=True, - ) - config_model = RelikReaderConfig( - model_dict["hyper_parameters"]["transformer_model"], - len(special_symbols), - training=False, - ) - model = RelikReaderREModel(config_model) - model_dict["state_dict"] = { - k.replace("relik_reader_re_model.", ""): v - for k, v in model_dict["state_dict"].items() - } - model.load_state_dict(model_dict["state_dict"], strict=False) - reader = RelikReaderForTripletExtraction( - model, training=False, device="cuda", tokenizer=tokenizer - ) - else: - # if it is a huggingface model we load the model directly. Note that it could even be a string from the hub - model = RelikReaderREModel.from_pretrained(model_path) - reader = RelikReaderForTripletExtraction(model, training=False, device="cuda") - - samples = list(load_relik_reader_samples(data_path)) - - for sample in samples: - sample.candidates = [dict_nyt[cand] for cand in sample.candidates] - sample.triplets = [ - { - "subject": triplet["subject"], - "relation": { - "name": dict_nyt[triplet["relation"]["name"]], - "type": triplet["relation"]["type"], - }, - "object": triplet["object"], - } - for triplet in sample.triplets - ] - - predicted_samples = reader.read(samples=samples, progress_bar=True) - if is_eval: - strong_matching_metric = StrongMatching() - predicted_samples = list(predicted_samples) - for k, v in strong_matching_metric(predicted_samples).items(): - print(f"test_{k}", v) - if output_path is not None: - with open(output_path, "w") as f: - for sample in predicted_samples: - f.write(sample.to_jsons() + "\n") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_path", - type=str, - default="/home/huguetcabot/alby-re/relik/relik/reader/models/relik_re_reader_base", - ) - parser.add_argument( - "--data_path", - type=str, - default="/home/huguetcabot/alby-re/relik/relik/reader/data/testa.jsonl", - ) - parser.add_argument("--is-eval", action="store_true") - parser.add_argument("--output_path", type=str, default=None) - args = parser.parse_args() - eval(args.model_path, args.data_path, args.is_eval, args.output_path) - - -if __name__ == "__main__": - main() diff --git a/relik/reader/trainer/train.py b/relik/reader/trainer/train.py deleted file mode 100644 index f1983b38c02199f45c112247fc74bd09d3f1e4f0..0000000000000000000000000000000000000000 --- a/relik/reader/trainer/train.py +++ /dev/null @@ -1,98 +0,0 @@ -import hydra -import lightning -from hydra.utils import to_absolute_path -from lightning import Trainer -from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint -from lightning.pytorch.loggers.wandb import WandbLogger -from omegaconf import DictConfig, OmegaConf, open_dict -from reader.data.relik_reader_data import RelikDataset -from reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule -from reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer -from torch.utils.data import DataLoader - -from relik.reader.utils.special_symbols import get_special_symbols -from relik.reader.utils.strong_matching_eval import ELStrongMatchingCallback - - -@hydra.main(config_path="../conf", config_name="config") -def train(cfg: DictConfig) -> None: - lightning.seed_everything(cfg.training.seed) - - special_symbols = get_special_symbols(cfg.model.entities_per_forward) - - # model declaration - model = RelikReaderPLModule( - cfg=OmegaConf.to_container(cfg), - transformer_model=cfg.model.model.transformer_model, - additional_special_symbols=len(special_symbols), - training=True, - ) - - # optimizer declaration - opt_conf = cfg.model.optimizer - electra_optimizer_factory = LayerWiseLRDecayOptimizer( - lr=opt_conf.lr, - warmup_steps=opt_conf.warmup_steps, - total_steps=opt_conf.total_steps, - total_reset=opt_conf.total_reset, - no_decay_params=opt_conf.no_decay_params, - weight_decay=opt_conf.weight_decay, - lr_decay=opt_conf.lr_decay, - ) - - model.set_optimizer_factory(electra_optimizer_factory) - - # datasets declaration - train_dataset: RelikDataset = hydra.utils.instantiate( - cfg.data.train_dataset, - dataset_path=to_absolute_path(cfg.data.train_dataset_path), - special_symbols=special_symbols, - ) - - # update of validation dataset config with special_symbols since they - # are required even from the EvaluationCallback dataset_config - with open_dict(cfg): - cfg.data.val_dataset.special_symbols = special_symbols - - val_dataset: RelikDataset = hydra.utils.instantiate( - cfg.data.val_dataset, - dataset_path=to_absolute_path(cfg.data.val_dataset_path), - ) - - # callbacks declaration - callbacks = [ - ELStrongMatchingCallback( - to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset - ), - ModelCheckpoint( - "model", - filename="{epoch}-{val_core_f1:.2f}", - monitor="val_core_f1", - mode="max", - ), - LearningRateMonitor(), - ] - - wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name) - - # trainer declaration - trainer: Trainer = hydra.utils.instantiate( - cfg.training.trainer, - callbacks=callbacks, - logger=wandb_logger, - ) - - # Trainer fit - trainer.fit( - model=model, - train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), - val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), - ) - - -def main(): - train() - - -if __name__ == "__main__": - main() diff --git a/relik/reader/trainer/train_re.py b/relik/reader/trainer/train_re.py deleted file mode 100644 index 550b3fc95d653bfc9503af635931aa72176ca89d..0000000000000000000000000000000000000000 --- a/relik/reader/trainer/train_re.py +++ /dev/null @@ -1,109 +0,0 @@ -import hydra -import lightning -from hydra.utils import to_absolute_path -from lightning import Trainer -from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint -from lightning.pytorch.loggers.wandb import WandbLogger -from omegaconf import DictConfig, OmegaConf, open_dict -from reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer -from torch.utils.data import DataLoader - -from relik.reader.lightning_modules.relik_reader_re_pl_module import ( - RelikReaderREPLModule, -) -from relik.reader.relik_reader_re_data import RelikREDataset -from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback -from relik.reader.utils.special_symbols import get_special_symbols_re - - -@hydra.main(config_path="conf", config_name="config") -def train(cfg: DictConfig) -> None: - lightning.seed_everything(cfg.training.seed) - - special_symbols = get_special_symbols_re(cfg.model.entities_per_forward) - - # datasets declaration - train_dataset: RelikREDataset = hydra.utils.instantiate( - cfg.data.train_dataset, - dataset_path=to_absolute_path(cfg.data.train_dataset_path), - special_symbols=special_symbols, - ) - - # update of validation dataset config with special_symbols since they - # are required even from the EvaluationCallback dataset_config - with open_dict(cfg): - cfg.data.val_dataset.special_symbols = special_symbols - - val_dataset: RelikREDataset = hydra.utils.instantiate( - cfg.data.val_dataset, - dataset_path=to_absolute_path(cfg.data.val_dataset_path), - ) - - # model declaration - model = RelikReaderREPLModule( - cfg=OmegaConf.to_container(cfg), - transformer_model=cfg.model.model.transformer_model, - additional_special_symbols=len(special_symbols), - training=True, - ) - model.relik_reader_re_model._tokenizer = train_dataset.tokenizer - # optimizer declaration - opt_conf = cfg.model.optimizer - - # adamw_optimizer_factory = AdamWWithWarmupOptimizer( - # lr=opt_conf.lr, - # warmup_steps=opt_conf.warmup_steps, - # total_steps=opt_conf.total_steps, - # no_decay_params=opt_conf.no_decay_params, - # weight_decay=opt_conf.weight_decay, - # ) - - electra_optimizer_factory = LayerWiseLRDecayOptimizer( - lr=opt_conf.lr, - warmup_steps=opt_conf.warmup_steps, - total_steps=opt_conf.total_steps, - total_reset=opt_conf.total_reset, - no_decay_params=opt_conf.no_decay_params, - weight_decay=opt_conf.weight_decay, - lr_decay=opt_conf.lr_decay, - ) - - model.set_optimizer_factory(electra_optimizer_factory) - - # callbacks declaration - callbacks = [ - REStrongMatchingCallback( - to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset - ), - ModelCheckpoint( - "model", - filename="{epoch}-{val_f1:.2f}", - monitor="val_f1", - mode="max", - ), - LearningRateMonitor(), - ] - - wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name) - - # trainer declaration - trainer: Trainer = hydra.utils.instantiate( - cfg.training.trainer, - callbacks=callbacks, - logger=wandb_logger, - ) - - # Trainer fit - trainer.fit( - model=model, - train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), - val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), - ) - - -def main(): - train() - - -if __name__ == "__main__": - main() diff --git a/relik/reader/utils/__init__.py b/relik/reader/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/reader/utils/metrics.py b/relik/reader/utils/metrics.py deleted file mode 100644 index fa17bf5d23cc888d6da0c6f40cf7bd3c20d77a66..0000000000000000000000000000000000000000 --- a/relik/reader/utils/metrics.py +++ /dev/null @@ -1,18 +0,0 @@ -def safe_divide(num: float, den: float) -> float: - if den == 0: - return 0 - else: - return num / den - - -def f1_measure(precision: float, recall: float) -> float: - if precision == 0 or recall == 0: - return 0.0 - return safe_divide(2 * precision * recall, (precision + recall)) - - -def compute_metrics(total_correct, total_preds, total_gold): - precision = safe_divide(total_correct, total_preds) - recall = safe_divide(total_correct, total_gold) - f1 = f1_measure(precision, recall) - return precision, recall, f1 diff --git a/relik/reader/utils/relation_matching_eval.py b/relik/reader/utils/relation_matching_eval.py deleted file mode 100644 index 94a6b1e7a8dc155ed1ab9f6c52cb3c1eebd44505..0000000000000000000000000000000000000000 --- a/relik/reader/utils/relation_matching_eval.py +++ /dev/null @@ -1,172 +0,0 @@ -from typing import Dict, List - -from lightning.pytorch.callbacks import Callback -from reader.data.relik_reader_sample import RelikReaderSample - -from relik.reader.relik_reader_predictor import RelikReaderPredictor -from relik.reader.utils.metrics import compute_metrics - - -class StrongMatching: - def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: - # accumulators - correct_predictions, total_predictions, total_gold = ( - 0, - 0, - 0, - ) - correct_predictions_strict, total_predictions_strict = ( - 0, - 0, - ) - correct_predictions_bound, total_predictions_bound = ( - 0, - 0, - ) - correct_span_predictions, total_span_predictions, total_gold_spans = 0, 0, 0 - - # collect data from samples - for sample in predicted_samples: - if sample.triplets is None: - sample.triplets = [] - - if sample.entity_candidates: - predicted_annotations_strict = set( - [ - ( - triplet["subject"]["start"], - triplet["subject"]["end"], - triplet["subject"]["type"], - triplet["relation"]["name"], - triplet["object"]["start"], - triplet["object"]["end"], - triplet["object"]["type"], - ) - for triplet in sample.predicted_relations - ] - ) - gold_annotations_strict = set( - [ - ( - triplet["subject"]["start"], - triplet["subject"]["end"], - triplet["subject"]["type"], - triplet["relation"]["name"], - triplet["object"]["start"], - triplet["object"]["end"], - triplet["object"]["type"], - ) - for triplet in sample.triplets - ] - ) - predicted_spans_strict = set(sample.predicted_entities) - gold_spans_strict = set(sample.entities) - # strict - correct_span_predictions += len( - predicted_spans_strict.intersection(gold_spans_strict) - ) - total_span_predictions += len(predicted_spans_strict) - total_gold_spans += len(gold_spans_strict) - correct_predictions_strict += len( - predicted_annotations_strict.intersection(gold_annotations_strict) - ) - total_predictions_strict += len(predicted_annotations_strict) - - predicted_annotations = set( - [ - ( - triplet["subject"]["start"], - triplet["subject"]["end"], - -1, - triplet["relation"]["name"], - triplet["object"]["start"], - triplet["object"]["end"], - -1, - ) - for triplet in sample.predicted_relations - ] - ) - gold_annotations = set( - [ - ( - triplet["subject"]["start"], - triplet["subject"]["end"], - -1, - triplet["relation"]["name"], - triplet["object"]["start"], - triplet["object"]["end"], - -1, - ) - for triplet in sample.triplets - ] - ) - predicted_spans = set( - [(ss, se) for (ss, se, _) in sample.predicted_entities] - ) - gold_spans = set([(ss, se) for (ss, se, _) in sample.entities]) - total_gold_spans += len(gold_spans) - - correct_predictions_bound += len(predicted_spans.intersection(gold_spans)) - total_predictions_bound += len(predicted_spans) - - total_predictions += len(predicted_annotations) - total_gold += len(gold_annotations) - # correct relation extraction - correct_predictions += len( - predicted_annotations.intersection(gold_annotations) - ) - - span_precision, span_recall, span_f1 = compute_metrics( - correct_span_predictions, total_span_predictions, total_gold_spans - ) - bound_precision, bound_recall, bound_f1 = compute_metrics( - correct_predictions_bound, total_predictions_bound, total_gold_spans - ) - - precision, recall, f1 = compute_metrics( - correct_predictions, total_predictions, total_gold - ) - - if sample.entity_candidates: - precision_strict, recall_strict, f1_strict = compute_metrics( - correct_predictions_strict, total_predictions_strict, total_gold - ) - return { - "span-precision": span_precision, - "span-recall": span_recall, - "span-f1": span_f1, - "precision": precision, - "recall": recall, - "f1": f1, - "precision-strict": precision_strict, - "recall-strict": recall_strict, - "f1-strict": f1_strict, - } - else: - return { - "span-precision": bound_precision, - "span-recall": bound_recall, - "span-f1": bound_f1, - "precision": precision, - "recall": recall, - "f1": f1, - } - - -class REStrongMatchingCallback(Callback): - def __init__(self, dataset_path: str, dataset_conf) -> None: - super().__init__() - self.dataset_path = dataset_path - self.dataset_conf = dataset_conf - self.strong_matching_metric = StrongMatching() - - def on_validation_epoch_start(self, trainer, pl_module) -> None: - relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_re_model) - predicted_samples = relik_reader_predictor._predict( - self.dataset_path, - None, - self.dataset_conf, - ) - predicted_samples = list(predicted_samples) - for k, v in self.strong_matching_metric(predicted_samples).items(): - pl_module.log(f"val_{k}", v) diff --git a/relik/reader/utils/save_load_utilities.py b/relik/reader/utils/save_load_utilities.py deleted file mode 100644 index 1e635650c1f69c0e223d268f97ec9d6e0677742c..0000000000000000000000000000000000000000 --- a/relik/reader/utils/save_load_utilities.py +++ /dev/null @@ -1,76 +0,0 @@ -import argparse -import os -from typing import Tuple - -import omegaconf -import torch - -from relik.common.utils import from_cache -from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule -from relik.reader.relik_reader_core import RelikReaderCoreModel - -CKPT_FILE_NAME = "model.ckpt" -CONFIG_FILE_NAME = "cfg.yaml" - - -def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None: - if not os.path.exists(output_dir): - os.makedirs(output_dir) - else: - print(f"{output_dir} already exists, aborting operation") - exit(1) - - relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint( - pl_module_ckpt_path - ) - torch.save( - relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}" - ) - with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f: - omegaconf.OmegaConf.save( - omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f - ) - - -def load_model_and_conf( - model_dir_path: str, -) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]: - # TODO: quick workaround to load the model from HF hub - model_dir = from_cache( - model_dir_path, - filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME], - cache_dir=None, - force_download=False, - ) - - ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}" - model = torch.load(ckpt_path, map_location=torch.device("cpu")) - - model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}" - model_conf = omegaconf.OmegaConf.load(model_cfg_path) - return model, model_conf - - -def parse_arg() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument( - "--ckpt", - help="Path to the pytorch lightning ckpt you want to convert.", - required=True, - ) - parser.add_argument( - "--output-dir", - "-o", - help="The output dir to store the bare models and the config.", - required=True, - ) - return parser.parse_args() - - -def main(): - args = parse_arg() - convert_pl_module(args.ckpt, args.output_dir) - - -if __name__ == "__main__": - main() diff --git a/relik/reader/utils/special_symbols.py b/relik/reader/utils/special_symbols.py deleted file mode 100644 index 170909ad6cb2b69e1d6a8384af34cba441e60ce4..0000000000000000000000000000000000000000 --- a/relik/reader/utils/special_symbols.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import List - -NME_SYMBOL = "--NME--" - - -def get_special_symbols(num_entities: int) -> List[str]: - return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)] - - -def get_special_symbols_re(num_entities: int) -> List[str]: - return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)] diff --git a/relik/reader/utils/strong_matching_eval.py b/relik/reader/utils/strong_matching_eval.py deleted file mode 100644 index 88ad651134907c8067c174ee5f6fcbbb5fc2cb73..0000000000000000000000000000000000000000 --- a/relik/reader/utils/strong_matching_eval.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import Dict, List - -from lightning.pytorch.callbacks import Callback -from reader.data.relik_reader_sample import RelikReaderSample - -from relik.reader.relik_reader_predictor import RelikReaderPredictor -from relik.reader.utils.metrics import f1_measure, safe_divide -from relik.reader.utils.special_symbols import NME_SYMBOL - - -class StrongMatching: - def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: - # accumulators - correct_predictions = 0 - correct_predictions_at_k = 0 - total_predictions = 0 - total_gold = 0 - correct_span_predictions = 0 - miss_due_to_candidates = 0 - - # prediction index stats - avg_correct_predicted_index = [] - avg_wrong_predicted_index = [] - less_index_predictions = [] - - # collect data from samples - for sample in predicted_samples: - predicted_annotations = sample.predicted_window_labels_chars - predicted_annotations_probabilities = sample.probs_window_labels_chars - gold_annotations = { - (ss, se, entity) - for ss, se, entity in sample.window_labels - if entity != NME_SYMBOL - } - total_predictions += len(predicted_annotations) - total_gold += len(gold_annotations) - - # correct named entity detection - predicted_spans = {(s, e) for s, e, _ in predicted_annotations} - gold_spans = {(s, e) for s, e, _ in gold_annotations} - correct_span_predictions += len(predicted_spans.intersection(gold_spans)) - - # correct entity linking - correct_predictions += len( - predicted_annotations.intersection(gold_annotations) - ) - - for ss, se, ge in gold_annotations.difference(predicted_annotations): - if ge not in sample.window_candidates: - miss_due_to_candidates += 1 - if ge in predicted_annotations_probabilities.get((ss, se), set()): - correct_predictions_at_k += 1 - - # indices metrics - predicted_spans_index = { - (ss, se): ent for ss, se, ent in predicted_annotations - } - gold_spans_index = {(ss, se): ent for ss, se, ent in gold_annotations} - - for pred_span, pred_ent in predicted_spans_index.items(): - gold_ent = gold_spans_index.get(pred_span) - - if pred_span not in gold_spans_index: - continue - - # missing candidate - if gold_ent not in sample.window_candidates: - continue - - gold_idx = sample.window_candidates.index(gold_ent) - if gold_idx is None: - continue - pred_idx = sample.window_candidates.index(pred_ent) - - if gold_ent != pred_ent: - avg_wrong_predicted_index.append(pred_idx) - - if gold_idx is not None: - if pred_idx > gold_idx: - less_index_predictions.append(0) - else: - less_index_predictions.append(1) - - else: - avg_correct_predicted_index.append(pred_idx) - - # compute NED metrics - span_precision = safe_divide(correct_span_predictions, total_predictions) - span_recall = safe_divide(correct_span_predictions, total_gold) - span_f1 = f1_measure(span_precision, span_recall) - - # compute EL metrics - precision = safe_divide(correct_predictions, total_predictions) - recall = safe_divide(correct_predictions, total_gold) - recall_at_k = safe_divide( - (correct_predictions + correct_predictions_at_k), total_gold - ) - - f1 = f1_measure(precision, recall) - - wrong_for_candidates = safe_divide(miss_due_to_candidates, total_gold) - - out_dict = { - "span_precision": span_precision, - "span_recall": span_recall, - "span_f1": span_f1, - "core_precision": precision, - "core_recall": recall, - "core_recall-at-k": recall_at_k, - "core_f1": round(f1, 4), - "wrong-for-candidates": wrong_for_candidates, - "index_errors_avg-index": safe_divide( - sum(avg_wrong_predicted_index), len(avg_wrong_predicted_index) - ), - "index_correct_avg-index": safe_divide( - sum(avg_correct_predicted_index), len(avg_correct_predicted_index) - ), - "index_avg-index": safe_divide( - sum(avg_correct_predicted_index + avg_wrong_predicted_index), - len(avg_correct_predicted_index + avg_wrong_predicted_index), - ), - "index_percentage-favoured-smaller-idx": safe_divide( - sum(less_index_predictions), len(less_index_predictions) - ), - } - - return {k: round(v, 5) for k, v in out_dict.items()} - - -class ELStrongMatchingCallback(Callback): - def __init__(self, dataset_path: str, dataset_conf) -> None: - super().__init__() - self.dataset_path = dataset_path - self.dataset_conf = dataset_conf - self.strong_matching_metric = StrongMatching() - - def on_validation_epoch_start(self, trainer, pl_module) -> None: - relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_core_model) - predicted_samples = relik_reader_predictor.predict( - self.dataset_path, - samples=None, - dataset_conf=self.dataset_conf, - ) - predicted_samples = list(predicted_samples) - for k, v in self.strong_matching_metric(predicted_samples).items(): - pl_module.log(f"val_{k}", v) diff --git a/relik/retriever/__init__.py b/relik/retriever/__init__.py deleted file mode 100644 index 42a3df6b991b0af65ec5974fc4faa381b8e555b7..0000000000000000000000000000000000000000 --- a/relik/retriever/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from relik.retriever.pytorch_modules.model import GoldenRetriever diff --git a/relik/retriever/callbacks/__init__.py b/relik/retriever/callbacks/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/retriever/callbacks/base.py b/relik/retriever/callbacks/base.py deleted file mode 100644 index 43042c94bfc93ac32fb60b344ca644cd1c79c1f3..0000000000000000000000000000000000000000 --- a/relik/retriever/callbacks/base.py +++ /dev/null @@ -1,168 +0,0 @@ -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union - -import hydra -import lightning as pl -import torch -from lightning.pytorch.trainer.states import RunningStage -from omegaconf import DictConfig -from torch.utils.data import DataLoader, Dataset - -from relik.common.log import get_logger -from relik.retriever.data.base.datasets import BaseDataset - -logger = get_logger() - - -STAGES_COMPATIBILITY_MAP = { - "train": RunningStage.TRAINING, - "val": RunningStage.VALIDATING, - "test": RunningStage.TESTING, -} - -DEFAULT_STAGES = { - RunningStage.VALIDATING, - RunningStage.TESTING, - RunningStage.SANITY_CHECKING, - RunningStage.PREDICTING, -} - - -class PredictionCallback(pl.Callback): - def __init__( - self, - batch_size: int = 32, - stages: Optional[Set[Union[str, RunningStage]]] = None, - other_callbacks: Optional[ - Union[List[DictConfig], List["NLPTemplateCallback"]] - ] = None, - datasets: Optional[Union[DictConfig, BaseDataset]] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - *args, - **kwargs, - ): - super().__init__() - # parameters - self.batch_size = batch_size - self.datasets = datasets - self.dataloaders = dataloaders - - # callback initialization - if stages is None: - stages = DEFAULT_STAGES - - # compatibily stuff - stages = {STAGES_COMPATIBILITY_MAP.get(stage, stage) for stage in stages} - self.stages = [RunningStage(stage) for stage in stages] - self.other_callbacks = other_callbacks or [] - for i, callback in enumerate(self.other_callbacks): - if isinstance(callback, DictConfig): - self.other_callbacks[i] = hydra.utils.instantiate( - callback, _recursive_=False - ) - - @torch.no_grad() - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - *args, - **kwargs, - ) -> Any: - # it should return the predictions - raise NotImplementedError - - def on_validation_epoch_end( - self, trainer: pl.Trainer, pl_module: pl.LightningModule - ): - predictions = self(trainer, pl_module) - for callback in self.other_callbacks: - callback( - trainer=trainer, - pl_module=pl_module, - callback=self, - predictions=predictions, - ) - - def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - predictions = self(trainer, pl_module) - for callback in self.other_callbacks: - callback( - trainer=trainer, - pl_module=pl_module, - callback=self, - predictions=predictions, - ) - - @staticmethod - def _get_datasets_and_dataloaders( - dataset: Optional[Union[Dataset, DictConfig]], - dataloader: Optional[DataLoader], - trainer: pl.Trainer, - dataloader_kwargs: Optional[Dict[str, Any]] = None, - collate_fn: Optional[Callable] = None, - collate_fn_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[List[Dataset], List[DataLoader]]: - """ - Get the datasets and dataloaders from the datamodule or from the dataset provided. - - Args: - dataset (`Optional[Union[Dataset, DictConfig]]`): - The dataset to use. If `None`, the datamodule is used. - dataloader (`Optional[DataLoader]`): - The dataloader to use. If `None`, the datamodule is used. - trainer (`pl.Trainer`): - The trainer that contains the datamodule. - dataloader_kwargs (`Optional[Dict[str, Any]]`): - The kwargs to pass to the dataloader. - collate_fn (`Optional[Callable]`): - The collate function to use. - collate_fn_kwargs (`Optional[Dict[str, Any]]`): - The kwargs to pass to the collate function. - - Returns: - `Tuple[List[Dataset], List[DataLoader]]`: The datasets and dataloaders. - """ - # if a dataset is provided, use it - if dataset is not None: - dataloader_kwargs = dataloader_kwargs or {} - # get dataset - if isinstance(dataset, DictConfig): - dataset = hydra.utils.instantiate(dataset, _recursive_=False) - datasets = [dataset] if not isinstance(dataset, list) else dataset - if dataloader is not None: - dataloaders = ( - [dataloader] if isinstance(dataloader, DataLoader) else dataloader - ) - else: - collate_fn = collate_fn or partial( - datasets[0].collate_fn, **collate_fn_kwargs - ) - dataloader_kwargs["collate_fn"] = collate_fn - dataloaders = [DataLoader(datasets[0], **dataloader_kwargs)] - else: - # get the dataloaders and datasets from the datamodule - datasets = ( - trainer.datamodule.test_datasets - if trainer.state.stage == RunningStage.TESTING - else trainer.datamodule.val_datasets - ) - dataloaders = ( - trainer.test_dataloaders - if trainer.state.stage == RunningStage.TESTING - else trainer.val_dataloaders - ) - return datasets, dataloaders - - -class NLPTemplateCallback: - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - callback: PredictionCallback, - predictions: Dict[str, Any], - *args, - **kwargs, - ) -> Any: - raise NotImplementedError diff --git a/relik/retriever/callbacks/evaluation_callbacks.py b/relik/retriever/callbacks/evaluation_callbacks.py deleted file mode 100644 index 72d8fcb306f9e0e0fb58d1dbdb8ea6cfcea8c7e3..0000000000000000000000000000000000000000 --- a/relik/retriever/callbacks/evaluation_callbacks.py +++ /dev/null @@ -1,276 +0,0 @@ -import logging -from typing import Dict, List, Optional - -import lightning as pl -import torch -from lightning.pytorch.trainer.states import RunningStage -from sklearn.metrics import label_ranking_average_precision_score - -from relik.common.log import get_console_logger, get_logger -from relik.retriever.callbacks.base import DEFAULT_STAGES, NLPTemplateCallback - -console_logger = get_console_logger() -logger = get_logger(__name__, level=logging.INFO) - - -class RecallAtKEvaluationCallback(NLPTemplateCallback): - """ - Computes the recall at k for the predictions. Recall at k is computed as the number of - correct predictions in the top k predictions divided by the total number of correct - predictions. - - Args: - k (`int`): - The number of predictions to consider. - prefix (`str`, `optional`): - The prefix to add to the metrics. - verbose (`bool`, `optional`, defaults to `False`): - Whether to log the metrics. - prog_bar (`bool`, `optional`, defaults to `True`): - Whether to log the metrics to the progress bar. - """ - - def __init__( - self, - k: int = 100, - prefix: Optional[str] = None, - verbose: bool = False, - prog_bar: bool = True, - *args, - **kwargs, - ): - super().__init__() - self.k = k - self.prefix = prefix - self.verbose = verbose - self.prog_bar = prog_bar - - @torch.no_grad() - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - predictions: Dict, - *args, - **kwargs, - ) -> dict: - """ - Computes the recall at k for the predictions. - - Args: - trainer (:obj:`lightning.trainer.trainer.Trainer`): - The trainer object. - pl_module (:obj:`lightning.core.lightning.LightningModule`): - The lightning module. - predictions (:obj:`Dict`): - The predictions. - - Returns: - :obj:`Dict`: The computed metrics. - """ - if self.verbose: - logger.info(f"Computing recall@{self.k}") - - # metrics to return - metrics = {} - - stage = trainer.state.stage - if stage not in DEFAULT_STAGES: - raise ValueError( - f"Stage {stage} not supported, only `validate` and `test` are supported." - ) - - for dataloader_idx, samples in predictions.items(): - hits, total = 0, 0 - for sample in samples: - # compute the recall at k - # cut the predictions to the first k elements - predictions = sample["predictions"][: self.k] - hits += len(set(predictions) & set(sample["gold"])) - total += len(set(sample["gold"])) - - # compute the mean recall at k - recall_at_k = hits / total - metrics[f"recall@{self.k}_{dataloader_idx}"] = recall_at_k - metrics[f"recall@{self.k}"] = sum(metrics.values()) / len(metrics) - - if self.prefix is not None: - metrics = {f"{self.prefix}_{k}": v for k, v in metrics.items()} - else: - metrics = {f"{stage.value}_{k}": v for k, v in metrics.items()} - pl_module.log_dict( - metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar - ) - - if self.verbose: - logger.info( - f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}" - ) - - return metrics - - -class AvgRankingEvaluationCallback(NLPTemplateCallback): - """ - Computes the average ranking of the gold label in the predictions. Average ranking is - computed as the average of the rank of the gold label in the predictions. - - Args: - k (`int`): - The number of predictions to consider. - prefix (`str`, `optional`): - The prefix to add to the metrics. - stages (`List[str]`, `optional`): - The stages to compute the metrics on. Defaults to `["validate", "test"]`. - verbose (`bool`, `optional`, defaults to `False`): - Whether to log the metrics. - """ - - def __init__( - self, - k: int, - prefix: Optional[str] = None, - stages: Optional[List[str]] = None, - verbose: bool = True, - *args, - **kwargs, - ): - super().__init__() - self.k = k - self.prefix = prefix - self.verbose = verbose - self.stages = ( - [RunningStage(stage) for stage in stages] if stages else DEFAULT_STAGES - ) - - @torch.no_grad() - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - predictions: Dict, - *args, - **kwargs, - ) -> dict: - """ - Computes the average ranking of the gold label in the predictions. - - Args: - trainer (:obj:`lightning.trainer.trainer.Trainer`): - The trainer object. - pl_module (:obj:`lightning.core.lightning.LightningModule`): - The lightning module. - predictions (:obj:`Dict`): - The predictions. - - Returns: - :obj:`Dict`: The computed metrics. - """ - if not predictions: - logger.warning("No predictions to compute the AVG Ranking metrics.") - return {} - - if self.verbose: - logger.info(f"Computing AVG Ranking@{self.k}") - - # metrics to return - metrics = {} - - stage = trainer.state.stage - if stage not in self.stages: - raise ValueError( - f"Stage `{stage}` not supported, only `validate` and `test` are supported." - ) - - for dataloader_idx, samples in predictions.items(): - rankings = [] - for sample in samples: - window_candidates = sample["predictions"][: self.k] - window_labels = sample["gold"] - for wl in window_labels: - if wl in window_candidates: - rankings.append(window_candidates.index(wl) + 1) - - avg_ranking = sum(rankings) / len(rankings) if len(rankings) > 0 else 0 - metrics[f"avg_ranking@{self.k}_{dataloader_idx}"] = avg_ranking - if len(metrics) == 0: - metrics[f"avg_ranking@{self.k}"] = 0 - else: - metrics[f"avg_ranking@{self.k}"] = sum(metrics.values()) / len(metrics) - - prefix = self.prefix or stage.value - metrics = { - f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32) - for k, v in metrics.items() - } - pl_module.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=False) - - if self.verbose: - logger.info( - f"AVG Ranking@{self.k} on {prefix}: {metrics[f'{prefix}_avg_ranking@{self.k}']}" - ) - - return metrics - - -class LRAPEvaluationCallback(NLPTemplateCallback): - def __init__( - self, - k: int = 100, - prefix: Optional[str] = None, - verbose: bool = False, - prog_bar: bool = True, - *args, - **kwargs, - ): - super().__init__() - self.k = k - self.prefix = prefix - self.verbose = verbose - self.prog_bar = prog_bar - - @torch.no_grad() - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - predictions: Dict, - *args, - **kwargs, - ) -> dict: - if self.verbose: - logger.info(f"Computing recall@{self.k}") - - # metrics to return - metrics = {} - - stage = trainer.state.stage - if stage not in DEFAULT_STAGES: - raise ValueError( - f"Stage {stage} not supported, only `validate` and `test` are supported." - ) - - for dataloader_idx, samples in predictions.items(): - scores = [sample["scores"][: self.k] for sample in samples] - golds = [sample["gold"] for sample in samples] - - # compute the mean recall at k - lrap = label_ranking_average_precision_score(golds, scores) - metrics[f"lrap@{self.k}_{dataloader_idx}"] = lrap - metrics[f"lrap@{self.k}"] = sum(metrics.values()) / len(metrics) - - prefix = self.prefix or stage.value - metrics = { - f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32) - for k, v in metrics.items() - } - pl_module.log_dict( - metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar - ) - - if self.verbose: - logger.info( - f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}" - ) - - return metrics diff --git a/relik/retriever/callbacks/prediction_callbacks.py b/relik/retriever/callbacks/prediction_callbacks.py deleted file mode 100644 index f8a051ad396d07872dfac05998d1ec550724677a..0000000000000000000000000000000000000000 --- a/relik/retriever/callbacks/prediction_callbacks.py +++ /dev/null @@ -1,432 +0,0 @@ -import logging -import random -import time -from copy import deepcopy -from pathlib import Path -from typing import List, Optional, Set, Union - -import lightning as pl -import torch -from lightning.pytorch.trainer.states import RunningStage -from omegaconf import DictConfig -from torch.utils.data import DataLoader -from tqdm import tqdm - -from relik.common.log import get_console_logger, get_logger -from relik.retriever.callbacks.base import PredictionCallback -from relik.retriever.common.model_inputs import ModelInputs -from relik.retriever.data.base.datasets import BaseDataset -from relik.retriever.data.datasets import GoldenRetrieverDataset -from relik.retriever.data.utils import HardNegativesManager -from relik.retriever.indexers.base import BaseDocumentIndex -from relik.retriever.pytorch_modules.model import GoldenRetriever - -console_logger = get_console_logger() -logger = get_logger(__name__, level=logging.INFO) - - -class GoldenRetrieverPredictionCallback(PredictionCallback): - def __init__( - self, - k: Optional[int] = None, - batch_size: int = 32, - num_workers: int = 8, - document_index: Optional[BaseDocumentIndex] = None, - precision: Union[str, int] = 32, - force_reindex: bool = True, - retriever_dir: Optional[Path] = None, - stages: Optional[Set[Union[str, RunningStage]]] = None, - other_callbacks: Optional[List[DictConfig]] = None, - dataset: Optional[Union[DictConfig, BaseDataset]] = None, - dataloader: Optional[DataLoader] = None, - *args, - **kwargs, - ): - super().__init__(batch_size, stages, other_callbacks, dataset, dataloader) - self.k = k - self.num_workers = num_workers - self.document_index = document_index - self.precision = precision - self.force_reindex = force_reindex - self.retriever_dir = retriever_dir - - @torch.no_grad() - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - datasets: Optional[ - Union[DictConfig, BaseDataset, List[DictConfig], List[BaseDataset]] - ] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - *args, - **kwargs, - ) -> dict: - stage = trainer.state.stage - logger.info(f"Computing predictions for stage {stage.value}") - if stage not in self.stages: - raise ValueError( - f"Stage `{stage}` not supported, only {self.stages} are supported" - ) - - self.datasets, self.dataloaders = self._get_datasets_and_dataloaders( - datasets, - dataloaders, - trainer, - dataloader_kwargs=dict( - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - shuffle=False, - ), - ) - - # set the model to eval mode - pl_module.eval() - # get the retriever - retriever: GoldenRetriever = pl_module.model - - # here we will store the samples with predictions for each dataloader - dataloader_predictions = {} - # compute the passage embeddings index for each dataloader - for dataloader_idx, dataloader in enumerate(self.dataloaders): - current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx] - logger.info( - f"Computing passage embeddings for dataset {current_dataset.name}" - ) - # passages = self._get_passages_dataloader(current_dataset, trainer) - - tokenizer = current_dataset.tokenizer - - def collate_fn(x): - return ModelInputs( - tokenizer( - x, - truncation=True, - padding=True, - max_length=current_dataset.max_passage_length, - return_tensors="pt", - ) - ) - - # check if we need to reindex the passages and - # also if we need to load the retriever from disk - if (self.retriever_dir is not None and trainer.current_epoch == 0) or ( - self.retriever_dir is not None and stage == RunningStage.TESTING - ): - force_reindex = False - else: - force_reindex = self.force_reindex - - if ( - not force_reindex - and self.retriever_dir is not None - and stage == RunningStage.TESTING - ): - retriever = retriever.from_pretrained(self.retriever_dir) - # set the retriever to eval mode if we are loading it from disk - - # you never know :) - retriever.eval() - - retriever.index( - batch_size=self.batch_size, - num_workers=self.num_workers, - max_length=current_dataset.max_passage_length, - collate_fn=collate_fn, - precision=self.precision, - compute_on_cpu=False, - force_reindex=force_reindex, - ) - - # pl_module_original_device = pl_module.device - # if ( - # and pl_module.device.type == "cuda" - # ): - # pl_module.to("cpu") - - # now compute the question embeddings and compute the top-k accuracy - predictions = [] - start = time.time() - for batch in tqdm( - dataloader, - desc=f"Computing predictions for dataset {current_dataset.name}", - ): - batch = batch.to(pl_module.device) - # get the top-k indices - retriever_output = retriever.retrieve( - **batch.questions, k=self.k, precision=self.precision - ) - # compute recall at k - for batch_idx, retrieved_samples in enumerate(retriever_output): - # get the positive passages - gold_passages = batch["positives"][batch_idx] - # get the index of the gold passages in the retrieved passages - gold_passage_indices = [ - retriever.get_index_from_passage(passage) - for passage in gold_passages - ] - retrieved_indices = [r.index for r in retrieved_samples] - retrieved_passages = [r.label for r in retrieved_samples] - retrieved_scores = [r.score for r in retrieved_samples] - # correct predictions are the passages that are in the top-k and are gold - correct_indices = set(gold_passage_indices) & set(retrieved_indices) - # wrong predictions are the passages that are in the top-k and are not gold - wrong_indices = set(retrieved_indices) - set(gold_passage_indices) - # add the predictions to the list - prediction_output = dict( - sample_idx=batch.sample_idx[batch_idx], - gold=gold_passages, - predictions=retrieved_passages, - scores=retrieved_scores, - correct=[ - retriever.get_passage_from_index(i) for i in correct_indices - ], - wrong=[ - retriever.get_passage_from_index(i) for i in wrong_indices - ], - ) - predictions.append(prediction_output) - end = time.time() - logger.info(f"Time to retrieve: {str(end - start)}") - - dataloader_predictions[dataloader_idx] = predictions - - # if pl_module_original_device != pl_module.device: - # pl_module.to(pl_module_original_device) - - # return the predictions - return dataloader_predictions - - # @staticmethod - # def _get_passages_dataloader( - # indexer: Optional[BaseIndexer] = None, - # dataset: Optional[GoldenRetrieverDataset] = None, - # trainer: Optional[pl.Trainer] = None, - # ): - # if indexer is None: - # logger.info( - # f"Indexer is None, creating indexer from passages not found in dataset {dataset.name}, computing them from the dataloaders" - # ) - # # get the passages from the all the dataloader passage ids - # passages = set() # set to avoid duplicates - # for batch in trainer.train_dataloader: - # passages.update( - # [ - # " ".join(map(str, [c for c in passage_ids.tolist() if c != 0])) - # for passage_ids in batch["passages"]["input_ids"] - # ] - # ) - # for d in trainer.val_dataloaders: - # for batch in d: - # passages.update( - # [ - # " ".join( - # map(str, [c for c in passage_ids.tolist() if c != 0]) - # ) - # for passage_ids in batch["passages"]["input_ids"] - # ] - # ) - # for d in trainer.test_dataloaders: - # for batch in d: - # passages.update( - # [ - # " ".join( - # map(str, [c for c in passage_ids.tolist() if c != 0]) - # ) - # for passage_ids in batch["passages"]["input_ids"] - # ] - # ) - # passages = list(passages) - # else: - # passages = dataset.passages - # return passages - - -class NegativeAugmentationCallback(GoldenRetrieverPredictionCallback): - """ - Callback that computes the predictions of a retriever model on a dataset and computes the - negative examples for the training set. - - Args: - k (:obj:`int`, `optional`, defaults to 100): - The number of top-k retrieved passages to - consider for the evaluation. - batch_size (:obj:`int`, `optional`, defaults to 32): - The batch size to use for the evaluation. - num_workers (:obj:`int`, `optional`, defaults to 0): - The number of workers to use for the evaluation. - force_reindex (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to force the reindexing of the dataset. - retriever_dir (:obj:`Path`, `optional`): - The path to the retriever directory. If not specified, the retriever will be - initialized from scratch. - stages (:obj:`Set[str]`, `optional`): - The stages to run the callback on. If not specified, the callback will be run on - train, validation and test. - other_callbacks (:obj:`List[DictConfig]`, `optional`): - A list of other callbacks to run on the same stages. - dataset (:obj:`Union[DictConfig, BaseDataset]`, `optional`): - The dataset to use for the evaluation. If not specified, the dataset will be - initialized from scratch. - metrics_to_monitor (:obj:`List[str]`, `optional`): - The metrics to monitor for the evaluation. - threshold (:obj:`float`, `optional`, defaults to 0.8): - The threshold to consider. If the recall score of the retriever is above the - threshold, the negative examples will be added to the training set. - max_negatives (:obj:`int`, `optional`, defaults to 5): - The maximum number of negative examples to add to the training set. - add_with_probability (:obj:`float`, `optional`, defaults to 1.0): - The probability with which to add the negative examples to the training set. - refresh_every_n_epochs (:obj:`int`, `optional`, defaults to 1): - The number of epochs after which to refresh the index. - """ - - def __init__( - self, - k: int = 100, - batch_size: int = 32, - num_workers: int = 0, - force_reindex: bool = False, - retriever_dir: Optional[Path] = None, - stages: Set[Union[str, RunningStage]] = None, - other_callbacks: Optional[List[DictConfig]] = None, - dataset: Optional[Union[DictConfig, BaseDataset]] = None, - metrics_to_monitor: List[str] = None, - threshold: float = 0.8, - max_negatives: int = 5, - add_with_probability: float = 1.0, - refresh_every_n_epochs: int = 1, - *args, - **kwargs, - ): - super().__init__( - k=k, - batch_size=batch_size, - num_workers=num_workers, - force_reindex=force_reindex, - retriever_dir=retriever_dir, - stages=stages, - other_callbacks=other_callbacks, - dataset=dataset, - *args, - **kwargs, - ) - if metrics_to_monitor is None: - metrics_to_monitor = ["val_loss"] - self.metrics_to_monitor = metrics_to_monitor - self.threshold = threshold - self.max_negatives = max_negatives - self.add_with_probability = add_with_probability - self.refresh_every_n_epochs = refresh_every_n_epochs - - @torch.no_grad() - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - *args, - **kwargs, - ) -> dict: - """ - Computes the predictions of a retriever model on a dataset and computes the negative - examples for the training set. - - Args: - trainer (:obj:`pl.Trainer`): - The trainer object. - pl_module (:obj:`pl.LightningModule`): - The lightning module. - - Returns: - A dictionary containing the negative examples. - """ - stage = trainer.state.stage - if stage not in self.stages: - return {} - - if self.metrics_to_monitor not in trainer.logged_metrics: - raise ValueError( - f"Metric `{self.metrics_to_monitor}` not found in trainer.logged_metrics" - f"Available metrics: {trainer.logged_metrics.keys()}" - ) - if trainer.logged_metrics[self.metrics_to_monitor] < self.threshold: - return {} - - if trainer.current_epoch % self.refresh_every_n_epochs != 0: - return {} - - # if all( - # [ - # trainer.logged_metrics.get(metric) is None - # for metric in self.metrics_to_monitor - # ] - # ): - # raise ValueError( - # f"No metric from {self.metrics_to_monitor} not found in trainer.logged_metrics" - # f"Available metrics: {trainer.logged_metrics.keys()}" - # ) - - # if all( - # [ - # trainer.logged_metrics.get(metric) < self.threshold - # for metric in self.metrics_to_monitor - # if trainer.logged_metrics.get(metric) is not None - # ] - # ): - # return {} - - if trainer.current_epoch % self.refresh_every_n_epochs != 0: - return {} - - logger.info( - f"At least one metric from {self.metrics_to_monitor} is above threshold " - f"{self.threshold}. Computing hard negatives." - ) - - # make a copy of the dataset to avoid modifying the original one - trainer.datamodule.train_dataset.hn_manager = None - dataset_copy = deepcopy(trainer.datamodule.train_dataset) - predictions = super().__call__( - trainer, - pl_module, - datasets=dataset_copy, - dataloaders=DataLoader( - dataset_copy.to_torch_dataset(), - shuffle=False, - batch_size=None, - num_workers=self.num_workers, - pin_memory=True, - collate_fn=lambda x: x, - ), - *args, - **kwargs, - ) - logger.info(f"Computing hard negatives for epoch {trainer.current_epoch}") - # predictions is a dict with the dataloader index as key and the predictions as value - # since we only have one dataloader, we can get the predictions directly - predictions = list(predictions.values())[0] - # store the predictions in a dictionary for faster access based on the sample index - hard_negatives_list = {} - for prediction in tqdm(predictions, desc="Collecting hard negatives"): - if random.random() < 1 - self.add_with_probability: - continue - top_k_passages = prediction["predictions"] - gold_passages = prediction["gold"] - # get the ids of the max_negatives wrong passages with the highest similarity - wrong_passages = [ - passage_id - for passage_id in top_k_passages - if passage_id not in gold_passages - ][: self.max_negatives] - hard_negatives_list[prediction["sample_idx"]] = wrong_passages - - trainer.datamodule.train_dataset.hn_manager = HardNegativesManager( - tokenizer=trainer.datamodule.train_dataset.tokenizer, - max_length=trainer.datamodule.train_dataset.max_passage_length, - data=hard_negatives_list, - ) - - # normalize predictions as in the original GoldenRetrieverPredictionCallback - predictions = {0: predictions} - return predictions diff --git a/relik/retriever/callbacks/utils_callbacks.py b/relik/retriever/callbacks/utils_callbacks.py deleted file mode 100644 index ba73e0d9ee02d9e1424611551befc002bdaaecf3..0000000000000000000000000000000000000000 --- a/relik/retriever/callbacks/utils_callbacks.py +++ /dev/null @@ -1,287 +0,0 @@ -import json -import logging -import os -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import lightning as pl -import torch -from lightning.pytorch.trainer.states import RunningStage - -from relik.common.log import get_console_logger, get_logger -from relik.retriever.callbacks.base import NLPTemplateCallback, PredictionCallback -from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel - -console_logger = get_console_logger() -logger = get_logger(__name__, level=logging.INFO) - - -class SavePredictionsCallback(NLPTemplateCallback): - def __init__( - self, - saving_dir: Optional[Union[str, os.PathLike]] = None, - verbose: bool = False, - *args, - **kwargs, - ): - super().__init__() - self.saving_dir = saving_dir - self.verbose = verbose - - @torch.no_grad() - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - predictions: Dict, - callback: PredictionCallback, - *args, - **kwargs, - ) -> dict: - # write the predictions to a file inside the experiment folder - if self.saving_dir is None and trainer.logger is None: - logger.info( - "You need to specify an output directory (`saving_dir`) or a logger to save the predictions.\n" - "Skipping saving predictions." - ) - return - datasets = callback.datasets - for dataloader_idx, dataloader_predictions in predictions.items(): - # save to file - if self.saving_dir is not None: - prediction_folder = Path(self.saving_dir) - else: - try: - prediction_folder = ( - Path(trainer.logger.experiment.dir) / "predictions" - ) - except Exception: - logger.info( - "You need to specify an output directory (`saving_dir`) or a logger to save the predictions.\n" - "Skipping saving predictions." - ) - return - prediction_folder.mkdir(exist_ok=True) - predictions_path = ( - prediction_folder - / f"{datasets[dataloader_idx].name}_{dataloader_idx}.json" - ) - if self.verbose: - logger.info(f"Saving predictions to {predictions_path}") - with open(predictions_path, "w") as f: - for prediction in dataloader_predictions: - for k, v in prediction.items(): - if isinstance(v, set): - prediction[k] = list(v) - f.write(json.dumps(prediction) + "\n") - - -class ResetModelCallback(pl.Callback): - def __init__( - self, - question_encoder: str, - passage_encoder: Optional[str] = None, - verbose: bool = True, - ) -> None: - super().__init__() - self.question_encoder = question_encoder - self.passage_encoder = passage_encoder - self.verbose = verbose - - def on_train_epoch_start( - self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs - ) -> None: - if trainer.current_epoch == 0: - if self.verbose: - logger.info("Current epoch is 0, skipping resetting model") - return - - if self.verbose: - logger.info("Resetting model, optimizer and lr scheduler") - # reload model from scratch - previous_device = pl_module.device - trainer.model.model.question_encoder = GoldenRetrieverModel.from_pretrained( - self.question_encoder - ) - trainer.model.model.question_encoder.to(previous_device) - if self.passage_encoder is not None: - trainer.model.model.passage_encoder = GoldenRetrieverModel.from_pretrained( - self.passage_encoder - ) - trainer.model.model.passage_encoder.to(previous_device) - - trainer.strategy.setup_optimizers(trainer) - - -class FreeUpIndexerVRAMCallback(pl.Callback): - def __call__( - self, - pl_module: pl.LightningModule, - *args, - **kwargs, - ) -> Any: - logger.info("Freeing up GPU memory") - - # remove the index from the GPU memory - # remove the embeddings from the GPU memory first - if pl_module.model.document_index is not None: - if pl_module.model.document_index.embeddings is not None: - pl_module.model.document_index.embeddings.cpu() - pl_module.model.document_index.embeddings = None - - import gc - - gc.collect() - torch.cuda.empty_cache() - - def on_train_epoch_start( - self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs - ) -> None: - return self(pl_module) - - def on_test_epoch_start( - self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs - ) -> None: - return self(pl_module) - - -class ShuffleTrainDatasetCallback(pl.Callback): - def __init__(self, seed: int = 42, verbose: bool = True) -> None: - super().__init__() - self.seed = seed - self.verbose = verbose - self.previous_epoch = -1 - - def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): - if self.verbose: - if trainer.current_epoch != self.previous_epoch: - logger.info(f"Shuffling train dataset at epoch {trainer.current_epoch}") - - # logger.info(f"Shuffling train dataset at epoch {trainer.current_epoch}") - if trainer.current_epoch != self.previous_epoch: - trainer.datamodule.train_dataset.shuffle_data( - seed=self.seed + trainer.current_epoch + 1 - ) - self.previous_epoch = trainer.current_epoch - - -class PrefetchTrainDatasetCallback(pl.Callback): - def __init__(self, verbose: bool = True) -> None: - super().__init__() - self.verbose = verbose - # self.previous_epoch = -1 - - def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): - if trainer.datamodule.train_dataset.prefetch_batches: - if self.verbose: - # if trainer.current_epoch != self.previous_epoch: - logger.info( - f"Prefetching train dataset at epoch {trainer.current_epoch}" - ) - # if trainer.current_epoch != self.previous_epoch: - trainer.datamodule.train_dataset.prefetch() - self.previous_epoch = trainer.current_epoch - - -class SubsampleTrainDatasetCallback(pl.Callback): - def __init__(self, seed: int = 43, verbose: bool = True) -> None: - super().__init__() - self.seed = seed - self.verbose = verbose - - def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): - if self.verbose: - logger.info(f"Subsampling train dataset at epoch {trainer.current_epoch}") - trainer.datamodule.train_dataset.random_subsample( - seed=self.seed + trainer.current_epoch + 1 - ) - - -class SaveRetrieverCallback(pl.Callback): - def __init__( - self, - saving_dir: Optional[Union[str, os.PathLike]] = None, - verbose: bool = True, - *args, - **kwargs, - ): - super().__init__() - self.saving_dir = saving_dir - self.verbose = verbose - self.free_up_indexer_callback = FreeUpIndexerVRAMCallback() - - @torch.no_grad() - def __call__( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - *args, - **kwargs, - ): - if self.saving_dir is None and trainer.logger is None: - logger.info( - "You need to specify an output directory (`saving_dir`) or a logger to save the retriever.\n" - "Skipping saving retriever." - ) - return - if self.saving_dir is not None: - retriever_folder = Path(self.saving_dir) - else: - try: - retriever_folder = Path(trainer.logger.experiment.dir) / "retriever" - except Exception: - logger.info( - "You need to specify an output directory (`saving_dir`) or a logger to save the retriever.\n" - "Skipping saving retriever." - ) - return - retriever_folder.mkdir(exist_ok=True, parents=True) - if self.verbose: - logger.info(f"Saving retriever to {retriever_folder}") - pl_module.model.save_pretrained(retriever_folder) - - def on_save_checkpoint( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - checkpoint: Dict[str, Any], - ): - self(trainer, pl_module) - # self.free_up_indexer_callback(pl_module) - - -class SampleNegativesDatasetCallback(pl.Callback): - def __init__(self, seed: int = 42, verbose: bool = True) -> None: - super().__init__() - self.seed = seed - self.verbose = verbose - - def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): - if self.verbose: - f"Sampling negatives for train dataset at epoch {trainer.current_epoch}" - trainer.datamodule.train_dataset.sample_dataset_negatives( - seed=self.seed + trainer.current_epoch - ) - - -class SubsampleDataCallback(pl.Callback): - def __init__(self, seed: int = 42, verbose: bool = True) -> None: - super().__init__() - self.seed = seed - self.verbose = verbose - - def on_validation_epoch_start(self, trainer: pl.Trainer, *args, **kwargs): - if self.verbose: - f"Subsampling data for train dataset at epoch {trainer.current_epoch}" - if trainer.state.stage == RunningStage.SANITY_CHECKING: - return - trainer.datamodule.train_dataset.subsample_data( - seed=self.seed + trainer.current_epoch - ) - - def on_fit_start(self, trainer: pl.Trainer, *args, **kwargs): - if self.verbose: - f"Subsampling data for train dataset at epoch {trainer.current_epoch}" - trainer.datamodule.train_dataset.subsample_data( - seed=self.seed + trainer.current_epoch - ) diff --git a/relik/retriever/common/__init__.py b/relik/retriever/common/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/retriever/common/model_inputs.py b/relik/retriever/common/model_inputs.py deleted file mode 100644 index 06b28f74e9594df32f98f8a32a0b46177db54062..0000000000000000000000000000000000000000 --- a/relik/retriever/common/model_inputs.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -from collections import UserDict -from typing import Any, Union - -import torch -from lightning.fabric.utilities import move_data_to_device - -from relik.common.log import get_console_logger - -logger = get_console_logger() - - -class ModelInputs(UserDict): - """Model input dictionary wrapper.""" - - def __getattr__(self, item: str): - try: - return self.data[item] - except KeyError: - raise AttributeError(f"`ModelInputs` has no attribute `{item}`") - - def __getitem__(self, item: str) -> Any: - return self.data[item] - - def __getstate__(self): - return {"data": self.data} - - def __setstate__(self, state): - if "data" in state: - self.data = state["data"] - - def keys(self): - """A set-like object providing a view on D's keys.""" - return self.data.keys() - - def values(self): - """An object providing a view on D's values.""" - return self.data.values() - - def items(self): - """A set-like object providing a view on D's items.""" - return self.data.items() - - def to(self, device: Union[str, torch.device]) -> ModelInputs: - """ - Send all tensors values to device. - Args: - device (`str` or `torch.device`): The device to put the tensors on. - Returns: - :class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs` - after modification. - """ - self.data = move_data_to_device(self.data, device) - return self diff --git a/relik/retriever/common/sampler.py b/relik/retriever/common/sampler.py deleted file mode 100644 index 024c57b23da6db71dd76929226b005f75b9e98f5..0000000000000000000000000000000000000000 --- a/relik/retriever/common/sampler.py +++ /dev/null @@ -1,108 +0,0 @@ -import math - -from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler - - -def identity(x): - return x - - -class SortedSampler(Sampler): - """ - Samples elements sequentially, always in the same order. - - Args: - data (`obj`: `Iterable`): - Iterable data. - sort_key (`obj`: `Callable`): - Specifies a function of one argument that is used to - extract a numerical comparison key from each list element. - - Example: - >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) - [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] - - """ - - def __init__(self, data, sort_key=identity): - super().__init__(data) - self.data = data - self.sort_key = sort_key - zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] - zip_ = sorted(zip_, key=lambda r: r[1]) - self.sorted_indexes = [item[0] for item in zip_] - - def __iter__(self): - return iter(self.sorted_indexes) - - def __len__(self): - return len(self.data) - - -class BucketBatchSampler(BatchSampler): - """ - `BucketBatchSampler` toggles between `sampler` batches and sorted batches. - Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between - random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice - versa. - Background: - ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like - ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar - size length to reduce the padding required for each batch while maintaining some noise - through bucketing. - **AllenNLP Implementation:** - https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py - **torchtext Implementation:** - https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 - - Args: - sampler (`obj`: `torch.data.utils.sampler.Sampler): - batch_size (`int`): - Size of mini-batch. - drop_last (`bool`, optional, defaults to `False`): - If `True` the sampler will drop the last batch if its size would be less than `batch_size`. - sort_key (`obj`: `Callable`, optional, defaults to `identity`): - Callable to specify a comparison key for sorting. - bucket_size_multiplier (`int`, optional, defaults to `100`): - Buckets are of size `batch_size * bucket_size_multiplier`. - Example: - >>> from torchnlp.random import set_seed - >>> set_seed(123) - >>> - >>> from torch.utils.data.sampler import SequentialSampler - >>> sampler = SequentialSampler(list(range(10))) - >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) - [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] - >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) - [[0, 1, 2], [3, 4, 5], [6, 7, 8]] - - """ - - def __init__( - self, - sampler, - batch_size, - drop_last: bool = False, - sort_key=identity, - bucket_size_multiplier=100, - ): - super().__init__(sampler, batch_size, drop_last) - self.sort_key = sort_key - _bucket_size = batch_size * bucket_size_multiplier - if hasattr(sampler, "__len__"): - _bucket_size = min(_bucket_size, len(sampler)) - self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) - - def __iter__(self): - for bucket in self.bucket_sampler: - sorted_sampler = SortedSampler(bucket, self.sort_key) - for batch in SubsetRandomSampler( - list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last)) - ): - yield [bucket[i] for i in batch] - - def __len__(self): - if self.drop_last: - return len(self.sampler) // self.batch_size - else: - return math.ceil(len(self.sampler) / self.batch_size) diff --git a/relik/retriever/conf/data/aida_dataset.yaml b/relik/retriever/conf/data/aida_dataset.yaml deleted file mode 100644 index 22fcc6458a2dc757b569baef37846751dd3c1c7a..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/data/aida_dataset.yaml +++ /dev/null @@ -1,47 +0,0 @@ -shared_params: - passages_path: null - max_passage_length: 64 - passage_batch_size: 64 - question_batch_size: 64 - use_topics: False - -datamodule: - _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule - datasets: - train: - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset - name: "train" - path: null - tokenizer: ${model.language_model} - max_passage_length: ${data.shared_params.max_passage_length} - question_batch_size: ${data.shared_params.question_batch_size} - passage_batch_size: ${data.shared_params.passage_batch_size} - subsample_strategy: null - subsample_portion: 0.1 - shuffle: True - use_topics: ${data.shared_params.use_topics} - - val: - - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset - name: "val" - path: null - tokenizer: ${model.language_model} - max_passage_length: ${data.shared_params.max_passage_length} - question_batch_size: ${data.shared_params.question_batch_size} - passage_batch_size: ${data.shared_params.passage_batch_size} - use_topics: ${data.shared_params.use_topics} - - test: - - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset - name: "test" - path: null - tokenizer: ${model.language_model} - max_passage_length: ${data.shared_params.max_passage_length} - question_batch_size: ${data.shared_params.question_batch_size} - passage_batch_size: ${data.shared_params.passage_batch_size} - use_topics: ${data.shared_params.use_topics} - - num_workers: - train: 4 - val: 4 - test: 4 diff --git a/relik/retriever/conf/data/dataset_v2.yaml b/relik/retriever/conf/data/dataset_v2.yaml deleted file mode 100644 index 6040616d1f92182c2d002c065a7b89dc240f96b3..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/data/dataset_v2.yaml +++ /dev/null @@ -1,43 +0,0 @@ -shared_params: - passages_path: null - max_passage_length: 64 - passage_batch_size: 64 - question_batch_size: 64 - -datamodule: - _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule - datasets: - train: - _target_: relik.retriever.data.datasets.InBatchNegativesDataset - name: "train" - path: null - tokenizer: ${model.language_model} - max_passage_length: ${data.shared_params.max_passage_length} - question_batch_size: ${data.shared_params.question_batch_size} - passage_batch_size: ${data.shared_params.passage_batch_size} - subsample_strategy: null - subsample_portion: 0.1 - shuffle: True - - val: - - _target_: relik.retriever.data.datasets.InBatchNegativesDataset - name: "val" - path: null - tokenizer: ${model.language_model} - max_passage_length: ${data.shared_params.max_passage_length} - question_batch_size: ${data.shared_params.question_batch_size} - passage_batch_size: ${data.shared_params.passage_batch_size} - - test: - - _target_: relik.retriever.data.datasets.InBatchNegativesDataset - name: "test" - path: null - tokenizer: ${model.language_model} - max_passage_length: ${data.shared_params.max_passage_length} - question_batch_size: ${data.shared_params.question_batch_size} - passage_batch_size: ${data.shared_params.passage_batch_size} - - num_workers: - train: 0 - val: 0 - test: 0 diff --git a/relik/retriever/conf/data/dpr_like.yaml b/relik/retriever/conf/data/dpr_like.yaml deleted file mode 100644 index d316b9868b0e64ea3e58725fc56549a2b31436be..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/data/dpr_like.yaml +++ /dev/null @@ -1,31 +0,0 @@ -datamodule: - _target_: relik.retriever.goldenretriever.lightning_modules.pl_data_modules.PLDataModule - tokenizer: ${model.language_model} - datasets: - train: - _target_: relik.retriever.data.dpr.datasets.DPRDataset - name: "train" - passages_path: ${data_overrides.passages_path} - path: ${data_overrides.train_path} - - val: - - _target_: relik.retriever.data.dpr.datasets.DPRDataset - name: "val" - passages_path: ${data_overrides.passages_path} - path: ${data_overrides.val_path} - - test: - - _target_: relik.retriever.data.dpr.datasets.DPRDataset - name: "test" - passages_path: ${data_overrides.passages_path} - path: ${data_overrides.test_path} - - batch_sizes: - train: 32 - val: 64 - test: 64 - - num_workers: - train: 4 - val: 4 - test: 4 diff --git a/relik/retriever/conf/data/in_batch_negatives.yaml b/relik/retriever/conf/data/in_batch_negatives.yaml deleted file mode 100644 index 1a90cb4619604e9d93df47f490f5948ebdfbf312..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/data/in_batch_negatives.yaml +++ /dev/null @@ -1,48 +0,0 @@ -shared_params: - passages_path: null - max_passage_length: 64 - prefetch_batches: True - use_topics: False - -datamodule: - _target_: goldenretriever.lightning_modules.pl_data_modules.PLDataModule - tokenizer: ${model.language_model} - datasets: - train: - _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset - name: "train" - path: null - passages_path: ${data.shared_params.passages_path} - max_passage_length: ${data.shared_params.max_passage_length} - prefetch_batches: ${data.shared_params.prefetch_batches} - subsample: null - shuffle: True - use_topics: ${data.shared_params.use_topics} - - val: - - _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset - name: "val" - path: null - passages_path: ${data.shared_params.passages_path} - max_passage_length: ${data.shared_params.max_passage_length} - prefetch_batches: ${data.shared_params.prefetch_batches} - use_topics: ${data.shared_params.use_topics} - - test: - - _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset - name: "test" - path: null - passages_path: ${data.shared_params.passages_path} - max_passage_length: ${data.shared_params.max_passage_length} - prefetch_batches: ${data.shared_params.prefetch_batches} - use_topics: ${data.shared_params.use_topics} - - batch_sizes: - train: 64 - val: 64 - test: 64 - - num_workers: - train: 4 - val: 4 - test: 4 diff --git a/relik/retriever/conf/data/iterable_in_batch_negatives.yaml b/relik/retriever/conf/data/iterable_in_batch_negatives.yaml deleted file mode 100644 index 397fda31c4b42ef5cd22c10805200f4d19cd590d..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/data/iterable_in_batch_negatives.yaml +++ /dev/null @@ -1,52 +0,0 @@ -shared_params: - passages_path: null - max_passage_length: 64 - max_passages_per_batch: 64 - max_questions_per_batch: 64 - prefetch_batches: True - use_topics: False - -datamodule: - _target_: relik.retriever.lightning_modules.pl_data_modules.PLDataModule - tokenizer: ${model.language_model} - datasets: - train: - _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset - name: "train" - path: null - passages_path: ${data.shared_params.passages_path} - max_passage_length: ${data.shared_params.max_passage_length} - max_questions_per_batch: ${data.shared_params.max_questions_per_batch} - max_passages_per_batch: ${data.shared_params.max_passages_per_batch} - prefetch_batches: ${data.shared_params.prefetch_batches} - subsample: null - random_subsample: False - shuffle: True - use_topics: ${data.shared_params.use_topics} - - val: - - _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset - name: "val" - path: null - passages_path: ${data.shared_params.passages_path} - max_passage_length: ${data.shared_params.max_passage_length} - max_questions_per_batch: ${data.shared_params.max_questions_per_batch} - max_passages_per_batch: ${data.shared_params.max_passages_per_batch} - prefetch_batches: ${data.shared_params.prefetch_batches} - use_topics: ${data.shared_params.use_topics} - - test: - - _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset - name: "test" - path: null - passages_path: ${data.shared_params.passages_path} - max_passage_length: ${data.shared_params.max_passage_length} - max_questions_per_batch: ${data.shared_params.max_questions_per_batch} - max_passages_per_batch: ${data.shared_params.max_passages_per_batch} - prefetch_batches: ${data.shared_params.prefetch_batches} - use_topics: ${data.shared_params.use_topics} - - num_workers: - train: 0 - val: 0 - test: 0 diff --git a/relik/retriever/conf/data/sampled_negatives.yaml b/relik/retriever/conf/data/sampled_negatives.yaml deleted file mode 100644 index 4f581a4a6488a7ad516aca08a5d0a0fba5906fe1..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/data/sampled_negatives.yaml +++ /dev/null @@ -1,39 +0,0 @@ -max_passages: 64 - -datamodule: - _target_: relik.retriever.lightning_modules.pl_data_modules.PLDataModule - tokenizer: ${model.language_model} - datasets: - train: - _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset - name: "train" - passages_path: ${data_overrides.passages_path} - max_passage_length: 64 - max_passages: ${data.max_passages} - path: ${data_overrides.train_path} - - val: - - _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset - name: "val" - passages_path: ${data_overrides.passages_path} - max_passage_length: 64 - max_passages: ${data.max_passages} - path: ${data_overrides.val_path} - - test: - - _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset - name: "test" - passages_path: ${data_overrides.passages_path} - max_passage_length: 64 - max_passages: ${data.max_passages} - path: ${data_overrides.test_path} - - batch_sizes: - train: 4 - val: 64 - test: 64 - - num_workers: - train: 4 - val: 4 - test: 4 diff --git a/relik/retriever/conf/finetune_iterable_in_batch.yaml b/relik/retriever/conf/finetune_iterable_in_batch.yaml deleted file mode 100644 index 7bcaa034ca3dca3e82d6efa6d6b674839f2ec880..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/finetune_iterable_in_batch.yaml +++ /dev/null @@ -1,117 +0,0 @@ -# Required to make the "experiments" dir the default one for the output of the models -hydra: - run: - dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - -model_name: ${model.language_model} # used to name the model in wandb -project_name: relik-retriever # used to name the project in wandb - -defaults: - - _self_ - - model: golden_retriever - - index: inmemory - - loss: nce_loss - - optimizer: radamw - - scheduler: linear_scheduler - - data: dataset_v2 # iterable_in_batch_negatives #dataset_v2 - - logging: wandb_logging - - override hydra/job_logging: colorlog - - override hydra/hydra_logging: colorlog - -train: - # reproducibility - seed: 42 - set_determinism_the_old_way: False - # torch parameters - float32_matmul_precision: "medium" - # if true, only test the model - only_test: False - # if provided, initialize the model with the weights from the checkpoint - pretrain_ckpt_path: null - # if provided, start training from the checkpoint - checkpoint_path: null - - # task specific parameter - top_k: 100 - - # pl_trainer - pl_trainer: - _target_: lightning.Trainer - accelerator: gpu - devices: 1 - num_nodes: 1 - strategy: auto - accumulate_grad_batches: 1 - gradient_clip_val: 1.0 - val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps - check_val_every_n_epoch: 1 - max_epochs: 0 - max_steps: 25_000 - deterministic: True - fast_dev_run: False - precision: 16 - reload_dataloaders_every_n_epochs: 1 - - early_stopping_callback: - # null - _target_: lightning.callbacks.EarlyStopping - monitor: validate_recall@${train.top_k} - mode: max - patience: 3 - - model_checkpoint_callback: - _target_: lightning.callbacks.ModelCheckpoint - monitor: validate_recall@${train.top_k} - mode: max - verbose: True - save_top_k: 1 - save_last: False - filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" - auto_insert_metric_name: False - - callbacks: - prediction_callback: - _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback - k: ${train.top_k} - batch_size: 64 - precision: 16 - index_precision: 16 - other_callbacks: - - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback - k: ${train.top_k} - verbose: True - - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback - k: 50 - verbose: True - prog_bar: False - - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback - k: ${train.top_k} - verbose: True - - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback - - hard_negatives_callback: - _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback - k: ${train.top_k} - batch_size: 64 - precision: 16 - index_precision: 16 - stages: [validate] #[validate, sanity_check] - metrics_to_monitor: - validate_recall@${train.top_k} - # - sanity_check_recall@${train.top_k} - threshold: 0.0 - max_negatives: 20 - add_with_probability: 1.0 - refresh_every_n_epochs: 1 - other_callbacks: - - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback - k: ${train.top_k} - verbose: True - prefix: "train" - - utils_callbacks: - - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback - - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback - # - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback - # question_encoder: ${model.pl_module.model.question_encoder} - # passage_encoder: ${model.pl_module.model.passage_encoder} diff --git a/relik/retriever/conf/index/inmemory.yaml b/relik/retriever/conf/index/inmemory.yaml deleted file mode 100644 index d77f5b79946a82384cb4da0c0f60b7b2700e9280..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/index/inmemory.yaml +++ /dev/null @@ -1,4 +0,0 @@ -_target_: relik.retriever.indexers.inmemory.InMemoryDocumentIndex -documents: ${data.shared_params.passages_path} -device: cuda -precision: 16 diff --git a/relik/retriever/conf/logging/wandb_logging.yaml b/relik/retriever/conf/logging/wandb_logging.yaml deleted file mode 100644 index 1908d7e09789ca0b8e4973ec7f3ca5d47d460af3..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/logging/wandb_logging.yaml +++ /dev/null @@ -1,16 +0,0 @@ -# don't forget loggers.login() for the first usage. - -log: True # set to False to avoid the logging - -wandb_arg: - _target_: lightning.loggers.WandbLogger - name: ${model_name} - project: ${project_name} - save_dir: ./ - log_model: True - mode: "online" - entity: null - -watch: - log: "all" - log_freq: 100 diff --git a/relik/retriever/conf/loss/nce_loss.yaml b/relik/retriever/conf/loss/nce_loss.yaml deleted file mode 100644 index fe9246b88027fd0d499b9bc3b4beaa7937b23586..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/loss/nce_loss.yaml +++ /dev/null @@ -1 +0,0 @@ -_target_: relik.retriever.pythorch_modules.losses.MultiLabelNCELoss diff --git a/relik/retriever/conf/loss/nll_loss.yaml b/relik/retriever/conf/loss/nll_loss.yaml deleted file mode 100644 index 1e0a5010025a4a6e9da382e5408af4a58cda6185..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/loss/nll_loss.yaml +++ /dev/null @@ -1 +0,0 @@ -_target_: torch.nn.NLLLoss diff --git a/relik/retriever/conf/optimizer/adamw.yaml b/relik/retriever/conf/optimizer/adamw.yaml deleted file mode 100644 index ff0f84e15ebd6c60e3e6e411c30e88fb910fdb29..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/optimizer/adamw.yaml +++ /dev/null @@ -1,4 +0,0 @@ -_target_: torch.optim.AdamW -lr: 1e-5 -weight_decay: 0.01 -fused: False diff --git a/relik/retriever/conf/optimizer/radam.yaml b/relik/retriever/conf/optimizer/radam.yaml deleted file mode 100644 index b5d2a4ecf468327bda98fd205bf483b6828cf653..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/optimizer/radam.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: torch.optim.RAdam -lr: 1e-5 -weight_decay: 0 diff --git a/relik/retriever/conf/optimizer/radamw.yaml b/relik/retriever/conf/optimizer/radamw.yaml deleted file mode 100644 index 6f1fc8c4696bf793180366a64baf107829cf7752..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/optimizer/radamw.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: relik.retriever.pytorch_modules.optim.RAdamW -lr: 1e-5 -weight_decay: 0.01 diff --git a/relik/retriever/conf/pretrain_iterable_in_batch.yaml b/relik/retriever/conf/pretrain_iterable_in_batch.yaml deleted file mode 100644 index c003d612cf718f4b4b84e866009a2944a5c22d9a..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/pretrain_iterable_in_batch.yaml +++ /dev/null @@ -1,114 +0,0 @@ -# Required to make the "experiments" dir the default one for the output of the models -hydra: - run: - dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - -model_name: ${model.language_model} # used to name the model in wandb -project_name: relik-retriever # used to name the project in wandb - -defaults: - - _self_ - - model: golden_retriever - - index: inmemory - - loss: nce_loss - - optimizer: radamw - - scheduler: linear_scheduler - - data: dataset_v2 - - logging: wandb_logging - - override hydra/job_logging: colorlog - - override hydra/hydra_logging: colorlog - -train: - # reproducibility - seed: 42 - set_determinism_the_old_way: False - # torch parameters - float32_matmul_precision: "medium" - # if true, only test the model - only_test: False - # if provided, initialize the model with the weights from the checkpoint - pretrain_ckpt_path: null - # if provided, start training from the checkpoint - checkpoint_path: null - - # task specific parameter - top_k: 100 - - # pl_trainer - pl_trainer: - _target_: lightning.Trainer - accelerator: gpu - devices: 1 - num_nodes: 1 - strategy: auto - accumulate_grad_batches: 1 - gradient_clip_val: 1.0 - val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps - check_val_every_n_epoch: 1 - max_epochs: 0 - max_steps: 220_000 - deterministic: True - fast_dev_run: False - precision: 16 - reload_dataloaders_every_n_epochs: 1 - - early_stopping_callback: - null - # _target_: lightning.callbacks.EarlyStopping - # monitor: validate_recall@${train.top_k} - # mode: max - # patience: 15 - - model_checkpoint_callback: - _target_: lightning.callbacks.ModelCheckpoint - monitor: validate_recall@${train.top_k} - mode: max - verbose: True - save_top_k: 1 - save_last: True - filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" - auto_insert_metric_name: False - - callbacks: - prediction_callback: - _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback - k: ${train.top_k} - batch_size: 128 - precision: 16 - index_precision: 16 - other_callbacks: - - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback - k: ${train.top_k} - verbose: True - - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback - k: 50 - verbose: True - prog_bar: False - - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback - k: ${train.top_k} - verbose: True - - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback - - hard_negatives_callback: - _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback - k: ${train.top_k} - batch_size: 128 - precision: 16 - index_precision: 16 - stages: [validate] #[validate, sanity_check] - metrics_to_monitor: - validate_recall@${train.top_k} - # - sanity_check_recall@${train.top_k} - threshold: 0.0 - max_negatives: 15 - add_with_probability: 0.2 - refresh_every_n_epochs: 1 - other_callbacks: - - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback - k: ${train.top_k} - verbose: True - prefix: "train" - - utils_callbacks: - - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback - - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback diff --git a/relik/retriever/conf/scheduler/linear_scheduler.yaml b/relik/retriever/conf/scheduler/linear_scheduler.yaml deleted file mode 100644 index d1896bff5d01ee2543639e4e379674d60682a0f6..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/scheduler/linear_scheduler.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: transformers.get_linear_schedule_with_warmup -num_warmup_steps: 0 -num_training_steps: ${train.pl_trainer.max_steps} diff --git a/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml b/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml deleted file mode 100644 index 417857489486469032b7e8b19d509a1e45da043c..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: transformers.get_linear_schedule_with_warmup -num_warmup_steps: 5_000 -num_training_steps: ${train.pl_trainer.max_steps} diff --git a/relik/retriever/conf/scheduler/none.yaml b/relik/retriever/conf/scheduler/none.yaml deleted file mode 100644 index ec747fa47ddb81e9bf2d282011ed32aa4c59f932..0000000000000000000000000000000000000000 --- a/relik/retriever/conf/scheduler/none.yaml +++ /dev/null @@ -1 +0,0 @@ -null \ No newline at end of file diff --git a/relik/retriever/data/__init__.py b/relik/retriever/data/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/retriever/data/base/__init__.py b/relik/retriever/data/base/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/retriever/data/base/datasets.py b/relik/retriever/data/base/datasets.py deleted file mode 100644 index 8a011402a7f6ead5914eb0808a62e9e967c8c12d..0000000000000000000000000000000000000000 --- a/relik/retriever/data/base/datasets.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union - -import torch -from torch.utils.data import Dataset, IterableDataset - -from relik.common.log import get_logger - -logger = get_logger() - - -class BaseDataset(Dataset): - def __init__( - self, - name: str, - path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None, - data: Any = None, - **kwargs, - ): - super().__init__() - self.name = name - if path is None and data is None: - raise ValueError("Either `path` or `data` must be provided") - self.path = path - self.project_folder = Path(__file__).parent.parent.parent - self.data = data - - def __len__(self) -> int: - return len(self.data) - - def __getitem__( - self, index - ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - return self.data[index] - - def __repr__(self) -> str: - return f"Dataset({self.name=}, {self.path=})" - - def load( - self, - paths: Union[str, os.PathLike, List[str], List[os.PathLike]], - *args, - **kwargs, - ) -> Any: - # load data from single or multiple paths in one single dataset - raise NotImplementedError - - @staticmethod - def collate_fn(batch: Any, *args, **kwargs) -> Any: - raise NotImplementedError - - -class IterableBaseDataset(IterableDataset): - def __init__( - self, - name: str, - path: Optional[Union[str, Path, List[str], List[Path]]] = None, - data: Any = None, - *args, - **kwargs, - ): - super().__init__() - self.name = name - if path is None and data is None: - raise ValueError("Either `path` or `data` must be provided") - self.path = path - self.project_folder = Path(__file__).parent.parent.parent - self.data = data - - def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: - for sample in self.data: - yield sample - - def __repr__(self) -> str: - return f"Dataset({self.name=}, {self.path=})" - - def load( - self, - paths: Union[str, os.PathLike, List[str], List[os.PathLike]], - *args, - **kwargs, - ) -> Any: - # load data from single or multiple paths in one single dataset - raise NotImplementedError - - @staticmethod - def collate_fn(batch: Any, *args, **kwargs) -> Any: - raise NotImplementedError diff --git a/relik/retriever/data/datasets.py b/relik/retriever/data/datasets.py deleted file mode 100644 index 6cf1897a9955d9eb902d74dd08b7dbcd09fd68e4..0000000000000000000000000000000000000000 --- a/relik/retriever/data/datasets.py +++ /dev/null @@ -1,726 +0,0 @@ -import os -from copy import deepcopy -from enum import Enum -from functools import partial -from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union - -import datasets -import psutil -import torch -import transformers as tr -from datasets import load_dataset -from torch.utils.data import Dataset -from tqdm import tqdm - -from relik.common.log import get_console_logger, get_logger -from relik.retriever.common.model_inputs import ModelInputs -from relik.retriever.data.base.datasets import BaseDataset, IterableBaseDataset -from relik.retriever.data.utils import HardNegativesManager - -console_logger = get_console_logger() - -logger = get_logger(__name__) - - -class SubsampleStrategyEnum(Enum): - NONE = "none" - RANDOM = "random" - IN_ORDER = "in_order" - - -class GoldenRetrieverDataset: - def __init__( - self, - name: str, - path: Union[str, os.PathLike, List[str], List[os.PathLike]] = None, - data: Any = None, - tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, - # passages: Union[str, os.PathLike, List[str]] = None, - passage_batch_size: int = 32, - question_batch_size: int = 32, - max_positives: int = -1, - max_negatives: int = 0, - max_hard_negatives: int = 0, - max_question_length: int = 256, - max_passage_length: int = 64, - shuffle: bool = False, - subsample_strategy: Optional[str] = SubsampleStrategyEnum.NONE, - subsample_portion: float = 0.1, - num_proc: Optional[int] = None, - load_from_cache_file: bool = True, - keep_in_memory: bool = False, - prefetch: bool = True, - load_fn_kwargs: Optional[Dict[str, Any]] = None, - batch_fn_kwargs: Optional[Dict[str, Any]] = None, - collate_fn_kwargs: Optional[Dict[str, Any]] = None, - ): - if path is None and data is None: - raise ValueError("Either `path` or `data` must be provided") - - if tokenizer is None: - raise ValueError("A tokenizer must be provided") - - # dataset parameters - self.name = name - self.path = Path(path) or path - if path is not None and not isinstance(self.path, Sequence): - self.path = [self.path] - # self.project_folder = Path(__file__).parent.parent.parent - self.data = data - - # hyper-parameters - self.passage_batch_size = passage_batch_size - self.question_batch_size = question_batch_size - self.max_positives = max_positives - self.max_negatives = max_negatives - self.max_hard_negatives = max_hard_negatives - self.max_question_length = max_question_length - self.max_passage_length = max_passage_length - self.shuffle = shuffle - self.num_proc = num_proc - self.load_from_cache_file = load_from_cache_file - self.keep_in_memory = keep_in_memory - self.prefetch = prefetch - - self.tokenizer = tokenizer - if isinstance(self.tokenizer, str): - self.tokenizer = tr.AutoTokenizer.from_pretrained(self.tokenizer) - - self.padding_ops = { - "input_ids": partial( - self.pad_sequence, - value=self.tokenizer.pad_token_id, - ), - "attention_mask": partial(self.pad_sequence, value=0), - "token_type_ids": partial( - self.pad_sequence, - value=self.tokenizer.pad_token_type_id, - ), - } - - # check if subsample strategy is valid - if subsample_strategy is not None: - # subsample_strategy can be a string or a SubsampleStrategy - if isinstance(subsample_strategy, str): - try: - subsample_strategy = SubsampleStrategyEnum(subsample_strategy) - except ValueError: - raise ValueError( - f"Subsample strategy {subsample_strategy} is not valid. " - f"Valid strategies are: {SubsampleStrategyEnum.__members__}" - ) - if not isinstance(subsample_strategy, SubsampleStrategyEnum): - raise ValueError( - f"Subsample strategy {subsample_strategy} is not valid. " - f"Valid strategies are: {SubsampleStrategyEnum.__members__}" - ) - self.subsample_strategy = subsample_strategy - self.subsample_portion = subsample_portion - - # load the dataset - if data is None: - self.data: Dataset = self.load( - self.path, - tokenizer=self.tokenizer, - load_from_cache_file=load_from_cache_file, - load_fn_kwargs=load_fn_kwargs, - num_proc=num_proc, - shuffle=shuffle, - keep_in_memory=keep_in_memory, - max_positives=max_positives, - max_negatives=max_negatives, - max_hard_negatives=max_hard_negatives, - max_question_length=max_question_length, - max_passage_length=max_passage_length, - ) - else: - self.data: Dataset = data - - self.hn_manager: Optional[HardNegativesManager] = None - - # keep track of how many times the dataset has been iterated over - self.number_of_complete_iterations = 0 - - def __repr__(self) -> str: - return f"GoldenRetrieverDataset({self.name=}, {self.path=})" - - def __len__(self) -> int: - raise NotImplementedError - - def __getitem__( - self, index - ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - raise NotImplementedError - - def to_torch_dataset(self, *args, **kwargs) -> torch.utils.data.Dataset: - raise NotImplementedError - - def load( - self, - paths: Union[str, os.PathLike, List[str], List[os.PathLike]], - tokenizer: tr.PreTrainedTokenizer = None, - load_fn_kwargs: Dict = None, - load_from_cache_file: bool = True, - num_proc: Optional[int] = None, - shuffle: bool = False, - keep_in_memory: bool = True, - max_positives: int = -1, - max_negatives: int = -1, - max_hard_negatives: int = -1, - max_passages: int = -1, - max_question_length: int = 256, - max_passage_length: int = 64, - *args, - **kwargs, - ) -> Any: - # if isinstance(paths, Sequence): - # paths = [self.project_folder / path for path in paths] - # else: - # paths = [self.project_folder / paths] - - # read the data and put it in a placeholder list - for path in paths: - if not path.exists(): - raise ValueError(f"{path} does not exist") - - fn_kwargs = dict( - tokenizer=tokenizer, - max_positives=max_positives, - max_negatives=max_negatives, - max_hard_negatives=max_hard_negatives, - max_passages=max_passages, - max_question_length=max_question_length, - max_passage_length=max_passage_length, - ) - if load_fn_kwargs is not None: - fn_kwargs.update(load_fn_kwargs) - - if num_proc is None: - num_proc = psutil.cpu_count(logical=False) - - # The data is a list of dictionaries, each dictionary is a sample - # Each sample has the following keys: - # - "question": the question - # - "answers": a list of answers - # - "positive_ctxs": a list of positive passages - # - "negative_ctxs": a list of negative passages - # - "hard_negative_ctxs": a list of hard negative passages - # use the huggingface dataset library to load the data, by default it will load the - # data in a dict with the key being "train". - logger.info(f"Loading data for dataset {self.name}") - data = load_dataset( - "json", - data_files=[str(p) for p in paths], # datasets needs str paths and not Path - split="train", - streaming=False, # TODO maybe we can make streaming work - keep_in_memory=keep_in_memory, - ) - # add id if not present - if isinstance(data, datasets.Dataset): - data = data.add_column("sample_idx", range(len(data))) - else: - data = data.map( - lambda x, idx: x.update({"sample_idx": idx}), with_indices=True - ) - - map_kwargs = dict( - function=self.load_fn, - fn_kwargs=fn_kwargs, - ) - if isinstance(data, datasets.Dataset): - map_kwargs.update( - dict( - load_from_cache_file=load_from_cache_file, - keep_in_memory=keep_in_memory, - num_proc=num_proc, - desc="Loading data", - ) - ) - # preprocess the data - data = data.map(**map_kwargs) - - # shuffle the data - if shuffle: - data.shuffle(seed=42) - - return data - - @staticmethod - def create_batches( - data: Dataset, - batch_fn: Callable, - batch_fn_kwargs: Optional[Dict[str, Any]] = None, - prefetch: bool = True, - *args, - **kwargs, - ) -> Union[Iterable, List]: - if not prefetch: - # if we are streaming, we don't need to create batches right now - # we will create them on the fly when we need them - batched_data = ( - batch - for batch in batch_fn( - data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) - ) - ) - else: - batched_data = [ - batch - for batch in tqdm( - batch_fn( - data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) - ), - desc="Creating batches", - ) - ] - return batched_data - - @staticmethod - def collate_batches( - batched_data: Union[Iterable, List], - collate_fn: Callable, - collate_fn_kwargs: Optional[Dict[str, Any]] = None, - prefetch: bool = True, - *args, - **kwargs, - ) -> Union[Iterable, List]: - if not prefetch: - collated_data = ( - collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) - for batch in batched_data - ) - else: - collated_data = [ - collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) - for batch in tqdm(batched_data, desc="Collating batches") - ] - return collated_data - - @staticmethod - def load_fn(sample: Dict, *args, **kwargs) -> Dict: - raise NotImplementedError - - @staticmethod - def batch_fn(data: Dataset, *args, **kwargs) -> Any: - raise NotImplementedError - - @staticmethod - def collate_fn(batch: Any, *args, **kwargs) -> Any: - raise NotImplementedError - - @staticmethod - def pad_sequence( - sequence: Union[List, torch.Tensor], - length: int, - value: Any = None, - pad_to_left: bool = False, - ) -> Union[List, torch.Tensor]: - """ - Pad the input to the specified length with the given value. - - Args: - sequence (:obj:`List`, :obj:`torch.Tensor`): - Element to pad, it can be either a :obj:`List` or a :obj:`torch.Tensor`. - length (:obj:`int`, :obj:`str`, optional, defaults to :obj:`subtoken`): - Length after pad. - value (:obj:`Any`, optional): - Value to use as padding. - pad_to_left (:obj:`bool`, optional, defaults to :obj:`False`): - If :obj:`True`, pads to the left, right otherwise. - - Returns: - :obj:`List`, :obj:`torch.Tensor`: The padded sequence. - - """ - padding = [value] * abs(length - len(sequence)) - if isinstance(sequence, torch.Tensor): - if len(sequence.shape) > 1: - raise ValueError( - f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`" - ) - padding = torch.as_tensor(padding) - if pad_to_left: - if isinstance(sequence, torch.Tensor): - return torch.cat((padding, sequence), -1) - return padding + sequence - if isinstance(sequence, torch.Tensor): - return torch.cat((sequence, padding), -1) - return sequence + padding - - def convert_to_batch( - self, samples: Any, *args, **kwargs - ) -> Dict[str, torch.Tensor]: - """ - Convert the list of samples to a batch. - - Args: - samples (:obj:`List`): - List of samples to convert to a batch. - - Returns: - :obj:`Dict[str, torch.Tensor]`: The batch. - """ - # invert questions from list of dict to dict of list - samples = {k: [d[k] for d in samples] for k in samples[0]} - # get max length of questions - max_len = max(len(x) for x in samples["input_ids"]) - # pad the questions - for key in samples: - if key in self.padding_ops: - samples[key] = torch.as_tensor( - [self.padding_ops[key](b, max_len) for b in samples[key]] - ) - return samples - - def shuffle_data(self, seed: int = 42): - self.data = self.data.shuffle(seed=seed) - - -class InBatchNegativesDataset(GoldenRetrieverDataset): - def __len__(self) -> int: - if isinstance(self.data, datasets.Dataset): - return len(self.data) - - def __getitem__( - self, index - ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - return self.data[index] - - def to_torch_dataset(self) -> torch.utils.data.Dataset: - shuffle_this_time = self.shuffle - - if ( - self.subsample_strategy - and self.subsample_strategy != SubsampleStrategyEnum.NONE - ): - number_of_samples = int(len(self.data) * self.subsample_portion) - if self.subsample_strategy == SubsampleStrategyEnum.RANDOM: - logger.info( - f"Random subsampling {number_of_samples} samples from {len(self.data)}" - ) - data = ( - deepcopy(self.data) - .shuffle(seed=42 + self.number_of_complete_iterations) - .select(range(0, number_of_samples)) - ) - elif self.subsample_strategy == SubsampleStrategyEnum.IN_ORDER: - # number_of_samples = int(len(self.data) * self.subsample_portion) - already_selected = ( - number_of_samples * self.number_of_complete_iterations - ) - logger.info( - f"Subsampling {number_of_samples} samples out of {len(self.data)}" - ) - to_select = min(already_selected + number_of_samples, len(self.data)) - logger.info( - f"Portion of data selected: {already_selected} " f"to {to_select}" - ) - data = deepcopy(self.data).select(range(already_selected, to_select)) - - # don't shuffle the data if we are subsampling, and we have still not completed - # one full iteration over the dataset - if self.number_of_complete_iterations > 0: - shuffle_this_time = False - - # reset the number of complete iterations - if to_select >= len(self.data): - # reset the number of complete iterations, - # we have completed one full iteration over the dataset - # the value is -1 because we want to start from 0 at the next iteration - self.number_of_complete_iterations = -1 - else: - raise ValueError( - f"Subsample strategy `{self.subsample_strategy}` is not valid. " - f"Valid strategies are: {SubsampleStrategyEnum.__members__}" - ) - - else: - data = data = self.data - - # do we need to shuffle the data? - if self.shuffle and shuffle_this_time: - logger.info("Shuffling the data") - data = data.shuffle(seed=42 + self.number_of_complete_iterations) - - batch_fn_kwargs = { - "passage_batch_size": self.passage_batch_size, - "question_batch_size": self.question_batch_size, - "hard_negatives_manager": self.hn_manager, - } - batched_data = self.create_batches( - data, - batch_fn=self.batch_fn, - batch_fn_kwargs=batch_fn_kwargs, - prefetch=self.prefetch, - ) - - batched_data = self.collate_batches( - batched_data, self.collate_fn, prefetch=self.prefetch - ) - - # increment the number of complete iterations - self.number_of_complete_iterations += 1 - - if self.prefetch: - return BaseDataset(name=self.name, data=batched_data) - else: - return IterableBaseDataset(name=self.name, data=batched_data) - - @staticmethod - def load_fn( - sample: Dict, - tokenizer: tr.PreTrainedTokenizer, - max_positives: int, - max_negatives: int, - max_hard_negatives: int, - max_passages: int = -1, - max_question_length: int = 256, - max_passage_length: int = 128, - *args, - **kwargs, - ) -> Dict: - # remove duplicates and limit the number of passages - positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]])) - if max_positives != -1: - positives = positives[:max_positives] - negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]])) - if max_negatives != -1: - negatives = negatives[:max_negatives] - hard_negatives = list( - set([h["text"].strip() for h in sample["hard_negative_ctxs"]]) - ) - if max_hard_negatives != -1: - hard_negatives = hard_negatives[:max_hard_negatives] - - question = tokenizer( - sample["question"], max_length=max_question_length, truncation=True - ) - - passage = positives + negatives + hard_negatives - if max_passages != -1: - passage = passage[:max_passages] - - passage = tokenizer(passage, max_length=max_passage_length, truncation=True) - - # invert the passage data structure from a dict of lists to a list of dicts - passage = [dict(zip(passage, t)) for t in zip(*passage.values())] - - output = dict( - question=question, - passage=passage, - positives=positives, - positive_pssgs=passage[: len(positives)], - ) - return output - - @staticmethod - def batch_fn( - data: Dataset, - passage_batch_size: int, - question_batch_size: int, - hard_negatives_manager: Optional[HardNegativesManager] = None, - *args, - **kwargs, - ) -> Dict[str, List[Dict[str, Any]]]: - def split_batch( - batch: Union[Dict[str, Any], ModelInputs], question_batch_size: int - ) -> List[ModelInputs]: - """ - Split a batch into multiple batches of size `question_batch_size` while keeping - the same number of passages. - """ - - def split_fn(x): - return [ - x[i : i + question_batch_size] - for i in range(0, len(x), question_batch_size) - ] - - # split the sample_idx - sample_idx = split_fn(batch["sample_idx"]) - # split the questions - questions = split_fn(batch["questions"]) - # split the positives - positives = split_fn(batch["positives"]) - # split the positives_pssgs - positives_pssgs = split_fn(batch["positives_pssgs"]) - - # collect the new batches - batches = [] - for i in range(len(questions)): - batches.append( - ModelInputs( - dict( - sample_idx=sample_idx[i], - questions=questions[i], - passages=batch["passages"], - positives=positives[i], - positives_pssgs=positives_pssgs[i], - ) - ) - ) - return batches - - batch = [] - passages_in_batch = {} - - for sample in data: - if len(passages_in_batch) >= passage_batch_size: - # create the batch dict - batch_dict = ModelInputs( - dict( - sample_idx=[s["sample_idx"] for s in batch], - questions=[s["question"] for s in batch], - passages=list(passages_in_batch.values()), - positives_pssgs=[s["positive_pssgs"] for s in batch], - positives=[s["positives"] for s in batch], - ) - ) - # split the batch if needed - if len(batch) > question_batch_size: - for splited_batch in split_batch(batch_dict, question_batch_size): - yield splited_batch - else: - yield batch_dict - - # reset batch - batch = [] - passages_in_batch = {} - - batch.append(sample) - # yes it's a bit ugly but it works :) - # count the number of passages in the batch and stop if we reach the limit - # we use a set to avoid counting the same passage twice - # we use a tuple because set doesn't support lists - # we use input_ids as discriminator - passages_in_batch.update( - {tuple(passage["input_ids"]): passage for passage in sample["passage"]} - ) - # check for hard negatives and add with a probability of 0.1 - if hard_negatives_manager is not None: - if sample["sample_idx"] in hard_negatives_manager: - passages_in_batch.update( - { - tuple(passage["input_ids"]): passage - for passage in hard_negatives_manager.get( - sample["sample_idx"] - ) - } - ) - - # left over - if len(batch) > 0: - # create the batch dict - batch_dict = ModelInputs( - dict( - sample_idx=[s["sample_idx"] for s in batch], - questions=[s["question"] for s in batch], - passages=list(passages_in_batch.values()), - positives_pssgs=[s["positive_pssgs"] for s in batch], - positives=[s["positives"] for s in batch], - ) - ) - # split the batch if needed - if len(batch) > question_batch_size: - for splited_batch in split_batch(batch_dict, question_batch_size): - yield splited_batch - else: - yield batch_dict - - def collate_fn(self, batch: Any, *args, **kwargs) -> Any: - # convert questions and passages to a batch - questions = self.convert_to_batch(batch.questions) - passages = self.convert_to_batch(batch.passages) - - # build an index to map the position of the passage in the batch - passage_index = {tuple(c["input_ids"]): i for i, c in enumerate(batch.passages)} - - # now we can create the labels - labels = torch.zeros( - questions["input_ids"].shape[0], passages["input_ids"].shape[0] - ) - # iterate over the questions and set the labels to 1 if the passage is positive - for sample_idx in range(len(questions["input_ids"])): - for pssg in batch["positives_pssgs"][sample_idx]: - # get the index of the positive passage - index = passage_index[tuple(pssg["input_ids"])] - # set the label to 1 - labels[sample_idx, index] = 1 - - model_inputs = ModelInputs( - { - "questions": questions, - "passages": passages, - "labels": labels, - "positives": batch["positives"], - "sample_idx": batch["sample_idx"], - } - ) - return model_inputs - - -class AidaInBatchNegativesDataset(InBatchNegativesDataset): - def __init__(self, use_topics: bool = False, *args, **kwargs): - if "load_fn_kwargs" not in kwargs: - kwargs["load_fn_kwargs"] = {} - kwargs["load_fn_kwargs"]["use_topics"] = use_topics - super().__init__(*args, **kwargs) - - @staticmethod - def load_fn( - sample: Dict, - tokenizer: tr.PreTrainedTokenizer, - max_positives: int, - max_negatives: int, - max_hard_negatives: int, - max_passages: int = -1, - max_question_length: int = 256, - max_passage_length: int = 128, - use_topics: bool = False, - *args, - **kwargs, - ) -> Dict: - # remove duplicates and limit the number of passages - positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]])) - if max_positives != -1: - positives = positives[:max_positives] - negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]])) - if max_negatives != -1: - negatives = negatives[:max_negatives] - hard_negatives = list( - set([h["text"].strip() for h in sample["hard_negative_ctxs"]]) - ) - if max_hard_negatives != -1: - hard_negatives = hard_negatives[:max_hard_negatives] - - question = sample["question"] - - if "doc_topic" in sample and use_topics: - question = tokenizer( - question, - sample["doc_topic"], - max_length=max_question_length, - truncation=True, - ) - else: - question = tokenizer( - question, max_length=max_question_length, truncation=True - ) - - passage = positives + negatives + hard_negatives - if max_passages != -1: - passage = passage[:max_passages] - - passage = tokenizer(passage, max_length=max_passage_length, truncation=True) - - # invert the passage data structure from a dict of lists to a list of dicts - passage = [dict(zip(passage, t)) for t in zip(*passage.values())] - - output = dict( - question=question, - passage=passage, - positives=positives, - positive_pssgs=passage[: len(positives)], - ) - return output diff --git a/relik/retriever/data/labels.py b/relik/retriever/data/labels.py deleted file mode 100644 index de8f87a2186a413b47d686f92b4d3039536a5988..0000000000000000000000000000000000000000 --- a/relik/retriever/data/labels.py +++ /dev/null @@ -1,338 +0,0 @@ -import json -from pathlib import Path -from typing import Dict, List, Optional, Set, Union - -import transformers as tr - - -class Labels: - """ - Class that contains the labels for a model. - - Args: - _labels_to_index (:obj:`Dict[str, Dict[str, int]]`): - A dictionary from :obj:`str` to :obj:`int`. - _index_to_labels (:obj:`Dict[str, Dict[int, str]]`): - A dictionary from :obj:`int` to :obj:`str`. - """ - - def __init__( - self, - _labels_to_index: Dict[str, Dict[str, int]] = None, - _index_to_labels: Dict[str, Dict[int, str]] = None, - **kwargs, - ): - self._labels_to_index = _labels_to_index or {"labels": {}} - self._index_to_labels = _index_to_labels or {"labels": {}} - # if _labels_to_index is not empty and _index_to_labels is not provided - # to the constructor, build the inverted label dictionary - if not _index_to_labels and _labels_to_index: - for namespace in self._labels_to_index: - self._index_to_labels[namespace] = { - v: k for k, v in self._labels_to_index[namespace].items() - } - - def get_index_from_label(self, label: str, namespace: str = "labels") -> int: - """ - Returns the index of a literal label. - - Args: - label (:obj:`str`): - The string representation of the label. - namespace (:obj:`str`, optional, defaults to ``labels``): - The namespace where the label belongs, e.g. ``roles`` for a SRL task. - - Returns: - :obj:`int`: The index of the label. - """ - if namespace not in self._labels_to_index: - raise ValueError( - f"Provided namespace `{namespace}` is not in the label dictionary." - ) - - if label not in self._labels_to_index[namespace]: - raise ValueError(f"Provided label {label} is not in the label dictionary.") - - return self._labels_to_index[namespace][label] - - def get_label_from_index(self, index: int, namespace: str = "labels") -> str: - """ - Returns the string representation of the label index. - - Args: - index (:obj:`int`): - The index of the label. - namespace (:obj:`str`, optional, defaults to ``labels``): - The namespace where the label belongs, e.g. ``roles`` for a SRL task. - - Returns: - :obj:`str`: The string representation of the label. - """ - if namespace not in self._index_to_labels: - raise ValueError( - f"Provided namespace `{namespace}` is not in the label dictionary." - ) - - if index not in self._index_to_labels[namespace]: - raise ValueError( - f"Provided label `{index}` is not in the label dictionary." - ) - - return self._index_to_labels[namespace][index] - - def add_labels( - self, - labels: Union[str, List[str], Set[str], Dict[str, int]], - namespace: str = "labels", - ) -> List[int]: - """ - Adds the labels in input in the label dictionary. - - Args: - labels (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`): - The labels (single label, list of labels or set of labels) to add to the dictionary. - namespace (:obj:`str`, optional, defaults to ``labels``): - Namespace where the labels belongs. - - Returns: - :obj:`List[int]`: The index of the labels just inserted. - """ - if isinstance(labels, dict): - self._labels_to_index[namespace] = labels - self._index_to_labels[namespace] = { - v: k for k, v in self._labels_to_index[namespace].items() - } - # normalize input - if isinstance(labels, (str, list)): - labels = set(labels) - # if new namespace, add to the dictionaries - if namespace not in self._labels_to_index: - self._labels_to_index[namespace] = {} - self._index_to_labels[namespace] = {} - # returns the new indices - return [self._add_label(label, namespace) for label in labels] - - def _add_label(self, label: str, namespace: str = "labels") -> int: - """ - Adds the label in input in the label dictionary. - - Args: - label (:obj:`str`): - The label to add to the dictionary. - namespace (:obj:`str`, optional, defaults to ``labels``): - Namespace where the label belongs. - - Returns: - :obj:`List[int]`: The index of the label just inserted. - """ - if label not in self._labels_to_index[namespace]: - index = len(self._labels_to_index[namespace]) - self._labels_to_index[namespace][label] = index - self._index_to_labels[namespace][index] = label - return index - else: - return self._labels_to_index[namespace][label] - - def get_labels(self, namespace: str = "labels") -> Dict[str, int]: - """ - Returns all the labels that belongs to the input namespace. - - Args: - namespace (:obj:`str`, optional, defaults to ``labels``): - Labels namespace to retrieve. - - Returns: - :obj:`Dict[str, int]`: The label dictionary, from ``str`` to ``int``. - """ - if namespace not in self._labels_to_index: - raise ValueError( - f"Provided namespace `{namespace}` is not in the label dictionary." - ) - return self._labels_to_index[namespace] - - def get_label_size(self, namespace: str = "labels") -> int: - """ - Returns the number of the labels in the namespace dictionary. - - Args: - namespace (:obj:`str`, optional, defaults to ``labels``): - Labels namespace to retrieve. - - Returns: - :obj:`int`: Number of labels. - """ - if namespace not in self._labels_to_index: - raise ValueError( - f"Provided namespace `{namespace}` is not in the label dictionary." - ) - return len(self._labels_to_index[namespace]) - - def get_namespaces(self) -> List[str]: - """ - Returns all the namespaces in the label dictionary. - - Returns: - :obj:`List[str]`: The namespaces in the label dictionary. - """ - return list(self._labels_to_index.keys()) - - @classmethod - def from_file(cls, file_path: Union[str, Path, dict], **kwargs): - with open(file_path, "r") as f: - labels_to_index = json.load(f) - return cls(labels_to_index, **kwargs) - - def save(self, file_path: Union[str, Path, dict], **kwargs): - with open(file_path, "w") as f: - json.dump(self._labels_to_index, f, indent=2) - - -class PassageManager: - def __init__( - self, - tokenizer: Optional[tr.PreTrainedTokenizer] = None, - passages: Optional[Union[Dict[str, Dict[str, int]], Labels, List[str]]] = None, - lazy: bool = True, - **kwargs, - ): - if passages is None: - self.passages = Labels() - elif isinstance(passages, Labels): - self.passages = passages - elif isinstance(passages, dict): - self.passages = Labels(passages) - elif isinstance(passages, list): - self.passages = Labels() - self.passages.add_labels(passages) - else: - raise ValueError( - "`passages` should be either a Labels object or a dictionary." - ) - - self.tokenizer = tokenizer - self.lazy = lazy - - self._tokenized_passages = {} - - if not self.lazy: - self._tokenize_passages(self.passages) - - def __len__(self) -> int: - return self.passages.get_label_size() - - def get_index_from_passage(self, passage: str) -> int: - """ - Returns the index of the passage in input. - - Args: - passage (:obj:`str`): - The passage to get the index from. - - Returns: - :obj:`int`: The index of the passage. - """ - return self.passages.get_index_from_label(passage) - - def get_passage_from_index(self, index: int) -> str: - """ " - Returns the passage from the index in input. - - Args: - index (:obj:`int`): - The index to get the passage from. - - Returns: - :obj:`str`: The passage. - """ - return self.passages.get_label_from_index(index) - - def add_passages( - self, - passages: Union[str, List[str], Set[str], Dict[str, int]], - lazy: Optional[bool] = None, - ) -> List[int]: - """ - Adds the passages in input in the passage dictionary. - - Args: - passages (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`, :obj:`Dict[str, int]`): - The passages (single passage, list of passages, set of passages or dictionary of passages) to add to the dictionary. - lazy (:obj:`bool`, optional, defaults to ``None``): - Whether to tokenize the passages right away or not. - - Returns: - :obj:`List[int]`: The index of the passages just inserted. - """ - - return self.passages.add_labels(passages) - - def get_passages(self) -> Dict[str, int]: - """ - Returns all the passages in the passage dictionary. - - Returns: - :obj:`Dict[str, int]`: The passage dictionary, from ``str`` to ``int``. - """ - return self.passages.get_labels() - - def get_tokenized_passage( - self, passage: Union[str, int], force_tokenize: bool = False, **kwargs - ) -> Dict: - """ - Returns the tokenized passage in input. - - Args: - passage (:obj:`Union[str, int]`): - The passage to tokenize. - force_tokenize (:obj:`bool`, optional, defaults to ``False``): - Whether to force the tokenization of the passage or not. - kwargs: - Additional keyword arguments to pass to the tokenizer. - - Returns: - :obj:`Dict`: The tokenized passage. - """ - passage_index: Optional[int] = None - passage_str: Optional[str] = None - - if isinstance(passage, str): - passage_index = self.passages.get_index_from_label(passage) - passage_str = passage - elif isinstance(passage, int): - passage_index = passage - passage_str = self.passages.get_label_from_index(passage) - else: - raise ValueError( - f"`passage` should be either a `str` or an `int`. Provided type: {type(passage)}." - ) - - if passage_index not in self._tokenized_passages or force_tokenize: - self._tokenized_passages[passage_index] = self.tokenizer( - passage_str, **kwargs - ) - - return self._tokenized_passages[passage_index] - - def _tokenize_passages(self, **kwargs): - for passage in self.passages.get_labels(): - self.get_tokenized_passage(passage, **kwargs) - - def tokenize(self, text: Union[str, List[str]], **kwargs): - """ - Tokenizes the text in input using the tokenizer. - - Args: - text (:obj:`str`, :obj:`List[str]`): - The text to tokenize. - **kwargs: - Additional keyword arguments to pass to the tokenizer. - - Returns: - :obj:`List[str]`: The tokenized text. - - """ - if self.tokenizer is None: - raise ValueError( - "No tokenizer was provided. Please provide a tokenizer to the passageManager." - ) - return self.tokenizer(text, **kwargs) diff --git a/relik/retriever/data/utils.py b/relik/retriever/data/utils.py deleted file mode 100644 index 928dfb833919ce2e30c9e90fed539c46f4bd3dec..0000000000000000000000000000000000000000 --- a/relik/retriever/data/utils.py +++ /dev/null @@ -1,176 +0,0 @@ -import json -import os -from collections import defaultdict -from typing import Any, Dict, Iterable, List, Optional, Union - -import numpy as np -import transformers as tr -from tqdm import tqdm - - -class HardNegativesManager: - def __init__( - self, - tokenizer: tr.PreTrainedTokenizer, - data: Union[List[Dict], os.PathLike, Dict[int, List]] = None, - max_length: int = 64, - batch_size: int = 1000, - lazy: bool = False, - ) -> None: - self._db: dict = None - self.tokenizer = tokenizer - - if data is None: - self._db = {} - else: - if isinstance(data, Dict): - self._db = data - elif isinstance(data, os.PathLike): - with open(data) as f: - self._db = json.load(f) - else: - raise ValueError( - f"Data type {type(data)} not supported, only Dict and os.PathLike are supported." - ) - # add the tokenizer to the class for future use - self.tokenizer = tokenizer - - # invert the db to have a passage -> sample_idx mapping - self._passage_db = defaultdict(set) - for sample_idx, passages in self._db.items(): - for passage in passages: - self._passage_db[passage].add(sample_idx) - - self._passage_hard_negatives = {} - if not lazy: - # create a dictionary of passage -> hard_negative mapping - batch_size = min(batch_size, len(self._passage_db)) - unique_passages = list(self._passage_db.keys()) - for i in tqdm( - range(0, len(unique_passages), batch_size), - desc="Tokenizing Hard Negatives", - ): - batch = unique_passages[i : i + batch_size] - tokenized_passages = self.tokenizer( - batch, - max_length=max_length, - truncation=True, - ) - for i, passage in enumerate(batch): - self._passage_hard_negatives[passage] = { - k: tokenized_passages[k][i] for k in tokenized_passages.keys() - } - - def __len__(self) -> int: - return len(self._db) - - def __getitem__(self, idx: int) -> Dict: - return self._db[idx] - - def __iter__(self): - for sample in self._db: - yield sample - - def __contains__(self, idx: int) -> bool: - return idx in self._db - - def get(self, idx: int) -> List[str]: - """Get the hard negatives for a given sample index.""" - if idx not in self._db: - raise ValueError(f"Sample index {idx} not in the database.") - - passages = self._db[idx] - - output = [] - for passage in passages: - if passage not in self._passage_hard_negatives: - self._passage_hard_negatives[passage] = self._tokenize(passage) - output.append(self._passage_hard_negatives[passage]) - - return output - - def _tokenize(self, passage: str) -> Dict: - return self.tokenizer(passage, max_length=self.max_length, truncation=True) - - -class NegativeSampler: - def __init__( - self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None - ): - if not isinstance(probabilities, np.ndarray): - probabilities = np.array(probabilities) - - if probabilities is None: - # probabilities should sum to 1 - probabilities = np.random.random(num_elements) - probabilities /= np.sum(probabilities) - self.probabilities = probabilities - - def __call__( - self, - sample_size: int, - num_samples: int = 1, - probabilities: np.array = None, - exclude: List[int] = None, - ) -> np.array: - """ - Fast sampling of `sample_size` elements from `num_elements` elements. - The sampling is done by randomly shifting the probabilities and then - finding the smallest of the negative numbers. This is much faster than - sampling from a multinomial distribution. - - Args: - sample_size (`int`): - number of elements to sample - num_samples (`int`, optional): - number of samples to draw. Defaults to 1. - probabilities (`np.array`, optional): - probabilities of each element. Defaults to None. - exclude (`List[int]`, optional): - indices of elements to exclude. Defaults to None. - - Returns: - `np.array`: array of sampled indices - """ - if probabilities is None: - probabilities = self.probabilities - - if exclude is not None: - probabilities[exclude] = 0 - # re-normalize? - # probabilities /= np.sum(probabilities) - - # replicate probabilities as many times as `num_samples` - replicated_probabilities = np.tile(probabilities, (num_samples, 1)) - # get random shifting numbers & scale them correctly - random_shifts = np.random.random(replicated_probabilities.shape) - random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis] - # shift by numbers & find largest (by finding the smallest of the negative) - shifted_probabilities = random_shifts - replicated_probabilities - sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[ - :, :sample_size - ] - return sampled_indices - - -def batch_generator(samples: Iterable[Any], batch_size: int) -> Iterable[Any]: - """ - Generate batches from samples. - - Args: - samples (`Iterable[Any]`): Iterable of samples. - batch_size (`int`): Batch size. - - Returns: - `Iterable[Any]`: Iterable of batches. - """ - batch = [] - for sample in samples: - batch.append(sample) - if len(batch) == batch_size: - yield batch - batch = [] - - # leftover batch - if len(batch) > 0: - yield batch diff --git a/relik/retriever/indexers/__init__.py b/relik/retriever/indexers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/retriever/indexers/base.py b/relik/retriever/indexers/base.py deleted file mode 100644 index c9b1b635566e941e5dfa6e204e594cd0d13c274c..0000000000000000000000000000000000000000 --- a/relik/retriever/indexers/base.py +++ /dev/null @@ -1,334 +0,0 @@ -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import hydra -import numpy -import torch -from omegaconf import OmegaConf -from rich.pretty import pprint - -from relik.common import upload -from relik.common.log import get_console_logger, get_logger -from relik.common.utils import ( - from_cache, - is_remote_url, - is_str_a_path, - relative_to_absolute_path, - sapienzanlp_model_urls, -) -from relik.retriever.data.labels import Labels - -# from relik.retriever.models.model import GoldenRetriever, RetrievedSample - - -logger = get_logger(__name__) -console_logger = get_console_logger() - - -@dataclass -class IndexerOutput: - indices: Union[torch.Tensor, numpy.ndarray] - distances: Union[torch.Tensor, numpy.ndarray] - - -class BaseDocumentIndex: - CONFIG_NAME = "config.yaml" - DOCUMENTS_FILE_NAME = "documents.json" - EMBEDDINGS_FILE_NAME = "embeddings.pt" - - def __init__( - self, - documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None, - embeddings: Optional[torch.Tensor] = None, - name_or_dir: Optional[Union[str, os.PathLike]] = None, - ) -> None: - if documents is not None: - if isinstance(documents, Labels): - self.documents = documents - else: - documents_are_paths = False - - # normalize the documents to list if not already - if not isinstance(documents, list): - documents = [documents] - - # now check if the documents are a list of paths (either str or os.PathLike) - if isinstance(documents[0], str) or isinstance( - documents[0], os.PathLike - ): - # check if the str is a path - documents_are_paths = is_str_a_path(documents[0]) - - # if the documents are a list of paths, then we load them - if documents_are_paths: - logger.info("Loading documents from paths") - _documents = [] - for doc in documents: - with open(relative_to_absolute_path(doc)) as f: - _documents += [line.strip() for line in f.readlines()] - # remove duplicates - documents = list(set(_documents)) - - self.documents = Labels() - self.documents.add_labels(documents) - else: - self.documents = Labels() - - self.embeddings = embeddings - self.name_or_dir = name_or_dir - - def __iter__(self): - # make this class iterable - for i in range(len(self)): - yield self[i] - - def __len__(self): - return self.documents.get_label_size() - - def __getitem__(self, index): - return self.get_passage_from_index(index) - - @property - def config(self) -> Dict[str, Any]: - """ - The configuration of the document index. - - Returns: - `Dict[str, Any]`: The configuration of the retriever. - """ - - def obj_to_dict(obj): - match obj: - case dict(): - data = {} - for k, v in obj.items(): - data[k] = obj_to_dict(v) - return data - - case list() | tuple(): - return [obj_to_dict(x) for x in obj] - - case object(__dict__=_): - data = { - "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}", - } - for k, v in obj.__dict__.items(): - if not k.startswith("_"): - data[k] = obj_to_dict(v) - return data - - case _: - return obj - - return obj_to_dict(self) - - def index( - self, - retriever, - *args, - **kwargs, - ) -> "BaseDocumentIndex": - raise NotImplementedError - - def search(self, query: Any, k: int = 1, *args, **kwargs) -> List: - raise NotImplementedError - - def get_index_from_passage(self, document: str) -> int: - """ - Get the index of the passage. - - Args: - document (`str`): - The document to get the index for. - - Returns: - `int`: The index of the document. - """ - return self.documents.get_index_from_label(document) - - def get_passage_from_index(self, index: int) -> str: - """ - Get the document from the index. - - Args: - index (`int`): - The index of the document. - - Returns: - `str`: The document. - """ - return self.documents.get_label_from_index(index) - - def get_embeddings_from_index(self, index: int) -> torch.Tensor: - """ - Get the document vector from the index. - - Args: - index (`int`): - The index of the document. - - Returns: - `torch.Tensor`: The document vector. - """ - if self.embeddings is None: - raise ValueError( - "The documents must be indexed before they can be retrieved." - ) - if index >= self.embeddings.shape[0]: - raise ValueError( - f"The index {index} is out of bounds. The maximum index is {len(self.embeddings) - 1}." - ) - return self.embeddings[index] - - def get_embeddings_from_passage(self, document: str) -> torch.Tensor: - """ - Get the document vector from the document label. - - Args: - document (`str`): - The document to get the vector for. - - Returns: - `torch.Tensor`: The document vector. - """ - if self.embeddings is None: - raise ValueError( - "The documents must be indexed before they can be retrieved." - ) - return self.get_embeddings_from_index(self.get_index_from_passage(document)) - - def save_pretrained( - self, - output_dir: Union[str, os.PathLike], - config: Optional[Dict[str, Any]] = None, - config_file_name: Optional[str] = None, - document_file_name: Optional[str] = None, - embedding_file_name: Optional[str] = None, - push_to_hub: bool = False, - **kwargs, - ): - """ - Save the retriever to a directory. - - Args: - output_dir (`str`): - The directory to save the retriever to. - config (`Optional[Dict[str, Any]]`, `optional`): - The configuration to save. If `None`, the current configuration of the retriever will be - saved. Defaults to `None`. - config_file_name (`Optional[str]`, `optional`): - The name of the configuration file. Defaults to `config.yaml`. - document_file_name (`Optional[str]`, `optional`): - The name of the document file. Defaults to `documents.json`. - embedding_file_name (`Optional[str]`, `optional`): - The name of the embedding file. Defaults to `embeddings.pt`. - push_to_hub (`bool`, `optional`): - Whether to push the saved retriever to the hub. Defaults to `False`. - """ - if config is None: - # create a default config - config = self.config - - config_file_name = config_file_name or self.CONFIG_NAME - document_file_name = document_file_name or self.DOCUMENTS_FILE_NAME - embedding_file_name = embedding_file_name or self.EMBEDDINGS_FILE_NAME - - # create the output directory - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Saving retriever to {output_dir}") - logger.info(f"Saving config to {output_dir / config_file_name}") - # pretty print the config - pprint(config, console=console_logger, expand_all=True) - OmegaConf.save(config, output_dir / config_file_name) - - # save the current state of the retriever - embedding_path = output_dir / embedding_file_name - logger.info(f"Saving retriever state to {output_dir / embedding_path}") - torch.save(self.embeddings, embedding_path) - - # save the passage index - documents_path = output_dir / document_file_name - logger.info(f"Saving passage index to {documents_path}") - self.documents.save(documents_path) - - logger.info("Saving document index to disk done.") - - if push_to_hub: - # push to hub - logger.info(f"Pushing to hub") - model_id = model_id or output_dir.name - upload(output_dir, model_id, **kwargs) - - @classmethod - def from_pretrained( - cls, - name_or_dir: Union[str, os.PathLike], - device: str = "cpu", - precision: Optional[str] = None, - config_file_name: Optional[str] = None, - document_file_name: Optional[str] = None, - embedding_file_name: Optional[str] = None, - config_kwargs: Optional[Dict[str, Any]] = None, - *args, - **kwargs, - ) -> "BaseDocumentIndex": - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - - config_file_name = config_file_name or cls.CONFIG_NAME - document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME - embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME - - model_dir = from_cache( - name_or_dir, - filenames=[config_file_name, document_file_name, embedding_file_name], - cache_dir=cache_dir, - force_download=force_download, - ) - - config_path = model_dir / config_file_name - if not config_path.exists(): - raise FileNotFoundError( - f"Model configuration file not found at {config_path}." - ) - - config = OmegaConf.load(config_path) - # override the config with the kwargs - if config_kwargs is not None: - config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) - pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) - - # load the documents - documents_path = model_dir / document_file_name - - if not documents_path.exists(): - raise ValueError(f"Document file `{documents_path}` does not exist.") - logger.info(f"Loading documents from {documents_path}") - documents = Labels.from_file(documents_path) - - # load the passage embeddings - embedding_path = model_dir / embedding_file_name - # run some checks - embeddings = None - if embedding_path.exists(): - logger.info(f"Loading embeddings from {embedding_path}") - embeddings = torch.load(embedding_path, map_location="cpu") - else: - logger.warning(f"Embedding file `{embedding_path}` does not exist.") - - document_index = hydra.utils.instantiate( - config, - documents=documents, - embeddings=embeddings, - device=device, - precision=precision, - name_or_dir=name_or_dir, - *args, - **kwargs, - ) - - return document_index diff --git a/relik/retriever/indexers/faiss.py b/relik/retriever/indexers/faiss.py deleted file mode 100644 index 153da317a99996ce9d80f2b212fa5eefdc5cb351..0000000000000000000000000000000000000000 --- a/relik/retriever/indexers/faiss.py +++ /dev/null @@ -1,422 +0,0 @@ -import contextlib -import logging -import math -import os -from dataclasses import dataclass -from typing import Callable, List, Optional, Union - -import numpy -import psutil -import torch -from relik.retriever.pytorch_modules import RetrievedSample -from torch.utils.data import DataLoader -from tqdm import tqdm - -from relik.common.log import get_logger -from relik.common.utils import is_package_available -from relik.retriever.common.model_inputs import ModelInputs -from relik.retriever.data.base.datasets import BaseDataset -from relik.retriever.data.labels import Labels -from relik.retriever.indexers.base import BaseDocumentIndex -from relik.retriever.pytorch_modules import PRECISION_MAP -from relik.retriever.pytorch_modules.model import GoldenRetriever - -if is_package_available("faiss"): - import faiss - import faiss.contrib.torch_utils - -logger = get_logger(__name__, level=logging.INFO) - - -@dataclass -class FaissOutput: - indices: Union[torch.Tensor, numpy.ndarray] - distances: Union[torch.Tensor, numpy.ndarray] - - -class FaissDocumentIndex(BaseDocumentIndex): - DOCUMENTS_FILE_NAME = "documents.json" - EMBEDDINGS_FILE_NAME = "embeddings.pt" - INDEX_FILE_NAME = "index.faiss" - - def __init__( - self, - documents: Union[List[str], Labels], - embeddings: Optional[Union[torch.Tensor, numpy.ndarray]] = None, - index=None, - index_type: str = "Flat", - nprobe: int = 1, - metric: int = faiss.METRIC_INNER_PRODUCT, - normalize: bool = False, - device: str = "cpu", - name_or_dir: Optional[Union[str, os.PathLike]] = None, - *args, - **kwargs, - ) -> None: - super().__init__(documents, embeddings, name_or_dir) - - if embeddings is not None and documents is not None: - logger.info("Both documents and embeddings are provided.") - if documents.get_label_size() != embeddings.shape[0]: - raise ValueError( - "The number of documents and embeddings must be the same." - ) - - faiss.omp_set_num_threads(psutil.cpu_count(logical=False)) - - # device to store the embeddings - self.device = device - - # params - self.index_type = index_type - self.metric = metric - self.normalize = normalize - - if index is not None: - self.embeddings = index - if self.device == "cuda": - # use a single GPU - faiss_resource = faiss.StandardGpuResources() - self.embeddings = faiss.index_cpu_to_gpu( - faiss_resource, 0, self.embeddings - ) - else: - if embeddings is not None: - # build the faiss index - logger.info("Building the index from the embeddings.") - self.embeddings = self._build_faiss_index( - embeddings=embeddings, - index_type=index_type, - nprobe=nprobe, - normalize=normalize, - metric=metric, - ) - - def _build_faiss_index( - self, - embeddings: Optional[Union[torch.Tensor, numpy.ndarray]], - index_type: str, - nprobe: int, - normalize: bool, - metric: int, - ): - # build the faiss index - self.normalize = ( - normalize - and metric == faiss.METRIC_INNER_PRODUCT - and not isinstance(embeddings, torch.Tensor) - ) - if self.normalize: - index_type = f"L2norm,{index_type}" - faiss_vector_size = embeddings.shape[1] - # if self.device == "cpu": - # index_type = index_type.replace("x,", "x_HNSW32,") - # nlist = math.ceil(math.sqrt(faiss_vector_size)) * 4 - # # nlist = 8 - # index_type = index_type.replace( - # "x", str(nlist) - # ) - # print("Current nlist:", nlist) - print("Current index:", index_type) - self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric) - - # convert to GPU - if self.device == "cuda": - # use a single GPU - faiss_resource = faiss.StandardGpuResources() - self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings) - else: - # move to CPU if embeddings is a torch.Tensor - embeddings = ( - embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings - ) - - self.embeddings.hnsw.efConstruction = 20 - # convert to float32 if embeddings is a torch.Tensor and is float16 - if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16: - embeddings = embeddings.float() - - logger.info("Training the index.") - self.embeddings.train(embeddings) - - logger.info("Adding the embeddings to the index.") - self.embeddings.add(embeddings) - - self.embeddings.nprobe = nprobe - - # self.embeddings.hnsw.efSearch - self.embeddings.hnsw.efSearch = 256 - - # self.embeddings.k_factor = 10 - - # save parameters for saving/loading - self.index_type = index_type - self.metric = metric - - # clear the embeddings to free up memory - embeddings = None - - return self.embeddings - - @torch.no_grad() - @torch.inference_mode() - def index( - self, - retriever: GoldenRetriever, - documents: Optional[List[str]] = None, - batch_size: int = 32, - num_workers: int = 4, - max_length: Optional[int] = None, - collate_fn: Optional[Callable] = None, - encoder_precision: Optional[Union[str, int]] = None, - compute_on_cpu: bool = False, - force_reindex: bool = False, - *args, - **kwargs, - ) -> "FaissDocumentIndex": - """ - Index the documents using the encoder. - - Args: - retriever (:obj:`torch.nn.Module`): - The encoder to be used for indexing. - documents (:obj:`List[str]`, `optional`, defaults to None): - The documents to be indexed. - batch_size (:obj:`int`, `optional`, defaults to 32): - The batch size to be used for indexing. - num_workers (:obj:`int`, `optional`, defaults to 4): - The number of workers to be used for indexing. - max_length (:obj:`int`, `optional`, defaults to None): - The maximum length of the input to the encoder. - collate_fn (:obj:`Callable`, `optional`, defaults to None): - The collate function to be used for batching. - encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): - The precision to be used for the encoder. - compute_on_cpu (:obj:`bool`, `optional`, defaults to False): - Whether to compute the embeddings on CPU. - force_reindex (:obj:`bool`, `optional`, defaults to False): - Whether to force reindexing. - - Returns: - :obj:`InMemoryIndexer`: The indexer object. - """ - - if self.embeddings is not None and not force_reindex: - logger.log( - "Embeddings are already present and `force_reindex` is `False`. Skipping indexing." - ) - if documents is None: - return self - - # release the memory - if collate_fn is None: - tokenizer = retriever.passage_tokenizer - - def collate_fn(x): - return ModelInputs( - tokenizer( - x, - padding=True, - return_tensors="pt", - truncation=True, - max_length=max_length or tokenizer.model_max_length, - ) - ) - - if force_reindex: - if documents is not None: - self.documents.add_labels(documents) - data = [k for k in self.documents.get_labels()] - - else: - if documents is not None: - data = [k for k in Labels(documents).get_labels()] - else: - return self - - dataloader = DataLoader( - BaseDataset(name="passage", data=data), - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=False, - collate_fn=collate_fn, - ) - - encoder = retriever.passage_encoder - - # Create empty lists to store the passage embeddings and passage index - passage_embeddings: List[torch.Tensor] = [] - - encoder_device = "cpu" if compute_on_cpu else self.device - - # fucking autocast only wants pure strings like 'cpu' or 'cuda' - # we need to convert the model device to that - device_type_for_autocast = str(encoder_device).split(":")[0] - # autocast doesn't work with CPU and stuff different from bfloat16 - autocast_pssg_mngr = ( - contextlib.nullcontext() - if device_type_for_autocast == "cpu" - else ( - torch.autocast( - device_type=device_type_for_autocast, - dtype=PRECISION_MAP[encoder_precision], - ) - ) - ) - with autocast_pssg_mngr: - # Iterate through each batch in the dataloader - for batch in tqdm(dataloader, desc="Indexing"): - # Move the batch to the device - batch: ModelInputs = batch.to(encoder_device) - # Compute the passage embeddings - passage_outs = encoder(**batch) - # Append the passage embeddings to the list - if self.device == "cpu": - passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) - else: - passage_embeddings.extend([c for c in passage_outs]) - - # move the passage embeddings to the CPU if not already done - passage_embeddings = [c.detach().cpu() for c in passage_embeddings] - # stack it - passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) - # convert to float32 for faiss - passage_embeddings.to(PRECISION_MAP["float32"]) - - # index the embeddings - self.embeddings = self._build_faiss_index( - embeddings=passage_embeddings, - index_type=self.index_type, - normalize=self.normalize, - metric=self.metric, - ) - # free up memory from the unused variable - del passage_embeddings - - return self - - @torch.no_grad() - @torch.inference_mode() - def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: - - k = min(k, self.embeddings.ntotal) - - if self.normalize: - faiss.normalize_L2(query) - if isinstance(query, torch.Tensor) and self.device == "cpu": - query = query.detach().cpu() - # Retrieve the indices of the top k passage embeddings - retriever_out = self.embeddings.search(query, k) - - # get int values (second element of the tuple) - batch_top_k: List[List[int]] = retriever_out[1].detach().cpu().tolist() - # get float values (first element of the tuple) - batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist() - # Retrieve the passages corresponding to the indices - batch_passages = [ - [self.documents.get_label_from_index(i) for i in indices if i != -1] - for indices in batch_top_k - ] - # build the output object - batch_retrieved_samples = [ - [ - RetrievedSample(label=passage, index=index, score=score) - for passage, index, score in zip(passages, indices, scores) - ] - for passages, indices, scores in zip( - batch_passages, batch_top_k, batch_scores - ) - ] - return batch_retrieved_samples - - # def save(self, saving_dir: Union[str, os.PathLike]): - # """ - # Save the indexer to the disk. - - # Args: - # saving_dir (:obj:`Union[str, os.PathLike]`): - # The directory where the indexer will be saved. - # """ - # saving_dir = Path(saving_dir) - # # save the passage embeddings - # index_path = saving_dir / self.INDEX_FILE_NAME - # logger.info(f"Saving passage embeddings to {index_path}") - # faiss.write_index(self.embeddings, str(index_path)) - # # save the passage index - # documents_path = saving_dir / self.DOCUMENTS_FILE_NAME - # logger.info(f"Saving passage index to {documents_path}") - # self.documents.save(documents_path) - - # @classmethod - # def load( - # cls, - # loading_dir: Union[str, os.PathLike], - # device: str = "cpu", - # document_file_name: Optional[str] = None, - # embedding_file_name: Optional[str] = None, - # index_file_name: Optional[str] = None, - # **kwargs, - # ) -> "FaissDocumentIndex": - # loading_dir = Path(loading_dir) - - # document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME - # embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME - # index_file_name = index_file_name or cls.INDEX_FILE_NAME - - # # load the documents - # documents_path = loading_dir / document_file_name - - # if not documents_path.exists(): - # raise ValueError(f"Document file `{documents_path}` does not exist.") - # logger.info(f"Loading documents from {documents_path}") - # documents = Labels.from_file(documents_path) - - # index = None - # embeddings = None - # # try to load the index directly - # index_path = loading_dir / index_file_name - # if not index_path.exists(): - # # try to load the embeddings - # embedding_path = loading_dir / embedding_file_name - # # run some checks - # if embedding_path.exists(): - # logger.info(f"Loading embeddings from {embedding_path}") - # embeddings = torch.load(embedding_path, map_location="cpu") - # logger.warning( - # f"Index file `{index_path}` and embedding file `{embedding_path}` do not exist." - # ) - # else: - # logger.info(f"Loading index from {index_path}") - # index = faiss.read_index(str(embedding_path)) - - # return cls( - # documents=documents, - # embeddings=embeddings, - # index=index, - # device=device, - # **kwargs, - # ) - - def get_embeddings_from_index( - self, index: int - ) -> Union[torch.Tensor, numpy.ndarray]: - """ - Get the document vector from the index. - - Args: - index (`int`): - The index of the document. - - Returns: - `torch.Tensor`: The document vector. - """ - if self.embeddings is None: - raise ValueError( - "The documents must be indexed before they can be retrieved." - ) - if index >= self.embeddings.ntotal: - raise ValueError( - f"The index {index} is out of bounds. The maximum index is {self.embeddings.ntotal}." - ) - return self.embeddings.reconstruct(index) diff --git a/relik/retriever/indexers/inmemory.py b/relik/retriever/indexers/inmemory.py deleted file mode 100644 index 8fb49bcaedf3f81c906c59dc23e7f8e0472a8598..0000000000000000000000000000000000000000 --- a/relik/retriever/indexers/inmemory.py +++ /dev/null @@ -1,287 +0,0 @@ -import contextlib -import logging -import os -from typing import Callable, List, Optional, Tuple, Union - -import torch -from torch.utils.data import DataLoader -from tqdm import tqdm - -from relik.common.log import get_logger -from relik.retriever.common.model_inputs import ModelInputs -from relik.retriever.data.base.datasets import BaseDataset -from relik.retriever.data.labels import Labels -from relik.retriever.indexers.base import BaseDocumentIndex -from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample - -logger = get_logger(__name__, level=logging.INFO) - - -class InMemoryDocumentIndex(BaseDocumentIndex): - DOCUMENTS_FILE_NAME = "documents.json" - EMBEDDINGS_FILE_NAME = "embeddings.pt" - - def __init__( - self, - documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None, - embeddings: Optional[torch.Tensor] = None, - device: str = "cpu", - precision: Optional[str] = None, - name_or_dir: Optional[Union[str, os.PathLike]] = None, - *args, - **kwargs, - ) -> None: - """ - An in-memory indexer. - - Args: - documents (:obj:`Union[List[str], PassageManager]`): - The documents to be indexed. - embeddings (:obj:`Optional[torch.Tensor]`, `optional`, defaults to :obj:`None`): - The embeddings of the documents. - device (:obj:`str`, `optional`, defaults to "cpu"): - The device to be used for storing the embeddings. - """ - - super().__init__(documents, embeddings, name_or_dir) - - if embeddings is not None and documents is not None: - logger.info("Both documents and embeddings are provided.") - if documents.get_label_size() != embeddings.shape[0]: - raise ValueError( - "The number of documents and embeddings must be the same." - ) - - # embeddings of the documents - self.embeddings = embeddings - # does this do anything? - del embeddings - # convert the embeddings to the desired precision - if precision is not None: - if ( - self.embeddings is not None - and self.embeddings.dtype != PRECISION_MAP[precision] - ): - logger.info( - f"Index vectors are of type {self.embeddings.dtype}. " - f"Converting to {PRECISION_MAP[precision]}." - ) - self.embeddings = self.embeddings.to(PRECISION_MAP[precision]) - else: - if ( - device == "cpu" - and self.embeddings is not None - and self.embeddings.dtype != torch.float32 - ): - logger.info( - "Index vectors are of type {}. Converting to float32.".format( - self.embeddings.dtype - ) - ) - self.embeddings = self.embeddings.to(PRECISION_MAP[32]) - # move the embeddings to the desired device - if self.embeddings is not None and not self.embeddings.device == device: - self.embeddings = self.embeddings.to(device) - - # device to store the embeddings - self.device = device - # precision to be used for the embeddings - self.precision = precision - - @torch.no_grad() - @torch.inference_mode() - def index( - self, - retriever, - documents: Optional[List[str]] = None, - batch_size: int = 32, - num_workers: int = 4, - max_length: Optional[int] = None, - collate_fn: Optional[Callable] = None, - encoder_precision: Optional[Union[str, int]] = None, - compute_on_cpu: bool = False, - force_reindex: bool = False, - add_to_existing_index: bool = False, - ) -> "InMemoryDocumentIndex": - """ - Index the documents using the encoder. - - Args: - retriever (:obj:`torch.nn.Module`): - The encoder to be used for indexing. - documents (:obj:`List[str]`, `optional`, defaults to :obj:`None`): - The documents to be indexed. - batch_size (:obj:`int`, `optional`, defaults to 32): - The batch size to be used for indexing. - num_workers (:obj:`int`, `optional`, defaults to 4): - The number of workers to be used for indexing. - max_length (:obj:`int`, `optional`, defaults to None): - The maximum length of the input to the encoder. - collate_fn (:obj:`Callable`, `optional`, defaults to None): - The collate function to be used for batching. - encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): - The precision to be used for the encoder. - compute_on_cpu (:obj:`bool`, `optional`, defaults to False): - Whether to compute the embeddings on CPU. - force_reindex (:obj:`bool`, `optional`, defaults to False): - Whether to force reindexing. - add_to_existing_index (:obj:`bool`, `optional`, defaults to False): - Whether to add the new documents to the existing index. - - Returns: - :obj:`InMemoryIndexer`: The indexer object. - """ - - if documents is None and self.documents is None: - raise ValueError("Documents must be provided.") - - if self.embeddings is not None and not force_reindex: - logger.info( - "Embeddings are already present and `force_reindex` is `False`. Skipping indexing." - ) - if documents is None: - return self - - if collate_fn is None: - tokenizer = retriever.passage_tokenizer - - def collate_fn(x): - return ModelInputs( - tokenizer( - x, - padding=True, - return_tensors="pt", - truncation=True, - max_length=max_length or tokenizer.model_max_length, - ) - ) - - if force_reindex: - if documents is not None: - self.documents.add_labels(documents) - data = [k for k in self.documents.get_labels()] - - else: - if documents is not None: - data = [k for k in Labels(documents).get_labels()] - else: - return self - - # if force_reindex: - # data = [k for k in self.documents.get_labels()] - - dataloader = DataLoader( - BaseDataset(name="passage", data=data), - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=False, - collate_fn=collate_fn, - ) - - encoder = retriever.passage_encoder - - # Create empty lists to store the passage embeddings and passage index - passage_embeddings: List[torch.Tensor] = [] - - encoder_device = "cpu" if compute_on_cpu else self.device - - # fucking autocast only wants pure strings like 'cpu' or 'cuda' - # we need to convert the model device to that - device_type_for_autocast = str(encoder_device).split(":")[0] - # autocast doesn't work with CPU and stuff different from bfloat16 - autocast_pssg_mngr = ( - contextlib.nullcontext() - if device_type_for_autocast == "cpu" - else ( - torch.autocast( - device_type=device_type_for_autocast, - dtype=PRECISION_MAP[encoder_precision], - ) - ) - ) - with autocast_pssg_mngr: - # Iterate through each batch in the dataloader - for batch in tqdm(dataloader, desc="Indexing"): - # Move the batch to the device - batch: ModelInputs = batch.to(encoder_device) - # Compute the passage embeddings - passage_outs = encoder(**batch).pooler_output - # Append the passage embeddings to the list - if self.device == "cpu": - passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) - else: - passage_embeddings.extend([c for c in passage_outs]) - - # move the passage embeddings to the CPU if not already done - # the move to cpu and then to gpu is needed to avoid OOM when using mixed precision - if not self.device == "cpu": # this if is to avoid unnecessary moves - passage_embeddings = [c.detach().cpu() for c in passage_embeddings] - # stack it - passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) - # move the passage embeddings to the gpu if needed - if not self.device == "cpu": - passage_embeddings = passage_embeddings.to(PRECISION_MAP[self.precision]) - passage_embeddings = passage_embeddings.to(self.device) - self.embeddings = passage_embeddings - - # free up memory from the unused variable - del passage_embeddings - - return self - - @torch.no_grad() - @torch.inference_mode() - def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: - """ - Search the documents using the query. - - Args: - query (:obj:`torch.Tensor`): - The query to be used for searching. - k (:obj:`int`, `optional`, defaults to 1): - The number of documents to be retrieved. - - Returns: - :obj:`List[RetrievedSample]`: The retrieved documents. - """ - # fucking autocast only wants pure strings like 'cpu' or 'cuda' - # we need to convert the model device to that - device_type_for_autocast = str(self.device).split(":")[0] - # autocast doesn't work with CPU and stuff different from bfloat16 - autocast_pssg_mngr = ( - contextlib.nullcontext() - if device_type_for_autocast == "cpu" - else ( - torch.autocast( - device_type=device_type_for_autocast, - dtype=self.embeddings.dtype, - ) - ) - ) - with autocast_pssg_mngr: - similarity = torch.matmul(query, self.embeddings.T) - # Retrieve the indices of the top k passage embeddings - retriever_out: Tuple = torch.topk( - similarity, k=min(k, similarity.shape[-1]), dim=1 - ) - # get int values - batch_top_k: List[List[int]] = retriever_out.indices.detach().cpu().tolist() - # get float values - batch_scores: List[List[float]] = retriever_out.values.detach().cpu().tolist() - # Retrieve the passages corresponding to the indices - batch_passages = [ - [self.documents.get_label_from_index(i) for i in indices] - for indices in batch_top_k - ] - # build the output object - batch_retrieved_samples = [ - [ - RetrievedSample(label=passage, index=index, score=score) - for passage, index, score in zip(passages, indices, scores) - ] - for passages, indices, scores in zip( - batch_passages, batch_top_k, batch_scores - ) - ] - return batch_retrieved_samples diff --git a/relik/retriever/lightning_modules/__init__.py b/relik/retriever/lightning_modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/relik/retriever/lightning_modules/pl_data_modules.py b/relik/retriever/lightning_modules/pl_data_modules.py deleted file mode 100644 index 8c69f5f291789cb7b473dd8466f7373b1a010e8a..0000000000000000000000000000000000000000 --- a/relik/retriever/lightning_modules/pl_data_modules.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import Any, List, Optional, Sequence, Union - -import hydra -import lightning as pl -import torch -from lightning.pytorch.utilities.types import EVAL_DATALOADERS -from omegaconf import DictConfig -from torch.utils.data import DataLoader - -from relik.common.log import get_logger -from relik.retriever.data.datasets import GoldenRetrieverDataset - -logger = get_logger() - - -class GoldenRetrieverPLDataModule(pl.LightningDataModule): - def __init__( - self, - train_dataset: Optional[GoldenRetrieverDataset] = None, - val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, - test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, - num_workers: Optional[Union[DictConfig, int]] = None, - datasets: Optional[DictConfig] = None, - *args, - **kwargs, - ): - super().__init__() - self.datasets = datasets - if num_workers is None: - num_workers = 0 - if isinstance(num_workers, int): - num_workers = DictConfig( - {"train": num_workers, "val": num_workers, "test": num_workers} - ) - self.num_workers = num_workers - # data - self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset - self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets - self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets - - def prepare_data(self, *args, **kwargs): - """ - Method for preparing the data before the training. This method is called only once. - It is used to download the data, tokenize the data, etc. - """ - pass - - def setup(self, stage: Optional[str] = None): - if stage == "fit" or stage is None: - # usually there is only one dataset for train - # if you need more train loader, you can follow - # the same logic as val and test datasets - if self.train_dataset is None: - self.train_dataset = hydra.utils.instantiate(self.datasets.train) - self.val_datasets = [ - hydra.utils.instantiate(dataset_cfg) - for dataset_cfg in self.datasets.val - ] - if stage == "test": - if self.test_datasets is None: - self.test_datasets = [ - hydra.utils.instantiate(dataset_cfg) - for dataset_cfg in self.datasets.test - ] - - def train_dataloader(self, *args, **kwargs) -> DataLoader: - torch_dataset = self.train_dataset.to_torch_dataset() - return DataLoader( - # self.train_dataset.to_torch_dataset(), - torch_dataset, - shuffle=False, - batch_size=None, - num_workers=self.num_workers.train, - pin_memory=True, - collate_fn=lambda x: x, - ) - - def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - dataloaders = [] - for dataset in self.val_datasets: - torch_dataset = dataset.to_torch_dataset() - dataloaders.append( - DataLoader( - torch_dataset, - shuffle=False, - batch_size=None, - num_workers=self.num_workers.val, - pin_memory=True, - collate_fn=lambda x: x, - ) - ) - return dataloaders - - def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - dataloaders = [] - for dataset in self.test_datasets: - torch_dataset = dataset.to_torch_dataset() - dataloaders.append( - DataLoader( - torch_dataset, - shuffle=False, - batch_size=None, - num_workers=self.num_workers.test, - pin_memory=True, - collate_fn=lambda x: x, - ) - ) - return dataloaders - - def predict_dataloader(self) -> EVAL_DATALOADERS: - raise NotImplementedError - - def transfer_batch_to_device( - self, batch: Any, device: torch.device, dataloader_idx: int - ) -> Any: - return super().transfer_batch_to_device(batch, device, dataloader_idx) - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, " - ) diff --git a/relik/retriever/lightning_modules/pl_modules.py b/relik/retriever/lightning_modules/pl_modules.py deleted file mode 100644 index 1ef6ace6eb8ac95d92e16b1e28dc50d2492b33f4..0000000000000000000000000000000000000000 --- a/relik/retriever/lightning_modules/pl_modules.py +++ /dev/null @@ -1,123 +0,0 @@ -from typing import Any, Union - -import hydra -import lightning as pl -import torch -from omegaconf import DictConfig - -from relik.retriever.common.model_inputs import ModelInputs - - -class GoldenRetrieverPLModule(pl.LightningModule): - def __init__( - self, - model: Union[torch.nn.Module, DictConfig], - optimizer: Union[torch.optim.Optimizer, DictConfig], - lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, DictConfig] = None, - *args, - **kwargs, - ) -> None: - super().__init__() - self.save_hyperparameters(ignore=["model"]) - if isinstance(model, DictConfig): - self.model = hydra.utils.instantiate(model) - else: - self.model = model - - self.optimizer_config = optimizer - self.lr_scheduler_config = lr_scheduler - - def forward(self, **kwargs) -> dict: - """ - Method for the forward pass. - 'training_step', 'validation_step' and 'test_step' should call - this method in order to compute the output predictions and the loss. - - Returns: - output_dict: forward output containing the predictions (output logits ecc...) and the loss if any. - - """ - return self.model(**kwargs) - - def training_step(self, batch: ModelInputs, batch_idx: int) -> torch.Tensor: - forward_output = self.forward(**batch, return_loss=True) - self.log( - "loss", - forward_output["loss"], - batch_size=batch["questions"]["input_ids"].size(0), - prog_bar=True, - ) - return forward_output["loss"] - - def validation_step(self, batch: ModelInputs, batch_idx: int) -> None: - forward_output = self.forward(**batch, return_loss=True) - self.log( - "val_loss", - forward_output["loss"], - batch_size=batch["questions"]["input_ids"].size(0), - ) - - def test_step(self, batch: ModelInputs, batch_idx: int) -> Any: - forward_output = self.forward(**batch, return_loss=True) - self.log( - "test_loss", - forward_output["loss"], - batch_size=batch["questions"]["input_ids"].size(0), - ) - - def configure_optimizers(self): - if isinstance(self.optimizer_config, DictConfig): - param_optimizer = list(self.named_parameters()) - no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [ - p for n, p in param_optimizer if "layer_norm_layer" in n - ], - "weight_decay": self.hparams.optimizer.weight_decay, - "lr": 1e-4, - }, - { - "params": [ - p - for n, p in param_optimizer - if all(nd not in n for nd in no_decay) - and "layer_norm_layer" not in n - ], - "weight_decay": self.hparams.optimizer.weight_decay, - }, - { - "params": [ - p - for n, p in param_optimizer - if "layer_norm_layer" not in n - and any(nd in n for nd in no_decay) - ], - "weight_decay": 0.0, - }, - ] - optimizer = hydra.utils.instantiate( - self.optimizer_config, - # params=self.parameters(), - params=optimizer_grouped_parameters, - _convert_="partial", - ) - else: - optimizer = self.optimizer_config - - if self.lr_scheduler_config is None: - return optimizer - - if isinstance(self.lr_scheduler_config, DictConfig): - lr_scheduler = hydra.utils.instantiate( - self.lr_scheduler_config, optimizer=optimizer - ) - else: - lr_scheduler = self.lr_scheduler_config - - lr_scheduler_config = { - "scheduler": lr_scheduler, - "interval": "step", - "frequency": 1, - } - return [optimizer], [lr_scheduler_config] diff --git a/relik/retriever/pytorch_modules/__init__.py b/relik/retriever/pytorch_modules/__init__.py deleted file mode 100644 index 01752b8aa79367a7bcdc2d18438ae87bebdd87f2..0000000000000000000000000000000000000000 --- a/relik/retriever/pytorch_modules/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -from dataclasses import dataclass - -import torch - -PRECISION_MAP = { - None: torch.float32, - 16: torch.float16, - 32: torch.float32, - "float16": torch.float16, - "float32": torch.float32, - "half": torch.float16, - "float": torch.float32, - "16": torch.float16, - "32": torch.float32, - "fp16": torch.float16, - "fp32": torch.float32, -} - - -@dataclass -class RetrievedSample: - """ - Dataclass for the output of the GoldenRetriever model. - """ - - score: float - index: int - label: str diff --git a/relik/retriever/pytorch_modules/hf.py b/relik/retriever/pytorch_modules/hf.py deleted file mode 100644 index b5868d67ce5ed97a1e66c6d1d3a606350e75267a..0000000000000000000000000000000000000000 --- a/relik/retriever/pytorch_modules/hf.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Tuple, Union - -import torch -from transformers import PretrainedConfig -from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions -from transformers.models.bert.modeling_bert import BertModel - - -class GoldenRetrieverConfig(PretrainedConfig): - model_type = "bert" - - def __init__( - self, - vocab_size=30522, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - layer_norm_eps=1e-12, - pad_token_id=0, - position_embedding_type="absolute", - use_cache=True, - classifier_dropout=None, - **kwargs, - ): - super().__init__(pad_token_id=pad_token_id, **kwargs) - - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type - self.use_cache = use_cache - self.classifier_dropout = classifier_dropout - - -class GoldenRetrieverModel(BertModel): - config_class = GoldenRetrieverConfig - - def __init__(self, config, *args, **kwargs): - super().__init__(config) - self.layer_norm_layer = torch.nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - - def forward( - self, **kwargs - ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: - attention_mask = kwargs.get("attention_mask", None) - model_outputs = super().forward(**kwargs) - if attention_mask is None: - pooler_output = model_outputs.pooler_output - else: - token_embeddings = model_outputs.last_hidden_state - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - ) - pooler_output = torch.sum( - token_embeddings * input_mask_expanded, 1 - ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) - - pooler_output = self.layer_norm_layer(pooler_output) - - if not kwargs.get("return_dict", True): - return (model_outputs[0], pooler_output) + model_outputs[2:] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=model_outputs.last_hidden_state, - pooler_output=pooler_output, - past_key_values=model_outputs.past_key_values, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, - cross_attentions=model_outputs.cross_attentions, - ) diff --git a/relik/retriever/pytorch_modules/loss.py b/relik/retriever/pytorch_modules/loss.py deleted file mode 100644 index 643d3a486ca73ca38486094553b357f5e7c28adb..0000000000000000000000000000000000000000 --- a/relik/retriever/pytorch_modules/loss.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Optional - -import torch -from torch.nn.modules.loss import _WeightedLoss - - -class MultiLabelNCELoss(_WeightedLoss): - __constants__ = ["reduction"] - - def __init__( - self, - weight: Optional[torch.Tensor] = None, - size_average=None, - reduction: Optional[str] = "mean", - ) -> None: - super(MultiLabelNCELoss, self).__init__(weight, size_average, None, reduction) - - def forward( - self, input: torch.Tensor, target: torch.Tensor, ignore_index: int = -100 - ) -> torch.Tensor: - gold_scores = input.masked_fill(~(target.bool()), 0) - gold_scores_sum = gold_scores.sum(-1) # B x C - neg_logits = input.masked_fill(target.bool(), float("-inf")) # B x C x L - neg_log_sum_exp = torch.logsumexp(neg_logits, -1, keepdim=True) # B x C x 1 - norm_term = ( - torch.logaddexp(input, neg_log_sum_exp) - .masked_fill(~(target.bool()), 0) - .sum(-1) - ) - gold_log_probs = gold_scores_sum - norm_term - loss = -gold_log_probs.sum() - if self.reduction == "mean": - loss /= input.size(0) - return loss diff --git a/relik/retriever/pytorch_modules/model.py b/relik/retriever/pytorch_modules/model.py deleted file mode 100644 index f02aedd5b43cbd789b87ae4afd417c919b72b129..0000000000000000000000000000000000000000 --- a/relik/retriever/pytorch_modules/model.py +++ /dev/null @@ -1,533 +0,0 @@ -import contextlib -import logging -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Dict, List, Optional, Union - -import torch -import torch.nn.functional as F -import transformers as tr - -from relik.common.log import get_console_logger, get_logger -from relik.retriever.common.model_inputs import ModelInputs -from relik.retriever.data.labels import Labels -from relik.retriever.indexers.base import BaseDocumentIndex -from relik.retriever.indexers.inmemory import InMemoryDocumentIndex -from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample -from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel - -console_logger = get_console_logger() -logger = get_logger(__name__, level=logging.INFO) - - -@dataclass -class GoldenRetrieverOutput(tr.file_utils.ModelOutput): - """Class for model's outputs.""" - - logits: Optional[torch.FloatTensor] = None - loss: Optional[torch.FloatTensor] = None - question_encodings: Optional[torch.FloatTensor] = None - passages_encodings: Optional[torch.FloatTensor] = None - - -class GoldenRetriever(torch.nn.Module): - def __init__( - self, - question_encoder: Union[str, tr.PreTrainedModel], - loss_type: Optional[torch.nn.Module] = None, - passage_encoder: Optional[Union[str, tr.PreTrainedModel]] = None, - document_index: Optional[Union[str, BaseDocumentIndex]] = None, - question_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, - passage_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, - device: Optional[Union[str, torch.device]] = None, - precision: Optional[Union[str, int]] = None, - index_precision: Optional[Union[str, int]] = 32, - index_device: Optional[Union[str, torch.device]] = "cpu", - *args, - **kwargs, - ): - super().__init__() - - self.passage_encoder_is_question_encoder = False - # question encoder model - if isinstance(question_encoder, str): - question_encoder = GoldenRetrieverModel.from_pretrained( - question_encoder, **kwargs - ) - self.question_encoder = question_encoder - if passage_encoder is None: - # if no passage encoder is provided, - # share the weights of the question encoder - passage_encoder = question_encoder - # keep track of the fact that the passage encoder is the same as the question encoder - self.passage_encoder_is_question_encoder = True - if isinstance(passage_encoder, str): - passage_encoder = GoldenRetrieverModel.from_pretrained( - passage_encoder, **kwargs - ) - # passage encoder model - self.passage_encoder = passage_encoder - - # loss function - self.loss_type = loss_type - - # indexer stuff - if document_index is None: - # if no indexer is provided, create a new one - document_index = InMemoryDocumentIndex( - device=index_device, precision=index_precision, **kwargs - ) - if isinstance(document_index, str): - document_index = BaseDocumentIndex.from_pretrained( - document_index, device=index_device, precision=index_precision, **kwargs - ) - self.document_index = document_index - - # lazy load the tokenizer for inference - self._question_tokenizer = question_tokenizer - self._passage_tokenizer = passage_tokenizer - - # move the model to the device - self.to(device or torch.device("cpu")) - - # set the precision - self.precision = precision - - def forward( - self, - questions: Optional[Dict[str, torch.Tensor]] = None, - passages: Optional[Dict[str, torch.Tensor]] = None, - labels: Optional[torch.Tensor] = None, - question_encodings: Optional[torch.Tensor] = None, - passages_encodings: Optional[torch.Tensor] = None, - passages_per_question: Optional[List[int]] = None, - return_loss: bool = False, - return_encodings: bool = False, - *args, - **kwargs, - ) -> GoldenRetrieverOutput: - """ - Forward pass of the model. - - Args: - questions (`Dict[str, torch.Tensor]`): - The questions to encode. - passages (`Dict[str, torch.Tensor]`): - The passages to encode. - labels (`torch.Tensor`): - The labels of the sentences. - return_loss (`bool`): - Whether to compute the predictions. - question_encodings (`torch.Tensor`): - The encodings of the questions. - passages_encodings (`torch.Tensor`): - The encodings of the passages. - passages_per_question (`List[int]`): - The number of passages per question. - return_loss (`bool`): - Whether to compute the loss. - return_encodings (`bool`): - Whether to return the encodings. - - Returns: - obj:`torch.Tensor`: The outputs of the model. - """ - if questions is None and question_encodings is None: - raise ValueError( - "Either `questions` or `question_encodings` must be provided" - ) - if passages is None and passages_encodings is None: - raise ValueError( - "Either `passages` or `passages_encodings` must be provided" - ) - - if question_encodings is None: - question_encodings = self.question_encoder(**questions).pooler_output - if passages_encodings is None: - passages_encodings = self.passage_encoder(**passages).pooler_output - - if passages_per_question is not None: - # multiply each question encoding with a passages_per_question encodings - concatenated_passages = torch.stack( - torch.split(passages_encodings, passages_per_question) - ).transpose(1, 2) - if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss): - # normalize the encodings for cosine similarity - concatenated_passages = F.normalize(concatenated_passages, p=2, dim=2) - question_encodings = F.normalize(question_encodings, p=2, dim=1) - logits = torch.bmm( - question_encodings.unsqueeze(1), concatenated_passages - ).view(question_encodings.shape[0], -1) - else: - if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss): - # normalize the encodings for cosine similarity - question_encodings = F.normalize(question_encodings, p=2, dim=1) - passages_encodings = F.normalize(passages_encodings, p=2, dim=1) - - logits = torch.matmul(question_encodings, passages_encodings.T) - - output = dict(logits=logits) - - if return_loss and labels is not None: - if self.loss_type is None: - raise ValueError( - "If `return_loss` is set to `True`, `loss_type` must be provided" - ) - if isinstance(self.loss_type, torch.nn.NLLLoss): - labels = labels.argmax(dim=1) - logits = F.log_softmax(logits, dim=1) - if len(question_encodings.size()) > 1: - logits = logits.view(question_encodings.size(0), -1) - - output["loss"] = self.loss_type(logits, labels) - - if return_encodings: - output["question_encodings"] = question_encodings - output["passages_encodings"] = passages_encodings - - return GoldenRetrieverOutput(**output) - - @torch.no_grad() - @torch.inference_mode() - def index( - self, - batch_size: int = 32, - num_workers: int = 4, - max_length: Optional[int] = None, - collate_fn: Optional[Callable] = None, - force_reindex: bool = False, - compute_on_cpu: bool = False, - precision: Optional[Union[str, int]] = None, - ): - """ - Index the passages for later retrieval. - - Args: - batch_size (`int`): - The batch size to use for the indexing. - num_workers (`int`): - The number of workers to use for the indexing. - max_length (`Optional[int]`): - The maximum length of the passages. - collate_fn (`Callable`): - The collate function to use for the indexing. - force_reindex (`bool`): - Whether to force reindexing even if the passages are already indexed. - compute_on_cpu (`bool`): - Whether to move the index to the CPU after the indexing. - precision (`Optional[Union[str, int]]`): - The precision to use for the model. - """ - if self.document_index is None: - raise ValueError( - "The retriever must be initialized with an indexer to index " - "the passages within the retriever." - ) - return self.document_index.index( - retriever=self, - batch_size=batch_size, - num_workers=num_workers, - max_length=max_length, - collate_fn=collate_fn, - encoder_precision=precision or self.precision, - compute_on_cpu=compute_on_cpu, - force_reindex=force_reindex, - ) - - @torch.no_grad() - @torch.inference_mode() - def retrieve( - self, - text: Optional[Union[str, List[str]]] = None, - text_pair: Optional[Union[str, List[str]]] = None, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - k: Optional[int] = None, - max_length: Optional[int] = None, - precision: Optional[Union[str, int]] = None, - ) -> List[List[RetrievedSample]]: - """ - Retrieve the passages for the questions. - - Args: - text (`Optional[Union[str, List[str]]]`): - The questions to retrieve the passages for. - text_pair (`Optional[Union[str, List[str]]]`): - The questions to retrieve the passages for. - input_ids (`torch.Tensor`): - The input ids of the questions. - attention_mask (`torch.Tensor`): - The attention mask of the questions. - token_type_ids (`torch.Tensor`): - The token type ids of the questions. - k (`int`): - The number of top passages to retrieve. - max_length (`Optional[int]`): - The maximum length of the questions. - precision (`Optional[Union[str, int]]`): - The precision to use for the model. - - Returns: - `List[List[RetrievedSample]]`: The retrieved passages and their indices. - """ - if self.document_index is None: - raise ValueError( - "The indexer must be indexed before it can be used within the retriever." - ) - if text is None and input_ids is None: - raise ValueError( - "Either `text` or `input_ids` must be provided to retrieve the passages." - ) - - if text is not None: - if isinstance(text, str): - text = [text] - if text_pair is not None and isinstance(text_pair, str): - text_pair = [text_pair] - tokenizer = self.question_tokenizer - model_inputs = ModelInputs( - tokenizer( - text, - text_pair=text_pair, - padding=True, - return_tensors="pt", - truncation=True, - max_length=max_length or tokenizer.model_max_length, - ) - ) - else: - model_inputs = ModelInputs(dict(input_ids=input_ids)) - if attention_mask is not None: - model_inputs["attention_mask"] = attention_mask - if token_type_ids is not None: - model_inputs["token_type_ids"] = token_type_ids - - model_inputs.to(self.device) - - # fucking autocast only wants pure strings like 'cpu' or 'cuda' - # we need to convert the model device to that - device_type_for_autocast = str(self.device).split(":")[0] - # autocast doesn't work with CPU and stuff different from bfloat16 - autocast_pssg_mngr = ( - contextlib.nullcontext() - if device_type_for_autocast == "cpu" - else ( - torch.autocast( - device_type=device_type_for_autocast, - dtype=PRECISION_MAP[precision], - ) - ) - ) - with autocast_pssg_mngr: - question_encodings = self.question_encoder(**model_inputs).pooler_output - - # TODO: fix if encoder and index are on different device - return self.document_index.search(question_encodings, k) - - def get_index_from_passage(self, passage: str) -> int: - """ - Get the index of the passage. - - Args: - passage (`str`): - The passage to get the index for. - - Returns: - `int`: The index of the passage. - """ - if self.document_index is None: - raise ValueError( - "The passages must be indexed before they can be retrieved." - ) - return self.document_index.get_index_from_passage(passage) - - def get_passage_from_index(self, index: int) -> str: - """ - Get the passage from the index. - - Args: - index (`int`): - The index of the passage. - - Returns: - `str`: The passage. - """ - if self.document_index is None: - raise ValueError( - "The passages must be indexed before they can be retrieved." - ) - return self.document_index.get_passage_from_index(index) - - def get_vector_from_index(self, index: int) -> torch.Tensor: - """ - Get the passage vector from the index. - - Args: - index (`int`): - The index of the passage. - - Returns: - `torch.Tensor`: The passage vector. - """ - if self.document_index is None: - raise ValueError( - "The passages must be indexed before they can be retrieved." - ) - return self.document_index.get_embeddings_from_index(index) - - def get_vector_from_passage(self, passage: str) -> torch.Tensor: - """ - Get the passage vector from the passage. - - Args: - passage (`str`): - The passage. - - Returns: - `torch.Tensor`: The passage vector. - """ - if self.document_index is None: - raise ValueError( - "The passages must be indexed before they can be retrieved." - ) - return self.document_index.get_embeddings_from_passage(passage) - - @property - def passage_embeddings(self) -> torch.Tensor: - """ - The passage embeddings. - """ - return self._passage_embeddings - - @property - def passage_index(self) -> Labels: - """ - The passage index. - """ - return self._passage_index - - @property - def device(self) -> torch.device: - """ - The device of the model. - """ - return next(self.parameters()).device - - @property - def question_tokenizer(self) -> tr.PreTrainedTokenizer: - """ - The question tokenizer. - """ - if self._question_tokenizer: - return self._question_tokenizer - - if ( - self.question_encoder.config.name_or_path - == self.question_encoder.config.name_or_path - ): - if not self._question_tokenizer: - self._question_tokenizer = tr.AutoTokenizer.from_pretrained( - self.question_encoder.config.name_or_path - ) - self._passage_tokenizer = self._question_tokenizer - return self._question_tokenizer - - if not self._question_tokenizer: - self._question_tokenizer = tr.AutoTokenizer.from_pretrained( - self.question_encoder.config.name_or_path - ) - return self._question_tokenizer - - @property - def passage_tokenizer(self) -> tr.PreTrainedTokenizer: - """ - The passage tokenizer. - """ - if self._passage_tokenizer: - return self._passage_tokenizer - - if ( - self.question_encoder.config.name_or_path - == self.passage_encoder.config.name_or_path - ): - if not self._question_tokenizer: - self._question_tokenizer = tr.AutoTokenizer.from_pretrained( - self.question_encoder.config.name_or_path - ) - self._passage_tokenizer = self._question_tokenizer - return self._passage_tokenizer - - if not self._passage_tokenizer: - self._passage_tokenizer = tr.AutoTokenizer.from_pretrained( - self.passage_encoder.config.name_or_path - ) - return self._passage_tokenizer - - def save_pretrained( - self, - output_dir: Union[str, os.PathLike], - question_encoder_name: Optional[str] = None, - passage_encoder_name: Optional[str] = None, - document_index_name: Optional[str] = None, - push_to_hub: bool = False, - **kwargs, - ): - """ - Save the retriever to a directory. - - Args: - output_dir (`str`): - The directory to save the retriever to. - question_encoder_name (`Optional[str]`): - The name of the question encoder. - passage_encoder_name (`Optional[str]`): - The name of the passage encoder. - document_index_name (`Optional[str]`): - The name of the document index. - push_to_hub (`bool`): - Whether to push the model to the hub. - """ - - # create the output directory - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Saving retriever to {output_dir}") - - question_encoder_name = question_encoder_name or "question_encoder" - passage_encoder_name = passage_encoder_name or "passage_encoder" - document_index_name = document_index_name or "document_index" - - logger.info( - f"Saving question encoder state to {output_dir / question_encoder_name}" - ) - # self.question_encoder.config._name_or_path = question_encoder_name - self.question_encoder.register_for_auto_class() - self.question_encoder.save_pretrained( - output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs - ) - self.question_tokenizer.save_pretrained( - output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs - ) - if not self.passage_encoder_is_question_encoder: - logger.info( - f"Saving passage encoder state to {output_dir / passage_encoder_name}" - ) - # self.passage_encoder.config._name_or_path = passage_encoder_name - self.passage_encoder.register_for_auto_class() - self.passage_encoder.save_pretrained( - output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs - ) - self.passage_tokenizer.save_pretrained( - output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs - ) - - if self.document_index is not None: - # save the indexer - self.document_index.save_pretrained( - output_dir / document_index_name, push_to_hub=push_to_hub, **kwargs - ) - - logger.info("Saving retriever to disk done.") diff --git a/relik/retriever/pytorch_modules/optim.py b/relik/retriever/pytorch_modules/optim.py deleted file mode 100644 index e815acf44948a4835d7b093c74e6ca8cf8539dd1..0000000000000000000000000000000000000000 --- a/relik/retriever/pytorch_modules/optim.py +++ /dev/null @@ -1,213 +0,0 @@ -import math - -import torch -from torch.optim import Optimizer - - -class RAdamW(Optimizer): - r"""Implements RAdamW algorithm. - - RAdam from `On the Variance of the Adaptive Learning Rate and Beyond - `_ - - * `Adam: A Method for Stochastic Optimization - `_ - * `Decoupled Weight Decay Regularization - `_ - * `On the Convergence of Adam and Beyond - `_ - * `On the Variance of the Adaptive Learning Rate and Beyond - `_ - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - """ - - def __init__( - self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super(RAdamW, self).__init__(params, defaults) - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError( - "Adam does not support sparse gradients, please consider SparseAdam instead" - ) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(p.data) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p.data) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - beta1, beta2 = group["betas"] - eps = group["eps"] - lr = group["lr"] - if "rho_inf" not in group: - group["rho_inf"] = 2 / (1 - beta2) - 1 - rho_inf = group["rho_inf"] - - state["step"] += 1 - t = state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - rho_t = rho_inf - ((2 * t * (beta2**t)) / (1 - beta2**t)) - - # Perform stepweight decay - p.data.mul_(1 - lr * group["weight_decay"]) - - if rho_t >= 5: - var = exp_avg_sq.sqrt().add_(eps) - r = math.sqrt( - (1 - beta2**t) - * ((rho_t - 4) * (rho_t - 2) * rho_inf) - / ((rho_inf - 4) * (rho_inf - 2) * rho_t) - ) - - p.data.addcdiv_(exp_avg, var, value=-lr * r / (1 - beta1**t)) - else: - p.data.add_(exp_avg, alpha=-lr / (1 - beta1**t)) - - return loss - - -# from typing import List -# import collections - -# import torch -# import transformers -# from classy.optim.factories import Factory -# from transformers import AdamW - - -# class ElectraOptimizer(Factory): -# def __init__( -# self, -# lr: float, -# warmup_steps: int, -# total_steps: int, -# weight_decay: float, -# lr_decay: float, -# no_decay_params: List[str], -# ): -# self.lr = lr -# self.warmup_steps = warmup_steps -# self.total_steps = total_steps -# self.weight_decay = weight_decay -# self.lr_decay = lr_decay -# self.no_decay_params = no_decay_params - -# def group_layers(self, module) -> dict: -# grouped_layers = collections.defaultdict(list) -# module_named_parameters = list(module.named_parameters()) -# for ln, lp in module_named_parameters: -# if "embeddings" in ln: -# grouped_layers["embeddings"].append((ln, lp)) -# elif "encoder.layer" in ln: -# layer_num = ln.replace("transformer_model.encoder.layer.", "") -# layer_num = layer_num[0 : layer_num.index(".")] -# grouped_layers[layer_num].append((ln, lp)) -# else: -# grouped_layers["head"].append((ln, lp)) - -# depth = len(grouped_layers) - 1 -# final_dict = dict() -# for key, value in grouped_layers.items(): -# if key == "head": -# final_dict[0] = value -# elif key == "embeddings": -# final_dict[depth] = value -# else: -# # -1 because layer number starts from zero -# final_dict[depth - int(key) - 1] = value - -# assert len(module_named_parameters) == sum( -# len(v) for _, v in final_dict.items() -# ) - -# return final_dict - -# def group_params(self, module) -> list: -# optimizer_grouped_params = [] -# for inverse_depth, layer in self.group_layers(module).items(): -# layer_lr = self.lr * (self.lr_decay**inverse_depth) -# layer_wd_params = { -# "params": [ -# lp -# for ln, lp in layer -# if not any(nd in ln for nd in self.no_decay_params) -# ], -# "weight_decay": self.weight_decay, -# "lr": layer_lr, -# } -# layer_no_wd_params = { -# "params": [ -# lp -# for ln, lp in layer -# if any(nd in ln for nd in self.no_decay_params) -# ], -# "weight_decay": 0, -# "lr": layer_lr, -# } - -# if len(layer_wd_params) != 0: -# optimizer_grouped_params.append(layer_wd_params) -# if len(layer_no_wd_params) != 0: -# optimizer_grouped_params.append(layer_no_wd_params) - -# return optimizer_grouped_params - -# def __call__(self, module: torch.nn.Module): -# optimizer_grouped_parameters = self.group_params(module) -# optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) -# scheduler = transformers.get_linear_schedule_with_warmup( -# optimizer, self.warmup_steps, self.total_steps -# ) -# return { -# "optimizer": optimizer, -# "lr_scheduler": { -# "scheduler": scheduler, -# "interval": "step", -# "frequency": 1, -# }, -# } diff --git a/relik/retriever/pytorch_modules/scheduler.py b/relik/retriever/pytorch_modules/scheduler.py deleted file mode 100644 index 5edd2433612b27c1a91fe189e57c4e2d41c462b2..0000000000000000000000000000000000000000 --- a/relik/retriever/pytorch_modules/scheduler.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -from torch.optim.lr_scheduler import LRScheduler - - -class LinearSchedulerWithWarmup(LRScheduler): - def __init__( - self, - optimizer: torch.optim.Optimizer, - num_warmup_steps: int, - num_training_steps: int, - last_epoch: int = -1, - verbose: bool = False, - **kwargs, - ): - self.num_warmup_steps = num_warmup_steps - self.num_training_steps = num_training_steps - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - def scheduler_fn(current_step): - if current_step < self.num_warmup_steps: - return current_step / max(1, self.num_warmup_steps) - return max( - 0.0, - float(self.num_training_steps - current_step) - / float(max(1, self.num_training_steps - self.num_warmup_steps)), - ) - - return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] - - -class LinearScheduler(LRScheduler): - def __init__( - self, - optimizer: torch.optim.Optimizer, - num_training_steps: int, - last_epoch: int = -1, - verbose: bool = False, - **kwargs, - ): - self.num_training_steps = num_training_steps - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - def scheduler_fn(current_step): - # if current_step < self.num_warmup_steps: - # return current_step / max(1, self.num_warmup_steps) - return max( - 0.0, - float(self.num_training_steps - current_step) - / float(max(1, self.num_training_steps)), - ) - - return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] diff --git a/relik/retriever/trainer/__init__.py b/relik/retriever/trainer/__init__.py deleted file mode 100644 index f1b18bf79091418217ae2bb782c3796dfa8b5b56..0000000000000000000000000000000000000000 --- a/relik/retriever/trainer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from relik.retriever.trainer.train import RetrieverTrainer diff --git a/relik/retriever/trainer/train.py b/relik/retriever/trainer/train.py deleted file mode 100644 index dc55298c178e7f519cf2f724e9a0881a168c9158..0000000000000000000000000000000000000000 --- a/relik/retriever/trainer/train.py +++ /dev/null @@ -1,667 +0,0 @@ -import os -from pathlib import Path -from typing import List, Optional, Union - -import hydra -import lightning as pl -import omegaconf -import torch -from lightning import Trainer -from lightning.pytorch.callbacks import ( - EarlyStopping, - LearningRateMonitor, - ModelCheckpoint, - ModelSummary, -) -from lightning.pytorch.loggers import WandbLogger -from omegaconf import OmegaConf -from rich.pretty import pprint - -from relik.common.log import get_console_logger -from relik.retriever.callbacks.evaluation_callbacks import ( - AvgRankingEvaluationCallback, - RecallAtKEvaluationCallback, -) -from relik.retriever.callbacks.prediction_callbacks import ( - GoldenRetrieverPredictionCallback, - NegativeAugmentationCallback, -) -from relik.retriever.callbacks.utils_callbacks import ( - FreeUpIndexerVRAMCallback, - SavePredictionsCallback, - SaveRetrieverCallback, -) -from relik.retriever.data.datasets import GoldenRetrieverDataset -from relik.retriever.indexers.base import BaseDocumentIndex -from relik.retriever.lightning_modules.pl_data_modules import ( - GoldenRetrieverPLDataModule, -) -from relik.retriever.lightning_modules.pl_modules import GoldenRetrieverPLModule -from relik.retriever.pytorch_modules.loss import MultiLabelNCELoss -from relik.retriever.pytorch_modules.model import GoldenRetriever -from relik.retriever.pytorch_modules.optim import RAdamW -from relik.retriever.pytorch_modules.scheduler import ( - LinearScheduler, - LinearSchedulerWithWarmup, -) - -logger = get_console_logger() - - -class RetrieverTrainer: - def __init__( - self, - retriever: GoldenRetriever, - train_dataset: GoldenRetrieverDataset, - val_dataset: Union[GoldenRetrieverDataset, list[GoldenRetrieverDataset]], - test_dataset: Optional[ - Union[GoldenRetrieverDataset, list[GoldenRetrieverDataset]] - ] = None, - num_workers: int = 4, - optimizer: torch.optim.Optimizer = RAdamW, - lr: float = 1e-5, - weight_decay: float = 0.01, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler = LinearScheduler, - num_warmup_steps: int = 0, - loss: torch.nn.Module = MultiLabelNCELoss, - callbacks: Optional[list] = None, - accelerator: str = "auto", - devices: int = 1, - num_nodes: int = 1, - strategy: str = "auto", - accumulate_grad_batches: int = 1, - gradient_clip_val: float = 1.0, - val_check_interval: float = 1.0, - check_val_every_n_epoch: int = 1, - max_steps: Optional[int] = None, - max_epochs: Optional[int] = None, - # checkpoint_path: Optional[Union[str, os.PathLike]] = None, - deterministic: bool = True, - fast_dev_run: bool = False, - precision: int = 16, - reload_dataloaders_every_n_epochs: int = 1, - top_ks: Union[int, List[int]] = 100, - # early stopping parameters - early_stopping: bool = True, - early_stopping_patience: int = 10, - # wandb logger parameters - log_to_wandb: bool = True, - wandb_entity: Optional[str] = None, - wandb_experiment_name: Optional[str] = None, - wandb_project_name: Optional[str] = None, - wandb_save_dir: Optional[Union[str, os.PathLike]] = None, - wandb_log_model: bool = True, - wandb_offline_mode: bool = False, - wandb_watch: str = "all", - # checkpoint parameters - model_checkpointing: bool = True, - chekpoint_dir: Optional[Union[str, os.PathLike]] = None, - checkpoint_filename: Optional[Union[str, os.PathLike]] = None, - save_top_k: int = 1, - save_last: bool = False, - # prediction callback parameters - prediction_batch_size: int = 128, - # hard negatives callback parameters - max_hard_negatives_to_mine: int = 15, - hard_negatives_threshold: float = 0.0, - metrics_to_monitor_for_hard_negatives: Optional[str] = None, - mine_hard_negatives_with_probability: float = 1.0, - # other parameters - seed: int = 42, - float32_matmul_precision: str = "medium", - **kwargs, - ): - # put all the parameters in the class - self.retriever = retriever - # datasets - self.train_dataset = train_dataset - self.val_dataset = val_dataset - self.test_dataset = test_dataset - self.num_workers = num_workers - # trainer parameters - self.optimizer = optimizer - self.lr = lr - self.weight_decay = weight_decay - self.lr_scheduler = lr_scheduler - self.num_warmup_steps = num_warmup_steps - self.loss = loss - self.callbacks = callbacks - self.accelerator = accelerator - self.devices = devices - self.num_nodes = num_nodes - self.strategy = strategy - self.accumulate_grad_batches = accumulate_grad_batches - self.gradient_clip_val = gradient_clip_val - self.val_check_interval = val_check_interval - self.check_val_every_n_epoch = check_val_every_n_epoch - self.max_steps = max_steps - self.max_epochs = max_epochs - # self.checkpoint_path = checkpoint_path - self.deterministic = deterministic - self.fast_dev_run = fast_dev_run - self.precision = precision - self.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs - self.top_ks = top_ks - # early stopping parameters - self.early_stopping = early_stopping - self.early_stopping_patience = early_stopping_patience - # wandb logger parameters - self.log_to_wandb = log_to_wandb - self.wandb_entity = wandb_entity - self.wandb_experiment_name = wandb_experiment_name - self.wandb_project_name = wandb_project_name - self.wandb_save_dir = wandb_save_dir - self.wandb_log_model = wandb_log_model - self.wandb_offline_mode = wandb_offline_mode - self.wandb_watch = wandb_watch - # checkpoint parameters - self.model_checkpointing = model_checkpointing - self.chekpoint_dir = chekpoint_dir - self.checkpoint_filename = checkpoint_filename - self.save_top_k = save_top_k - self.save_last = save_last - # prediction callback parameters - self.prediction_batch_size = prediction_batch_size - # hard negatives callback parameters - self.max_hard_negatives_to_mine = max_hard_negatives_to_mine - self.hard_negatives_threshold = hard_negatives_threshold - self.metrics_to_monitor_for_hard_negatives = ( - metrics_to_monitor_for_hard_negatives - ) - self.mine_hard_negatives_with_probability = mine_hard_negatives_with_probability - # other parameters - self.seed = seed - self.float32_matmul_precision = float32_matmul_precision - - if self.max_epochs is None and self.max_steps is None: - raise ValueError( - "Either `max_epochs` or `max_steps` should be specified in the trainer configuration" - ) - - if self.max_epochs is not None and self.max_steps is not None: - logger.log( - "Both `max_epochs` and `max_steps` are specified in the trainer configuration. " - "Will use `max_epochs` for the number of training steps" - ) - self.max_steps = None - - # reproducibility - pl.seed_everything(self.seed) - # set the precision of matmul operations - torch.set_float32_matmul_precision(self.float32_matmul_precision) - - # lightning data module declaration - self.lightining_datamodule = self.configure_lightning_datamodule() - - if self.max_epochs is not None: - logger.log(f"Number of training epochs: {self.max_epochs}") - self.max_steps = ( - len(self.lightining_datamodule.train_dataloader()) * self.max_epochs - ) - - # optimizer declaration - self.optimizer, self.lr_scheduler = self.configure_optimizers() - - # lightning module declaration - self.lightining_module = self.configure_lightning_module() - - # callbacks declaration - self.callbacks_store: List[pl.Callback] = self.configure_callbacks() - - logger.log("Instantiating the Trainer") - self.trainer = pl.Trainer( - accelerator=self.accelerator, - devices=self.devices, - num_nodes=self.num_nodes, - strategy=self.strategy, - accumulate_grad_batches=self.accumulate_grad_batches, - max_epochs=self.max_epochs, - max_steps=self.max_steps, - gradient_clip_val=self.gradient_clip_val, - val_check_interval=self.val_check_interval, - check_val_every_n_epoch=self.check_val_every_n_epoch, - deterministic=self.deterministic, - fast_dev_run=self.fast_dev_run, - precision=self.precision, - reload_dataloaders_every_n_epochs=self.reload_dataloaders_every_n_epochs, - callbacks=self.callbacks_store, - logger=self.wandb_logger, - ) - - def configure_lightning_datamodule(self, *args, **kwargs): - # lightning data module declaration - if isinstance(self.val_dataset, GoldenRetrieverDataset): - self.val_dataset = [self.val_dataset] - if self.test_dataset is not None and isinstance( - self.test_dataset, GoldenRetrieverDataset - ): - self.test_dataset = [self.test_dataset] - - self.lightining_datamodule = GoldenRetrieverPLDataModule( - train_dataset=self.train_dataset, - val_datasets=self.val_dataset, - test_datasets=self.test_dataset, - num_workers=self.num_workers, - *args, - **kwargs, - ) - return self.lightining_datamodule - - def configure_lightning_module(self, *args, **kwargs): - # add loss object to the retriever - if self.retriever.loss_type is None: - self.retriever.loss_type = self.loss() - - # lightning module declaration - self.lightining_module = GoldenRetrieverPLModule( - model=self.retriever, - optimizer=self.optimizer, - lr_scheduler=self.lr_scheduler, - *args, - **kwargs, - ) - - return self.lightining_module - - def configure_optimizers(self, *args, **kwargs): - # check if it is the class or the instance - if isinstance(self.optimizer, type): - self.optimizer = self.optimizer( - params=self.retriever.parameters(), - lr=self.lr, - weight_decay=self.weight_decay, - ) - else: - self.optimizer = self.optimizer - - # LR Scheduler declaration - # check if it is the class, the instance or a function - if self.lr_scheduler is not None: - if isinstance(self.lr_scheduler, type): - self.lr_scheduler = self.lr_scheduler( - optimizer=self.optimizer, - num_warmup_steps=self.num_warmup_steps, - num_training_steps=self.max_steps, - ) - - return self.optimizer, self.lr_scheduler - - def configure_callbacks(self, *args, **kwargs): - # callbacks declaration - self.callbacks_store = self.callbacks or [] - self.callbacks_store.append(ModelSummary(max_depth=2)) - - # metric to monitor - if isinstance(self.top_ks, int): - self.top_ks = [self.top_ks] - # order the top_ks in descending order - self.top_ks = sorted(self.top_ks, reverse=True) - # get the max top_k to monitor - self.top_k = self.top_ks[0] - self.metric_to_monitor = f"validate_recall@{self.top_k}" - self.monitor_mode = "max" - - # early stopping callback if specified - self.early_stopping_callback: Optional[EarlyStopping] = None - if self.early_stopping: - logger.log( - f"Eanbling Early Stopping, patience: {self.early_stopping_patience}" - ) - self.early_stopping_callback = EarlyStopping( - monitor=self.metric_to_monitor, - mode=self.monitor_mode, - patience=self.early_stopping_patience, - ) - self.callbacks_store.append(self.early_stopping_callback) - - # wandb logger if specified - self.wandb_logger: Optional[WandbLogger] = None - self.experiment_path: Optional[Path] = None - if self.log_to_wandb: - # define some default values for the wandb logger - if self.wandb_project_name is None: - self.wandb_project_name = "relik-retriever" - if self.wandb_save_dir is None: - self.wandb_save_dir = "./" - logger.log("Instantiating Wandb Logger") - self.wandb_logger = WandbLogger( - entity=self.wandb_entity, - project=self.wandb_project_name, - name=self.wandb_experiment_name, - save_dir=self.wandb_save_dir, - log_model=self.wandb_log_model, - mode="offline" if self.wandb_offline_mode else "online", - ) - self.wandb_logger.watch(self.lightining_module, log=self.wandb_watch) - self.experiment_path = Path(self.wandb_logger.experiment.dir) - # Store the YaML config separately into the wandb dir - # yaml_conf: str = OmegaConf.to_yaml(cfg=conf) - # (experiment_path / "hparams.yaml").write_text(yaml_conf) - # Add a Learning Rate Monitor callback to log the learning rate - self.callbacks_store.append(LearningRateMonitor(logging_interval="step")) - - # model checkpoint callback if specified - self.model_checkpoint_callback: Optional[ModelCheckpoint] = None - if self.model_checkpointing: - logger.log("Enabling Model Checkpointing") - if self.chekpoint_dir is None: - self.chekpoint_dir = ( - self.experiment_path / "checkpoints" - if self.experiment_path - else None - ) - if self.checkpoint_filename is None: - self.checkpoint_filename = ( - "checkpoint-validate_recall@" - + str(self.top_k) - + "_{validate_recall@" - + str(self.top_k) - + ":.4f}-epoch_{epoch:02d}" - ) - self.model_checkpoint_callback = ModelCheckpoint( - monitor=self.metric_to_monitor, - mode=self.monitor_mode, - verbose=True, - save_top_k=self.save_top_k, - save_last=self.save_last, - filename=self.checkpoint_filename, - dirpath=self.chekpoint_dir, - auto_insert_metric_name=False, - ) - self.callbacks_store.append(self.model_checkpoint_callback) - - # prediction callback - self.other_callbacks_for_prediction = [ - RecallAtKEvaluationCallback(k) for k in self.top_ks - ] - self.other_callbacks_for_prediction += [ - AvgRankingEvaluationCallback(k=self.top_k, verbose=True, prefix="train"), - SavePredictionsCallback(), - ] - self.prediction_callback = GoldenRetrieverPredictionCallback( - k=self.top_k, - batch_size=self.prediction_batch_size, - precision=self.precision, - other_callbacks=self.other_callbacks_for_prediction, - ) - self.callbacks_store.append(self.prediction_callback) - - # hard negative mining callback - self.hard_negatives_callback: Optional[NegativeAugmentationCallback] = None - if self.max_hard_negatives_to_mine > 0: - self.metrics_to_monitor = ( - self.metrics_to_monitor_for_hard_negatives - or f"validate_recall@{self.top_k}" - ) - self.hard_negatives_callback = NegativeAugmentationCallback( - k=self.top_k, - batch_size=self.prediction_batch_size, - precision=self.precision, - stages=["validate"], - metrics_to_monitor=self.metrics_to_monitor, - threshold=self.hard_negatives_threshold, - max_negatives=self.max_hard_negatives_to_mine, - add_with_probability=self.mine_hard_negatives_with_probability, - refresh_every_n_epochs=1, - other_callbacks=[ - AvgRankingEvaluationCallback( - k=self.top_k, verbose=True, prefix="train" - ) - ], - ) - self.callbacks_store.append(self.hard_negatives_callback) - - # utils callback - self.callbacks_store.extend( - [SaveRetrieverCallback(), FreeUpIndexerVRAMCallback()] - ) - return self.callbacks_store - - def train(self): - self.trainer.fit(self.lightining_module, datamodule=self.lightining_datamodule) - - def test( - self, - lightining_module: Optional[GoldenRetrieverPLModule] = None, - checkpoint_path: Optional[Union[str, os.PathLike]] = None, - lightining_datamodule: Optional[GoldenRetrieverPLDataModule] = None, - ): - if lightining_module is not None: - self.lightining_module = lightining_module - else: - if self.fast_dev_run: - best_lightining_module = self.lightining_module - else: - # load best model for testing - if checkpoint_path is not None: - best_model_path = checkpoint_path - elif self.checkpoint_path: - best_model_path = self.checkpoint_path - elif self.model_checkpoint_callback: - best_model_path = self.model_checkpoint_callback.best_model_path - else: - raise ValueError( - "Either `checkpoint_path` or `model_checkpoint_callback` should " - "be provided to the trainer" - ) - logger.log(f"Loading best model from {best_model_path}") - - try: - best_lightining_module = ( - GoldenRetrieverPLModule.load_from_checkpoint(best_model_path) - ) - except Exception as e: - logger.log(f"Failed to load the model from checkpoint: {e}") - logger.log("Using last model instead") - best_lightining_module = self.lightining_module - - lightining_datamodule = lightining_datamodule or self.lightining_datamodule - # module test - self.trainer.test(best_lightining_module, datamodule=lightining_datamodule) - - -def train(conf: omegaconf.DictConfig) -> None: - # reproducibility - pl.seed_everything(conf.train.seed) - torch.set_float32_matmul_precision(conf.train.float32_matmul_precision) - - logger.log(f"Starting training for [bold cyan]{conf.model_name}[/bold cyan] model") - if conf.train.pl_trainer.fast_dev_run: - logger.log( - f"Debug mode {conf.train.pl_trainer.fast_dev_run}. Forcing debugger configuration" - ) - # Debuggers don't like GPUs nor multiprocessing - # conf.train.pl_trainer.accelerator = "cpu" - conf.train.pl_trainer.devices = 1 - conf.train.pl_trainer.strategy = "auto" - conf.train.pl_trainer.precision = 32 - if "num_workers" in conf.data.datamodule: - conf.data.datamodule.num_workers = { - k: 0 for k in conf.data.datamodule.num_workers - } - # Switch wandb to offline mode to prevent online logging - conf.logging.log = None - # remove model checkpoint callback - conf.train.model_checkpoint_callback = None - - if "print_config" in conf and conf.print_config: - pprint(OmegaConf.to_container(conf), console=logger, expand_all=True) - - # data module declaration - logger.log("Instantiating the Data Module") - pl_data_module: GoldenRetrieverPLDataModule = hydra.utils.instantiate( - conf.data.datamodule, _recursive_=False - ) - # force setup to get labels initialized for the model - pl_data_module.prepare_data() - # main module declaration - pl_module: Optional[GoldenRetrieverPLModule] = None - - if not conf.train.only_test: - pl_data_module.setup("fit") - - # count the number of training steps - if ( - "max_epochs" in conf.train.pl_trainer - and conf.train.pl_trainer.max_epochs > 0 - ): - num_training_steps = ( - len(pl_data_module.train_dataloader()) - * conf.train.pl_trainer.max_epochs - ) - if "max_steps" in conf.train.pl_trainer: - logger.log( - "Both `max_epochs` and `max_steps` are specified in the trainer configuration. " - "Will use `max_epochs` for the number of training steps" - ) - conf.train.pl_trainer.max_steps = None - elif ( - "max_steps" in conf.train.pl_trainer and conf.train.pl_trainer.max_steps > 0 - ): - num_training_steps = conf.train.pl_trainer.max_steps - conf.train.pl_trainer.max_epochs = None - else: - raise ValueError( - "Either `max_epochs` or `max_steps` should be specified in the trainer configuration" - ) - logger.log(f"Expected number of training steps: {num_training_steps}") - - if "lr_scheduler" in conf.model.pl_module and conf.model.pl_module.lr_scheduler: - # set the number of warmup steps as x% of the total number of training steps - if conf.model.pl_module.lr_scheduler.num_warmup_steps is None: - if ( - "warmup_steps_ratio" in conf.model.pl_module - and conf.model.pl_module.warmup_steps_ratio is not None - ): - conf.model.pl_module.lr_scheduler.num_warmup_steps = int( - conf.model.pl_module.lr_scheduler.num_training_steps - * conf.model.pl_module.warmup_steps_ratio - ) - else: - conf.model.pl_module.lr_scheduler.num_warmup_steps = 0 - logger.log( - f"Number of warmup steps: {conf.model.pl_module.lr_scheduler.num_warmup_steps}" - ) - - logger.log("Instantiating the Model") - pl_module: GoldenRetrieverPLModule = hydra.utils.instantiate( - conf.model.pl_module, _recursive_=False - ) - if ( - "pretrain_ckpt_path" in conf.train - and conf.train.pretrain_ckpt_path is not None - ): - logger.log( - f"Loading pretrained checkpoint from {conf.train.pretrain_ckpt_path}" - ) - pl_module.load_state_dict( - torch.load(conf.train.pretrain_ckpt_path)["state_dict"], strict=False - ) - - if "compile" in conf.model.pl_module and conf.model.pl_module.compile: - try: - pl_module = torch.compile(pl_module, backend="inductor") - except Exception: - logger.log( - "Failed to compile the model, you may need to install PyTorch 2.0" - ) - - # callbacks declaration - callbacks_store = [ModelSummary(max_depth=2)] - - experiment_logger: Optional[WandbLogger] = None - experiment_path: Optional[Path] = None - if conf.logging.log: - logger.log("Instantiating Wandb Logger") - experiment_logger = hydra.utils.instantiate(conf.logging.wandb_arg) - if pl_module is not None: - # it may happen that the model is not instantiated if we are only testing - # in that case, we don't need to watch the model - experiment_logger.watch(pl_module, **conf.logging.watch) - experiment_path = Path(experiment_logger.experiment.dir) - # Store the YaML config separately into the wandb dir - yaml_conf: str = OmegaConf.to_yaml(cfg=conf) - (experiment_path / "hparams.yaml").write_text(yaml_conf) - # Add a Learning Rate Monitor callback to log the learning rate - callbacks_store.append(LearningRateMonitor(logging_interval="step")) - - early_stopping_callback: Optional[EarlyStopping] = None - if conf.train.early_stopping_callback is not None: - early_stopping_callback = hydra.utils.instantiate( - conf.train.early_stopping_callback - ) - callbacks_store.append(early_stopping_callback) - - model_checkpoint_callback: Optional[ModelCheckpoint] = None - if conf.train.model_checkpoint_callback is not None: - model_checkpoint_callback = hydra.utils.instantiate( - conf.train.model_checkpoint_callback, - dirpath=experiment_path / "checkpoints" if experiment_path else None, - ) - callbacks_store.append(model_checkpoint_callback) - - if "callbacks" in conf.train and conf.train.callbacks is not None: - for _, callback in conf.train.callbacks.items(): - # callback can be a list of callbacks or a single callback - if isinstance(callback, omegaconf.listconfig.ListConfig): - for cb in callback: - if cb is not None: - callbacks_store.append( - hydra.utils.instantiate(cb, _recursive_=False) - ) - else: - if callback is not None: - callbacks_store.append(hydra.utils.instantiate(callback)) - - # trainer - logger.log("Instantiating the Trainer") - trainer: Trainer = hydra.utils.instantiate( - conf.train.pl_trainer, callbacks=callbacks_store, logger=experiment_logger - ) - - if not conf.train.only_test: - # module fit - trainer.fit(pl_module, datamodule=pl_data_module) - - if conf.train.pl_trainer.fast_dev_run: - best_pl_module = pl_module - else: - # load best model for testing - if conf.train.checkpoint_path: - best_model_path = conf.evaluation.checkpoint_path - elif model_checkpoint_callback: - best_model_path = model_checkpoint_callback.best_model_path - else: - raise ValueError( - "Either `checkpoint_path` or `model_checkpoint_callback` should " - "be specified in the evaluation configuration" - ) - logger.log(f"Loading best model from {best_model_path}") - - try: - best_pl_module = GoldenRetrieverPLModule.load_from_checkpoint( - best_model_path - ) - except Exception as e: - logger.log(f"Failed to load the model from checkpoint: {e}") - logger.log("Using last model instead") - best_pl_module = pl_module - if "compile" in conf.model.pl_module and conf.model.pl_module.compile: - try: - best_pl_module = torch.compile(best_pl_module, backend="inductor") - except Exception: - logger.log( - "Failed to compile the model, you may need to install PyTorch 2.0" - ) - - # module test - trainer.test(best_pl_module, datamodule=pl_data_module) - - -@hydra.main(config_path="../../conf", config_name="default", version_base="1.3") -def main(conf: omegaconf.DictConfig): - train(conf) - - -if __name__ == "__main__": - main() diff --git a/relik/version.py b/relik/version.py deleted file mode 100644 index bed137800c980e0e82d7c8ccdf474053baed630f..0000000000000000000000000000000000000000 --- a/relik/version.py +++ /dev/null @@ -1,13 +0,0 @@ -import os - -_MAJOR = "0" -_MINOR = "1" -# On main and in a nightly release the patch should be one ahead of the last -# released build. -_PATCH = "0" -# This is mainly for nightly builds which have the suffix ".dev$DATE". See -# https://semver.org/#is-v123-a-semantic-version for the semantics. -_SUFFIX = os.environ.get("RELIK_VERSION_SUFFIX", "") - -VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) -VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)