File size: 23,309 Bytes
ce90309
47fc4a0
ce90309
47fc4a0
 
 
ce90309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47fc4a0
ce90309
 
47fc4a0
ce90309
 
 
47fc4a0
ce90309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47fc4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce90309
 
 
 
 
47fc4a0
ce90309
 
 
47fc4a0
 
 
 
 
 
ce90309
47fc4a0
 
 
c691b46
47fc4a0
ce90309
47fc4a0
 
ce90309
 
 
 
 
 
47fc4a0
ce90309
47fc4a0
ce90309
47fc4a0
 
 
 
ce90309
47fc4a0
ce90309
 
 
 
 
 
 
 
 
 
 
 
c691b46
ce90309
 
 
 
c691b46
 
ce90309
9b91020
 
ce90309
 
 
 
 
 
 
 
9b91020
 
 
c691b46
ce90309
9b91020
 
c691b46
ce90309
 
 
 
 
 
 
 
c691b46
47fc4a0
ce90309
c691b46
 
9b91020
 
ce90309
 
c691b46
ce90309
 
 
 
 
 
c691b46
 
ce90309
 
 
 
 
 
0d2292c
c691b46
 
ce90309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47fc4a0
ce90309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c691b46
ce90309
 
c691b46
ce90309
 
 
 
c691b46
 
 
 
 
ce90309
c691b46
 
 
ce90309
 
 
 
 
 
 
 
 
c691b46
 
ce90309
c691b46
 
ce90309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c691b46
ce90309
 
 
47fc4a0
 
ce90309
47fc4a0
ce90309
 
47fc4a0
 
 
ce90309
47fc4a0
 
ce90309
 
 
 
47fc4a0
 
ce90309
47fc4a0
ce90309
47fc4a0
 
ce90309
47fc4a0
ce90309
 
 
47fc4a0
 
ce90309
 
 
47fc4a0
ce90309
47fc4a0
 
 
ce90309
 
 
47fc4a0
 
 
ce90309
47fc4a0
ce90309
47fc4a0
ce90309
 
 
 
 
 
 
47fc4a0
 
ce90309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47fc4a0
 
ce90309
 
47fc4a0
ce90309
 
 
 
 
47fc4a0
ce90309
 
 
 
47fc4a0
ce90309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47fc4a0
ce90309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47fc4a0
ce90309
47fc4a0
ce90309
47fc4a0
 
 
 
ce90309
47fc4a0
ce90309
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
# dream_app.py
import torch
import numpy as np
import gradio as gr
import spaces
import time
import re
from transformers import AutoModel, AutoTokenizer
from threading import Lock
from queue import Queue

# --- Configuration ---
MODEL_PATH = "Dream-org/Dream-v0-Instruct-7B"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

# --- Load Model and Tokenizer ---
print("Loading model and tokenizer...")
# Need configuration files for trust_remote_code
# Make sure config.json, configuration_dream.py, modeling_dream.py,
# generation_utils.py, generation_config.json are in the same directory
# or accessible in the Hugging Face cache.
model = AutoModel.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True
)
print("Model and tokenizer loaded.")

# --- Constants ---
# Get IDs from tokenizer/config if possible, otherwise hardcode from provided files
MASK_TOKEN = tokenizer.mask_token # Should be "<|mask|>"
try:
    MASK_ID = tokenizer.mask_token_id # Should be 151666
    if MASK_ID is None: raise AttributeError # Handle case where it might not be set directly
except AttributeError:
    print("Warning: Could not directly get mask_token_id, using hardcoded value 151666.")
    MASK_ID = 151666

try:
    EOS_ID = tokenizer.eos_token_id # Should be 151643
    PAD_ID = tokenizer.pad_token_id # Should be 151643
    if EOS_ID is None or PAD_ID is None: raise AttributeError
except AttributeError:
    print("Warning: Could not directly get eos/pad_token_id, using hardcoded value 151643.")
    EOS_ID = 151643
    PAD_ID = 151643

# Ensure MASK_TOKEN and MASK_ID are valid
if MASK_TOKEN is None or MASK_ID is None:
    raise ValueError("Mask token or ID is not defined correctly.")
if EOS_ID is None or PAD_ID is None:
     raise ValueError("EOS/PAD token or ID is not defined correctly.")

