import logging import gradio as gr import torch import numpy as np from transformers import ( MODEL_FOR_MASKED_LM_MAPPING, ) from sdlm.arguments import get_args from sdlm.models.utils import load_model from sdlm.pipelines.simplex_ddpm import SimplexDDPMPipeline from sdlm.schedulers import TokenWiseSimplexDDPMScheduler logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) def main(): model_args, data_args, training_args, diffusion_args = get_args("args.json") tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline = SimplexDDPMPipeline( model=model.to(device), scheduler=TokenWiseSimplexDDPMScheduler( num_train_timesteps=diffusion_args.num_train_timesteps if hasattr(diffusion_args, "num_train_timesteps") else 100, beta_schedule=getattr(diffusion_args, "beta_schedule", "squaredcos_improved_ddpm"), simplex_value=getattr(diffusion_args, "simplex_value", 5.0), clip_sample=getattr(diffusion_args, "clip_sample", False), device=device, ), simplex_value=getattr(diffusion_args, "simplex_value", 5.0), top_p=getattr(diffusion_args, "top_p", 0.99), sampling_type="top_p", is_conditional_generation=True, tokenizer=tokenizer, classifier_free_uncond_input="empty_token", temperature=getattr(diffusion_args, "temperature", 1.0), guidance_softmax_combination=True, ) def generate( inputs, simplex_value=5.0, top_p=0.99, temperature=1.0, diffusion_steps=100, generated_sequence_length=256, beta_schedule="squaredcos_improved_ddpm", clip_sample=False, guidance_scale=1.0, ): """ Gradio-friendly generation function. Adjusts the pipeline's parameters (simplex_value, top_p, etc.) as requested, then runs generation. """ with torch.inference_mode(): # Update pipeline scheduler with user-provided parameters: pipeline.scheduler.num_train_timesteps = diffusion_steps pipeline.scheduler.timesteps = diffusion_steps pipeline.scheduler.beta_schedule = beta_schedule pipeline.scheduler.simplex_value = simplex_value pipeline.scheduler.clip_sample = clip_sample pipeline.simplex_value = simplex_value pipeline.top_p = top_p pipeline.temperature = temperature # Ensure timesteps are properly set as a sequence pipeline.scheduler.timesteps = torch.arange(0, diffusion_steps).flip(0) # tulu chat template inputs = "<|user|>\n" + inputs + "<|assistant|>\n" # Tokenize and prepare input for diffusion tokenized_input = tokenizer([inputs], add_special_tokens=False, return_tensors="pt").input_ids tokenized_input_len = tokenized_input.shape[1] # Concatenate BOS + input + blank space for generation tokenized_input = torch.cat( [ torch.ones((1, 1), dtype=torch.long) * tokenizer.bos_token_id, tokenized_input, torch.ones((1, generated_sequence_length), dtype=torch.long) * tokenizer.pad_token_id, ], dim=-1, ) # Create a mask over the generation region span_mask = torch.cat( [ torch.zeros((1, tokenized_input_len + 1), dtype=torch.bool), torch.ones((1, generated_sequence_length), dtype=torch.bool), ], dim=-1, ) batch = { "input_ids": tokenized_input.to(device), "span_mask": span_mask.to(device), } # Run sampling current_step = 0 pipe = pipeline(batch=batch, seq_length=generated_sequence_length, guidance_scale=guidance_scale) for out in pipe: output_ids = out.logits.argmax(dim=-1) generated_tokens = output_ids[:, tokenized_input_len + 1 :] text_output = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) current_step += 1 progress = current_step / diffusion_steps yield [text_output, gr.Slider(value=progress, minimum=0, maximum=1, label=f"Step {current_step}/{diffusion_steps}")] with gr.Blocks() as demo: gr.Markdown("# TESS 2 Demo!") gr.Markdown("A live demo of TESS 2 v0.3. Check out the models and code [here](https://github.com/hamishivi/tess-2), or the paper [here](https://arxiv.org/abs/2502.13917)!") with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox(lines=5, label="Input Prompt") simplex_value = gr.Number(value=5.0, label="Simplex value") top_p = gr.Slider(0, 1, value=0.99, step=0.01, label="Top-p") temperature = gr.Slider(0, 5, value=1.0, step=0.1, label="Temperature") diffusion_steps = gr.Number(value=100, precision=0, label="Diffusion steps") seq_length = gr.Number(value=256, label="Generation length (tokens)") with gr.Column(scale=3): output_text = gr.Textbox(label="Generated Text") progress_bar = gr.Slider( minimum=0, maximum=1, value=0, label="Generation Progress", interactive=False ) generate_btn = gr.Button("Generate") generate_btn.click( generate, inputs=[ input_text, simplex_value, top_p, temperature, diffusion_steps, seq_length, ], outputs=[output_text, progress_bar] ) demo.queue().launch() if __name__ == "__main__": main()