lukecq commited on
Commit
e28f513
·
verified ·
1 Parent(s): fc4bf24

Update files to support multi-turn

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +82 -83
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.21.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
app.py CHANGED
@@ -11,14 +11,6 @@ from vllm import LLM, SamplingParams
11
  import vllm
12
  import re
13
 
14
- from huggingface_hub import login
15
- TOKEN = os.environ.get("TOKEN", None)
16
- login(token=TOKEN)
17
-
18
- print("transformers version:", transformers.__version__)
19
- print("vllm version:", vllm.__version__)
20
- print("gradio version:", gr.__version__)
21
-
22
 
23
  def load_model_processor(model_path):
24
  processor = AutoProcessor.from_pretrained(model_path)
@@ -32,24 +24,15 @@ def load_model_processor(model_path):
32
  model_path1 = "SeaLLMs/SeaLLMs-Audio-7B"
33
  model1, processor1 = load_model_processor(model_path1)
34
 
35
- def response_to_audio(audio_url, text, model=None, processor=None, temperature = 0,repetition_penalty=1.1, top_p = 0.9,max_new_tokens = 2048):
36
- if text == None:
37
- conversation = [
38
- {"role": "user", "content": [
39
- {"type": "audio", "audio_url": audio_url},
40
- ]},]
41
- elif audio_url == None:
42
- conversation = [
43
- {"role": "user", "content": [
44
- {"type": "text", "text": text},
45
- ]},]
46
- else:
47
- conversation = [
48
- {"role": "user", "content": [
49
- {"type": "audio", "audio_url": audio_url},
50
- {"type": "text", "text": text},
51
- ]},]
52
-
53
  text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
54
  audios = []
55
  for message in conversation:
