hanbinChen commited on
Commit
a06ac21
·
1 Parent(s): 9985fd7

Enhance report input handling and UI updates

Browse files

- Added a new input selection method to allow users to choose between dataset selection and manual text input.
- Updated the report content display function to handle both dictionary and string formats for report data.
- Integrated user input processing for radiology reports, including error handling for empty inputs.
- Adjusted the main application flow to accommodate the new input method and ensure proper entity display and review submission.

This commit improves user experience by providing more flexible input options and enhances the overall functionality of the application.

Files changed (4) hide show
  1. app.py +31 -13
  2. app_ui.py +28 -1
  3. ie.py +189 -0
  4. train.json +0 -0
app.py CHANGED
@@ -5,7 +5,8 @@ from app_ui import (
5
  display_report_content,
6
  display_entities,
7
  display_relationship_graph,
8
- handle_review_submission
 
9
  )
10
 
11
  def initialize_session_state():
@@ -29,15 +30,28 @@ def main():
29
  # Initialize session state
30
  initialize_session_state()
31
 
32
- # Setup report selection
33
- selected_report = setup_report_selection()
34
-
35
- if selected_report:
36
- report_data = st.session_state.reports_json[selected_report]
37
- entities_data = report_data['entities']
38
-
39
- # Setup entity annotation
40
- selections_og = entities2Selection(report_data['text'], entities_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # 创建两列布局
43
  col1, col2 = st.columns([2, 2]) # 调整列宽比例
@@ -53,10 +67,14 @@ def main():
53
 
54
  # Display entities
55
  with col2:
56
- selections = display_entities(report_data['text'], selections_og)
 
 
 
57
 
58
- # Handle review submission
59
- handle_review_submission(selected_report, selections, entities_data)
 
60
 
61
  if __name__ == "__main__":
62
  main()
 
5
  display_report_content,
6
  display_entities,
7
  display_relationship_graph,
8
+ handle_review_submission,
9
+ setup_input_selection
10
  )
11
 
12
  def initialize_session_state():
 
30
  # Initialize session state
31
  initialize_session_state()
32
 
33
+ # 获取输入方式
34
+ input_selection = setup_input_selection()
35
+
36
+ if input_selection:
37
+ if input_selection["type"] == "dataset":
38
+ # 原有的数据集选择逻辑
39
+ selected_report = setup_report_selection()
40
+ if selected_report:
41
+ report_data = st.session_state.reports_json[selected_report]
42
+ entities_data = report_data['entities']
43
+ selections_og = entities2Selection(report_data['text'], entities_data)
44
+ else:
45
+ # 处理用户输入的文本
46
+ user_text = input_selection["text"]
47
+ if user_text:
48
+ from ie import process_sample
49
+ entities_data = process_sample(user_text, num_shots=5)
50
+ report_data = user_text
51
+ selections_og = entities2Selection(user_text, entities_data)
52
+ else:
53
+ st.warning("请输入文本内容")
54
+ return
55
 
56
  # 创建两列布局
57
  col1, col2 = st.columns([2, 2]) # 调整列宽比例
 
67
 
68
  # Display entities
69
  with col2:
70
+ selections = display_entities(
71
+ report_data['text'] if isinstance(report_data, dict) else report_data,
72
+ selections_og
73
+ )
74
 
75
+ # 仅对数据集的报告显示提交按钮
76
+ if input_selection["type"] == "dataset":
77
+ handle_review_submission(selected_report, selections, entities_data)
78
 
79
  if __name__ == "__main__":
80
  main()
app_ui.py CHANGED
@@ -96,7 +96,10 @@ def setup_report_selection():
96
  def display_report_content(report_data):
97
  """Display the report text content"""
98
  st.subheader("Report Content:")
99
- st.markdown(report_data['text'])
 
 
 
100
 
101
 
102
  def display_entities(report_text, entities):
@@ -137,3 +140,27 @@ def handle_review_submission(selected_report, selections, entities_data):
137
  save_data(st.session_state.reports_json)
138
  st.success("Review status saved!")
139
  st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def display_report_content(report_data):
