Baweja commited on
Commit
ee6ab98
·
verified ·
1 Parent(s): d7e5236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -17,7 +17,7 @@ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name=
17
  rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
18
  rag_model.retriever.init_retrieval()
19
  rag_model.to(device)
20
- model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it',
21
  device_map = 'auto',
22
  torch_dtype = torch.bfloat16,
23
  )
@@ -73,7 +73,7 @@ def retrieved_info(query, rag_model = rag_model, generating_model = model):
73
  generation_model_input = input_format(query, retrieved_context)
74
 
75
  # Generating answer using gemma model
76
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
77
  input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
78
  output = generating_model.generate(input_ids, max_new_tokens = 512)
79
 
 
17
  rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
18
  rag_model.retriever.init_retrieval()
19
  rag_model.to(device)
20
+ model = AutoModelForCausalLM.from_pretrained('HuggingFaceH4/zephyr-7b-beta',
21
  device_map = 'auto',
22
  torch_dtype = torch.bfloat16,
23
  )
 
73
  generation_model_input = input_format(query, retrieved_context)
74
 
75
  # Generating answer using gemma model
76
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
77
  input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
78
  output = generating_model.generate(input_ids, max_new_tokens = 512)
79