pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluates image-text retrieval results."""
from typing import List, Mapping
import numpy as np
RECALL_THRESHOLDS = (1, 5, 10)
def text_to_image_retrieval_eval(
dist_matrix: np.ndarray,
text_image_correspondence: List[int]) -> Mapping[str, float]:
"""Runs the text-to-image retrieval eval from the distance matrix.
Args:
dist_matrix: Distance matrix between text and image embeddings (shape
N_IMAGES x N_TEXTS).
text_image_correspondence: Mapping between rows and columns of
`dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that
the text embedding in column i corresponds to the image embedding in row
n_i. Please note that many texts can be assigned to the same image. For
instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then
`text_image_correspondence = [0, 0, 1, 1]` means that the two first texts
correspond to the first image and the two last texts to the second image.
Returns:
A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS.
"""
per_text_ranks = dist_matrix.argsort(axis=0)
text_image_correspondence = np.array(text_image_correspondence)
def recall_at(k):
wins = per_text_ranks[:k, :] == text_image_correspondence[None]
return wins.any(axis=0).mean()
return {
f'Recall@{k}': recall_at(k)
for k in RECALL_THRESHOLDS
}
def image_to_text_retrieval_eval(
dist_matrix: np.ndarray,
text_image_correspondence: List[int]) -> Mapping[str, float]:
"""Runs the image-to-text retrieval eval from the distance matrix.
Args:
dist_matrix: Distance matrix between text and image embeddings (shape
N_IMAGES x N_TEXTS).
text_image_correspondence: Mapping between rows and columns of
`dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that
the text embedding in column i corresponds to the image embedding in row
n_i. Please note that many texts can be assigned to the same image. For
instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then
`text_image_correspondence = [0, 0, 1, 1]` means that the two first texts
correspond to the first image and the two last texts to the second image.
Returns:
A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS.
"""
per_image_ranks = dist_matrix.argsort(axis=1)
text_image_correspondence = np.array(text_image_correspondence)
def recall_at(k):
top_k_images = text_image_correspondence[per_image_ranks[:, :k]]
wins = top_k_images == np.arange(len(per_image_ranks))[:, None]
return wins.any(axis=1).mean()
return {
f'Recall@{k}': recall_at(k)
for k in RECALL_THRESHOLDS
}