legacy107 commited on
Commit
51eb71a
·
1 Parent(s): 8e1c673

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -96,6 +96,9 @@ def generate_answer(question, context, ground, do_pretrained):
96
  max_length=max_length,
97
  ).input_ids
98
 
 
 
 
99
  # Generate the answer
100
  with torch.no_grad():
101
  generated_ids = model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
@@ -110,7 +113,7 @@ def generate_answer(question, context, ground, do_pretrained):
110
  pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
111
  pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
112
 
113
- return generated_answer, context, pretrained_answer
114
 
115
 
116
  # Define a function to list examples from the dataset
 
96
  max_length=max_length,
97
  ).input_ids
98
 
99
+ # Decode the context back
100
+ decoded_context = tokenizer.decode(input_ids, skip_special_tokens=True)[len(f"question: {question} context: "):]
101
+
102
  # Generate the answer
103
  with torch.no_grad():
104
  generated_ids = model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
 
113
  pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
114
  pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
115
 
116
+ return generated_answer, decoded_context, pretrained_answer
117
 
118
 
119
  # Define a function to list examples from the dataset