Krish45 commited on
Commit
a74f64b
·
verified ·
1 Parent(s): c94e3f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -8
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
 
4
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
5
 
@@ -8,24 +9,36 @@ model = AutoModelForCausalLM.from_pretrained(
8
  model_name, low_cpu_mem_usage=True, device_map="auto", torch_dtype="auto"
9
  )
10
 
11
- def predict(messages):
 
 
 
 
 
 
 
 
 
 
 
12
  text = tokenizer.apply_chat_template(
13
  messages, tokenize=False, add_generation_prompt=True
14
  )
15
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
16
 
17
- logger.info(f"Model generation process started at - {process_id}")
18
  generated_ids = model.generate(**model_inputs, max_new_tokens=512)
19
  generated_ids = [
20
- output_ids[len(input_ids) :]
21
  for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
22
  ]
23
- logger.info(f"Model generation process completed [{process_id}]")
24
 
25
  reply = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
26
- return reply
 
27
 
28
- iface = gr.Interface(fn=predict, inputs="messages", outputs="reply")
 
 
 
29
 
30
- # Launch with API access
31
- iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
6
 
 
9
  model_name, low_cpu_mem_usage=True, device_map="auto", torch_dtype="auto"
10
  )
11
 
12
+ def predict(history):
13
+ """
14
+ history: list of [user, bot] message pairs from the Chatbot
15
+ """
16
+ # Convert history into the 'messages' format for chat template
17
+ messages = []
18
+ for human, bot in history:
19
+ if human:
20
+ messages.append({"role": "user", "content": human})
21
+ if bot:
22
+ messages.append({"role": "assistant", "content": bot})
23
+
24
  text = tokenizer.apply_chat_template(
25
  messages, tokenize=False, add_generation_prompt=True
26
  )
27
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
28
 
 
29
  generated_ids = model.generate(**model_inputs, max_new_tokens=512)
30
  generated_ids = [
31
+ output_ids[len(input_ids):]
32
  for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
33
  ]
 
34
 
35
  reply = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
+ history.append((messages[-1]["content"] if messages else "", reply))
37
+ return history
38
 
39
+ with gr.Blocks() as server:
40
+ chatbot = gr.Chatbot()
41
+ msg = gr.Textbox(placeholder="Type your message here...")
42
+ msg.submit(predict, [chatbot], chatbot)
43
 
44
+ server.launch(server_name="0.0.0.0", server_port=7860, share=False)