Demo750's picture
Upload folder using huggingface_hub
569f484 verified
import uuid
from functools import partial
from .image_base import ImageBaseDataset
from ..smp import *
rouge = None
nlp_en = None
nlp_zh = None
nlp = None
def initialize():
import evaluate
import spacy
global rouge, nlp_en, nlp_zh, nlp
try:
rouge = evaluate.load('rouge', experiment_id=str(uuid.uuid4()))
except:
warnings.warn('Please first `pip install rouge_score`.')
try:
nlp_en = spacy.load('en_core_web_sm')
except:
warnings.warn('Will automatically download en_core_web_sm via spacy.')
spacy.cli.download('en_core_web_sm')
nlp_en = spacy.load('en_core_web_sm')
try:
nlp_zh = spacy.load('zh_core_web_sm')
except:
warnings.warn('Will automatically download zh_core_web_sm via spacy.')
spacy.cli.download('zh_core_web_sm')
nlp_zh = spacy.load('zh_core_web_sm')
nlp = {'en': nlp_en, 'zh': nlp_zh}
def rough_filter(answer_text):
if "I can't" in answer_text:
return False
elif 'I cannot' in answer_text:
return False
elif 'sorry' in answer_text.lower():
return False
if '无法' in answer_text:
return False
elif '抱歉' in answer_text:
return False
else:
return True
def zero_template(crossed_text):
return {
'crossed_text': crossed_text,
'max_sim_val': 0,
'max_sim_string': '',
'precision': 0,
'recall': 0,
'f1': 0,
'jaccard': 0,
'rouge1': 0,
'exact_match': 0,
}
def tokenize(text, language):
"""
Tokenize the text and return the tokens.
Parameters:
text (str): The text to tokenize.
language (str): The language of the text.
Returns:
list: The list of tokens.
"""
assert language in ['en', 'zh']
nlp_language = nlp[language]
processed_text = nlp_language(text)
return [token.text for token in processed_text]
def find_best_match(needle, hay, language, rouge):
"""
Finds the best matching n-gram in the haystack for the given needle.
Parameters:
needle (str): The string to find.
hay (str): The text to search within.
Returns:
tuple: The highest similarity value and the best matching string.
"""
assert language in ['en', 'zh']
from nltk.util import ngrams
from difflib import SequenceMatcher as SM
tokens_hay = tokenize(hay, language)
tokens_needle = tokenize(needle, language)
splitter = '' if language == 'zh' else ' '
ngrams_ = ngrams(tokens_hay, len(tokens_needle))
max_sim_val = 0
max_sim_string = ''
max_sim_ngram = []
tokens_needle_set = set(tokens_needle)
ngrams_hasjoint = [
ngram
for ngram in ngrams_
if not set(ngram).isdisjoint(tokens_needle_set)
]
for ngram in ngrams_hasjoint:
hay_ngram = splitter.join(ngram)
similarity = SM(None, hay_ngram, needle).ratio()
if similarity > max_sim_val:
max_sim_val = similarity
max_sim_string = hay_ngram
max_sim_ngram = ngram
# Evaluate
if len(max_sim_ngram) == 0:
return {
'crossed_text': needle,
'max_sim_val': 0,
'max_sim_string': '',
'precision': 0,
'recall': 0,
'f1': 0,
'jaccard': 0,
'rouge1': 0,
'exact_match': 0,
}
pred_set = set(max_sim_ngram)
ref_set = set(tokens_needle)
correct_tokens = pred_set.intersection(ref_set)
len_correct_tokens = len(correct_tokens)
precision = len_correct_tokens / len(pred_set)
recall = len_correct_tokens / len(ref_set)
if (precision + recall) == 0:
f1 = 0
else:
f1 = 2 * precision * recall / (precision + recall)
union = pred_set.union(ref_set)
jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0
rouge_1 = rouge.compute(
predictions=[max_sim_string],
references=[needle],
tokenizer=partial(tokenize, language=language),
rouge_types=['rouge1'],
)['rouge1']
exact_match = float(list(max_sim_ngram) == list(tokens_needle))
out = {
'crossed_text': needle,
'max_sim_string': max_sim_string,
'max_sim_val': max_sim_val,
'precision': precision,
'recall': recall,
'f1': f1,
'jaccard': jaccard,
'rouge1': rouge_1,
'exact_match': exact_match,
}
return out
def process_match_single_new(
image_id, prediction, answer, language, progress
):
"""
process the inference results for a single image and calculate the metrics
Parameters:
image_id (int): The image id (question id).
prediction (str): The prediction text.
answer (Union[str, List[str]]): The answer text, or a list of answer texts. The masked n-grams in the image.
language (str): The language of the text. Can be "en" or "zh".
rouge (rouge): The rouge metric object.
progress (multiprocessing.Queue): The progress queue.
Returns:
tuple: The image id (question_id, int) and the result per id (dict of dict of dict).
"""
result_per_id = {image_id: {}}
if isinstance(answer, str):
answer = eval(answer)
assert isinstance(answer, list)
result = prediction.split('Assistant: ')[-1]
for i, crossed_text in enumerate(answer):
if rough_filter(result):
find_best_match_result = find_best_match(
crossed_text, result, language, rouge
)
if i == 0:
result_per_id[image_id] = {str(i): find_best_match_result}
else:
result_per_id[image_id][str(i)] = find_best_match_result
else:
if i == 0:
result_per_id[image_id] = {str(i): zero_template(crossed_text)}
else:
result_per_id[image_id][str(i)] = zero_template(crossed_text)
progress.put(1)
return image_id, result_per_id
class VCRDataset(ImageBaseDataset):
TYPE = 'VQA'
URL_PREFIX = 'https://huggingface.co/datasets/vcr-org'
DATASET_URL = {
'VCR_EN_EASY_500': f'{URL_PREFIX}/VCR-wiki-en-easy-test-500/resolve/main/VCR-wiki-en-easy-test-500.tsv',
'VCR_EN_EASY_100': f'{URL_PREFIX}/VCR-wiki-en-easy-test-100/resolve/main/VCR-wiki-en-easy-test-100.tsv',
'VCR_EN_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-en-easy-test/resolve/main/VCR-wiki-en-easy-test.tsv',
'VCR_EN_HARD_500': f'{URL_PREFIX}/VCR-wiki-en-hard-test-500/resolve/main/VCR-wiki-en-hard-test-500.tsv',
'VCR_EN_HARD_100': f'{URL_PREFIX}/VCR-wiki-en-hard-test-100/resolve/main/VCR-wiki-en-hard-test-100.tsv',
'VCR_EN_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-en-hard-test/resolve/main/VCR-wiki-en-hard-test.tsv',
'VCR_ZH_EASY_500': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-500/resolve/main/VCR-wiki-zh-easy-test-500.tsv',
'VCR_ZH_EASY_100': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-100/resolve/main/VCR-wiki-zh-easy-test-100.tsv',
'VCR_ZH_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-zh-easy-test/resolve/main/VCR-wiki-zh-easy-test.tsv',
'VCR_ZH_HARD_500': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-500/resolve/main/VCR-wiki-zh-hard-test-500.tsv',
'VCR_ZH_HARD_100': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-100/resolve/main/VCR-wiki-zh-hard-test-100.tsv',
'VCR_ZH_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-zh-hard-test/resolve/main/VCR-wiki-zh-hard-test.tsv',
}
DATASET_MD5 = {
'VCR_EN_EASY_500': 'fd9258db52f8685dc710619a0ea0a261',
'VCR_EN_EASY_100': '9df5d7266683458621ecbe122beb72f0',
'VCR_EN_EASY_ALL': '8a9b96885f251d1c85f42f84073327f1',
'VCR_EN_HARD_500': '0a22a85080b6a1f52b1f95e302d43df4',
'VCR_EN_HARD_100': '1b20f5cbcbeae0b0bec77f7a36143958',
'VCR_EN_HARD_ALL': '2d8b8b1ee0eba0e0b618fd3aa7d9710e',
'VCR_ZH_EASY_500': 'beca5fd54176adf44cf94bd9b50cf048',
'VCR_ZH_EASY_100': '4a86a5678a79844d6d22ab0629c51cd5',
'VCR_ZH_EASY_ALL': '5050fe7f0027ad2068fd4c7f220edaea',
'VCR_ZH_HARD_500': '617e3360f75c54455625cb0a8da5c1e7',
'VCR_ZH_HARD_100': 'b0e38c85f5d5e63894a3b881c372a62b',
'VCR_ZH_HARD_ALL': '54bbfef448206518b03127ef8b61404c',
}
def __init__(self, dataset='VCR_EN_EASY_500', skip_noimg=True):
super().__init__(dataset, skip_noimg)
initialize()
self.language = 'en' if 'EN' in dataset else 'zh'
self.difficulty = 'easy' if 'EASY' in dataset else 'hard'
# def build_prompt(self, line):
# msgs = super().build_prompt(line)
# assert msgs[-1]['type'] == 'text'
# if self.language == 'zh':
# msgs[-1]['value'] += '图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。'
# else:
# msgs[-1]['value'] += ('What is the covered texts in the image? '
# 'Please restore the covered texts without outputting the explanations.')
# return msgs
def evaluate(self, eval_file, **judge_kwargs):
import multiprocessing
vcr_score_list = {'Exact_Match': [], 'Jaccard': []}
vcr_score = {'Exact_Match': 0, 'Jaccard': 0}
logger = get_logger('Evaluation')
data = load(eval_file)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
pool = multiprocessing.Pool()
manager = multiprocessing.Manager()
progress_queue = manager.Queue()
results = []
overall_results = {str(image_id): {} for image_id in range(len(lines))}
for instance_id, instance in enumerate(lines):
results.append(
pool.apply_async(
process_match_single_new,
args=(
str(instance_id),
instance['prediction'],
instance['answer'],
self.language,
progress_queue,
),
)
)
pool.close()
# Display progress bar
for _ in tqdm(range(len(results))):
progress_queue.get()
pool.join()
# Merging results into overall_result
for result in results:
image_id, result_per_id = result.get()
overall_results[str(image_id)].update(result_per_id[image_id])
for blank_id_str in result_per_id[image_id].keys():
vcr_score_list['Exact_Match'].append(
result_per_id[image_id][blank_id_str]['exact_match']
)
vcr_score_list['Jaccard'].append(
result_per_id[image_id][blank_id_str]['jaccard']
)
vcr_score['Exact_Match'] = np.mean(vcr_score_list['Exact_Match'])
vcr_score['Jaccard'] = np.mean(vcr_score_list['Jaccard'])
results_out = {
k: v for i in range(len(results)) for k, v in results[i].get()[1].items()
}
results_with_metrics = {
'Exact_Match': vcr_score['Exact_Match'],
'Jaccard': vcr_score['Jaccard'],
'Predictions': results_out,
}
score_pth = eval_file.replace(
'.xlsx', f'{self.language}_{self.difficulty}_score.json'
)
dump(results_with_metrics, score_pth)
logger.info(
f'VCR successfully finished evaluating {eval_file}, results saved in {score_pth}'
)
logger.info('Score: ')
for key, value in vcr_score.items():
logger.info('{}:{}'.format(key, value))