Spaces:
Running
on
Zero
Running
on
Zero
File size: 23,309 Bytes
ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 c691b46 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 c691b46 ce90309 c691b46 ce90309 9b91020 ce90309 9b91020 c691b46 ce90309 9b91020 c691b46 ce90309 c691b46 47fc4a0 ce90309 c691b46 9b91020 ce90309 c691b46 ce90309 c691b46 ce90309 0d2292c c691b46 ce90309 47fc4a0 ce90309 c691b46 ce90309 c691b46 ce90309 c691b46 ce90309 c691b46 ce90309 c691b46 ce90309 c691b46 ce90309 c691b46 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 47fc4a0 ce90309 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 |
# dream_app.py
import torch
import numpy as np
import gradio as gr
import spaces
import time
import re
from transformers import AutoModel, AutoTokenizer
from threading import Lock
from queue import Queue
# --- Configuration ---
MODEL_PATH = "Dream-org/Dream-v0-Instruct-7B"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")
# --- Load Model and Tokenizer ---
print("Loading model and tokenizer...")
# Need configuration files for trust_remote_code
# Make sure config.json, configuration_dream.py, modeling_dream.py,
# generation_utils.py, generation_config.json are in the same directory
# or accessible in the Hugging Face cache.
model = AutoModel.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
print("Model and tokenizer loaded.")
# --- Constants ---
# Get IDs from tokenizer/config if possible, otherwise hardcode from provided files
MASK_TOKEN = tokenizer.mask_token # Should be "<|mask|>"
try:
MASK_ID = tokenizer.mask_token_id # Should be 151666
if MASK_ID is None: raise AttributeError # Handle case where it might not be set directly
except AttributeError:
print("Warning: Could not directly get mask_token_id, using hardcoded value 151666.")
MASK_ID = 151666
try:
EOS_ID = tokenizer.eos_token_id # Should be 151643
PAD_ID = tokenizer.pad_token_id # Should be 151643
if EOS_ID is None or PAD_ID is None: raise AttributeError
except AttributeError:
print("Warning: Could not directly get eos/pad_token_id, using hardcoded value 151643.")
EOS_ID = 151643
PAD_ID = 151643
# Ensure MASK_TOKEN and MASK_ID are valid
if MASK_TOKEN is None or MASK_ID is None:
raise ValueError("Mask token or ID is not defined correctly.")
if EOS_ID is None or PAD_ID is None:
raise ValueError("EOS/PAD token or ID is not defined correctly.")
# --- Helper Functions ---
def parse_constraints(constraints_text):
"""Parse constraints in format: 'position:word, position:word, ...'"""
constraints = {}
if not constraints_text:
return constraints
parts = constraints_text.split(',')
for part in parts:
if ':' not in part:
continue
try:
pos_str, word = part.split(':', 1)
pos = int(pos_str.strip())
word = word.strip()
if word and pos >= 0:
# Tokenize the word - handle potential multi-token words
# Add space prefix for consistency, similar to how model might see words mid-sentence
tokens = tokenizer.encode(" " + word, add_special_tokens=False)
for i, token_id in enumerate(tokens):
constraints[pos + i] = token_id
except ValueError:
continue
except Exception as e:
print(f"Error parsing constraint part '{part}': {e}")
continue
return constraints
def format_chat_history(history):
"""
Format chat history for the Dream model using its chat template logic.
Args:
history: List of [user_message, assistant_message] pairs
Returns:
Formatted list of message dictionaries for the model
"""
messages = []
# Add system prompt if history is empty or doesn't start with system
if not history or history[0][0].lower() != 'system':
# Check if the tokenizer's template expects an explicit system message
# The template provided in tokenizer_config.json handles adding a default one
pass # Let apply_chat_template handle the default system message
for user_msg, assistant_msg in history:
if user_msg: # Handle potential initial system message possibility if needed
messages.append({"role": "user", "content": user_msg})
if assistant_msg is not None: # Skip if None (for the latest user message)
messages.append({"role": "assistant", "content": assistant_msg})
return messages
# --- Core Generation Logic with Visualization ---
# Use a thread-safe queue to pass visualization states from the hook
vis_queue = Queue()
# Lock to prevent race conditions when accessing shared state like previous_x
state_lock = Lock()
# Store the previous state for comparison in the hook
previous_x_shared = None
@spaces.GPU
def generate_response_with_visualization(
messages, # List of message dicts from format_chat_history
max_new_tokens=64,
steps=64, # Default steps based on README example
constraints=None,
temperature=0.6, # Default from demo_token_control
top_p=0.95, # Default from demos
alg="entropy", # Default from demos
alg_temp=0.1, # Default from demo_multiturn_chat
):
"""
Generate text with Dream model and capture visualization states using a hook.
Args:
messages: List of message dictionaries with 'role' and 'content'.
max_new_tokens: Max tokens to generate.
steps: Diffusion steps.
constraints: Dictionary mapping positions (relative to response start) to token IDs.
temperature: Sampling temperature.
top_p: Nucleus sampling p.
alg: Remasking algorithm ('origin', 'entropy', 'maskgit_plus', 'topk_margin').
alg_temp: Temperature for confidence-based algorithms.
Returns:
Tuple: (List of visualization states, final generated text string)
"""
global previous_x_shared, vis_queue
if constraints is None:
constraints = {}
visualization_states = []
# Clear the queue for a new generation
while not vis_queue.empty():
try:
vis_queue.get_nowait()
except Queue.Empty:
break
# Prepare the prompt using chat template
# The template automatically adds the generation prompt like "<|im_start|>assistant\n"
try:
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
return_dict=True
)
input_ids = inputs.input_ids.to(device=DEVICE)
# Dream doesn't seem to explicitly use attention_mask in simple demos,
# but it's good practice if padding were involved.
# For now, assume no padding in this interactive demo.
attention_mask = inputs.attention_mask.to(device=DEVICE) if 'attention_mask' in inputs else None
except Exception as e:
print(f"Error applying chat template: {e}")
# Provide a fallback or error state
error_state = [("Error in chat formatting.", "red")]
return [error_state], f"Error: Could not format chat history. {e}"
prompt_length = input_ids.shape[1]
total_length = prompt_length + max_new_tokens
# --- Define the Hook Function ---
def generation_tokens_hook_func(step, x, logits):
global previous_x_shared, vis_queue
with state_lock: # Ensure thread safety if needed, though hooks might run sequentially
current_x = x.clone() # Shape: (batch_size, total_length)
# --- Apply Constraints ---
# Constraints are relative to the start of the *response*
for rel_pos, token_id in constraints.items():
abs_pos = prompt_length + rel_pos
if 0 <= abs_pos < current_x.shape[1]:
# Ensure constraint application doesn't go out of bounds
# Apply constraint for the first batch element (batch size is 1 here)
current_x[0, abs_pos] = token_id
# --- Create Visualization State ---
current_vis_state = []
x_response = current_x[0, prompt_length:] # Get the response part for batch 0
prev_x_response = previous_x_shared[0, prompt_length:] if previous_x_shared is not None else None
for i in range(max_new_tokens):
current_token_id = x_response[i].item()
token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for vis
# Clean up visual representation of special tokens
if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
token_str = "[EOS/PAD]" # Make it visually distinct
elif token_str == tokenizer.mask_token:
token_str = "[MASK]"
elif token_str.strip() == "": # Handle empty strings from decoding potentially odd tokens
token_str = "[UNK/SPACE]"
color = "#DDDDDD" # Default background
if current_token_id == MASK_ID:
color = "#444444" # Dark gray for masks
elif prev_x_response is not None and prev_x_response[i].item() == MASK_ID:
# Token was mask, now it's revealed in this step
# Use green for newly revealed
color = "#66CC66" # Light green
else:
# Token was already revealed in a previous step or is a constraint
# Check if it's a constraint applied *now*
is_constraint = (prompt_length + i - prompt_length) in constraints and \
constraints[prompt_length + i - prompt_length] == current_token_id
if is_constraint:
color = "#FFD700" # Gold for constraints
else:
color = "#6699CC" # Light blue for previously revealed
current_vis_state.append((token_str, color))
# --- Update shared state and put vis state in queue ---
previous_x_shared = current_x.clone() # Update for the *next* step's comparison
vis_queue.put(current_vis_state)
# The hook must return the potentially modified tensor `x`
return current_x
# --- End of Hook Function ---
# Initialize previous_x_shared before generation starts
# Create initial masked state for visualization
initial_x = input_ids.clone()
if initial_x.shape[1] < total_length:
padding = torch.full((1, total_length - initial_x.shape[1]), MASK_ID, dtype=torch.long, device=DEVICE)
initial_x = torch.cat([initial_x, padding], dim=1)
else:
initial_x = initial_x[:, :total_length] # Truncate if prompt is too long
# Apply initial constraints to the starting state
for rel_pos, token_id in constraints.items():
abs_pos = prompt_length + rel_pos
if 0 <= abs_pos < initial_x.shape[1]:
initial_x[0, abs_pos] = token_id
with state_lock:
previous_x_shared = initial_x.clone()
# Add the initial all-masked state (or with constraints) to the visualization queue
initial_vis_state = []
initial_x_response = initial_x[0, prompt_length:]
for i in range(max_new_tokens):
token_id = initial_x_response[i].item()
if token_id == MASK_ID:
initial_vis_state.append((MASK_TOKEN, "#444444"))
else:
# Must be a pre-applied constraint
token_str = tokenizer.decode([token_id], skip_special_tokens=False)
if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
token_str = "[EOS/PAD]"
elif token_str.strip() == "":
token_str = "[UNK/SPACE]"
initial_vis_state.append((token_str, "#FFD700")) # Gold for constraints
vis_queue.put(initial_vis_state)
# --- Run Generation ---
try:
# output_history=False because the hook handles state capture
# return_dict_in_generate=True to get the GenerationOutput object
output = model.diffusion_generate(
initial_x, # Start with the potentially constraint-applied tensor
attention_mask=None, # Assuming no padding needed for interactive use
max_new_tokens=max_new_tokens, # This might not be strictly needed if total_length is fixed
output_history=False,
return_dict_in_generate=True,
steps=steps,
temperature=temperature,
top_p=top_p,
alg=alg,
alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence algs
generation_tokens_hook_func=generation_tokens_hook_func
)
final_sequence = output.sequences[0] # Batch size 1
# Decode the final response text, cleaning up special tokens
response_tokens = final_sequence[prompt_length:]
# Filter out EOS/PAD tokens for the final text display
response_tokens_filtered = [tok for tok in response_tokens.tolist() if tok != EOS_ID and tok != PAD_ID]
final_text = tokenizer.decode(response_tokens_filtered,
skip_special_tokens=True,
clean_up_tokenization_spaces=True) # Standard cleanup
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
# Provide error state
error_state = [("Generation Error.", "red")]
visualization_states.append(error_state)
final_text = f"Error: Generation failed. {e}"
# Add any states captured before the error
while not vis_queue.empty():
try:
visualization_states.append(vis_queue.get_nowait())
except Queue.Empty:
break
return visualization_states, final_text
# Retrieve all visualization states captured by the hook
while not vis_queue.empty():
try:
visualization_states.append(vis_queue.get_nowait())
except Queue.Empty:
break
# If somehow no states were captured, add the initial one
if not visualization_states:
visualization_states.append(initial_vis_state)
return visualization_states, final_text.strip()
# --- Gradio UI ---
css = '''
.category-legend{display:none}
button{height: 60px}
'''
def create_chatbot_demo():
with gr.Blocks(css=css) as demo:
gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
gr.Markdown("Chat with the Dream 7B Instruct model and visualize the diffusion generation process.")
gr.Markdown("Model: [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)")
# STATE MANAGEMENT
chat_history = gr.State([])
# UI COMPONENTS
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(label="Conversation", height=500, avatar_images=["user.png", "robot.png"])
# Message input
with gr.Group():
with gr.Row():
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
show_label=False,
scale=9
)
send_btn = gr.Button("Send", scale=1)
constraints_input = gr.Textbox(
label="Word Constraints (Optional)",
info="Place words at specific positions (0-indexed from response start). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:a'",
placeholder="0:Once, 5:upon, 10:a",
value=""
)
with gr.Column(scale=2):
output_vis = gr.HighlightedText(
label="Diffusion Process Visualization",
combine_adjacent=False,
show_legend=True, # Keep legend hidden via CSS if desired
)
# Legend (colors defined in generate_response_with_visualization)
gr.Markdown(
"<small>Color Legend: <span style='background-color:#444444; color:white;'>[MASK]</span>"
" <span style='background-color:#66CC66;'>Newly Revealed</span>"
" <span style='background-color:#6699CC;'>Previously Revealed</span>"
" <span style='background-color:#FFD700;'>Constraint</span>"
" <span style='background-color:#DDDDDD;'>[EOS/PAD/UNK]</span></small>"
)
# Advanced generation settings
with gr.Accordion("Generation Settings", open=False):
max_new_tokens_slider = gr.Slider(
minimum=16, maximum=512, value=128, step=16, # Increased default/max
label="Max New Tokens (Generation Length)"
)
steps_slider = gr.Slider(
minimum=8, maximum=512, value=128, step=8, # Increased default/max
label="Diffusion Steps"
)
temp_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.6, step=0.05, # Finer steps for temp
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.95, step=0.05,
label="Top-P (Nucleus Sampling)"
)
alg_radio = gr.Radio(
# Choices from README
choices=['origin', 'entropy', 'maskgit_plus', 'topk_margin'],
value='entropy',
label="Remasking Algorithm"
)
alg_temp_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
label="Algorithm Temperature (for confidence-based algs)"
)
vis_delay_slider = gr.Slider(
minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay
label="Visualization Delay (seconds)"
)
# Clear button
clear_btn = gr.Button("Clear Conversation")
# HELPER FUNCTIONS (UI Logic)
def add_message_to_history(history, message, response):
"""Add a message pair to the history state"""
new_history = history + [[message, response]]
return new_history
def user_message_submitted(message, history):
""" Handle user sending a message: update history, clear input """
if not message or message.strip() == "":
return history, history, "", [] # No change if empty
# Add user message, response is initially None
new_history = add_message_to_history(history, message, None)
# Prepare display version (immediately shows user message)
display_history = new_history
# Clear input box
message_out = ""
# Clear visualization
vis_out = []
return new_history, display_history, message_out, vis_out
def bot_response_generator(history, constraints_str, max_tokens, steps, temp, top_p, alg, alg_temp, delay):
""" Generator function to stream bot response and visualization """
if not history or history[-1][1] is not None: # Ensure there's a user msg waiting for response
print("Warning: Bot response triggered without pending user message.")
yield history, [], "Error: No user message to respond to." # Send error state back?
return
# Get the full conversation history formatted for the model
last_user_message = history[-1][0]
messages_for_model = format_chat_history(history[:-1]) # History *before* the last user msg
messages_for_model.append({"role": "user", "content": last_user_message})
# Parse constraints
try:
parsed_constraints = parse_constraints(constraints_str)
except Exception as e:
print(f"Error parsing constraints: {e}")
yield history, [("Constraint Error", "red")], f"Error: Failed to parse constraints: {e}"
return
# Generate response and visualization states
try:
vis_states, final_response_text = generate_response_with_visualization(
messages_for_model,
max_new_tokens=max_tokens,
steps=steps,
constraints=parsed_constraints,
temperature=temp,
top_p=top_p,
alg=alg,
alg_temp=alg_temp
)
except Exception as e:
print(f"Error in generate_response_with_visualization: {e}")
import traceback
traceback.print_exc()
yield history, [("Generation Error", "red")], f"Error: Generation failed: {e}"
return
# Update the history state with the final response *once*
history[-1][1] = final_response_text # Update the None placeholder
# Yield initial state immediately
if vis_states:
yield history, vis_states[0]
else:
yield history, [] # Should not happen if generation worked
# Stream intermediate visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield history, state
# Final yield ensures the chatbot UI has the complete history
# The last state in vis_states should already be yielded by the loop
# yield history, vis_states[-1] if vis_states else []
def clear_conversation():
"""Clear the conversation history and visualization"""
return [], [], "", [] # history, chatbot_ui, user_input, output_vis
# EVENT HANDLERS
# User presses Enter or Send button
submit_args = {
"fn": user_message_submitted,
"inputs": [user_input, chat_history],
"outputs": [chat_history, chatbot_ui, user_input, output_vis]
}
user_input.submit(**submit_args)
send_btn.click(**submit_args)
# After user message is submitted, trigger bot response generation
generate_args = {
"fn": bot_response_generator,
"inputs": [
chat_history, constraints_input, max_new_tokens_slider, steps_slider,
temp_slider, top_p_slider, alg_radio, alg_temp_slider, vis_delay_slider
],
"outputs": [chatbot_ui, output_vis] # Update chatbot history and visualization
}
# Trigger generation after submit OR click
user_input.submit(None, None, None, queue=True).then(**generate_args)
send_btn.click(None, None, None, queue=True).then(**generate_args)
# Clear button handler
clear_btn.click(
fn=clear_conversation,
inputs=[],
outputs=[chat_history, chatbot_ui, user_input, output_vis]
)
return demo
# Launch the demo
if __name__ == "__main__":
demo = create_chatbot_demo()
# queue() allows streaming and handling multiple users
# share=True creates a public link (use with caution)
demo.queue().launch(share=True, debug=True) |