zhenyundeng commited on
Commit
8c5fc49
·
1 Parent(s): 0fa98b8

upadte app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -99,7 +99,7 @@ if torch.cuda.is_available():
99
 
100
  # question generation
101
  qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
102
- qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16)
103
  # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
104
  # qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
105
  # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
@@ -557,7 +557,7 @@ def get_google_search_results(api_key, search_engine_id, google_search, sort_dat
557
  return search_results
558
 
559
 
560
- @spaces.GPU
561
  def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
562
  # default config
563
  api_key = os.environ["GOOGLE_API_KEY"]
@@ -688,15 +688,15 @@ def generate_step2_reference_corpus(reference_file):
688
  return tokenized_corpus, prompt_corpus
689
 
690
  @spaces.GPU
691
- def decorate_with_questions(claim, retrieve_evidence, top_k=5): # top_k=10, 100
692
  #
693
  reference_file = "averitec/data/train.json"
694
  tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
695
  prompt_bm25 = BM25Okapi(tokenized_corpus)
696
 
697
  # Define the bloom model:
698
- accelerator = Accelerator()
699
- accel_device = accelerator.device
700
  # device = "cuda:0" if torch.cuda.is_available() else "cpu"
701
  # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
702
  # model = BloomForCausalLM.from_pretrained(
 
99
 
100
  # question generation
101
  qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
102
+ qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to('cuda')
103
  # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
104
  # qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
105
  # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
 
557
  return search_results
558
 
559
 
560
+ # @spaces.GPU
561
  def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
562
  # default config
563
  api_key = os.environ["GOOGLE_API_KEY"]
 
688
  return tokenized_corpus, prompt_corpus
689
 
690
  @spaces.GPU
691
+ def decorate_with_questions(claim, retrieve_evidence, top_k=3): # top_k=5, 10, 100
692
  #
693
  reference_file = "averitec/data/train.json"
694
  tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
695
  prompt_bm25 = BM25Okapi(tokenized_corpus)
696
 
697
  # Define the bloom model:
698
+ # accelerator = Accelerator()
699
+ # accel_device = accelerator.device
700
  # device = "cuda:0" if torch.cuda.is_available() else "cpu"
701
  # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
702
  # model = BloomForCausalLM.from_pretrained(