File size: 5,947 Bytes
846e270
 
9cf8e68
 
 
d4904e9
 
9cf8e68
846e270
 
9cf8e68
 
 
 
 
 
 
 
d4904e9
9cf8e68
d4904e9
9cf8e68
d4904e9
 
 
9cf8e68
 
d4904e9
 
 
 
 
9e2c057
d4904e9
9cf8e68
 
 
 
 
 
 
 
 
d4904e9
9cf8e68
 
d4904e9
 
 
 
 
 
9cf8e68
 
 
 
 
 
 
 
 
d4904e9
9cf8e68
 
 
 
 
 
 
846e270
9cf8e68
 
 
 
 
 
 
 
d4904e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122d511
 
 
 
 
 
 
 
9cf8e68
122d511
 
d4904e9
122d511
 
 
 
d4904e9
122d511
d4904e9
122d511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4904e9
122d511
 
 
 
 
 
d4904e9
 
 
 
 
 
813e436
2c1a7b5
9cf8e68
 
846e270
9cf8e68
 
122d511
 
 
9cf8e68
d4904e9
 
 
 
 
 
 
 
 
 
122d511
 
9cf8e68
 
2c1a7b5
d4904e9
 
 
 
122d511
 
d4904e9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
from dotenv import find_dotenv, load_dotenv
import streamlit as st
from typing import Generator
from groq import Groq
import datetime
import json

_ = load_dotenv(find_dotenv())
st.set_page_config(page_icon="๐Ÿ’ฌ", layout="wide", page_title="Groq Chat Bot...")

def icon(emoji: str):
    """Shows an emoji as a Notion-style page icon."""
    st.write(
        f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
        unsafe_allow_html=True,
    )

icon("๐Ÿ“ฃ")

st.subheader("Groq Chat Streamlit App", divider="rainbow", anchor=False)

client = Groq(
    api_key=os.environ['GROQ_API_KEY'],
)

models = {
    "mixtral-8x7b-32768": {
        "name": "Mixtral-8x7b-Instruct-v0.1",
        "tokens": 32768,
        "developer": "Mistral",
    },
    "llama2-70b-4096": {"name": "LLaMA2-70b-chat", "tokens": 4096, "developer": "Meta"},
    "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
}

col1, col2 = st.columns(2)

with col1:
    model_option = st.selectbox(
        "Choose a model:",
        options=list(models.keys()),
        format_func=lambda x: models[x]["name"],
        index=0,
    )

if "messages" not in st.session_state:
    st.session_state.messages = []

if "selected_model" not in st.session_state:
    st.session_state.selected_model = None

if st.session_state.selected_model != model_option:
    st.session_state.messages = []
    st.session_state.selected_model = model_option

max_tokens_range = models[model_option]["tokens"]

with col2:
    max_tokens = st.slider(
        "Max Tokens:",
        min_value=512,
        max_value=max_tokens_range,
        value=min(32768, max_tokens_range),
        step=512,
        help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
    )

for message in st.session_state.messages:
    avatar = "๐Ÿค–" if message["role"] == "assistant" else "๐Ÿ•บ"
    with st.chat_message(message["role"], avatar=avatar):
        st.markdown(message["content"])

def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
    """Yield chat response content from the Groq API response."""
    for chunk in chat_completion:
        if chunk.choices[0].delta.content:
            yield chunk.choices[0].delta.content
        if chunk.choices[0].message.tool_calls:
            for tool_call in chunk.choices[0].message.tool_calls:
                function_name = tool_call.function.name
                if function_name == "time_date":
                    owner_info = get_tool_owner_info()
                    yield owner_info

def run_conversation(user_prompt):
    messages=[
        {
            "role": "system",
            "content": "You are a helpful assistant named ChattyBot."
        },
        {
            "role": "user",
            "content": user_prompt,
        }
    ]
    tools = [
        {
            "type": "function",
            "function": {
                "name": "time_date",
                "description": "The tool will return information about the time and date to the AI.",
                "parameters": {},
            },
        }
    ]
    try:
        response = client.chat.completions.create(
            model=model_option,
            messages=messages,
            tools=tools,
            tool_choice="auto",
            max_tokens=4096
        )

        response_message = response.choices[0].delta
        tool_calls = response_message.tool_calls

        if tool_calls:
            available_functions = {
                "time_date": get_tool_owner_info
            }

            messages.append(response_message)

            for tool_call in tool_calls:
                function_name = tool_call.function.name
                function_to_call = available_functions[function_name]
                function_args = json.loads(tool_call.function.arguments)
                function_response = function_to_call(**function_args)
                messages.append(
                    {
                        "tool_call_id": tool_call.id,
                        "role": "tool",
                        "name": function_name,
                        "content": function_response,
                    }
                )

            second_response = client.chat.completions.create(
                model=model_option,
                messages=messages
            )

            return second_response.choices[0].delta.content
        else:
            return response_message.content
    except Exception as e:
        st.error(e, icon="๐Ÿšจ")
        return None

def get_tool_owner_info():
    owner_info = {
        "date_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    return json.dumps(owner_info)

if prompt := st.chat_input("Enter your prompt here..."):
    st.session_state.messages.append({"role": "user", "content": prompt})

    with st.chat_message("user", avatar="๐Ÿ•บ"):  
        st.markdown(prompt)

    chat_responses_generator = None
    full_response = None

    try:
        chat_completion = client.chat.completions.create(
            model=model_option,
            messages=[
                {"role": m["role"], "content": m["content"]}
                for m in st.session_state.messages
            ],
            max_tokens=max_tokens,
            stream=True,
        )

        chat_responses_generator = generate_chat_responses(chat_completion)
        full_response = st.write_stream(chat_responses_generator)
    except Exception as e:
        st.error(e, icon="๐Ÿšจ")

    if isinstance(full_response, str):
        st.session_state.messages.append(
            {"role": "assistant", "content": full_response}
        )
    elif chat_responses_generator:
        combined_response = "\n".join(str(item) for item in chat_responses_generator)
        st.session_state.messages.append(
            {"role": "assistant", "content": combined_response}
        )