Mihaiii's picture
Update app.py
17af049 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
from backtrack_sampler.provider.transformers_provider import TransformersProvider
import torch
import spaces
import asyncio
description = """## Compare Creative Writing: Standard Sampler vs. Backtrack Sampler with Creative Writing Strategy
This is a demo of the [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) framework using "Creative Writing Strategy".
<br />On the left is the output of the standard sampler and on the right the output privided by Backtrack Sampler.
"""
model_name = "unsloth/Llama-3.2-1B-Instruct"
device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_name)
model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
model2 = AutoModelForCausalLM.from_pretrained(model_name)
provider = TransformersProvider(model2, tokenizer, device)
strategy = CreativeWritingStrategy(provider,
top_p_flat = 0.65,
top_k_threshold_flat = 9,
eos_penalty = 0.75)
creative_sampler = BacktrackSampler(provider, strategy)
def create_chat_template_messages(history, prompt):
messages = [{"role": "user", "content": prompt}]
for i, (input_text, response_text) in enumerate(history):
messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
messages.append({"role": "assistant", "content": response_text})
return messages
@spaces.GPU(duration=60)
def generate_responses(prompt, history):
messages = create_chat_template_messages(history, prompt)
wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
#it already has special tokens from wrapped_prompt
inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
async def custom_sampler_task():
generated_list = []
generator = creative_sampler.generate(wrapped_prompt, max_new_tokens=1024, temperature=1)
for token in generator:
generated_list.append(token)
return tokenizer.decode(generated_list, skip_special_tokens=True)
custom_output = asyncio.run(custom_sampler_task())
standard_output = model1.generate(inputs, max_new_tokens=1024, temperature=1)
standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
return standard_response.strip(), custom_output.strip()
with gr.Blocks(theme=gr.themes.Citrus()) as demo:
gr.Markdown(description)
with gr.Row():
standard_chat = gr.Chatbot(label="Standard Sampler")
custom_chat = gr.Chatbot(label="Creative Writing Strategy")
with gr.Row():
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)
examples = [
"Write me a short story about a talking dog who wants to be a detective.",
"Tell me a short tale of a dragon who is afraid of heights.",
"Create a short story where aliens land on Earth, but they just want to throw a party."
]
gr.Examples(examples=examples, inputs=prompt_input)
submit_button = gr.Button("Submit")
def update_chat(prompt, standard_history, custom_history):
standard_response, custom_response = generate_responses(prompt, standard_history)
standard_history = standard_history + [(prompt, standard_response)]
custom_history = custom_history + [(prompt, custom_response)]
return standard_history, custom_history, ""
prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
demo.queue().launch(debug=True)