MaxLSB commited on
Commit
d91c9af
·
verified ·
1 Parent(s): 9be0b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -102
app.py CHANGED
@@ -1,34 +1,30 @@
1
  import os
2
- import threading
 
 
 
3
  import gradio as gr
4
- from transformers import (
5
- AutoModelForCausalLM,
6
- AutoTokenizer,
7
- TextIteratorStreamer,
8
- )
9
 
10
- # Define your models
11
  MODEL_PATHS = {
12
  "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
13
  "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
14
  "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
15
  }
16
 
17
- # Add your Hugging Face token
18
  hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
  if not hf_token:
20
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
21
 
22
- # Load tokenizers & models - only load one initially
23
  tokenizer = None
24
  model = None
25
 
26
  def load_model(model_name: str):
27
- """Loads the specified model and tokenizer."""
28
  global tokenizer, model
29
  if model_name not in MODEL_PATHS:
30
  raise ValueError(f"Unknown model: {model_name}")
31
-
32
  print(f"Loading {model_name}...")
33
  repo = MODEL_PATHS[model_name]
34
  tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token)
@@ -36,98 +32,75 @@ def load_model(model_name: str):
36
  model.eval()
37
  print(f"{model_name} loaded.")
38
 
39
- # Initial model load
40
- initial_model = list(MODEL_PATHS.keys())[0]
41
- load_model(initial_model)
42
-
 
43
 
44
- def respond(
45
- prompt: str,
46
- chat_history: list,
47
- model_choice: str,
48
- max_tokens: int,
49
- temperature: float,
50
- top_p: float,
51
- ):
52
- global tokenizer, model
53
 
54
- # Reload model if it's not the currently loaded one
55
- current_path = getattr(model.config, "_name_or_path", None)
56
- desired_path = MODEL_PATHS[model_choice]
57
- if current_path != desired_path:
58
- load_model(model_choice)
59
-
60
- # Tokenize
61
- inputs = tokenizer(prompt, return_tensors="pt")
62
- streamer = TextIteratorStreamer(
63
- tokenizer,
64
- skip_prompt=False,
65
- skip_special_tokens=True,
66
- )
67
-
68
- # Prepare generation kwargs
69
- generate_kwargs = dict(
70
- **inputs,
71
- streamer=streamer,
72
- max_new_tokens=max_tokens,
73
- do_sample=True,
74
- temperature=temperature,
75
- top_p=top_p,
76
- eos_token_id=tokenizer.eos_token_id,
77
- )
78
-
79
- # Launch generation in a background thread
80
- thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
81
- thread.start()
82
-
83
- # Stream back to the UI
84
- accumulated = ""
85
- for new_text in streamer:
86
- accumulated += new_text
87
- yield accumulated
88
-
89
-
90
- # If you have custom CSS, define it here; otherwise set to None or remove the css= line below
91
- custom_css = None
92
-
93
- with gr.Blocks(css=custom_css, fill_width=True) as demo:
94
- with gr.Row():
95
- with gr.Column(scale=1):
96
- model_dropdown = gr.Dropdown(
97
- choices=list(MODEL_PATHS.keys()),
98
- value=initial_model,
99
- label="Choose Model",
100
- interactive=True
101
- )
102
- max_tokens_slider = gr.Slider(
103
- minimum=1, maximum=512, value=512, step=1, label="Max new tokens"
104
- )
105
- temperature_slider = gr.Slider(
106
- minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"
107
- )
108
- top_p_slider = gr.Slider(
109
- minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top‑p"
110
- )
111
-
112
- with gr.Column(scale=3):
113
- chatbot = gr.ChatInterface(
114
- fn=respond,
115
- additional_inputs=[
116
- model_dropdown,
117
- max_tokens_slider,
118
- temperature_slider,
119
- top_p_slider,
120
- ],
121
- examples=[
122
- ["Il était une fois un petit garçon qui vivait dans un village paisible."],
123
- ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
124
- ["Il était une fois un petit lapin perdu"],
125
- ],
126
- cache_examples=False,
127
- submit_btn="Generate",
128
- avatar_images=(None, "media/le-carnet.png")
129
- )
130
 
131
  if __name__ == "__main__":
132
- demo.queue()
133
- demo.launch()
 
1
  import os
2
+ import uuid
3
+ import time
4
+ import json
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import gradio as gr
7
+ import modelscope_studio.components.antd as antd
8
+ import modelscope_studio.components.base as ms
9
+ import modelscope_studio.components.pro as pro
 
 
10
 
 
11
  MODEL_PATHS = {
12
  "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
13
  "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
14
  "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
15
  }
16
 
 
17
  hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
18
  if not hf_token:
19
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
20
 
 
21
  tokenizer = None
22
  model = None
23
 
24
  def load_model(model_name: str):
 
25
  global tokenizer, model
26
  if model_name not in MODEL_PATHS:
27
  raise ValueError(f"Unknown model: {model_name}")
 
28
  print(f"Loading {model_name}...")
29
  repo = MODEL_PATHS[model_name]
30
  tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token)
 
32
  model.eval()
33
  print(f"{model_name} loaded.")
34
 
35
+ def generate_response(prompt, max_new_tokens=200):
36
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
37
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
38
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+ return response[len(prompt):].strip()
40
 
41
+ DEFAULT_SETTINGS = {
42
+ "model": "LeCarnet-3M",
43
+ "sys_prompt": "",
44
+ }
 
 
 
 
 
45
 
46
+ # Initial state with one fixed conversation
47
+ state = gr.State({
48
+ "conversation_id": "default",
49
+ "conversation_contexts": {
50
+ "default": {
51
+ "history": [],
52
+ "settings": DEFAULT_SETTINGS,
53
+ }
54
+ },
55
+ })
56
+
57
+ with gr.Blocks(css=css) as demo:
58
+ with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"):
59
+ # Right Column - Chat Interface
60
+ with antd.Col(flex=1, elem_style=dict(height="100%")):
61
+ with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"):
62
+ chatbot = pro.Chatbot(elem_classes="chatbot-chat-messages", height=0)
63
+ with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion:
64
+ with ms.Slot("children"):
65
+ input = antdx.Sender(placeholder="Type your message here...")
66
+
67
+ # Internal State
68
+ current_state = state
69
+
70
+ def add_message(user_input, state_value):
71
+ history = state_value["conversation_contexts"]["default"]["history"]
72
+ settings = state_value["conversation_contexts"]["default"]["settings"]
73
+ selected_model = settings["model"]
74
+
75
+ # Add user message
76
+ history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())})
77
+ yield {"chatbot": gr.update(value=history)}
78
+
79
+ # Start assistant response
80
+ history.append({"role": "assistant", "content": [], "key": str(uuid.uuid4()), "loading": True})
81
+ yield {"chatbot": gr.update(value=history)}
82
+
83
+ try:
84
+ # Generate model response
85
+ prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"])
86
+ response = generate_response(prompt)
87
+
88
+ # Update assistant message
89
+ history[-1]["content"] = [{"type": "text", "content": response}]
90
+ history[-1]["loading"] = False
91
+ yield {"chatbot": gr.update(value=history)}
92
+ except Exception as e:
93
+ history[-1]["content"] = [{
94
+ "type": "text",
95
+ "content": f'<span style="color: red;">{str(e)}</span>'
96
+ }]
97
+ history[-1]["loading"] = False
98
+ yield {"chatbot": gr.update(value=history)}
99
+
100
+ input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot])
101
+
102
+ # Load default model on startup
103
+ load_model(DEFAULT_SETTINGS["model"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
+ demo.queue(default_concurrency_limit=10).launch()