zhenyundeng commited on
Commit
0fa98b8
·
1 Parent(s): 7168c2f
Files changed (2) hide show
  1. app.py +75 -233
  2. requirements.txt +3 -2
app.py CHANGED
@@ -43,7 +43,7 @@ try:
43
  except Exception as e:
44
  pass
45
 
46
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
47
  account_url = os.environ["AZURE_ACCOUNT_URL"]
48
  credential = {
49
  "account_key": os.environ['AZURE_ACCOUNT_KEY'],
@@ -93,30 +93,38 @@ LABEL = [
93
  "Conflicting Evidence/Cherrypicking",
94
  ]
95
 
96
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
97
- # Question Generation
98
- qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1")
99
- qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1", torch_dtype=torch.bfloat16).to(device)
100
- # qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
101
- # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
102
-
103
- # rerank
104
- rerank_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
105
- rereank_bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason
106
- best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt"
107
- rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model).to(device)
108
-
109
- # Veracity
110
- veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
111
- bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
112
- veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
113
- tokenizer=veracity_tokenizer, model=bert_model).to(device)
114
- # Justification
115
- justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
116
- bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
117
- best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
118
- justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
119
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
120
 
121
  # Set up Gradio Theme
122
  theme = gr.themes.Base(
@@ -124,9 +132,9 @@ theme = gr.themes.Base(
124
  secondary_hue="red",
125
  font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
126
  )
127
-
128
  # ---------- Setting ----------
129
 
 
130
  class Docs:
131
  def __init__(self, metadata=dict(), page_content=""):
132
  self.metadata = metadata
@@ -184,6 +192,7 @@ class SequenceClassificationDataLoader(pl.LightningDataModule):
184
 
185
  return input_ids, attention_masks
186
 
 
187
  def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
188
  if bool_explanation is not None and len(bool_explanation) > 0:
189
  bool_explanation = ", because " + bool_explanation.lower().strip()
@@ -200,91 +209,8 @@ class SequenceClassificationDataLoader(pl.LightningDataModule):
200
  )
201
 
202
 
