import openai
import streamlit as st
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
from langchain_core.tracers.context import collect_runs
from langsmith import Client
from streamlit_feedback import streamlit_feedback

from rag_chain.chain import get_rag_chain

# Langsmith client for the feedback system
client = Client()

# Streamlit page configuration
st.set_page_config(page_title="Tall Tree Health",
                   page_icon="💬",
                   layout="centered",
                   initial_sidebar_state="expanded")

# Streamlit CSS configuration

with open("styles/styles.css") as css:
    st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)

# Error message template
base_error_message = (
    "Something went wrong while processing your request. "
    "Please refresh the page or try again later.\n\n"
    "If the error persists, please contact us at "
    "[Tall Tree Health](https://www.talltreehealth.ca/contact-us)."
)

# Get chain and memory


@st.cache_resource(ttl="5d", show_spinner=False)
def get_chain_and_memory():
    try:
        # gpt-4 points to gpt-4-0613
        # gpt-4-turbo-preview points to gpt-4-0125-preview
        # Fine-tuned: ft:gpt-3.5-turbo-1106:tall-tree::8mAkOSED
        # gpt-4-1106-preview
        return get_rag_chain(model_name="gpt-4-turbo", temperature=0.2)

    except Exception as e:
        st.warning(base_error_message, icon="🙁")
        st.stop()


chain, memory = get_chain_and_memory()

# Set up session state and clean memory (important to clean the memory at the end of each session)
if "history" not in st.session_state:
    st.session_state["history"] = []
    memory.clear()

if "messages" not in st.session_state:
    st.session_state["messages"] = []

# Select locations element into a container
with st.container(border=False):
    # Set the welcome message
    st.markdown(
        "\n\nHello there! 👋 Need help finding the right service or practitioner? Let our AI-powered assistant give you a hand.\n\n"
        "To get started, please select your preferred location and share details about your symptoms or needs. "
    )
    location = st.radio(
        "**Our Locations**:",
        ["Cordova Bay - Victoria", "James Bay - Victoria",
            "Commercial Drive - Vancouver"],
        index=None, horizontal=False,
    )
    st.write("\n")

# Get user input only if a location is selected
prompt = ""
if location:
    user_input = st.chat_input("Enter your message...")
    if user_input:
        st.session_state["messages"].append(
            ChatMessage(role="user", content=user_input))
        prompt = f"{user_input}\nLocation: {location}"


# Display previous messages

user_avatar = "images/user.png"
ai_avatar = "images/tall-tree-logo.png"
for msg in st.session_state["messages"]:
    avatar = user_avatar if msg.role == 'user' else ai_avatar
    with st.chat_message(msg.role, avatar=avatar):
        st.markdown(msg.content)

# Chat interface
if prompt:

    # Add all previous messages to memory
    for human, ai in st.session_state["history"]:
        memory.chat_memory.add_user_message(HumanMessage(content=human))
        memory.chat_memory.add_ai_message(AIMessage(content=ai))

    # render the assistant's response
    with st.chat_message("assistant", avatar=ai_avatar):
        message_placeholder = st.empty()

        try:
            partial_message = ""
            # Collect runs for feedback using Langsmith
            with st.spinner(" "), collect_runs() as cb:
                for chunk in chain.stream({"message": prompt}):
                    partial_message += chunk
                    message_placeholder.markdown(partial_message + "|")
                st.session_state.run_id = cb.traced_runs[0].id
            message_placeholder.markdown(partial_message)
        except openai.BadRequestError:
            st.warning(base_error_message, icon="🙁")
            st.stop()
        except Exception as e:
            st.warning(base_error_message, icon="🙁")
            st.stop()

        # Add the full response to the history
        st.session_state["history"].append((prompt, partial_message))

        # Add AI message to memory after the response is generated
        memory.chat_memory.add_ai_message(AIMessage(content=partial_message))

        # Add the full response to the message history
        st.session_state["messages"].append(ChatMessage(
            role="assistant", content=partial_message))


# Feedback system using streamlit feedback and Langsmith

# Get the feedback option
feedback_option = "thumbs"

if st.session_state.get("run_id"):
    run_id = st.session_state.run_id
    feedback = streamlit_feedback(
        feedback_type=feedback_option,
        optional_text_label="[Optional] Please provide an explanation",
        key=f"feedback_{run_id}",
    )
    score_mappings = {
        "thumbs": {"👍": 1, "👎": 0},
        "faces": {"😀": 1, "🙂": 0.75, "😐": 0.5, "🙁": 0.25, "😞": 0},
    }

    # Get the score mapping based on the selected feedback option
    scores = score_mappings[feedback_option]

    if feedback:
        # Get the score from the selected feedback option's score mapping
        score = scores.get(feedback["score"])

        if score is not None:
            # Formulate feedback type string incorporating the feedback option
            # and score value
            feedback_type_str = f"{feedback_option} {feedback['score']}"

            # Record the feedback with the formulated feedback type string
            feedback_record = client.create_feedback(
                run_id,
                feedback_type_str,
                score=score,
                comment=feedback.get("text"),
            )
            st.session_state.feedback = {
                "feedback_id": str(feedback_record.id),
                "score": score,
            }
        else:
            st.warning("Invalid feedback score.")