isayahc commited on
Commit
3116721
·
verified ·
1 Parent(s): 0190e25

added the code needed to qa

Browse files
Files changed (1) hide show
  1. qa.py +112 -0
qa.py CHANGED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # logging
3
+ import logging
4
+
5
+ # access .env file
6
+ import os
7
+ from dotenv import load_dotenv
8
+
9
+ import time
10
+
11
+ #boto3 for S3 access
12
+ import boto3
13
+ from botocore import UNSIGNED
14
+ from botocore.client import Config
15
+
16
+ # HF libraries
17
+ from langchain.llms import HuggingFaceHub
18
+ from langchain.embeddings import HuggingFaceHubEmbeddings
19
+ # vectorestore
20
+ from langchain.vectorstores import Chroma
21
+
22
+ # retrieval chain
23
+ from langchain.chains import RetrievalQAWithSourcesChain
24
+ # prompt template
25
+ from langchain.prompts import PromptTemplate
26
+ from langchain.memory import ConversationBufferMemory
27
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
28
+ # reorder retrived documents
29
+ # github issues
30
+ from langchain.document_loaders import GitHubIssuesLoader
31
+ # debugging
32
+ from langchain.globals import set_verbose
33
+ # caching
34
+ from langchain.globals import set_llm_cache
35
+ # We can do the same thing with a SQLite cache
36
+ from langchain.cache import SQLiteCache
37
+
38
+
39
+ # template for prompt
40
+ from prompt import template
41
+
42
+
43
+
44
+ set_verbose(True)
45
+
46
+ # set up logging for the chain
47
+ logging.basicConfig()
48
+ logging.getLogger("langchain.retrievers").setLevel(logging.INFO)
49
+ logging.getLogger("langchain.chains.qa_with_sources").setLevel(logging.INFO)
50
+
51
+ # load .env variables
52
+ config = load_dotenv(".env")
53
+ HUGGINGFACEHUB_API_TOKEN=os.getenv('HUGGINGFACEHUB_API_TOKEN')
54
+ AWS_S3_LOCATION=os.getenv('AWS_S3_LOCATION')
55
+ AWS_S3_FILE=os.getenv('AWS_S3_FILE')
56
+ VS_DESTINATION=os.getenv('VS_DESTINATION')
57
+
58
+ # remove old vectorstore
59
+ if os.path.exists(VS_DESTINATION):
60
+ os.remove(VS_DESTINATION)
61
+
62
+ # remove old sqlite cache
63
+ if os.path.exists('.langchain.sqlite'):
64
+ os.remove('.langchain.sqlite')
65
+
66
+
67
+
68
+ # initialize Model config
69
+ llm_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
70
+
71
+ # changed named to model_id to llm as is common
72
+ llm = HuggingFaceHub(repo_id=llm_model_name, model_kwargs={
73
+ # "temperature":0.1,
74
+ "max_new_tokens":1024,
75
+ "repetition_penalty":1.2,
76
+ # "streaming": True,
77
+ # "return_full_text":True
78
+ })
79
+
80
+ # initialize Embedding config
81
+ embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
82
+ embeddings = HuggingFaceHubEmbeddings(repo_id=embedding_model_name)
83
+
84
+ set_llm_cache(SQLiteCache(database_path=".langchain.sqlite"))
85
+
86
+ # retrieve vectorsrore
87
+ s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
88
+
89
+ ## Chroma DB
90
+ s3.download_file(AWS_S3_LOCATION, AWS_S3_FILE, VS_DESTINATION)
91
+ # use the cached embeddings instead of embeddings to speed up re-retrival
92
+ db = Chroma(persist_directory="./vectorstore", embedding_function=embeddings)
93
+ db.get()
94
+
95
+ retriever = db.as_retriever(search_type="mmr")#, search_kwargs={'k': 3, 'lambda_mult': 0.25})
96
+
97
+ # asks LLM to create 3 alternatives baed on user query
98
+ # asks LLM to extract relevant parts from retrieved documents
99
+
100
+ prompt = PromptTemplate(
101
+ input_variables=["history", "context", "question"],
102
+ template=template,
103
+ )
104
+ memory = ConversationBufferMemory(memory_key="history", input_key="question")
105
+
106
+ qa = RetrievalQAWithSourcesChain.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
107
+ "verbose": True,
108
+ "memory": memory,
109
+ "prompt": prompt,
110
+ "document_variable_name": "context"
111
+ }
112
+ )