203
- def google_veracity_prediction(claim, qa_evidence):
204
- bert_model_name = "bert-base-uncased"
205
- tokenizer = BertTokenizer.from_pretrained(bert_model_name)
206
- bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
207
- problem_type="single_label_classification")
208
-
209
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
210
- trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
211
- tokenizer=tokenizer, model=bert_model).to(device)
212
-
213
- dataLoader = SequenceClassificationDataLoader(
214
- tokenizer=tokenizer,
215
- data_file="this_is_discontinued",
216
- batch_size=32,
217
- add_extra_nee=False,
218
- )
219
-
220
- evidence_strings = []
221
- for evidence in qa_evidence:
222
- evidence_strings.append(
223
- dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
224
-
225
- if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
226
- pred_label = "Not Enough Evidence"
227
- return pred_label
228
-
229
- tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
230
- example_support = torch.argmax(
231
- trained_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
232
-
233
- has_unanswerable = False
234
- has_true = False
235
- has_false = False
236
-
237
- for v in example_support:
238
- if v == 0:
239
- has_true = True
240
- if v == 1:
241
- has_false = True
242
- if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
243
- has_unanswerable = True
244
-
245
- if has_unanswerable:
246
- answer = 2
247
- elif has_true and not has_false:
248
- answer = 0
249
- elif not has_true and has_false:
250
- answer = 1
251
- else:
252
- answer = 3
253
-
254
- pred_label = LABEL[answer]
255
-
256
- return pred_label
257
-
258
-
259
- def fever_veracity_prediction(claim, evidence):
260
- tokenizer = RobertaTokenizer.from_pretrained('Dzeniks/roberta-fact-check')
261
- model = RobertaForSequenceClassification.from_pretrained('Dzeniks/roberta-fact-check')
262
-
263
- evidence_string = ""
264
- for evi in evidence:
265
- evidence_string += evi.metadata['title'] + evi.metadata['evidence'] + ' '
266
-
267
- input_sequence = tokenizer.encode_plus(claim, evidence_string, return_tensors="pt")
268
- with torch.no_grad():
269
- prediction = model(**input_sequence)
270
-
271
- label = torch.argmax(prediction[0]).item()
272
- pred_label = LABEL[label]
273
-
274
- return pred_label
275
-
276
-
277
  @spaces.GPU
278
  def veracity_prediction(claim, qa_evidence):
279
- # bert_model_name = "bert-base-uncased"
280
- # tokenizer = BertTokenizer.from_pretrained(bert_model_name)
281
- # bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
282
- # problem_type="single_label_classification")
283
- #
284
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
285
- # trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
286
- # tokenizer=tokenizer, model=bert_model).to(device)
287
-
288
  dataLoader = SequenceClassificationDataLoader(
289
  tokenizer=veracity_tokenizer,
290
  data_file="this_is_discontinued",
@@ -302,7 +228,8 @@ def veracity_prediction(claim, qa_evidence):
302
  return pred_label
303
 
304
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
305
- example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
 
306
 
307
  has_unanswerable = False
308
  has_true = False
@@ -330,6 +257,7 @@ def veracity_prediction(claim, qa_evidence):
330
  return pred_label
331
 
332
 
 
333
  def extract_claim_str(claim, qa_evidence, verdict_label):
334
  claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
335
 
@@ -359,43 +287,43 @@ def extract_claim_str(claim, qa_evidence, verdict_label):
359
  return claim_str
360
 
361
 
362
- def google_justification_generation(claim, qa_evidence, verdict_label):
 
363
  #
364
- claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
365
- claim_str.strip()
366
 
367
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
368
- tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
369
- bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
370
 
371
- best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
372
- trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
373
- model=bart_model).to(device)
374
 
375
- pred_justification = trained_model.generate(claim_str, device=device)
 
376
 
377
- return pred_justification.strip()
 
378
 
 
 
 
 
 
 
379
 
380
- @spaces.GPU
381
- def justification_generation(claim, qa_evidence, verdict_label):
382
- #
383
- claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
384
- claim_str.strip()
385
 
386
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
387
- # tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
388
- # bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
389
  #
390
- # best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
391
- # trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
392
- # model=bart_model).to(device)
393
 
394
- pred_justification = justification_model.generate(claim_str, device=device)
 
395
 
396
  return pred_justification.strip()
397
 
398
-
399
  def QAprediction(claim, evidence, sources):
400
  parts = []
401
  #
@@ -493,9 +421,9 @@ def prompt_question_generation(test_claim, speaker="they", topk=10):
493
  "\". Criticism includes questions like: "
494
  sentences = [prompt]
495
 
496
- inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
497
- outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2,
498
- early_stopping=True)
499
 
500
  tgt_text = qg_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
501
  in_len = len(sentences[0])
@@ -629,96 +557,7 @@ def get_google_search_results(api_key, search_engine_id, google_search, sort_dat
629
  return search_results
630
 
631
 
632
- def averitec_search_michael(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
633
- # default config
634
- api_key = os.environ["GOOGLE_API_KEY"]
635
- search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
636
-
637
- blacklist = [
638
- "jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
639
- "facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
640
- "ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
641
- "nlp.cs.princeton.edu",
642
- "huggingface.co"
643
- ]
644
-
645
- blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
646
- "/glove.",
647
- "ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
648
- "https://web.mit.edu/adamrose/Public/googlelist",
649
- ]
650
-
651
- # save to folder
652
- store_folder = "averitec/data/store/retrieved_docs"
653
- #
654
- index = 0
655
- questions = [q["question"] for q in generate_question]
656
-
657
- # check the date of the claim
658
- current_date = datetime.now().strftime("%Y-%m-%d")
659
- sort_date = check_claim_date(current_date) # check_date="2022-01-01"
660
-
661
- #
662
- search_strings = []
663
- search_types = []
664
-
665
- search_string_2 = string_to_search_query(claim, None)
666
- search_strings += [search_string_2, claim, ]
667
- search_types += ["claim", "claim-noformat", ]
668
-
669
- search_strings += questions
670
- search_types += ["question" for _ in questions]
671
-
672
- # start to search
673
- search_results = []
674
- visited = {}
675
- store_counter = 0
676
- worker_stack = list(range(10))
677
-
678
- retrieve_evidence = []
679
-
680
- for this_search_string, this_search_type in zip(search_strings, search_types):
681
- for page_num in range(n_pages):
682
- search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
683
- this_search_string, page=page_num)
684
-
685
- for result in search_results:
686
- link = str(result["link"])
687
- domain = get_domain_name(link)
688
-
689
- if domain in blacklist:
690
- continue
691
- broken = False
692
- for b_file in blacklist_files:
693
- if b_file in link:
694
- broken = True
695
- if broken:
696
- continue
697
- if link.endswith(".pdf") or link.endswith(".doc"):
698
- continue
699
-
700
- if link in visited:
701
- store_file_path = visited[link]
702
- else:
703
- store_counter += 1
704
- store_file_path = store_folder + "/search_result_" + str(index) + "_" + str(
705
- store_counter) + ".store"
706
- visited[link] = store_file_path
707
-
708
- while len(worker_stack) == 0: # Wait for a worker to become available. Check every second.
709
- sleep(1)
710
-
711
- worker = worker_stack.pop()
712
-
713
- t = threading.Thread(target=get_and_store, args=(link, store_file_path, worker, worker_stack))
714
- t.start()
715
-
716
- line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, store_file_path]
717
- retrieve_evidence.append(line)
718
-
719
- return retrieve_evidence
720
-
721
-
722
  def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
