MaxLSB commited on
Commit
790cffd
·
verified ·
1 Parent(s): a167f72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -84
app.py CHANGED
@@ -1,49 +1,34 @@
1
  import os
2
  import threading
3
- from collections import defaultdict
4
-
5
  import gradio as gr
6
- from transformers import (
7
- AutoModelForCausalLM,
8
- AutoTokenizer,
9
- TextIteratorStreamer,
10
- )
11
-
12
- # Define model paths
13
- model_name_to_path = {
14
- "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
15
- "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
16
- "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
17
- }
18
-
19
- # Load Hugging Face token
20
- hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN", "default_token") # Use default to avoid errors
21
-
22
- # Preload models and tokenizers
23
- loaded_models = defaultdict(dict)
24
-
25
- for name, path in model_name_to_path.items():
26
- try:
27
- loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
28
- loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
29
- loaded_models[name]["model"].eval()
30
- except Exception as e:
31
- print(f"Error loading {name}: {str(e)}")
32
-
33
- def respond(message, history, model_name, max_tokens, temperature, top_p):
34
- history = history + [(message, "")]
35
- yield history
36
-
37
- tokenizer = loaded_models[model_name]["tokenizer"]
38
- model = loaded_models[model_name]["model"]
39
 
40
- inputs = tokenizer(message, return_tensors="pt")
 
41
 
42
- streamer = TextIteratorStreamer(
43
- tokenizer,
44
- skip_prompt=False,
45
- skip_special_tokens=True,
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  generate_kwargs = dict(
49
  **inputs,
@@ -58,61 +43,80 @@ def respond(message, history, model_name, max_tokens, temperature, top_p):
58
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
59
  thread.start()
60
 
61
- accumulated = "" # Removed model name prefix
62
  for new_text in streamer:
63
- accumulated += new_text
64
- history[-1] = (message, accumulated)
65
- yield history
66
-
67
- def submit(message, history, model_name, max_tokens, temperature, top_p):
68
- for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
69
- yield updated_history, ""
70
-
71
- with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
72
- gr.Markdown("# LeCarnet")
73
- gr.Markdown("Select a model on the right and type a message to chat.")
74
-
 
 
 
 
 
 
 
 
 
 
 
75
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  with gr.Column(scale=4):
 
 
 
 
 
 
 
 
77
  chatbot = gr.Chatbot(
78
- avatar_images=(None, "https://raw.githubusercontent.com/maxlsb/le-carnet/main/media/le-carnet.png"), # Using URL for reliability
79
- label="Chat",
80
- height=600,
 
 
 
81
  )
82
- user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
83
- submit_btn = gr.Button("Send")
84
- examples = gr.Examples(
85
  examples=[
86
  ["Il était une fois un petit garçon qui vivait dans un village paisible."],
87
  ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
88
  ["Il était une fois un petit lapin perdu"],
89
  ],
90
- inputs=user_input,
91
- )
92
-
93
- with gr.Column(scale=1, min_width=200):
94
- model_dropdown = gr.Dropdown(
95
- choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
96
- value="LeCarnet-8M",
97
- label="Select Model"
98
  )
99
- max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
100
- temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
101
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
102
 
103
- # Submit button click
104
- submit_btn.click(
105
- fn=submit,
106
- inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
107
- outputs=[chatbot, user_input],
108
- )
109
-
110
- # Enter key press
111
- user_input.submit(
112
- fn=submit,
113
- inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
114
- outputs=[chatbot, user_input],
115
  )
 
116
 
117
  if __name__ == "__main__":
118
- demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
 
 
1
  import os
2
  import threading
 
 
3
  import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Hugging Face token
7
+ hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
8
 
9
+ # Global model & tokenizer
10
+ tokenizer = None
11
+ model = None
12
+
13
+ # Load selected model
14
+ def load_model(model_name):
15
+ global tokenizer, model
16
+ full_model_name = f"MaxLSB/{model_name}"
17
+ tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
18
+ model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
19
+ model.eval()
20
+
21
+ # Initialize default model
22
+ load_model("LeCarnet-8M")
23
+
24
+ # Streamer for real-time generation
25
+ streamer = None
26
+
27
+ # Streaming generation function
28
+ def respond(message, max_tokens, temperature, top_p):
29
+ global streamer
30
+ inputs = tokenizer(message, return_tensors="pt")
31
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
32
 
33
  generate_kwargs = dict(
34
  **inputs,
 
43
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
44
  thread.start()
45
 
46
+ response = ""
47
  for new_text in streamer:
48
+ response += new_text
49
+ yield response
50
+
51
+ # User input handler
52
+ def user(message, chat_history):
53
+ chat_history.append([message, None])
54
+ return "", chat_history
55
+
56
+ # Bot response handler
57
+ def bot(chatbot, max_tokens, temperature, top_p):
58
+ message = chatbot[-1][0]
59
+ response_generator = respond(message, max_tokens, temperature, top_p)
60
+ for response in response_generator:
61
+ chatbot[-1][1] = response
62
+ yield chatbot
63
+
64
+ # Model selector handler
65
+ def update_model(model_name):
66
+ load_model(model_name)
67
+ return []
68
+
69
+ # Gradio UI
70
+ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
71
  with gr.Row():
72
+ # Left column: Options
73
+ with gr.Column(scale=1, min_width=150):
74
+ gr.Markdown("### 🧠 Model Settings")
75
+ model_selector = gr.Dropdown(
76
+ choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
77
+ value="LeCarnet-8M",
78
+ label="Select Model"
79
+ )
80
+ max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
81
+ temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
82
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
83
+ clear_button = gr.Button("🗑️ Clear Chat")
84
+
85
+ # Right column: Chat + Image
86
  with gr.Column(scale=4):
87
+ gr.Markdown("### 🤖 LeCarnet Chatbot")
88
+ model_logo = gr.Image(
89
+ value="media/le-carnet.png",
90
+ label="Model Logo",
91
+ height=100,
92
+ width=100,
93
+ interactive=False
94
+ )
95
  chatbot = gr.Chatbot(
96
+ bubble_full_width=False,
97
+ height=500
98
+ )
99
+ msg_input = gr.Textbox(
100
+ placeholder="Type your message and press Enter...",
101
+ label="User Input"
102
  )
103
+ gr.Examples(
 
 
104
  examples=[
105
  ["Il était une fois un petit garçon qui vivait dans un village paisible."],
106
  ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
107
  ["Il était une fois un petit lapin perdu"],
108
  ],
109
+ inputs=msg_input,
110
+ label="Example Prompts"
 
 
 
 
 
 
111
  )
 
 
 
112
 
113
+ # Event handlers
114
+ model_selector.change(fn=update_model, inputs=[model_selector], outputs=[])
115
+ msg_input.submit(fn=user, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], queue=False).then(
116
+ fn=bot, inputs=[chatbot, max_tokens, temperature, top_p], outputs=[chatbot]
 
 
 
 
 
 
 
 
117
  )
118
+ clear_button.click(fn=lambda: None, inputs=None, outputs=chatbot, queue=False)
119
 
120
  if __name__ == "__main__":
121
+ demo.queue()
122
+ demo.launch()