MCK-02 commited on
Commit
de8fc84
·
1 Parent(s): 5fc21e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -187
app.py CHANGED
@@ -25,21 +25,20 @@ all_datasets = get_datasets()
25
 
26
 
27
  #def get_split(dataset_name):
28
- '''
29
- if dataset_name == "Communication Networks: unseen questions":
30
- split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
31
- if dataset_name == "Communication Networks: unseen answers":
32
- split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
33
- if dataset_name == "Micro Job: unseen questions":
34
- split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
35
- if dataset_name == "Micro Job: unseen answers":
36
- split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
37
- if dataset_name == "Legal Domain: unseen questions":
38
- split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
39
- if dataset_name == "Legal Domain: unseen answers":
40
- split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
41
- return split
42
- '''
43
 
44
  def get_model(datasetname):
45
  if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
@@ -50,195 +49,195 @@ def get_model(datasetname):
50
  model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
51
  return model
52
 
53
- '''
54
- def get_tokenizer(datasetname):
55
- if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
56
- tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
57
- if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers":
58
- tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
59
- if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers":
60
- tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
61
- return tokenizer
62
-
63
- sacrebleu = load_metric('sacrebleu')
64
- rouge = load_metric('rouge')
65
- meteor = load_metric('meteor')
66
- bertscore = load_metric('bertscore')
67
-
68
- # use gpu if it's available
69
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
70
-
71
- MAX_INPUT_LENGTH = 256
72
- MAX_TARGET_LENGTH = 128
73
-
74
- def preprocess_function(examples, **kwargs):
75
- """
76
- Preprocess entries of the given dataset
77
-
78
- Params:
79
- examples (Dataset): dataset to be preprocessed
80
- Returns:
81
- model_inputs (BatchEncoding): tokenized dataset entries
82
- """
83
-
84
- inputs, targets = [], []
85
- for i in range(len(examples['question'])):
86
- inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}")
87
- targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}")
88
-
89
- # apply tokenization to inputs and labels
90
- tokenizer = kwargs["tokenizer"]
91
- model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
92
- labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True)
93
-
94
- model_inputs['labels'] = labels['input_ids']
95
-
96
- return model_inputs
97
-
98
-
99
-
100
- def flatten_list(l):
101
- """
102
- Utility function to convert a list of lists into a flattened list
103
- Params:
104
- l (list of lists): list to be flattened
105
- Returns:
106
- A flattened list with the elements of the original list
107
- """
108
- return [item for sublist in l for item in sublist]
109
-
110
-
111
- def extract_feedback(predictions):
112
- """
113
- Utility function to extract the feedback from the predictions of the model
114
- Params:
115
- predictions (list): complete model predictions
116
- Returns:
117
- feedback (list): extracted feedback from the model's predictions
118
- """
119
- feedback = []
120
- # iterate through predictions and try to extract predicted feedback
121
- for pred in predictions:
122
- try:
123
- fb = pred.split(':', 1)[1]
124
- except IndexError:
125
- try:
126
- if pred.lower().startswith('partially correct'):
127
- fb = pred.split(' ', 1)[2]
128
- else:
129
- fb = pred.split(' ', 1)[1]
130
- except IndexError:
131
- fb = pred
132
- feedback.append(fb.strip())
133
 
134
- return feedback
135
-
136
-
137
- def extract_labels(predictions):
138
- """
139
- Utility function to extract the labels from the predictions of the model
140
- Params:
141
- predictions (list): complete model predictions
142
- Returns:
143
- feedback (list): extracted labels from the model's predictions
144
- """
145
- labels = []
146
- for pred in predictions:
147
- if pred.lower().startswith('correct'):
148
- label = 'Correct'
149
- elif pred.lower().startswith('partially correct'):
150
- label = 'Partially correct'
151
- elif pred.lower().startswith('incorrect'):
152
- label = 'Incorrect'
153
- else:
154
- label = 'Unknown label'
155
- labels.append(label)
156
 
157
- return labels
158
-
159
-
160
- def get_predictions_labels(model, dataloader, tokenizer):
161
- """
162
- Evaluate model on the given dataset
163
-
164
- Params:
165
- model (PreTrainedModel): seq2seq model
166
- dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
167
- Returns:
168
- results (dict): dictionary with the computed evaluation metrics
169
- predictions (list): list of the decoded predictions of the model
170
- """
171
- decoded_preds, decoded_labels = [], []
172
-
173
- model.eval()
174
- # iterate through batchs in the dataloader
175
- for batch in tqdm(dataloader):
176
- with torch.no_grad():
177
- batch = {k: v.to(device) for k, v in batch.items()}
178
- # generate tokens from batch
179
- generated_tokens = model.generate(
180
- batch['input_ids'],
181
- attention_mask=batch['attention_mask'],
182
- max_length=MAX_TARGET_LENGTH
183
- )
184
- # get golden labels from batch
185
- labels_batch = batch['labels']
186
 
187
- # decode model predictions and golden labels
188
- decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
189
- decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
190
 
191
- decoded_preds.append(decoded_preds_batch)
192
- decoded_labels.append(decoded_labels_batch)
193
 
194
- # convert predictions and golden labels into flattened lists
195
- predictions = flatten_list(decoded_preds)
196
- labels = flatten_list(decoded_labels)
197
 
198
- return predictions, labels
199
 
200
 
201
- def load_data():
202
- df = pd.DataFrame(columns=['Model', 'Dataset', 'SacreBLEU', 'ROUGE-2', 'METEOR', 'BERTScore', 'Accuracy', 'Weighted F1', 'Macro F1'])
203
- for ds in all_datasets:
204
- split = get_split(ds)
205
- model = AutoModelForSeq2SeqLM.from_pretrained(get_model(ds))
206
- tokenizer = AutoTokenizer.from_pretrained(get_tokenizer(ds))
207
 
208
- processed_dataset = split.map(
209
- preprocess_function,
210
- batched=True,
211
- remove_columns=split.column_names,
212
- fn_kwargs={"tokenizer": tokenizer}
213
- )
214
- processed_dataset.set_format('torch')
215
 
216
- dataloader = DataLoader(processed_dataset, batch_size=4)
217
 
218
- predictions, labels = get_predictions_labels(model, dataloader, tokenizer)
219
 
220
- predicted_feedback = extract_feedback(predictions)
221
- predicted_labels = extract_labels(predictions)
222
 
223
- reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
224
- reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
225
 
226
- rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
227
- bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
228
- meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
229
- bert_score = bertscore.compute(predictions=predicted_feedback, references=reference_feedback, lang='de', model_type='bert-base-multilingual-cased', rescale_with_baseline=True)
230
 
231
- reference_labels_np = np.array(reference_labels)
232
- accuracy_value = accuracy_score(reference_labels_np, predicted_labels)
233
- f1_weighted_value = f1_score(reference_labels_np, predicted_labels, average='weighted')
234
- f1_macro_value = f1_score(reference_labels_np, predicted_labels, average='macro', labels=['Incorrect', 'Partially correct', 'Correct'])
235
 
236
- new_row_data = {"Model": get_model(ds), "Dataset": ds, "SacreBLEU": bleu_score, "ROUGE-2": rouge_score, "METEOR": meteor_score, "BERTScore": bert_score, "Accuracy": accuracy_value, "Weighted F1": f1_weighted_value, "Macro F1": f1_macro_value}
237
- new_row = pd.DataFrame(new_row_data)
238
 
239
- df = pd.concat([df, new_row])
240
- return df
241
- '''
242
 