723
  # default config
724
  api_key = os.environ["GOOGLE_API_KEY"]
@@ -848,7 +687,7 @@ def generate_step2_reference_corpus(reference_file):
848
 
849
  return tokenized_corpus, prompt_corpus
850
 
851
-
852
  def decorate_with_questions(claim, retrieve_evidence, top_k=5): # top_k=10, 100
853
  #
854
  reference_file = "averitec/data/train.json"
@@ -916,9 +755,9 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=5): # top_k=10, 100
916
  prompt = "\n\n".join(prompt_docs + [claim_prompt])
917
  sentences = [prompt]
918
 
919
- inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
920
- outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
921
- early_stopping=True)
922
 
923
  tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
924
  # We are not allowed to generate more than 250 characters:
@@ -1034,8 +873,8 @@ def rerank_questions(claim, bm25_qas, topk=3):
1034
  values.append([question, answer, source])
1035
 
1036
  if len(bm25_qas) > 0:
1037
- encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True,
1038
- return_tensors="pt").to(device)
1039
 
1040
  input_ids = encoded_dict['input_ids']
1041
  attention_masks = encoded_dict['attention_mask']
@@ -1052,6 +891,7 @@ def rerank_questions(claim, bm25_qas, topk=3):
1052
  return top3_qa_pairs
1053
 
1054
 
 
1055
  def Googleretriever(query, sources):
1056
  # ----- Generate QA pairs using AVeriTeC
1057
  # step 1: generate questions for the query/claim using Bloom
@@ -1207,6 +1047,7 @@ def WikipediaDumpsretriever(claim):
1207
 
1208
  return results
1209
 
 
1210
  # ----------WikipediaAPIretriever---------
1211
  def clean_str(p):
1212
  return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
@@ -1556,6 +1397,7 @@ def main():
1556
  dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
1557
  demo.queue()
1558
 
 
1559
  demo.launch(share=True)
1560
 
1561
 
 
43
  except Exception as e:
44
  pass
45
 
46
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
47
  account_url = os.environ["AZURE_ACCOUNT_URL"]
48
  credential = {
49
  "account_key": os.environ['AZURE_ACCOUNT_KEY'],
 
93
  "Conflicting Evidence/Cherrypicking",
94
  ]
95
 
96
+ if torch.cuda.is_available():
97
+ # # device
98
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
99
+
100
+ # question generation
101
+ qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
102
+ qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16)
103
+ # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
104
+ # qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
105
+ # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
106
+
107
+ # rerank
108
+ rerank_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
109
+ rereank_bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason
110
+ best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt"
111
+ rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model)
112
+ # rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model).to(device)
113
+
114
+ # Veracity
115
+ veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
116
+ bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
117
+ veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=veracity_tokenizer, model=bert_model)
118
+ # veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=veracity_tokenizer, model=bert_model).to(device)
119
+
120
+
121
+ # Justification
122
+ justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
123
+ bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
124
+ best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
125
+ justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model)
126
+ # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
127
+
128
 
129
  # Set up Gradio Theme
130
  theme = gr.themes.Base(
 
132
  secondary_hue="red",
133
  font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
134
  )
 
135
  # ---------- Setting ----------
136
 
137
+
138
  class Docs:
139
  def __init__(self, metadata=dict(), page_content=""):
140
  self.metadata = metadata
 
192
 
193
  return input_ids, attention_masks
194
 
195
+
196
  def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
197
  if bool_explanation is not None and len(bool_explanation) > 0:
198
  bool_explanation = ", because " + bool_explanation.lower().strip()
 
209
  )
210
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  @spaces.GPU
213
  def veracity_prediction(claim, qa_evidence):
 
 
 
 
 
 
 
 
 
