MaxLSB commited on
Commit
a3c4cbd
·
verified ·
1 Parent(s): 7c9e931

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -16,17 +16,27 @@ current_model_name = None
16
  # Load selected model
17
  def load_model(model_name):
18
  global tokenizer, model, current_model_name
 
 
 
 
 
19
  full_model_name = f"MaxLSB/{model_name}"
 
20
  tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
21
  model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
22
  model.eval()
23
  current_model_name = model_name
 
24
 
25
  # Initialize default model
26
  load_model("LeCarnet-8M")
27
 
28
  # Streaming generation function
29
- def respond(message, max_tokens, temperature, top_p):
 
 
 
30
  inputs = tokenizer(message, return_tensors="pt")
31
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
32
 
@@ -57,18 +67,17 @@ def user(message, chat_history):
57
  chat_history.append([message, None])
58
  return "", chat_history
59
 
60
- # Bot response handler
61
- def bot(chatbot, max_tokens, temperature, top_p):
62
  message = chatbot[-1][0]
63
- response_generator = respond(message, max_tokens, temperature, top_p)
64
  for response in response_generator:
65
  chatbot[-1][1] = response
66
  yield chatbot
67
 
68
- # Model selector handler - FIXED
69
  def update_model(model_name):
70
  load_model(model_name)
71
- # Return the model_name directly instead of using gr.Dropdown.update()
72
  return model_name
73
 
74
  # Clear chat handler
@@ -84,7 +93,6 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
84
  </div>
85
  """)
86
 
87
- # Create the msg_input early, but don't render it yet
88
  msg_input = gr.Textbox(
89
  placeholder="Il était une fois un petit garçon",
90
  label="User Input",
@@ -118,14 +126,13 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
118
  bubble_full_width=False,
119
  height=500
120
  )
121
- # Now render the msg_input inside the right column, below the chatbot
122
  msg_input.render()
123
 
124
  # Event Handlers
125
  model_selector.change(
126
  fn=update_model,
127
  inputs=[model_selector],
128
- outputs=[model_selector], # This will update the dropdown value
129
  )
130
 
131
  msg_input.submit(
@@ -135,7 +142,7 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
135
  queue=False
136
  ).then(
137
  fn=bot,
138
- inputs=[chatbot, max_tokens, temperature, top_p],
139
  outputs=[chatbot]
140
  )
141
 
 
16
  # Load selected model
17
  def load_model(model_name):
18
  global tokenizer, model, current_model_name
19
+
20
+ # Only load if it's a different model
21
+ if current_model_name == model_name:
22
+ return
23
+
24
  full_model_name = f"MaxLSB/{model_name}"
25
+ print(f"Loading model: {full_model_name}")
26
  tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
27
  model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
28
  model.eval()
29
  current_model_name = model_name
30
+ print(f"Model loaded: {current_model_name}")
31
 
32
  # Initialize default model
33
  load_model("LeCarnet-8M")
34
 
35
  # Streaming generation function
36
+ def respond(message, max_tokens, temperature, top_p, selected_model):
37
+ # Ensure the correct model is loaded before generation
38
+ load_model(selected_model)
39
+
40
  inputs = tokenizer(message, return_tensors="pt")
41
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
42
 
 
67
  chat_history.append([message, None])
68
  return "", chat_history
69
 
70
+ # Bot response handler - UPDATED to pass selected model
71
+ def bot(chatbot, max_tokens, temperature, top_p, selected_model):
72
  message = chatbot[-1][0]
73
+ response_generator = respond(message, max_tokens, temperature, top_p, selected_model)
74
  for response in response_generator:
75
  chatbot[-1][1] = response
76
  yield chatbot
77
 
78
+ # Model selector handler
79
  def update_model(model_name):
80
  load_model(model_name)
 
81
  return model_name
82
 
83
  # Clear chat handler
 
93
  </div>
94
  """)
95
 
 
96
  msg_input = gr.Textbox(
97
  placeholder="Il était une fois un petit garçon",
98
  label="User Input",
 
126
  bubble_full_width=False,
127
  height=500
128
  )
 
129
  msg_input.render()
130
 
131
  # Event Handlers
132
  model_selector.change(
133
  fn=update_model,
134
  inputs=[model_selector],
135
+ outputs=[model_selector],
136
  )
137
 
138
  msg_input.submit(
 
142
  queue=False
143
  ).then(
144
  fn=bot,
145
+ inputs=[chatbot, max_tokens, temperature, top_p, model_selector], # Pass model_selector
146
  outputs=[chatbot]
147
  )
148