legacy107 commited on
Commit
8c7ff32
·
1 Parent(s): 0811f96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -99,7 +99,7 @@ def clean_data(text):
99
  return cleaned_text
100
 
101
 
102
- def paraphrase_answer(question, answer):
103
  # Combine question and context
104
  input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
105
 
@@ -114,7 +114,10 @@ def paraphrase_answer(question, answer):
114
 
115
  # Generate the answer
116
  with torch.no_grad():
117
- generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
 
 
 
118
 
119
  # Decode and return the generated answer
120
  paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
@@ -174,7 +177,12 @@ def generate_answer(question, context, ground, do_pretrained, do_natural):
174
  pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
175
  pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
176
 
177
- return generated_answer, context, paraphrased_answer, pretrained_answer
 
 
 
 
 
178
 
179
 
180
  # Define a function to list examples from the dataset
@@ -184,7 +192,7 @@ def list_examples():
184
  context = example["context"]
185
  question = example["question"]
186
  answer = example["answer"]
187
- examples.append([question, context, answer, True, True])
188
  return examples
189
 
190
 
@@ -196,13 +204,15 @@ iface = gr.Interface(
196
  Textbox(label="Context"),
197
  Textbox(label="Ground truth"),
198
  Checkbox(label="Include pretrained model's result"),
199
- Checkbox(label="Include natural answer")
 
200
  ],
201
  outputs=[
202
  Textbox(label="Generated Answer"),
203
  Textbox(label="Retrieved Context"),
204
  Textbox(label="Natural Answer"),
205
- Textbox(label="Pretrained Model's Answer")
 
206
  ],
207
  examples=list_examples(),
208
  examples_per_page=1,
 
99
  return cleaned_text
100
 
101
 
102
+ def paraphrase_answer(question, answer, use_pretrained=False):
103
  # Combine question and context
104
  input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
105
 
 
114
 
115
  # Generate the answer
116
  with torch.no_grad():
117
+ if use_pretrained:
118
+ generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
119
+ else:
120
+ generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
121
 
122
  # Decode and return the generated answer
123
  paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
177
  pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
178
  pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
179
 
180
+ # Get pretrained model's natural answer
181
+ pretrained_paraphrased_answer = ""
182
+ if do_pretrained_natural:
183
+ pretrained_paraphrased_answer = paraphrase_answer(question, generated_answer, True)
184
+
185
+ return generated_answer, context, paraphrased_answer, pretrained_answer, pretrained_paraphrased_answer
186
 
187
 
188
  # Define a function to list examples from the dataset
 
192
  context = example["context"]
193
  question = example["question"]
194
  answer = example["answer"]
195
+ examples.append([question, context, answer, True, True, True])
196
  return examples
197
 
198
 
 
204
  Textbox(label="Context"),
205
  Textbox(label="Ground truth"),
206
  Checkbox(label="Include pretrained model's result"),
207
+ Checkbox(label="Include natural answer"),
208
+ Checkbox(label="Include pretrained model's natural answer")
209
  ],
210
  outputs=[
211
  Textbox(label="Generated Answer"),
212
  Textbox(label="Retrieved Context"),
213
  Textbox(label="Natural Answer"),
214
+ Textbox(label="Pretrained Model's Answer"),
215
+ Textbox(label="Pretrained Model's Natural Answer")
216
  ],
217
  examples=list_examples(),
218
  examples_per_page=1,