Mihaiii commited on
Commit
d620330
·
verified ·
1 Parent(s): cd0b23a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -107
app.py CHANGED
@@ -1,107 +1,108 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
4
- from backtrack_sampler.provider.transformers_provider import TransformersProvider
5
- import torch
6
- import asyncio
7
-
8
- description = """## Compare Creative Writing: Custom Sampler vs. Backtrack Sampler with Creative Writing Strategy
9
- This is a demo of [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) using one of its algorithms named "Creative Writing Strategy".
10
- <br />On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler.
11
- """
12
- # Load tokenizer
13
- model_name = "unsloth/Llama-3.2-1B-Instruct"
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
-
16
- # Load two instances of the model on CUDA for parallel inference
17
- model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
18
-
19
- model2 = AutoModelForCausalLM.from_pretrained(model_name)
20
- device = torch.device('cuda')
21
-
22
- strategy = CreativeWritingStrategy(top_p_flat = 0.8, top_k_threshold_flat = 2, min_prob_second_highest = 0.2)
23
- provider = TransformersProvider(model2, tokenizer, device)
24
- creative_sampler = BacktrackSampler(strategy, provider)
25
-
26
- # Helper function to create message array for the chat template
27
- def create_chat_template_messages(history, prompt):
28
- messages = [{"role": "user", "content": prompt}]
29
-
30
- for i, (input_text, response_text) in enumerate(history):
31
- messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
32
- messages.append({"role": "assistant", "content": response_text})
33
-
34
- return messages
35
-
36
- # Async function for generating responses using two models
37
- @spaces.GPU(duration=60)
38
- async def generate_responses(prompt, history):
39
- # Create messages array for chat history and apply template
40
- messages = create_chat_template_messages(history, prompt)
41
- wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True)
42
-
43
- #already has special tokens
44
- inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
45
- # Standard sampler task
46
- standard_task = asyncio.to_thread(
47
- model1.generate, inputs, max_length=2048, temperature=1
48
- )
49
-
50
- # Custom sampler task: loop over generator and collect outputs in a list
51
- async def custom_sampler_task():
52
- generated_list = []
53
- generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1)
54
- for token in generator:
55
- generated_list.append(token)
56
- return tokenizer.decode(generated_list, skip_special_tokens=True)
57
-
58
- # Wait for both responses
59
- standard_output, custom_output = await asyncio.gather(standard_task, custom_sampler_task())
60
- # Decode standard output and remove the prompt from the generated response
61
- standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
62
-
63
- return standard_response.strip(), custom_output.strip()
64
-
65
- # Create the Gradio interface with the Citrus theme
66
- with gr.Blocks(theme=gr.themes.Citrus()) as demo:
67
- gr.Markdown(description)
68
-
69
- # Chatbot components
70
- with gr.Row():
71
- standard_chat = gr.Chatbot(label="Standard Sampler")
72
- custom_chat = gr.Chatbot(label="Creative Writing Strategy")
73
-
74
- # Input components
75
- with gr.Row():
76
- prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)
77
-
78
- # Example prompts
79
- examples = [
80
- "Write me a short story about a talking dog who wants to be a detective.",
81
- "Tell me a short tale of a dragon who is afraid of heights.",
82
- "Create a short story where aliens land on Earth, but they just want to throw a party."
83
- ]
84
-
85
- # Add example buttons
86
- gr.Examples(examples=examples, inputs=prompt_input)
87
-
88
- # Button to submit the prompt
89
- submit_button = gr.Button("Submit")
90
-
91
- # Function to handle chat updates
92
- async def update_chat(prompt, standard_history, custom_history):
93
- standard_response, custom_response = await generate_responses(prompt, standard_history)
94
-
95
- # Append new responses to chat histories
96
- standard_history = standard_history + [(prompt, standard_response)]
97
- custom_history = custom_history + [(prompt, custom_response)]
98
-
99
- # Clear the input field after submission
100
- return standard_history, custom_history, ""
101
-
102
- # Bind the submit button to the update function and allow pressing Enter to submit
103
- prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
104
- submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
105
-
106
- # Launch the app with queueing and sharing enabled
107
- demo.queue().launch(share=True, debug=True)
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
4
+ from backtrack_sampler.provider.transformers_provider import TransformersProvider
5
+ import torch
6
+ import asyncio
7
+ import spaces
8
+
9
+ description = """## Compare Creative Writing: Custom Sampler vs. Backtrack Sampler with Creative Writing Strategy
10
+ This is a demo of [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) using one of its algorithms named "Creative Writing Strategy".
11
+ <br />On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler.
12
+ """
13
+ # Load tokenizer
14
+ model_name = "unsloth/Llama-3.2-1B-Instruct"
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+
17
+ # Load two instances of the model on CUDA for parallel inference
18
+ model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
19
+
20
+ model2 = AutoModelForCausalLM.from_pretrained(model_name)
21
+ device = torch.device('cuda')
22
+
23
+ strategy = CreativeWritingStrategy(top_p_flat = 0.8, top_k_threshold_flat = 2, min_prob_second_highest = 0.2)
24
+ provider = TransformersProvider(model2, tokenizer, device)
25
+ creative_sampler = BacktrackSampler(strategy, provider)
26
+
27
+ # Helper function to create message array for the chat template
28
+ def create_chat_template_messages(history, prompt):
29
+ messages = [{"role": "user", "content": prompt}]
30
+
31
+ for i, (input_text, response_text) in enumerate(history):
32
+ messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
33
+ messages.append({"role": "assistant", "content": response_text})
34
+
35
+ return messages
36
+
37
+ # Async function for generating responses using two models
38
+ @spaces.GPU(duration=60)
39
+ async def generate_responses(prompt, history):
40
+ # Create messages array for chat history and apply template
41
+ messages = create_chat_template_messages(history, prompt)
42
+ wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True)
43
+
44
+ #already has special tokens
45
+ inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
46
+ # Standard sampler task
47
+ standard_task = asyncio.to_thread(
48
+ model1.generate, inputs, max_length=2048, temperature=1
49
+ )
50
+
51
+ # Custom sampler task: loop over generator and collect outputs in a list
52
+ async def custom_sampler_task():
53
+ generated_list = []
54
+ generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1)
55
+ for token in generator:
56
+ generated_list.append(token)
57
+ return tokenizer.decode(generated_list, skip_special_tokens=True)
58
+
59
+ # Wait for both responses
60
+ standard_output, custom_output = await asyncio.gather(standard_task, custom_sampler_task())
61
+ # Decode standard output and remove the prompt from the generated response
62
+ standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
63
+
64
+ return standard_response.strip(), custom_output.strip()
65
+
66
+ # Create the Gradio interface with the Citrus theme
67
+ with gr.Blocks(theme=gr.themes.Citrus()) as demo:
68
+ gr.Markdown(description)
69
+
70
+ # Chatbot components
71
+ with gr.Row():
72
+ standard_chat = gr.Chatbot(label="Standard Sampler")
73
+ custom_chat = gr.Chatbot(label="Creative Writing Strategy")
74
+
75
+ # Input components
76
+ with gr.Row():
77
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)
78
+
79
+ # Example prompts
80
+ examples = [
81
+ "Write me a short story about a talking dog who wants to be a detective.",
82
+ "Tell me a short tale of a dragon who is afraid of heights.",
83
+ "Create a short story where aliens land on Earth, but they just want to throw a party."
84
+ ]
85
+
86
+ # Add example buttons
87
+ gr.Examples(examples=examples, inputs=prompt_input)
88
+
89
+ # Button to submit the prompt
90
+ submit_button = gr.Button("Submit")
91
+
92
+ # Function to handle chat updates
93
+ async def update_chat(prompt, standard_history, custom_history):
94
+ standard_response, custom_response = await generate_responses(prompt, standard_history)
95
+
96
+ # Append new responses to chat histories
97
+ standard_history = standard_history + [(prompt, standard_response)]
98
+ custom_history = custom_history + [(prompt, custom_response)]
99
+
100
+ # Clear the input field after submission
101
+ return standard_history, custom_history, ""
102
+
103
+ # Bind the submit button to the update function and allow pressing Enter to submit
104
+ prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
105
+ submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
106
+
107
+ # Launch the app with queueing and sharing enabled
108
+ demo.queue().launch(share=True, debug=True)