MaxLSB commited on
Commit
eaff982
·
verified ·
1 Parent(s): 892e21c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -22
app.py CHANGED
@@ -7,12 +7,10 @@ from transformers import (
7
  TextIteratorStreamer,
8
  )
9
 
10
- # Configuration
11
  MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
12
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
- MEDIA_PATH = "media/le-carnet.png" # Relative path to logo
14
 
15
- # Pre-load all tokenizers and models
16
  models = {}
17
  tokenizers = {}
18
  for name in MODEL_NAMES:
@@ -30,10 +28,6 @@ def respond(
30
  temperature: float,
31
  top_p: float,
32
  ):
33
- """
34
- Generate a streaming response from the chosen LeCarnet model,
35
- prepending the logo and model name in the chat bubble.
36
- """
37
  tokenizer = tokenizers[selected_model]
38
  model = models[selected_model]
39
  inputs = tokenizer(prompt, return_tensors="pt")
@@ -54,16 +48,14 @@ def respond(
54
  eos_token_id=tokenizer.eos_token_id,
55
  )
56
 
57
- # Start generation in background thread
58
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
59
  thread.start()
60
 
61
- prefix = f"<img src='{MEDIA_PATH}' alt='logo' width='20' style='vertical-align: middle;'/> <strong>{selected_model}</strong>: "
62
  accumulated = ""
63
  first = True
64
  for new_text in streamer:
65
  if first:
66
- # include prefix only once at start
67
  accumulated = prefix + new_text
68
  first = False
69
  else:
@@ -71,19 +63,16 @@ def respond(
71
  yield accumulated
72
 
73
 
74
- # Build Gradio ChatInterface
75
- with gr.Blocks() as demo:
76
- gr.Markdown("# LeCarnet: Short French Stories")
77
  with gr.Row():
78
- with gr.Column():
 
 
79
  chat = gr.ChatInterface(
80
  fn=respond,
81
- additional_inputs=[
82
- gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model"),
83
- gr.Slider(1, 512, value=512, step=1, label="Max new tokens"),
84
- gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
85
- gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p"),
86
- ],
87
  title="LeCarnet Chat",
88
  description="Type the beginning of a sentence and watch the model finish it.",
89
  examples=[
@@ -93,7 +82,25 @@ with gr.Blocks() as demo:
93
  ],
94
  cache_examples=False,
95
  )
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if __name__ == "__main__":
98
  demo.queue()
99
- demo.launch()
 
7
  TextIteratorStreamer,
8
  )
9
 
 
10
  MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
11
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
12
+ MEDIA_PATH = "media/le-carnet.png"
13
 
 
14
  models = {}
15
  tokenizers = {}
16
  for name in MODEL_NAMES:
 
28
  temperature: float,
29
  top_p: float,
30
  ):
 
 
 
 
31
  tokenizer = tokenizers[selected_model]
32
  model = models[selected_model]
33
  inputs = tokenizer(prompt, return_tensors="pt")
 
48
  eos_token_id=tokenizer.eos_token_id,
49
  )
50
 
 
51
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
52
  thread.start()
53
 
54
+ prefix = f"<img src='{MEDIA_PATH}' width='24' style='display:inline; vertical-align:middle; margin-right:6px;'/> <strong>{selected_model}</strong>: "
55
  accumulated = ""
56
  first = True
57
  for new_text in streamer:
58
  if first:
 
59
  accumulated = prefix + new_text
60
  first = False
61
  else:
 
63
  yield accumulated
64
 
65
 
66
+ with gr.Blocks(css=".gr-chatbox {height: 600px !important;}") as demo:
67
+ gr.Markdown("## LeCarnet: Short French Stories")
68
+
69
  with gr.Row():
70
+ with gr.Column(scale=4):
71
+ with gr.Row():
72
+ toggle_btn = gr.Button("Show/hide parameters", elem_id="toggle-btn")
73
  chat = gr.ChatInterface(
74
  fn=respond,
75
+ additional_inputs=[],
 
 
 
 
 
76
  title="LeCarnet Chat",
77
  description="Type the beginning of a sentence and watch the model finish it.",
78
  examples=[
 
82
  ],
83
  cache_examples=False,
84
  )
85
+
86
+ with gr.Column(scale=1, visible=True, elem_id="settings-panel") as param_panel:
87
+ selected_model = gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model")
88
+ max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max new tokens")
89
+ temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
90
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p")
91
+
92
+ chat.additional_inputs = [selected_model, max_tokens, temperature, top_p]
93
+
94
+ demo.load(None, None, _js="""
95
+ () => {
96
+ const toggleBtn = document.querySelector('#toggle-btn button') || document.querySelector('#toggle-btn');
97
+ const panel = document.querySelector('#settings-panel');
98
+ toggleBtn.addEventListener('click', () => {
99
+ panel.style.display = (panel.style.display === 'none') ? 'flex' : 'none';
100
+ });
101
+ }
102
+ """)
103
+
104
  if __name__ == "__main__":
105
  demo.queue()
106
+ demo.launch()