File size: 6,173 Bytes
50af927
378ea6c
 
 
50af927
378ea6c
 
 
 
 
50af927
378ea6c
 
50af927
378ea6c
 
 
 
50af927
378ea6c
 
 
 
 
 
50af927
378ea6c
50af927
5029883
 
 
 
 
50af927
378ea6c
 
 
 
 
 
 
 
50af927
378ea6c
 
 
 
50af927
 
378ea6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50af927
378ea6c
 
 
 
 
 
 
50af927
378ea6c
 
 
 
 
 
50af927
378ea6c
fa4be47
378ea6c
845f4fe
 
 
 
378ea6c
 
 
 
 
 
50af927
378ea6c
 
 
 
 
 
 
 
 
 
 
50af927
378ea6c
 
 
 
 
845f4fe
378ea6c
50af927
378ea6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50af927
 
378ea6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50af927
378ea6c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import gradio as gr
import os
import re
from groq import Groq

def validate_api_key(api_key):
    """Validate if the API key has the correct format."""
    # Basic format check for Groq API keys (they typically start with 'gsk_')
    if not api_key.strip():
        return False, "API key cannot be empty"
    
    if not api_key.startswith("gsk_"):
        return False, "Invalid API key format. Groq API keys typically start with 'gsk_'"
    
    return True, "API key looks valid"

def test_api_connection(api_key):
    """Test the API connection with a minimal request."""
    try:
        client = Groq(api_key=api_key)
        # Making a minimal API call to test the connection
        client.chat.completions.create(
            model="deepseek-r1-distill-llama-70b",
            messages=[{"role": "user", "content": "test"}],
            max_tokens=5
        )
        return True, "API connection successful"
    except Exception as e:
        # Handle all exceptions since Groq might not expose specific error types
        if "authentication" in str(e).lower() or "api key" in str(e).lower():
            return False, "Authentication failed: Invalid API key"
        else:
            return False, f"Error connecting to Groq API: {str(e)}"

def chat_with_groq(api_key, model, user_message, temperature, max_tokens, top_p, chat_history):
    """
    Interact with the Groq API to get a response.
    """
    # Validate API key
    is_valid, message = validate_api_key(api_key)
    if not is_valid:
        return chat_history + [[user_message, f"Error: {message}"]]
    
    # Test API connection
    connection_valid, connection_message = test_api_connection(api_key)
    if not connection_valid:
        return chat_history + [[user_message, f"Error: {connection_message}"]]
    
    try:
        # Format history for the API
        messages = []
        for human, assistant in chat_history:
            messages.append({"role": "user", "content": human})
            messages.append({"role": "assistant", "content": assistant})
        
        # Add the current message
        messages.append({"role": "user", "content": user_message})
        
        # Create the client and make the API call
        client = Groq(api_key=api_key)
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p
        )
        
        # Extract the response text
        assistant_response = response.choices[0].message.content
        
        # Return updated chat history
        return chat_history + [[user_message, assistant_response]]
    
    except Exception as e:
        error_message = f"Error: {str(e)}"
        return chat_history + [[user_message, error_message]]

def clear_conversation():
    """Clear the conversation history."""
    return []

# Define available models
models = [
    "llama3-70b-8192",
    "llama3-8b-8192",
    "mistral-saba-24b",
    "gemma2-9b-it",
    "allam-2-7b"
]

# Create the Gradio interface
with gr.Blocks(title="Groq AI Chat Interface") as app:
    gr.Markdown("# Groq AI Chat Interface")
    gr.Markdown("Enter your Groq API key to start chatting with AI models.")
    
    with gr.Row():
        with gr.Column(scale=2):
            api_key_input = gr.Textbox(
                label="Groq API Key", 
                placeholder="Enter your Groq API key (starts with gsk_)",
                type="password"
            )
            
        with gr.Column(scale=1):
            test_button = gr.Button("Test API Connection")
            api_status = gr.Textbox(label="API Status", interactive=False)
    
    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=models,
                label="Select Model",
                value="llama3-70b-8192"
            )
    
    with gr.Row():
        with gr.Column():
            with gr.Accordion("Advanced Settings", open=False):
                temperature_slider = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.7, step=0.01,
                    label="Temperature (higher = more creative, lower = more focused)"
                )
                max_tokens_slider = gr.Slider(
                    minimum=256, maximum=8192, value=4096, step=256,
                    label="Max Tokens (maximum length of response)"
                )
                top_p_slider = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.95, step=0.01,
                    label="Top P (nucleus sampling probability threshold)"
                )
    
    chatbot = gr.Chatbot(label="Conversation", height=500)
    
    with gr.Row():
        message_input = gr.Textbox(
            label="Your Message",
            placeholder="Type your message here...",
            lines=3
        )
    
    with gr.Row():
        submit_button = gr.Button("Send", variant="primary")
        clear_button = gr.Button("Clear Conversation")
    
    # Connect components with functions
    submit_button.click(
        fn=chat_with_groq,
        inputs=[
            api_key_input,
            model_dropdown,
            message_input,
            temperature_slider,
            max_tokens_slider,
            top_p_slider,
            chatbot
        ],
        outputs=chatbot
    ).then(
        fn=lambda: "",
        inputs=None,
        outputs=message_input
    )
    
    message_input.submit(
        fn=chat_with_groq,
        inputs=[
            api_key_input,
            model_dropdown,
            message_input,
            temperature_slider,
            max_tokens_slider,
            top_p_slider,
            chatbot
        ],
        outputs=chatbot
    ).then(
        fn=lambda: "",
        inputs=None,
        outputs=message_input
    )
    
    clear_button.click(
        fn=clear_conversation,
        inputs=None,
        outputs=chatbot
    )
    
    test_button.click(
        fn=test_api_connection,
        inputs=[api_key_input],
        outputs=[api_status]
    )

# Launch the app
if __name__ == "__main__":
    app.launch(share=False)