schellrw commited on
Commit
07fff33
·
verified ·
1 Parent(s): f4c7d5d

Create chat/bot.py

Browse files
Files changed (1) hide show
  1. chat/bot.py +100 -0
chat/bot.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_pinecone.vectorstores import PineconeVectorStore
3
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
4
+ from langchain.prompts import PromptTemplate
5
+ from pinecone import Pinecone #, ServerlessSpec
6
+ from langchain_community.chat_message_histories import ChatMessageHistory
7
+ from langchain.memory import ConversationBufferMemory
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.retrievers import MergerRetriever
10
+ from dotenv import load_dotenv
11
+ import os
12
+ # from utils import process
13
+ from langchain_community.vectorstores import Chroma as LangChainChroma
14
+ import chromadb
15
+ # from chromadb.config import Settings
16
+ # from chromadb.utils import embedding_functions
17
+
18
+ # Load environment variables from the .env file
19
+ load_dotenv()
20
+
21
+ # Fetch environment variables
22
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
23
+ PINECONE_INDEX = os.getenv("PINECONE_INDEX")
24
+ HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
25
+ EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL")
26
+ CHAT_MODEL = os.getenv("CHAT_MODEL")
27
+
28
+ # Supplement with streamlit secrets if None
29
+ if None in [PINECONE_API_KEY, PINECONE_INDEX, HUGGINGFACE_API_TOKEN, EMBEDDINGS_MODEL, CHAT_MODEL]:
30
+ PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
31
+ PINECONE_INDEX = st.secrets["PINECONE_INDEX"]
32
+ HUGGINGFACE_API_TOKEN = st.secrets["HUGGINGFACEHUB_API_TOKEN"]
33
+ EMBEDDINGS_MODEL = st.secrets["EMBEDDINGS_MODEL"]
34
+ CHAT_MODEL = st.secrets["CHAT_MODEL"]
35
+
36
+ def ChatBot():
37
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL)
38
+ # Initialize Pinecone
39
+ pc = Pinecone(api_key=PINECONE_API_KEY)
40
+ index = pc.Index(PINECONE_INDEX)
41
+ pinecone_docsearch = PineconeVectorStore.from_existing_index(index_name=PINECONE_INDEX, embedding=embeddings)
42
+ pinecone_retriever = pinecone_docsearch.as_retriever(
43
+ search_kwargs={'filter': {'source': 'user_id'}}
44
+ )
45
+ chroma_client = chromadb.PersistentClient(path=":memory:")
46
+ chroma_collection = chroma_client.get_or_create_collection(
47
+ name="user_docs",
48
+ # embedding_function=embeddings
49
+ )
50
+ langchain_chroma = LangChainChroma(
51
+ client=chroma_client,
52
+ collection_name="user_docs",
53
+ embedding_function=embeddings
54
+ )
55
+
56
+ # chroma_retriever = chroma_collection.as_retriever()
57
+ chroma_retriever = langchain_chroma.as_retriever()
58
+
59
+ # Combine retrievers
60
+ combined_retriever = MergerRetriever(retrievers=[pinecone_retriever, chroma_retriever])
61
+
62
+ # Initialize LLM and chain
63
+ llm = HuggingFaceEndpoint(
64
+ repo_id=CHAT_MODEL,
65
+ model_kwargs={"huggingface_api_token":HUGGINGFACE_API_TOKEN},
66
+ temperature=0.5, ## make st.slider, subsequently
67
+ top_k=10, ## make st.slider, subsequently
68
+ )
69
+
70
+ prompt_template = """
71
+ You are a trained bot to guide people about Illinois Crimnal Law Statutes and the Safe-T Act. You will answer user's query with your knowledge and the context provided.
72
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
73
+ Do not say thank you and tell you are an AI Assistant and be open about everything.
74
+ Use the following pieces of context to answer the users question.
75
+ Context: {context}
76
+ Question: {question}
77
+ Only return the helpful answer below and nothing else.
78
+ Helpful answer:
79
+ """
80
+
81
+ PROMPT = PromptTemplate(
82
+ template=prompt_template,
83
+ input_variables=["context", "question"])
84
+
85
+ memory = ConversationBufferMemory(
86
+ memory_key="chat_history",
87
+ output_key="answer",
88
+ chat_memory=ChatMessageHistory(),
89
+ return_messages=True,
90
+ )
91
+
92
+ retrieval_chain = ConversationalRetrievalChain.from_llm(
93
+ llm=llm,
94
+ chain_type="stuff",
95
+ retriever=combined_retriever,
96
+ return_source_documents=True,
97
+ combine_docs_chain_kwargs={"prompt": PROMPT},
98
+ memory= memory
99
+ )
100
+ return retrieval_chain, chroma_collection, langchain_chroma