Spaces:
Sleeping
Sleeping
File size: 6,448 Bytes
17ff0d8 feab349 17ff0d8 50d00b1 17ff0d8 1d4abc8 17ff0d8 887c784 17ff0d8 feab349 17ff0d8 1534dce 17ff0d8 1534dce 69df01c 887c784 3191580 69df01c 17ff0d8 2403601 17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import logging
import gradio as gr
import torch
import numpy as np
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
)
from sdlm.arguments import get_args
from sdlm.models.utils import load_model
from sdlm.pipelines.simplex_ddpm import SimplexDDPMPipeline
from sdlm.schedulers import TokenWiseSimplexDDPMScheduler
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def main():
model_args, data_args, training_args, diffusion_args = get_args("args.json")
tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = SimplexDDPMPipeline(
model=model.to(device),
scheduler=TokenWiseSimplexDDPMScheduler(
num_train_timesteps=diffusion_args.num_train_timesteps
if hasattr(diffusion_args, "num_train_timesteps") else 100,
beta_schedule=getattr(diffusion_args, "beta_schedule", "squaredcos_improved_ddpm"),
simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
clip_sample=getattr(diffusion_args, "clip_sample", False),
device=device,
),
simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
top_p=getattr(diffusion_args, "top_p", 0.99),
sampling_type="top_p",
is_conditional_generation=True,
tokenizer=tokenizer,
classifier_free_uncond_input="empty_token",
temperature=getattr(diffusion_args, "temperature", 1.0),
guidance_softmax_combination=True,
)
def generate(
inputs,
simplex_value=5.0,
top_p=0.99,
temperature=1.0,
diffusion_steps=100,
generated_sequence_length=256,
beta_schedule="squaredcos_improved_ddpm",
clip_sample=False,
guidance_scale=1.0,
):
"""
Gradio-friendly generation function. Adjusts the pipeline's parameters
(simplex_value, top_p, etc.) as requested, then runs generation.
"""
with torch.inference_mode():
# Update pipeline scheduler with user-provided parameters:
pipeline.scheduler.num_train_timesteps = diffusion_steps
pipeline.scheduler.timesteps = diffusion_steps
pipeline.scheduler.beta_schedule = beta_schedule
pipeline.scheduler.simplex_value = simplex_value
pipeline.scheduler.clip_sample = clip_sample
pipeline.simplex_value = simplex_value
pipeline.top_p = top_p
pipeline.temperature = temperature
# Ensure timesteps are properly set as a sequence
pipeline.scheduler.timesteps = torch.arange(0, diffusion_steps).flip(0)
# tulu chat template
inputs = "<|user|>\n" + inputs + "<|assistant|>\n"
# Tokenize and prepare input for diffusion
tokenized_input = tokenizer([inputs], add_special_tokens=False, return_tensors="pt").input_ids
tokenized_input_len = tokenized_input.shape[1]
# Concatenate BOS + input + blank space for generation
tokenized_input = torch.cat(
[
torch.ones((1, 1), dtype=torch.long) * tokenizer.bos_token_id,
tokenized_input,
torch.ones((1, generated_sequence_length), dtype=torch.long) * tokenizer.pad_token_id,
],
dim=-1,
)
# Create a mask over the generation region
span_mask = torch.cat(
[
torch.zeros((1, tokenized_input_len + 1), dtype=torch.bool),
torch.ones((1, generated_sequence_length), dtype=torch.bool),
],
dim=-1,
)
batch = {
"input_ids": tokenized_input.to(device),
"span_mask": span_mask.to(device),
}
# Run sampling
current_step = 0
pipe = pipeline(batch=batch, seq_length=generated_sequence_length, guidance_scale=guidance_scale)
for out in pipe:
output_ids = out.logits.argmax(dim=-1)
generated_tokens = output_ids[:, tokenized_input_len + 1 :]
text_output = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
current_step += 1
progress = current_step / diffusion_steps
yield [text_output, gr.Slider(value=progress, minimum=0, maximum=1, label=f"Step {current_step}/{diffusion_steps}")]
with gr.Blocks() as demo:
gr.Markdown("# TESS 2 Demo!")
gr.Markdown("A live demo of TESS 2 v0.3. Check out the models and code [here](https://github.com/hamishivi/tess-2), or the paper [here](https://arxiv.org/abs/2502.13917)!")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(lines=5, label="Input Prompt")
simplex_value = gr.Number(value=5.0, label="Simplex value")
top_p = gr.Slider(0, 1, value=0.99, step=0.01, label="Top-p")
temperature = gr.Slider(0, 5, value=1.0, step=0.1, label="Temperature")
diffusion_steps = gr.Number(value=100, precision=0, label="Diffusion steps")
seq_length = gr.Number(value=256, label="Generation length (tokens)")
with gr.Column(scale=3):
output_text = gr.Textbox(label="Generated Text")
progress_bar = gr.Slider(
minimum=0,
maximum=1,
value=0,
label="Generation Progress",
interactive=False
)
generate_btn = gr.Button("Generate")
generate_btn.click(
generate,
inputs=[
input_text,
simplex_value,
top_p,
temperature,
diffusion_steps,
seq_length,
],
outputs=[output_text, progress_bar]
)
demo.queue().launch()
if __name__ == "__main__":
main() |