Spaces:
Sleeping
Sleeping
import json | |
import tempfile | |
import requests | |
import streamlit as st | |
from lagent.schema import AgentStatusCode | |
from pyvis.network import Network | |
# Function to create the network graph | |
def create_network_graph(nodes, adjacency_list): | |
net = Network(height='500px', | |
width='60%', | |
bgcolor='white', | |
font_color='black') | |
for node_id, node_data in nodes.items(): | |
if node_id in ['root', 'response']: | |
title = node_data.get('content', node_id) | |
else: | |
title = node_data['detail']['content'] | |
net.add_node(node_id, | |
label=node_id, | |
title=title, | |
color='#FF5733', | |
size=25) | |
for node_id, neighbors in adjacency_list.items(): | |
for neighbor in neighbors: | |
if neighbor['name'] in nodes: | |
net.add_edge(node_id, neighbor['name']) | |
net.show_buttons(filter_=['physics']) | |
return net | |
# Function to draw the graph and return the HTML file path | |
def draw_graph(net): | |
path = tempfile.mktemp(suffix='.html') | |
net.save_graph(path) | |
return path | |
def streaming(raw_response): | |
for chunk in raw_response.iter_lines(chunk_size=8192, | |
decode_unicode=False, | |
delimiter=b'\n'): | |
if chunk: | |
decoded = chunk.decode('utf-8') | |
if decoded == '\r': | |
continue | |
if decoded[:6] == 'data: ': | |
decoded = decoded[6:] | |
elif decoded.startswith(': ping - '): | |
continue | |
response = json.loads(decoded) | |
yield (response['response'], response['current_node']) | |
# Initialize Streamlit session state | |
if 'queries' not in st.session_state: | |
st.session_state['queries'] = [] | |
st.session_state['responses'] = [] | |
st.session_state['graphs_html'] = [] | |
st.session_state['nodes_list'] = [] | |
st.session_state['adjacency_list_list'] = [] | |
st.session_state['history'] = [] | |
st.session_state['already_used_keys'] = list() | |
# Set up page layout | |
st.set_page_config(layout='wide') | |
st.title('MindSearch-思索') | |
# Function to update chat | |
def update_chat(query): | |
with st.chat_message('user'): | |
st.write(query) | |
if query not in st.session_state['queries']: | |
# Mock data to simulate backend response | |
# response, history, nodes, adjacency_list | |
st.session_state['queries'].append(query) | |
st.session_state['responses'].append([]) | |
history = None | |
# 暂不支持多轮 | |
message = [dict(role='user', content=query)] | |
url = 'http://localhost:8002/solve' | |
headers = {'Content-Type': 'application/json'} | |
data = {'inputs': message} | |
raw_response = requests.post(url, | |
headers=headers, | |
data=json.dumps(data), | |
timeout=20, | |
stream=True) | |
for resp in streaming(raw_response): | |
agent_return, node_name = resp | |
if node_name and node_name in ['root', 'response']: | |
continue | |
nodes = agent_return['nodes'] | |
adjacency_list = agent_return['adj'] | |
response = agent_return['response'] | |
history = agent_return['inner_steps'] | |
if nodes: | |
net = create_network_graph(nodes, adjacency_list) | |
graph_html_path = draw_graph(net) | |
with open(graph_html_path, encoding='utf-8') as f: | |
graph_html = f.read() | |
else: | |
graph_html = None | |
if 'graph_placeholder' not in st.session_state: | |
st.session_state['graph_placeholder'] = st.empty() | |
if 'expander_placeholder' not in st.session_state: | |
st.session_state['expander_placeholder'] = st.empty() | |
if graph_html: | |
with st.session_state['expander_placeholder'].expander( | |
'Show Graph', expanded=False): | |
st.session_state['graph_placeholder']._html(graph_html, | |
height=500) | |
if 'container_placeholder' not in st.session_state: | |
st.session_state['container_placeholder'] = st.empty() | |
with st.session_state['container_placeholder'].container(): | |
if 'columns_placeholder' not in st.session_state: | |
st.session_state['columns_placeholder'] = st.empty() | |
col1, col2 = st.session_state['columns_placeholder'].columns( | |
[2, 1]) | |
with col1: | |
if 'planner_placeholder' not in st.session_state: | |
st.session_state['planner_placeholder'] = st.empty() | |
if 'session_info_temp' not in st.session_state: | |
st.session_state['session_info_temp'] = '' | |
if not node_name: | |
if agent_return['state'] in [ | |
AgentStatusCode.STREAM_ING, | |
AgentStatusCode.ANSWER_ING | |
]: | |
st.session_state['session_info_temp'] = response | |
elif agent_return[ | |
'state'] == AgentStatusCode.PLUGIN_START: | |
thought = st.session_state[ | |
'session_info_temp'].split('```')[0] | |
if agent_return['response'].startswith('```'): | |
st.session_state[ | |
'session_info_temp'] = thought + '\n' + response | |
elif agent_return[ | |
'state'] == AgentStatusCode.PLUGIN_RETURN: | |
assert agent_return['inner_steps'][-1][ | |
'role'] == 'environment' | |
st.session_state[ | |
'session_info_temp'] += '\n' + agent_return[ | |
'inner_steps'][-1]['content'] | |
st.session_state['planner_placeholder'].markdown( | |
st.session_state['session_info_temp']) | |
if agent_return[ | |
'state'] == AgentStatusCode.PLUGIN_RETURN: | |
st.session_state['responses'][-1].append( | |
st.session_state['session_info_temp']) | |
st.session_state['session_info_temp'] = '' | |
else: | |
st.session_state['planner_placeholder'].markdown( | |
st.session_state['responses'][-1][-1] if | |
not st.session_state['session_info_temp'] else st. | |
session_state['session_info_temp']) | |
with col2: | |
if 'selectbox_placeholder' not in st.session_state: | |
st.session_state['selectbox_placeholder'] = st.empty() | |
if 'searcher_placeholder' not in st.session_state: | |
st.session_state['searcher_placeholder'] = st.empty() | |
# st.session_state['searcher_placeholder'].markdown('') | |
if node_name: | |
selected_node_key = f"selected_node_{len(st.session_state['queries'])}_{node_name}" | |
if selected_node_key not in st.session_state: | |
st.session_state[selected_node_key] = node_name | |
if selected_node_key not in st.session_state[ | |
'already_used_keys']: | |
selected_node = st.session_state[ | |
'selectbox_placeholder'].selectbox( | |
'Select a node:', | |
list(nodes.keys()), | |
key=f'key_{selected_node_key}', | |
index=list(nodes.keys()).index(node_name)) | |
st.session_state['already_used_keys'].append( | |
selected_node_key) | |
else: | |
selected_node = node_name | |
st.session_state[selected_node_key] = selected_node | |
if selected_node in nodes: | |
node = nodes[selected_node] | |
agent_return = node['detail'] | |
node_info_key = f'{selected_node}_info' | |
if 'node_info_temp' not in st.session_state: | |
st.session_state[ | |
'node_info_temp'] = f'### {agent_return["content"]}' | |
if node_info_key not in st.session_state: | |
st.session_state[node_info_key] = [] | |
if agent_return['state'] in [ | |
AgentStatusCode.STREAM_ING, | |
AgentStatusCode.ANSWER_ING | |
]: | |
st.session_state[ | |
'node_info_temp'] = agent_return[ | |
'response'] | |
elif agent_return[ | |
'state'] == AgentStatusCode.PLUGIN_START: | |
thought = st.session_state[ | |
'node_info_temp'].split('```')[0] | |
if agent_return['response'].startswith('```'): | |
st.session_state[ | |
'node_info_temp'] = thought + '\n' + agent_return[ | |
'response'] | |
elif agent_return[ | |
'state'] == AgentStatusCode.PLUGIN_END: | |
thought = st.session_state[ | |
'node_info_temp'].split('```')[0] | |
if isinstance(agent_return['response'], dict): | |
st.session_state[ | |
'node_info_temp'] = thought + '\n' + f'```json\n{json.dumps(agent_return["response"], ensure_ascii=False, indent=4)}\n```' # noqa: E501 | |
elif agent_return[ | |
'state'] == AgentStatusCode.PLUGIN_RETURN: | |
assert agent_return['inner_steps'][-1][ | |
'role'] == 'environment' | |
st.session_state[node_info_key].append( | |
('thought', | |
st.session_state['node_info_temp'])) | |
st.session_state[node_info_key].append( | |
('observation', | |
agent_return['inner_steps'][-1]['content'] | |
)) | |
st.session_state['searcher_placeholder'].markdown( | |
st.session_state['node_info_temp']) | |
if agent_return['state'] == AgentStatusCode.END: | |
st.session_state[node_info_key].append( | |
('answer', | |
st.session_state['node_info_temp'])) | |
st.session_state['node_info_temp'] = '' | |
if st.session_state['session_info_temp']: | |
st.session_state['responses'][-1].append( | |
st.session_state['session_info_temp']) | |
st.session_state['session_info_temp'] = '' | |
# st.session_state['responses'][-1] = '\n'.join(st.session_state['responses'][-1]) | |
st.session_state['graphs_html'].append(graph_html) | |
st.session_state['nodes_list'].append(nodes) | |
st.session_state['adjacency_list_list'].append(adjacency_list) | |
st.session_state['history'] = history | |
def display_chat_history(): | |
for i, query in enumerate(st.session_state['queries'][-1:]): | |
# with st.chat_message('assistant'): | |
if st.session_state['graphs_html'][i]: | |
with st.session_state['expander_placeholder'].expander( | |
'Show Graph', expanded=False): | |
st.session_state['graph_placeholder']._html( | |
st.session_state['graphs_html'][i], height=500) | |
with st.session_state['container_placeholder'].container(): | |
col1, col2 = st.session_state['columns_placeholder'].columns( | |
[2, 1]) | |
with col1: | |
st.session_state['planner_placeholder'].markdown( | |
st.session_state['responses'][-1][-1]) | |
with col2: | |
selected_node_key = st.session_state['already_used_keys'][ | |
-1] | |
st.session_state['selectbox_placeholder'] = st.empty() | |
selected_node = st.session_state[ | |
'selectbox_placeholder'].selectbox( | |
'Select a node:', | |
list(st.session_state['nodes_list'][i].keys()), | |
key=f'replay_key_{i}', | |
index=list(st.session_state['nodes_list'][i].keys( | |
)).index(st.session_state[selected_node_key])) | |
st.session_state[selected_node_key] = selected_node | |
if selected_node not in [ | |
'root', 'response' | |
] and selected_node in st.session_state['nodes_list'][i]: | |
node_info_key = f'{selected_node}_info' | |
for item in st.session_state[node_info_key]: | |
if item[0] in ['thought', 'answer']: | |
st.session_state[ | |
'searcher_placeholder'] = st.empty() | |
st.session_state[ | |
'searcher_placeholder'].markdown(item[1]) | |
elif item[0] == 'observation': | |
st.session_state[ | |
'observation_expander'] = st.empty() | |
with st.session_state[ | |
'observation_expander'].expander( | |
'Results'): | |
st.write(item[1]) | |
# st.session_state['searcher_placeholder'].markdown(st.session_state[node_info_key]) | |
def clean_history(): | |
st.session_state['queries'] = [] | |
st.session_state['responses'] = [] | |
st.session_state['graphs_html'] = [] | |
st.session_state['nodes_list'] = [] | |
st.session_state['adjacency_list_list'] = [] | |
st.session_state['history'] = [] | |
st.session_state['already_used_keys'] = list() | |
for k in st.session_state: | |
if k.endswith('placeholder') or k.endswith('_info'): | |
del st.session_state[k] | |
# Main function to run the Streamlit app | |
def main(): | |
st.sidebar.title('Model Control') | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
user_input = st.chat_input('Enter your query:') | |
with col2: | |
if st.button('Clear History'): | |
clean_history() | |
if user_input: | |
update_chat(user_input) | |
display_chat_history() | |
if __name__ == '__main__': | |
main() | |