MaxLSB commited on
Commit
644b0a5
·
verified ·
1 Parent(s): 52a9a97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -56
app.py CHANGED
@@ -27,20 +27,31 @@ for name, path in model_name_to_path.items():
27
  loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
28
  loaded_models[name]["model"].eval()
29
 
30
- def respond(
31
- prompt: str,
32
- chat_history,
33
- model_name: str,
34
- max_tokens: int,
35
- temperature: float,
36
- top_p: float,
37
- ):
38
- # Select the appropriate model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
39
  tokenizer = loaded_models[model_name]["tokenizer"]
40
  model = loaded_models[model_name]["model"]
41
 
42
  # Tokenize input
43
- inputs = tokenizer(prompt, return_tensors="pt")
44
 
45
  # Set up streaming
46
  streamer = TextIteratorStreamer(
@@ -60,65 +71,79 @@ def respond(
60
  eos_token_id=tokenizer.eos_token_id,
61
  )
62
 
63
- # Run generation in a background thread
64
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
65
  thread.start()
66
 
67
- # Stream results
68
- accumulated = ""
69
  for new_text in streamer:
70
  accumulated += new_text
71
- yield accumulated
 
72
 
73
- # Create Gradio Chat Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  with gr.Blocks() as demo:
75
- # Custom title with logo
 
 
 
 
76
  with gr.Row():
77
- gr.HTML(
78
- '<div style="display: flex; align-items: center;">'
79
- f'<img src="file/{os.path.abspath("media/le-carnet.png")}" style="height: 50px; margin-right: 10px;" />'
80
- '<h1 style="margin: 0;">LeCarnet</h1>'
81
- '</div>'
82
- )
83
-
84
- # Chat interface
85
- chatbot = gr.ChatInterface(
86
- fn=respond,
87
- title=None, # Remove default title
88
- description=None, # Remove default description
 
 
 
 
 
 
 
 
 
 
89
  examples=[
90
  ["Il était une fois un petit garçon qui vivait dans un village paisible."],
91
  ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
92
  ["Il était une fois un petit lapin perdu"],
93
  ],
94
- cache_examples=False,
95
  )
96
-
97
- # Sidebar for model selection and parameters
98
- with gr.Column(elem_classes="sidebar", variant="panel"):
99
- gr.Markdown("### Model Configuration")
100
- model_dropdown = gr.Dropdown(
101
- choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
102
- value="LeCarnet-8M",
103
- label="Model",
104
- )
105
- max_tokens_slider = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
106
- temperature_slider = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
107
- top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
108
-
109
- # Pass parameters to the chatbot
110
- chatbot.load(
111
- fn=lambda x, y, z, w: None,
112
- inputs=[model_dropdown, max_tokens_slider, temperature_slider, top_p_slider],
113
- outputs=None,
114
- )
115
- chatbot.config.update({
116
- "model_name": model_dropdown,
117
- "max_tokens": max_tokens_slider,
118
- "temperature": temperature_slider,
119
- "top_p": top_p_slider,
120
- })
121
-
122
- # Launch the app
123
  if __name__ == "__main__":
124
  demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
 
27
  loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
28
  loaded_models[name]["model"].eval()
29
 
30
+ def respond(message, history, model_name, max_tokens, temperature, top_p):
31
+ """
32
+ Generate a response from the selected model, streaming the output and updating chat history.
33
+
34
+ Args:
35
+ message (str): User's input message.
36
+ history (list): Current chat history as list of (user_msg, bot_msg) tuples.
37
+ model_name (str): Selected model name.
38
+ max_tokens (int): Maximum number of tokens to generate.
39
+ temperature (float): Sampling temperature.
40
+ top_p (float): Top-p sampling parameter.
41
+
42
+ Yields:
43
+ list: Updated chat history with the user's message and streaming bot response.
44
+ """
45
+ # Append user's message to history with an empty bot response
46
+ history = history + [(message, "")]
47
+ yield history # Display user's message immediately
48
+
49
+ # Select tokenizer and model
50
  tokenizer = loaded_models[model_name]["tokenizer"]
51
  model = loaded_models[model_name]["model"]
52
 
53
  # Tokenize input
54
+ inputs = tokenizer(message, return_tensors="pt")
55
 
56
  # Set up streaming
57
  streamer = TextIteratorStreamer(
 
71
  eos_token_id=tokenizer.eos_token_id,
72
  )
73
 
74
+ # Start generation in a background thread
75
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
76
  thread.start()
77
 
78
+ # Stream the response with model name prefix
79
+ accumulated = f"**{model_name}:** "
80
  for new_text in streamer:
81
  accumulated += new_text
82
+ history[-1] = (message, accumulated)
83
+ yield history
84
 
85
+ def submit(message, history, model_name, max_tokens, temperature, top_p):
86
+ """
87
+ Handle form submission by calling respond and clearing the input box.
88
+
89
+ Args:
90
+ message (str): User's input message.
91
+ history (list): Current chat history.
92
+ model_name (str): Selected model name.
93
+ max_tokens (int): Max tokens parameter.
94
+ temperature (float): Temperature parameter.
95
+ top_p (float): Top-p parameter.
96
+
97
+ Yields:
98
+ tuple: (updated chat history, cleared user input)
99
+ """
100
+ for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
101
+ yield updated_history, ""
102
+
103
+ # Create the Gradio interface with Blocks
104
  with gr.Blocks() as demo:
105
+ # Title and description
106
+ gr.Markdown("# LeCarnet")
107
+ gr.Markdown("Select a model on the right and type a message to chat.")
108
+
109
+ # Two-column layout
110
  with gr.Row():
111
+ # Left column: Chat interface
112
+ with gr.Column():
113
+ chatbot = gr.Chatbot(
114
+ avatar_images=(None, "media/le-carnet.png"), # User avatar: None, Bot avatar: Logo
115
+ label="Chat"
116
+ )
117
+ user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
118
+ submit_btn = gr.Button("Send")
119
+
120
+ # Right column: Model selection and parameters
121
+ with gr.Column():
122
+ model_dropdown = gr.Dropdown(
123
+ choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
124
+ value="LeCarnet-8M",
125
+ label="Model"
126
+ )
127
+ max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
128
+ temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
129
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
130
+
131
+ # Example prompts
132
+ examples = gr.Examples(
133
  examples=[
134
  ["Il était une fois un petit garçon qui vivait dans un village paisible."],
135
  ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
136
  ["Il était une fois un petit lapin perdu"],
137
  ],
138
+ inputs=user_input,
139
  )
140
+
141
+ # Event handling for submit button
142
+ submit_btn.click(
143
+ fn=submit,
144
+ inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
145
+ outputs=[chatbot, user_input],
146
+ )
147
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  if __name__ == "__main__":
149
  demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)