import uuid from functools import partial from .image_base import ImageBaseDataset from ..smp import * rouge = None nlp_en = None nlp_zh = None nlp = None def initialize(): import evaluate import spacy global rouge, nlp_en, nlp_zh, nlp try: rouge = evaluate.load('rouge', experiment_id=str(uuid.uuid4())) except: warnings.warn('Please first `pip install rouge_score`.') try: nlp_en = spacy.load('en_core_web_sm') except: warnings.warn('Will automatically download en_core_web_sm via spacy.') spacy.cli.download('en_core_web_sm') nlp_en = spacy.load('en_core_web_sm') try: nlp_zh = spacy.load('zh_core_web_sm') except: warnings.warn('Will automatically download zh_core_web_sm via spacy.') spacy.cli.download('zh_core_web_sm') nlp_zh = spacy.load('zh_core_web_sm') nlp = {'en': nlp_en, 'zh': nlp_zh} def rough_filter(answer_text): if "I can't" in answer_text: return False elif 'I cannot' in answer_text: return False elif 'sorry' in answer_text.lower(): return False if '无法' in answer_text: return False elif '抱歉' in answer_text: return False else: return True def zero_template(crossed_text): return { 'crossed_text': crossed_text, 'max_sim_val': 0, 'max_sim_string': '', 'precision': 0, 'recall': 0, 'f1': 0, 'jaccard': 0, 'rouge1': 0, 'exact_match': 0, } def tokenize(text, language): """ Tokenize the text and return the tokens. Parameters: text (str): The text to tokenize. language (str): The language of the text. Returns: list: The list of tokens. """ assert language in ['en', 'zh'] nlp_language = nlp[language] processed_text = nlp_language(text) return [token.text for token in processed_text] def find_best_match(needle, hay, language, rouge): """ Finds the best matching n-gram in the haystack for the given needle. Parameters: needle (str): The string to find. hay (str): The text to search within. Returns: tuple: The highest similarity value and the best matching string. """ assert language in ['en', 'zh'] from nltk.util import ngrams from difflib import SequenceMatcher as SM tokens_hay = tokenize(hay, language) tokens_needle = tokenize(needle, language) splitter = '' if language == 'zh' else ' ' ngrams_ = ngrams(tokens_hay, len(tokens_needle)) max_sim_val = 0 max_sim_string = '' max_sim_ngram = [] tokens_needle_set = set(tokens_needle) ngrams_hasjoint = [ ngram for ngram in ngrams_ if not set(ngram).isdisjoint(tokens_needle_set) ] for ngram in ngrams_hasjoint: hay_ngram = splitter.join(ngram) similarity = SM(None, hay_ngram, needle).ratio() if similarity > max_sim_val: max_sim_val = similarity max_sim_string = hay_ngram max_sim_ngram = ngram # Evaluate if len(max_sim_ngram) == 0: return { 'crossed_text': needle, 'max_sim_val': 0, 'max_sim_string': '', 'precision': 0, 'recall': 0, 'f1': 0, 'jaccard': 0, 'rouge1': 0, 'exact_match': 0, } pred_set = set(max_sim_ngram) ref_set = set(tokens_needle) correct_tokens = pred_set.intersection(ref_set) len_correct_tokens = len(correct_tokens) precision = len_correct_tokens / len(pred_set) recall = len_correct_tokens / len(ref_set) if (precision + recall) == 0: f1 = 0 else: f1 = 2 * precision * recall / (precision + recall) union = pred_set.union(ref_set) jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0 rouge_1 = rouge.compute( predictions=[max_sim_string], references=[needle], tokenizer=partial(tokenize, language=language), rouge_types=['rouge1'], )['rouge1'] exact_match = float(list(max_sim_ngram) == list(tokens_needle)) out = { 'crossed_text': needle, 'max_sim_string': max_sim_string, 'max_sim_val': max_sim_val, 'precision': precision, 'recall': recall, 'f1': f1, 'jaccard': jaccard, 'rouge1': rouge_1, 'exact_match': exact_match, } return out def process_match_single_new( image_id, prediction, answer, language, progress ): """ process the inference results for a single image and calculate the metrics Parameters: image_id (int): The image id (question id). prediction (str): The prediction text. answer (Union[str, List[str]]): The answer text, or a list of answer texts. The masked n-grams in the image. language (str): The language of the text. Can be "en" or "zh". rouge (rouge): The rouge metric object. progress (multiprocessing.Queue): The progress queue. Returns: tuple: The image id (question_id, int) and the result per id (dict of dict of dict). """ result_per_id = {image_id: {}} if isinstance(answer, str): answer = eval(answer) assert isinstance(answer, list) result = prediction.split('Assistant: ')[-1] for i, crossed_text in enumerate(answer): if rough_filter(result): find_best_match_result = find_best_match( crossed_text, result, language, rouge ) if i == 0: result_per_id[image_id] = {str(i): find_best_match_result} else: result_per_id[image_id][str(i)] = find_best_match_result else: if i == 0: result_per_id[image_id] = {str(i): zero_template(crossed_text)} else: result_per_id[image_id][str(i)] = zero_template(crossed_text) progress.put(1) return image_id, result_per_id class VCRDataset(ImageBaseDataset): TYPE = 'VQA' URL_PREFIX = 'https://huggingface.co/datasets/vcr-org' DATASET_URL = { 'VCR_EN_EASY_500': f'{URL_PREFIX}/VCR-wiki-en-easy-test-500/resolve/main/VCR-wiki-en-easy-test-500.tsv', 'VCR_EN_EASY_100': f'{URL_PREFIX}/VCR-wiki-en-easy-test-100/resolve/main/VCR-wiki-en-easy-test-100.tsv', 'VCR_EN_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-en-easy-test/resolve/main/VCR-wiki-en-easy-test.tsv', 'VCR_EN_HARD_500': f'{URL_PREFIX}/VCR-wiki-en-hard-test-500/resolve/main/VCR-wiki-en-hard-test-500.tsv', 'VCR_EN_HARD_100': f'{URL_PREFIX}/VCR-wiki-en-hard-test-100/resolve/main/VCR-wiki-en-hard-test-100.tsv', 'VCR_EN_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-en-hard-test/resolve/main/VCR-wiki-en-hard-test.tsv', 'VCR_ZH_EASY_500': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-500/resolve/main/VCR-wiki-zh-easy-test-500.tsv', 'VCR_ZH_EASY_100': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-100/resolve/main/VCR-wiki-zh-easy-test-100.tsv', 'VCR_ZH_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-zh-easy-test/resolve/main/VCR-wiki-zh-easy-test.tsv', 'VCR_ZH_HARD_500': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-500/resolve/main/VCR-wiki-zh-hard-test-500.tsv', 'VCR_ZH_HARD_100': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-100/resolve/main/VCR-wiki-zh-hard-test-100.tsv', 'VCR_ZH_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-zh-hard-test/resolve/main/VCR-wiki-zh-hard-test.tsv', } DATASET_MD5 = { 'VCR_EN_EASY_500': 'fd9258db52f8685dc710619a0ea0a261', 'VCR_EN_EASY_100': '9df5d7266683458621ecbe122beb72f0', 'VCR_EN_EASY_ALL': '8a9b96885f251d1c85f42f84073327f1', 'VCR_EN_HARD_500': '0a22a85080b6a1f52b1f95e302d43df4', 'VCR_EN_HARD_100': '1b20f5cbcbeae0b0bec77f7a36143958', 'VCR_EN_HARD_ALL': '2d8b8b1ee0eba0e0b618fd3aa7d9710e', 'VCR_ZH_EASY_500': 'beca5fd54176adf44cf94bd9b50cf048', 'VCR_ZH_EASY_100': '4a86a5678a79844d6d22ab0629c51cd5', 'VCR_ZH_EASY_ALL': '5050fe7f0027ad2068fd4c7f220edaea', 'VCR_ZH_HARD_500': '617e3360f75c54455625cb0a8da5c1e7', 'VCR_ZH_HARD_100': 'b0e38c85f5d5e63894a3b881c372a62b', 'VCR_ZH_HARD_ALL': '54bbfef448206518b03127ef8b61404c', } def __init__(self, dataset='VCR_EN_EASY_500', skip_noimg=True): super().__init__(dataset, skip_noimg) initialize() self.language = 'en' if 'EN' in dataset else 'zh' self.difficulty = 'easy' if 'EASY' in dataset else 'hard' # def build_prompt(self, line): # msgs = super().build_prompt(line) # assert msgs[-1]['type'] == 'text' # if self.language == 'zh': # msgs[-1]['value'] += '图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。' # else: # msgs[-1]['value'] += ('What is the covered texts in the image? ' # 'Please restore the covered texts without outputting the explanations.') # return msgs def evaluate(self, eval_file, **judge_kwargs): import multiprocessing vcr_score_list = {'Exact_Match': [], 'Jaccard': []} vcr_score = {'Exact_Match': 0, 'Jaccard': 0} logger = get_logger('Evaluation') data = load(eval_file) lt = len(data) lines = [data.iloc[i] for i in range(lt)] pool = multiprocessing.Pool() manager = multiprocessing.Manager() progress_queue = manager.Queue() results = [] overall_results = {str(image_id): {} for image_id in range(len(lines))} for instance_id, instance in enumerate(lines): results.append( pool.apply_async( process_match_single_new, args=( str(instance_id), instance['prediction'], instance['answer'], self.language, progress_queue, ), ) ) pool.close() # Display progress bar for _ in tqdm(range(len(results))): progress_queue.get() pool.join() # Merging results into overall_result for result in results: image_id, result_per_id = result.get() overall_results[str(image_id)].update(result_per_id[image_id]) for blank_id_str in result_per_id[image_id].keys(): vcr_score_list['Exact_Match'].append( result_per_id[image_id][blank_id_str]['exact_match'] ) vcr_score_list['Jaccard'].append( result_per_id[image_id][blank_id_str]['jaccard'] ) vcr_score['Exact_Match'] = np.mean(vcr_score_list['Exact_Match']) vcr_score['Jaccard'] = np.mean(vcr_score_list['Jaccard']) results_out = { k: v for i in range(len(results)) for k, v in results[i].get()[1].items() } results_with_metrics = { 'Exact_Match': vcr_score['Exact_Match'], 'Jaccard': vcr_score['Jaccard'], 'Predictions': results_out, } score_pth = eval_file.replace( '.xlsx', f'{self.language}_{self.difficulty}_score.json' ) dump(results_with_metrics, score_pth) logger.info( f'VCR successfully finished evaluating {eval_file}, results saved in {score_pth}' ) logger.info('Score: ') for key, value in vcr_score.items(): logger.info('{}:{}'.format(key, value))