Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain import memory as lc_memory | |
from langsmith import Client | |
from streamlit_feedback import streamlit_feedback | |
from utils import get_expression_chain, retriever, get_embeddings, create_qdrant_collection | |
from langchain_core.tracers.context import collect_runs | |
from qdrant_client import QdrantClient | |
from dotenv import load_dotenv | |
import os | |
if "access_granted" not in st.session_state: | |
st.session_state.access_granted = False | |
if "profile" not in st.session_state: | |
st.session_state.profile = None | |
if "name" not in st.session_state: | |
st.session_state.name = None | |
if not st.session_state.access_granted: | |
# Profile input section | |
st.title("User Profile") | |
name = st.text_input("Name") | |
profile_selector = st.selectbox("Profile", options=["Student", "Professor", "Administrator", "Other"]) | |
if profile_selector == "Other": | |
profile = st.text_input("What is your role?") | |
else: | |
profile = profile_selector | |
if profile and name: | |
d = False | |
else: | |
d = True | |
submission = st.button("Submit", disabled=d) | |
if submission: | |
st.session_state.profile = profile | |
st.session_state.name = name | |
st.session_state.access_granted = True # Grant access to main app | |
st.rerun() # Reload the app | |
else: | |
load_dotenv() | |
profile = st.session_state.profile | |
client = Client() | |
qdrant_api=os.getenv("QDRANT_API_KEY") | |
qdrant_url=os.getenv("QDRANT_URL") | |
qdrant_client = QdrantClient(qdrant_url ,api_key=qdrant_api) | |
st.set_page_config(page_title = "SUP'ASSISTANT") | |
st.subheader(f"Hello {st.session_state.name}! How can I help you today!") | |
memory = lc_memory.ConversationBufferMemory( | |
chat_memory=lc_memory.StreamlitChatMessageHistory(key="langchain_messages"), | |
return_messages=True, | |
memory_key="chat_history", | |
) | |
st.sidebar.markdown("## Feedback Scale") | |
feedback_option = ( | |
"thumbs" if st.sidebar.toggle(label="`Faces` β `Thumbs`", value=False) else "faces" | |
) | |
with st.sidebar: | |
temp = st.slider("**Temperature**", min_value=0.0, max_value=1.0, step=0.001) | |
n_docs = st.number_input("**Number of retireved documents**", min_value=0, max_value=10, value=5, step=1) | |
if st.sidebar.button("Clear message history"): | |
print("Clearing message history") | |
memory.clear() | |
retriever = retriever(n_docs=n_docs) | |
# Create Chain | |
chain = get_expression_chain(retriever,"llama-3.3-70b-versatile",temp) | |
for msg in st.session_state.langchain_messages: | |
avatar = "π¦" if msg.type == "ai" else None | |
with st.chat_message(msg.type, avatar=avatar): | |
st.markdown(msg.content) | |
prompt = st.chat_input(placeholder="What do you need to know about SUP'COM ?") | |
if prompt : | |
with st.chat_message("user"): | |
st.write(prompt) | |
with st.chat_message("assistant", avatar="π¦"): | |
message_placeholder = st.empty() | |
full_response = "" | |
# Define the basic input structure for the chains | |
input_dict = {"input": prompt.lower()} | |
used_docs = retriever.get_relevant_documents(prompt.lower()) | |
with collect_runs() as cb: | |
for chunk in chain.stream(input_dict, config={"tags": ["SUP'ASSISTANT"]}): | |
full_response += chunk.content | |
message_placeholder.markdown(full_response + "β") | |
memory.save_context(input_dict, {"output": full_response}) | |
st.session_state.run_id = cb.traced_runs[0].id | |
message_placeholder.markdown(full_response) | |
if used_docs : | |
docs_content = "\n\n".join( | |
[ | |
f"Doc {i+1}:\n" | |
f"Source: {doc.metadata['source']}\n" | |
f"Title: {doc.metadata['title']}\n" | |
f"Content: {doc.page_content}\n" | |
for i, doc in enumerate(used_docs) | |
] | |
) | |
with st.sidebar: | |
st.download_button( | |
label="Consulted Documents", | |
data=docs_content, | |
file_name="Consulted_documents.txt", | |
mime="text/plain", | |
) | |
with st.spinner("Just a sec! Dont enter prompts while loading pelase!"): | |
run_id = st.session_state.run_id | |
question_embedding = get_embeddings(prompt) | |
answer_embedding = get_embeddings(full_response) | |
# Add question and answer to Qdrant | |
qdrant_client.upload_collection( | |
collection_name="chat-history", | |
payload=[ | |
{"text": prompt, "type": "question", "question_ID": run_id}, | |
{"text": full_response, "type": "answer", "question_ID": run_id, "used_docs":used_docs} | |
], | |
vectors=[ | |
question_embedding, | |
answer_embedding, | |
], | |
parallel=4, | |
max_retries=3, | |
) | |
if st.session_state.get("run_id"): | |
run_id = st.session_state.run_id | |
feedback = streamlit_feedback( | |
feedback_type=feedback_option, | |
optional_text_label="[Optional] Please provide an explanation", | |
key=f"feedback_{run_id}", | |
) | |
# Define score mappings for both "thumbs" and "faces" feedback systems | |
score_mappings = { | |
"thumbs": {"π": 1, "π": 0}, | |
"faces": {"π": 1, "π": 0.75, "π": 0.5, "π": 0.25, "π": 0}, | |
} | |
# Get the score mapping based on the selected feedback option | |
scores = score_mappings[feedback_option] | |
if feedback: | |
# Get the score from the selected feedback option's score mapping | |
score = scores.get(feedback["score"]) | |
if score is not None: | |
# Formulate feedback type string incorporating the feedback option | |
# and score value | |
feedback_type_str = f"{feedback_option} {feedback['score']}" | |
# Record the feedback with the formulated feedback type string | |
# and optional comment | |
with st.spinner("Just a sec! Dont enter prompts while loading pelase!"): | |
feedback_record = client.create_feedback( | |
run_id, | |
feedback_type_str, | |
score=score, | |
comment=feedback.get("text"), | |
source_info={"profile":profile} | |
) | |
st.session_state.feedback = { | |
"feedback_id": str(feedback_record.id), | |
"score": score, | |
} | |
else: | |
st.warning("Invalid feedback score.") | |
with st.spinner("Just a sec! Dont enter prompts while loading pelase!"): | |
if feedback.get("text"): | |
comment = feedback.get("text") | |
feedback_embedding = get_embeddings(comment) | |
else: | |
comment = "no comment" | |
feedback_embedding = get_embeddings(comment) | |
qdrant_client.upload_collection( | |
collection_name="chat-history", | |
payload=[ | |
{"text": comment, | |
"Score:":score, | |
"type": "feedback", | |
"question_ID": run_id, | |
"User_profile":profile} | |
], | |
vectors=[ | |
feedback_embedding | |
], | |
parallel=4, | |
max_retries=3, | |
) |