LeCarnet-Demo / app.py
MaxLSB's picture
Update app.py
726d23d verified
import os
import torch
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
torch.set_num_threads(1)
# Globals
tokenizer = None
model = None
current_model_name = None
# Load selected model
def load_model(model_name):
global tokenizer, model, current_model_name
# Only load if it's a different model
if current_model_name == model_name:
return
full_model_name = f"MaxLSB/{model_name}"
print(f"Loading model: {full_model_name}")
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
model.eval()
current_model_name = model_name
print(f"Model loaded: {current_model_name}")
# Initialize default model
load_model("LeCarnet-8M")
# Streaming generation function
def respond(message, max_tokens, temperature, top_p, selected_model):
# Ensure the correct model is loaded before generation
load_model(selected_model)
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
def run():
with torch.no_grad():
model.generate(**generate_kwargs)
thread = threading.Thread(target=run)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield f"**{current_model_name}**\n\n{response}"
# User input handler
def user(message, chat_history):
chat_history.append([message, None])
return "", chat_history
# Bot response handler - UPDATED to pass selected model
def bot(chatbot, max_tokens, temperature, top_p, selected_model):
message = chatbot[-1][0]
response_generator = respond(message, max_tokens, temperature, top_p, selected_model)
for response in response_generator:
chatbot[-1][1] = response
yield chatbot
# Model selector handler
def update_model(model_name):
load_model(model_name)
return model_name
# Clear chat handler
def clear_chat():
return None
# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
with gr.Row():
gr.HTML("""
<div style="text-align: center; width: 100%;">
<h1 style="margin: 0;">LeCarnet Demo</h1>
</div>
""")
msg_input = gr.Textbox(
placeholder="Il était une fois un petit garçon",
label="User Input",
render=False
)
with gr.Row():
with gr.Column(scale=1, min_width=150):
model_selector = gr.Dropdown(
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
value="LeCarnet-8M",
label="Select Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
clear_button = gr.Button("Clear Chat")
gr.Examples(
examples=[
["Il était une fois un petit phoque nommé Zoom. Zoom était très habile et aimait jouer dans l'eau."],
["Il était une fois un petit écureuil nommé Pipo. Pipo adorait grimper aux arbres."],
["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."],
],
inputs=msg_input,
label="Example Prompts"
)
with gr.Column(scale=4):
chatbot = gr.Chatbot(
bubble_full_width=False,
height=500
)
msg_input.render()
# Event Handlers
model_selector.change(
fn=update_model,
inputs=[model_selector],
outputs=[model_selector],
)
msg_input.submit(
fn=user,
inputs=[msg_input, chatbot],
outputs=[msg_input, chatbot],
queue=False
).then(
fn=bot,
inputs=[chatbot, max_tokens, temperature, top_p, model_selector], # Pass model_selector
outputs=[chatbot]
)
clear_button.click(
fn=clear_chat,
inputs=None,
outputs=chatbot,
queue=False
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)