medKGC / app_ui.py
hanbinChen's picture
update
1a1c17c
import streamlit as st
from streamlit_agraph import agraph, Node, Edge, Config
from streamlit_text_label import label_select
from app_logic import *
def display_entity_selections(selections):
"""Display entity selections in a grid layout"""
st.subheader("Selected Entities:")
# 使用columns来水平排列按钮
cols = st.columns(4) # 每行4个按钮
for i, entity in enumerate(selections):
col_idx = i % 4
with cols[col_idx]:
if st.button(
f"{entity.text} ({entity.labels[0]})",
key=f"entity_{i}",
help=f"Start: {entity.start}, End: {entity.end}"
):
st.session_state.selected_entity = entity
def create_graph(entities, relations):
"""Create entity relationship graph"""
nodes_dict = {}
nodes = []
for entity in entities:
if entity.text not in nodes_dict:
node = Node(
id=entity.text,
label=f"{entity.text}\n({entity.labels[0]})",
size=25,
color=get_label_color(entity.labels[0])
)
nodes.append(node)
nodes_dict[entity.text] = node
edges = []
for relation in relations:
if relation.source.text in nodes_dict and relation.target.text in nodes_dict:
edge = Edge(
source=relation.source.text,
target=relation.target.text,
label=relation.label,
color="#666666"
)
edges.append(edge)
config = Config(
width=750,
height=500,
directed=True,
physics=True,
hierarchical=False,
nodeHighlightBehavior=True,
highlightColor="#F7A7A6",
)
return agraph(nodes=nodes, edges=edges, config=config)
def setup_report_selection():
"""Setup report selection columns and return selected report"""
col1, col2 = st.columns(2)
with col1:
st.subheader("Reports to Review")
unreviewed_reports = [
report_id for report_id, content in st.session_state.reports_json.items()
if 'reviewed' not in content
]
selected_report = st.selectbox(
"Select Report",
unreviewed_reports,
key="unreviewed"
)
with col2:
st.subheader("Reviewed Reports")
reviewed_reports = [
report_id for report_id, content in st.session_state.reports_json.items()
if content.get('reviewed', False)
]
st.selectbox(
"Completed Reports",
reviewed_reports if reviewed_reports else ['None'],
key="reviewed"
)
return selected_report
def display_report_content(report_data):
"""Display the report text content"""
st.subheader("Report Content:")
if isinstance(report_data, dict):
st.markdown(report_data['text'])
else:
st.markdown(report_data)
def display_entities(report_text, entities):
"""Setup and display entity annotation interface"""
st.subheader("Entity Annotation:")
selections = label_select(
body=report_text,
labels=list(set(e.labels[0] for e in entities)),
selections=entities,
)
# 显示实体选择
display_entity_selections(selections)
return selections
def display_relationship_graph(entities: list[Selection], entities_data: dict):
"""Display the relationship graph"""
st.subheader("Entity Relationship Graph:")
relations = find_relations_with_entities(entities, entities_data)
create_graph(entities, relations)
def handle_review_submission(selected_report, selections, entities_data):
"""Handle the review submission process"""
if st.button("Mark as Reviewed"):
updated_entities = selection2entities(selections)
for entity_id, entity in updated_entities.items():
if entity_id in entities_data:
entity['relations'] = entities_data[entity_id]['relations']
st.session_state.reports_json[selected_report]['reviewed'] = {
'entities': updated_entities
}
file_path = 'mockedReports.json'
save_data(file_path, st.session_state.reports_json)
st.success("Review status saved!")
st.rerun()
def setup_input_selection():
"""设置输入方式选择"""
st.subheader("Select Input Method")
input_method = st.radio(
"Select Input Method",
["Select from Dataset", "Manual Text Input"],
key="input_method"
)
if input_method == "Manual Text Input":
user_text = st.text_area(
"Please Input Radiology Report Text",
height=200,
placeholder="Enter report text here...",
key="user_input_text"
)
if st.button("Analyze Text"):
return {"type": "user_input", "text": user_text}
else:
return {"type": "dataset"}
return None