MaxLSB commited on
Commit
a7a20a5
·
verified ·
1 Parent(s): 537cd44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -139
app.py CHANGED
@@ -1,142 +1,99 @@
1
  import os
2
- import uuid
3
- import time
4
- import json
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import gradio as gr
7
- import modelscope_studio.components.antd as antd
8
- import modelscope_studio.components.antdx as antdx
9
- import modelscope_studio.components.base as ms
10
- import modelscope_studio.components.pro as pro
11
-
12
- # Define model paths
13
- MODEL_PATHS = {
14
- "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
15
- "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
16
- "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
17
- }
18
-
19
- # Set HF token
20
- hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
21
- if not hf_token:
22
- raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
23
-
24
- # Load tokenizer and model globally
25
- tokenizer = None
26
- model = None
27
-
28
- def load_model(model_name: str):
29
- global tokenizer, model
30
- if model_name not in MODEL_PATHS:
31
- raise ValueError(f"Unknown model: {model_name}")
32
- print(f"Loading {model_name}...")
33
- repo = MODEL_PATHS[model_name]
34
- tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token)
35
- model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token)
36
- model.eval()
37
- print(f"{model_name} loaded.")
38
-
39
- def generate_response(prompt, max_new_tokens=200):
40
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
41
- outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
42
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
- return response[len(prompt):].strip()
44
-
45
- # CSS for styling chatbot header with avatar
46
- css = """
47
- .chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header {
48
- display: flex;
49
- align-items: center;
50
- }
51
- .chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header img {
52
- width: 20px;
53
- height: 20px;
54
- margin-right: 8px;
55
- vertical-align: middle;
56
- }
57
- """
58
-
59
- # Default settings
60
- DEFAULT_SETTINGS = {
61
- "model": "LeCarnet-3M",
62
- "sys_prompt": "",
63
- }
64
-
65
- # Initial state with one fixed conversation
66
- state = gr.State({
67
- "conversation_id": "default",
68
- "conversation_contexts": {
69
- "default": {
70
- "history": [],
71
- "settings": DEFAULT_SETTINGS,
72
- }
73
- },
74
- })
75
-
76
- # Welcome message (optional)
77
- def welcome_config():
78
- return {
79
- "title": "LeCarnet Chatbot",
80
- "description": "Start chatting below!",
81
- "promptSuggestions": ["Hello", "Tell me a story", "How are you?"]
82
- }
83
-
84
- with gr.Blocks(css=css) as demo:
85
- with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"):
86
- # Right Column - Chat Interface
87
- with antd.Col(flex=1, elem_style=dict(height="100%")):
88
- with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"):
89
- chatbot = pro.Chatbot(
90
- elem_classes="chatbot-chat-messages",
91
- height=0,
92
- welcome_config=welcome_config()
93
- )
94
- with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion:
95
- with ms.Slot("children"):
96
- input = antdx.Sender(placeholder="Type your message here...")
97
-
98
- current_state = state
99
-
100
- def add_message(user_input, state_value):
101
- history = state_value["conversation_contexts"]["default"]["history"]
102
- settings = state_value["conversation_contexts"]["default"]["settings"]
103
- selected_model = settings["model"]
104
-
105
- # Add user message
106
- history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())})
107
- yield {"chatbot": gr.update(value=history)}
108
-
109
- # Start assistant response
110
- history.append({
111
- "role": "assistant",
112
- "content": [],
113
- "key": str(uuid.uuid4()),
114
- "header": f'<img src="/file=media/le-carnet.png" style="width:20px;height:20px;margin-right:8px;"> <span>{selected_model}</span>',
115
- "loading": True
116
- })
117
- yield {"chatbot": gr.update(value=history)}
118
-
119
- try:
120
- # Generate model response
121
- prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"])
122
- response = generate_response(prompt)
123
-
124
- # Update assistant message
125
- history[-1]["content"] = [{"type": "text", "content": response}]
126
- history[-1]["loading"] = False
127
- yield {"chatbot": gr.update(value=history)}
128
- except Exception as e:
129
- history[-1]["content"] = [{
130
- "type": "text",
131
- "content": f'<span style="color: red;">{str(e)}</span>'
132
- }]
133
- history[-1]["loading"] = False
134
- yield {"chatbot": gr.update(value=history)}
135
-
136
- input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot])
137
-
138
- # Load default model on startup
139
- load_model(DEFAULT_SETTINGS["model"])
140
-
141
  if __name__ == "__main__":
