from dataclasses import dataclass import statistics import sys from typing import List, Union from numpy.typing import NDArray NumSentencesType = Union[List[int], List[List[int]]] EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]] def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType: def _slice_embeddings(s_idx: int, n_sentences: List[int]): _result = [] for count in n_sentences: _result.append(embeddings[s_idx:s_idx + count]) s_idx += count return _result, s_idx if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences): result, _ = _slice_embeddings(0, num_sentences) return result elif isinstance(num_sentences, list) and all( isinstance(sublist, list) and all( isinstance(item, int) for item in sublist ) for sublist in num_sentences ): nested_result = [] start_idx = 0 for nested_num_sentences in num_sentences: embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences) nested_result.append(embedding_slice) return nested_result else: raise TypeError(f"Incorrect Type for {num_sentences=}") def is_list_of_strings_at_depth(obj, depth: int) -> bool: if depth == 0: return isinstance(obj, str) elif depth > 0: return isinstance(obj, list) and all(is_list_of_strings_at_depth(item, depth - 1) for item in obj) else: raise ValueError("Depth can't be negative") def flatten_list(nested_list: list) -> list: """ Recursively flattens a nested list of any depth. Parameters: nested_list (list): The nested list to flatten. Returns: list: A flat list containing all the elements of the nested list. """ flat_list = [] for item in nested_list: if isinstance(item, list): flat_list.extend(flatten_list(item)) else: flat_list.append(item) return flat_list def compute_f1(p: float, r: float, eps=sys.float_info.epsilon) -> float: """ Computes F1 value :param p: Precision Value :param r: Recall Value :param eps: Epsilon Value :return: """ f1 = 2 * p * r / (p + r + eps) return f1 @dataclass class Scores: precision: float recall: List[float] def __post_init__(self): self.f1: float = compute_f1(self.precision, statistics.fmean(self.recall))