MaxLSB commited on
Commit
a167f72
·
verified ·
1 Parent(s): 954f37f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -52
app.py CHANGED
@@ -17,50 +17,34 @@ model_name_to_path = {
17
  }
18
 
19
  # Load Hugging Face token
20
- hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
21
 
22
  # Preload models and tokenizers
23
  loaded_models = defaultdict(dict)
24
 
25
  for name, path in model_name_to_path.items():
26
- loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
27
- loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
28
- loaded_models[name]["model"].eval()
 
 
 
29
 
30
  def respond(message, history, model_name, max_tokens, temperature, top_p):
31
- """
32
- Generate a response from the selected model, streaming the output and updating chat history.
33
-
34
- Args:
35
- message (str): User's input message.
36
- history (list): Current chat history as list of (user_msg, bot_msg) tuples.
37
- model_name (str): Selected model name.
38
- max_tokens (int): Maximum number of tokens to generate.
39
- temperature (float): Sampling temperature.
40
- top_p (float): Top-p sampling parameter.
41
-
42
- Yields:
43
- list: Updated chat history with the user's message and streaming bot response.
44
- """
45
- # Append user's message to history with an empty bot response
46
  history = history + [(message, "")]
47
- yield history # Display user's message immediately
48
 
49
- # Select tokenizer and model
50
  tokenizer = loaded_models[model_name]["tokenizer"]
51
  model = loaded_models[model_name]["model"]
52
 
53
- # Tokenize input
54
  inputs = tokenizer(message, return_tensors="pt")
55
 
56
- # Set up streaming
57
  streamer = TextIteratorStreamer(
58
  tokenizer,
59
  skip_prompt=False,
60
  skip_special_tokens=True,
61
  )
62
 
63
- # Configure generation parameters
64
  generate_kwargs = dict(
65
  **inputs,
66
  streamer=streamer,
@@ -71,53 +55,32 @@ def respond(message, history, model_name, max_tokens, temperature, top_p):
71
  eos_token_id=tokenizer.eos_token_id,
72
  )
73
 
74
- # Start generation in a background thread
75
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
76
  thread.start()
77
 
78
- # Stream the response with model name prefix
79
- accumulated = f"**{model_name}:** "
80
  for new_text in streamer:
81
  accumulated += new_text
82
  history[-1] = (message, accumulated)
83
  yield history
84
 
85
  def submit(message, history, model_name, max_tokens, temperature, top_p):
86
- """
87
- Handle form submission by calling respond and clearing the input box.
88
-
89
- Args:
90
- message (str): User's input message.
91
- history (list): Current chat history.
92
- model_name (str): Selected model name.
93
- max_tokens (int): Max tokens parameter.
94
- temperature (float): Temperature parameter.
95
- top_p (float): Top-p parameter.
96
-
97
- Yields:
98
- tuple: (updated chat history, cleared user input)
99
- """
100
  for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
101
  yield updated_history, ""
102
 
103
- # Create the Gradio interface with Blocks
104
  with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
105
- # Title and description
106
  gr.Markdown("# LeCarnet")
107
  gr.Markdown("Select a model on the right and type a message to chat.")
108
 
109
- # Two-column layout with specific widths
110
  with gr.Row():
111
- # Left column: Chat interface (80% width)
112
  with gr.Column(scale=4):
113
  chatbot = gr.Chatbot(
114
- avatar_images=(None, "media/le-carnet.png"), # User avatar: None, Bot avatar: Logo
115
  label="Chat",
116
- height=600, # Increase chat height for larger display
117
  )
118
  user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
119
  submit_btn = gr.Button("Send")
120
- # Example prompts
121
  examples = gr.Examples(
122
  examples=[
123
  ["Il était une fois un petit garçon qui vivait dans un village paisible."],
@@ -127,25 +90,29 @@ with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding:
127
  inputs=user_input,
128
  )
129
 
130
- # Right column: Model selection and parameters (20% width)
131
  with gr.Column(scale=1, min_width=200):
132
- # Dropdown for model selection
133
  model_dropdown = gr.Dropdown(
134
  choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
135
  value="LeCarnet-8M",
136
  label="Select Model"
137
  )
138
- # Sliders for parameters
139
  max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
140
  temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
141
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
142
 
143
- # Event handling for submit button
144
  submit_btn.click(
145
  fn=submit,
146
  inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
147
  outputs=[chatbot, user_input],
148
  )
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
  demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
 
17
  }
18
 
19
  # Load Hugging Face token
20
+ hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN", "default_token") # Use default to avoid errors
21
 
22
  # Preload models and tokenizers
23
  loaded_models = defaultdict(dict)
24
 
25
  for name, path in model_name_to_path.items():
26
+ try:
27
+ loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
28
+ loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
29
+ loaded_models[name]["model"].eval()
30
+ except Exception as e:
31
+ print(f"Error loading {name}: {str(e)}")
32
 
33
  def respond(message, history, model_name, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  history = history + [(message, "")]
35
+ yield history
36
 
 
37
  tokenizer = loaded_models[model_name]["tokenizer"]
38
  model = loaded_models[model_name]["model"]
39
 
 
40
  inputs = tokenizer(message, return_tensors="pt")
41
 
 
42
  streamer = TextIteratorStreamer(
43
  tokenizer,
44
  skip_prompt=False,
45
  skip_special_tokens=True,
46
  )
47
 
 
48
  generate_kwargs = dict(
49
  **inputs,
50
  streamer=streamer,
 
55
  eos_token_id=tokenizer.eos_token_id,
56
  )
57
 
 
58
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
59
  thread.start()
60
 
61
+ accumulated = "" # Removed model name prefix
 
62
  for new_text in streamer:
63
  accumulated += new_text
64
  history[-1] = (message, accumulated)
65
  yield history
66
 
67
  def submit(message, history, model_name, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
69
  yield updated_history, ""
70
 
 
71
  with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
 
72
  gr.Markdown("# LeCarnet")
73
  gr.Markdown("Select a model on the right and type a message to chat.")
74
 
 
75
  with gr.Row():
 
76
  with gr.Column(scale=4):
77
  chatbot = gr.Chatbot(
78
+ avatar_images=(None, "https://raw.githubusercontent.com/maxlsb/le-carnet/main/media/le-carnet.png"), # Using URL for reliability
79
  label="Chat",
80
+ height=600,
81
  )
82
  user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
83
  submit_btn = gr.Button("Send")
 
84
  examples = gr.Examples(
85
  examples=[
86
  ["Il était une fois un petit garçon qui vivait dans un village paisible."],
 
90
  inputs=user_input,
91
  )
92
 
 
93
  with gr.Column(scale=1, min_width=200):
 
94
  model_dropdown = gr.Dropdown(
95
  choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
96
  value="LeCarnet-8M",
97
  label="Select Model"
98
  )
 
99
  max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
100
  temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
101
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
102
 
103
+ # Submit button click
104
  submit_btn.click(
105
  fn=submit,
106
  inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
107
  outputs=[chatbot, user_input],
108
  )
109
+
110
+ # Enter key press
111
+ user_input.submit(
112
+ fn=submit,
113
+ inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
114
+ outputs=[chatbot, user_input],
115
+ )
116
 
117
  if __name__ == "__main__":
118
  demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)