Spaces:
Runtime error
Runtime error
import os, time, wave | |
import openai | |
import gradio as gr | |
import requests | |
from pydub import AudioSegment as am | |
from xml.etree import ElementTree | |
api_base = "https://mvp-azureopenai.openai.azure.com/" | |
api_key = os.getenv("OPENAI_API_KEY") | |
openai.api_type = "azure" | |
openai.api_base = api_base | |
openai.api_version = "2023-03-15-preview" | |
openai.api_key = api_key | |
messages_gpt = [] | |
messages_chat = [ | |
{"role": "system", "content": "You are an AI assistant that helps people find information."}, | |
] | |
prompts = "" | |
response_walle = [] | |
messages_vchat = [ | |
{"role": "system", "content": "You are an AI assistant that helps people find information and just response with SSML."}, | |
] | |
with gr.Blocks() as page: | |
with gr.Tabs(): | |
with gr.TabItem("GPT Playgroud"): | |
ui_chatbot_gpt = gr.Chatbot(label="GPT Playground:") | |
with gr.Row(): | |
with gr.Column(scale=0.9): | |
ui_prompt_gpt = gr.Textbox(placeholder="Please enter your prompt here.", show_label=False).style(container=False) | |
with gr.Column(scale=0.1, min_width=100): | |
ui_clear_gpt = gr.Button("Clear Input", ) | |
with gr.Accordion("Expand to config parameters:", open=False): | |
gr.Markdown("Look at me...") | |
with gr.Row(): | |
ui_temp_gpt = gr.Slider(0.1, 1.0, 0.9, step=0.1, label="Temperature", interactive=True) | |
ui_max_tokens_gpt = gr.Slider(100, 4000, 1000, step=100, label="Max Tokens", interactive=True) | |
ui_top_p_gpt = gr.Slider(0.1, 1.0, 0.5, step=0.1, label="Top P", interactive=True) | |
with gr.Accordion("Select radio button to see detail:", open=False): | |
ui_res_radio_gpt = gr.Radio(["Response from OpenAI Model", "Prompt messages history"], label="Show OpenAI response:", interactive=True) | |
ui_response_gpt = gr.TextArea(show_label=False, interactive=False).style(container=False) | |
def get_parameters_gpt(slider_1, slider_2, slider_3): | |
ui_temp_gpt.value = slider_1 | |
ui_max_tokens_gpt.value = slider_2 | |
ui_top_p_gpt.value = slider_3 | |
print("Log - Updated GPT parameters: Temperature=", ui_temp_gpt.value, | |
" Max Tokens=", ui_max_tokens_gpt.value, " Top_P=", ui_top_p_gpt.value) | |
def select_response_gpt(radio): | |
if radio == "Response from OpenAI Model": | |
return gr.update(value=gpt_x) | |
else: | |
return gr.update(value=messages_gpt) | |
def user_gpt(user_message, history): | |
global prompts | |
prompts = user_message | |
messages_gpt.append(prompts) | |
return "", history + [[user_message, None]] | |
def bot_gpt(history): | |
global gpt_x | |
gpt_x = openai.Completion.create( | |
engine="mvp-text-davinci-003", | |
prompt=prompts, | |
temperature=0.6, | |
max_tokens=1000, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
best_of=1, | |
stop=None | |
) | |
gpt_reply = gpt_x.choices[0].text | |
messages_gpt.append(gpt_reply) | |
history[-1][1] = gpt_reply | |
return history | |
ui_temp_gpt.change(get_parameters_gpt, [ui_temp_gpt, ui_max_tokens_gpt, ui_top_p_gpt]) | |
ui_max_tokens_gpt.change(get_parameters_gpt, [ui_temp_gpt, ui_max_tokens_gpt, ui_top_p_gpt]) | |
ui_top_p_gpt.change(get_parameters_gpt, [ui_temp_gpt, ui_max_tokens_gpt, ui_top_p_gpt]) | |
ui_prompt_gpt.submit(user_gpt, [ui_prompt_gpt, ui_chatbot_gpt], [ui_prompt_gpt, ui_chatbot_gpt], queue=False).then( | |
bot_gpt, ui_chatbot_gpt, ui_chatbot_gpt | |
) | |
ui_clear_gpt.click(lambda: None, None, ui_chatbot_gpt, queue=False) | |
ui_res_radio_gpt.change(select_response_gpt, ui_res_radio_gpt, ui_response_gpt) | |
with gr.TabItem("ChatGPT"): | |
ui_chatbot_chat = gr.Chatbot(label="ChatGPT:") | |
with gr.Row(): | |
with gr.Column(scale=0.9): | |
ui_prompt_chat = gr.Textbox(placeholder="Please enter your prompt here.", show_label=False).style(container=False) | |
with gr.Column(scale=0.1, min_width=100): | |
ui_clear_chat = gr.Button("Clear Chat") | |
with gr.Blocks(): | |
with gr.Accordion("Expand to config parameters:", open=False): | |
gr.Markdown("Here is the default system prompt, you can change it to your own prompt.") | |
ui_prompt_sys = gr.Textbox(value="You are an AI assistant that helps people find information.", show_label=False, interactive=True).style(container=False) | |
with gr.Row(): | |
ui_temp_chat = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="Temperature", interactive=True) | |
ui_max_tokens_chat = gr.Slider(100, 8000, 800, step=100, label="Max Tokens", interactive=True) | |
ui_top_p_chat = gr.Slider(0.05, 1.0, 0.9, step=0.1, label="Top P", interactive=True) | |
with gr.Accordion("Select radio button to see detail:", open=False): | |
ui_res_radio_chat = gr.Radio(["Response from OpenAI Model", "Prompt messages history"], label="Show OpenAI response:", interactive=True) | |
ui_response_chat = gr.TextArea(show_label=False, interactive=False).style(container=False) | |
def get_parameters_chat(slider_1, slider_2, slider_3): | |
ui_temp_chat.value = slider_1 | |
ui_max_tokens_chat.value = slider_2 | |
ui_top_p_chat.value = slider_3 | |
print("Log - Updated chatGPT parameters: Temperature=", ui_temp_chat.value, | |
" Max Tokens=", ui_max_tokens_chat.value, " Top_P=", ui_top_p_chat.value) | |
def select_response_chat(radio): | |
if radio == "Response from OpenAI Model": | |
return gr.update(value=chat_x) | |
else: | |
return gr.update(value=messages_chat) | |
def user_chat(user_message, history): | |
messages_chat.append({"role": "user", "content": user_message}) | |
return "", history + [[user_message, None]] | |
def bot_chat(history): | |
global chat_x | |
chat_x = openai.ChatCompletion.create( | |
engine="mvp-gpt-35-turbo", messages=messages_chat, | |
temperature=ui_temp_chat.value, | |
max_tokens=ui_max_tokens_chat.value, | |
top_p=ui_top_p_chat.value, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stop=None | |
) | |
ui_response_chat.value= chat_x | |
print(ui_response_chat.value) | |
chat_reply = chat_x.choices[0].message.content | |
messages_chat.append({"role": "assistant", "content": chat_reply}) | |
history[-1][1] = chat_reply | |
return history | |
def reset_sys(sysmsg): | |
global messages_chat | |
messages_chat = [ | |
{"role": "system", "content": sysmsg}, | |
] | |
ui_res_radio_chat.change(select_response_chat, ui_res_radio_chat, ui_response_chat) | |
ui_temp_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) | |
ui_max_tokens_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) | |
ui_top_p_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) | |
ui_prompt_sys.submit(reset_sys, ui_prompt_sys) | |
ui_prompt_chat.submit(user_chat, [ui_prompt_chat, ui_chatbot_chat], [ui_prompt_chat, ui_chatbot_chat], queue=False).then( | |
bot_chat, ui_chatbot_chat, ui_chatbot_chat | |
) | |
ui_clear_chat.click(lambda: None, None, ui_chatbot_chat, queue=False).then(reset_sys, ui_prompt_sys) | |
with gr.TabItem("WALL·E 2"): | |
ui_prompt_walle = gr.Textbox(placeholder="Please enter your prompt here to generate image.", show_label=False).style(container=False) | |
ui_image_walle = gr.Image() | |
with gr.Accordion("Select radio button to see detail:", open=False): | |
ui_response_walle = gr.TextArea(show_label=False, interactive=False).style(container=False) | |
def get_image_walle(prompt_walle): | |
global response_walle | |
walle_api_version = '2022-08-03-preview' | |
url = "{}dalle/text-to-image?api-version={}".format(api_base, walle_api_version) | |
headers= { "api-key": api_key, "Content-Type": "application/json" } | |
body = { | |
"caption": prompt_walle, | |
"resolution": "1024x1024" | |
} | |
submission = requests.post(url, headers=headers, json=body) | |
response_walle.append(submission.json()) | |
print("Log - WALL·E status: {}".format(submission.json())) | |
operation_location = submission.headers['Operation-Location'] | |
retry_after = submission.headers['Retry-after'] | |
status = "" | |
while (status != "Succeeded"): | |
time.sleep(int(retry_after)) | |
response = requests.get(operation_location, headers=headers) | |
response_walle.append(response.json()) | |
print("Log - WALL·E status: {}".format(response.json())) | |
status = response.json()['status'] | |
image_url_walle = response.json()['result']['contentUrl'] | |
return gr.update(value=image_url_walle) | |
def get_response_walle(): | |
global response_walle | |
return gr.update(value=response_walle) | |
ui_prompt_walle.submit(get_image_walle, ui_prompt_walle, ui_image_walle, queue=False).then(get_response_walle, None, ui_response_walle) | |
with gr.TabItem("VoiceChat"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Expand to config parameters:", open=False): | |
ui_prompt_sys_vchat = gr.Textbox(value="You are an AI assistant that helps people find information and just response with SSML.", show_label=False, interactive=True).style(container=False) | |
ui_voice_inc_vchat = gr.Audio(source="microphone", type="filepath") | |
ui_voice_out_vchat = gr.Audio(value=None, type="filepath", interactive=False).style(container=False) | |
with gr.Accordion("Expand to config parameters:", open=False): | |
with gr.Row(): | |
ui_temp_vchat = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="Temperature", interactive=True) | |
ui_max_tokens_vchat = gr.Slider(100, 8000, 800, step=100, label="Max Tokens", interactive=True) | |
ui_top_p_vchat = gr.Slider(0.05, 1.0, 0.9, step=0.1, label="Top P", interactive=True) | |
with gr.Column(): | |
ui_chatbot_vchat = gr.Chatbot(label="Voice to ChatGPT:") | |
with gr.Accordion("Select radio button to see detail:", open=False): | |
ui_res_radio_vchat = gr.Radio(["Response from OpenAI Model", "Prompt messages history"], label="Show OpenAI response:", interactive=True) | |
ui_response_vchat = gr.TextArea(show_label=False, interactive=False).style(container=False) | |
def get_parameters_vchat(slider_1, slider_2, slider_3): | |
ui_temp_vchat.value = slider_1 | |
ui_max_tokens_vchat.value = slider_2 | |
ui_top_p_vchat.value = slider_3 | |
print("Log - Updated chatGPT parameters: Temperature=", ui_temp_vchat.value, | |
" Max Tokens=", ui_max_tokens_vchat.value, " Top_P=", ui_top_p_vchat.value) | |
def select_response_vchat(radio): | |
if radio == "Response from OpenAI Model": | |
return gr.update(value=vchat_x) | |
else: | |
return gr.update(value=messages_vchat) | |
def speech_to_text(voice_message): | |
# Downsample input voice to 16kHz | |
voice_wav = am.from_file(voice_message, format='wav') | |
voice_wav = voice_wav.set_frame_rate(16000) | |
voice_wav.export(voice_message, format='wav') | |
# STT | |
OASK_Speech = os.getenv("OASK_Speech") | |
service_region = "westus" | |
base_url = "https://"+service_region+".stt.speech.microsoft.com/" | |
path = 'speech/recognition/conversation/cognitiveservices/v1' | |
constructed_url = base_url + path | |
params = { | |
'language': 'zh-CN', | |
'format': 'detailed' | |
} | |
headers = { | |
'Ocp-Apim-Subscription-Key': OASK_Speech, | |
'Content-Type': 'audio/wav; codecs=audio/pcm; samplerate=16000', | |
'Accept': 'application/json;text/xml' | |
} | |
body = open(voice_message,'rb').read() | |
response = requests.post(constructed_url, params=params, headers=headers, data=body) | |
if response.status_code == 200: | |
rs = response.json() | |
if rs != '': | |
print(rs) | |
else: | |
print("\nLog - Status code: " + str(response.status_code) + "\nSomething went wrong. Check your subscription key and headers.\n") | |
print("Reason: " + str(response.reason) + "\n") | |
sst_text = rs['DisplayText'] | |
return sst_text | |
def text_to_speech(): | |
OASK_Speech = os.getenv("OASK_Speech") | |
service_region = "westus" | |
base_url = "https://"+service_region+".tts.speech.microsoft.com/" | |
path = 'cognitiveservices/v1' | |
constructed_url = base_url + path | |
headers = { | |
'Ocp-Apim-Subscription-Key': OASK_Speech, | |
'Content-Type': 'application/ssml+xml', | |
'X-Microsoft-OutputFormat': 'riff-24khz-16bit-mono-pcm', | |
'User-Agent': 'Voice ChatGPT' | |
} | |
xml_body = ElementTree.Element('speak', version='1.0') | |
xml_body.set('{http://www.w3.org/XML/1998/namespace}lang', 'zh-cn') | |
voice = ElementTree.SubElement(xml_body, 'voice') | |
voice.set('{http://www.w3.org/XML/1998/namespace}lang', 'zh-cn') | |
voice.set('name', 'zh-CN-XiaoxiaoNeural') | |
voice.text = vchat_reply | |
body = ElementTree.tostring(xml_body) | |
response = requests.post(constructed_url, headers=headers, data=body) | |
if response.status_code == 200: | |
with open('chatgpt.wav', 'wb') as audio: | |
audio.write(response.content) | |
print("\nStatus code: " + str(response.status_code) + "\nYour TTS is ready for playback.\n") | |
else: | |
print("\nStatus code: " + str(response.status_code) + "\nSomething went wrong. Check your subscription key and headers.\n") | |
print("Reason: " + str(response.reason) + "\n") | |
tts_file = "chatgpt.wav" | |
return gr.update(value=tts_file, interactive=True) | |
def user_vchat(user_voice_message, history): | |
user_message = speech_to_text(user_voice_message) | |
messages_vchat.append({"role": "user", "content": user_message}) | |
return history + [[user_message, None]] | |
def bot_vchat(history): | |
global vchat_x, vchat_reply | |
vchat_x = openai.ChatCompletion.create( | |
engine="mvp-gpt-35-turbo", messages=messages_vchat, | |
temperature=ui_temp_chat.value, | |
max_tokens=ui_max_tokens_chat.value, | |
top_p=ui_top_p_chat.value, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stop=None | |
) | |
ui_response_vchat.value= vchat_x | |
print(ui_response_vchat.value) | |
vchat_reply = vchat_x.choices[0].message.content | |
messages_vchat.append({"role": "assistant", "content": vchat_reply}) | |
history[-1][1] = vchat_reply | |
return history | |
ui_res_radio_chat.change(select_response_chat, ui_res_radio_chat, ui_response_chat) | |
ui_temp_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) | |
ui_max_tokens_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) | |
ui_top_p_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) | |
ui_voice_inc_vchat.change(user_vchat, [ui_voice_inc_vchat, ui_chatbot_vchat], ui_chatbot_vchat, queue=False).then( | |
bot_vchat, ui_chatbot_vchat, ui_chatbot_vchat, queue=False).then(text_to_speech, None, ui_voice_out_vchat) | |
page.launch(share=False) | |