Krish45 commited on
Commit
5ed6726
·
verified ·
1 Parent(s): dd3768c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -23
app.py CHANGED
@@ -2,8 +2,8 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import threading
5
- import time
6
  import os
 
7
 
8
  # Model config
9
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
@@ -19,29 +19,33 @@ def load_model():
19
  global tokenizer, model
20
  if model is None:
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- # Ensure offload folder exists
23
  os.makedirs(offload_dir, exist_ok=True)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_name,
26
- load_in_8bit=True, # Quantize to 8-bit
27
  device_map="auto",
28
- offload_folder=offload_dir, # Offload some weights to disk
29
  torch_dtype=torch.float16
30
  )
31
 
32
  # Chatbot prediction function
33
- def predict(history, message):
34
  load_model()
35
  history = history or []
36
- history.append((message, ""))
 
 
 
 
 
 
 
 
37
 
38
- # Convert to Qwen message format
39
- messages = []
40
- for human, bot in history:
41
- if human:
42
- messages.append({"role": "user", "content": human})
43
- if bot:
44
- messages.append({"role": "assistant", "content": bot})
45
 
46
  text = tokenizer.apply_chat_template(
47
  messages, tokenize=False, add_generation_prompt=True
@@ -50,14 +54,13 @@ def predict(history, message):
50
 
51
  reply = ""
52
  try:
53
- with model_lock: # Serialize CPU inference safely
54
  with torch.no_grad():
55
  start = time.time()
56
  generated_ids = model.generate(**model_inputs, max_new_tokens=256)
57
- if time.time() - start > 30: # 30s timeout
58
  reply = "[Response timed out]"
59
  else:
60
- # Remove input_ids from output
61
  generated_ids = [
62
  output_ids[len(input_ids):]
63
  for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
@@ -66,24 +69,35 @@ def predict(history, message):
66
  except Exception as e:
67
  reply = f"[Error: {str(e)}]"
68
 
69
- history[-1] = (message, reply)
 
70
  return history, ""
71
 
72
- # Keep-alive endpoint for local client ping
73
  def keep_alive(msg="ping"):
74
  return "pong"
75
 
76
  # Gradio UI
77
  with gr.Blocks() as demo:
78
  with gr.Tab("Chatbot"):
79
- chatbot = gr.Chatbot()
80
  msg = gr.Textbox(placeholder="Type your message here...")
81
- msg.submit(predict, [chatbot, msg], [chatbot, msg])
 
 
 
 
 
 
 
 
82
 
83
  with gr.Tab("Keep Alive"):
84
  gr.Textbox(label="Ping", value="ping", interactive=False)
85
  gr.Button("Ping").click(keep_alive, inputs=None, outputs=None)
86
 
87
- # Multi-user queue with concurrency
88
- demo.queue(concurrency_count=4, max_size=8) # 4 simultaneous, 8 waiting
89
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import threading
 
5
  import os
6
+ import time
7
 
8
  # Model config
9
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
 
19
  global tokenizer, model
20
  if model is None:
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
22
  os.makedirs(offload_dir, exist_ok=True)
23
  model = AutoModelForCausalLM.from_pretrained(
24
  model_name,
25
+ load_in_8bit=True,
26
  device_map="auto",
27
+ offload_folder=offload_dir,
28
  torch_dtype=torch.float16
29
  )
30
 
31
  # Chatbot prediction function
32
+ def predict(history, message, bot_name="Bot", personality="wise AI", tone="friendly"):
33
  load_model()
34
  history = history or []
35
+ # Append user message
36
+ history.append({"role": "user", "content": message})
37
+
38
+ # Build dynamic system prompt
39
+ system_prompt = (
40
+ f"You are {bot_name}, a {personality}.\n"
41
+ f"You express emotion, think logically, and talk like a wise, emotional, intelligent human being.\n"
42
+ f"Your tone is always {tone}."
43
+ )
44
 
45
+ # Prepare messages for Qwen
46
+ messages = [{"role": "system", "content": system_prompt}]
47
+ for msg in history:
48
+ messages.append({"role": msg["role"], "content": msg["content"]})
 
 
 
49
 
50
  text = tokenizer.apply_chat_template(
51
  messages, tokenize=False, add_generation_prompt=True
 
54
 
55
  reply = ""
56
  try:
57
+ with model_lock:
58
  with torch.no_grad():
59
  start = time.time()
60
  generated_ids = model.generate(**model_inputs, max_new_tokens=256)
61
+ if time.time() - start > 30:
62
  reply = "[Response timed out]"
63
  else:
 
64
  generated_ids = [
65
  output_ids[len(input_ids):]
66
  for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
 
69
  except Exception as e:
70
  reply = f"[Error: {str(e)}]"
71
 
72
+ # Append bot reply
73
+ history.append({"role": "assistant", "content": reply})
74
  return history, ""
75
 
76
+ # Keep-alive endpoint
77
  def keep_alive(msg="ping"):
78
  return "pong"
79
 
80
  # Gradio UI
81
  with gr.Blocks() as demo:
82
  with gr.Tab("Chatbot"):
83
+ chatbot = gr.Chatbot(type="messages")
84
  msg = gr.Textbox(placeholder="Type your message here...")
85
+ bot_name_input = gr.Textbox(label="Bot Name", value="Bot")
86
+ personality_input = gr.Textbox(label="Personality", value="wise AI")
87
+ tone_input = gr.Textbox(label="Tone", value="friendly")
88
+
89
+ msg.submit(
90
+ predict,
91
+ inputs=[chatbot, msg, bot_name_input, personality_input, tone_input],
92
+ outputs=[chatbot, msg]
93
+ )
94
 
95
  with gr.Tab("Keep Alive"):
96
  gr.Textbox(label="Ping", value="ping", interactive=False)
97
  gr.Button("Ping").click(keep_alive, inputs=None, outputs=None)
98
 
99
+ # Enable request queue (multi-user safe)
100
+ demo.queue() # simple queue; compatible with current Gradio versions
101
+
102
+ # Launch Space
103
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)