m-ric's picture
m-ric HF staff
Update app.py
fb66618 verified
raw
history blame
6.28 kB
import torch
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
STYLE = """
.container {
width: 100%;
display: grid;
align-items: center;
margin: 0!important;
}
.prose ul ul {
margin: 0!important;
font-size: 13px!important;
}
.tree {
padding: 0px;
margin: 0!important;
box-sizing: border-box;
font-size: 16px;
width: 100%;
height: auto;
text-align: center;
}
.tree ul {
padding-top: 20px;
position: relative;
transition: .5s;
margin: 0!important;
}
.tree li {
display: inline-table;
text-align: center;
list-style-type: none;
position: relative;
padding: 10px;
transition: .5s;
}
.tree li::before, .tree li::after {
content: '';
position: absolute;
top: 0;
right: 50%;
border-top: 1px solid #ccc;
width: 51%;
height: 10px;
}
.tree li::after {
right: auto;
left: 50%;
border-left: 1px solid #ccc;
}
.tree li:only-child::after, .tree li:only-child::before {
display: none;
}
.tree li:only-child {
padding-top: 0;
}
.tree li:first-child::before, .tree li:last-child::after {
border: 0 none;
}
.tree li:last-child::before {
border-right: 1px solid #ccc;
border-radius: 0 5px 0 0;
-webkit-border-radius: 0 5px 0 0;
-moz-border-radius: 0 5px 0 0;
}
.tree li:first-child::after {
border-radius: 5px 0 0 0;
-webkit-border-radius: 5px 0 0 0;
-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
content: '';
position: absolute;
top: 0;
left: 50%;
border-left: 1px solid #ccc;
width: 0;
height: 20px;
}
.tree li a {
border: 1px solid #ccc;
padding: 10px;
display: inline-grid;
border-radius: 5px;
text-decoration-line: none;
border-radius: 5px;
transition: .5s;
}
.tree li a span {
border: 1px solid #ccc;
border-radius: 5px;
color: #666;
padding: 8px;
font-size: 12px;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a {
background: #c8e4f8;
color: #000;
border: 1px solid #94a0b4;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
border-color: #94a0b4;
}
"""
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Loading finished.")
def generate_html(token, node):
"""Recursively generate HTML for the tree."""
html_content = f" <li> <a href='#'> <span> <b>{token}</b> </span> "
html_content += node["table"] if node["table"] is not None else ""
html_content += "</a>"
if len(node["children"].keys()) > 0:
html_content += "<ul> "
for token, subnode in node["children"].items():
html_content += generate_html(token, subnode)
html_content += "</ul>"
html_content += "</li>"
return html_content
def generate_markdown_table(scores, top_k=4, chosen_tokens=None):
markdown_table = """
<table>
<tr>
<th><b>Token</b></th>
<th><b>Probability</b></th>
</tr>"""
for token_idx in np.argsort(scores)[-top_k:]:
token = tokenizer.decode([token_idx])
style = ""
if chosen_tokens and token in chosen_tokens:
style = "background-color:red"
markdown_table += f"""
<tr style={style}>
<td>{token}</td>
<td>{scores[token_idx]}</td>
</tr>"""
markdown_table += """
</table>"""
return markdown_table
def display_tree(start_sentence, scores, sequences, beam_indices):
display = """<div class="container">
<div class="tree">
<ul>"""
sequences = sequences.cpu().numpy()
print(tokenizer.batch_decode(sequences))
original_tree = {"table": None, "children": {}}
for sequence_ix in range(len(sequences)):
current_tree = original_tree
for step, step_scores in enumerate(scores):
current_token_choice = tokenizer.decode([sequences[sequence_ix, step]])
current_beam = beam_indices[sequence_ix, step]
if current_token_choice not in current_tree["children"]:
current_tree["children"][current_token_choice] = {
"table": None,
"children": {},
}
# Rewrite the probs table even if it was there before, since new chosen nodes have appeared in the children of current tree
markdown_table = generate_markdown_table(
step_scores[current_beam, :],
chosen_tokens=current_tree["children"].keys(),
)
current_tree["table"] = markdown_table
current_tree = current_tree["children"][current_token_choice]
display += generate_html(start_sentence, original_tree)
display += """
</ul>
</div>
</body>
"""
return display
@spaces.GPU
def get_tables(input_text, number_steps, number_beams):
inputs = tokenizer([input_text], return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=number_steps,
num_beams=number_beams,
num_return_sequences=number_beams,
return_dict_in_generate=True,
output_scores=True,
top_k=5,
temperature=1.0,
do_sample=True,
)
tables = display_tree(
input_text,
outputs.scores,
outputs.sequences[:, len(inputs) :],
outputs.beam_indices[:, : -len(inputs)],
)
return tables
with gr.Blocks(
theme=gr.themes.Soft(
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green
),
css=STYLE,
) as demo:
text = gr.Textbox(label="Sentence to decode from", value="Today is")
steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4)
beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
button = gr.Button()
out = gr.Markdown(label="Output")
button.click(get_tables, inputs=[text, steps, beams], outputs=out)
demo.launch()