97
  """Display the report text content"""
98
  st.subheader("Report Content:")
99
+ if isinstance(report_data, dict):
100
+ st.markdown(report_data['text'])
101
+ else:
102
+ st.markdown(report_data)
103
 
104
 
105
  def display_entities(report_text, entities):
 
140
  save_data(st.session_state.reports_json)
141
  st.success("Review status saved!")
142
  st.rerun()
143
+
144
+
145
+ def setup_input_selection():
146
+ """设置输入方式选择"""
147
+ st.subheader("选择输入方式")
148
+ input_method = st.radio(
149
+ "请选择输入方式",
150
+ ["从数据集选择", "手动输入文本"],
151
+ key="input_method"
152
+ )
153
+
154
+ if input_method == "手动输入文本":
155
+ user_text = st.text_area(
156
+ "请输入放射学报告文本",
157
+ height=200,
158
+ placeholder="在此输入报告文本...",
159
+ key="user_input_text"
160
+ )
161
+ if st.button("分析文本"):
162
+ return {"type": "user_input", "text": user_text}
163
+ else:
164
+ return {"type": "dataset"}
165
+
166
+ return None
ie.py CHANGED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 医学放射学报告命名实体识别(NER)评估脚本
3
+
4
+ 该脚本用于从放射学报告中提取临床术语实体,包括解剖部位(ANAT)和观察结果(OBS)。
5
+ 主要功能:
6
+ 1. 使用LLM模型进行实体抽取
7
+ 2. 支持few-shot学习,可以从训练数据中加载示例
8
+ 3. 包含错误重试机制
9
+ 4. 结果输出格式化处理
10
+
11
+ 实体标签说明:
12
+ - OBS-DP: 明确存在的观察结果
13
+ - ANAT-DP: 明确存在的解剖部位
14
+ - OBS-U: 不确定的观察结果
15
+ - OBS-DA: 明确不存在的观察结果
16
+
17
+ 输出格式:
18
+ {
19
+ "1": {"tokens": "...", "label": "..."},
20
+ "2": {"tokens": "...", "label": "..."},
21
+ ...
22
+ }
23
+ """
24
+
25
+ import json
26
+ import requests
27
+ import pprint
28
+ import time # 添加time模块用于重试机制
29
+ from typing import Dict, List, Any # 添加类型提示支持
30
+
31
+
32
+ # Prepare prompt with instructions
33
+ system_prompt = '''You are a radiologist performing clinical term extraction from the FINDINGS and IMPRESSION sections
34
+ in the radiology report. Here a clinical term can be either anatomy or observation that is related to a finding or an impression. The anatomy term refers to an anatomical body part such as a 'lung'. The observation terms refer to observations made when referring to the associated radiology image. Observations are associated with visual features, identifiable pathophysiologic processes, or diagnostic disease classifications. For example, an observation could be 'effusion' or description phrases like 'increased'. You also need to assign a label to indicate whether the clinical term is present, absent or uncertain. The labels are:
35
+ - OBS-DP: Observation definitely present
36
+ - ANAT-DP: Anatomy definitely present
37
+ - OBS-U: Observation uncertain
38
+ - OBS-DA: Observation definitely absent
39
+
40
+ Given a piece of radiology text input in the format:
41
+
42
+ <INPUT>
43
+ <text>
44
+ </INPUT>
45
+
46
+ reply with the following structure:
47
+
48
+ <OUTPUT>
49
+ ANSWER: tuples separated by newlines. Each tuple has the format: (<clinical term text>, <label: observation-present |observation-absent|observation-uncertain|anatomy-present>). If there are no extraction related to findings or impression, return ()
50
+ </OUTPUT>
51
+ '''
52
+
53
+ # Create a session object for reuse
54
+ session = requests.Session()
55
+
56
+
57
+ def get_question_prompt(text):
58
+ return f'''<INPUT>
59
+ {text}
60
+ </INPUT>
61
+
62
+ What are the clinical terms and their labels in this text? Discard sections other than FINDINGS and IMPRESSION: eg. INDICATION, HISTORY, TECHNIQUE, COMPARISON sections. If there is no extraction from findings and impression, return (). Please only output the tuples without additional notes or explanations.
63
+
64
+ <OUTPUT> ANSWER:
65
+ '''
66
+
67
+
68
+ def extract_entities(text, num_shots=5):
69
+ try:
70
+ url = 'http://ollama.corinth.informatik.rwth-aachen.de/api/chat'
71
+
72
+ messages = [{'role': 'system', 'content': system_prompt}]
73
+ messages = create_messages_with_shots(num_shots, messages)
74
+ messages.append({'role': 'user', 'content': get_question_prompt(text)})
75
+
76
+ data = {
77
+ "model": "llama3.1:8b",
78
+ "messages": messages,
79
+ "stream": False
80
+ }
81
+
82
+ response = session.post(url, json=data)
83
+ response.raise_for_status() # 检查HTTP错误
84
+ entities = response.json()['message']['content']
85
+
86
+ # Parse response into JSON format
87
+ result = {}
88
+ lines = entities.split('\n')
89
+ for i, line in enumerate(lines):
90
+ if '(' in line and ')' in line:
91
+ try:
92
+ term, label = eval(line.strip())
93
+ result[str(i+1)] = {
94
+ "tokens": term,
95
+ "label": label
96
+ }
97
+ except (ValueError, SyntaxError):
98
+ continue
99
+
100
+ return result
101
+
102
+ except Exception as e:
103
+ raise e
104
+
105
+
106
+ def create_messages_with_shots(num_shots, messages):
107
+ """
108
+ 从train.json创建few-shot示例消息列表
109
+
110
+ Args:
111
+ num_shots (int): 需要的few-shot示例数量
112
+ messages (list): 现有的消息列表
113
+
114
+ Returns:
115
+ list: 包含few-shot示例的消息列表
116
+ """
117
+ # 加载训练数据
118
+ with open('train.json', 'r') as f:
119
+ train_data = json.load(f)
120
+
121
+ # 获取前num_shots个示例
122
+ shot_count = 0
123
+ for key, value in train_data.items():
124
+ if shot_count >= num_shots:
125
+ break
126
+
127
+ text = value['text']
128
+ entities = value['entities']
129
+
130
+ # 构建答案字符串
131
+ answer_parts = []
132
+ for entity_id, entity_info in entities.items():
133
+ tokens = entity_info['tokens']
134
+ label = entity_info['label']
135
+ answer_parts.append(f"('{tokens}', '{label}')")
136
+
137
+ answer = '\n'.join(answer_parts)
138
+ answer += '\n</OUTPUT>'
139
+
140
+ # 添加用户问题
141
+ messages.append({
142
+ 'role': 'user',
143
+ 'content': get_question_prompt(text)
144
+ })
145
+
146
+ # 添加助手回答
147
+ messages.append({
148
+ 'role': 'assistant',
149
+ 'content': answer
150
+ })
151
+
152
+ shot_count += 1
153
+
154
+ return messages
155
+
156
+
157
+ def main():
158
+ # Load test data
159
+ with open('dev.json', 'r') as f:
160
+ data = json.load(f)
161
+
162
+ # Get first example
163
+ for key, value in data.items():
164
+ text = value['text']
165
+ # 移除未使用的shots_path参数,添加num_shots参数
166
+ entities = extract_entities(text, num_shots=50)
167
+ break
168
+
169
+ print(entities)
170
+
171
+
172
+ if __name__ == '__main__':
173
+ main()
174
+
175
+
176
+ def process_sample(text: str, num_shots: int) -> list:
177
+ """处理单个样本,包含重试机制"""
178
+ max_retries = 3
179
+
180
+ for attempt in range(max_retries):
181
+ try:
182
+ entities = extract_entities(text, num_shots)
183
+ return entities
184
+ except Exception as e:
185
+ if attempt < max_retries - 1:
186
+ print(f'第{attempt + 1}次尝试失败: {e}')
187
+ continue
188
+ else:
189
+ raise e
train.json ADDED
The diff for this file is too large to render. See raw diff