Dream / app.py
multimodalart's picture
Update app.py
5713ed1 verified
raw
history blame
23.3 kB
# dream_app.py
import torch
import numpy as np
import gradio as gr
import spaces
import time
import re
from transformers import AutoModel, AutoTokenizer
from threading import Lock
from queue import Queue
# --- Configuration ---
MODEL_PATH = "Dream-org/Dream-v0-Instruct-7B"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")
# --- Load Model and Tokenizer ---
print("Loading model and tokenizer...")
# Need configuration files for trust_remote_code
# Make sure config.json, configuration_dream.py, modeling_dream.py,
# generation_utils.py, generation_config.json are in the same directory
# or accessible in the Hugging Face cache.
model = AutoModel.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
print("Model and tokenizer loaded.")
# --- Constants ---
# Get IDs from tokenizer/config if possible, otherwise hardcode from provided files
MASK_TOKEN = tokenizer.mask_token # Should be "<|mask|>"
try:
MASK_ID = tokenizer.mask_token_id # Should be 151666
if MASK_ID is None: raise AttributeError # Handle case where it might not be set directly
except AttributeError:
print("Warning: Could not directly get mask_token_id, using hardcoded value 151666.")
MASK_ID = 151666
try:
EOS_ID = tokenizer.eos_token_id # Should be 151643
PAD_ID = tokenizer.pad_token_id # Should be 151643
if EOS_ID is None or PAD_ID is None: raise AttributeError
except AttributeError:
print("Warning: Could not directly get eos/pad_token_id, using hardcoded value 151643.")
EOS_ID = 151643
PAD_ID = 151643
# Ensure MASK_TOKEN and MASK_ID are valid
if MASK_TOKEN is None or MASK_ID is None:
raise ValueError("Mask token or ID is not defined correctly.")
if EOS_ID is None or PAD_ID is None:
raise ValueError("EOS/PAD token or ID is not defined correctly.")
# --- Helper Functions ---
def parse_constraints(constraints_text):
"""Parse constraints in format: 'position:word, position:word, ...'"""
constraints = {}
if not constraints_text:
return constraints
parts = constraints_text.split(',')
for part in parts:
if ':' not in part:
continue
try:
pos_str, word = part.split(':', 1)
pos = int(pos_str.strip())
word = word.strip()
if word and pos >= 0:
# Tokenize the word - handle potential multi-token words
# Add space prefix for consistency, similar to how model might see words mid-sentence
tokens = tokenizer.encode(" " + word, add_special_tokens=False)
for i, token_id in enumerate(tokens):
constraints[pos + i] = token_id
except ValueError:
continue
except Exception as e:
print(f"Error parsing constraint part '{part}': {e}")
continue
return constraints
def format_chat_history(history):
"""
Format chat history for the Dream model using its chat template logic.
Args:
history: List of [user_message, assistant_message] pairs
Returns:
Formatted list of message dictionaries for the model
"""
messages = []
# Add system prompt if history is empty or doesn't start with system
if not history or history[0][0].lower() != 'system':
# Check if the tokenizer's template expects an explicit system message
# The template provided in tokenizer_config.json handles adding a default one
pass # Let apply_chat_template handle the default system message
for user_msg, assistant_msg in history:
if user_msg: # Handle potential initial system message possibility if needed
messages.append({"role": "user", "content": user_msg})
if assistant_msg is not None: # Skip if None (for the latest user message)
messages.append({"role": "assistant", "content": assistant_msg})
return messages
# --- Core Generation Logic with Visualization ---
# Use a thread-safe queue to pass visualization states from the hook
vis_queue = Queue()
# Lock to prevent race conditions when accessing shared state like previous_x
state_lock = Lock()
# Store the previous state for comparison in the hook
previous_x_shared = None
@spaces.GPU
def generate_response_with_visualization(
messages, # List of message dicts from format_chat_history
max_new_tokens=64,
steps=64, # Default steps based on README example
constraints=None,
temperature=0.6, # Default from demo_token_control
top_p=0.95, # Default from demos
alg="entropy", # Default from demos
alg_temp=0.1, # Default from demo_multiturn_chat
):
"""
Generate text with Dream model and capture visualization states using a hook.
Args:
messages: List of message dictionaries with 'role' and 'content'.
max_new_tokens: Max tokens to generate.
steps: Diffusion steps.
constraints: Dictionary mapping positions (relative to response start) to token IDs.
temperature: Sampling temperature.
top_p: Nucleus sampling p.
alg: Remasking algorithm ('origin', 'entropy', 'maskgit_plus', 'topk_margin').
alg_temp: Temperature for confidence-based algorithms.
Returns:
Tuple: (List of visualization states, final generated text string)
"""
global previous_x_shared, vis_queue
if constraints is None:
constraints = {}
visualization_states = []
# Clear the queue for a new generation
while not vis_queue.empty():
try:
vis_queue.get_nowait()
except Queue.Empty:
break
# Prepare the prompt using chat template
# The template automatically adds the generation prompt like "<|im_start|>assistant\n"
try:
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
return_dict=True
)
input_ids = inputs.input_ids.to(device=DEVICE)
# Dream doesn't seem to explicitly use attention_mask in simple demos,
# but it's good practice if padding were involved.
# For now, assume no padding in this interactive demo.
attention_mask = inputs.attention_mask.to(device=DEVICE) if 'attention_mask' in inputs else None
except Exception as e:
print(f"Error applying chat template: {e}")
# Provide a fallback or error state
error_state = [("Error in chat formatting.", "red")]
return [error_state], f"Error: Could not format chat history. {e}"
prompt_length = input_ids.shape[1]
total_length = prompt_length + max_new_tokens
# --- Define the Hook Function ---
def generation_tokens_hook_func(step, x, logits):
global previous_x_shared, vis_queue
with state_lock: # Ensure thread safety if needed, though hooks might run sequentially
current_x = x.clone() # Shape: (batch_size, total_length)
# --- Apply Constraints ---
# Constraints are relative to the start of the *response*
for rel_pos, token_id in constraints.items():
abs_pos = prompt_length + rel_pos
if 0 <= abs_pos < current_x.shape[1]:
# Ensure constraint application doesn't go out of bounds
# Apply constraint for the first batch element (batch size is 1 here)
current_x[0, abs_pos] = token_id
# --- Create Visualization State ---
current_vis_state = []
x_response = current_x[0, prompt_length:] # Get the response part for batch 0
prev_x_response = previous_x_shared[0, prompt_length:] if previous_x_shared is not None else None
for i in range(max_new_tokens):
current_token_id = x_response[i].item()
token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for vis
# Clean up visual representation of special tokens
if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
token_str = "[EOS/PAD]" # Make it visually distinct
elif token_str == tokenizer.mask_token:
token_str = "[MASK]"
elif token_str.strip() == "": # Handle empty strings from decoding potentially odd tokens
token_str = "[UNK/SPACE]"
color = "#DDDDDD" # Default background
if current_token_id == MASK_ID:
color = "#444444" # Dark gray for masks
elif prev_x_response is not None and prev_x_response[i].item() == MASK_ID:
# Token was mask, now it's revealed in this step
# Use green for newly revealed
color = "#66CC66" # Light green
else:
# Token was already revealed in a previous step or is a constraint
# Check if it's a constraint applied *now*
is_constraint = (prompt_length + i - prompt_length) in constraints and \
constraints[prompt_length + i - prompt_length] == current_token_id
if is_constraint:
color = "#FFD700" # Gold for constraints
else:
color = "#6699CC" # Light blue for previously revealed
current_vis_state.append((token_str, color))
# --- Update shared state and put vis state in queue ---
previous_x_shared = current_x.clone() # Update for the *next* step's comparison
vis_queue.put(current_vis_state)
# The hook must return the potentially modified tensor `x`
return current_x
# --- End of Hook Function ---
# Initialize previous_x_shared before generation starts
# Create initial masked state for visualization
initial_x = input_ids.clone()
if initial_x.shape[1] < total_length:
padding = torch.full((1, total_length - initial_x.shape[1]), MASK_ID, dtype=torch.long, device=DEVICE)
initial_x = torch.cat([initial_x, padding], dim=1)
else:
initial_x = initial_x[:, :total_length] # Truncate if prompt is too long
# Apply initial constraints to the starting state
for rel_pos, token_id in constraints.items():
abs_pos = prompt_length + rel_pos
if 0 <= abs_pos < initial_x.shape[1]:
initial_x[0, abs_pos] = token_id
with state_lock:
previous_x_shared = initial_x.clone()
# Add the initial all-masked state (or with constraints) to the visualization queue
initial_vis_state = []
initial_x_response = initial_x[0, prompt_length:]
for i in range(max_new_tokens):
token_id = initial_x_response[i].item()
if token_id == MASK_ID:
initial_vis_state.append((MASK_TOKEN, "#444444"))
else:
# Must be a pre-applied constraint
token_str = tokenizer.decode([token_id], skip_special_tokens=False)
if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
token_str = "[EOS/PAD]"
elif token_str.strip() == "":
token_str = "[UNK/SPACE]"
initial_vis_state.append((token_str, "#FFD700")) # Gold for constraints
vis_queue.put(initial_vis_state)
# --- Run Generation ---
try:
# output_history=False because the hook handles state capture
# return_dict_in_generate=True to get the GenerationOutput object
output = model.diffusion_generate(
initial_x, # Start with the potentially constraint-applied tensor
attention_mask=None, # Assuming no padding needed for interactive use
max_new_tokens=max_new_tokens, # This might not be strictly needed if total_length is fixed
output_history=False,
return_dict_in_generate=True,
steps=steps,
temperature=temperature,
top_p=top_p,
alg=alg,
alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence algs
generation_tokens_hook_func=generation_tokens_hook_func
)
final_sequence = output.sequences[0] # Batch size 1
# Decode the final response text, cleaning up special tokens
response_tokens = final_sequence[prompt_length:]
# Filter out EOS/PAD tokens for the final text display
response_tokens_filtered = [tok for tok in response_tokens.tolist() if tok != EOS_ID and tok != PAD_ID]
final_text = tokenizer.decode(response_tokens_filtered,
skip_special_tokens=True,
clean_up_tokenization_spaces=True) # Standard cleanup
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
# Provide error state
error_state = [("Generation Error.", "red")]
visualization_states.append(error_state)
final_text = f"Error: Generation failed. {e}"
# Add any states captured before the error
while not vis_queue.empty():
try:
visualization_states.append(vis_queue.get_nowait())
except Queue.Empty:
break
return visualization_states, final_text
# Retrieve all visualization states captured by the hook
while not vis_queue.empty():
try:
visualization_states.append(vis_queue.get_nowait())
except Queue.Empty:
break
# If somehow no states were captured, add the initial one
if not visualization_states:
visualization_states.append(initial_vis_state)
return visualization_states, final_text.strip()
# --- Gradio UI ---
css = '''
.category-legend{display:none}
button{height: 60px}
'''
def create_chatbot_demo():
with gr.Blocks(css=css) as demo:
gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
gr.Markdown("Chat with the Dream 7B Instruct model and visualize the diffusion generation process.")
gr.Markdown("Model: [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)")
# STATE MANAGEMENT
chat_history = gr.State([])
# UI COMPONENTS
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(label="Conversation", height=500, avatar_images=["user.png", "robot.png"])
# Message input
with gr.Group():
with gr.Row():
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
show_label=False,
scale=9
)
send_btn = gr.Button("Send", scale=1)
constraints_input = gr.Textbox(
label="Word Constraints (Optional)",
info="Place words at specific positions (0-indexed from response start). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:a'",
placeholder="0:Once, 5:upon, 10:a",
value=""
)
with gr.Column(scale=2):
output_vis = gr.HighlightedText(
label="Diffusion Process Visualization",
combine_adjacent=False,
show_legend=True, # Keep legend hidden via CSS if desired
)
# Legend (colors defined in generate_response_with_visualization)
gr.Markdown(
"<small>Color Legend: <span style='background-color:#444444; color:white;'>[MASK]</span>"
" <span style='background-color:#66CC66;'>Newly Revealed</span>"
" <span style='background-color:#6699CC;'>Previously Revealed</span>"
" <span style='background-color:#FFD700;'>Constraint</span>"
" <span style='background-color:#DDDDDD;'>[EOS/PAD/UNK]</span></small>"
)
# Advanced generation settings
with gr.Accordion("Generation Settings", open=False):
max_new_tokens_slider = gr.Slider(
minimum=16, maximum=512, value=128, step=16, # Increased default/max
label="Max New Tokens (Generation Length)"
)
steps_slider = gr.Slider(
minimum=8, maximum=512, value=128, step=8, # Increased default/max
label="Diffusion Steps"
)
temp_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.6, step=0.05, # Finer steps for temp
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.95, step=0.05,
label="Top-P (Nucleus Sampling)"
)
alg_radio = gr.Radio(
# Choices from README
choices=['origin', 'entropy', 'maskgit_plus', 'topk_margin'],
value='entropy',
label="Remasking Algorithm"
)
alg_temp_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
label="Algorithm Temperature (for confidence-based algs)"
)
vis_delay_slider = gr.Slider(
minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay
label="Visualization Delay (seconds)"
)
# Clear button
clear_btn = gr.Button("Clear Conversation")
# HELPER FUNCTIONS (UI Logic)
def add_message_to_history(history, message, response):
"""Add a message pair to the history state"""
new_history = history + [[message, response]]
return new_history
def user_message_submitted(message, history):
""" Handle user sending a message: update history, clear input """
if not message or message.strip() == "":
return history, history, "", [] # No change if empty
# Add user message, response is initially None
new_history = add_message_to_history(history, message, None)
# Prepare display version (immediately shows user message)
display_history = new_history
# Clear input box
message_out = ""
# Clear visualization
vis_out = []
return new_history, display_history, message_out, vis_out
def bot_response_generator(history, constraints_str, max_tokens, steps, temp, top_p, alg, alg_temp, delay):
""" Generator function to stream bot response and visualization """
if not history or history[-1][1] is not None: # Ensure there's a user msg waiting for response
print("Warning: Bot response triggered without pending user message.")
yield history, [], "Error: No user message to respond to." # Send error state back?
return
# Get the full conversation history formatted for the model
last_user_message = history[-1][0]
messages_for_model = format_chat_history(history[:-1]) # History *before* the last user msg
messages_for_model.append({"role": "user", "content": last_user_message})
# Parse constraints
try:
parsed_constraints = parse_constraints(constraints_str)
except Exception as e:
print(f"Error parsing constraints: {e}")
yield history, [("Constraint Error", "red")], f"Error: Failed to parse constraints: {e}"
return
# Generate response and visualization states
try:
vis_states, final_response_text = generate_response_with_visualization(
messages_for_model,
max_new_tokens=max_tokens,
steps=steps,
constraints=parsed_constraints,
temperature=temp,
top_p=top_p,
alg=alg,
alg_temp=alg_temp
)
except Exception as e:
print(f"Error in generate_response_with_visualization: {e}")
import traceback
traceback.print_exc()
yield history, [("Generation Error", "red")], f"Error: Generation failed: {e}"
return
# Update the history state with the final response *once*
history[-1][1] = final_response_text # Update the None placeholder
# Yield initial state immediately
if vis_states:
yield history, vis_states[0]
else:
yield history, [] # Should not happen if generation worked
# Stream intermediate visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield history, state
# Final yield ensures the chatbot UI has the complete history
# The last state in vis_states should already be yielded by the loop
# yield history, vis_states[-1] if vis_states else []
def clear_conversation():
"""Clear the conversation history and visualization"""
return [], [], "", [] # history, chatbot_ui, user_input, output_vis
# EVENT HANDLERS
# User presses Enter or Send button
submit_args = {
"fn": user_message_submitted,
"inputs": [user_input, chat_history],
"outputs": [chat_history, chatbot_ui, user_input, output_vis]
}
user_input.submit(**submit_args)
send_btn.click(**submit_args)
# After user message is submitted, trigger bot response generation
generate_args = {
"fn": bot_response_generator,
"inputs": [
chat_history, constraints_input, max_new_tokens_slider, steps_slider,
temp_slider, top_p_slider, alg_radio, alg_temp_slider, vis_delay_slider
],
"outputs": [chatbot_ui, output_vis] # Update chatbot history and visualization
}
# Trigger generation after submit OR click
user_input.submit(None, None, None, queue=True).then(**generate_args)
send_btn.click(None, None, None, queue=True).then(**generate_args)
# Clear button handler
clear_btn.click(
fn=clear_conversation,
inputs=[],
outputs=[chat_history, chatbot_ui, user_input, output_vis]
)
return demo
# Launch the demo
if __name__ == "__main__":
demo = create_chatbot_demo()
# queue() allows streaming and handling multiple users
# share=True creates a public link (use with caution)
demo.queue().launch(share=True, debug=True)