|
|
|
from langchain_pinecone.vectorstores import PineconeVectorStore |
|
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint |
|
from langchain.prompts import PromptTemplate |
|
from pinecone import Pinecone |
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.retrievers import MergerRetriever |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
from langchain_community.vectorstores import Chroma as LangChainChroma |
|
import chromadb |
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") |
|
PINECONE_INDEX = os.getenv("PINECONE_INDEX") |
|
HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
EMBEDDINGS_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
CHAT_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ChatBot(): |
|
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL) |
|
|
|
pc = Pinecone(api_key=PINECONE_API_KEY) |
|
index = pc.Index(PINECONE_INDEX) |
|
pinecone_docsearch = PineconeVectorStore.from_existing_index(index_name=PINECONE_INDEX, embedding=embeddings) |
|
pinecone_retriever = pinecone_docsearch.as_retriever( |
|
search_kwargs={'filter': {'source': 'user_id'}} |
|
) |
|
chroma_client = chromadb.PersistentClient(path=":memory:") |
|
chroma_collection = chroma_client.get_or_create_collection( |
|
name="user_docs", |
|
|
|
) |
|
langchain_chroma = LangChainChroma( |
|
client=chroma_client, |
|
collection_name="user_docs", |
|
embedding_function=embeddings |
|
) |
|
|
|
|
|
chroma_retriever = langchain_chroma.as_retriever() |
|
|
|
|
|
combined_retriever = MergerRetriever(retrievers=[pinecone_retriever, chroma_retriever]) |
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
repo_id=CHAT_MODEL, |
|
model_kwargs={"huggingface_api_token":HUGGINGFACE_API_TOKEN}, |
|
temperature=0.5, |
|
top_k=10, |
|
) |
|
|
|
prompt_template = """ |
|
You are a trained bot to guide people about Illinois Crimnal Law Statutes and the Safe-T Act. You will answer user's query with your knowledge and the context provided. |
|
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. |
|
Do not say thank you and tell you are an AI Assistant and be open about everything. |
|
Use the following pieces of context to answer the users question. |
|
Context: {context} |
|
Question: {question} |
|
Only return the helpful answer below and nothing else. |
|
Helpful answer: |
|
""" |
|
|
|
PROMPT = PromptTemplate( |
|
template=prompt_template, |
|
input_variables=["context", "question"]) |
|
|
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
output_key="answer", |
|
chat_memory=ChatMessageHistory(), |
|
return_messages=True, |
|
) |
|
|
|
retrieval_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=combined_retriever, |
|
return_source_documents=True, |
|
combine_docs_chain_kwargs={"prompt": PROMPT}, |
|
memory= memory |
|
) |
|
return retrieval_chain, chroma_collection, langchain_chroma |