tess-2-demo / app.py
hamishivi's picture
arxiv link
3191580 verified
import logging
import gradio as gr
import torch
import numpy as np
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
)
from sdlm.arguments import get_args
from sdlm.models.utils import load_model
from sdlm.pipelines.simplex_ddpm import SimplexDDPMPipeline
from sdlm.schedulers import TokenWiseSimplexDDPMScheduler
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def main():
model_args, data_args, training_args, diffusion_args = get_args("args.json")
tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = SimplexDDPMPipeline(
model=model.to(device),
scheduler=TokenWiseSimplexDDPMScheduler(
num_train_timesteps=diffusion_args.num_train_timesteps
if hasattr(diffusion_args, "num_train_timesteps") else 100,
beta_schedule=getattr(diffusion_args, "beta_schedule", "squaredcos_improved_ddpm"),
simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
clip_sample=getattr(diffusion_args, "clip_sample", False),
device=device,
),
simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
top_p=getattr(diffusion_args, "top_p", 0.99),
sampling_type="top_p",
is_conditional_generation=True,
tokenizer=tokenizer,
classifier_free_uncond_input="empty_token",
temperature=getattr(diffusion_args, "temperature", 1.0),
guidance_softmax_combination=True,
)
def generate(
inputs,
simplex_value=5.0,
top_p=0.99,
temperature=1.0,
diffusion_steps=100,
generated_sequence_length=256,
beta_schedule="squaredcos_improved_ddpm",
clip_sample=False,
guidance_scale=1.0,
):
"""
Gradio-friendly generation function. Adjusts the pipeline's parameters
(simplex_value, top_p, etc.) as requested, then runs generation.
"""
with torch.inference_mode():
# Update pipeline scheduler with user-provided parameters:
pipeline.scheduler.num_train_timesteps = diffusion_steps
pipeline.scheduler.timesteps = diffusion_steps
pipeline.scheduler.beta_schedule = beta_schedule
pipeline.scheduler.simplex_value = simplex_value
pipeline.scheduler.clip_sample = clip_sample
pipeline.simplex_value = simplex_value
pipeline.top_p = top_p
pipeline.temperature = temperature
# Ensure timesteps are properly set as a sequence
pipeline.scheduler.timesteps = torch.arange(0, diffusion_steps).flip(0)
# tulu chat template
inputs = "<|user|>\n" + inputs + "<|assistant|>\n"
# Tokenize and prepare input for diffusion
tokenized_input = tokenizer([inputs], add_special_tokens=False, return_tensors="pt").input_ids
tokenized_input_len = tokenized_input.shape[1]
# Concatenate BOS + input + blank space for generation
tokenized_input = torch.cat(
[
torch.ones((1, 1), dtype=torch.long) * tokenizer.bos_token_id,
tokenized_input,
torch.ones((1, generated_sequence_length), dtype=torch.long) * tokenizer.pad_token_id,
],
dim=-1,
)
# Create a mask over the generation region
span_mask = torch.cat(
[
torch.zeros((1, tokenized_input_len + 1), dtype=torch.bool),
torch.ones((1, generated_sequence_length), dtype=torch.bool),
],
dim=-1,
)
batch = {
"input_ids": tokenized_input.to(device),
"span_mask": span_mask.to(device),
}
# Run sampling
current_step = 0
pipe = pipeline(batch=batch, seq_length=generated_sequence_length, guidance_scale=guidance_scale)
for out in pipe:
output_ids = out.logits.argmax(dim=-1)
generated_tokens = output_ids[:, tokenized_input_len + 1 :]
text_output = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
current_step += 1
progress = current_step / diffusion_steps
yield [text_output, gr.Slider(value=progress, minimum=0, maximum=1, label=f"Step {current_step}/{diffusion_steps}")]
with gr.Blocks() as demo:
gr.Markdown("# TESS 2 Demo!")
gr.Markdown("A live demo of TESS 2 v0.3. Check out the models and code [here](https://github.com/hamishivi/tess-2), or the paper [here](https://arxiv.org/abs/2502.13917)!")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(lines=5, label="Input Prompt")
simplex_value = gr.Number(value=5.0, label="Simplex value")
top_p = gr.Slider(0, 1, value=0.99, step=0.01, label="Top-p")
temperature = gr.Slider(0, 5, value=1.0, step=0.1, label="Temperature")
diffusion_steps = gr.Number(value=100, precision=0, label="Diffusion steps")
seq_length = gr.Number(value=256, label="Generation length (tokens)")
with gr.Column(scale=3):
output_text = gr.Textbox(label="Generated Text")
progress_bar = gr.Slider(
minimum=0,
maximum=1,
value=0,
label="Generation Progress",
interactive=False
)
generate_btn = gr.Button("Generate")
generate_btn.click(
generate,
inputs=[
input_text,
simplex_value,
top_p,
temperature,
diffusion_steps,
seq_length,
],
outputs=[output_text, progress_bar]
)
demo.queue().launch()
if __name__ == "__main__":
main()