# --- Helper Functions ---

def parse_constraints(constraints_text):
    """Parse constraints in format: 'position:word, position:word, ...'"""
    constraints = {}
    if not constraints_text:
        return constraints

    parts = constraints_text.split(',')
    for part in parts:
        if ':' not in part:
            continue
        try:
            pos_str, word = part.split(':', 1)
            pos = int(pos_str.strip())
            word = word.strip()
            if word and pos >= 0:
                # Tokenize the word - handle potential multi-token words
                # Add space prefix for consistency, similar to how model might see words mid-sentence
                tokens = tokenizer.encode(" " + word, add_special_tokens=False)
                for i, token_id in enumerate(tokens):
                     constraints[pos + i] = token_id
        except ValueError:
            continue
        except Exception as e:
            print(f"Error parsing constraint part '{part}': {e}")
            continue

    return constraints

def format_chat_history(history):
    """
    Format chat history for the Dream model using its chat template logic.

    Args:
        history: List of [user_message, assistant_message] pairs

    Returns:
        Formatted list of message dictionaries for the model
    """
    messages = []
     # Add system prompt if history is empty or doesn't start with system
    if not history or history[0][0].lower() != 'system':
         # Check if the tokenizer's template expects an explicit system message
         # The template provided in tokenizer_config.json handles adding a default one
         pass # Let apply_chat_template handle the default system message

    for user_msg, assistant_msg in history:
        if user_msg: # Handle potential initial system message possibility if needed
             messages.append({"role": "user", "content": user_msg})
        if assistant_msg is not None: # Skip if None (for the latest user message)
            messages.append({"role": "assistant", "content": assistant_msg})

    return messages

# --- Core Generation Logic with Visualization ---

# Use a thread-safe queue to pass visualization states from the hook
vis_queue = Queue()
# Lock to prevent race conditions when accessing shared state like previous_x
state_lock = Lock()
# Store the previous state for comparison in the hook
previous_x_shared = None

