Asankhaya Sharma
commited on
Commit
·
7cd26c6
1
Parent(s):
d536adc
update model names
Browse files- main.py +2 -2
- question.py +7 -6
main.py
CHANGED
@@ -31,7 +31,7 @@ embeddings = HuggingFaceInferenceAPIEmbeddings(
|
|
31 |
|
32 |
vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
|
33 |
|
34 |
-
models = ["llama-2"]
|
35 |
|
36 |
if openai_api_key:
|
37 |
models += ["gpt-3.5-turbo", "gpt-4"]
|
@@ -77,7 +77,7 @@ if st.session_state["authenticated"]:
|
|
77 |
|
78 |
# Initialize session state variables
|
79 |
if 'model' not in st.session_state:
|
80 |
-
st.session_state['model'] = "llama-2"
|
81 |
if 'temperature' not in st.session_state:
|
82 |
st.session_state['temperature'] = 0.1
|
83 |
if 'chunk_size' not in st.session_state:
|
|
|
31 |
|
32 |
vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
|
33 |
|
34 |
+
models = ["meta-llama/Llama-2-7b-chat-hf", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
|
35 |
|
36 |
if openai_api_key:
|
37 |
models += ["gpt-3.5-turbo", "gpt-4"]
|
|
|
77 |
|
78 |
# Initialize session state variables
|
79 |
if 'model' not in st.session_state:
|
80 |
+
st.session_state['model'] = "meta-llama/Llama-2-7b-chat-hf"
|
81 |
if 'temperature' not in st.session_state:
|
82 |
st.session_state['temperature'] = 0.1
|
83 |
if 'chunk_size' not in st.session_state:
|
question.py
CHANGED
@@ -9,8 +9,8 @@ from langchain.chat_models import ChatAnthropic
|
|
9 |
from langchain.vectorstores import SupabaseVectorStore
|
10 |
from stats import add_usage
|
11 |
|
12 |
-
memory = ConversationBufferMemory(
|
13 |
-
|
14 |
openai_api_key = st.secrets.openai_api_key
|
15 |
anthropic_api_key = st.secrets.anthropic_api_key
|
16 |
hf_api_key = st.secrets.hf_api_key
|
@@ -62,10 +62,10 @@ def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
|
|
62 |
qa = ConversationalRetrievalChain.from_llm(
|
63 |
ChatAnthropic(
|
64 |
model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400)
|
65 |
-
elif hf_api_key
|
66 |
-
logger.info('Using
|
67 |
# print(st.session_state['max_tokens'])
|
68 |
-
endpoint_url = ("https://api-inference.huggingface.co/models/
|
69 |
model_kwargs = {"temperature" : st.session_state['temperature'],
|
70 |
"max_new_tokens" : st.session_state['max_tokens'],
|
71 |
"return_full_text" : False}
|
@@ -75,7 +75,7 @@ def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
|
|
75 |
huggingfacehub_api_token=hf_api_key,
|
76 |
model_kwargs=model_kwargs
|
77 |
)
|
78 |
-
qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(), memory=memory, verbose=True)
|
79 |
|
80 |
st.session_state['chat_history'].append(("You", question))
|
81 |
|
@@ -84,6 +84,7 @@ def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
|
|
84 |
logger.info('Result: %s', model_response)
|
85 |
|
86 |
st.session_state['chat_history'].append(("meraKB", model_response["answer"]))
|
|
|
87 |
|
88 |
# Display chat history
|
89 |
st.empty()
|
|
|
9 |
from langchain.vectorstores import SupabaseVectorStore
|
10 |
from stats import add_usage
|
11 |
|
12 |
+
# memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
|
13 |
+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
14 |
openai_api_key = st.secrets.openai_api_key
|
15 |
anthropic_api_key = st.secrets.anthropic_api_key
|
16 |
hf_api_key = st.secrets.hf_api_key
|
|
|
62 |
qa = ConversationalRetrievalChain.from_llm(
|
63 |
ChatAnthropic(
|
64 |
model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400)
|
65 |
+
elif hf_api_key:
|
66 |
+
logger.info('Using HF model %s', model)
|
67 |
# print(st.session_state['max_tokens'])
|
68 |
+
endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
|
69 |
model_kwargs = {"temperature" : st.session_state['temperature'],
|
70 |
"max_new_tokens" : st.session_state['max_tokens'],
|
71 |
"return_full_text" : False}
|
|
|
75 |
huggingfacehub_api_token=hf_api_key,
|
76 |
model_kwargs=model_kwargs
|
77 |
)
|
78 |
+
qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(), memory=memory, verbose=True, return_source_documents=True)
|
79 |
|
80 |
st.session_state['chat_history'].append(("You", question))
|
81 |
|
|
|
84 |
logger.info('Result: %s', model_response)
|
85 |
|
86 |
st.session_state['chat_history'].append(("meraKB", model_response["answer"]))
|
87 |
+
# logger.info('Sources: %s', model_response["source_documents"][0])
|
88 |
|
89 |
# Display chat history
|
90 |
st.empty()
|