File size: 6,198 Bytes
0fd775b
 
 
 
 
0dd8117
 
eda2dbf
0fd775b
c90c7ed
42f3ad6
0fd775b
 
 
 
 
eda2dbf
 
 
c90c7ed
eda2dbf
 
fd137e6
 
eda2dbf
fd137e6
 
eda2dbf
fd137e6
 
 
 
 
 
 
 
eda2dbf
 
0dd8117
eda2dbf
 
 
 
 
 
 
 
 
 
fd137e6
 
0fd775b
fd137e6
eda2dbf
7697929
 
 
eda2dbf
 
a573f7f
eda2dbf
0dd8117
 
 
 
eda2dbf
 
 
0dd8117
 
 
 
eda2dbf
 
93d55ac
eda2dbf
93d55ac
eda2dbf
93d55ac
 
eda2dbf
9a56c61
eda2dbf
9a56c61
 
eda2dbf
 
9a56c61
eda2dbf
9a56c61
 
 
 
0dd8117
eda2dbf
 
 
 
 
 
0dd8117
9a56c61
 
 
 
 
 
 
 
 
 
0dd8117
eda2dbf
 
 
 
 
 
 
 
 
 
 
 
 
bbcf582
 
 
 
 
 
eda2dbf
bbcf582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eda2dbf
 
 
 
 
 
 
 
 
 
 
d0568c3
 
dba44ef
3d91d29
d0568c3
 
 
 
760dc49
d0568c3
 
 
 
eda2dbf
 
 
 
5283206
 
 
 
 
 
 
 
d0568c3
5283206
 
 
 
 
 
 
 
 
 
 
eda2dbf
 
 
 
 
0dd8117
 
eda2dbf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from dotenv import load_dotenv

import os
from pathlib import Path

import gradio as gr
from huggingface_hub import InferenceClient
from openai import OpenAI

from prompt_template import PromptTemplate, PromptLoader
from assistant import AIAssistant

# Load .env file
load_dotenv()

API_KEY = os.getenv('API_KEY')

# Load prompts from YAML
prompts = PromptLoader.load_prompts("prompts.yaml")

# Available models and their configurations
MODELS = {
    "Llama 3.3 70B Instruct": {
        "name": "meta/llama-3.3-70b-instruct",
    },
    "Llama 3.1 405B Instruct": {
        "name": "meta/llama-3.1-405b-instruct",
    },
    "Llama 3.2 3B Instruct": {
        "name": "meta/llama-3.2-3b-instruct",
    },
    "Falcon 3 7B Instruct": {
        "name": "tiiuae/falcon3-7b-instruct",
    },
    "Granite 3.0 8B Instruct": {
        "name": "ibm/granite-3.0-8b-instruct",
    }
}

# Available prompt strategies
PROMPT_STRATEGIES = {
    "Default": "system_context",
    "Chain of Thought": "cot_prompt",
    "Knowledge-based": "knowledge_prompt",
    "Few-shot Learning": "few_shot_prompt",
    "Meta-prompting": "meta_prompt"
}

def create_assistant(model_name):
    client = OpenAI(
        base_url = "https://integrate.api.nvidia.com/v1",
        api_key = API_KEY
    )
    
    # Should use MODELS dictionary to get the actual model name
    model_name = MODELS[model_name]["name"]  # Add this line
    
    return AIAssistant(
        client=client,
        model=model_name
    )

def respond(
    message,
    history: list[tuple[str, str]],
    model_name,
    prompt_strategy,
    override_params: bool,
    max_tokens,
    temperature,
    top_p,
):
    assistant = create_assistant(model_name)
    
    # Get selected prompt template and system context
    prompt_template: PromptTemplate = prompts[PROMPT_STRATEGIES[prompt_strategy]]
    system_context: PromptTemplate = prompts["system_context"]
    
    # Format system context with the selected prompt strategy
    formatted_system_message = system_context.format(prompt_strategy=prompt_template.template)
    
    # Prepare messages with proper format
    messages = [{"role": "system", "content": formatted_system_message}]
    
    # Add conversation history
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": str(user_msg)})
        if assistant_msg:
            messages.append({"role": "assistant", "content": str(assistant_msg)})
    
    # Add current message
    messages.append({"role": "user", "content": str(message)})

    # Get generation parameters
    generation_params = prompt_template.parameters if not override_params else {
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p
    }

    try:
        for response in assistant.generate_response(
            prompt_template=prompt_template,
            generation_params=generation_params,
            stream=True,
            messages=messages
        ):
            yield response
    except Exception as e:
        yield f"Error: {str(e)}"

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=list(MODELS.keys()),
                value=list(MODELS.keys())[0],
                label="Select Model"
            )
            prompt_strategy_dropdown = gr.Dropdown(
                choices=list(PROMPT_STRATEGIES.keys()),
                value=list(PROMPT_STRATEGIES.keys())[0],
                label="Select Prompt Strategy"
            )
            
    with gr.Row():
        override_params = gr.Checkbox(
            label="Override Template Parameters",
            value=False
        )
    
    with gr.Row():
        with gr.Column(visible=False) as param_controls:
            max_tokens = gr.Slider(
                minimum=1,
                maximum=2048,
                value=512,
                step=1,
                label="Max new tokens"
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=4.0,
                value=0.7,
                step=0.1,
                label="Temperature"
            )
            top_p = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p (nucleus sampling)"
            )
    chatbot = gr.ChatInterface(
        fn=respond,
        additional_inputs=[
            model_dropdown,
            prompt_strategy_dropdown,
            override_params,
            max_tokens,
            temperature,
            top_p,
        ]
    )

    # Parameters and Prompt Details section below the chat
    with gr.Row(equal_height=True):
        with gr.Column(scale=1, min_width=300):
            with gr.Accordion("Current Prompt Details", open=False):
                system_prompt_display = gr.TextArea(
                    label="System Prompt",
                    interactive=False,
                    lines=20
                )
                current_messages_display = gr.JSON(
                    label="Full Conversation Context",
                )
    
    def toggle_param_controls(override):
        return gr.Column(visible=override)
    
    def update_prompt_display(prompt_strategy):
        prompt_template = prompts[PROMPT_STRATEGIES[prompt_strategy]]
        system_context = prompts["system_context"]
        formatted_system_message = system_context.format(prompt_strategy=prompt_template.template)
        
        return (
            formatted_system_message,
            {
                "Template Parameters": prompt_template.parameters,
                "Prompt Strategy": prompt_template.template
            }
        )
    
    # Update prompt display when strategy changes
    prompt_strategy_dropdown.change(
        update_prompt_display,
        inputs=[prompt_strategy_dropdown],
        outputs=[system_prompt_display, current_messages_display]
    )
    
    override_params.change(
        toggle_param_controls,
        inputs=[override_params],
        outputs=[param_controls]
    )

if __name__ == "__main__":
    demo.launch()