Moha782 commited on
Commit
9811ddc
·
verified ·
1 Parent(s): 1989656

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from pathlib import Path
4
- from transformers import RagTokenForGeneration, RagTokenizer, DenseRetriever
5
  import faiss
6
  from typing import List
7
  from pdfplumber import open as open_pdf
@@ -29,8 +29,19 @@ embeddings = rag_model.question_encoder(rag_tokenizer(text_chunks, padding=True,
29
  index = faiss.IndexFlatL2(embeddings.size(-1))
30
  index.add(embeddings.detach().numpy())
31
 
32
- # Create a custom retriever
33
- retriever = DenseRetriever(document_store=text_chunks, embedding=embeddings, similarities=index.search)
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def respond(
36
  message,
@@ -52,8 +63,9 @@ def respond(
52
 
53
  # Retrieve relevant chunks using the custom retriever
54
  rag_input_ids = rag_tokenizer(message, return_tensors="pt").input_ids
55
- rag_output = rag_model(rag_input_ids, retriever=retriever, return_retrieved_inputs=True)
56
- retrieved_text = rag_output.retrieved_inputs
 
57
 
58
  # Generate the response using the zephyr model
59
  for message in client.chat_completion(
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from pathlib import Path
4
+ from transformers import RagTokenForGeneration, RagTokenizer
5
  import faiss
6
  from typing import List
7
  from pdfplumber import open as open_pdf
 
29
  index = faiss.IndexFlatL2(embeddings.size(-1))
30
  index.add(embeddings.detach().numpy())
31
 
32
+ # Custom retriever class
33
+ class CustomRetriever:
34
+ def __init__(self, documents, embeddings, index):
35
+ self.documents = documents
36
+ self.embeddings = embeddings
37
+ self.index = index
38
+
39
+ def get_relevant_docs(self, query_embeddings, top_k=4):
40
+ scores, doc_indices = self.index.search(query_embeddings.detach().numpy(), top_k)
41
+ return [(self.documents[doc_idx], score) for doc_idx, score in zip(doc_indices[0], scores[0])]
42
+
43
+ # Create a custom retriever instance
44
+ retriever = CustomRetriever(text_chunks, embeddings, index)
45
 
46
  def respond(
47
  message,
 
63
 
64
  # Retrieve relevant chunks using the custom retriever
65
  rag_input_ids = rag_tokenizer(message, return_tensors="pt").input_ids
66
+ query_embeddings = rag_model.question_encoder(rag_input_ids)
67
+ relevant_docs = retriever.get_relevant_docs(query_embeddings)
68
+ retrieved_text = "\n".join([doc for doc, _ in relevant_docs])
69
 
70
  # Generate the response using the zephyr model
71
  for message in client.chat_completion(