zhenyundeng commited on
Commit
7168c2f
·
1 Parent(s): 8d2d2b1
Files changed (2) hide show
  1. app.py +81 -76
  2. requirements.txt +3 -2
app.py CHANGED
@@ -15,6 +15,7 @@ import json
15
  import pytorch_lightning as pl
16
  from urllib.parse import urlparse
17
  from accelerate import Accelerator
 
18
 
19
  from transformers import BartTokenizer, BartForConditionalGeneration
20
  from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
@@ -273,6 +274,7 @@ def fever_veracity_prediction(claim, evidence):
273
  return pred_label
274
 
275
 
 
276
  def veracity_prediction(claim, qa_evidence):
277
  # bert_model_name = "bert-base-uncased"
278
  # tokenizer = BertTokenizer.from_pretrained(bert_model_name)
@@ -375,6 +377,7 @@ def google_justification_generation(claim, qa_evidence, verdict_label):
375
  return pred_justification.strip()
376
 
377
 
 
378
  def justification_generation(claim, qa_evidence, verdict_label):
379
  #
380
  claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
@@ -465,6 +468,7 @@ def docs2prompt(top_docs):
465
  return "\n\n".join([doc2prompt(d) for d in top_docs])
466
 
467
 
 
468
  def prompt_question_generation(test_claim, speaker="they", topk=10):
469
  #
470
  reference_file = "averitec/data/train.json"
@@ -926,88 +930,89 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=5): # top_k=10, 100
926
  return generate_qa_pairs
927
 
928
 
929
- def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100
930
- #
931
- reference_file = "averitec/data/train.json"
932
- tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
933
- prompt_bm25 = BM25Okapi(tokenized_corpus)
934
-
935
- # Define the bloom model:
936
- accelerator = Accelerator()
937
- accel_device = accelerator.device
938
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
939
- # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
940
- # model = BloomForCausalLM.from_pretrained(
941
- # "bigscience/bloom-7b1",
942
- # device_map="auto",
943
- # torch_dtype=torch.bfloat16,
944
- # offload_folder="./offload"
945
- # )
946
-
947
- #
948
- tokenized_corpus = []
949
- all_data_corpus = []
950
-
951
- for retri_evi in tqdm.tqdm(retrieve_evidence):
952
- store_file = retri_evi[-1]
953
-
954
- with open(store_file, 'r') as f:
955
- first = True
956
- for line in f:
957
- line = line.strip()
958
-
959
- if first:
960
- first = False
961
- location_url = line
962
- continue
963
-
964
- if len(line) > 3:
965
- entry = nltk.word_tokenize(line)
966
- if (location_url, line) not in all_data_corpus:
967
- tokenized_corpus.append(entry)
968
- all_data_corpus.append((location_url, line))
969
-
970
- if len(tokenized_corpus) == 0:
971
- print("")
972
-
973
- bm25 = BM25Okapi(tokenized_corpus)
974
- s = bm25.get_scores(nltk.word_tokenize(claim))
975
- top_n = np.argsort(s)[::-1][:top_k]
976
- docs = [all_data_corpus[i] for i in top_n]
977
-
978
- generate_qa_pairs = []
979
- # Then, generate questions for those top 50:
980
- for doc in tqdm.tqdm(docs):
981
- # prompt_lookup_str = example["claim"] + " " + doc[1]
982
- prompt_lookup_str = doc[1]
983
-
984
- prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
985
- prompt_n = 10
986
- prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
987
- prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
988
-
989
- claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
990
- prompt = "\n\n".join(prompt_docs + [claim_prompt])
991
- sentences = [prompt]
992
-
993
- inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
994
- outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
995
- early_stopping=True)
996
-
997
- tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
998
- # We are not allowed to generate more than 250 characters:
999
- tgt_text = tgt_text[:250]
1000
-
1001
- qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
1002
- generate_qa_pairs.append(qa_pair)
1003
-
1004
- return generate_qa_pairs
1005
 
1006
 
1007
  def triple_to_string(x):
1008
  return " </s> ".join([item.strip() for item in x])
1009
 
1010
 
 
1011
  def rerank_questions(claim, bm25_qas, topk=3):
1012
  #
1013
  # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
 
15
  import pytorch_lightning as pl
16
  from urllib.parse import urlparse
17
  from accelerate import Accelerator
18
+ import spaces
19
 
20
  from transformers import BartTokenizer, BartForConditionalGeneration
21
  from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
 
274
  return pred_label
275
 
276
 
277
+ @spaces.GPU
278
  def veracity_prediction(claim, qa_evidence):
