Moha782 commited on
Commit
4cce6fa
·
verified ·
1 Parent(s): edc2346

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- from transformers import pipeline
4
  from typing import List, Dict, Tuple
5
  import re
6
  import os
 
7
 
8
- # Set up the retriever pipeline
9
- retriever = pipeline('retrieval', model='facebook/rag-token-nq')
 
10
 
11
  # Load your PDF document
12
  pdf_path = "apexcustoms.pdf"
@@ -44,10 +46,12 @@ def respond(
44
 
45
  messages.append({"role": "user", "content": message})
46
 
47
- # Retrieve relevant context from the PDF
48
- retrieval_output = retriever(message, corpus, top_k=3)
49
- retrieved_contexts = [passage['text'] for passage in retrieval_output['retrieved_passages']]
50
- context = ' '.join(retrieved_contexts)
 
 
51
 
52
  response = ""
53
 
@@ -57,7 +61,7 @@ def respond(
57
  stream=True,
58
  temperature=temperature,
59
  top_p=top_p,
60
- context=context, # Include the retrieved context
61
  ):
62
  token = message.choices[0].delta.content
63
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import RagTokenizer, RagTokenForGeneration
4
  from typing import List, Dict, Tuple
5
  import re
6
  import os
7
+ import torch
8
 
9
+ # Load the RAG model and tokenizer
10
+ rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
11
+ rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
12
 
13
  # Load your PDF document
14
  pdf_path = "apexcustoms.pdf"
 
46
 
47
  messages.append({"role": "user", "content": message})
48
 
49
+ # Tokenize the input and retrieve relevant context from the PDF
50
+ inputs = rag_tokenizer(message, return_tensors="pt")
51
+ inputs.update({"corpus": corpus})
52
+ input_ids = inputs.pop("input_ids")
53
+ output_ids = rag_model.generate(**inputs, max_length=max_tokens, temperature=temperature, top_p=top_p, num_beams=2)
54
+ retrieved_context = rag_tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
55
 
56
  response = ""
57
 
 
61
  stream=True,
62
  temperature=temperature,
63
  top_p=top_p,
64
+ context=retrieved_context, # Include the retrieved context
65
  ):
66
  token = message.choices[0].delta.content
67