@spaces.GPU
def generate_response_with_visualization(
    messages, # List of message dicts from format_chat_history
    max_new_tokens=64,
    steps=64, # Default steps based on README example
    constraints=None,
    temperature=0.6, # Default from demo_token_control
    top_p=0.95,      # Default from demos
    alg="entropy",   # Default from demos
    alg_temp=0.1,    # Default from demo_multiturn_chat
):
    """
    Generate text with Dream model and capture visualization states using a hook.

    Args:
        messages: List of message dictionaries with 'role' and 'content'.
        max_new_tokens: Max tokens to generate.
        steps: Diffusion steps.
        constraints: Dictionary mapping positions (relative to response start) to token IDs.
        temperature: Sampling temperature.
        top_p: Nucleus sampling p.
        alg: Remasking algorithm ('origin', 'entropy', 'maskgit_plus', 'topk_margin').
        alg_temp: Temperature for confidence-based algorithms.

    Returns:
        Tuple: (List of visualization states, final generated text string)
    """
    global previous_x_shared, vis_queue
    if constraints is None:
        constraints = {}

    visualization_states = []

    # Clear the queue for a new generation
    while not vis_queue.empty():
        try:
            vis_queue.get_nowait()
        except Queue.Empty:
            break

    # Prepare the prompt using chat template
    # The template automatically adds the generation prompt like "<|im_start|>assistant\n"
    try:
        inputs = tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            add_generation_prompt=True,
            return_dict=True
        )
        input_ids = inputs.input_ids.to(device=DEVICE)
        # Dream doesn't seem to explicitly use attention_mask in simple demos,
        # but it's good practice if padding were involved.
        # For now, assume no padding in this interactive demo.
        attention_mask = inputs.attention_mask.to(device=DEVICE) if 'attention_mask' in inputs else None

    except Exception as e:
        print(f"Error applying chat template: {e}")
        # Provide a fallback or error state
        error_state = [("Error in chat formatting.", "red")]
        return [error_state], f"Error: Could not format chat history. {e}"

    prompt_length = input_ids.shape[1]
    total_length = prompt_length + max_new_tokens

    # --- Define the Hook Function ---
    def generation_tokens_hook_func(step, x, logits):
        global previous_x_shared, vis_queue
        with state_lock: # Ensure thread safety if needed, though hooks might run sequentially
            current_x = x.clone() # Shape: (batch_size, total_length)

            # --- Apply Constraints ---
            # Constraints are relative to the start of the *response*
            for rel_pos, token_id in constraints.items():
                abs_pos = prompt_length + rel_pos
                if 0 <= abs_pos < current_x.shape[1]:
                    # Ensure constraint application doesn't go out of bounds
                    # Apply constraint for the first batch element (batch size is 1 here)
                    current_x[0, abs_pos] = token_id

            # --- Create Visualization State ---
            current_vis_state = []
            x_response = current_x[0, prompt_length:] # Get the response part for batch 0
            prev_x_response = previous_x_shared[0, prompt_length:] if previous_x_shared is not None else None

            for i in range(max_new_tokens):
                current_token_id = x_response[i].item()
                token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for vis

                # Clean up visual representation of special tokens
                if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
                     token_str = "[EOS/PAD]" # Make it visually distinct
                elif token_str == tokenizer.mask_token:
                     token_str = "[MASK]"
                elif token_str.strip() == "": # Handle empty strings from decoding potentially odd tokens
                    token_str = "[UNK/SPACE]"


                color = "#DDDDDD" # Default background

                if current_token_id == MASK_ID:
                    color = "#444444" # Dark gray for masks
                elif prev_x_response is not None and prev_x_response[i].item() == MASK_ID:
                    # Token was mask, now it's revealed in this step
                     # Use green for newly revealed
                    color = "#66CC66" # Light green
                else:
                     # Token was already revealed in a previous step or is a constraint
                     # Check if it's a constraint applied *now*
                     is_constraint = (prompt_length + i - prompt_length) in constraints and \
                                      constraints[prompt_length + i - prompt_length] == current_token_id

                     if is_constraint:
                         color = "#FFD700" # Gold for constraints
                     else:
                         color = "#6699CC" # Light blue for previously revealed

                current_vis_state.append((token_str, color))

            # --- Update shared state and put vis state in queue ---
            previous_x_shared = current_x.clone() # Update for the *next* step's comparison
            vis_queue.put(current_vis_state)

            # The hook must return the potentially modified tensor `x`
            return current_x
    # --- End of Hook Function ---

    # Initialize previous_x_shared before generation starts
    # Create initial masked state for visualization
    initial_x = input_ids.clone()
    if initial_x.shape[1] < total_length:
         padding = torch.full((1, total_length - initial_x.shape[1]), MASK_ID, dtype=torch.long, device=DEVICE)
         initial_x = torch.cat([initial_x, padding], dim=1)
    else:
         initial_x = initial_x[:, :total_length] # Truncate if prompt is too long

    # Apply initial constraints to the starting state
    for rel_pos, token_id in constraints.items():
        abs_pos = prompt_length + rel_pos
        if 0 <= abs_pos < initial_x.shape[1]:
             initial_x[0, abs_pos] = token_id

    with state_lock:
        previous_x_shared = initial_x.clone()

    # Add the initial all-masked state (or with constraints) to the visualization queue
    initial_vis_state = []
    initial_x_response = initial_x[0, prompt_length:]
    for i in range(max_new_tokens):
         token_id = initial_x_response[i].item()
         if token_id == MASK_ID:
              initial_vis_state.append((MASK_TOKEN, "#444444"))
         else:
              # Must be a pre-applied constraint
              token_str = tokenizer.decode([token_id], skip_special_tokens=False)
              if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
                   token_str = "[EOS/PAD]"
              elif token_str.strip() == "":
                   token_str = "[UNK/SPACE]"
              initial_vis_state.append((token_str, "#FFD700")) # Gold for constraints
    vis_queue.put(initial_vis_state)


    # --- Run Generation ---
    try:
        # output_history=False because the hook handles state capture
        # return_dict_in_generate=True to get the GenerationOutput object
        output = model.diffusion_generate(
            initial_x, # Start with the potentially constraint-applied tensor
            attention_mask=None, # Assuming no padding needed for interactive use
            max_new_tokens=max_new_tokens, # This might not be strictly needed if total_length is fixed
            output_history=False,
            return_dict_in_generate=True,
            steps=steps,
            temperature=temperature,
            top_p=top_p,
            alg=alg,
            alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence algs
            generation_tokens_hook_func=generation_tokens_hook_func
        )

        final_sequence = output.sequences[0] # Batch size 1

        # Decode the final response text, cleaning up special tokens
        response_tokens = final_sequence[prompt_length:]
        # Filter out EOS/PAD tokens for the final text display
        response_tokens_filtered = [tok for tok in response_tokens.tolist() if tok != EOS_ID and tok != PAD_ID]
        final_text = tokenizer.decode(response_tokens_filtered,
                                      skip_special_tokens=True,
                                      clean_up_tokenization_spaces=True) # Standard cleanup

    except Exception as e:
        print(f"Error during generation: {e}")
        import traceback
        traceback.print_exc()
        # Provide error state
        error_state = [("Generation Error.", "red")]
        visualization_states.append(error_state)
        final_text = f"Error: Generation failed. {e}"
        # Add any states captured before the error
        while not vis_queue.empty():
             try:
                 visualization_states.append(vis_queue.get_nowait())
             except Queue.Empty:
                 break
        return visualization_states, final_text

    # Retrieve all visualization states captured by the hook
    while not vis_queue.empty():
        try:
            visualization_states.append(vis_queue.get_nowait())
        except Queue.Empty:
            break

    # If somehow no states were captured, add the initial one
    if not visualization_states:
         visualization_states.append(initial_vis_state)


    return visualization_states, final_text.strip()


