Spaces:
Running
Running
File size: 4,729 Bytes
36942d4 f6b834f a7a20a5 39c555f 790cffd 954f37f 790cffd 2c7b633 f6b834f 3a38c1f 790cffd 3a38c1f 790cffd 3a38c1f a3c4cbd 790cffd a3c4cbd 790cffd 3a38c1f a3c4cbd 790cffd a3c4cbd 790cffd 90d1b16 954f37f a7a20a5 954f37f f6b834f a7a20a5 954f37f 603f014 a7a20a5 603f014 790cffd a3c4cbd 603f014 a3c4cbd 603f014 790cffd a3c4cbd 790cffd 7c9e931 790cffd 603f014 7811152 f19d748 99f5fa0 7e54aad f19d748 7e54aad b7b0fd1 02deb9a 7e54aad 02deb9a 7e54aad 52a9a97 790cffd 44b31eb 790cffd 1cbb5a4 790cffd b7b0fd1 7e54aad 790cffd 954f37f 56d40da 2d0a01f 954f37f 790cffd 7b4f2fa 954f37f 5b21f39 7e54aad 5b21f39 7e54aad 07da16a 7c9e931 a3c4cbd 7c9e931 a3c4cbd 7c9e931 07da16a 7c9e931 a167f72 6ecb51d 341bd22 7e54aad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import torch
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
torch.set_num_threads(1)
# Globals
tokenizer = None
model = None
current_model_name = None
# Load selected model
def load_model(model_name):
global tokenizer, model, current_model_name
# Only load if it's a different model
if current_model_name == model_name:
return
full_model_name = f"MaxLSB/{model_name}"
print(f"Loading model: {full_model_name}")
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
model.eval()
current_model_name = model_name
print(f"Model loaded: {current_model_name}")
# Initialize default model
load_model("LeCarnet-8M")
# Streaming generation function
def respond(message, max_tokens, temperature, top_p, selected_model):
# Ensure the correct model is loaded before generation
load_model(selected_model)
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
def run():
with torch.no_grad():
model.generate(**generate_kwargs)
thread = threading.Thread(target=run)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield f"**{current_model_name}**\n\n{response}"
# User input handler
def user(message, chat_history):
chat_history.append([message, None])
return "", chat_history
# Bot response handler - UPDATED to pass selected model
def bot(chatbot, max_tokens, temperature, top_p, selected_model):
message = chatbot[-1][0]
response_generator = respond(message, max_tokens, temperature, top_p, selected_model)
for response in response_generator:
chatbot[-1][1] = response
yield chatbot
# Model selector handler
def update_model(model_name):
load_model(model_name)
return model_name
# Clear chat handler
def clear_chat():
return None
# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
with gr.Row():
gr.HTML("""
<div style="text-align: center; width: 100%;">
<h1 style="margin: 0;">LeCarnet Demo</h1>
</div>
""")
msg_input = gr.Textbox(
placeholder="Il était une fois un petit garçon",
label="User Input",
render=False
)
with gr.Row():
with gr.Column(scale=1, min_width=150):
model_selector = gr.Dropdown(
choices=["LeCarnet-8M"],
value="LeCarnet-8M",
label="Select Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
clear_button = gr.Button("Clear Chat")
gr.Examples(
examples=[
["Il était une fois un petit phoque nommé Zoom. Zoom était très habile et aimait jouer dans l'eau."],
["Il était une fois un petit écureuil nommé Pipo. Pipo adorait grimper aux arbres."],
["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."],
],
inputs=msg_input,
label="Example Prompts"
)
with gr.Column(scale=4):
chatbot = gr.Chatbot(
bubble_full_width=False,
height=500
)
msg_input.render()
# Event Handlers
model_selector.change(
fn=update_model,
inputs=[model_selector],
outputs=[model_selector],
)
msg_input.submit(
fn=user,
inputs=[msg_input, chatbot],
outputs=[msg_input, chatbot],
queue=False
).then(
fn=bot,
inputs=[chatbot, max_tokens, temperature, top_p, model_selector], # Pass model_selector
outputs=[chatbot]
)
clear_button.click(
fn=clear_chat,
inputs=None,
outputs=chatbot,
queue=False
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10) |