Spaces:
Sleeping
Sleeping
initial
Browse files
app.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from datasets import load_dataset
|
7 |
+
from evaluate import load as load_metric
|
8 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
9 |
+
from sklearn.metrics import accuracy_score, f1_score
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
select = st.selectbox('Which model would you like to evaluate?',
|
14 |
+
('Bart', 'mBart'))
|
15 |
+
|
16 |
+
def get_datasets():
|
17 |
+
if select == "Bart"
|
18 |
+
all_datasets = ["Communication Networks: unseen questions", "Communication Networks: unseen answers"]
|
19 |
+
if select == "mBart"
|
20 |
+
all_datasets = ["Micro Job: unseen questions", "Micro Job: unseen answers", "Legal Domain: unseen questions", "Legal Domain: unseen answers"]
|
21 |
+
return all_datasets
|
22 |
+
|
23 |
+
all_datasets = get_datasets()
|
24 |
+
|
25 |
+
def get_split(dataset_name):
|
26 |
+
if dataset_name == "Communication Networks: unseen questions"
|
27 |
+
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
|
28 |
+
if dataset_name == "Communication Networks: unseen answers"
|
29 |
+
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
|
30 |
+
if dataset_name == "Micro Job: unseen questions"
|
31 |
+
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
|
32 |
+
if dataset_name == "Micro Job: unseen answers"
|
33 |
+
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
|
34 |
+
if dataset_name == "Legal Domain: unseen questions"
|
35 |
+
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
|
36 |
+
if dataset_name == "Legal Domain: unseen answers"
|
37 |
+
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
|
38 |
+
return split
|
39 |
+
|
40 |
+
def get_model(datasetname):
|
41 |
+
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers"
|
42 |
+
model = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
43 |
+
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers"
|
44 |
+
model = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
45 |
+
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers"
|
46 |
+
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
47 |
+
return model
|
48 |
+
|
49 |
+
def get_tokenizer(datasetname):
|
50 |
+
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers"
|
51 |
+
tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
52 |
+
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers"
|
53 |
+
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
54 |
+
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers"
|
55 |
+
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
56 |
+
return tokenizer
|
57 |
+
|
58 |
+
sacrebleu = load_metric('sacrebleu')
|
59 |
+
rouge = load_metric('rouge')
|
60 |
+
meteor = load_metric('meteor')
|
61 |
+
bertscore = load_metric('bertscore')
|
62 |
+
|
63 |
+
MAX_INPUT_LENGTH = 256
|
64 |
+
MAX_TARGET_LENGTH = 128
|
65 |
+
|
66 |
+
def preprocess_function(examples):
|
67 |
+
"""
|
68 |
+
Preprocess entries of the given dataset
|
69 |
+
|
70 |
+
Params:
|
71 |
+
examples (Dataset): dataset to be preprocessed
|
72 |
+
Returns:
|
73 |
+
model_inputs (BatchEncoding): tokenized dataset entries
|
74 |
+
"""
|
75 |
+
inputs, targets = [], []
|
76 |
+
for i in range(len(examples['question'])):
|
77 |
+
inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}")
|
78 |
+
targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}")
|
79 |
+
|
80 |
+
# apply tokenization to inputs and labels
|
81 |
+
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
|
82 |
+
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True)
|
83 |
+
|
84 |
+
model_inputs['labels'] = labels['input_ids']
|
85 |
+
|
86 |
+
return model_inputs
|
87 |
+
|
88 |
+
|
89 |
+
def flatten_list(l):
|
90 |
+
"""
|
91 |
+
Utility function to convert a list of lists into a flattened list
|
92 |
+
Params:
|
93 |
+
l (list of lists): list to be flattened
|
94 |
+
Returns:
|
95 |
+
A flattened list with the elements of the original list
|
96 |
+
"""
|
97 |
+
return [item for sublist in l for item in sublist]
|
98 |
+
|
99 |
+
|
100 |
+
def extract_feedback(predictions):
|
101 |
+
"""
|
102 |
+
Utility function to extract the feedback from the predictions of the model
|
103 |
+
Params:
|
104 |
+
predictions (list): complete model predictions
|
105 |
+
Returns:
|
106 |
+
feedback (list): extracted feedback from the model's predictions
|
107 |
+
"""
|
108 |
+
feedback = []
|
109 |
+
# iterate through predictions and try to extract predicted feedback
|
110 |
+
for pred in predictions:
|
111 |
+
try:
|
112 |
+
fb = pred.split(':', 1)[1]
|
113 |
+
except IndexError:
|
114 |
+
try:
|
115 |
+
if pred.lower().startswith('partially correct'):
|
116 |
+
fb = pred.split(' ', 1)[2]
|
117 |
+
else:
|
118 |
+
fb = pred.split(' ', 1)[1]
|
119 |
+
except IndexError:
|
120 |
+
fb = pred
|
121 |
+
feedback.append(fb.strip())
|
122 |
+
|
123 |
+
return feedback
|
124 |
+
|
125 |
+
|
126 |
+
def extract_labels(predictions):
|
127 |
+
"""
|
128 |
+
Utility function to extract the labels from the predictions of the model
|
129 |
+
Params:
|
130 |
+
predictions (list): complete model predictions
|
131 |
+
Returns:
|
132 |
+
feedback (list): extracted labels from the model's predictions
|
133 |
+
"""
|
134 |
+
labels = []
|
135 |
+
for pred in predictions:
|
136 |
+
if pred.lower().startswith('correct'):
|
137 |
+
label = 'Correct'
|
138 |
+
elif pred.lower().startswith('partially correct'):
|
139 |
+
label = 'Partially correct'
|
140 |
+
elif pred.lower().startswith('incorrect'):
|
141 |
+
label = 'Incorrect'
|
142 |
+
else:
|
143 |
+
label = 'Unknown label'
|
144 |
+
labels.append(label)
|
145 |
+
|
146 |
+
return labels
|
147 |
+
|
148 |
+
|
149 |
+
def get_predictions_labels(model, dataloader):
|
150 |
+
"""
|
151 |
+
Evaluate model on the given dataset
|
152 |
+
|
153 |
+
Params:
|
154 |
+
model (PreTrainedModel): seq2seq model
|
155 |
+
dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
|
156 |
+
Returns:
|
157 |
+
results (dict): dictionary with the computed evaluation metrics
|
158 |
+
predictions (list): list of the decoded predictions of the model
|
159 |
+
"""
|
160 |
+
decoded_preds, decoded_labels = [], []
|
161 |
+
|
162 |
+
model.eval()
|
163 |
+
# iterate through batchs in the dataloader
|
164 |
+
for batch in tqdm(dataloader):
|
165 |
+
with torch.no_grad():
|
166 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
167 |
+
# generate tokens from batch
|
168 |
+
generated_tokens = model.generate(
|
169 |
+
batch['input_ids'],
|
170 |
+
attention_mask=batch['attention_mask'],
|
171 |
+
max_length=MAX_TARGET_LENGTH
|
172 |
+
)
|
173 |
+
# get golden labels from batch
|
174 |
+
labels_batch = batch['labels']
|
175 |
+
|
176 |
+
# decode model predictions and golden labels
|
177 |
+
decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
178 |
+
decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
|
179 |
+
|
180 |
+
decoded_preds.append(decoded_preds_batch)
|
181 |
+
decoded_labels.append(decoded_labels_batch)
|
182 |
+
|
183 |
+
# convert predictions and golden labels into flattened lists
|
184 |
+
predictions = flatten_list(decoded_preds)
|
185 |
+
labels = flatten_list(decoded_labels)
|
186 |
+
|
187 |
+
return predictions, labels
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
def load_data():
|
194 |
+
df = pd.DataFrame(columns=['Model', 'Dataset', 'SacreBLEU', 'ROUGE-2', 'METEOR', 'BERTScore', 'Accuracy', 'Weighted F1', 'Macro F1'])
|
195 |
+
for ds in all_datasets:
|
196 |
+
split = get_split(ds)
|
197 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(get_model(ds))
|
198 |
+
tokenizer = AutoTokenizer.from_pretrained(get_tokenizer(ds))
|
199 |
+
|
200 |
+
processed_dataset = split.map(
|
201 |
+
preprocess_function,
|
202 |
+
batched=True,
|
203 |
+
remove_columns=split.column_names
|
204 |
+
)
|
205 |
+
processed_dataset.set_format('torch')
|
206 |
+
|
207 |
+
dataloader = DataLoader(processed_dataset, batch_size=4)
|
208 |
+
|
209 |
+
predictions, labels = get_predictions_labels(model, dataloader)
|
210 |
+
|
211 |
+
predicted_feedback = extract_feedback(predictions)
|
212 |
+
predicted_labels = extract_labels(predictions)
|
213 |
+
|
214 |
+
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
|
215 |
+
reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
|
216 |
+
|
217 |
+
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
|
218 |
+
bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
|
219 |
+
meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
|
220 |
+
bert_score = bertscore.compute(predictions=predicted_feedback, references=reference_feedback, lang='de', model_type='bert-base-multilingual-cased', rescale_with_baseline=True)
|
221 |
+
|
222 |
+
reference_labels_np = np.array(reference_labels)
|
223 |
+
|
224 |
+
accuracy_value = accuracy_score(reference_labels_np, predicted_labels)
|
225 |
+
f1_weighted_value = f1_score(reference_labels_np, predicted_labels, average='weighted')
|
226 |
+
f1_macro_value = f1_score(reference_labels_np, predicted_labels, average='macro', labels=['Incorrect', 'Partially correct', 'Correct'])
|
227 |
+
|
228 |
+
new_row = pd.Dataframe("Model" : get_model(ds), "Dataset" : ds, "SacreBLEU" : bleu_score, "ROUGE-2" : rouge_score, "METEOR" : meteor_score, "BERTScore" : bert_score, "Accuracy" : accuracy_value, "Weighted F1" : f1_weighted_value, "Macro F1": f1_macro_value)
|
229 |
+
|
230 |
+
df = pd.concat([df, new_row])
|
231 |
+
return df
|
232 |
+
|
233 |
+
dataframe = load_data()
|
234 |
+
|
235 |
+
st.dataframe(dataframe)
|