Spaces:
Running
Running
import base64 | |
import gradio as gr | |
import json | |
import mimetypes | |
import os | |
import requests | |
import time | |
import modelscope_studio.components.antd as antd | |
import modelscope_studio.components.antdx as antdx | |
import modelscope_studio.components.base as ms | |
import modelscope_studio.components.pro as pro | |
from modelscope_studio.components.pro.chatbot import ( | |
ChatbotActionConfig, ChatbotBotConfig, ChatbotMarkdownConfig, | |
ChatbotPromptsConfig, ChatbotUserConfig, ChatbotWelcomeConfig) | |
MODEL_VERSION = os.environ['MODEL_VERSION'] | |
API_URL = os.environ['API_URL'] | |
API_KEY = os.environ['API_KEY'] | |
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') | |
MULTIMODAL_FLAG = os.environ.get('MULTIMODAL') | |
MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS']) | |
NAME_MAP = { | |
'system': os.environ.get('SYSTEM_NAME'), | |
'user': os.environ.get('USER_NAME'), | |
} | |
MODEL_NAME = 'MiniMax-M1' | |
def prompt_select(e: gr.EventData): | |
return gr.update(value=e._data["payload"][0]["value"]["description"]) | |
def clear(): | |
return gr.update(value=None) | |
def retry(chatbot_value, e: gr.EventData): | |
index = e._data["payload"][0]["index"] | |
chatbot_value = chatbot_value[:index] | |
yield gr.update(loading=True), gr.update(value=chatbot_value), gr.update(disabled=True) | |
for chunk in submit(None, chatbot_value): | |
yield chunk | |
def cancel(chatbot_value): | |
chatbot_value[-1]["loading"] = False | |
chatbot_value[-1]["status"] = "done" | |
chatbot_value[-1]["footer"] = "Chat completion paused" | |
return gr.update(value=chatbot_value), gr.update(loading=False), gr.update(disabled=False) | |
def add_name_for_message(message): | |
name = NAME_MAP.get(message['role']) | |
if name is not None: | |
message['name'] = name | |
def convert_content(content): | |
if isinstance(content, str): | |
return content | |
if isinstance(content, tuple): | |
return [{ | |
'type': 'image_url', | |
'image_url': { | |
'url': encode_base64(content[0]), | |
}, | |
}] | |
content_list = [] | |
for key, val in content.items(): | |
if key == 'text': | |
content_list.append({ | |
'type': 'text', | |
'text': val, | |
}) | |
elif key == 'files': | |
for f in val: | |
content_list.append({ | |
'type': 'image_url', | |
'image_url': { | |
'url': encode_base64(f), | |
}, | |
}) | |
return content_list | |
def encode_base64(path): | |
guess_type = mimetypes.guess_type(path)[0] | |
if not guess_type.startswith('image/'): | |
raise gr.Error('not an image ({}): {}'.format(guess_type, path)) | |
with open(path, 'rb') as handle: | |
data = handle.read() | |
return 'data:{};base64,{}'.format( | |
guess_type, | |
base64.b64encode(data).decode(), | |
) | |
def format_history(history): | |
"""Convert chatbot history format to API call format""" | |
messages = [] | |
if SYSTEM_PROMPT is not None: | |
messages.append({ | |
'role': 'system', | |
'content': SYSTEM_PROMPT, | |
}) | |
for item in history: | |
if item["role"] == "user": | |
messages.append({ | |
'role': 'user', | |
'content': convert_content(item["content"]), | |
}) | |
elif item["role"] == "assistant": | |
# Extract reasoning content and main content | |
reasoning_content = "" | |
main_content = "" | |
if isinstance(item["content"], list): | |
for content_item in item["content"]: | |
if content_item.get("type") == "tool": | |
reasoning_content = content_item.get("content", "") | |
elif content_item.get("type") == "text": | |
main_content = content_item.get("content", "") | |
else: | |
main_content = item["content"] | |
messages.append({ | |
'role': 'assistant', | |
'content': convert_content(main_content), | |
'reasoning_content': convert_content(reasoning_content), | |
}) | |
return messages | |
def submit(sender_value, chatbot_value): | |
if sender_value is not None: | |
chatbot_value.append({ | |
"role": "user", | |
"content": sender_value, | |
}) | |
api_messages = format_history(chatbot_value) | |
for message in api_messages: | |
add_name_for_message(message) | |
chatbot_value.append({ | |
"role": "assistant", | |
"content": [], | |
"loading": True, | |
"status": "pending" | |
}) | |
yield { | |
sender: gr.update(value=None, loading=True), | |
clear_btn: gr.update(disabled=True), | |
chatbot: gr.update(value=chatbot_value) | |
} | |
try: | |
data = { | |
'model': MODEL_VERSION, | |
'messages': api_messages, | |
'stream': True, | |
'max_tokens': MODEL_CONTROL_DEFAULTS['tokens_to_generate'], | |
'temperature': MODEL_CONTROL_DEFAULTS['temperature'], | |
'top_p': MODEL_CONTROL_DEFAULTS['top_p'], | |
} | |
r = requests.post( | |
API_URL, | |
headers={ | |
'Content-Type': 'application/json', | |
'Authorization': 'Bearer {}'.format(API_KEY), | |
}, | |
data=json.dumps(data), | |
stream=True, | |
) | |
thought_done = False | |
start_time = time.time() | |
message_content = chatbot_value[-1]["content"] | |
# Reasoning content (tool type) | |
message_content.append({ | |
"type": "tool", | |
"content": "", | |
"options": { | |
"title": "🤔 Thinking..." | |
} | |
}) | |
# Main content (text type) | |
message_content.append({ | |
"type": "text", | |
"content": "", | |
}) | |
reasoning_start_time = None | |
reasoning_duration = None | |
for row in r.iter_lines(): | |
if row.startswith(b'data:'): | |
data = json.loads(row[5:]) | |
if 'choices' not in data: | |
raise gr.Error('request failed') | |
choice = data['choices'][0] | |
if 'delta' in choice: | |
delta = choice['delta'] | |
reasoning_content = delta.get('reasoning_content', '') | |
content = delta.get('content', '') | |
chatbot_value[-1]["loading"] = False | |
# Handle reasoning content | |
if reasoning_content: | |
if reasoning_start_time is None: | |
reasoning_start_time = time.time() | |
message_content[-2]["content"] += reasoning_content | |
# Handle main content | |
if content: | |
message_content[-1]["content"] += content | |
if not thought_done: | |
thought_done = True | |
if reasoning_start_time is not None: | |
reasoning_duration = time.time() - reasoning_start_time | |
thought_cost_time = "{:.2f}".format(reasoning_duration) | |
else: | |
reasoning_duration = 0.0 | |
thought_cost_time = "0.00" | |
message_content[-2]["options"] = {"title": f"End of Thought ({thought_cost_time}s)"} | |
yield {chatbot: gr.update(value=chatbot_value)} | |
elif 'message' in choice: | |
message_data = choice['message'] | |
reasoning_content = message_data.get('reasoning_content', '') | |
main_content = message_data.get('content', '') | |
message_content[-2]["content"] = reasoning_content | |
message_content[-1]["content"] = main_content | |
if reasoning_content and main_content: | |
if reasoning_duration is None: | |
if reasoning_start_time is not None: | |
reasoning_duration = time.time() - reasoning_start_time | |
thought_cost_time = "{:.2f}".format(reasoning_duration) | |
else: | |
reasoning_duration = 0.0 | |
thought_cost_time = "0.00" | |
else: | |
thought_cost_time = "{:.2f}".format(reasoning_duration) | |
message_content[-2]["options"] = {"title": f"End of Thought ({thought_cost_time}s)"} | |
chatbot_value[-1]["loading"] = False | |
yield {chatbot: gr.update(value=chatbot_value)} | |
chatbot_value[-1]["footer"] = "{:.2f}s".format(time.time() - start_time) | |
chatbot_value[-1]["status"] = "done" | |
yield { | |
clear_btn: gr.update(disabled=False), | |
sender: gr.update(loading=False), | |
chatbot: gr.update(value=chatbot_value), | |
} | |
except Exception as e: | |
chatbot_value[-1]["loading"] = False | |
chatbot_value[-1]["status"] = "done" | |
chatbot_value[-1]["content"] = "Request failed, please try again." | |
yield { | |
clear_btn: gr.update(disabled=False), | |
sender: gr.update(loading=False), | |
chatbot: gr.update(value=chatbot_value), | |
} | |
raise e | |
with gr.Blocks() as demo, ms.Application(), antdx.XProvider(): | |
with antd.Flex(vertical=True, gap="middle"): | |
chatbot = pro.Chatbot( | |
height="calc(100vh - 200px)", | |
markdown_config=ChatbotMarkdownConfig(allow_tags=["think"]), | |
welcome_config=ChatbotWelcomeConfig( | |
variant="borderless", | |
icon="./assets/minimax-logo.png", | |
title="Hello, I'm MiniMax-M1", | |
description="You can input text to get started.", | |
prompts=ChatbotPromptsConfig( | |
title="How can I help you today?", | |
styles={ | |
"list": { | |
"width": '100%', | |
}, | |
"item": { | |
"flex": 1, | |
}, | |
}, | |
items=[{ | |
"label": "🤔 Logical Reasoning", | |
"children": [{ | |
"description": "A is taller than B, B is shorter than C. Who is taller, A or C?" | |
}, { | |
"description": "Alice put candy in the drawer and went out. Bob moved the candy to the cabinet. Where will Alice look for the candy when she returns?" | |
}] | |
}, { | |
"label": "📚 Knowledge Q&A", | |
"children": [{ | |
"description": "Can you tell me about middle school mathematics?" | |
}, { | |
"description": "If Earth's gravity suddenly halved, what would happen to the height humans can jump?" | |
}] | |
}])), | |
user_config=ChatbotUserConfig(actions=["copy", "edit"]), | |
bot_config=ChatbotBotConfig( | |
header=MODEL_NAME, | |
avatar="./assets/minimax-logo.png", | |
actions=["copy", "retry"] | |
), | |
) | |
with antdx.Sender() as sender: | |
with ms.Slot("prefix"): | |
with antd.Button(value=None, color="default", variant="text") as clear_btn: | |
with ms.Slot("icon"): | |
antd.Icon("ClearOutlined") | |
clear_btn.click(fn=clear, outputs=[chatbot]) | |
submit_event = sender.submit( | |
fn=submit, | |
inputs=[sender, chatbot], | |
outputs=[sender, chatbot, clear_btn] | |
) | |
sender.cancel( | |
fn=cancel, | |
inputs=[chatbot], | |
outputs=[chatbot, sender, clear_btn], | |
cancels=[submit_event], | |
queue=False | |
) | |
chatbot.retry( | |
fn=retry, | |
inputs=[chatbot], | |
outputs=[sender, chatbot, clear_btn] | |
) | |
chatbot.welcome_prompt_select(fn=prompt_select, outputs=[sender]) | |
if __name__ == '__main__': | |
demo.queue(default_concurrency_limit=50).launch(share=True) | |