243
  def get_rows(datasetname):
244
  if datasetname == "Communication Networks: unseen questions":
 
25
 
26
 
27
  #def get_split(dataset_name):
28
+ # if dataset_name == "Communication Networks: unseen questions":
29
+ # split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
30
+ # if dataset_name == "Communication Networks: unseen answers":
31
+ # split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
32
+ # if dataset_name == "Micro Job: unseen questions":
33
+ # split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
34
+ # if dataset_name == "Micro Job: unseen answers":
35
+ # split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
36
+ # if dataset_name == "Legal Domain: unseen questions":
37
+ # split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
38
+ # if dataset_name == "Legal Domain: unseen answers":
39
+ # split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
40
+ # return split
41
+
 
42
 
43
  def get_model(datasetname):
44
  if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
 
49
  model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
50
  return model
51
 
52
+
53
+ # def get_tokenizer(datasetname):
54
+ # if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
55
+ # tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
56
+ # if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers":
57
+ # tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
58
+ # if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers":
59
+ # tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
60
+ # return tokenizer
61
+
62
+ # sacrebleu = load_metric('sacrebleu')
63
+ # rouge = load_metric('rouge')
64
+ # meteor = load_metric('meteor')
65
+ # bertscore = load_metric('bertscore')
66
+
67
+ # # use gpu if it's available
68
+ # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
69
+
70
+ # MAX_INPUT_LENGTH = 256
71
+ # MAX_TARGET_LENGTH = 128
72
+
73
+ # def preprocess_function(examples, **kwargs):
74
+ # """
75
+ # Preprocess entries of the given dataset
76
+
77
+ # Params:
78
+ # examples (Dataset): dataset to be preprocessed
79
+ # Returns:
80
+ # model_inputs (BatchEncoding): tokenized dataset entries
81
+ # """
82
+
83
+ # inputs, targets = [], []
84
+ # for i in range(len(examples['question'])):
85
+ # inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}")
86
+ # targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}")
87
+
88
+ # # apply tokenization to inputs and labels
89
+ # tokenizer = kwargs["tokenizer"]
90
+ # model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
91
+ # labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True)
92
+
93
+ # model_inputs['labels'] = labels['input_ids']
94
+
95
+ # return model_inputs
96
+
97
+
98
+
99
+ # def flatten_list(l):
100
+ # """
101
+ # Utility function to convert a list of lists into a flattened list
102
+ # Params:
103
+ # l (list of lists): list to be flattened
104
+ # Returns:
105
+ # A flattened list with the elements of the original list
106
+ # """
107
+ # return [item for sublist in l for item in sublist]
108
+
109
+
110
+ # def extract_feedback(predictions):
111
+ # """
112
+ # Utility function to extract the feedback from the predictions of the model
113
+ # Params:
114
+ # predictions (list): complete model predictions
115
+ # Returns:
116
+ # feedback (list): extracted feedback from the model's predictions
117
+ # """
118
+ # feedback = []
119
+ # # iterate through predictions and try to extract predicted feedback
120
+ # for pred in predictions:
121
+ # try:
122
+ # fb = pred.split(':', 1)[1]
123
+ # except IndexError:
124
+ # try:
125
+ # if pred.lower().startswith('partially correct'):
126
+ # fb = pred.split(' ', 1)[2]
127
+ # else:
128
+ # fb = pred.split(' ', 1)[1]
129
+ # except IndexError:
130
+ # fb = pred
131
+ # feedback.append(fb.strip())
132
 
