File size: 4,729 Bytes
36942d4
f6b834f
a7a20a5
39c555f
790cffd
954f37f
790cffd
 
2c7b633
f6b834f
3a38c1f
790cffd
 
3a38c1f
790cffd
 
 
3a38c1f
a3c4cbd
 
 
 
 
790cffd
a3c4cbd
790cffd
 
 
3a38c1f
a3c4cbd
790cffd
 
 
 
 
a3c4cbd
 
 
 
790cffd
90d1b16
954f37f
a7a20a5
 
 
 
 
 
 
 
 
954f37f
f6b834f
 
 
 
 
a7a20a5
954f37f
603f014
a7a20a5
603f014
 
790cffd
 
 
 
 
 
a3c4cbd
 
603f014
a3c4cbd
603f014
 
 
790cffd
a3c4cbd
790cffd
 
7c9e931
 
 
 
 
790cffd
 
603f014
7811152
f19d748
99f5fa0
7e54aad
f19d748
7e54aad
b7b0fd1
02deb9a
7e54aad
 
 
02deb9a
7e54aad
52a9a97
790cffd
 
44b31eb
790cffd
 
 
 
1cbb5a4
790cffd
b7b0fd1
7e54aad
790cffd
954f37f
56d40da
 
2d0a01f
954f37f
790cffd
 
7b4f2fa
954f37f
5b21f39
 
 
 
 
7e54aad
5b21f39
7e54aad
07da16a
7c9e931
 
a3c4cbd
7c9e931
 
 
 
 
 
 
 
 
a3c4cbd
7c9e931
07da16a
7c9e931
 
 
 
 
 
a167f72
6ecb51d
341bd22
7e54aad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import torch
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
torch.set_num_threads(1)

# Globals
tokenizer = None
model = None
current_model_name = None

# Load selected model
def load_model(model_name):
    global tokenizer, model, current_model_name
    
    # Only load if it's a different model
    if current_model_name == model_name:
        return
        
    full_model_name = f"MaxLSB/{model_name}"
    print(f"Loading model: {full_model_name}")
    tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
    model.eval()
    current_model_name = model_name
    print(f"Model loaded: {current_model_name}")

# Initialize default model
load_model("LeCarnet-8M")

# Streaming generation function
def respond(message, max_tokens, temperature, top_p, selected_model):
    # Ensure the correct model is loaded before generation
    load_model(selected_model)
    
    inputs = tokenizer(message, return_tensors="pt")
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)

    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    def run():
        with torch.no_grad():
            model.generate(**generate_kwargs)

    thread = threading.Thread(target=run)
    thread.start()

    response = ""
    for new_text in streamer:
        response += new_text
        yield f"**{current_model_name}**\n\n{response}"

# User input handler
def user(message, chat_history):
    chat_history.append([message, None])
    return "", chat_history

# Bot response handler - UPDATED to pass selected model
def bot(chatbot, max_tokens, temperature, top_p, selected_model):
    message = chatbot[-1][0]
    response_generator = respond(message, max_tokens, temperature, top_p, selected_model)
    for response in response_generator:
        chatbot[-1][1] = response
        yield chatbot

# Model selector handler
def update_model(model_name):
    load_model(model_name)
    return model_name

# Clear chat handler
def clear_chat():
    return None

# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
    with gr.Row():
        gr.HTML("""
        <div style="text-align: center; width: 100%;">
            <h1 style="margin: 0;">LeCarnet Demo</h1>
        </div>
        """)

    msg_input = gr.Textbox(
        placeholder="Il était une fois un petit garçon",
        label="User Input",
        render=False
    )

    with gr.Row():
        with gr.Column(scale=1, min_width=150):
            model_selector = gr.Dropdown(
                choices=["LeCarnet-8M"],
                value="LeCarnet-8M",
                label="Select Model"
            )
            max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
            clear_button = gr.Button("Clear Chat")

            gr.Examples(
                examples=[
                    ["Il était une fois un petit phoque nommé Zoom. Zoom était très habile et aimait jouer dans l'eau."],
                    ["Il était une fois un petit écureuil nommé Pipo. Pipo adorait grimper aux arbres."],
                    ["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."],
                ],
                inputs=msg_input,
                label="Example Prompts"
            )

        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                bubble_full_width=False,
                height=500
            )
            msg_input.render()

    # Event Handlers
    model_selector.change(
        fn=update_model,
        inputs=[model_selector],
        outputs=[model_selector],
    )

    msg_input.submit(
        fn=user, 
        inputs=[msg_input, chatbot], 
        outputs=[msg_input, chatbot], 
        queue=False
    ).then(
        fn=bot, 
        inputs=[chatbot, max_tokens, temperature, top_p, model_selector],  # Pass model_selector
        outputs=[chatbot]
    )

    clear_button.click(
        fn=clear_chat, 
        inputs=None, 
        outputs=chatbot, 
        queue=False
    )

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)