File size: 4,964 Bytes
d620330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
from backtrack_sampler.provider.transformers_provider import TransformersProvider
import torch
import asyncio
import spaces

description = """## Compare Creative Writing: Custom Sampler vs. Backtrack Sampler with Creative Writing Strategy
This is a demo of [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) using one of its algorithms named "Creative Writing Strategy".
<br />On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler.
"""
# Load tokenizer
model_name = "unsloth/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load two instances of the model on CUDA for parallel inference
model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")

model2 = AutoModelForCausalLM.from_pretrained(model_name)
device = torch.device('cuda')

strategy = CreativeWritingStrategy(top_p_flat = 0.8, top_k_threshold_flat = 2, min_prob_second_highest = 0.2)
provider = TransformersProvider(model2, tokenizer, device)
creative_sampler = BacktrackSampler(strategy, provider)

# Helper function to create message array for the chat template
def create_chat_template_messages(history, prompt):
    messages = [{"role": "user", "content": prompt}]
    
    for i, (input_text, response_text) in enumerate(history):
        messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
        messages.append({"role": "assistant", "content": response_text})
    
    return messages

# Async function for generating responses using two models
@spaces.GPU(duration=60)
async def generate_responses(prompt, history):
    # Create messages array for chat history and apply template
    messages = create_chat_template_messages(history, prompt)
    wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True)

    #already has special tokens
    inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
    # Standard sampler task
    standard_task = asyncio.to_thread(
        model1.generate, inputs, max_length=2048, temperature=1
    )

    # Custom sampler task: loop over generator and collect outputs in a list
    async def custom_sampler_task():
        generated_list = []
        generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1)
        for token in generator:
            generated_list.append(token)
        return tokenizer.decode(generated_list, skip_special_tokens=True)

    # Wait for both responses
    standard_output, custom_output = await asyncio.gather(standard_task, custom_sampler_task())
    # Decode standard output and remove the prompt from the generated response
    standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)

    return standard_response.strip(), custom_output.strip()

# Create the Gradio interface with the Citrus theme
with gr.Blocks(theme=gr.themes.Citrus()) as demo:
    gr.Markdown(description)

    # Chatbot components
    with gr.Row():
        standard_chat = gr.Chatbot(label="Standard Sampler")
        custom_chat = gr.Chatbot(label="Creative Writing Strategy")

    # Input components
    with gr.Row():
        prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)

    # Example prompts
    examples = [
        "Write me a short story about a talking dog who wants to be a detective.",
        "Tell me a short tale of a dragon who is afraid of heights.",
        "Create a short story where aliens land on Earth, but they just want to throw a party."
    ]

    # Add example buttons
    gr.Examples(examples=examples, inputs=prompt_input)

    # Button to submit the prompt
    submit_button = gr.Button("Submit")

    # Function to handle chat updates
    async def update_chat(prompt, standard_history, custom_history):
        standard_response, custom_response = await generate_responses(prompt, standard_history)

        # Append new responses to chat histories
        standard_history = standard_history + [(prompt, standard_response)]
        custom_history = custom_history + [(prompt, custom_response)]

        # Clear the input field after submission
        return standard_history, custom_history, ""

    # Bind the submit button to the update function and allow pressing Enter to submit
    prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
    submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])

# Launch the app with queueing and sharing enabled
demo.queue().launch(share=True, debug=True)