Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,33 +1,55 @@
|
|
1 |
import os
|
2 |
import threading
|
3 |
-
|
4 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
5 |
-
|
6 |
-
MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
|
7 |
-
HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
8 |
-
MEDIA_PATH = "media/le-carnet.png"
|
9 |
-
|
10 |
-
models = {}
|
11 |
-
tokenizers = {}
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
inputs = tokenizer(prompt, return_tensors="pt")
|
24 |
|
|
|
25 |
streamer = TextIteratorStreamer(
|
26 |
tokenizer,
|
27 |
skip_prompt=False,
|
28 |
skip_special_tokens=True,
|
29 |
)
|
30 |
|
|
|
31 |
generate_kwargs = dict(
|
32 |
**inputs,
|
33 |
streamer=streamer,
|
@@ -38,60 +60,38 @@ def respond(prompt: str, chat_history, selected_model: str, max_tokens: int, tem
|
|
38 |
eos_token_id=tokenizer.eos_token_id,
|
39 |
)
|
40 |
|
|
|
41 |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
42 |
thread.start()
|
43 |
|
44 |
-
|
45 |
accumulated = ""
|
46 |
-
first = True
|
47 |
for new_text in streamer:
|
48 |
-
|
49 |
-
accumulated = prefix + new_text
|
50 |
-
first = False
|
51 |
-
else:
|
52 |
-
accumulated += new_text
|
53 |
yield accumulated
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
type="messages",
|
78 |
-
)
|
79 |
-
|
80 |
-
with gr.Column(scale=1, visible=True, elem_id="settings-panel"):
|
81 |
-
pass # Inputs are already defined in ChatInterface
|
82 |
-
|
83 |
-
demo.load(
|
84 |
-
js="""
|
85 |
-
() => {
|
86 |
-
const toggleBtn = document.querySelector('#toggle-btn button') || document.querySelector('#toggle-btn');
|
87 |
-
const panel = document.querySelector('#settings-panel');
|
88 |
-
toggleBtn.addEventListener('click', () => {
|
89 |
-
panel.style.display = (panel.style.display === 'none') ? 'flex' : 'none';
|
90 |
-
});
|
91 |
-
}
|
92 |
-
"""
|
93 |
-
)
|
94 |
-
|
95 |
|
96 |
if __name__ == "__main__":
|
97 |
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
|
|
|
1 |
import os
|
2 |
import threading
|
3 |
+
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
import gradio as gr
|
6 |
+
from transformers import (
|
7 |
+
AutoModelForCausalLM,
|
8 |
+
AutoTokenizer,
|
9 |
+
TextIteratorStreamer,
|
10 |
+
)
|
11 |
+
|
12 |
+
# Define model paths
|
13 |
+
model_name_to_path = {
|
14 |
+
"LeCarnet-3M": "MaxLSB/LeCarnet-3M",
|
15 |
+
"LeCarnet-8M": "MaxLSB/LeCarnet-8M",
|
16 |
+
"LeCarnet-21M": "MaxLSB/LeCarnet-21M",
|
17 |
+
}
|
18 |
+
|
19 |
+
# Load Hugging Face token
|
20 |
+
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
21 |
+
|
22 |
+
# Preload models and tokenizers
|
23 |
+
loaded_models = defaultdict(dict)
|
24 |
+
|
25 |
+
for name, path in model_name_to_path.items():
|
26 |
+
loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
|
27 |
+
loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
|
28 |
+
loaded_models[name]["model"].eval()
|
29 |
+
|
30 |
+
def respond(
|
31 |
+
prompt: str,
|
32 |
+
chat_history,
|
33 |
+
model_name: str,
|
34 |
+
max_tokens: int,
|
35 |
+
temperature: float,
|
36 |
+
top_p: float,
|
37 |
+
):
|
38 |
+
# Select the appropriate model and tokenizer
|
39 |
+
tokenizer = loaded_models[model_name]["tokenizer"]
|
40 |
+
model = loaded_models[model_name]["model"]
|
41 |
+
|
42 |
+
# Tokenize input
|
43 |
inputs = tokenizer(prompt, return_tensors="pt")
|
44 |
|
45 |
+
# Set up streaming
|
46 |
streamer = TextIteratorStreamer(
|
47 |
tokenizer,
|
48 |
skip_prompt=False,
|
49 |
skip_special_tokens=True,
|
50 |
)
|
51 |
|
52 |
+
# Configure generation parameters
|
53 |
generate_kwargs = dict(
|
54 |
**inputs,
|
55 |
streamer=streamer,
|
|
|
60 |
eos_token_id=tokenizer.eos_token_id,
|
61 |
)
|
62 |
|
63 |
+
# Run generation in a background thread
|
64 |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
65 |
thread.start()
|
66 |
|
67 |
+
# Stream results
|
68 |
accumulated = ""
|
|
|
69 |
for new_text in streamer:
|
70 |
+
accumulated += new_text
|
|
|
|
|
|
|
|
|
71 |
yield accumulated
|
72 |
|
73 |
+
# Create Gradio Chat Interface
|
74 |
+
demo = gr.ChatInterface(
|
75 |
+
fn=respond,
|
76 |
+
additional_inputs=[
|
77 |
+
gr.Dropdown(
|
78 |
+
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
|
79 |
+
value="LeCarnet-8M",
|
80 |
+
label="Model",
|
81 |
+
),
|
82 |
+
gr.Slider(1, 512, value=512, step=1, label="Max New Tokens"),
|
83 |
+
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
|
84 |
+
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
|
85 |
+
],
|
86 |
+
title="LeCarnet",
|
87 |
+
description="Select a model and enter text to get started.",
|
88 |
+
examples=[
|
89 |
+
["Il était une fois un petit garçon qui vivait dans un village paisible."],
|
90 |
+
["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
|
91 |
+
["Il était une fois un petit lapin perdu"],
|
92 |
+
],
|
93 |
+
cache_examples=False,
|
94 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
if __name__ == "__main__":
|
97 |
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
|