Chris4K commited on
Commit
47575a3
·
1 Parent(s): 57289c2

Update vector_store_retriever.py

Browse files
Files changed (1) hide show
  1. vector_store_retriever.py +25 -45
vector_store_retriever.py CHANGED
@@ -1,59 +1,39 @@
1
  import gradio as gr
2
- from langchain.document_loaders import DirectoryLoader, PyPDFLoader
3
  from langchain.vectorstores import Chroma
4
- from langchain.chains import RetrievalQA
5
  from langchain.embeddings import HuggingFaceInstructEmbeddings
6
- from langchain.agents import Tool
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain.llms import HuggingFacePipeline
9
- from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
10
 
11
- # Load and process the text files
12
- loader = DirectoryLoader('./new_papers/', glob="./*.pdf", loader_cls=PyPDFLoader)
13
- documents = loader.load()
14
-
15
- # Splitting the text into chunks
16
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
17
- texts = text_splitter.split_documents(documents)
18
-
19
- # HF Instructor Embeddings
20
- instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl", model_kwargs={"device": "cuda"})
21
-
22
- # Embed and store the texts
23
- persist_directory = 'db'
24
- embedding = instructor_embeddings
25
- vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)
26
-
27
- # Make a retriever
28
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
29
-
30
- # Setup LLM for text generation
31
- tokenizer = LlamaTokenizer.from_pretrained("TheBloke/wizardLM-7B-HF")
32
- model = LlamaForCausalLM.from_pretrained("TheBloke/wizardLM-7B-HF", load_in_8bit=True, device_map='auto', torch_dtype=torch.float16, low_cpu_mem_usage=True)
33
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=1024, temperature=0, top_p=0.95, repetition_penalty=1.15)
34
- local_llm = HuggingFacePipeline(pipeline=pipe)
35
 
36
- # Make a chain
37
- qa_chain = RetrievalQA.from_chain_type(llm=local_llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
 
38
 
39
- class VectorStoreRetrieverTool(Tool):
40
- name = "vectorstore_retriever"
41
- description = "This tool uses LangChain's RetrievalQA to find relevant answers from a vector store based on a given query."
42
 
43
- inputs = ["text"]
44
- outputs = ["text"]
 
45
 
46
- def __call__(self, query: str):
47
- # Run the query through the RetrievalQA chain
48
- llm_response = qa_chain(query)
49
- return llm_response['result']
50
 
51
- # Create the Gradio interface using the HuggingFaceTool
52
  tool = gr.Interface(
53
- VectorStoreRetrieverTool(),
 
 
54
  live=True,
55
- title="LangChain-Application: Vectorstore-Retriever",
56
- description="This tool uses LangChain's RetrievalQA to find relevant answers from a vector store based on a given query.",
57
  )
58
 
59
  # Launch the Gradio interface
 
1
  import gradio as gr
 
2
  from langchain.vectorstores import Chroma
3
+ from langchain.document_loaders import PyPDFLoader
4
  from langchain.embeddings import HuggingFaceInstructEmbeddings
 
 
 
 
5
 
6
+ # Initialize the HuggingFaceInstructEmbeddings
7
+ hf = HuggingFaceInstructEmbeddings(
8
+ model_name="hkunlp/instructor-large",
9
+ embed_instruction="Represent the document for retrieval: ",
10
+ query_instruction="Represent the query for retrieval: "
11
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Load and process the PDF files
14
+ loader = PyPDFLoader('./new_papers/new_papers/', glob="./*.pdf")
15
+ documents = loader.load()
16
 
17
+ # Create a Chroma vector store from the PDF documents
18
+ db = Chroma.from_documents(documents, hf, collection_name="my-collection")
 
19
 
20
+ class VectoreStoreRetrievalTool:
21
+ def __init__(self):
22
+ self.retriever = db.as_retriever(search_kwargs={"k": 1})
23
 
24
+ def __call__(self, query):
25
+ # Run the query through the retriever
26
+ response = self.retriever.run(query)
27
+ return response['result']
28
 
29
+ # Create the Gradio interface using the PDFRetrievalTool
30
  tool = gr.Interface(
31
+ PDFRetrievalTool(),
32
+ inputs=gr.Textbox(),
33
+ outputs=gr.Textbox(),
34
  live=True,
35
+ title="PDF Retrieval Tool",
36
+ description="This tool indexes PDF documents and retrieves relevant answers based on a given query.",
37
  )
38
 
39
  # Launch the Gradio interface