File size: 5,076 Bytes
846e270
 
9cf8e68
 
 
d4904e9
 
9cf8e68
846e270
ae1b9d9
9cf8e68
 
 
 
 
 
 
 
2bfa474
71f38ca
9cf8e68
71f38ca
9cf8e68
d4904e9
 
 
9cf8e68
2bfa474
 
 
 
 
ae1b9d9
 
2bfa474
 
9cf8e68
d4904e9
 
 
 
 
 
71f38ca
 
 
9cf8e68
 
2bfa474
9cf8e68
 
 
 
 
 
 
2bfa474
9cf8e68
 
2bfa474
9cf8e68
 
 
 
 
 
 
2bfa474
9cf8e68
 
2bfa474
9cf8e68
2bfa474
9cf8e68
 
 
 
 
2bfa474
9cf8e68
ae1b9d9
ce37439
 
b9f1b65
2bfa474
 
 
 
 
d4904e9
ae1b9d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf8e68
 
ae1b9d9
9cf8e68
 
2bfa474
9cf8e68
bd30380
 
 
 
 
 
 
 
5a84099
2bfa474
bd30380
ae1b9d9
bd30380
 
eb24157
ae1b9d9
bd30380
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
from dotenv import find_dotenv, load_dotenv
import streamlit as st
from typing import Generator
from groq import Groq
import datetime
import json

_ = load_dotenv(find_dotenv())
st.set_page_config(page_icon="", layout="wide", page_title="...")

def icon(emoji: str):
    """Shows an emoji as a Notion-style page icon."""
    st.write(
        f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
        unsafe_allow_html=True,
    )


icon("⚡")

st.subheader("Chatbot", divider="rainbow", anchor=False)

client = Groq(
    api_key=os.environ['GROQ_API_KEY'],
)

# Initialize chat history and selected model
if "messages" not in st.session_state:
    st.session_state.messages = []
if "selected_model" not in st.session_state:
    st.session_state.selected_model = None
if "prompts" not in st.session_state:
    st.session_state.prompts = {}

# Define model details
models = {
    "mixtral-8x7b-32768": {
        "name": "Mixtral-8x7b-Instruct-v0.1",
        "tokens": 32768,
        "developer": "Mistral",
    },
    "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
    "llama2-70b-4096": {"name": "LLaMA2-70b-chat", "tokens": 4096, "developer": "Meta"},
    "llama3-70b-8192": {"name": "LLaMA3-70b-8192", "tokens": 8192, "developer": "Meta"},
    "llama3-8b-8192": {"name": "LLaMA3-8b-8192", "tokens": 8192, "developer": "Meta"},
}

# Layout for model selection and max_tokens slider
col1, col2 = st.columns(2)

with col1:
    model_option = st.selectbox(
        "Choose a model:",
        options=list(models.keys()),
        format_func=lambda x: models[x]["name"],
        index=0,  # Default to the first model in the list
    )

# Detect model change and clear chat history if model has changed
if st.session_state.selected_model != model_option:
    st.session_state.messages = []
    st.session_state.selected_model = model_option

max_tokens_range = models[model_option]["tokens"]

with col2:
    # Adjust max_tokens slider dynamically based on the selected model
    max_tokens = st.slider(
        "Max Tokens:",
        min_value=512,  # Minimum value to allow some flexibility
        max_value=max_tokens_range,
        # Default value or max allowed if less
        value=min(32768, max_tokens_range),
        step=512,
        help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
    )

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    avatar = "" if message["role"] == "assistant" else ""
    with st.chat_message(message["role"], avatar=avatar):
        st.markdown(message["content"])

def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
    """Yield chat response content from the Groq API response."""
    for chunk in chat_completion:
        if chunk.choices[0].delta.content:
            yield chunk.choices[0].delta.content

with st.expander("Save Prompt"):
    save_prompt_button = st.button("Save Prompt")
    if save_prompt_button:
        prompt_name = st.text_input("Enter a name for the prompt:")
        if prompt_name:
            st.session_state.prompts[prompt_name] = st.session_state.messages[-1]["content"]
            st.success(f"Prompt '{prompt_name}' saved!")

with st.expander("Load Prompt"):
    load_prompt_option = st.selectbox(
        "Select a saved prompt:",
        options=list(st.session_state.prompts.keys()),
    )
    if load_prompt_option:
        st.session_state.messages[-1]["content"] = st.session_state.prompts[load_prompt_option]
        st.write("Prompt loaded!")

if load_prompt_option:
    prompt = st.session_state.prompts[load_prompt_option]
else:
    prompt = st.chat_input("Enter your prompt here...")

if prompt:
    st.session_state.messages.append({"role": "user", "content": prompt})

    with st.chat_message("user", avatar=""):  
        st.markdown(prompt)

    # Fetch response from Groq API
    try:
        chat_completion = client.chat.completions.create(
            model=model_option,
            messages=[
                {"role": m["role"], "content": m["content"]}
                for m in st.session_state.messages
            ],
            max_tokens=max_tokens,
            stream=True,
        )

        # Use the generator function with st.write_stream
        with st.chat_message("assistant", avatar=""):
            chat_responses_generator = generate_chat_responses(chat_completion)
            full_response = st.write_stream(chat_responses_generator)
    except Exception as e:
        st.error(e, icon="")

    # Append the full response to session_state.messages
    if isinstance(full_response, str):
        st.session_state.messages.append(
            {"role": "assistant", "content": full_response}
        )
    else:
        # Handle the case where full_response is not a string
        combined_response = "\n".join(str(item) for item in full_response)
        st.session_state.messages.append(
            {"role": "assistant", "content": combined_response}
        )