Baweja commited on
Commit
579a454
·
verified ·
1 Parent(s): d1bed13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -288,7 +288,7 @@
288
 
289
  import torch
290
  import transformers
291
- from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM
292
  import gradio as gr
293
 
294
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -304,11 +304,12 @@ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name=
304
  rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
305
  rag_model.retriever.init_retrieval()
306
  rag_model.to(device)
 
307
  pipe = pipeline(
308
  "text-generation",
309
  model="google/gemma-2-2b-it",
310
  model_kwargs={"torch_dtype": torch.bfloat16},
311
- device=device, # replace with "mps" to run on a Mac device
312
  )
313
 
314
  def strip_title(title):
 
288
 
289
  import torch
290
  import transformers
291
+ from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM, pipeline
292
  import gradio as gr
293
 
294
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
304
  rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
305
  rag_model.retriever.init_retrieval()
306
  rag_model.to(device)
307
+
308
  pipe = pipeline(
309
  "text-generation",
310
  model="google/gemma-2-2b-it",
311
  model_kwargs={"torch_dtype": torch.bfloat16},
312
+ device=device,
313
  )
314
 
315
  def strip_title(title):