214
  dataLoader = SequenceClassificationDataLoader(
215
  tokenizer=veracity_tokenizer,
216
  data_file="this_is_discontinued",
 
228
  return pred_label
229
 
230
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
231
+ example_support = torch.argmax(veracity_model(tokenized_strings.to(veracity_model.device), attention_mask=attention_mask.to(veracity_model.device)).logits, axis=1)
232
+ # example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
233
 
234
  has_unanswerable = False
235
  has_true = False
 
257
  return pred_label
258
 
259
 
260
+ @spaces.GPU
261
  def extract_claim_str(claim, qa_evidence, verdict_label):
262
  claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
263
 
 
287
  return claim_str
288
 
289
 
290
+ @spaces.GPU
291
+ def justification_generation(claim, qa_evidence, verdict_label):
292
  #
293
+ # claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
294
+ claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
295
 
296
+ for evi in qa_evidence:
297
+ q_text = evi.metadata['query'].strip()
 
298
 
299
+ if len(q_text) == 0:
300
+ continue
 
301
 
302
+ if not q_text[-1] == "?":
303
+ q_text += "?"
304
 
305
+ answer_strings = []
306
+ answer_strings.append(evi.metadata['answer'])
307
 
308
+ claim_str += q_text
309
+ for a_text in answer_strings:
310
+ if a_text:
311
+ if not a_text[-1] == ".":
312
+ a_text += "."
313
+ claim_str += " " + a_text.strip()
314
 
315
+ claim_str += " "
 
 
 
 
316
 
317
+ claim_str += " [VERDICT] " + verdict_label
 
 
318
  #
319
+ claim_str.strip()
 
 
320
 
321
+ pred_justification = justification_model.generate(claim_str, device=justification_model.device)
322
+ # pred_justification = justification_model.generate(claim_str, device=device)
323
 
324
  return pred_justification.strip()
325
 
326
+ @spaces.GPU
327
  def QAprediction(claim, evidence, sources):
328
  parts = []
329
  #
 
421
  "\". Criticism includes questions like: "
422
  sentences = [prompt]
423
 
424
+ inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device)
425
+ # inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
426
+ outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True)
427
 
428
  tgt_text = qg_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
429
  in_len = len(sentences[0])
 
557
  return search_results
558
 
559
 
560
+ @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
562
  # default config
563
  api_key = os.environ["GOOGLE_API_KEY"]
 
687
 
688
  return tokenized_corpus, prompt_corpus
689
 
690
+ @spaces.GPU
691
  def decorate_with_questions(claim, retrieve_evidence, top_k=5): # top_k=10, 100
692
  #
693
  reference_file = "averitec/data/train.json"
 
755
  prompt = "\n\n".join(prompt_docs + [claim_prompt])
756
  sentences = [prompt]
757
 
758
+ inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device)
759
+ # inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
760
+ outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True)
761
 
762
  tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
763
  # We are not allowed to generate more than 250 characters:
 
873
  values.append([question, answer, source])
874
 
875
  if len(bm25_qas) > 0:
876
+ encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(rerank_trained_model.device)
877
+ # encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(device)
878
 
879
  input_ids = encoded_dict['input_ids']
880
  attention_masks = encoded_dict['attention_mask']
 
891
  return top3_qa_pairs
892
 
893
 
894
+ @spaces.GPU
895
  def Googleretriever(query, sources):
896
  # ----- Generate QA pairs using AVeriTeC
897
  # step 1: generate questions for the query/claim using Bloom
 
1047
 
1048
  return results
1049
 
1050
+
1051
  # ----------WikipediaAPIretriever---------
1052
  def clean_str(p):
1053
  return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
 
1397
  dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
1398
  demo.queue()
1399
 
1400
+ # demo.launch()
1401
  demo.launch(share=True)
1402
 
1403
 
requirements.txt CHANGED
@@ -1,11 +1,12 @@
1
  gradio
2
- nltk==3.8.1
3
  rank_bm25
4
  accelerate
5
  trafilatura
6
- spacy==3.7.5
7
  pytorch_lightning
8
  transformers==4.29.2
 
9
  datasets
10
  leven
11
  scikit-learn
 
1
  gradio
2
+ nltk
3
  rank_bm25
4
  accelerate
5
  trafilatura
6
+ spacy
7
  pytorch_lightning
8
  transformers==4.29.2
9
+ SentencePiece
10
  datasets
11
  leven
12
  scikit-learn