Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,12 +7,10 @@ from transformers import (
|
|
7 |
TextIteratorStreamer,
|
8 |
)
|
9 |
|
10 |
-
# Configuration
|
11 |
MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
|
12 |
HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
13 |
-
MEDIA_PATH = "media/le-carnet.png"
|
14 |
|
15 |
-
# Pre-load all tokenizers and models
|
16 |
models = {}
|
17 |
tokenizers = {}
|
18 |
for name in MODEL_NAMES:
|
@@ -30,10 +28,6 @@ def respond(
|
|
30 |
temperature: float,
|
31 |
top_p: float,
|
32 |
):
|
33 |
-
"""
|
34 |
-
Generate a streaming response from the chosen LeCarnet model,
|
35 |
-
prepending the logo and model name in the chat bubble.
|
36 |
-
"""
|
37 |
tokenizer = tokenizers[selected_model]
|
38 |
model = models[selected_model]
|
39 |
inputs = tokenizer(prompt, return_tensors="pt")
|
@@ -54,16 +48,14 @@ def respond(
|
|
54 |
eos_token_id=tokenizer.eos_token_id,
|
55 |
)
|
56 |
|
57 |
-
# Start generation in background thread
|
58 |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
59 |
thread.start()
|
60 |
|
61 |
-
prefix = f"<img src='{MEDIA_PATH}'
|
62 |
accumulated = ""
|
63 |
first = True
|
64 |
for new_text in streamer:
|
65 |
if first:
|
66 |
-
# include prefix only once at start
|
67 |
accumulated = prefix + new_text
|
68 |
first = False
|
69 |
else:
|
@@ -71,19 +63,16 @@ def respond(
|
|
71 |
yield accumulated
|
72 |
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
with gr.Row():
|
78 |
-
with gr.Column():
|
|
|
|
|
79 |
chat = gr.ChatInterface(
|
80 |
fn=respond,
|
81 |
-
additional_inputs=[
|
82 |
-
gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model"),
|
83 |
-
gr.Slider(1, 512, value=512, step=1, label="Max new tokens"),
|
84 |
-
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
|
85 |
-
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p"),
|
86 |
-
],
|
87 |
title="LeCarnet Chat",
|
88 |
description="Type the beginning of a sentence and watch the model finish it.",
|
89 |
examples=[
|
@@ -93,7 +82,25 @@ with gr.Blocks() as demo:
|
|
93 |
],
|
94 |
cache_examples=False,
|
95 |
)
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
if __name__ == "__main__":
|
98 |
demo.queue()
|
99 |
-
demo.launch()
|
|
|
7 |
TextIteratorStreamer,
|
8 |
)
|
9 |
|
|
|
10 |
MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
|
11 |
HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
12 |
+
MEDIA_PATH = "media/le-carnet.png"
|
13 |
|
|
|
14 |
models = {}
|
15 |
tokenizers = {}
|
16 |
for name in MODEL_NAMES:
|
|
|
28 |
temperature: float,
|
29 |
top_p: float,
|
30 |
):
|
|
|
|
|
|
|
|
|
31 |
tokenizer = tokenizers[selected_model]
|
32 |
model = models[selected_model]
|
33 |
inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
48 |
eos_token_id=tokenizer.eos_token_id,
|
49 |
)
|
50 |
|
|
|
51 |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
52 |
thread.start()
|
53 |
|
54 |
+
prefix = f"<img src='{MEDIA_PATH}' width='24' style='display:inline; vertical-align:middle; margin-right:6px;'/> <strong>{selected_model}</strong>: "
|
55 |
accumulated = ""
|
56 |
first = True
|
57 |
for new_text in streamer:
|
58 |
if first:
|
|
|
59 |
accumulated = prefix + new_text
|
60 |
first = False
|
61 |
else:
|
|
|
63 |
yield accumulated
|
64 |
|
65 |
|
66 |
+
with gr.Blocks(css=".gr-chatbox {height: 600px !important;}") as demo:
|
67 |
+
gr.Markdown("## LeCarnet: Short French Stories")
|
68 |
+
|
69 |
with gr.Row():
|
70 |
+
with gr.Column(scale=4):
|
71 |
+
with gr.Row():
|
72 |
+
toggle_btn = gr.Button("Show/hide parameters", elem_id="toggle-btn")
|
73 |
chat = gr.ChatInterface(
|
74 |
fn=respond,
|
75 |
+
additional_inputs=[],
|
|
|
|
|
|
|
|
|
|
|
76 |
title="LeCarnet Chat",
|
77 |
description="Type the beginning of a sentence and watch the model finish it.",
|
78 |
examples=[
|
|
|
82 |
],
|
83 |
cache_examples=False,
|
84 |
)
|
85 |
+
|
86 |
+
with gr.Column(scale=1, visible=True, elem_id="settings-panel") as param_panel:
|
87 |
+
selected_model = gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model")
|
88 |
+
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max new tokens")
|
89 |
+
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
|
90 |
+
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p")
|
91 |
+
|
92 |
+
chat.additional_inputs = [selected_model, max_tokens, temperature, top_p]
|
93 |
+
|
94 |
+
demo.load(None, None, _js="""
|
95 |
+
() => {
|
96 |
+
const toggleBtn = document.querySelector('#toggle-btn button') || document.querySelector('#toggle-btn');
|
97 |
+
const panel = document.querySelector('#settings-panel');
|
98 |
+
toggleBtn.addEventListener('click', () => {
|
99 |
+
panel.style.display = (panel.style.display === 'none') ? 'flex' : 'none';
|
100 |
+
});
|
101 |
+
}
|
102 |
+
""")
|
103 |
+
|
104 |
if __name__ == "__main__":
|
105 |
demo.queue()
|
106 |
+
demo.launch()
|