MaxLSB commited on
Commit
39c555f
·
verified ·
1 Parent(s): f5f805b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -66
app.py CHANGED
@@ -1,33 +1,55 @@
1
  import os
2
  import threading
3
- import gradio as gr
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
-
6
- MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
7
- HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
8
- MEDIA_PATH = "media/le-carnet.png"
9
-
10
- models = {}
11
- tokenizers = {}
12
 
13
- for name in MODEL_NAMES:
14
- hub_id = f"MaxLSB/{name}"
15
- tokenizers[name] = AutoTokenizer.from_pretrained(hub_id, token=HF_TOKEN)
16
- models[name] = AutoModelForCausalLM.from_pretrained(hub_id, token=HF_TOKEN)
17
- models[name].eval()
18
-
19
-
20
- def respond(prompt: str, chat_history, selected_model: str, max_tokens: int, temperature: float, top_p: float):
21
- tokenizer = tokenizers[selected_model]
22
- model = models[selected_model]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  inputs = tokenizer(prompt, return_tensors="pt")
24
 
 
25
  streamer = TextIteratorStreamer(
26
  tokenizer,
27
  skip_prompt=False,
28
  skip_special_tokens=True,
29
  )
30
 
 
31
  generate_kwargs = dict(
32
  **inputs,
33
  streamer=streamer,
@@ -38,60 +60,38 @@ def respond(prompt: str, chat_history, selected_model: str, max_tokens: int, tem
38
  eos_token_id=tokenizer.eos_token_id,
39
  )
40
 
 
41
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
42
  thread.start()
43
 
44
- prefix = f"<img src='{MEDIA_PATH}' width='24' style='display:inline; vertical-align:middle; margin-right:6px;'/> <strong>{selected_model}</strong>: "
45
  accumulated = ""
46
- first = True
47
  for new_text in streamer:
48
- if first:
49
- accumulated = prefix + new_text
50
- first = False
51
- else:
52
- accumulated += new_text
53
  yield accumulated
54
 
55
-
56
- with gr.Blocks(css=".gr-chatbox {height: 600px !important;}") as demo:
57
- gr.Markdown("## LeCarnet")
58
-
59
- with gr.Row():
60
- with gr.Column(scale=4):
61
- with gr.Row():
62
- toggle_btn = gr.Button("Show/hide parameters", elem_id="toggle-btn")
63
- chat = gr.ChatInterface(
64
- fn=respond,
65
- additional_inputs=[
66
- gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model"),
67
- gr.Slider(1, 512, value=512, step=1, label="Max new tokens"),
68
- gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
69
- gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p"),
70
- ],
71
- examples=[
72
- ["Il était une fois un petit garçon qui vivait dans un village paisible."],
73
- ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
74
- ["Il était une fois un petit lapin perdu"],
75
- ],
76
- cache_examples=False,
77
- type="messages",
78
- )
79
-
80
- with gr.Column(scale=1, visible=True, elem_id="settings-panel"):
81
- pass # Inputs are already defined in ChatInterface
82
-
83
- demo.load(
84
- js="""
85
- () => {
86
- const toggleBtn = document.querySelector('#toggle-btn button') || document.querySelector('#toggle-btn');
87
- const panel = document.querySelector('#settings-panel');
88
- toggleBtn.addEventListener('click', () => {
89
- panel.style.display = (panel.style.display === 'none') ? 'flex' : 'none';
90
- });
91
- }
92
- """
93
- )
94
-
95
 
96
  if __name__ == "__main__":
97
  demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
 
1
  import os
2
  import threading
3
+ from collections import defaultdict
 
 
 
 
 
 
 
 
4
 
5
+ import gradio as gr
6
+ from transformers import (
7
+ AutoModelForCausalLM,
8
+ AutoTokenizer,
9
+ TextIteratorStreamer,
10
+ )
11
+
12
+ # Define model paths
13
+ model_name_to_path = {
14
+ "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
15
+ "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
16
+ "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
17
+ }
18
+
19
+ # Load Hugging Face token
20
+ hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
21
+
22
+ # Preload models and tokenizers
23
+ loaded_models = defaultdict(dict)
24
+
25
+ for name, path in model_name_to_path.items():
26
+ loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
27
+ loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
28
+ loaded_models[name]["model"].eval()
29
+
30
+ def respond(
31
+ prompt: str,
32
+ chat_history,
33
+ model_name: str,
34
+ max_tokens: int,
35
+ temperature: float,
36
+ top_p: float,
37
+ ):
38
+ # Select the appropriate model and tokenizer
39
+ tokenizer = loaded_models[model_name]["tokenizer"]
40
+ model = loaded_models[model_name]["model"]
41
+
42
+ # Tokenize input
43
  inputs = tokenizer(prompt, return_tensors="pt")
44
 
45
+ # Set up streaming
46
  streamer = TextIteratorStreamer(
47
  tokenizer,
48
  skip_prompt=False,
49
  skip_special_tokens=True,
50
  )
51
 
52
+ # Configure generation parameters
53
  generate_kwargs = dict(
54
  **inputs,
55
  streamer=streamer,
 
60
  eos_token_id=tokenizer.eos_token_id,
61
  )
62
 
63
+ # Run generation in a background thread
64
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
65
  thread.start()
66
 
67
+ # Stream results
68
  accumulated = ""
 
69
  for new_text in streamer:
70
+ accumulated += new_text
 
 
 
 
71
  yield accumulated
72
 
73
+ # Create Gradio Chat Interface
74
+ demo = gr.ChatInterface(
75
+ fn=respond,
76
+ additional_inputs=[
77
+ gr.Dropdown(
78
+ choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
79
+ value="LeCarnet-8M",
80
+ label="Model",
81
+ ),
82
+ gr.Slider(1, 512, value=512, step=1, label="Max New Tokens"),
83
+ gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
84
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
85
+ ],
86
+ title="LeCarnet",
87
+ description="Select a model and enter text to get started.",
88
+ examples=[
89
+ ["Il était une fois un petit garçon qui vivait dans un village paisible."],
90
+ ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
91
+ ["Il était une fois un petit lapin perdu"],
92
+ ],
93
+ cache_examples=False,
94
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  if __name__ == "__main__":
97
  demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)