import gradio as gr STYLE = """ @import url('https://fonts.googleapis.com/css2?family=Poppins:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap'); * { padding: 0px; margin: 0px; box-sizing: border-box; font-size: 16px; } body { height: 100vh; width: 100vw; display: grid; align-items: center; font-family: 'Poppins', sans-serif; } .tree { width: 100%; height: auto; text-align: center; } .tree ul { padding-top: 20px; position: relative; transition: .5s; } .tree li { display: flex; flex-direction:row; text-align: center; list-style-type: none; position: relative; padding: 10px; transition: .5s; } .tree li::before, .tree li::after { content: ''; position: absolute; top: 0; right: 50%; border-top: 1px solid #ccc; width: 51%; height: 10px; } .tree li::after { right: auto; left: 50%; border-left: 1px solid #ccc; } .tree li:only-child::after, .tree li:only-child::before { display: none; } .tree li:only-child { padding-top: 0; } .tree li:first-child::before, .tree li:last-child::after { border: 0 none; } .tree li:last-child::before { border-right: 1px solid #ccc; border-radius: 0 5px 0 0; -webkit-border-radius: 0 5px 0 0; -moz-border-radius: 0 5px 0 0; } .tree li:first-child::after { border-radius: 5px 0 0 0; -webkit-border-radius: 5px 0 0 0; -moz-border-radius: 5px 0 0 0; } .tree ul ul::before { content: ''; position: absolute; top: 0; left: 50%; border-left: 1px solid #ccc; width: 0; height: 20px; } .tree li a { border: 1px solid #ccc; padding: 10px; display: inline-grid; border-radius: 5px; text-decoration-line: none; border-radius: 5px; transition: .5s; } .tree li a img { width: 50px; height: 50px; margin-bottom: 10px !important; border-radius: 100px; margin: auto; } .tree li a span { border: 1px solid #ccc; border-radius: 5px; color: #666; padding: 8px; font-size: 12px; text-transform: uppercase; letter-spacing: 1px; font-weight: 500; } /*Hover-Section*/ .tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a { background: #c8e4f8; color: #000; border: 1px solid #94a0b4; } .tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before { border-color: #94a0b4; } """ from transformers import GPT2Tokenizer, AutoModelForCausalLM import numpy as np tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = AutoModelForCausalLM.from_pretrained("gpt2") tokenizer.pad_token_id = tokenizer.eos_token_id def display_top_k_tokens(scores, sequences, beam_indices): display = "
" for i, sequence in enumerate(sequences): markdown_table = f"""

Sequence {i}: {tokenizer.batch_decode(sequence)}


""" for step, step_scores in enumerate(scores): markdown_table += f""" """ current_beam = beam_indices[i, step] chosen_token = sequences[i, step] for token_idx in np.argsort(step_scores[current_beam, :])[-5:]: if token_idx == chosen_token: markdown_table += f""" """ else: markdown_table += f""" """ markdown_table += "
Token Probability
Step {step} =====
{tokenizer.decode([token_idx])} {step_scores[current_beam, token_idx]}
{tokenizer.decode([token_idx])} {step_scores[current_beam, token_idx]}
" display += markdown_table display += "

" print(display) return display def generate_html(token, node): """Recursively generate HTML for the tree.""" html_content = f" " return html_content def generate_markdown_table(scores, top_k=4, chosen_tokens=None): markdown_table = """ """ for token_idx in np.argsort(scores)[-top_k:]: token = tokenizer.decode([token_idx]) style = "" if chosen_tokens and token in chosen_tokens: style = "background-color:red" markdown_table += f""" """ markdown_table += """
Token Probability
{token} {scores[token_idx]}
""" return markdown_table def display_tree(scores, sequences, beam_indices): display = """
""" sequences = sequences.cpu().numpy() print(tokenizer.batch_decode(sequences)) original_tree = {"table": None, "children": {}} for sequence_ix in range(len(sequences)): current_tree = original_tree for step, step_scores in enumerate(scores): current_token_choice = tokenizer.decode([sequences[sequence_ix, step]]) current_beam = beam_indices[sequence_ix, step] if current_token_choice not in current_tree["children"]: current_tree["children"][current_token_choice] = { "table": None, "children": {}, } # Rewrite the probs table even if it was there before, since new chosen nodes have appeared in the children of current tree markdown_table = generate_markdown_table( step_scores[current_beam, :], chosen_tokens=current_tree["children"].keys(), ) current_tree["table"] = markdown_table current_tree = current_tree["children"][current_token_choice] display += generate_html("Today is", original_tree) display += """
""" print(display) return display def get_tables(input_text, number_steps, number_beams): inputs = tokenizer([input_text], return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=number_steps, num_beams=number_beams, num_return_sequences=number_beams, return_dict_in_generate=True, output_scores=True, top_k=5, temperature=1.0, do_sample=True, ) tables = display_tree( outputs.scores, outputs.sequences[:, len(inputs) :], outputs.beam_indices[:, : -len(inputs)], ) return tables with gr.Blocks( theme=gr.themes.Soft( text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green ), css=STYLE, ) as demo: text = gr.Textbox(label="Sentence to decode from🪶", value="Today is") steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4) beams = gr.Slider(label="Number of beams", minimum=1, maximum=3, step=1, value=3) button = gr.Button() out = gr.Markdown(label="Output") button.click(get_tables, inputs=[text, steps, beams], outputs=out) demo.launch()