File size: 6,448 Bytes
17ff0d8
 
 
feab349
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50d00b1
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d4abc8
17ff0d8
 
 
 
 
 
 
 
 
 
 
887c784
17ff0d8
 
 
 
 
 
feab349
 
 
 
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1534dce
17ff0d8
 
 
 
1534dce
 
69df01c
 
 
 
887c784
3191580
69df01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17ff0d8
2403601
17ff0d8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import logging
import gradio as gr
import torch
import numpy as np
from transformers import (
    MODEL_FOR_MASKED_LM_MAPPING,
)

from sdlm.arguments import get_args
from sdlm.models.utils import load_model
from sdlm.pipelines.simplex_ddpm import SimplexDDPMPipeline
from sdlm.schedulers import TokenWiseSimplexDDPMScheduler

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def main():
    model_args, data_args, training_args, diffusion_args = get_args("args.json")
    tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger)

    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    pipeline = SimplexDDPMPipeline(
        model=model.to(device),
        scheduler=TokenWiseSimplexDDPMScheduler(
            num_train_timesteps=diffusion_args.num_train_timesteps
            if hasattr(diffusion_args, "num_train_timesteps") else 100,
            beta_schedule=getattr(diffusion_args, "beta_schedule", "squaredcos_improved_ddpm"),
            simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
            clip_sample=getattr(diffusion_args, "clip_sample", False),
            device=device,
        ),
        simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
        top_p=getattr(diffusion_args, "top_p", 0.99),
        sampling_type="top_p",
        is_conditional_generation=True,
        tokenizer=tokenizer,
        classifier_free_uncond_input="empty_token",
        temperature=getattr(diffusion_args, "temperature", 1.0),
        guidance_softmax_combination=True,
    )

    def generate(
        inputs,
        simplex_value=5.0,
        top_p=0.99,
        temperature=1.0,
        diffusion_steps=100,
        generated_sequence_length=256,
        beta_schedule="squaredcos_improved_ddpm",
        clip_sample=False,
        guidance_scale=1.0,
    ):
        """
        Gradio-friendly generation function. Adjusts the pipeline's parameters
        (simplex_value, top_p, etc.) as requested, then runs generation.
        """
        with torch.inference_mode():
            # Update pipeline scheduler with user-provided parameters:
            pipeline.scheduler.num_train_timesteps = diffusion_steps
            pipeline.scheduler.timesteps = diffusion_steps
            pipeline.scheduler.beta_schedule = beta_schedule
            pipeline.scheduler.simplex_value = simplex_value
            pipeline.scheduler.clip_sample = clip_sample
            pipeline.simplex_value = simplex_value
            pipeline.top_p = top_p
            pipeline.temperature = temperature
            
            # Ensure timesteps are properly set as a sequence
            pipeline.scheduler.timesteps = torch.arange(0, diffusion_steps).flip(0)
            
            # tulu chat template
            inputs = "<|user|>\n" + inputs + "<|assistant|>\n"

            # Tokenize and prepare input for diffusion
            tokenized_input = tokenizer([inputs], add_special_tokens=False, return_tensors="pt").input_ids
            tokenized_input_len = tokenized_input.shape[1]

            # Concatenate BOS + input + blank space for generation
            tokenized_input = torch.cat(
                [
                    torch.ones((1, 1), dtype=torch.long) * tokenizer.bos_token_id,
                    tokenized_input,
                    torch.ones((1, generated_sequence_length), dtype=torch.long) * tokenizer.pad_token_id,
                ],
                dim=-1,
            )

            # Create a mask over the generation region
            span_mask = torch.cat(
                [
                    torch.zeros((1, tokenized_input_len + 1), dtype=torch.bool),
                    torch.ones((1, generated_sequence_length), dtype=torch.bool),
                ],
                dim=-1,
            )

            batch = {
                "input_ids": tokenized_input.to(device),
                "span_mask": span_mask.to(device),
            }

            # Run sampling
            current_step = 0
            pipe = pipeline(batch=batch, seq_length=generated_sequence_length, guidance_scale=guidance_scale)
            for out in pipe:
                output_ids = out.logits.argmax(dim=-1)
                generated_tokens = output_ids[:, tokenized_input_len + 1 :]
                text_output = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
                current_step += 1
                progress = current_step / diffusion_steps
                yield [text_output, gr.Slider(value=progress, minimum=0, maximum=1, label=f"Step {current_step}/{diffusion_steps}")]

    with gr.Blocks() as demo:
        gr.Markdown("# TESS 2 Demo!")
        gr.Markdown("A live demo of TESS 2 v0.3. Check out the models and code [here](https://github.com/hamishivi/tess-2), or the paper [here](https://arxiv.org/abs/2502.13917)!")
        
        with gr.Row():
            with gr.Column(scale=2):
                input_text = gr.Textbox(lines=5, label="Input Prompt")
                simplex_value = gr.Number(value=5.0, label="Simplex value")
                top_p = gr.Slider(0, 1, value=0.99, step=0.01, label="Top-p")
                temperature = gr.Slider(0, 5, value=1.0, step=0.1, label="Temperature")
                diffusion_steps = gr.Number(value=100, precision=0, label="Diffusion steps")
                seq_length = gr.Number(value=256, label="Generation length (tokens)")
            
            with gr.Column(scale=3):
                output_text = gr.Textbox(label="Generated Text")
                progress_bar = gr.Slider(
                    minimum=0,
                    maximum=1,
                    value=0,
                    label="Generation Progress",
                    interactive=False
                )
                
                generate_btn = gr.Button("Generate")
                generate_btn.click(
                    generate,
                    inputs=[
                        input_text,
                        simplex_value,
                        top_p,
                        temperature,
                        diffusion_steps,
                        seq_length,
                    ],
                    outputs=[output_text, progress_bar]
                )

    demo.queue().launch()


if __name__ == "__main__":
    main()