Krish45 commited on
Commit
dd3768c
·
verified ·
1 Parent(s): 74f9277

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -26
app.py CHANGED
@@ -1,24 +1,41 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
4
 
 
5
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
 
6
 
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_name, low_cpu_mem_usage=True, device_map="auto", torch_dtype="auto"
10
- )
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def predict(history, message):
13
- """
14
- history: list of [user, bot] message pairs from the Chatbot
15
- message: new user input string
16
- """
17
- # Add the latest user message to the conversation
18
- history = history or [] # make sure it's a list
19
  history.append((message, ""))
20
 
21
- # Convert to messages format for Qwen
22
  messages = []
23
  for human, bot in history:
24
  if human:
@@ -26,28 +43,47 @@ def predict(history, message):
26
  if bot:
27
  messages.append({"role": "assistant", "content": bot})
28
 
29
- # Apply chat template
30
  text = tokenizer.apply_chat_template(
31
  messages, tokenize=False, add_generation_prompt=True
32
  )
33
-
34
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
35
 
36
- # Generate response
37
- generated_ids = model.generate(**model_inputs, max_new_tokens=512)
38
- generated_ids = [
39
- output_ids[len(input_ids):]
40
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
41
- ]
42
- reply = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Update last message with bot reply
45
  history[-1] = (message, reply)
46
- return history, "" # return history + clear textbox
47
 
 
 
 
 
 
48
  with gr.Blocks() as demo:
49
- chatbot = gr.Chatbot()
50
- msg = gr.Textbox(placeholder="Type your message here...")
51
- msg.submit(predict, [chatbot, msg], [chatbot, msg])
 
 
 
 
 
52
 
53
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
 
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import threading
5
+ import time
6
+ import os
7
 
8
+ # Model config
9
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
10
+ offload_dir = "offload"
11
 
12
+ # Global variables
13
+ tokenizer = None
14
+ model = None
15
+ model_lock = threading.Lock()
16
 
17
+ # Lazy-load the model with quantization & offloading
18
+ def load_model():
19
+ global tokenizer, model
20
+ if model is None:
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ # Ensure offload folder exists
23
+ os.makedirs(offload_dir, exist_ok=True)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ load_in_8bit=True, # Quantize to 8-bit
27
+ device_map="auto",
28
+ offload_folder=offload_dir, # Offload some weights to disk
29
+ torch_dtype=torch.float16
30
+ )
31
+
32
+ # Chatbot prediction function
33
  def predict(history, message):
34
+ load_model()
35
+ history = history or []
 
 
 
 
36
  history.append((message, ""))
37
 
38
+ # Convert to Qwen message format
39
  messages = []
40
  for human, bot in history:
41
  if human:
 
43
  if bot:
44
  messages.append({"role": "assistant", "content": bot})
45
 
 
46
  text = tokenizer.apply_chat_template(
47
  messages, tokenize=False, add_generation_prompt=True
48
  )
 
49
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
50
 
51
+ reply = ""
52
+ try:
53
+ with model_lock: # Serialize CPU inference safely
54
+ with torch.no_grad():
55
+ start = time.time()
56
+ generated_ids = model.generate(**model_inputs, max_new_tokens=256)
57
+ if time.time() - start > 30: # 30s timeout
58
+ reply = "[Response timed out]"
59
+ else:
60
+ # Remove input_ids from output
61
+ generated_ids = [
62
+ output_ids[len(input_ids):]
63
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
64
+ ]
65
+ reply = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
66
+ except Exception as e:
67
+ reply = f"[Error: {str(e)}]"
68
 
 
69
  history[-1] = (message, reply)
70
+ return history, ""
71
 
72
+ # Keep-alive endpoint for local client ping
73
+ def keep_alive(msg="ping"):
74
+ return "pong"
75
+
76
+ # Gradio UI
77
  with gr.Blocks() as demo:
78
+ with gr.Tab("Chatbot"):
79
+ chatbot = gr.Chatbot()
80
+ msg = gr.Textbox(placeholder="Type your message here...")
81
+ msg.submit(predict, [chatbot, msg], [chatbot, msg])
82
+
83
+ with gr.Tab("Keep Alive"):
84
+ gr.Textbox(label="Ping", value="ping", interactive=False)
85
+ gr.Button("Ping").click(keep_alive, inputs=None, outputs=None)
86
 
87
+ # Multi-user queue with concurrency
88
+ demo.queue(concurrency_count=4, max_size=8) # 4 simultaneous, 8 waiting
89
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)