import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy from backtrack_sampler.provider.transformers_provider import TransformersProvider import torch import asyncio import spaces description = """## Compare Creative Writing: Custom Sampler vs. Backtrack Sampler with Creative Writing Strategy This is a demo of [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) using one of its algorithms named "Creative Writing Strategy".
On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler. """ # Load tokenizer model_name = "unsloth/Llama-3.2-1B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) # Load two instances of the model on CUDA for parallel inference model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda") model2 = AutoModelForCausalLM.from_pretrained(model_name) device = torch.device('cuda') strategy = CreativeWritingStrategy(top_p_flat = 0.8, top_k_threshold_flat = 2, min_prob_second_highest = 0.2) provider = TransformersProvider(model2, tokenizer, device) creative_sampler = BacktrackSampler(strategy, provider) # Helper function to create message array for the chat template def create_chat_template_messages(history, prompt): messages = [{"role": "user", "content": prompt}] for i, (input_text, response_text) in enumerate(history): messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text}) messages.append({"role": "assistant", "content": response_text}) return messages # Async function for generating responses using two models @spaces.GPU(duration=60) async def generate_responses(prompt, history): # Create messages array for chat history and apply template messages = create_chat_template_messages(history, prompt) wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True) #already has special tokens inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda") # Standard sampler task standard_task = asyncio.to_thread( model1.generate, inputs, max_length=2048, temperature=1 ) # Custom sampler task: loop over generator and collect outputs in a list async def custom_sampler_task(): generated_list = [] generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1) for token in generator: generated_list.append(token) return tokenizer.decode(generated_list, skip_special_tokens=True) # Wait for both responses standard_output, custom_output = await asyncio.gather(standard_task, custom_sampler_task()) # Decode standard output and remove the prompt from the generated response standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True) return standard_response.strip(), custom_output.strip() # Create the Gradio interface with the Citrus theme with gr.Blocks(theme=gr.themes.Citrus()) as demo: gr.Markdown(description) # Chatbot components with gr.Row(): standard_chat = gr.Chatbot(label="Standard Sampler") custom_chat = gr.Chatbot(label="Creative Writing Strategy") # Input components with gr.Row(): prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1) # Example prompts examples = [ "Write me a short story about a talking dog who wants to be a detective.", "Tell me a short tale of a dragon who is afraid of heights.", "Create a short story where aliens land on Earth, but they just want to throw a party." ] # Add example buttons gr.Examples(examples=examples, inputs=prompt_input) # Button to submit the prompt submit_button = gr.Button("Submit") # Function to handle chat updates async def update_chat(prompt, standard_history, custom_history): standard_response, custom_response = await generate_responses(prompt, standard_history) # Append new responses to chat histories standard_history = standard_history + [(prompt, standard_response)] custom_history = custom_history + [(prompt, custom_response)] # Clear the input field after submission return standard_history, custom_history, "" # Bind the submit button to the update function and allow pressing Enter to submit prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input]) submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input]) # Launch the app with queueing and sharing enabled demo.queue().launch(share=True, debug=True)