MaxLSB commited on
Commit
7b4f2fa
·
verified ·
1 Parent(s): 269f4b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -108
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import threading
3
  from collections import defaultdict
4
-
 
5
  import gradio as gr
6
  from transformers import (
7
  AutoModelForCausalLM,
@@ -9,17 +10,14 @@ from transformers import (
9
  TextIteratorStreamer,
10
  )
11
 
12
- # Define model paths
13
  model_name_to_path = {
14
  "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
15
  "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
16
  "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
17
  }
18
 
19
- # Load Hugging Face token
20
  hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
21
 
22
- # Preload models and tokenizers
23
  loaded_models = defaultdict(dict)
24
 
25
  for name, path in model_name_to_path.items():
@@ -27,40 +25,24 @@ for name, path in model_name_to_path.items():
27
  loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
28
  loaded_models[name]["model"].eval()
29
 
 
 
 
 
 
 
 
30
  def respond(message, history, model_name, max_tokens, temperature, top_p):
31
- """
32
- Generate a response from the selected model, streaming the output and updating chat history.
33
-
34
- Args:
35
- message (str): User's input message.
36
- history (list): Current chat history as list of (user_msg, bot_msg) tuples.
37
- model_name (str): Selected model name.
38
- max_tokens (int): Maximum number of tokens to generate.
39
- temperature (float): Sampling temperature.
40
- top_p (float): Top-p sampling parameter.
41
-
42
- Yields:
43
- list: Updated chat history with the user's message and streaming bot response.
44
- """
45
- # Append user's message to history with an empty bot response
46
  history = history + [(message, "")]
47
- yield history # Display user's message immediately
48
-
49
- # Select tokenizer and model
50
  tokenizer = loaded_models[model_name]["tokenizer"]
51
  model = loaded_models[model_name]["model"]
52
-
53
- # Tokenize input
54
  inputs = tokenizer(message, return_tensors="pt")
55
-
56
- # Set up streaming
57
  streamer = TextIteratorStreamer(
58
  tokenizer,
59
  skip_prompt=False,
60
  skip_special_tokens=True,
61
  )
62
-
63
- # Configure generation parameters
64
  generate_kwargs = dict(
65
  **inputs,
66
  streamer=streamer,
@@ -70,115 +52,61 @@ def respond(message, history, model_name, max_tokens, temperature, top_p):
70
  top_p=top_p,
71
  eos_token_id=tokenizer.eos_token_id,
72
  )
73
-
74
- # Start generation in a background thread
75
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
76
  thread.start()
77
-
78
- # Stream the response with model name prefix
79
- accumulated = f"**{model_name}:** "
80
  for new_text in streamer:
81
  accumulated += new_text
82
  history[-1] = (message, accumulated)
83
  yield history
84
 
85
  def submit(message, history, model_name, max_tokens, temperature, top_p):
86
- """
87
- Handle form submission by calling respond and clearing the input box.
88
-
89
- Args:
90
- message (str): User's input message.
91
- history (list): Current chat history.
92
- model_name (str): Selected model name.
93
- max_tokens (int): Max tokens parameter.
94
- temperature (float): Temperature parameter.
95
- top_p (float): Top-p parameter.
96
-
97
- Yields:
98
- tuple: (updated chat history, cleared user input)
99
- """
100
  for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
101
  yield updated_history, ""
102
 
103
- def select_model(model_name, current_model):
104
- """
105
- Update the selected model name when a model button is clicked.
106
-
107
- Args:
108
- model_name (str): The model name to select.
109
- current_model (str): The currently selected model.
110
-
111
- Returns:
112
- str: The newly selected model name.
113
- """
114
- return model_name
115
 
116
- # Create the Gradio interface with Blocks
117
  with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
118
- # Title and description
119
  gr.Markdown("# LeCarnet")
120
- gr.Markdown("Select a model on the right and type a message to chat.")
121
-
122
- # Two-column layout with specific widths
123
  with gr.Row():
124
- # Left column: Chat interface (80% width)
125
  with gr.Column(scale=4):
 
126
  chatbot = gr.Chatbot(
127
- avatar_images=(None, "media/le-carnet.png"), # User avatar: None, Bot avatar: Logo
128
  label="Chat",
129
- height=600, # Increase chat height for larger display
130
  )
131
  user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
132
  submit_btn = gr.Button("Send")
133
-
134
- # Right column: Model selection and parameters (20% width)
135
  with gr.Column(scale=1, min_width=200):
136
- # State to track selected model
137
- model_state = gr.State(value="LeCarnet-8M")
138
-
139
- # Model selection buttons
140
- gr.Markdown("**Select Model**")
141
- btn_3m = gr.Button("LeCarnet-3M")
142
- btn_8m = gr.Button("LeCarnet-8M")
143
- btn_21m = gr.Button("LeCarnet-21M")
144
-
145
- # Sliders for parameters
146
  max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
147
  temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
148
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
149
-
150
- # Example prompts
151
- examples = gr.Examples(
152
- examples=[
153
- ["Il était une fois un petit garçon qui vivait dans un village paisible."],
154
- ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
155
- ["Il était une fois un petit lapin perdu"],
156
- ],
157
- inputs=user_input,
158
- )
159
-
160
- # Event handling for submit button
161
  submit_btn.click(
162
  fn=submit,
163
- inputs=[user_input, chatbot, model_state, max_tokens, temperature, top_p],
164
  outputs=[chatbot, user_input],
165
  )
166
-
167
- # Event handling for model selection buttons
168
- btn_3m.click(
169
- fn=select_model,
170
- inputs=[gr.State("LeCarnet-3M"), model_state],
171
- outputs=model_state,
172
- )
173
- btn_8m.click(
174
- fn=select_model,
175
- inputs=[gr.State("LeCarnet-8M"), model_state],
176
- outputs=model_state,
177
- )
178
- btn_21m.click(
179
- fn=select_model,
180
- inputs=[gr.State("LeCarnet-21M"), model_state],
181
- outputs=model_state,
182
  )
183
 
184
  if __name__ == "__main__":
 
1
  import os
2
  import threading
3
  from collections import defaultdict
4
+ from PIL import Image
5
+ import tempfile
6
  import gradio as gr
7
  from transformers import (
8
  AutoModelForCausalLM,
 
10
  TextIteratorStreamer,
11
  )
12
 
 
13
  model_name_to_path = {
14
  "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
15
  "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
16
  "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
17
  }
18
 
 
19
  hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
20
 
 
21
  loaded_models = defaultdict(dict)
22
 
23
  for name, path in model_name_to_path.items():
 
25
  loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
26
  loaded_models[name]["model"].eval()
27
 
28
+ def resize_logo(input_path, size=(100, 100)):
29
+ with Image.open(input_path) as img:
30
+ img = img.resize(size, Image.LANCZOS)
31
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
32
+ img.save(temp_file.name, format="PNG")
33
+ return temp_file.name
34
+
35
  def respond(message, history, model_name, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  history = history + [(message, "")]
37
+ yield history
 
 
38
  tokenizer = loaded_models[model_name]["tokenizer"]
39
  model = loaded_models[model_name]["model"]
 
 
40
  inputs = tokenizer(message, return_tensors="pt")
 
 
41
  streamer = TextIteratorStreamer(
42
  tokenizer,
43
  skip_prompt=False,
44
  skip_special_tokens=True,
45
  )
 
 
46
  generate_kwargs = dict(
47
  **inputs,
48
  streamer=streamer,
 
52
  top_p=top_p,
53
  eos_token_id=tokenizer.eos_token_id,
54
  )
 
 
55
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
56
  thread.start()
57
+ accumulated = f"**{model_name}**\n\n"
 
 
58
  for new_text in streamer:
59
  accumulated += new_text
60
  history[-1] = (message, accumulated)
61
  yield history
62
 
63
  def submit(message, history, model_name, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
65
  yield updated_history, ""
66
 
67
+ def start_with_example(example, model_name, max_tokens, temperature, top_p):
68
+ for updated_history in respond(example, [], model_name, max_tokens, temperature, top_p):
69
+ yield updated_history, ""
70
+
71
+ resized_logo_path = resize_logo("media/le-carnet.png", size=(100, 100))
72
+
73
+ examples = [
74
+ "Il était une fois un petit garçon qui vivait dans un village paisible.",
75
+ "Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang.",
76
+ "Il était une fois un petit lapin perdu",
77
+ ]
 
78
 
 
79
  with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
 
80
  gr.Markdown("# LeCarnet")
81
+ gr.Markdown("Select a model on the right and type a message to chat, or choose an example below.")
 
 
82
  with gr.Row():
 
83
  with gr.Column(scale=4):
84
+ dataset = gr.Dataset(components=[gr.Textbox(visible=False)], samples=[[ex] for ex in examples], type="values")
85
  chatbot = gr.Chatbot(
86
+ avatar_images=(None, resized_logo_path),
87
  label="Chat",
88
+ height=600,
89
  )
90
  user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
91
  submit_btn = gr.Button("Send")
 
 
92
  with gr.Column(scale=1, min_width=200):
93
+ model_dropdown = gr.Dropdown(
94
+ choices=list(model_name_to_path.keys()),
95
+ value="LeCarnet-8M",
96
+ label="Model"
97
+ )
 
 
 
 
 
98
  max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
99
  temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
100
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
 
 
 
 
 
 
 
 
 
 
 
 
101
  submit_btn.click(
102
  fn=submit,
103
+ inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
104
  outputs=[chatbot, user_input],
105
  )
106
+ dataset.change(
107
+ fn=start_with_example,
108
+ inputs=[dataset, model_dropdown, max_tokens, temperature, top_p],
109
+ outputs=[chatbot, user_input],
 
 
 
 
 
 
 
 
 
 
 
 
110
  )
111
 
112
  if __name__ == "__main__":