import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

STYLE = """
.container {
	width: 100%;
	display: grid;
	align-items: center;
    margin: 0!important;
}
.prose ul ul {
    margin: 0!important;
}
.tree {
	padding: 0px;
	margin: 0!important;
	box-sizing: border-box;
    font-size: 16px;
	width: 100%;
	height: auto;
	text-align: center;
}
.tree ul {
	padding-top: 20px;
	position: relative;
	transition: .5s;
    margin: 0!important;
}
.tree li {
	display: inline-table;
	text-align: center;
	list-style-type: none;
	position: relative;
	padding: 10px;
	transition: .5s;
}
.tree li::before, .tree li::after {
	content: '';
	position: absolute;
	top: 0;
	right: 50%;
	border-top: 1px solid #ccc;
	width: 51%;
	height: 10px;
}
.tree li::after {
	right: auto;
	left: 50%;
	border-left: 1px solid #ccc;
}
.tree li:only-child::after, .tree li:only-child::before {
	display: none;
}
.tree li:only-child {
	padding-top: 0;
}
.tree li:first-child::before, .tree li:last-child::after {
	border: 0 none;
}
.tree li:last-child::before {
	border-right: 1px solid #ccc;
	border-radius: 0 5px 0 0;
	-webkit-border-radius: 0 5px 0 0;
	-moz-border-radius: 0 5px 0 0;
}
.tree li:first-child::after {
	border-radius: 5px 0 0 0;
	-webkit-border-radius: 5px 0 0 0;
	-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
	content: '';
	position: absolute;
	top: 0;
	left: 50%;
	border-left: 1px solid #ccc;
	width: 0;
	height: 20px;
}
.tree li a {
	border: 1px solid #ccc;
	padding: 10px;
	display: inline-grid;
	border-radius: 5px;
	text-decoration-line: none;
	border-radius: 5px;
	transition: .5s;
}
.tree li a span {
	border: 1px solid #ccc;
	border-radius: 5px;
	color: #666;
	padding: 8px;
	font-size: 12px;
	text-transform: uppercase;
	letter-spacing: 1px;
	font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a {
	background: #c8e4f8;
	color: #000;
	border: 1px solid #94a0b4;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
	border-color: #94a0b4;
}
"""

from transformers import GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("Locutusque/TinyMistral-248M-v2")
model = AutoModelForCausalLM.from_pretrained("Locutusque/TinyMistral-248M-v2")
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Loading finished.")
def generate_html(token, node):
    """Recursively generate HTML for the tree."""

    html_content = f" <li> <a href='#'> <span> <b>{token}</b> </span> "
    html_content += node["table"] if node["table"] is not None else ""
    html_content += "</a>"
    if len(node["children"].keys()) > 0:
        html_content += "<ul> "
        for token, subnode in node["children"].items():
            html_content += generate_html(token, subnode)
        html_content += "</ul>"

    html_content += "</li>"

    return html_content


def generate_markdown_table(scores, top_k=4, chosen_tokens=None):
    markdown_table = """
    <table>
        <tr>
            <th><b>Token</b></th>
            <th><b>Probability</b></th>
        </tr>"""
    for token_idx in np.argsort(scores)[-top_k:]:
        token = tokenizer.decode([token_idx])
        style = ""
        if chosen_tokens and token in chosen_tokens:
            style = "background-color:red"
        markdown_table += f"""
        <tr style={style}>
            <td>{token}</td>
            <td>{scores[token_idx]}</td>
        </tr>"""
    markdown_table += """
    </table>"""
    return markdown_table


def display_tree(start_sentence, scores, sequences, beam_indices):
    display = """<div class="container">
				<div class="tree">
                <ul>"""
    sequences = sequences.cpu().numpy()
    print(tokenizer.batch_decode(sequences))
    original_tree = {"table": None, "children": {}}
    for sequence_ix in range(len(sequences)):
        current_tree = original_tree
        for step, step_scores in enumerate(scores):
            current_token_choice = tokenizer.decode([sequences[sequence_ix, step]])
            current_beam = beam_indices[sequence_ix, step]

            if current_token_choice not in current_tree["children"]:
                current_tree["children"][current_token_choice] = {
                    "table": None,
                    "children": {},
                }

            # Rewrite the probs table even if it was there before, since new chosen nodes have appeared in the children of current tree
            markdown_table = generate_markdown_table(
                step_scores[current_beam, :],
                chosen_tokens=current_tree["children"].keys(),
            )
            current_tree["table"] = markdown_table

            current_tree = current_tree["children"][current_token_choice]

    display += generate_html(start_sentence, original_tree)

    display += """
        </ul>
        </div>
    </body>
    """
    return display


def get_tables(input_text, number_steps, number_beams):
    inputs = tokenizer([input_text], return_tensors="pt")

    outputs = model.generate(
        **inputs,
        max_new_tokens=number_steps,
        num_beams=number_beams,
        num_return_sequences=number_beams,
        return_dict_in_generate=True,
        output_scores=True,
        top_k=5,
        temperature=1.0,
        do_sample=True,
    )

    tables = display_tree(
        input_text,
        outputs.scores,
        outputs.sequences[:, len(inputs) :],
        outputs.beam_indices[:, : -len(inputs)],
    )
    return tables
    
import gradio as gr

with gr.Blocks(
    theme=gr.themes.Soft(
        text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green
    ),
    css=STYLE,
) as demo:
    text = gr.Textbox(label="Sentence to decode from🪶", value="Today is")
    steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4)
    beams = gr.Slider(label="Number of beams", minimum=1, maximum=3, step=1, value=3)
    button = gr.Button()
    out = gr.Markdown(label="Output")
    button.click(get_tables, inputs=[text, steps, beams], outputs=out)

demo.launch()