@@ -76,45 +59,76 @@ def response_to_audio(audio_url, text, model=None, processor=None, temperature =
76
 
77
  output = model.generate([input], sampling_params=sampling_params)[0]
78
  response = output.outputs[0].text
 
 
79
  return response
80
 
81
- def clear_inputs():
82
- return None, "", ""
83
 
84
  def contains_chinese(text):
85
  # Regular expression for Chinese characters
86
  chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
87
  return bool(chinese_char_pattern.search(text))
88
 
89
- def compare_responses(audio_url, text):
90
- if contains_chinese(text):
91
- return "Caution! This demo does not support Chinese!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- response1 = response_to_audio(audio_url, text, model1, processor1)
94
- if contains_chinese(response1):
95
- return "ERROR! Try another example!"
96
-
97
- return response1
 
 
 
 
 
 
 
 
 
 
98
 
99
  with gr.Blocks() as demo:
100
- # gr.Markdown(f"Evaluate {model_path1}")
101
  gr.HTML("""<p align="center"><img src="https://DAMO-NLP-SG.github.io/SeaLLMs-Audio/static/images/seallm-audio-logo.png" style="height: 80px"/><p>""")
102
- # gr.Image("images/seal_logo.png", elem_id="seal_logo", show_label=False,height=80,show_fullscreen_button=False)
103
  gr.HTML("""<h1 align="center" id="space-title">SeaLLMs-Audio-Demo</h1>""")
104
- # gr.Markdown(
105
- # """\
106
- # <center><font size=4>This WebUI is based on SeaLLMs-Audio-7B, developed by Alibaba DAMO Academy.<br>
107
- # You can interact with the chatbot in <b>English, Chinese, Indonesian, Thai, or Vietnamese</b>.<br>
108
- # For the input, you can input <b>audio and/or text</center>.""")
109
-
110
- # # Links with proper formatting
111
- # gr.Markdown(
112
- # """<center><font size=4>
113
- # <a href="https://huggingface.co/SeaLLMs/SeaLLMs-v3-7B-Chat">[Website]</a> &nbsp;
114
- # <a href="https://huggingface.co/SeaLLMs/SeaLLMs-Audio-7B">[Model🤗]</a> &nbsp;
115
- # <a href="https://github.com/DAMO-NLP-SG/SeaLLMs-Audio">[Github]</a>
116
- # </center>""",
117
- # )
118
 
119
  gr.HTML(
120
  """<div style="text-align: center; font-size: 16px;">
@@ -141,42 +155,27 @@ with gr.Blocks() as demo:
141
  # with gr.Column():
142
  # repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.1, step=0.1, label="Repetition Penalty")
143
 
144
- with gr.Row():
145
- with gr.Column():
146
- # mic_input = gr.Microphone(label="Record Audio", type="filepath", elem_id="mic_input")
147
- mic_input = gr.Audio(sources = ['upload', 'microphone'], label="Record Audio", type="filepath", elem_id="mic_input")
148
- with gr.Column():
149
- additional_input = gr.Textbox(label="Text Input")
150
-
151
- # Button to trigger the function
152
- with gr.Row():
153
- btn_submit = gr.Button("Submit")
154
- btn_clear = gr.Button("Clear")
155
-
156
- with gr.Row():
157
- output_text1 = gr.Textbox(label=model_path1.split('/')[-1], interactive=False, elem_id="output_text1")
158
-
159
- btn_submit.click(
160
- fn=compare_responses,
161
- inputs=[mic_input, additional_input],
162
- outputs=[output_text1],
163
  )
164
 
165
- btn_clear.click(
166
- fn=clear_inputs,
167
- inputs=None,
168
- outputs=[mic_input, additional_input, output_text1],
169
- queue=False,
170
  )
 
 
 
171
 
 
172
 
173
- # demo.launch(
174
- # share=False,
175
- # inbrowser=True,
176
- # server_port=7950,
177
- # server_name="0.0.0.0",
178
- # max_threads=40
179
- # )
180
 
181
  demo.launch(share=True)
182
  demo.queue(default_concurrency_limit=40).launch(share=True)
 
11
  import vllm
12
  import re
13
 
 
 
 
 
 
 
 
 
14
 
15
  def load_model_processor(model_path):
16
  processor = AutoProcessor.from_pretrained(model_path)
 
24
  model_path1 = "SeaLLMs/SeaLLMs-Audio-7B"
25
  model1, processor1 = load_model_processor(model_path1)
26
 
27
+ def response_to_audio_conv(conversation, model=None, processor=None, temperature = 0.7,repetition_penalty=1.1, top_p = 0.5,max_new_tokens = 2048):
28
+ turn = conversation[-1]
29
+ if turn["role"] == "user":
30
+ for content in turn['content']:
31
+ if content["type"] == "text":
32
+ if contains_chinese(content["text"]):
33
+ return "Caution! This demo does not support Chinese!"
34
+
35
+
 
 
 
 
 
 
 
 
 
36
  text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
37
  audios = []
38
  for message in conversation:
 
59
 
60
  output = model.generate([input], sampling_params=sampling_params)[0]
61
  response = output.outputs[0].text
62
+ if contains_chinese(response):
63
+ return "ERROR! Try a different instruction/prompt!"
64
  return response
65
 
66
+ def print_like_dislike(x: gr.LikeData):
67
+ print(x.index, x.value, x.liked)
68
 
69
  def contains_chinese(text):
70
  # Regular expression for Chinese characters
71
  chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
72
  return bool(chinese_char_pattern.search(text))
73
 
74
+ def add_message(history, message):
75
+ paths = []
76
+ for turn in history:
77
+ if turn['role'] == "user" and type(turn['content']) != str:
78
+ paths.append(turn['content'][0])
79
+ for x in message["files"]:
80
+ if x not in paths:
81
+ history.append({"role": "user", "content": {"path": x}})
82
+ if message["text"] is not None:
83
+ history.append({"role": "user", "content": message["text"]})
84
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
85
+
86
+ def format_user_messgae(message):
87
+ if type(message['content']) == str:
88
+ return {"role": "user", "content": [{"type": "text", "text": message['content']}]}
89
+ else:
90
+ return {"role": "user", "content": [{"type": "audio", "audio_url": message['content'][0]}]}
91
+
92
+ def history_to_conversation(history):
93
+ conversation = []
94
+ audio_paths = []
95
+ for turn in history:
96
+ if turn['role'] == "user":
97
+ if not turn['content']:
98
+ continue
99
+ turn = format_user_messgae(turn)
100
+ if turn['content'][0]['type'] == 'audio':
101
+ if turn['content'][0]['audio_url'] in audio_paths:
102
+ continue
103
+ else:
104
+ audio_paths.append(turn['content'][0]['audio_url'])
105
+
106
+ if len(conversation) > 0 and conversation[-1]["role"] == "user":
107
+ conversation[-1]['content'].append(turn['content'][0])
108
+ else:
109
+ conversation.append(turn)
110
+ else:
111
+ conversation.append(turn)
112
 
113
+ print(json.dumps(conversation, indent=4, ensure_ascii=False))
114
+ return conversation
115
+
116
+ def bot(history: list, temperature = 0.7,repetition_penalty=1.1, top_p = 0.5,
117
+ max_new_tokens = 2048):
118
+ conversation = history_to_conversation(history)
119
+ response = response_to_audio_conv(conversation, model=model1, processor=processor1, temperature = temperature,repetition_penalty=repetition_penalty, top_p = top_p, max_new_tokens = max_new_tokens)
120
+ # response = "Nice to meet you!"
121
+ print("Bot:",response)
122
+
123
+ history.append({"role": "assistant", "content": ""})
124
+ for character in response:
125
+ history[-1]["content"] += character
126
+ time.sleep(0.01)
127
+ yield history
128
 
129
  with gr.Blocks() as demo:
 
130
  gr.HTML("""<p align="center"><img src="https://DAMO-NLP-SG.github.io/SeaLLMs-Audio/static/images/seallm-audio-logo.png" style="height: 80px"/><p>""")
 
131
  gr.HTML("""<h1 align="center" id="space-title">SeaLLMs-Audio-Demo</h1>""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  gr.HTML(
134
  """<div style="text-align: center; font-size: 16px;">
 
155
  # with gr.Column():
156
  # repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.1, step=0.1, label="Repetition Penalty")
157
 
158
+ chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages")
159
+
160
+ chat_input = gr.MultimodalTextbox(
161
+ interactive=True,
162
+ file_count="single",
163
+ file_types=['.wav'],
164
+ placeholder="Enter message (optional) ...",
165
+ show_label=False,
166
+ sources=["microphone", "upload"],
 
 
 
 
 
 
 
 
 
 
167
  )
168
 
169
+ chat_msg = chat_input.submit(
170
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
 
 
 
171
  )
172
+ bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
173
+ # bot_msg = chat_msg.then(bot, [chatbot, temperature, repetition_penalty, top_p], chatbot, api_name="bot_response")
174
+ bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
175
 
176
+ # chatbot.like(print_like_dislike, None, None, like_user_message=True)
177
 
178
+ clear_button = gr.ClearButton([chatbot, chat_input])
 
 
 
 
 
 
179
 
180
  demo.launch(share=True)
181
  demo.queue(default_concurrency_limit=40).launch(share=True)