MaxLSB commited on
Commit
f0687e5
·
verified ·
1 Parent(s): 0a364b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -28
app.py CHANGED
@@ -7,29 +7,58 @@ from transformers import (
7
  TextIteratorStreamer,
8
  )
9
 
10
- MODEL_NAME = "MaxLSB/LeCarnet-8M"
11
- hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Load tokenizer & model locally
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
15
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=hf_token)
16
- model.eval()
17
 
18
  def respond(
19
  prompt: str,
20
  chat_history,
 
21
  max_tokens: int,
22
  temperature: float,
23
  top_p: float,
24
  ):
25
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
26
 
 
27
  streamer = TextIteratorStreamer(
28
  tokenizer,
29
  skip_prompt=False,
30
  skip_special_tokens=True,
31
  )
32
-
33
  generate_kwargs = dict(
34
  **inputs,
35
  streamer=streamer,
@@ -39,33 +68,117 @@ def respond(
39
  top_p=top_p,
40
  eos_token_id=tokenizer.eos_token_id,
41
  )
42
-
43
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
44
  thread.start()
45
-
46
  accumulated = ""
47
  for new_text in streamer:
48
  accumulated += new_text
49
  yield accumulated
50
 
51
- # Wire it up in Gradio
52
- demo = gr.ChatInterface(
53
- fn=respond,
54
- additional_inputs=[
55
- gr.Slider(1, 512, value=512, step=1, label="Max new tokens"),
56
- gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
57
- gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p"),
58
- ],
59
- title="LeCarnet-8M",
60
- description="Type the beginning of a sentence and watch the model finish it.",
61
- examples = [
62
- ["Il était une fois un petit garçon qui vivait dans un village paisible."],
63
- ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
64
- ["Il était une fois un petit lapin perdu"],
65
- ],
66
- cache_examples=False,
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  demo.queue()
71
- demo.launch()
 
7
  TextIteratorStreamer,
8
  )
9
 
10
+ # Define your models
11
+ MODEL_PATHS = {
12
+ "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
13
+ "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
14
+ "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
15
+ }
16
+
17
+ # Add your Hugging Face token
18
+ hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
+ if not hf_token:
20
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
21
+
22
+ # Load tokenizers & models - only load one initially
23
+ tokenizer = None
24
+ model = None
25
+
26
+ def load_model(model_name):
27
+ """Loads the specified model and tokenizer."""
28
+ global tokenizer, model
29
+ if model_name not in MODEL_PATHS:
30
+ raise ValueError(f"Unknown model: {model_name}")
31
+
32
+ print(f"Loading {model_name}...")
33
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATHS[model_name], token=hf_token)
34
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATHS[model_name], token=hf_token)
35
+ model.eval()
36
+ print(f"{model_name} loaded.")
37
+
38
+ # Initial model load
39
+ initial_model = list(MODEL_PATHS.keys())[0]
40
+ load_model(initial_model)
41
 
 
 
 
 
42
 
43
  def respond(
44
  prompt: str,
45
  chat_history,
46
+ model_choice: str,
47
  max_tokens: int,
48
  temperature: float,
49
  top_p: float,
50
  ):
51
+ global tokenizer, model
52
+ # Reload model if it's not the currently loaded one
53
+ if model.config._name_or_path != MODEL_PATHS[model_choice]:
54
+ load_model(model_choice)
55
 
56
+ inputs = tokenizer(prompt, return_tensors="pt")
57
  streamer = TextIteratorStreamer(
58
  tokenizer,
59
  skip_prompt=False,
60
  skip_special_tokens=True,
61
  )
 
62
  generate_kwargs = dict(
63
  **inputs,
64
  streamer=streamer,
 
68
  top_p=top_p,
69
  eos_token_id=tokenizer.eos_token_id,
70
  )
 
71
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
72
  thread.start()
 
73
  accumulated = ""
74
  for new_text in streamer:
75
  accumulated += new_text
76
  yield accumulated
77
 
78
+ # --- Gradio Interface ---
79
+ # CSS for the custom logo and layout
80
+ css = """
81
+ .gradio-container {
82
+ padding: 0 !important;
83
+ }
84
+ .gradio-container > main.fillable {
85
+ padding: 0 !important;
86
+ }
87
+ #chatbot {
88
+ height: calc(100vh - 21px - 16px);
89
+ max-height: 1500px;
90
+ }
91
+ #chatbot .chatbot-conversations {
92
+ height: 100vh;
93
+ background-color: var(--ms-gr-ant-color-bg-layout);
94
+ padding-left: 4px;
95
+ padding-right: 4px;
96
+ }
97
+ #chatbot .chatbot-conversations .chatbot-conversations-list {
98
+ padding-left: 0;
99
+ padding-right: 0;
100
+ }
101
+ #chatbot .chatbot-chat {
102
+ padding: 32px;
103
+ padding-bottom: 0;
104
+ height: 100%;
105
+ }
106
+ @media (max-width: 768px) {
107
+ #chatbot .chatbot-chat {
108
+ padding: 0;
109
+ }
110
+ }
111
+ #chatbot .chatbot-chat .chatbot-chat-messages {
112
+ flex: 1;
113
+ }
114
+ .logo-container {
115
+ display: flex;
116
+ justify-content: center;
117
+ padding: 10px;
118
+ }
119
+ .logo-container img {
120
+ max-width: 80%; /* Adjust as needed */
121
+ height: auto;
122
+ }
123
+ """
124
+
125
+ with gr.Blocks(css=css, fill_width=True) as demo:
126
+ with gr.Column(elem_id="chatbot", variant="panel"):
127
+ # Custom Logo
128
+ with gr.Row(elem_classes="logo-container"):
129
+ gr.Image(
130
+ value="media/le-carnet.png", # Replace with the path to your image file
131
+ label="LeCarnet Logo",
132
+ interactive=False,
133
+ show_label=False,
134
+ show_download_button=False,
135
+ height=100 # Adjust height as needed
136
+ )
137
+
138
+ gr.Markdown(
139
+ """
140
+ # LeCarnet AI Assistant
141
+ Type the beginning of a sentence and watch the model finish it.
142
+ """
143
+ )
144
+
145
+ with gr.Row():
146
+ with gr.Column(scale=1):
147
+ model_dropdown = gr.Dropdown(
148
+ choices=list(MODEL_PATHS.keys()),
149
+ value=initial_model,
150
+ label="Choose Model",
151
+ interactive=True
152
+ )
153
+ max_tokens_slider = gr.Slider(
154
+ 1, 512, value=512, step=1, label="Max new tokens"
155
+ )
156
+ temperature_slider = gr.Slider(
157
+ 0.1, 2.0, value=0.7, step=0.1, label="Temperature"
158
+ )
159
+ top_p_slider = gr.Slider(
160
+ 0.1, 1.0, value=0.9, step=0.05, label="Top‑p"
161
+ )
162
+
163
+ with gr.Column(scale=3):
164
+ chatbot = gr.ChatInterface(
165
+ fn=respond,
166
+ additional_inputs=[
167
+ model_dropdown, # Pass model choice to respond function
168
+ max_tokens_slider,
169
+ temperature_slider,
170
+ top_p_slider,
171
+ ],
172
+ examples=[
173
+ ["Il était une fois un petit garçon qui vivait dans un village paisible."],
174
+ ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
175
+ ["Il était une fois un petit lapin perdu"],
176
+ ],
177
+ cache_examples=False,
178
+ submit_btn="Generate",
179
+ clear_btn="Clear Chat",
180
+ )
181
 
182
  if __name__ == "__main__":
183
  demo.queue()
184
+ demo.launch()