mindsearch / frontend /mindsearch_streamlit.py
solitudeLin's picture
Update frontend/mindsearch_streamlit.py
9a025c8 verified
import json
import tempfile
import requests
import streamlit as st
from lagent.schema import AgentStatusCode
from pyvis.network import Network
# Function to create the network graph
def create_network_graph(nodes, adjacency_list):
net = Network(height='500px',
width='60%',
bgcolor='white',
font_color='black')
for node_id, node_data in nodes.items():
if node_id in ['root', 'response']:
title = node_data.get('content', node_id)
else:
title = node_data['detail']['content']
net.add_node(node_id,
label=node_id,
title=title,
color='#FF5733',
size=25)
for node_id, neighbors in adjacency_list.items():
for neighbor in neighbors:
if neighbor['name'] in nodes:
net.add_edge(node_id, neighbor['name'])
net.show_buttons(filter_=['physics'])
return net
# Function to draw the graph and return the HTML file path
def draw_graph(net):
path = tempfile.mktemp(suffix='.html')
net.save_graph(path)
return path
def streaming(raw_response):
for chunk in raw_response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\n'):
if chunk:
decoded = chunk.decode('utf-8')
if decoded == '\r':
continue
if decoded[:6] == 'data: ':
decoded = decoded[6:]
elif decoded.startswith(': ping - '):
continue
response = json.loads(decoded)
yield (response['response'], response['current_node'])
# Initialize Streamlit session state
if 'queries' not in st.session_state:
st.session_state['queries'] = []
st.session_state['responses'] = []
st.session_state['graphs_html'] = []
st.session_state['nodes_list'] = []
st.session_state['adjacency_list_list'] = []
st.session_state['history'] = []
st.session_state['already_used_keys'] = list()
# Set up page layout
st.set_page_config(layout='wide')
st.title('MindSearch-思索')
# Function to update chat
def update_chat(query):
with st.chat_message('user'):
st.write(query)
if query not in st.session_state['queries']:
# Mock data to simulate backend response
# response, history, nodes, adjacency_list
st.session_state['queries'].append(query)
st.session_state['responses'].append([])
history = None
# 暂不支持多轮
message = [dict(role='user', content=query)]
url = 'http://localhost/solve'
headers = {'Content-Type': 'application/json'}
data = {'inputs': message}
raw_response = requests.post(url,
headers=headers,
data=json.dumps(data),
timeout=20,
stream=True)
for resp in streaming(raw_response):
agent_return, node_name = resp
if node_name and node_name in ['root', 'response']:
continue
nodes = agent_return['nodes']
adjacency_list = agent_return['adj']
response = agent_return['response']
history = agent_return['inner_steps']
if nodes:
net = create_network_graph(nodes, adjacency_list)
graph_html_path = draw_graph(net)
with open(graph_html_path, encoding='utf-8') as f:
graph_html = f.read()
else:
graph_html = None
if 'graph_placeholder' not in st.session_state:
st.session_state['graph_placeholder'] = st.empty()
if 'expander_placeholder' not in st.session_state:
st.session_state['expander_placeholder'] = st.empty()
if graph_html:
with st.session_state['expander_placeholder'].expander(
'Show Graph', expanded=False):
st.session_state['graph_placeholder']._html(graph_html,
height=500)
if 'container_placeholder' not in st.session_state:
st.session_state['container_placeholder'] = st.empty()
with st.session_state['container_placeholder'].container():
if 'columns_placeholder' not in st.session_state:
st.session_state['columns_placeholder'] = st.empty()
col1, col2 = st.session_state['columns_placeholder'].columns(
[2, 1])
with col1:
if 'planner_placeholder' not in st.session_state:
st.session_state['planner_placeholder'] = st.empty()
if 'session_info_temp' not in st.session_state:
st.session_state['session_info_temp'] = ''
if not node_name:
if agent_return['state'] in [
AgentStatusCode.STREAM_ING,
AgentStatusCode.ANSWER_ING
]:
st.session_state['session_info_temp'] = response
elif agent_return[
'state'] == AgentStatusCode.PLUGIN_START:
thought = st.session_state[
'session_info_temp'].split('```')[0]
if agent_return['response'].startswith('```'):
st.session_state[
'session_info_temp'] = thought + '\n' + response
elif agent_return[
'state'] == AgentStatusCode.PLUGIN_RETURN:
assert agent_return['inner_steps'][-1][
'role'] == 'environment'
st.session_state[
'session_info_temp'] += '\n' + agent_return[
'inner_steps'][-1]['content']
st.session_state['planner_placeholder'].markdown(
st.session_state['session_info_temp'])
if agent_return[
'state'] == AgentStatusCode.PLUGIN_RETURN:
st.session_state['responses'][-1].append(
st.session_state['session_info_temp'])
st.session_state['session_info_temp'] = ''
else:
st.session_state['planner_placeholder'].markdown(
st.session_state['responses'][-1][-1] if
not st.session_state['session_info_temp'] else st.
session_state['session_info_temp'])
with col2:
if 'selectbox_placeholder' not in st.session_state:
st.session_state['selectbox_placeholder'] = st.empty()
if 'searcher_placeholder' not in st.session_state:
st.session_state['searcher_placeholder'] = st.empty()
# st.session_state['searcher_placeholder'].markdown('')
if node_name:
selected_node_key = f"selected_node_{len(st.session_state['queries'])}_{node_name}"
if selected_node_key not in st.session_state:
st.session_state[selected_node_key] = node_name
if selected_node_key not in st.session_state[
'already_used_keys']:
selected_node = st.session_state[
'selectbox_placeholder'].selectbox(
'Select a node:',
list(nodes.keys()),
key=f'key_{selected_node_key}',
index=list(nodes.keys()).index(node_name))
st.session_state['already_used_keys'].append(
selected_node_key)
else:
selected_node = node_name
st.session_state[selected_node_key] = selected_node
if selected_node in nodes:
node = nodes[selected_node]
agent_return = node['detail']
node_info_key = f'{selected_node}_info'
if 'node_info_temp' not in st.session_state:
st.session_state[
'node_info_temp'] = f'### {agent_return["content"]}'
if node_info_key not in st.session_state:
st.session_state[node_info_key] = []
if agent_return['state'] in [
AgentStatusCode.STREAM_ING,
AgentStatusCode.ANSWER_ING
]:
st.session_state[
'node_info_temp'] = agent_return[
'response']
elif agent_return[
'state'] == AgentStatusCode.PLUGIN_START:
thought = st.session_state[
'node_info_temp'].split('```')[0]
if agent_return['response'].startswith('```'):
st.session_state[
'node_info_temp'] = thought + '\n' + agent_return[
'response']
elif agent_return[
'state'] == AgentStatusCode.PLUGIN_END:
thought = st.session_state[
'node_info_temp'].split('```')[0]
if isinstance(agent_return['response'], dict):
st.session_state[
'node_info_temp'] = thought + '\n' + f'```json\n{json.dumps(agent_return["response"], ensure_ascii=False, indent=4)}\n```' # noqa: E501
elif agent_return[
'state'] == AgentStatusCode.PLUGIN_RETURN:
assert agent_return['inner_steps'][-1][
'role'] == 'environment'
st.session_state[node_info_key].append(
('thought',
st.session_state['node_info_temp']))
st.session_state[node_info_key].append(
('observation',
agent_return['inner_steps'][-1]['content']
))
st.session_state['searcher_placeholder'].markdown(
st.session_state['node_info_temp'])
if agent_return['state'] == AgentStatusCode.END:
st.session_state[node_info_key].append(
('answer',
st.session_state['node_info_temp']))
st.session_state['node_info_temp'] = ''
if st.session_state['session_info_temp']:
st.session_state['responses'][-1].append(
st.session_state['session_info_temp'])
st.session_state['session_info_temp'] = ''
# st.session_state['responses'][-1] = '\n'.join(st.session_state['responses'][-1])
st.session_state['graphs_html'].append(graph_html)
st.session_state['nodes_list'].append(nodes)
st.session_state['adjacency_list_list'].append(adjacency_list)
st.session_state['history'] = history
def display_chat_history():
for i, query in enumerate(st.session_state['queries'][-1:]):
# with st.chat_message('assistant'):
if st.session_state['graphs_html'][i]:
with st.session_state['expander_placeholder'].expander(
'Show Graph', expanded=False):
st.session_state['graph_placeholder']._html(
st.session_state['graphs_html'][i], height=500)
with st.session_state['container_placeholder'].container():
col1, col2 = st.session_state['columns_placeholder'].columns(
[2, 1])
with col1:
st.session_state['planner_placeholder'].markdown(
st.session_state['responses'][-1][-1])
with col2:
selected_node_key = st.session_state['already_used_keys'][
-1]
st.session_state['selectbox_placeholder'] = st.empty()
selected_node = st.session_state[
'selectbox_placeholder'].selectbox(
'Select a node:',
list(st.session_state['nodes_list'][i].keys()),
key=f'replay_key_{i}',
index=list(st.session_state['nodes_list'][i].keys(
)).index(st.session_state[selected_node_key]))
st.session_state[selected_node_key] = selected_node
if selected_node not in [
'root', 'response'
] and selected_node in st.session_state['nodes_list'][i]:
node_info_key = f'{selected_node}_info'
for item in st.session_state[node_info_key]:
if item[0] in ['thought', 'answer']:
st.session_state[
'searcher_placeholder'] = st.empty()
st.session_state[
'searcher_placeholder'].markdown(item[1])
elif item[0] == 'observation':
st.session_state[
'observation_expander'] = st.empty()
with st.session_state[
'observation_expander'].expander(
'Results'):
st.write(item[1])
# st.session_state['searcher_placeholder'].markdown(st.session_state[node_info_key])
def clean_history():
st.session_state['queries'] = []
st.session_state['responses'] = []
st.session_state['graphs_html'] = []
st.session_state['nodes_list'] = []
st.session_state['adjacency_list_list'] = []
st.session_state['history'] = []
st.session_state['already_used_keys'] = list()
for k in st.session_state:
if k.endswith('placeholder') or k.endswith('_info'):
del st.session_state[k]
# Main function to run the Streamlit app
def main():
st.sidebar.title('Model Control')
col1, col2 = st.columns([4, 1])
with col1:
user_input = st.chat_input('Enter your query:')
with col2:
if st.button('Clear History'):
clean_history()
if user_input:
update_chat(user_input)
display_chat_history()
if __name__ == '__main__':
main()