133
+ # return feedback
134
+
135
+
136
+ # def extract_labels(predictions):
137
+ # """
138
+ # Utility function to extract the labels from the predictions of the model
139
+ # Params:
140
+ # predictions (list): complete model predictions
141
+ # Returns:
142
+ # feedback (list): extracted labels from the model's predictions
143
+ # """
144
+ # labels = []
145
+ # for pred in predictions:
146
+ # if pred.lower().startswith('correct'):
147
+ # label = 'Correct'
148
+ # elif pred.lower().startswith('partially correct'):
149
+ # label = 'Partially correct'
150
+ # elif pred.lower().startswith('incorrect'):
151
+ # label = 'Incorrect'
152
+ # else:
153
+ # label = 'Unknown label'
154
+ # labels.append(label)
155
 
156
+ # return labels
157
+
158
+
159
+ # def get_predictions_labels(model, dataloader, tokenizer):
160
+ # """
161
+ # Evaluate model on the given dataset
162
+
163
+ # Params:
164
+ # model (PreTrainedModel): seq2seq model
165
+ # dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
166
+ # Returns:
167
+ # results (dict): dictionary with the computed evaluation metrics
168
+ # predictions (list): list of the decoded predictions of the model
169
+ # """
170
+ # decoded_preds, decoded_labels = [], []
171
+
172
+ # model.eval()
173
+ # # iterate through batchs in the dataloader
174
+ # for batch in tqdm(dataloader):
175
+ # with torch.no_grad():
176
+ # batch = {k: v.to(device) for k, v in batch.items()}
177
+ # # generate tokens from batch
178
+ # generated_tokens = model.generate(
179
+ # batch['input_ids'],
180
+ # attention_mask=batch['attention_mask'],
181
+ # max_length=MAX_TARGET_LENGTH
182
+ # )
183
+ # # get golden labels from batch
184
+ # labels_batch = batch['labels']
185
 