279
  # bert_model_name = "bert-base-uncased"
280
  # tokenizer = BertTokenizer.from_pretrained(bert_model_name)
 
377
  return pred_justification.strip()
378
 
379
 
380
+ @spaces.GPU
381
  def justification_generation(claim, qa_evidence, verdict_label):
382
  #
383
  claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
 
468
  return "\n\n".join([doc2prompt(d) for d in top_docs])
469
 
470
 
471
+ @spaces.GPU
472
  def prompt_question_generation(test_claim, speaker="they", topk=10):
473
  #
474
  reference_file = "averitec/data/train.json"
 
930
  return generate_qa_pairs
931
 
932
 
933
+ # def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100
934
+ # #
935
+ # reference_file = "averitec/data/train.json"
936
+ # tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
937
+ # prompt_bm25 = BM25Okapi(tokenized_corpus)
938
+ #
939
+ # # Define the bloom model:
940
+ # accelerator = Accelerator()
941
+ # accel_device = accelerator.device
942
+ # # device = "cuda:0" if torch.cuda.is_available() else "cpu"
943
+ # # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
944
+ # # model = BloomForCausalLM.from_pretrained(
945
+ # # "bigscience/bloom-7b1",
946
+ # # device_map="auto",
947
+ # # torch_dtype=torch.bfloat16,
948
+ # # offload_folder="./offload"
949
+ # # )
950
+ #
951
+ # #
952
+ # tokenized_corpus = []
953
+ # all_data_corpus = []
954
+ #
955
+ # for retri_evi in tqdm.tqdm(retrieve_evidence):
956
+ # store_file = retri_evi[-1]
957
+ #
958
+ # with open(store_file, 'r') as f:
959
+ # first = True
960
+ # for line in f:
961
+ # line = line.strip()
962
+ #
963
+ # if first:
964
+ # first = False
965
+ # location_url = line
966
+ # continue
967
+ #
968
+ # if len(line) > 3:
969
+ # entry = nltk.word_tokenize(line)
970
+ # if (location_url, line) not in all_data_corpus:
971
+ # tokenized_corpus.append(entry)
972
+ # all_data_corpus.append((location_url, line))
973
+ #
974
+ # if len(tokenized_corpus) == 0:
975
+ # print("")
976
+ #
977
+ # bm25 = BM25Okapi(tokenized_corpus)
978
+ # s = bm25.get_scores(nltk.word_tokenize(claim))
979
+ # top_n = np.argsort(s)[::-1][:top_k]
980
+ # docs = [all_data_corpus[i] for i in top_n]
981
+ #
982
+ # generate_qa_pairs = []
983
+ # # Then, generate questions for those top 50:
984
+ # for doc in tqdm.tqdm(docs):
985
+ # # prompt_lookup_str = example["claim"] + " " + doc[1]
986
+ # prompt_lookup_str = doc[1]
987
+ #
988
+ # prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
989
+ # prompt_n = 10
990
+ # prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
991
+ # prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
992
+ #
993
+ # claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
994
+ # prompt = "\n\n".join(prompt_docs + [claim_prompt])
995
+ # sentences = [prompt]
996
+ #
997
+ # inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
998
+ # outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
999
+ # early_stopping=True)
1000
+ #
1001
+ # tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
1002
+ # # We are not allowed to generate more than 250 characters:
1003
+ # tgt_text = tgt_text[:250]
1004
+ #
1005
+ # qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
1006
+ # generate_qa_pairs.append(qa_pair)
1007
+ #
1008
+ # return generate_qa_pairs
1009
 
1010
 
1011
  def triple_to_string(x):
1012
  return " </s> ".join([item.strip() for item in x])
1013
 
1014
 
1015
+ @spaces.GPU
1016
  def rerank_questions(claim, bm25_qas, topk=3):
1017
  #
1018
  # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
  gradio
2
- nltk
3
  rank_bm25
4
  accelerate
5
  trafilatura
6
- spacy
7
  pytorch_lightning
8
  transformers==4.29.2
9
  datasets
@@ -21,3 +21,4 @@ azure-storage-blob
21
  bm25s
22
  PyStemmer
23
  lxml_html_clean
 
 
1
  gradio
2
+ nltk==3.8.1
3
  rank_bm25
4
  accelerate
5
  trafilatura
6
+ spacy==3.7.5
7
  pytorch_lightning
8
  transformers==4.29.2
9
  datasets
 
21
  bm25s
22
  PyStemmer
23
  lxml_html_clean
24
+ spaces