MaxLSB commited on
Commit
954f37f
·
verified ·
1 Parent(s): 7b4f2fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -34
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  import threading
3
  from collections import defaultdict
4
- from PIL import Image
5
- import tempfile
6
  import gradio as gr
7
  from transformers import (
8
  AutoModelForCausalLM,
@@ -10,14 +9,17 @@ from transformers import (
10
  TextIteratorStreamer,
11
  )
12
 
 
13
  model_name_to_path = {
14
  "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
15
  "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
16
  "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
17
  }
18
 
 
19
  hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
20
 
 
21
  loaded_models = defaultdict(dict)
22
 
23
  for name, path in model_name_to_path.items():
@@ -25,24 +27,40 @@ for name, path in model_name_to_path.items():
25
  loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
26
  loaded_models[name]["model"].eval()
27
 
28
- def resize_logo(input_path, size=(100, 100)):
29
- with Image.open(input_path) as img:
30
- img = img.resize(size, Image.LANCZOS)
31
- temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
32
- img.save(temp_file.name, format="PNG")
33
- return temp_file.name
34
-
35
  def respond(message, history, model_name, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  history = history + [(message, "")]
37
- yield history
 
 
38
  tokenizer = loaded_models[model_name]["tokenizer"]
39
  model = loaded_models[model_name]["model"]
 
 
40
  inputs = tokenizer(message, return_tensors="pt")
 
 
41
  streamer = TextIteratorStreamer(
42
  tokenizer,
43
  skip_prompt=False,
44
  skip_special_tokens=True,
45
  )
 
 
46
  generate_kwargs = dict(
47
  **inputs,
48
  streamer=streamer,
@@ -52,62 +70,82 @@ def respond(message, history, model_name, max_tokens, temperature, top_p):
52
  top_p=top_p,
53
  eos_token_id=tokenizer.eos_token_id,
54
  )
 
 
55
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
56
  thread.start()
57
- accumulated = f"**{model_name}**\n\n"
 
 
58
  for new_text in streamer:
59
  accumulated += new_text
60
  history[-1] = (message, accumulated)
61
  yield history
62
 
63
  def submit(message, history, model_name, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
65
  yield updated_history, ""
66
 
67
- def start_with_example(example, model_name, max_tokens, temperature, top_p):
68
- for updated_history in respond(example, [], model_name, max_tokens, temperature, top_p):
69
- yield updated_history, ""
70
-
71
- resized_logo_path = resize_logo("media/le-carnet.png", size=(100, 100))
72
-
73
- examples = [
74
- "Il était une fois un petit garçon qui vivait dans un village paisible.",
75
- "Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang.",
76
- "Il était une fois un petit lapin perdu",
77
- ]
78
-
79
  with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
 
80
  gr.Markdown("# LeCarnet")
81
- gr.Markdown("Select a model on the right and type a message to chat, or choose an example below.")
 
 
82
  with gr.Row():
 
83
  with gr.Column(scale=4):
84
- dataset = gr.Dataset(components=[gr.Textbox(visible=False)], samples=[[ex] for ex in examples], type="values")
85
  chatbot = gr.Chatbot(
86
- avatar_images=(None, resized_logo_path),
87
  label="Chat",
88
- height=600,
89
  )
90
  user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
91
  submit_btn = gr.Button("Send")
 
 
 
 
 
 
 
 
 
 
 
92
  with gr.Column(scale=1, min_width=200):
 
93
  model_dropdown = gr.Dropdown(
94
- choices=list(model_name_to_path.keys()),
95
  value="LeCarnet-8M",
96
- label="Model"
97
  )
 
98
  max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
99
  temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
100
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
 
 
101
  submit_btn.click(
102
  fn=submit,
103
  inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
104
  outputs=[chatbot, user_input],
105
  )
106
- dataset.change(
107
- fn=start_with_example,
108
- inputs=[dataset, model_dropdown, max_tokens, temperature, top_p],
109
- outputs=[chatbot, user_input],
110
- )
111
 
112
  if __name__ == "__main__":
113
  demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
 
1
  import os
2
  import threading
3
  from collections import defaultdict
4
+
 
5
  import gradio as gr
6
  from transformers import (
7
  AutoModelForCausalLM,
 
9
  TextIteratorStreamer,
10
  )
11
 
12
+ # Define model paths
13
  model_name_to_path = {
14
  "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
15
  "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
16
  "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
17
  }
18
 
19
+ # Load Hugging Face token
20
  hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
21
 
22
+ # Preload models and tokenizers
23
  loaded_models = defaultdict(dict)
24
 
25
  for name, path in model_name_to_path.items():
 
27
  loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
28
  loaded_models[name]["model"].eval()
29
 
 
 
 
 
 
 
 
30
  def respond(message, history, model_name, max_tokens, temperature, top_p):
31
+ """
32
+ Generate a response from the selected model, streaming the output and updating chat history.
33
+
34
+ Args:
35
+ message (str): User's input message.
36
+ history (list): Current chat history as list of (user_msg, bot_msg) tuples.
37
+ model_name (str): Selected model name.
38
+ max_tokens (int): Maximum number of tokens to generate.
39
+ temperature (float): Sampling temperature.
40
+ top_p (float): Top-p sampling parameter.
41
+
42
+ Yields:
43
+ list: Updated chat history with the user's message and streaming bot response.
44
+ """
45
+ # Append user's message to history with an empty bot response
46
  history = history + [(message, "")]
47
+ yield history # Display user's message immediately
48
+
49
+ # Select tokenizer and model
50
  tokenizer = loaded_models[model_name]["tokenizer"]
51
  model = loaded_models[model_name]["model"]
52
+
53
+ # Tokenize input
54
  inputs = tokenizer(message, return_tensors="pt")
55
+
56
+ # Set up streaming
57
  streamer = TextIteratorStreamer(
58
  tokenizer,
59
  skip_prompt=False,
60
  skip_special_tokens=True,
61
  )
62
+
63
+ # Configure generation parameters
64
  generate_kwargs = dict(
65
  **inputs,
66
  streamer=streamer,
 
70
  top_p=top_p,
71
  eos_token_id=tokenizer.eos_token_id,
72
  )
73
+
74
+ # Start generation in a background thread
75
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
76
  thread.start()
77
+
78
+ # Stream the response with model name prefix
79
+ accumulated = f"**{model_name}:** "
80
  for new_text in streamer:
81
  accumulated += new_text
82
  history[-1] = (message, accumulated)
83
  yield history
84
 
85
  def submit(message, history, model_name, max_tokens, temperature, top_p):
86
+ """
87
+ Handle form submission by calling respond and clearing the input box.
88
+
89
+ Args:
90
+ message (str): User's input message.
91
+ history (list): Current chat history.
92
+ model_name (str): Selected model name.
93
+ max_tokens (int): Max tokens parameter.
94
+ temperature (float): Temperature parameter.
95
+ top_p (float): Top-p parameter.
96
+
97
+ Yields:
98
+ tuple: (updated chat history, cleared user input)
99
+ """
100
  for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
101
  yield updated_history, ""
102
 
103
+ # Create the Gradio interface with Blocks
 
 
 
 
 
 
 
 
 
 
 
104
  with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
105
+ # Title and description
106
  gr.Markdown("# LeCarnet")
107
+ gr.Markdown("Select a model on the right and type a message to chat.")
108
+
109
+ # Two-column layout with specific widths
110
  with gr.Row():
111
+ # Left column: Chat interface (80% width)
112
  with gr.Column(scale=4):
 
113
  chatbot = gr.Chatbot(
114
+ avatar_images=(None, "media/le-carnet.png"), # User avatar: None, Bot avatar: Logo
115
  label="Chat",
116
+ height=600, # Increase chat height for larger display
117
  )
118
  user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
119
  submit_btn = gr.Button("Send")
120
+ # Example prompts
121
+ examples = gr.Examples(
122
+ examples=[
123
+ ["Il était une fois un petit garçon qui vivait dans un village paisible."],
124
+ ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
125
+ ["Il était une fois un petit lapin perdu"],
126
+ ],
127
+ inputs=user_input,
128
+ )
129
+
130
+ # Right column: Model selection and parameters (20% width)
131
  with gr.Column(scale=1, min_width=200):
132
+ # Dropdown for model selection
133
  model_dropdown = gr.Dropdown(
134
+ choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
135
  value="LeCarnet-8M",
136
+ label="Select Model"
137
  )
138
+ # Sliders for parameters
139
  max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
140
  temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
141
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
142
+
143
+ # Event handling for submit button
144
  submit_btn.click(
145
  fn=submit,
146
  inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
147
  outputs=[chatbot, user_input],
148
  )
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
  demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)