File size: 4,593 Bytes
569f484 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from .image_base import ImageBaseDataset
from .utils.judge_util import build_judge
from ..smp import *
from ..utils import track_progress_rich
class ImageMTDataset(ImageBaseDataset):
def build_prompt(self, line):
if isinstance(line, int):
line =[line]
if self.meta_only:
tgt_path = toliststr(line['image_path'])
tgt_path = self.dump_image(line)
questions = toliststr(line['question'])
if 'answer' in line:
answers = toliststr(line['answer'])
answers = [''] * len(questions)
assert len(questions) == len(answers)
dlgs, pics_number = [], 0
for i in range(len(questions)):
q, a = questions[i], answers[i]
if '<ImageHere>' in q:
content = []
tag_number = q.count('<ImageHere>')
images = tgt_path[pics_number: pics_number + tag_number]
pics_number += tag_number
q_split = q.split('<ImageHere>')
for i in range(tag_number):
qsp, im = q_split[i], images[i]
if qsp != '':
content.append(dict(type='text', value=qsp))
content.append(dict(type='image', value=im))
if q_split[-1] != '':
content.append(dict(type='text', value=q_split[-1]))
content = [dict(type='text', value=q)]
dlgs.append(dict(role='user', content=content))
assert '<ImageHere>' not in a, 'We currently do not support images in the answer. '
content = [dict(type='text', value=a)]
dlgs.append(dict(role='assistant', content=content))
return dlgs
class MMDUDataset(ImageMTDataset):
DATASET_MD5 = {'MMDU': '848b635a88a078f49aebcc6e39792061'}
DIMS = [
'Creativity', 'Richness', 'Visual Perception', 'Logical Coherence',
'Answer Accuracy', 'Image Relationship Understanding', 'Overall Score'
def calculat_metric(self, ans):
all = defaultdict(lambda: 0)
tot = defaultdict(lambda: 0)
valid = defaultdict(lambda: 0)
for k in ans:
res = ans[k]['res']
assert isinstance(res, pd.DataFrame)
lt = len(res)
for i in range(lt):
line = res.iloc[i]
for k in self.DIMS:
tot[k] += 1
if k in line and line[k] is not None:
score = int(line[k])
score = np.clip(score, 0, 10)
all[k] += score
valid[k] += 1
except Exception as e:
print(f'Failed to parse the score: {str(e)}')
sp1 = {'set': 'all'}
sp1.update({k: all[k] / tot[k] * 10 for k in self.DIMS})
sp2 = {'set': 'valid'}
sp2.update({k: all[k] / valid[k] * 10 for k in self.DIMS})
return pd.DataFrame([sp1, sp2])
def evaluate(self, eval_file, **judge_kwargs):
suffix = eval_file.split('.')[-1]
model = judge_kwargs['model']
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
nproc = judge_kwargs.pop('nproc', 4)
data = load(eval_file)
model = judge_kwargs.pop('model', 'gpt-4o')
judge_model = build_judge(model=model, **judge_kwargs)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(judge_model, line) for line in lines]
indices = [line['index'] for line in lines]
ans = {}
if osp.exists(tmp_file):
ans = load(tmp_file)
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]
from .utils.mmdu import mmdu_score
if len(indices):
new_results = track_progress_rich(
ans = load(tmp_file)
for k, v in zip(indices, new_results):
assert k in ans
metric = self.calculat_metric(ans)
dump(metric, score_file)
return metric