"""Calculate the corr matrix.""" # pylint: disable=invalid-name from typing import List, Optional import numpy as np from logzero import logger # from model_pool import load_model_s from hf_model_s import model_s model = model_s() def gradio_cmat( list1: List[str], list2_: Optional[List[str]] = None, ) -> np.ndarray: """Gen corr matrix given two lists of str. Args: list1: list of strings list2_: list of strings, if None, set to list1 Returns: numpy.ndarray, (len(list1)xlen(list2)) """ if not list2_: list2 = list1[:] else: list2 = list2_[:] try: vec1 = model.encode(list1) except Exception as e: logger.error("mode_s.encode(list1) error: %s", e) raise try: vec2 = model.encode(list2) except Exception as e: logger.error("mode_s.encode(list2) error: %s", e) raise try: res = vec1.dot(vec2.T) except Exception as e: logger.error("vec1.dot(vec2.T) error: %s", e) raise return res