Spaces:
Sleeping
Sleeping
File size: 3,297 Bytes
3116721 ffcbf6b 3116721 ffcbf6b 383c59f ffcbf6b eeb00a0 33b41aa 3116721 33b41aa 3116721 33b41aa 3116721 ffcbf6b 3116721 ffcbf6b 3116721 33b41aa ffcbf6b 3116721 33b41aa 3116721 33b41aa 3116721 33b41aa ffcbf6b 33b41aa 3116721 33b41aa 3116721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# logging
import logging
# access .env file
import os
from dotenv import load_dotenv
#import time
#boto3 for S3 access
import boto3
from botocore import UNSIGNED
from botocore.client import Config
# HF libraries
from langchain_huggingface import HuggingFaceEndpoint
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
# vectorestore
#from langchain_community.vectorstores import Chroma
from langchain_community.vectorstores import FAISS
import zipfile
# retrieval chain
from langchain.chains import RetrievalQAWithSourcesChain
# prompt template
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
# github issues
#from langchain.document_loaders import GitHubIssuesLoader
# debugging
from langchain.globals import set_verbose
# caching
from langchain.globals import set_llm_cache
# We can do the same thing with a SQLite cache
from langchain_community.cache import SQLiteCache
# template for prompt
from prompt import template
set_verbose(True)
# set up logging for the chain
logging.basicConfig()
logging.getLogger("langchain.retrievers").setLevel(logging.INFO)
logging.getLogger("langchain.chains.qa_with_sources").setLevel(logging.INFO)
# load .env variables
config = load_dotenv(".env")
HUGGINGFACEHUB_API_TOKEN=os.getenv('HUGGINGFACEHUB_API_TOKEN')
AWS_S3_LOCATION=os.getenv('AWS_S3_LOCATION')
AWS_S3_FILE=os.getenv('AWS_S3_FILE')
VS_DESTINATION=os.getenv('VS_DESTINATION')
# remove old vectorstore
if os.path.exists(VS_DESTINATION):
os.remove(VS_DESTINATION)
# remove old sqlite cache
if os.path.exists('.langchain.sqlite'):
os.remove('.langchain.sqlite')
# initialize Model config
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
# changed named to model_id to llm as is common
llm = HuggingFaceEndpoint(
repo_id=llm_model_name,
temperature=0.1,
max_new_tokens=1024,
repetition_penalty=1.2,
return_full_text=False,
)
# initialize Embedding config
embedding_model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
set_llm_cache(SQLiteCache(database_path=".langchain.sqlite"))
# retrieve vectorsrore
s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
## download vectorstore from S3
s3.download_file(AWS_S3_LOCATION, AWS_S3_FILE, VS_DESTINATION)
with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
zip_ref.extractall('./vectorstore/')
FAISS_INDEX_PATH='./vectorstore/lc-faiss-multi-qa-mpnet'
db = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
# use the cached embeddings instead of embeddings to speed up re-retrival
# db = Chroma(persist_directory="./vectorstore", embedding_function=embeddings)
# db.get()
retriever = db.as_retriever(search_type="mmr", search_kwargs={'k': 3, 'lambda_mult': 0.25})
prompt = PromptTemplate(
input_variables=["history", "context", "question"],
template=template,
)
memory = ConversationBufferMemory(memory_key="history", input_key="question")
qa = RetrievalQAWithSourcesChain.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
"verbose": True,
"memory": memory,
"prompt": prompt,
"document_variable_name": "context"
}
)
|