Mihaiii's picture
Update app.py
d620330 verified
raw
history blame
4.96 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
from backtrack_sampler.provider.transformers_provider import TransformersProvider
import torch
import asyncio
import spaces
description = """## Compare Creative Writing: Custom Sampler vs. Backtrack Sampler with Creative Writing Strategy
This is a demo of [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) using one of its algorithms named "Creative Writing Strategy".
<br />On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler.
"""
# Load tokenizer
model_name = "unsloth/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load two instances of the model on CUDA for parallel inference
model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
model2 = AutoModelForCausalLM.from_pretrained(model_name)
device = torch.device('cuda')
strategy = CreativeWritingStrategy(top_p_flat = 0.8, top_k_threshold_flat = 2, min_prob_second_highest = 0.2)
provider = TransformersProvider(model2, tokenizer, device)
creative_sampler = BacktrackSampler(strategy, provider)
# Helper function to create message array for the chat template
def create_chat_template_messages(history, prompt):
messages = [{"role": "user", "content": prompt}]
for i, (input_text, response_text) in enumerate(history):
messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
messages.append({"role": "assistant", "content": response_text})
return messages
# Async function for generating responses using two models
@spaces.GPU(duration=60)
async def generate_responses(prompt, history):
# Create messages array for chat history and apply template
messages = create_chat_template_messages(history, prompt)
wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True)
#already has special tokens
inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
# Standard sampler task
standard_task = asyncio.to_thread(
model1.generate, inputs, max_length=2048, temperature=1
)
# Custom sampler task: loop over generator and collect outputs in a list
async def custom_sampler_task():
generated_list = []
generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1)
for token in generator:
generated_list.append(token)
return tokenizer.decode(generated_list, skip_special_tokens=True)
# Wait for both responses
standard_output, custom_output = await asyncio.gather(standard_task, custom_sampler_task())
# Decode standard output and remove the prompt from the generated response
standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
return standard_response.strip(), custom_output.strip()
# Create the Gradio interface with the Citrus theme
with gr.Blocks(theme=gr.themes.Citrus()) as demo:
gr.Markdown(description)
# Chatbot components
with gr.Row():
standard_chat = gr.Chatbot(label="Standard Sampler")
custom_chat = gr.Chatbot(label="Creative Writing Strategy")
# Input components
with gr.Row():
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)
# Example prompts
examples = [
"Write me a short story about a talking dog who wants to be a detective.",
"Tell me a short tale of a dragon who is afraid of heights.",
"Create a short story where aliens land on Earth, but they just want to throw a party."
]
# Add example buttons
gr.Examples(examples=examples, inputs=prompt_input)
# Button to submit the prompt
submit_button = gr.Button("Submit")
# Function to handle chat updates
async def update_chat(prompt, standard_history, custom_history):
standard_response, custom_response = await generate_responses(prompt, standard_history)
# Append new responses to chat histories
standard_history = standard_history + [(prompt, standard_response)]
custom_history = custom_history + [(prompt, custom_response)]
# Clear the input field after submission
return standard_history, custom_history, ""
# Bind the submit button to the update function and allow pressing Enter to submit
prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
# Launch the app with queueing and sharing enabled
demo.queue().launch(share=True, debug=True)