Krish45 commited on
Commit
3dcc866
·
verified ·
1 Parent(s): a74f64b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -9,11 +9,16 @@ model = AutoModelForCausalLM.from_pretrained(
9
  model_name, low_cpu_mem_usage=True, device_map="auto", torch_dtype="auto"
10
  )
11
 
12
- def predict(history):
13
  """
14
  history: list of [user, bot] message pairs from the Chatbot
 
15
  """
16
- # Convert history into the 'messages' format for chat template
 
 
 
 
17
  messages = []
18
  for human, bot in history:
19
  if human:
@@ -21,24 +26,28 @@ def predict(history):
21
  if bot:
22
  messages.append({"role": "assistant", "content": bot})
23
 
 
24
  text = tokenizer.apply_chat_template(
25
  messages, tokenize=False, add_generation_prompt=True
26
  )
 
27
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
28
 
 
29
  generated_ids = model.generate(**model_inputs, max_new_tokens=512)
30
  generated_ids = [
31
  output_ids[len(input_ids):]
32
  for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
33
  ]
34
-
35
  reply = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
- history.append((messages[-1]["content"] if messages else "", reply))
37
- return history
38
 
39
- with gr.Blocks() as server:
 
 
 
 
40
  chatbot = gr.Chatbot()
41
  msg = gr.Textbox(placeholder="Type your message here...")
42
- msg.submit(predict, [chatbot], chatbot)
43
 
44
- server.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
9
  model_name, low_cpu_mem_usage=True, device_map="auto", torch_dtype="auto"
10
  )
11
 
12
+ def predict(history, message):
13
  """
14
  history: list of [user, bot] message pairs from the Chatbot
15
+ message: new user input string
16
  """
17
+ # Add the latest user message to the conversation
18
+ history = history or [] # make sure it's a list
19
+ history.append((message, ""))
20
+
21
+ # Convert to messages format for Qwen
22
  messages = []
23
  for human, bot in history:
24
  if human:
 
26
  if bot:
27
  messages.append({"role": "assistant", "content": bot})
28
 
29
+ # Apply chat template
30
  text = tokenizer.apply_chat_template(
31
  messages, tokenize=False, add_generation_prompt=True
32
  )
33
+
34
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
35
 
36
+ # Generate response
37
  generated_ids = model.generate(**model_inputs, max_new_tokens=512)
38
  generated_ids = [
39
  output_ids[len(input_ids):]
40
  for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
41
  ]
 
42
  reply = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
43
 
44
+ # Update last message with bot reply
45
+ history[-1] = (message, reply)
46
+ return history, "" # return history + clear textbox
47
+
48
+ with gr.Blocks() as demo:
49
  chatbot = gr.Chatbot()
50
  msg = gr.Textbox(placeholder="Type your message here...")
51
+ msg.submit(predict, [chatbot, msg], [chatbot, msg])
52
 
53
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)