from openai import OpenAI import streamlit as st import streamlit.components.v1 as components import datetime, time from dataclasses import dataclass import math import base64 ## Firestore ?? import os # import sys # import inspect # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) # parentdir = os.path.dirname(currentdir) # sys.path.append(parentdir) # ## ---------------------------------------------------------------- # ## LLM Part import openai from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings import tiktoken from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate from operator import itemgetter from langchain.schema import StrOutputParser from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough import langchain_community.embeddings.huggingface from langchain_community.embeddings.huggingface import HuggingFaceBgeEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import LLMChain from langchain.chains.conversation.memory import ConversationBufferWindowMemory #, ConversationBufferMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory import os, dotenv from dotenv import load_dotenv load_dotenv() if not os.path.isdir("./.streamlit"): os.mkdir("./.streamlit") print('made streamlit folder') if not os.path.isfile("./.streamlit/secrets.toml"): with open("./.streamlit/secrets.toml", "w") as f: f.write(os.environ.get("STREAMLIT_SECRETS")) print('made new file') import db_firestore as db ## Load from streamlit!! os.environ["HF_TOKEN"] = os.environ.get("HF_TOKEN") or st.secrets["HF_TOKEN"] os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") or st.secrets["OPENAI_API_KEY"] os.environ["FIREBASE_CREDENTIAL"] = os.environ.get("FIREBASE_CREDENTIAL") or st.secrets["FIREBASE_CREDENTIAL"] if "openai_model" not in st.session_state: st.session_state["openai_model"] = "gpt-3.5-turbo-1106" ## Hardcode indexes for now ## TODO: Move indexes to firebase indexes = """Bleeding ChestPain Dysphagia Headache ShortnessOfBreath Vomiting Weakness Weakness2""".split("\n") # if "selected_index" not in st.session_state: # st.session_state.selected_index = 3 # if "index_selectbox" not in st.session_state: # st.session_state.index_selectbox = "Headache" # index_selectbox = st.selectbox("Select index",indexes, index=int(st.session_state.selected_index)) # if index_selectbox != indexes[st.session_state.selected_index]: # st.session_state.selected_index = indexes.index(index_selectbox) # st.session_state.index_selectbox = index_selectbox # del st.session_state["store"] # del st.session_state["store2"] # del st.session_state["retriever"] # del st.session_state["retriever2"] # del st.session_state["chain"] # del st.session_state["chain2"] model_name = "bge-large-en-v1.5" model_kwargs = {"device": "cpu"} encode_kwargs = {"normalize_embeddings": True} if "embeddings" not in st.session_state: st.session_state.embeddings = HuggingFaceBgeEmbeddings( # model_name=model_name, model_kwargs = model_kwargs, encode_kwargs = encode_kwargs) embeddings = st.session_state.embeddings if "llm" not in st.session_state: st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0) llm = st.session_state.llm if "llm_i" not in st.session_state: st.session_state.llm_i = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0) llm_i = st.session_state.llm_i if "llm_gpt4" not in st.session_state: st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0) llm_gpt4 = st.session_state.llm_gpt4 # ## ------------------------------------------------------------------------------------------------ # ## Patient part # index_name = f"indexes/{st.session_state.index_selectbox}/QA" # if "store" not in st.session_state: # st.session_state.store = db.get_store(index_name, embeddings=embeddings) # store = st.session_state.store if "TEMPLATE" not in st.session_state: with open('templates/patient.txt', 'r') as file: TEMPLATE = file.read() st.session_state.TEMPLATE = TEMPLATE TEMPLATE = st.session_state.TEMPLATE # with st.expander("Patient Prompt"): # TEMPLATE = st.text_area("Patient Prompt", value=st.session_state.TEMPLATE) prompt = PromptTemplate( input_variables = ["question", "context"], template = st.session_state.TEMPLATE ) # if "retriever" not in st.session_state: # st.session_state.retriever = store.as_retriever(search_type="similarity", search_kwargs={"k":2}) # retriever = st.session_state.retriever def format_docs(docs): return "\n--------------------\n".join(doc.page_content for doc in docs) # if "memory" not in st.session_state: # st.session_state.memory = ConversationBufferWindowMemory( # llm=llm, memory_key="chat_history", input_key="question", # k=5, human_prefix="student", ai_prefix="patient",) # memory = st.session_state.memory # if ("chain" not in st.session_state # or # st.session_state.TEMPLATE != TEMPLATE): # st.session_state.chain = ( # { # "context": retriever | format_docs, # "question": RunnablePassthrough() # } | # LLMChain(llm=llm, prompt=prompt, memory=memory, verbose=False) # ) # chain = st.session_state.chain sp_mapper = {"human":"student","ai":"patient", "user":"student","assistant":"patient"} # ## ------------------------------------------------------------------------------------------------ # ## ------------------------------------------------------------------------------------------------ # ## Grader part # index_name = f"indexes/{st.session_state.index_selectbox}/Rubric" # if "store2" not in st.session_state: # st.session_state.store2 = db.get_store(index_name, embeddings=embeddings) # store2 = st.session_state.store2 if "TEMPLATE2" not in st.session_state: with open('templates/grader.txt', 'r') as file: TEMPLATE2 = file.read() st.session_state.TEMPLATE2 = TEMPLATE2 TEMPLATE2 = st.session_state.TEMPLATE2 # with st.expander("Grader Prompt"): # TEMPLATE2 = st.text_area("Grader Prompt", value=st.session_state.TEMPLATE2) prompt2 = PromptTemplate( input_variables = ["question", "context", "history"], template = st.session_state.TEMPLATE2 ) def get_patient_chat_history(_): return st.session_state.get("patient_chat_history") # if "retriever2" not in st.session_state: # st.session_state.retriever2 = store2.as_retriever(search_type="similarity", search_kwargs={"k":2}) # retriever2 = st.session_state.retriever2 # def format_docs(docs): # return "\n--------------------\n".join(doc.page_content for doc in docs) # fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages]) # fake_history = '\n'.join([(sp_mapper.get(i['role'], i['role']) + ": "+ i['content']) for i in st.session_state.messages_1]) # st.write(fake_history) # def y(_): # return fake_history # if ("chain2" not in st.session_state # or # st.session_state.TEMPLATE2 != TEMPLATE2): # st.session_state.chain2 = ( # { # "context": retriever2 | format_docs, # "history": y, # "question": RunnablePassthrough(), # } | # # LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #| # LLMChain(llm=llm_gpt4, prompt=prompt2, verbose=False ) #| # | { # "json": itemgetter("text"), # "text": ( # LLMChain( # llm=llm, # prompt=PromptTemplate( # input_variables=["text"], # template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"), # verbose=False) # ) # } # ) # chain2 = st.session_state.chain2 # ## ------------------------------------------------------------------------------------------------ # ## ------------------------------------------------------------------------------------------------ # ## Streamlit now # # from dotenv import load_dotenv # # import os # # load_dotenv() # # key = os.environ.get("OPENAI_API_KEY") # # client = OpenAI(api_key=key) # if st.button("Clear History and Memory", type="primary"): # st.session_state.messages_1 = [] # st.session_state.messages_2 = [] # st.session_state.memory = ConversationBufferWindowMemory(llm=llm, memory_key="chat_history", input_key="question" ) # memory = st.session_state.memory # ## Testing HTML # # html_string = """ # # # # # # # # """ # # components.html(html_string, # # width=1280, # # height=640) # st.write("Timer has been removed, switch with this button") # if st.button(f"Switch to {'PATIENT' if st.session_state.active_chat==2 else 'GRADER'}"+".... Buggy button, please double click"): # st.session_state.active_chat = 3 - st.session_state.active_chat # # st.write("Currently in " + ('PATIENT' if st.session_state.active_chat==2 else 'GRADER')) # # Create two columns for the two chat interfaces # col1, col2 = st.columns(2) # # First chat interface # with col1: # st.subheader("Student LLM") # for message in st.session_state.messages_1: # with st.chat_message(message["role"]): # st.markdown(message["content"]) # # Second chat interface # with col2: # # st.write("pls dun spam this, its tons of tokens cos chat history") # st.subheader("Grader LLM") # st.write("grader takes a while to load... please be patient") # for message in st.session_state.messages_2: # with st.chat_message(message["role"]): # st.markdown(message["content"]) # # Timer and Input # # time_left = None # # if st.session_state.start_time: # # time_elapsed = datetime.datetime.now() - st.session_state.start_time # # time_left = datetime.timedelta(minutes=10) - time_elapsed # # st.write(f"Time left: {time_left}") # # if time_left is None or time_left > datetime.timedelta(0): # # # Chat 1 is active # # prompt = st.text_input("Enter your message for Chat 1:") # # active_chat = 1 # # messages = st.session_state.messages_1 # # elif time_left and time_left <= datetime.timedelta(0): # # # Chat 2 is active # # prompt = st.text_input("Enter your message for Chat 2:") # # active_chat = 2 # # messages = st.session_state.messages_2 # if st.session_state.active_chat==1: # text_prompt = st.text_input("Enter your message for PATIENT") # messages = st.session_state.messages_1 # else: # text_prompt = st.text_input("Enter your message for GRADER") # messages = st.session_state.messages_2 # from langchain.callbacks.manager import tracing_v2_enabled # from uuid import uuid4 # import os # if text_prompt: # messages.append({"role": "user", "content": text_prompt}) # with (col1 if st.session_state.active_chat == 1 else col2): # with st.chat_message("user"): # st.markdown(text_prompt) # with (col1 if st.session_state.active_chat == 1 else col2): # with st.chat_message("assistant"): # message_placeholder = st.empty() # if True: ## with tracing_v2_enabled(project_name = "streamlit"): # if st.session_state.active_chat==1: # full_response = chain.invoke(text_prompt).get("text") # else: # full_response = chain2.invoke(text_prompt).get("text").get("text") # message_placeholder.markdown(full_response) # messages.append({"role": "assistant", "content": full_response}) # st.write('fake history is:') # st.write(y("")) # st.write('done') ## ==================== if not st.session_state.get("scenario_list", None): st.session_state.scenario_list = indexes def init_patient_llm(): if "messages_1" not in st.session_state: st.session_state.messages_1 = [] ## messages 2? index_name = f"indexes/{st.session_state.scenario_list[st.session_state.selected_scenario]}/QA" if "store" not in st.session_state: st.session_state.store = db.get_store(index_name, embeddings=embeddings) if "retriever" not in st.session_state: st.session_state.retriever = st.session_state.store.as_retriever(search_type="similarity", search_kwargs={"k":2}) if "memory" not in st.session_state: st.session_state.memory = ConversationBufferWindowMemory( llm=llm, memory_key="chat_history", input_key="question", k=5, human_prefix="student", ai_prefix="patient",) if ("chain" not in st.session_state or st.session_state.TEMPLATE != TEMPLATE): st.session_state.chain = ( { "context": st.session_state.retriever | format_docs, "question": RunnablePassthrough() } | LLMChain(llm=llm, prompt=prompt, memory=st.session_state.memory, verbose=False) ) def init_grader_llm(): ## Grader index_name = f"indexes/{st.session_state.scenario_list[st.session_state.selected_scenario]}/Rubric" ## Reset time st.session_state.start_time = False if "store2" not in st.session_state: st.session_state.store2 = db.get_store(index_name, embeddings=embeddings) if "retriever2" not in st.session_state: st.session_state.retriever2 = st.session_state.store2.as_retriever(search_type="similarity", search_kwargs={"k":2}) ## Re-init history st.session_state["patient_chat_history"] = "History\n" + '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in st.session_state.memory.chat_memory.messages]) if ("chain2" not in st.session_state or st.session_state.TEMPLATE2 != TEMPLATE2): st.session_state.chain2 = ( { "context": st.session_state.retriever2 | format_docs, "history": (get_patient_chat_history), "question": RunnablePassthrough(), } | # LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #| LLMChain(llm=llm_gpt4, prompt=prompt2, verbose=False ) #| | { "json": itemgetter("text"), "text": ( LLMChain( llm=llm, prompt=PromptTemplate( input_variables=["text"], template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"), verbose=False) ) } ) login_info = { "bob":"builder", "student1": "password", "admin":"admin" } def set_username(x): st.session_state.username = x def validate_username(username, password): if login_info.get(username) == password: set_username(username) else: st.warning("Wrong username or password") return None if not st.session_state.get("username"): ## ask to login st.title("Login") username = st.text_input("Username:") password = st.text_input("Password:", type="password") login_button = st.button("Login", on_click=validate_username, args=[username, password]) else: if True: ## Says hello and logout col_1, col_2 = st.columns([1,3]) col_2.title(f"Hello there, {st.session_state.username}") # Display logout button if col_1.button('Logout'): # Remove username from session state del st.session_state.username # Rerun the app to go back to the login view st.rerun() scenario_tab, dashboard_tab = st.tabs(["Training", "Dashboard"]) # st.header("head") # st.markdown("## markdown") # st.caption("caption") # st.divider() # import pandas as pd # import numpy as np # map_data = pd.DataFrame( # np.random.randn(1000, 2) / [50, 50] + [37.76, -122.4], # columns=['lat', 'lon']) # st.map(map_data) class ScenarioTabIndex: SELECT_SCENARIO = 0 PATIENT_LLM = 1 GRADER_LLM = 2 def set_scenario_tab_index(x): st.session_state.scenario_tab_index=x return None def select_scenario_and_change_tab(_): set_scenario_tab_index(ScenarioTabIndex.PATIENT_LLM) def go_to_patient_llm(): selected_scenario = st.session_state.get('selected_scenario') if selected_scenario is None or selected_scenario < 0: st.warning("Please select a scenario!") else: ## TODO: Clear state for time, LLM, Index, etc states = ["store", "store2","retriever","retriever2","chain","chain2"] for state_to_del in states: if state_to_del in st.session_state: del st.session_state[state_to_del] init_patient_llm() set_scenario_tab_index(ScenarioTabIndex.PATIENT_LLM) if not st.session_state.get("scenario_tab_index"): set_scenario_tab_index(ScenarioTabIndex.SELECT_SCENARIO) with scenario_tab: ## Check in select scenario if st.session_state.scenario_tab_index == ScenarioTabIndex.SELECT_SCENARIO: def change_scenario(scenario_index): st.session_state.selected_scenario = scenario_index if st.session_state.get("selected_scenario", None) is None: st.session_state.selected_scenario = -1 total_cols = 3 rows = list() # for _ in range(0, number_of_indexes, total_cols): # rows.extend(st.columns(total_cols)) st.header(f"Selected Scenario: {st.session_state.scenario_list[st.session_state.selected_scenario] if st.session_state.selected_scenario>=0 else 'None'}") for i, scenario in enumerate(st.session_state.scenario_list): if i % total_cols == 0: rows.extend(st.columns(total_cols)) curr_col = rows[(-total_cols + i % total_cols)] tile = curr_col.container(height=120) ## TODO: Implement highlight box if index is selected # if st.session_state.selected_scenario == i: # tile.markdown("", unsafe_allow_html=True) tile.write(":balloon:") tile.button(label=scenario, on_click=change_scenario, args=[i]) select_scenario_btn = st.button("Select Scenario", on_click=go_to_patient_llm, args=[]) elif st.session_state.scenario_tab_index == ScenarioTabIndex.PATIENT_LLM: st.header("Patient info") st.write("Pull the info here!!!") col1, col2, col3 = st.columns([1,3,1]) with col1: back_to_scenario_btn = st.button("Back to selection", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO]) with col3: start_timer_button = st.button("START") with col2: TIME_LIMIT = 60*10 ## to change to 10 minutes time.sleep(1) if start_timer_button: st.session_state.start_time = datetime.datetime.now() # st.session_state.time = -1 if not st.session_state.get('time') else st.session_state.get('time') st.session_state.start_time = False if not st.session_state.get('start_time') else st.session_state.start_time from streamlit.components.v1 import html html(f"""