# --- Gradio UI ---

css = '''
.category-legend{display:none}
button{height: 60px}
'''
def create_chatbot_demo():
    with gr.Blocks(css=css) as demo:
        gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
        gr.Markdown("Chat with the Dream 7B Instruct model and visualize the diffusion generation process.")
        gr.Markdown("Model: [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)")

        # STATE MANAGEMENT
        chat_history = gr.State([])

        # UI COMPONENTS
        with gr.Row():
            with gr.Column(scale=3):
                chatbot_ui = gr.Chatbot(label="Conversation", height=500, avatar_images=["user.png", "robot.png"])

                # Message input
                with gr.Group():
                    with gr.Row():
                        user_input = gr.Textbox(
                            label="Your Message",
                            placeholder="Type your message here...",
                            show_label=False,
                            scale=9
                        )
                        send_btn = gr.Button("Send", scale=1)

                constraints_input = gr.Textbox(
                    label="Word Constraints (Optional)",
                    info="Place words at specific positions (0-indexed from response start). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:a'",
                    placeholder="0:Once, 5:upon, 10:a",
                    value=""
                )
            with gr.Column(scale=2):
                output_vis = gr.HighlightedText(
                    label="Diffusion Process Visualization",
                    combine_adjacent=False,
                    show_legend=True, # Keep legend hidden via CSS if desired
                )
                # Legend (colors defined in generate_response_with_visualization)
                gr.Markdown(
                    "<small>Color Legend: <span style='background-color:#444444; color:white;'>[MASK]</span>"
                    " <span style='background-color:#66CC66;'>Newly Revealed</span>"
                    " <span style='background-color:#6699CC;'>Previously Revealed</span>"
                    " <span style='background-color:#FFD700;'>Constraint</span>"
                     " <span style='background-color:#DDDDDD;'>[EOS/PAD/UNK]</span></small>"
                )

        # Advanced generation settings
        with gr.Accordion("Generation Settings", open=False):
            max_new_tokens_slider = gr.Slider(
                minimum=16, maximum=512, value=128, step=16, # Increased default/max
                label="Max New Tokens (Generation Length)"
            )
            steps_slider = gr.Slider(
                minimum=8, maximum=512, value=128, step=8,   # Increased default/max
                label="Diffusion Steps"
            )
            temp_slider = gr.Slider(
                minimum=0.0, maximum=1.0, value=0.6, step=0.05, # Finer steps for temp
                label="Temperature"
            )
            top_p_slider = gr.Slider(
                minimum=0.0, maximum=1.0, value=0.95, step=0.05,
                label="Top-P (Nucleus Sampling)"
            )
            alg_radio = gr.Radio(
                # Choices from README
                choices=['origin', 'entropy', 'maskgit_plus', 'topk_margin'],
                value='entropy',
                label="Remasking Algorithm"
            )
            alg_temp_slider = gr.Slider(
                minimum=0.0, maximum=1.0, value=0.1, step=0.05,
                label="Algorithm Temperature (for confidence-based algs)"
            )
            vis_delay_slider = gr.Slider(
                minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay
                label="Visualization Delay (seconds)"
            )

        # Clear button
        clear_btn = gr.Button("Clear Conversation")

        # HELPER FUNCTIONS (UI Logic)
        def add_message_to_history(history, message, response):
            """Add a message pair to the history state"""
            new_history = history + [[message, response]]
            return new_history

        def user_message_submitted(message, history):
            """ Handle user sending a message: update history, clear input """
            if not message or message.strip() == "":
                return history, history, "", [] # No change if empty

            # Add user message, response is initially None
            new_history = add_message_to_history(history, message, None)

            # Prepare display version (immediately shows user message)
            display_history = new_history

            # Clear input box
            message_out = ""

            # Clear visualization
            vis_out = []

            return new_history, display_history, message_out, vis_out

        def bot_response_generator(history, constraints_str, max_tokens, steps, temp, top_p, alg, alg_temp, delay):
            """ Generator function to stream bot response and visualization """
            if not history or history[-1][1] is not None: # Ensure there's a user msg waiting for response
                print("Warning: Bot response triggered without pending user message.")
                yield history, [], "Error: No user message to respond to." # Send error state back?
                return

            # Get the full conversation history formatted for the model
            last_user_message = history[-1][0]
            messages_for_model = format_chat_history(history[:-1]) # History *before* the last user msg
            messages_for_model.append({"role": "user", "content": last_user_message})

            # Parse constraints
            try:
                parsed_constraints = parse_constraints(constraints_str)
            except Exception as e:
                 print(f"Error parsing constraints: {e}")
                 yield history, [("Constraint Error", "red")], f"Error: Failed to parse constraints: {e}"
                 return

            # Generate response and visualization states
            try:
                 vis_states, final_response_text = generate_response_with_visualization(
                    messages_for_model,
                    max_new_tokens=max_tokens,
                    steps=steps,
                    constraints=parsed_constraints,
                    temperature=temp,
                    top_p=top_p,
                    alg=alg,
                    alg_temp=alg_temp
                 )
            except Exception as e:
                print(f"Error in generate_response_with_visualization: {e}")
                import traceback
                traceback.print_exc()
                yield history, [("Generation Error", "red")], f"Error: Generation failed: {e}"
                return

            # Update the history state with the final response *once*
            history[-1][1] = final_response_text # Update the None placeholder

            # Yield initial state immediately
            if vis_states:
                yield history, vis_states[0]
            else:
                 yield history, [] # Should not happen if generation worked

            # Stream intermediate visualization states
            for state in vis_states[1:]:
                time.sleep(delay)
                yield history, state

            # Final yield ensures the chatbot UI has the complete history
            # The last state in vis_states should already be yielded by the loop
            # yield history, vis_states[-1] if vis_states else []


        def clear_conversation():
            """Clear the conversation history and visualization"""
            return [], [], "", [] # history, chatbot_ui, user_input, output_vis

        # EVENT HANDLERS

        # User presses Enter or Send button
        submit_args = {
             "fn": user_message_submitted,
             "inputs": [user_input, chat_history],
             "outputs": [chat_history, chatbot_ui, user_input, output_vis]
        }
        user_input.submit(**submit_args)
        send_btn.click(**submit_args)

        # After user message is submitted, trigger bot response generation
        generate_args = {
            "fn": bot_response_generator,
            "inputs": [
                chat_history, constraints_input, max_new_tokens_slider, steps_slider,
                temp_slider, top_p_slider, alg_radio, alg_temp_slider, vis_delay_slider
            ],
            "outputs": [chatbot_ui, output_vis] # Update chatbot history and visualization
        }
        # Trigger generation after submit OR click
        user_input.submit(None, None, None, queue=True).then(**generate_args)
        send_btn.click(None, None, None, queue=True).then(**generate_args)


        # Clear button handler
        clear_btn.click(
            fn=clear_conversation,
            inputs=[],
            outputs=[chat_history, chatbot_ui, user_input, output_vis]
        )

    return demo

# Launch the demo
if __name__ == "__main__":
    demo = create_chatbot_demo()
    # queue() allows streaming and handling multiple users
    # share=True creates a public link (use with caution)
    demo.queue().launch(share=True, debug=True)