|
import anthropic |
|
import streamlit as st |
|
from streamlit.logger import get_logger |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.llms import OpenAI |
|
from langchain.llms import HuggingFaceEndpoint |
|
from langchain.chat_models import ChatAnthropic |
|
from langchain.vectorstores import SupabaseVectorStore |
|
from stats import add_usage |
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True) |
|
openai_api_key = st.secrets.openai_api_key |
|
anthropic_api_key = st.secrets.anthropic_api_key |
|
hf_api_key = st.secrets.hf_api_key |
|
logger = get_logger(__name__) |
|
|
|
|
|
def count_tokens(question, model): |
|
count = f'Words: {len(question.split())}' |
|
if model.startswith("claude"): |
|
count += f' | Tokens: {anthropic.count_tokens(question)}' |
|
return count |
|
|
|
|
|
def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db): |
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state['chat_history'] = [] |
|
|
|
|
|
|
|
question = st.text_area("## Ask a question") |
|
columns = st.columns(3) |
|
with columns[0]: |
|
button = st.button("Ask") |
|
with columns[1]: |
|
count_button = st.button("Count Tokens", type='secondary') |
|
with columns[2]: |
|
clear_history = st.button("Clear History", type='secondary') |
|
|
|
|
|
|
|
if clear_history: |
|
|
|
memory.clear() |
|
st.session_state['chat_history'] = [] |
|
st.experimental_rerun() |
|
|
|
if button: |
|
qa = None |
|
if not st.session_state["overused"]: |
|
add_usage(stats_db, "chat", "prompt" + question, {"model": model, "temperature": st.session_state['temperature']}) |
|
if model.startswith("gpt"): |
|
logger.info('Using OpenAI model %s', model) |
|
qa = ConversationalRetrievalChain.from_llm( |
|
OpenAI( |
|
model_name=st.session_state['model'], openai_api_key=openai_api_key, temperature=st.session_state['temperature'], max_tokens=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True) |
|
elif anthropic_api_key and model.startswith("claude"): |
|
logger.info('Using Anthropics model %s', model) |
|
qa = ConversationalRetrievalChain.from_llm( |
|
ChatAnthropic( |
|
model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400) |
|
elif hf_api_key: |
|
logger.info('Using HF model %s', model) |
|
|
|
endpoint_url = ("https://api-inference.huggingface.co/models/"+ model) |
|
model_kwargs = {"temperature" : st.session_state['temperature'], |
|
"max_new_tokens" : st.session_state['max_tokens'], |
|
"return_full_text" : False} |
|
hf = HuggingFaceEndpoint( |
|
endpoint_url=endpoint_url, |
|
task="text-generation", |
|
huggingfacehub_api_token=hf_api_key, |
|
model_kwargs=model_kwargs |
|
) |
|
qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": st.session_state["username"]}}), memory=memory, verbose=True, return_source_documents=True) |
|
|
|
st.session_state['chat_history'].append(("You", question)) |
|
|
|
|
|
model_response = qa({"question": question}) |
|
logger.info('Result: %s', model_response["answer"]) |
|
|
|
st.session_state['chat_history'].append(("meraKB", model_response["answer"])) |
|
logger.info('Sources: %s', model_response["source_documents"]) |
|
|
|
|
|
st.empty() |
|
for speaker, text in st.session_state['chat_history']: |
|
st.markdown(f"**{speaker}:** {text}") |
|
else: |
|
st.error("You have used all your free credits. Please try again later or self host.") |
|
|
|
if count_button: |
|
st.write(count_tokens(question, model)) |
|
|