Spaces:
Running
Running
import os | |
import torch | |
import threading | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
# Hugging Face token | |
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
torch.set_num_threads(1) | |
# Globals | |
tokenizer = None | |
model = None | |
current_model_name = None | |
# Load selected model | |
def load_model(model_name): | |
global tokenizer, model, current_model_name | |
# Only load if it's a different model | |
if current_model_name == model_name: | |
return | |
full_model_name = f"MaxLSB/{model_name}" | |
print(f"Loading model: {full_model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token) | |
model.eval() | |
current_model_name = model_name | |
print(f"Model loaded: {current_model_name}") | |
# Initialize default model | |
load_model("LeCarnet-8M") | |
# Streaming generation function | |
def respond(message, max_tokens, temperature, top_p, selected_model): | |
# Ensure the correct model is loaded before generation | |
load_model(selected_model) | |
inputs = tokenizer(message, return_tensors="pt") | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True) | |
generate_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
def run(): | |
with torch.no_grad(): | |
model.generate(**generate_kwargs) | |
thread = threading.Thread(target=run) | |
thread.start() | |
response = "" | |
for new_text in streamer: | |
response += new_text | |
yield f"**{current_model_name}**\n\n{response}" | |
# User input handler | |
def user(message, chat_history): | |
chat_history.append([message, None]) | |
return "", chat_history | |
# Bot response handler - UPDATED to pass selected model | |
def bot(chatbot, max_tokens, temperature, top_p, selected_model): | |
message = chatbot[-1][0] | |
response_generator = respond(message, max_tokens, temperature, top_p, selected_model) | |
for response in response_generator: | |
chatbot[-1][1] = response | |
yield chatbot | |
# Model selector handler | |
def update_model(model_name): | |
load_model(model_name) | |
return model_name | |
# Clear chat handler | |
def clear_chat(): | |
return None | |
# Gradio UI | |
with gr.Blocks(title="LeCarnet - Chat Interface") as demo: | |
with gr.Row(): | |
gr.HTML(""" | |
<div style="text-align: center; width: 100%;"> | |
<h1 style="margin: 0;">LeCarnet Demo</h1> | |
</div> | |
""") | |
msg_input = gr.Textbox( | |
placeholder="Il était une fois un petit garçon", | |
label="User Input", | |
render=False | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=150): | |
model_selector = gr.Dropdown( | |
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"], | |
value="LeCarnet-8M", | |
label="Select Model" | |
) | |
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens") | |
temperature = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling") | |
clear_button = gr.Button("Clear Chat") | |
gr.Examples( | |
examples=[ | |
["Il était une fois un petit phoque nommé Zoom. Zoom était très habile et aimait jouer dans l'eau."], | |
["Il était une fois un petit écureuil nommé Pipo. Pipo adorait grimper aux arbres."], | |
["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."], | |
], | |
inputs=msg_input, | |
label="Example Prompts" | |
) | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
bubble_full_width=False, | |
height=500 | |
) | |
msg_input.render() | |
# Event Handlers | |
model_selector.change( | |
fn=update_model, | |
inputs=[model_selector], | |
outputs=[model_selector], | |
) | |
msg_input.submit( | |
fn=user, | |
inputs=[msg_input, chatbot], | |
outputs=[msg_input, chatbot], | |
queue=False | |
).then( | |
fn=bot, | |
inputs=[chatbot, max_tokens, temperature, top_p, model_selector], # Pass model_selector | |
outputs=[chatbot] | |
) | |
clear_button.click( | |
fn=clear_chat, | |
inputs=None, | |
outputs=chatbot, | |
queue=False | |
) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10) |