Spaces:
Running
on
Zero
Running
on
Zero
# dream_app.py | |
import torch | |
import numpy as np | |
import gradio as gr | |
import spaces | |
import time | |
import re | |
from transformers import AutoModel, AutoTokenizer | |
from threading import Lock | |
from queue import Queue | |
# --- Configuration --- | |
MODEL_PATH = "Dream-org/Dream-v0-Instruct-7B" | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {DEVICE}") | |
# --- Load Model and Tokenizer --- | |
print("Loading model and tokenizer...") | |
# Need configuration files for trust_remote_code | |
# Make sure config.json, configuration_dream.py, modeling_dream.py, | |
# generation_utils.py, generation_config.json are in the same directory | |
# or accessible in the Hugging Face cache. | |
model = AutoModel.from_pretrained( | |
MODEL_PATH, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
).to(DEVICE).eval() | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATH, | |
trust_remote_code=True | |
) | |
print("Model and tokenizer loaded.") | |
# --- Constants --- | |
# Get IDs from tokenizer/config if possible, otherwise hardcode from provided files | |
MASK_TOKEN = tokenizer.mask_token # Should be "<|mask|>" | |
try: | |
MASK_ID = tokenizer.mask_token_id # Should be 151666 | |
if MASK_ID is None: raise AttributeError # Handle case where it might not be set directly | |
except AttributeError: | |
print("Warning: Could not directly get mask_token_id, using hardcoded value 151666.") | |
MASK_ID = 151666 | |
try: | |
EOS_ID = tokenizer.eos_token_id # Should be 151643 | |
PAD_ID = tokenizer.pad_token_id # Should be 151643 | |
if EOS_ID is None or PAD_ID is None: raise AttributeError | |
except AttributeError: | |
print("Warning: Could not directly get eos/pad_token_id, using hardcoded value 151643.") | |
EOS_ID = 151643 | |
PAD_ID = 151643 | |
# Ensure MASK_TOKEN and MASK_ID are valid | |
if MASK_TOKEN is None or MASK_ID is None: | |
raise ValueError("Mask token or ID is not defined correctly.") | |
if EOS_ID is None or PAD_ID is None: | |
raise ValueError("EOS/PAD token or ID is not defined correctly.") | |
# --- Helper Functions --- | |
def parse_constraints(constraints_text): | |
"""Parse constraints in format: 'position:word, position:word, ...'""" | |
constraints = {} | |
if not constraints_text: | |
return constraints | |
parts = constraints_text.split(',') | |
for part in parts: | |
if ':' not in part: | |
continue | |
try: | |
pos_str, word = part.split(':', 1) | |
pos = int(pos_str.strip()) | |
word = word.strip() | |
if word and pos >= 0: | |
# Tokenize the word - handle potential multi-token words | |
# Add space prefix for consistency, similar to how model might see words mid-sentence | |
tokens = tokenizer.encode(" " + word, add_special_tokens=False) | |
for i, token_id in enumerate(tokens): | |
constraints[pos + i] = token_id | |
except ValueError: | |
continue | |
except Exception as e: | |
print(f"Error parsing constraint part '{part}': {e}") | |
continue | |
return constraints | |
def format_chat_history(history): | |
""" | |
Format chat history for the Dream model using its chat template logic. | |
Args: | |
history: List of [user_message, assistant_message] pairs | |
Returns: | |
Formatted list of message dictionaries for the model | |
""" | |
messages = [] | |
# Add system prompt if history is empty or doesn't start with system | |
if not history or history[0][0].lower() != 'system': | |
# Check if the tokenizer's template expects an explicit system message | |
# The template provided in tokenizer_config.json handles adding a default one | |
pass # Let apply_chat_template handle the default system message | |
for user_msg, assistant_msg in history: | |
if user_msg: # Handle potential initial system message possibility if needed | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg is not None: # Skip if None (for the latest user message) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
return messages | |
# --- Core Generation Logic with Visualization --- | |
# Use a thread-safe queue to pass visualization states from the hook | |
vis_queue = Queue() | |
# Lock to prevent race conditions when accessing shared state like previous_x | |
state_lock = Lock() | |
# Store the previous state for comparison in the hook | |
previous_x_shared = None | |
def generate_response_with_visualization( | |
messages, # List of message dicts from format_chat_history | |
max_new_tokens=64, | |
steps=64, # Default steps based on README example | |
constraints=None, | |
temperature=0.6, # Default from demo_token_control | |
top_p=0.95, # Default from demos | |
alg="entropy", # Default from demos | |
alg_temp=0.1, # Default from demo_multiturn_chat | |
): | |
""" | |
Generate text with Dream model and capture visualization states using a hook. | |
Args: | |
messages: List of message dictionaries with 'role' and 'content'. | |
max_new_tokens: Max tokens to generate. | |
steps: Diffusion steps. | |
constraints: Dictionary mapping positions (relative to response start) to token IDs. | |
temperature: Sampling temperature. | |
top_p: Nucleus sampling p. | |
alg: Remasking algorithm ('origin', 'entropy', 'maskgit_plus', 'topk_margin'). | |
alg_temp: Temperature for confidence-based algorithms. | |
Returns: | |
Tuple: (List of visualization states, final generated text string) | |
""" | |
global previous_x_shared, vis_queue | |
if constraints is None: | |
constraints = {} | |
visualization_states = [] | |
# Clear the queue for a new generation | |
while not vis_queue.empty(): | |
try: | |
vis_queue.get_nowait() | |
except Queue.Empty: | |
break | |
# Prepare the prompt using chat template | |
# The template automatically adds the generation prompt like "<|im_start|>assistant\n" | |
try: | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
return_tensors="pt", | |
add_generation_prompt=True, | |
return_dict=True | |
) | |
input_ids = inputs.input_ids.to(device=DEVICE) | |
# Dream doesn't seem to explicitly use attention_mask in simple demos, | |
# but it's good practice if padding were involved. | |
# For now, assume no padding in this interactive demo. | |
attention_mask = inputs.attention_mask.to(device=DEVICE) if 'attention_mask' in inputs else None | |
except Exception as e: | |
print(f"Error applying chat template: {e}") | |
# Provide a fallback or error state | |
error_state = [("Error in chat formatting.", "red")] | |
return [error_state], f"Error: Could not format chat history. {e}" | |
prompt_length = input_ids.shape[1] | |
total_length = prompt_length + max_new_tokens | |
# --- Define the Hook Function --- | |
def generation_tokens_hook_func(step, x, logits): | |
global previous_x_shared, vis_queue | |
with state_lock: # Ensure thread safety if needed, though hooks might run sequentially | |
current_x = x.clone() # Shape: (batch_size, total_length) | |
# --- Apply Constraints --- | |
# Constraints are relative to the start of the *response* | |
for rel_pos, token_id in constraints.items(): | |
abs_pos = prompt_length + rel_pos | |
if 0 <= abs_pos < current_x.shape[1]: | |
# Ensure constraint application doesn't go out of bounds | |
# Apply constraint for the first batch element (batch size is 1 here) | |
current_x[0, abs_pos] = token_id | |
# --- Create Visualization State --- | |
current_vis_state = [] | |
x_response = current_x[0, prompt_length:] # Get the response part for batch 0 | |
prev_x_response = previous_x_shared[0, prompt_length:] if previous_x_shared is not None else None | |
for i in range(max_new_tokens): | |
current_token_id = x_response[i].item() | |
token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for vis | |
# Clean up visual representation of special tokens | |
if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token: | |
token_str = "[EOS/PAD]" # Make it visually distinct | |
elif token_str == tokenizer.mask_token: | |
token_str = "[MASK]" | |
elif token_str.strip() == "": # Handle empty strings from decoding potentially odd tokens | |
token_str = "[UNK/SPACE]" | |
color = "#DDDDDD" # Default background | |
if current_token_id == MASK_ID: | |
color = "#444444" # Dark gray for masks | |
elif prev_x_response is not None and prev_x_response[i].item() == MASK_ID: | |
# Token was mask, now it's revealed in this step | |
# Use green for newly revealed | |
color = "#66CC66" # Light green | |
else: | |
# Token was already revealed in a previous step or is a constraint | |
# Check if it's a constraint applied *now* | |
is_constraint = (prompt_length + i - prompt_length) in constraints and \ | |
constraints[prompt_length + i - prompt_length] == current_token_id | |
if is_constraint: | |
color = "#FFD700" # Gold for constraints | |
else: | |
color = "#6699CC" # Light blue for previously revealed | |
current_vis_state.append((token_str, color)) | |
# --- Update shared state and put vis state in queue --- | |
previous_x_shared = current_x.clone() # Update for the *next* step's comparison | |
vis_queue.put(current_vis_state) | |
# The hook must return the potentially modified tensor `x` | |
return current_x | |
# --- End of Hook Function --- | |
# Initialize previous_x_shared before generation starts | |
# Create initial masked state for visualization | |
initial_x = input_ids.clone() | |
if initial_x.shape[1] < total_length: | |
padding = torch.full((1, total_length - initial_x.shape[1]), MASK_ID, dtype=torch.long, device=DEVICE) | |
initial_x = torch.cat([initial_x, padding], dim=1) | |
else: | |
initial_x = initial_x[:, :total_length] # Truncate if prompt is too long | |
# Apply initial constraints to the starting state | |
for rel_pos, token_id in constraints.items(): | |
abs_pos = prompt_length + rel_pos | |
if 0 <= abs_pos < initial_x.shape[1]: | |
initial_x[0, abs_pos] = token_id | |
with state_lock: | |
previous_x_shared = initial_x.clone() | |
# Add the initial all-masked state (or with constraints) to the visualization queue | |
initial_vis_state = [] | |
initial_x_response = initial_x[0, prompt_length:] | |
for i in range(max_new_tokens): | |
token_id = initial_x_response[i].item() | |
if token_id == MASK_ID: | |
initial_vis_state.append((MASK_TOKEN, "#444444")) | |
else: | |
# Must be a pre-applied constraint | |
token_str = tokenizer.decode([token_id], skip_special_tokens=False) | |
if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token: | |
token_str = "[EOS/PAD]" | |
elif token_str.strip() == "": | |
token_str = "[UNK/SPACE]" | |
initial_vis_state.append((token_str, "#FFD700")) # Gold for constraints | |
vis_queue.put(initial_vis_state) | |
# --- Run Generation --- | |
try: | |
# output_history=False because the hook handles state capture | |
# return_dict_in_generate=True to get the GenerationOutput object | |
output = model.diffusion_generate( | |
initial_x, # Start with the potentially constraint-applied tensor | |
attention_mask=None, # Assuming no padding needed for interactive use | |
max_new_tokens=max_new_tokens, # This might not be strictly needed if total_length is fixed | |
output_history=False, | |
return_dict_in_generate=True, | |
steps=steps, | |
temperature=temperature, | |
top_p=top_p, | |
alg=alg, | |
alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence algs | |
generation_tokens_hook_func=generation_tokens_hook_func | |
) | |
final_sequence = output.sequences[0] # Batch size 1 | |
# Decode the final response text, cleaning up special tokens | |
response_tokens = final_sequence[prompt_length:] | |
# Filter out EOS/PAD tokens for the final text display | |
response_tokens_filtered = [tok for tok in response_tokens.tolist() if tok != EOS_ID and tok != PAD_ID] | |
final_text = tokenizer.decode(response_tokens_filtered, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True) # Standard cleanup | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
import traceback | |
traceback.print_exc() | |
# Provide error state | |
error_state = [("Generation Error.", "red")] | |
visualization_states.append(error_state) | |
final_text = f"Error: Generation failed. {e}" | |
# Add any states captured before the error | |
while not vis_queue.empty(): | |
try: | |
visualization_states.append(vis_queue.get_nowait()) | |
except Queue.Empty: | |
break | |
return visualization_states, final_text | |
# Retrieve all visualization states captured by the hook | |
while not vis_queue.empty(): | |
try: | |
visualization_states.append(vis_queue.get_nowait()) | |
except Queue.Empty: | |
break | |
# If somehow no states were captured, add the initial one | |
if not visualization_states: | |
visualization_states.append(initial_vis_state) | |
return visualization_states, final_text.strip() | |
# --- Gradio UI --- | |
css = ''' | |
.category-legend{display:none} | |
button{height: 60px} | |
''' | |
def create_chatbot_demo(): | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# Dream 7B - Diffusion Language Model Demo") | |
gr.Markdown("Chat with the Dream 7B Instruct model and visualize the diffusion generation process.") | |
gr.Markdown("Model: [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)") | |
# STATE MANAGEMENT | |
chat_history = gr.State([]) | |
# UI COMPONENTS | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot_ui = gr.Chatbot(label="Conversation", height=500, avatar_images=["user.png", "robot.png"]) | |
# Message input | |
with gr.Group(): | |
with gr.Row(): | |
user_input = gr.Textbox( | |
label="Your Message", | |
placeholder="Type your message here...", | |
show_label=False, | |
scale=9 | |
) | |
send_btn = gr.Button("Send", scale=1) | |
constraints_input = gr.Textbox( | |
label="Word Constraints (Optional)", | |
info="Place words at specific positions (0-indexed from response start). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:a'", | |
placeholder="0:Once, 5:upon, 10:a", | |
value="" | |
) | |
with gr.Column(scale=2): | |
output_vis = gr.HighlightedText( | |
label="Diffusion Process Visualization", | |
combine_adjacent=False, | |
show_legend=True, # Keep legend hidden via CSS if desired | |
) | |
# Legend (colors defined in generate_response_with_visualization) | |
gr.Markdown( | |
"<small>Color Legend: <span style='background-color:#444444; color:white;'>[MASK]</span>" | |
" <span style='background-color:#66CC66;'>Newly Revealed</span>" | |
" <span style='background-color:#6699CC;'>Previously Revealed</span>" | |
" <span style='background-color:#FFD700;'>Constraint</span>" | |
" <span style='background-color:#DDDDDD;'>[EOS/PAD/UNK]</span></small>" | |
) | |
# Advanced generation settings | |
with gr.Accordion("Generation Settings", open=False): | |
max_new_tokens_slider = gr.Slider( | |
minimum=16, maximum=512, value=128, step=16, # Increased default/max | |
label="Max New Tokens (Generation Length)" | |
) | |
steps_slider = gr.Slider( | |
minimum=8, maximum=512, value=128, step=8, # Increased default/max | |
label="Diffusion Steps" | |
) | |
temp_slider = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.6, step=0.05, # Finer steps for temp | |
label="Temperature" | |
) | |
top_p_slider = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.95, step=0.05, | |
label="Top-P (Nucleus Sampling)" | |
) | |
alg_radio = gr.Radio( | |
# Choices from README | |
choices=['origin', 'entropy', 'maskgit_plus', 'topk_margin'], | |
value='entropy', | |
label="Remasking Algorithm" | |
) | |
alg_temp_slider = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.1, step=0.05, | |
label="Algorithm Temperature (for confidence-based algs)" | |
) | |
vis_delay_slider = gr.Slider( | |
minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay | |
label="Visualization Delay (seconds)" | |
) | |
# Clear button | |
clear_btn = gr.Button("Clear Conversation") | |
# HELPER FUNCTIONS (UI Logic) | |
def add_message_to_history(history, message, response): | |
"""Add a message pair to the history state""" | |
new_history = history + [[message, response]] | |
return new_history | |
def user_message_submitted(message, history): | |
""" Handle user sending a message: update history, clear input """ | |
if not message or message.strip() == "": | |
return history, history, "", [] # No change if empty | |
# Add user message, response is initially None | |
new_history = add_message_to_history(history, message, None) | |
# Prepare display version (immediately shows user message) | |
display_history = new_history | |
# Clear input box | |
message_out = "" | |
# Clear visualization | |
vis_out = [] | |
return new_history, display_history, message_out, vis_out | |
def bot_response_generator(history, constraints_str, max_tokens, steps, temp, top_p, alg, alg_temp, delay): | |
""" Generator function to stream bot response and visualization """ | |
if not history or history[-1][1] is not None: # Ensure there's a user msg waiting for response | |
print("Warning: Bot response triggered without pending user message.") | |
yield history, [], "Error: No user message to respond to." # Send error state back? | |
return | |
# Get the full conversation history formatted for the model | |
last_user_message = history[-1][0] | |
messages_for_model = format_chat_history(history[:-1]) # History *before* the last user msg | |
messages_for_model.append({"role": "user", "content": last_user_message}) | |
# Parse constraints | |
try: | |
parsed_constraints = parse_constraints(constraints_str) | |
except Exception as e: | |
print(f"Error parsing constraints: {e}") | |
yield history, [("Constraint Error", "red")], f"Error: Failed to parse constraints: {e}" | |
return | |
# Generate response and visualization states | |
try: | |
vis_states, final_response_text = generate_response_with_visualization( | |
messages_for_model, | |
max_new_tokens=max_tokens, | |
steps=steps, | |
constraints=parsed_constraints, | |
temperature=temp, | |
top_p=top_p, | |
alg=alg, | |
alg_temp=alg_temp | |
) | |
except Exception as e: | |
print(f"Error in generate_response_with_visualization: {e}") | |
import traceback | |
traceback.print_exc() | |
yield history, [("Generation Error", "red")], f"Error: Generation failed: {e}" | |
return | |
# Update the history state with the final response *once* | |
history[-1][1] = final_response_text # Update the None placeholder | |
# Yield initial state immediately | |
if vis_states: | |
yield history, vis_states[0] | |
else: | |
yield history, [] # Should not happen if generation worked | |
# Stream intermediate visualization states | |
for state in vis_states[1:]: | |
time.sleep(delay) | |
yield history, state | |
# Final yield ensures the chatbot UI has the complete history | |
# The last state in vis_states should already be yielded by the loop | |
# yield history, vis_states[-1] if vis_states else [] | |
def clear_conversation(): | |
"""Clear the conversation history and visualization""" | |
return [], [], "", [] # history, chatbot_ui, user_input, output_vis | |
# EVENT HANDLERS | |
# User presses Enter or Send button | |
submit_args = { | |
"fn": user_message_submitted, | |
"inputs": [user_input, chat_history], | |
"outputs": [chat_history, chatbot_ui, user_input, output_vis] | |
} | |
user_input.submit(**submit_args) | |
send_btn.click(**submit_args) | |
# After user message is submitted, trigger bot response generation | |
generate_args = { | |
"fn": bot_response_generator, | |
"inputs": [ | |
chat_history, constraints_input, max_new_tokens_slider, steps_slider, | |
temp_slider, top_p_slider, alg_radio, alg_temp_slider, vis_delay_slider | |
], | |
"outputs": [chatbot_ui, output_vis] # Update chatbot history and visualization | |
} | |
# Trigger generation after submit OR click | |
user_input.submit(None, None, None, queue=True).then(**generate_args) | |
send_btn.click(None, None, None, queue=True).then(**generate_args) | |
# Clear button handler | |
clear_btn.click( | |
fn=clear_conversation, | |
inputs=[], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis] | |
) | |
return demo | |
# Launch the demo | |
if __name__ == "__main__": | |
demo = create_chatbot_demo() | |
# queue() allows streaming and handling multiple users | |
# share=True creates a public link (use with caution) | |
demo.queue().launch(share=True, debug=True) |