Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name=
|
|
17 |
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
|
18 |
rag_model.retriever.init_retrieval()
|
19 |
rag_model.to(device)
|
20 |
-
model = AutoModelForCausalLM.from_pretrained('
|
21 |
device_map = 'auto',
|
22 |
torch_dtype = torch.bfloat16,
|
23 |
)
|
@@ -73,7 +73,7 @@ def retrieved_info(query, rag_model = rag_model, generating_model = model):
|
|
73 |
generation_model_input = input_format(query, retrieved_context)
|
74 |
|
75 |
# Generating answer using gemma model
|
76 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
77 |
input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
|
78 |
output = generating_model.generate(input_ids, max_new_tokens = 512)
|
79 |
|
|
|
17 |
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
|
18 |
rag_model.retriever.init_retrieval()
|
19 |
rag_model.to(device)
|
20 |
+
model = AutoModelForCausalLM.from_pretrained('HuggingFaceH4/zephyr-7b-beta',
|
21 |
device_map = 'auto',
|
22 |
torch_dtype = torch.bfloat16,
|
23 |
)
|
|
|
73 |
generation_model_input = input_format(query, retrieved_context)
|
74 |
|
75 |
# Generating answer using gemma model
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
77 |
input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
|
78 |
output = generating_model.generate(input_ids, max_new_tokens = 512)
|
79 |
|