Update app.py
Browse files
app.py
CHANGED
@@ -288,7 +288,7 @@
|
|
288 |
|
289 |
import torch
|
290 |
import transformers
|
291 |
-
from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM
|
292 |
import gradio as gr
|
293 |
|
294 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
@@ -304,11 +304,12 @@ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name=
|
|
304 |
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
|
305 |
rag_model.retriever.init_retrieval()
|
306 |
rag_model.to(device)
|
|
|
307 |
pipe = pipeline(
|
308 |
"text-generation",
|
309 |
model="google/gemma-2-2b-it",
|
310 |
model_kwargs={"torch_dtype": torch.bfloat16},
|
311 |
-
device=device,
|
312 |
)
|
313 |
|
314 |
def strip_title(title):
|
|
|
288 |
|
289 |
import torch
|
290 |
import transformers
|
291 |
+
from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM, pipeline
|
292 |
import gradio as gr
|
293 |
|
294 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
304 |
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
|
305 |
rag_model.retriever.init_retrieval()
|
306 |
rag_model.to(device)
|
307 |
+
|
308 |
pipe = pipeline(
|
309 |
"text-generation",
|
310 |
model="google/gemma-2-2b-it",
|
311 |
model_kwargs={"torch_dtype": torch.bfloat16},
|
312 |
+
device=device,
|
313 |
)
|
314 |
|
315 |
def strip_title(title):
|