186
+ # # decode model predictions and golden labels
187
+ # decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
188
+ # decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
189
 
190
+ # decoded_preds.append(decoded_preds_batch)
191
+ # decoded_labels.append(decoded_labels_batch)
192
 
193
+ # # convert predictions and golden labels into flattened lists
194
+ # predictions = flatten_list(decoded_preds)
195
+ # labels = flatten_list(decoded_labels)
196
 
197
+ # return predictions, labels
198
 
199
 
200
+ # def load_data():
201
+ # df = pd.DataFrame(columns=['Model', 'Dataset', 'SacreBLEU', 'ROUGE-2', 'METEOR', 'BERTScore', 'Accuracy', 'Weighted F1', 'Macro F1'])
202
+ # for ds in all_datasets:
203
+ # split = get_split(ds)
204
+ # model = AutoModelForSeq2SeqLM.from_pretrained(get_model(ds))
205
+ # tokenizer = AutoTokenizer.from_pretrained(get_tokenizer(ds))
206
 
207
+ # processed_dataset = split.map(
208
+ # preprocess_function,
209
+ # batched=True,
210
+ # remove_columns=split.column_names,
211
+ # fn_kwargs={"tokenizer": tokenizer}
212
+ # )
213
+ # processed_dataset.set_format('torch')
214
 
215
+ # dataloader = DataLoader(processed_dataset, batch_size=4)
216
 
217
+ # predictions, labels = get_predictions_labels(model, dataloader, tokenizer)
218
 
219
+ # predicted_feedback = extract_feedback(predictions)
220
+ # predicted_labels = extract_labels(predictions)
221
 
222
+ # reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
223
+ # reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
224
 
225
+ # rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
226
+ # bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
227
+ # meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
228
+ # bert_score = bertscore.compute(predictions=predicted_feedback, references=reference_feedback, lang='de', model_type='bert-base-multilingual-cased', rescale_with_baseline=True)
229
 
230
+ # reference_labels_np = np.array(reference_labels)
231
+ # accuracy_value = accuracy_score(reference_labels_np, predicted_labels)
232
+ # f1_weighted_value = f1_score(reference_labels_np, predicted_labels, average='weighted')
233
+ # f1_macro_value = f1_score(reference_labels_np, predicted_labels, average='macro', labels=['Incorrect', 'Partially correct', 'Correct'])
234
 
235
+ # new_row_data = {"Model": get_model(ds), "Dataset": ds, "SacreBLEU": bleu_score, "ROUGE-2": rouge_score, "METEOR": meteor_score, "BERTScore": bert_score, "Accuracy": accuracy_value, "Weighted F1": f1_weighted_value, "Macro F1": f1_macro_value}
236
+ # new_row = pd.DataFrame(new_row_data)
237
 
238
+ # df = pd.concat([df, new_row])
239
+ # return df
240
+
241
 
242
  def get_rows(datasetname):
243
  if datasetname == "Communication Networks: unseen questions":