Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -16,17 +16,27 @@ current_model_name = None
|
|
16 |
# Load selected model
|
17 |
def load_model(model_name):
|
18 |
global tokenizer, model, current_model_name
|
|
|
|
|
|
|
|
|
|
|
19 |
full_model_name = f"MaxLSB/{model_name}"
|
|
|
20 |
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
|
21 |
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
|
22 |
model.eval()
|
23 |
current_model_name = model_name
|
|
|
24 |
|
25 |
# Initialize default model
|
26 |
load_model("LeCarnet-8M")
|
27 |
|
28 |
# Streaming generation function
|
29 |
-
def respond(message, max_tokens, temperature, top_p):
|
|
|
|
|
|
|
30 |
inputs = tokenizer(message, return_tensors="pt")
|
31 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
|
32 |
|
@@ -57,18 +67,17 @@ def user(message, chat_history):
|
|
57 |
chat_history.append([message, None])
|
58 |
return "", chat_history
|
59 |
|
60 |
-
# Bot response handler
|
61 |
-
def bot(chatbot, max_tokens, temperature, top_p):
|
62 |
message = chatbot[-1][0]
|
63 |
-
response_generator = respond(message, max_tokens, temperature, top_p)
|
64 |
for response in response_generator:
|
65 |
chatbot[-1][1] = response
|
66 |
yield chatbot
|
67 |
|
68 |
-
# Model selector handler
|
69 |
def update_model(model_name):
|
70 |
load_model(model_name)
|
71 |
-
# Return the model_name directly instead of using gr.Dropdown.update()
|
72 |
return model_name
|
73 |
|
74 |
# Clear chat handler
|
@@ -84,7 +93,6 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
|
|
84 |
</div>
|
85 |
""")
|
86 |
|
87 |
-
# Create the msg_input early, but don't render it yet
|
88 |
msg_input = gr.Textbox(
|
89 |
placeholder="Il était une fois un petit garçon",
|
90 |
label="User Input",
|
@@ -118,14 +126,13 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
|
|
118 |
bubble_full_width=False,
|
119 |
height=500
|
120 |
)
|
121 |
-
# Now render the msg_input inside the right column, below the chatbot
|
122 |
msg_input.render()
|
123 |
|
124 |
# Event Handlers
|
125 |
model_selector.change(
|
126 |
fn=update_model,
|
127 |
inputs=[model_selector],
|
128 |
-
outputs=[model_selector],
|
129 |
)
|
130 |
|
131 |
msg_input.submit(
|
@@ -135,7 +142,7 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
|
|
135 |
queue=False
|
136 |
).then(
|
137 |
fn=bot,
|
138 |
-
inputs=[chatbot, max_tokens, temperature, top_p],
|
139 |
outputs=[chatbot]
|
140 |
)
|
141 |
|
|
|
16 |
# Load selected model
|
17 |
def load_model(model_name):
|
18 |
global tokenizer, model, current_model_name
|
19 |
+
|
20 |
+
# Only load if it's a different model
|
21 |
+
if current_model_name == model_name:
|
22 |
+
return
|
23 |
+
|
24 |
full_model_name = f"MaxLSB/{model_name}"
|
25 |
+
print(f"Loading model: {full_model_name}")
|
26 |
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
|
27 |
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
|
28 |
model.eval()
|
29 |
current_model_name = model_name
|
30 |
+
print(f"Model loaded: {current_model_name}")
|
31 |
|
32 |
# Initialize default model
|
33 |
load_model("LeCarnet-8M")
|
34 |
|
35 |
# Streaming generation function
|
36 |
+
def respond(message, max_tokens, temperature, top_p, selected_model):
|
37 |
+
# Ensure the correct model is loaded before generation
|
38 |
+
load_model(selected_model)
|
39 |
+
|
40 |
inputs = tokenizer(message, return_tensors="pt")
|
41 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
|
42 |
|
|
|
67 |
chat_history.append([message, None])
|
68 |
return "", chat_history
|
69 |
|
70 |
+
# Bot response handler - UPDATED to pass selected model
|
71 |
+
def bot(chatbot, max_tokens, temperature, top_p, selected_model):
|
72 |
message = chatbot[-1][0]
|
73 |
+
response_generator = respond(message, max_tokens, temperature, top_p, selected_model)
|
74 |
for response in response_generator:
|
75 |
chatbot[-1][1] = response
|
76 |
yield chatbot
|
77 |
|
78 |
+
# Model selector handler
|
79 |
def update_model(model_name):
|
80 |
load_model(model_name)
|
|
|
81 |
return model_name
|
82 |
|
83 |
# Clear chat handler
|
|
|
93 |
</div>
|
94 |
""")
|
95 |
|
|
|
96 |
msg_input = gr.Textbox(
|
97 |
placeholder="Il était une fois un petit garçon",
|
98 |
label="User Input",
|
|
|
126 |
bubble_full_width=False,
|
127 |
height=500
|
128 |
)
|
|
|
129 |
msg_input.render()
|
130 |
|
131 |
# Event Handlers
|
132 |
model_selector.change(
|
133 |
fn=update_model,
|
134 |
inputs=[model_selector],
|
135 |
+
outputs=[model_selector],
|
136 |
)
|
137 |
|
138 |
msg_input.submit(
|
|
|
142 |
queue=False
|
143 |
).then(
|
144 |
fn=bot,
|
145 |
+
inputs=[chatbot, max_tokens, temperature, top_p, model_selector], # Pass model_selector
|
146 |
outputs=[chatbot]
|
147 |
)
|
148 |
|