Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -25,21 +25,20 @@ all_datasets = get_datasets()
|
|
25 |
|
26 |
|
27 |
#def get_split(dataset_name):
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
'''
|
43 |
|
44 |
def get_model(datasetname):
|
45 |
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
@@ -50,195 +49,195 @@ def get_model(datasetname):
|
|
50 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
51 |
return model
|
52 |
|
53 |
-
|
54 |
-
def get_tokenizer(datasetname):
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
sacrebleu = load_metric('sacrebleu')
|
64 |
-
rouge = load_metric('rouge')
|
65 |
-
meteor = load_metric('meteor')
|
66 |
-
bertscore = load_metric('bertscore')
|
67 |
-
|
68 |
-
# use gpu if it's available
|
69 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
70 |
-
|
71 |
-
MAX_INPUT_LENGTH = 256
|
72 |
-
MAX_TARGET_LENGTH = 128
|
73 |
-
|
74 |
-
def preprocess_function(examples, **kwargs):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
def flatten_list(l):
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
def extract_feedback(predictions):
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
def extract_labels(predictions):
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
def get_predictions_labels(model, dataloader, tokenizer):
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
|
191 |
-
|
192 |
-
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
|
198 |
-
|
199 |
|
200 |
|
201 |
-
def load_data():
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
|
216 |
-
|
217 |
|
218 |
-
|
219 |
|
220 |
-
|
221 |
-
|
222 |
|
223 |
-
|
224 |
-
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
|
236 |
-
|
237 |
-
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
|
243 |
def get_rows(datasetname):
|
244 |
if datasetname == "Communication Networks: unseen questions":
|
|
|
25 |
|
26 |
|
27 |
#def get_split(dataset_name):
|
28 |
+
# if dataset_name == "Communication Networks: unseen questions":
|
29 |
+
# split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
|
30 |
+
# if dataset_name == "Communication Networks: unseen answers":
|
31 |
+
# split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
|
32 |
+
# if dataset_name == "Micro Job: unseen questions":
|
33 |
+
# split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
|
34 |
+
# if dataset_name == "Micro Job: unseen answers":
|
35 |
+
# split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
|
36 |
+
# if dataset_name == "Legal Domain: unseen questions":
|
37 |
+
# split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
|
38 |
+
# if dataset_name == "Legal Domain: unseen answers":
|
39 |
+
# split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
|
40 |
+
# return split
|
41 |
+
|
|
|
42 |
|
43 |
def get_model(datasetname):
|
44 |
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
|
|
49 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
50 |
return model
|
51 |
|
52 |
+
|
53 |
+
# def get_tokenizer(datasetname):
|
54 |
+
# if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
55 |
+
# tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
56 |
+
# if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers":
|
57 |
+
# tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
58 |
+
# if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers":
|
59 |
+
# tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
60 |
+
# return tokenizer
|
61 |
+
|
62 |
+
# sacrebleu = load_metric('sacrebleu')
|
63 |
+
# rouge = load_metric('rouge')
|
64 |
+
# meteor = load_metric('meteor')
|
65 |
+
# bertscore = load_metric('bertscore')
|
66 |
+
|
67 |
+
# # use gpu if it's available
|
68 |
+
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
69 |
+
|
70 |
+
# MAX_INPUT_LENGTH = 256
|
71 |
+
# MAX_TARGET_LENGTH = 128
|
72 |
+
|
73 |
+
# def preprocess_function(examples, **kwargs):
|
74 |
+
# """
|
75 |
+
# Preprocess entries of the given dataset
|
76 |
+
|
77 |
+
# Params:
|
78 |
+
# examples (Dataset): dataset to be preprocessed
|
79 |
+
# Returns:
|
80 |
+
# model_inputs (BatchEncoding): tokenized dataset entries
|
81 |
+
# """
|
82 |
+
|
83 |
+
# inputs, targets = [], []
|
84 |
+
# for i in range(len(examples['question'])):
|
85 |
+
# inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}")
|
86 |
+
# targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}")
|
87 |
+
|
88 |
+
# # apply tokenization to inputs and labels
|
89 |
+
# tokenizer = kwargs["tokenizer"]
|
90 |
+
# model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
|
91 |
+
# labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True)
|
92 |
+
|
93 |
+
# model_inputs['labels'] = labels['input_ids']
|
94 |
+
|
95 |
+
# return model_inputs
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
# def flatten_list(l):
|
100 |
+
# """
|
101 |
+
# Utility function to convert a list of lists into a flattened list
|
102 |
+
# Params:
|
103 |
+
# l (list of lists): list to be flattened
|
104 |
+
# Returns:
|
105 |
+
# A flattened list with the elements of the original list
|
106 |
+
# """
|
107 |
+
# return [item for sublist in l for item in sublist]
|
108 |
+
|
109 |
+
|
110 |
+
# def extract_feedback(predictions):
|
111 |
+
# """
|
112 |
+
# Utility function to extract the feedback from the predictions of the model
|
113 |
+
# Params:
|
114 |
+
# predictions (list): complete model predictions
|
115 |
+
# Returns:
|
116 |
+
# feedback (list): extracted feedback from the model's predictions
|
117 |
+
# """
|
118 |
+
# feedback = []
|
119 |
+
# # iterate through predictions and try to extract predicted feedback
|
120 |
+
# for pred in predictions:
|
121 |
+
# try:
|
122 |
+
# fb = pred.split(':', 1)[1]
|
123 |
+
# except IndexError:
|
124 |
+
# try:
|
125 |
+
# if pred.lower().startswith('partially correct'):
|
126 |
+
# fb = pred.split(' ', 1)[2]
|
127 |
+
# else:
|
128 |
+
# fb = pred.split(' ', 1)[1]
|
129 |
+
# except IndexError:
|
130 |
+
# fb = pred
|
131 |
+
# feedback.append(fb.strip())
|
132 |
|
133 |
+
# return feedback
|
134 |
+
|
135 |
+
|
136 |
+
# def extract_labels(predictions):
|
137 |
+
# """
|
138 |
+
# Utility function to extract the labels from the predictions of the model
|
139 |
+
# Params:
|
140 |
+
# predictions (list): complete model predictions
|
141 |
+
# Returns:
|
142 |
+
# feedback (list): extracted labels from the model's predictions
|
143 |
+
# """
|
144 |
+
# labels = []
|
145 |
+
# for pred in predictions:
|
146 |
+
# if pred.lower().startswith('correct'):
|
147 |
+
# label = 'Correct'
|
148 |
+
# elif pred.lower().startswith('partially correct'):
|
149 |
+
# label = 'Partially correct'
|
150 |
+
# elif pred.lower().startswith('incorrect'):
|
151 |
+
# label = 'Incorrect'
|
152 |
+
# else:
|
153 |
+
# label = 'Unknown label'
|
154 |
+
# labels.append(label)
|
155 |
|
156 |
+
# return labels
|
157 |
+
|
158 |
+
|
159 |
+
# def get_predictions_labels(model, dataloader, tokenizer):
|
160 |
+
# """
|
161 |
+
# Evaluate model on the given dataset
|
162 |
+
|
163 |
+
# Params:
|
164 |
+
# model (PreTrainedModel): seq2seq model
|
165 |
+
# dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
|
166 |
+
# Returns:
|
167 |
+
# results (dict): dictionary with the computed evaluation metrics
|
168 |
+
# predictions (list): list of the decoded predictions of the model
|
169 |
+
# """
|
170 |
+
# decoded_preds, decoded_labels = [], []
|
171 |
+
|
172 |
+
# model.eval()
|
173 |
+
# # iterate through batchs in the dataloader
|
174 |
+
# for batch in tqdm(dataloader):
|
175 |
+
# with torch.no_grad():
|
176 |
+
# batch = {k: v.to(device) for k, v in batch.items()}
|
177 |
+
# # generate tokens from batch
|
178 |
+
# generated_tokens = model.generate(
|
179 |
+
# batch['input_ids'],
|
180 |
+
# attention_mask=batch['attention_mask'],
|
181 |
+
# max_length=MAX_TARGET_LENGTH
|
182 |
+
# )
|
183 |
+
# # get golden labels from batch
|
184 |
+
# labels_batch = batch['labels']
|
185 |
|
186 |
+
# # decode model predictions and golden labels
|
187 |
+
# decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
188 |
+
# decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
|
189 |
|
190 |
+
# decoded_preds.append(decoded_preds_batch)
|
191 |
+
# decoded_labels.append(decoded_labels_batch)
|
192 |
|
193 |
+
# # convert predictions and golden labels into flattened lists
|
194 |
+
# predictions = flatten_list(decoded_preds)
|
195 |
+
# labels = flatten_list(decoded_labels)
|
196 |
|
197 |
+
# return predictions, labels
|
198 |
|
199 |
|
200 |
+
# def load_data():
|
201 |
+
# df = pd.DataFrame(columns=['Model', 'Dataset', 'SacreBLEU', 'ROUGE-2', 'METEOR', 'BERTScore', 'Accuracy', 'Weighted F1', 'Macro F1'])
|
202 |
+
# for ds in all_datasets:
|
203 |
+
# split = get_split(ds)
|
204 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(get_model(ds))
|
205 |
+
# tokenizer = AutoTokenizer.from_pretrained(get_tokenizer(ds))
|
206 |
|
207 |
+
# processed_dataset = split.map(
|
208 |
+
# preprocess_function,
|
209 |
+
# batched=True,
|
210 |
+
# remove_columns=split.column_names,
|
211 |
+
# fn_kwargs={"tokenizer": tokenizer}
|
212 |
+
# )
|
213 |
+
# processed_dataset.set_format('torch')
|
214 |
|
215 |
+
# dataloader = DataLoader(processed_dataset, batch_size=4)
|
216 |
|
217 |
+
# predictions, labels = get_predictions_labels(model, dataloader, tokenizer)
|
218 |
|
219 |
+
# predicted_feedback = extract_feedback(predictions)
|
220 |
+
# predicted_labels = extract_labels(predictions)
|
221 |
|
222 |
+
# reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
|
223 |
+
# reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
|
224 |
|
225 |
+
# rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
|
226 |
+
# bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
|
227 |
+
# meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
|
228 |
+
# bert_score = bertscore.compute(predictions=predicted_feedback, references=reference_feedback, lang='de', model_type='bert-base-multilingual-cased', rescale_with_baseline=True)
|
229 |
|
230 |
+
# reference_labels_np = np.array(reference_labels)
|
231 |
+
# accuracy_value = accuracy_score(reference_labels_np, predicted_labels)
|
232 |
+
# f1_weighted_value = f1_score(reference_labels_np, predicted_labels, average='weighted')
|
233 |
+
# f1_macro_value = f1_score(reference_labels_np, predicted_labels, average='macro', labels=['Incorrect', 'Partially correct', 'Correct'])
|
234 |
|
235 |
+
# new_row_data = {"Model": get_model(ds), "Dataset": ds, "SacreBLEU": bleu_score, "ROUGE-2": rouge_score, "METEOR": meteor_score, "BERTScore": bert_score, "Accuracy": accuracy_value, "Weighted F1": f1_weighted_value, "Macro F1": f1_macro_value}
|
236 |
+
# new_row = pd.DataFrame(new_row_data)
|
237 |
|
238 |
+
# df = pd.concat([df, new_row])
|
239 |
+
# return df
|
240 |
+
|
241 |
|
242 |
def get_rows(datasetname):
|
243 |
if datasetname == "Communication Networks: unseen questions":
|