142
- demo.queue(default_concurrency_limit=10).launch()
 
 
1
  import os
2
+ import threading
 
 
 
3
  import gradio as gr
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ TextIteratorStreamer,
8
+ )
9
+
10
+ # Configuration
11
+ MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
12
+ HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
+ MEDIA_PATH = "media/le-carnet.png" # Relative path to logo
14
+
15
+ # Pre-load all tokenizers and models
16
+ models = {}
17
+ tokenizers = {}
18
+ for name in MODEL_NAMES:
19
+ hub_id = f"MaxLSB/LeCarnet-{name.split('-')[-1]}M"
20
+ tokenizers[name] = AutoTokenizer.from_pretrained(hub_id, token=HF_TOKEN)
21
+ models[name] = AutoModelForCausalLM.from_pretrained(hub_id, token=HF_TOKEN)
22
+ models[name].eval()
23
+
24
+
25
+ def respond(
26
+ prompt: str,
27
+ chat_history,
28
+ selected_model: str,
29
+ max_tokens: int,
30
+ temperature: float,
31
+ top_p: float,
32
+ ):
33
+ """
34
+ Generate a streaming response from the chosen LeCarnet model,
35
+ prepending the logo and model name in the chat bubble.
36
+ """
37
+ tokenizer = tokenizers[selected_model]
38
+ model = models[selected_model]
39
+ inputs = tokenizer(prompt, return_tensors="pt")
40
+
41
+ streamer = TextIteratorStreamer(
42
+ tokenizer,
43
+ skip_prompt=False,
44
+ skip_special_tokens=True,
45
+ )
46
+
47
+ generate_kwargs = dict(
48
+ **inputs,
49
+ streamer=streamer,
50
+ max_new_tokens=max_tokens,
51
+ do_sample=True,
52
+ temperature=temperature,
53
+ top_p=top_p,
54
+ eos_token_id=tokenizer.eos_token_id,
55
+ )
56
+
57
+ # Start generation in background thread
58
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
59
+ thread.start()
60
+
61
+ prefix = f"<img src='{MEDIA_PATH}' alt='logo' width='20' style='vertical-align: middle;'/> <strong>{selected_model}</strong>: "
62
+ accumulated = ""
63
+ first = True
64
+ for new_text in streamer:
65
+ if first:
66
+ # include prefix only once at start
67
+ accumulated = prefix + new_text
68
+ first = False
69
+ else:
70
+ accumulated += new_text
71
+ yield accumulated
72
+
73
+
74
+ # Build Gradio ChatInterface
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown("# LeCarnet: Short French Stories")
77
+ with gr.Row():
78
+ with gr.Column():
79
+ chat = gr.ChatInterface(
80
+ fn=respond,
81
+ additional_inputs=[
82
+ gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model"),
83
+ gr.Slider(1, 512, value=512, step=1, label="Max new tokens"),
84
+ gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
85
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p"),
86
+ ],
87
+ title="LeCarnet Chat",
88
+ description="Type the beginning of a sentence and watch the model finish it.",
89
+ examples=[
90
+ ["Il était une fois un petit garçon qui vivait dans un village paisible."],
91
+ ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
92
+ ["Il était une fois un petit lapin perdu"],
93
+ ],
94
+ cache_examples=False,
95
+ )
96
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if __name__ == "__main__":
98
+ demo.queue()
99
+ demo.launch()