|
import streamlit as st |
|
import os |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.llms import Together |
|
from langchain.prompts import ChatPromptTemplate, PromptTemplate |
|
from langchain.schema import format_document |
|
from typing import List |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain_community.chat_message_histories import StreamlitChatMessageHistory |
|
import time |
|
|
|
|
|
model_name = "BAAI/bge-base-en" |
|
encode_kwargs = {'normalize_embeddings': True} |
|
embedding_function = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs) |
|
|
|
|
|
llm = Together(model="mistralai/Mixtral-8x22B-Instruct-v0.1", temperature=0.2, max_tokens=19096, top_k=10, together_api_key=os.environ['pilotikval'], streaming=True) |
|
|
|
msgs = StreamlitChatMessageHistory(key="langchain_messages") |
|
memory = ConversationBufferMemory(chat_memory=msgs) |
|
|
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") |
|
|
|
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"): |
|
doc_strings = [format_document(doc, document_prompt) for doc in docs] |
|
return document_separator.join(doc_strings) |
|
|
|
chistory = [] |
|
|
|
def store_chat_history(role: str, content: str): |
|
chistory.append({"role": role, "content": content}) |
|
|
|
def render_message_with_copy_button(role: str, content: str, key: str): |
|
html_code = f""" |
|
<div class="message" style="position: relative; padding-right: 40px;"> |
|
<div class="message-content">{content}</div> |
|
<button onclick="copyToClipboard('{key}')" style="position: absolute; right: 0; top: 0; background-color: transparent; border: none; cursor: pointer;"> |
|
<img src="https://img.icons8.com/material-outlined/24/grey/copy.png" alt="Copy"> |
|
</button> |
|
</div> |
|
<textarea id="{key}" style="display:none;">{content}</textarea> |
|
<script> |
|
function copyToClipboard(key) {{ |
|
var copyText = document.getElementById(key); |
|
copyText.style.display = "block"; |
|
copyText.select(); |
|
document.execCommand("copy"); |
|
copyText.style.display = "none"; |
|
alert("Copied to clipboard"); |
|
}} |
|
</script> |
|
""" |
|
st.write(html_code, unsafe_allow_html=True) |
|
|
|
def get_streaming_response(user_query, chat_history): |
|
template = """ |
|
You are a knowledgeable assistant. Provide a detailed and thorough answer to the question based on the following context: |
|
|
|
Chat history: {chat_history} |
|
|
|
User question: {user_question} |
|
""" |
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
inputs = { |
|
"chat_history": chat_history, |
|
"user_question": user_query |
|
} |
|
|
|
chain = prompt | llm |
|
return chain.stream(inputs) |
|
|
|
def app(): |
|
with st.sidebar: |
|
st.title("dochatter") |
|
option = st.selectbox('Which retriever would you like to use?', ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')) |
|
if option == 'RespiratoryFishman': |
|
persist_directory = "./respfishmandbcud/" |
|
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud") |
|
retriever = vectordb.as_retriever(search_kwargs={"k": 5}) |
|
elif option == 'RespiratoryMurray': |
|
persist_directory = "./respmurray/" |
|
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="respmurraynotes") |
|
retriever = vectordb.as_retriever(search_kwargs={"k": 5}) |
|
elif option == 'MedMRCP2': |
|
persist_directory = "./medmrcp2store/" |
|
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="medmrcp2notes") |
|
retriever = vectordb.as_retriever(search_kwargs={"k": 5}) |
|
elif option == 'General Medicine': |
|
persist_directory = "./oxfordmedbookdir/" |
|
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="oxfordmed") |
|
retriever = vectordb.as_retriever(search_kwargs={"k": 7}) |
|
else: |
|
persist_directory = "./mrcpchromadb/" |
|
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes") |
|
retriever = vectordb.as_retriever(search_kwargs={"k": 5}) |
|
|
|
if "messages" not in st.session_state.keys(): |
|
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}] |
|
|
|
st.header("Ask Away!") |
|
for i, message in enumerate(st.session_state.messages): |
|
with st.chat_message(message["role"]): |
|
render_message_with_copy_button(message["role"], message["content"], key=f"message-{i}") |
|
store_chat_history(message["role"], message["content"]) |
|
|
|
user_query = st.chat_input("Say something") |
|
if user_query: |
|
st.session_state.messages.append({"role": "user", "content": user_query}) |
|
with st.chat_message("user"): |
|
st.write(user_query) |
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Thinking..."): |
|
chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chistory]) |
|
try: |
|
response_generator = get_streaming_response(user_query, chat_history) |
|
response_text = "" |
|
for response_part in response_generator: |
|
response_text += response_part |
|
st.write(response_text) |
|
st.session_state.messages.append({"role": "assistant", "content": response_text}) |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
|
|
if __name__ == '__main__': |
|
app() |
|
|