Kevin Wu commited on
Commit
9d23c0f
·
1 Parent(s): 8b47fae
Files changed (2) hide show
  1. app.py +71 -105
  2. structures.py +73 -0
app.py CHANGED
@@ -4,117 +4,46 @@ import os
4
  import time
5
  import gradio as gr
6
  from openai import OpenAI
7
- import xml.etree.ElementTree as ET
8
- import re
9
- import pandas as pd
10
  import prompts
11
  import traceback
12
  from io import StringIO
 
 
 
 
 
 
 
13
 
14
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
15
 
16
  model_name = "gpt-4o-2024-08-06"
17
-
18
  try:
19
  demo = client.beta.assistants.create(
20
  name="Information Extractor",
21
- instructions="Extract information from this note.",
22
  model=model_name,
23
  tools=[{"type": "file_search"}],
24
  )
 
25
  except Exception as e:
26
  print(f"Error creating assistant: {str(e)}")
27
  raise
28
 
29
- def parse_xml_response(xml_string: str) -> pd.DataFrame:
30
- """
31
- Parse the XML response from the model and extract all fields into a dictionary,
32
- then convert it to a pandas DataFrame with a nested index.
33
- """
34
- try:
35
- # Extract only the XML content between the outermost tags
36
- xml_content = re.findall(r'<[^>]+>.*?</[^>]+>', xml_string, re.DOTALL)
37
- if not xml_content:
38
- print("No valid XML content found.")
39
- return pd.DataFrame()
40
-
41
- # Wrap the content in a root element to ensure there's only one root
42
- xml_string = f"<root>{''.join(xml_content)}</root>"
43
-
44
- # Parse the XML
45
- root = ET.fromstring(xml_string)
46
-
47
- result = {}
48
-
49
- for element in root:
50
- tag = element.tag
51
- if tag in ['patient_name', 'date_of_birth', 'sex', 'weight', 'date_of_death']:
52
- result[tag] = {
53
- 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
54
- **{child.tag: child.text.strip() if child.text else None
55
- for child in element if child.tag != 'reasoning'}
56
- }
57
- elif tag in ['traditional_chemo', 'other_cancer_treatments', 'other_conmeds']:
58
- if tag not in result:
59
- result[tag] = []
60
- reasoning = element.find('reasoning')
61
- for item in element:
62
- if item.tag in ['drug', 'treatment', 'medication']:
63
- date_element = element.find('date')
64
- result[tag].append({
65
- 'reasoning': reasoning.text.strip() if reasoning is not None else None,
66
- 'name': item.text.strip() if item.text else None,
67
- 'date': date_element.text.strip() if date_element is not None and date_element.text else None
68
- })
69
- elif tag in ['surgery', 'surgery_outcome', 'metastasis_at_time_of_diagnosis']:
70
- result[tag] = {
71
- 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
72
- **{child.tag: child.text.strip() if child.text else None
73
- for child in element if child.tag != 'reasoning'}
74
- }
75
- elif tag == 'compounding_pharmacy':
76
- result[tag] = {
77
- 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
78
- 'pharmacy': element.find('pharmacy').text.strip() if element.find('pharmacy') is not None else None
79
- }
80
- elif tag == 'adverse_effects':
81
- if tag not in result:
82
- result[tag] = []
83
- effect = {
84
- 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None
85
- }
86
- for child in element:
87
- if child.tag != 'reasoning':
88
- effect[child.tag] = child.text.strip() if child.text else None
89
- if effect:
90
- result[tag].append(effect)
91
-
92
- # Convert to nested DataFrame
93
- df_data = {}
94
- for key, value in result.items():
95
- if isinstance(value, dict):
96
- for sub_key, sub_value in value.items():
97
- df_data[(key, '1', sub_key)] = [sub_value]
98
- elif isinstance(value, list):
99
- for i, item in enumerate(value):
100
- for sub_key, sub_value in item.items():
101
- df_data[(key, f"{i+1}", sub_key)] = [sub_value]
102
- else:
103
- df_data[(key, '1', '')] = [value]
104
 
105
- # Create multi-index DataFrame
106
- df = pd.DataFrame(df_data)
107
- df.columns = pd.MultiIndex.from_tuples(df.columns)
108
-
109
- return df
110
- except ET.ParseError as e:
111
- print(f"XML parsing error: {str(e)}")
112
- print(f"Problematic XML content: {xml_string[:500]}...") # Print first 500 chars of XML
113
- return pd.DataFrame()
114
- except Exception as e:
115
- print(f"Error in parse_xml_response: {str(e)}")
116
- print(f"Traceback: {traceback.format_exc()}")
117
- return pd.DataFrame()
118
 
119
  def get_response(file_id, assistant_id, max_retries=3):
120
  for attempt in range(max_retries):
@@ -126,16 +55,24 @@ def get_response(file_id, assistant_id, max_retries=3):
126
  "content": prompts.info_prompt,
127
  "attachments": [
128
  {"file_id": file_id, "tools": [{"type": "file_search"}]}
129
- ],
130
- }
131
  ]
132
  )
133
- run = client.beta.threads.runs.create_and_poll(
134
- thread_id=thread.id, assistant_id=assistant_id
 
 
 
135
  )
 
 
 
 
 
136
  messages = list(
137
  client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
138
  )
 
139
  assert len(messages) == 1, f"Expected 1 message, got {len(messages)}"
140
  message_content = messages[0].content[0].text
141
  annotations = message_content.annotations
@@ -150,6 +87,36 @@ def get_response(file_id, assistant_id, max_retries=3):
150
  time.sleep(5)
151
  else:
152
  raise Exception("Max retries reached. Unable to get response from the model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def process(file_content):
155
  try:
@@ -162,18 +129,18 @@ def process(file_content):
162
  message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants")
163
 
164
  response = get_response(message_file.id, demo.id) # This now includes retry logic
165
- df = parse_xml_response(response)
 
 
166
 
167
  if df.empty:
168
  return "<p>No valid information could be extracted from the provided file.</p>"
169
 
170
- # Transpose the DataFrame
171
- df_transposed = df.T.reset_index()
172
- df_transposed.columns = ['Category', 'Index', 'Field', 'Value']
173
- df_transposed = df_transposed.sort_values(['Category', 'Index', 'Field'])
174
 
175
  # Convert to HTML with some basic styling
176
- html = df_transposed.to_html(index=False, classes='table table-striped table-bordered', escape=False)
177
 
178
  # Add some custom CSS for better readability
179
  html = f"""
@@ -227,8 +194,7 @@ def gradio_interface():
227
  def run_in_terminal():
228
  print("Clinical Note Information Extractor")
229
  print("This tool extracts key information from clinical notes in PDF format.")
230
- print("Enter the path to your PDF file:")
231
- file_path = input().strip()
232
 
233
  if not os.path.exists(file_path):
234
  print(f"Error: File not found at {file_path}")
 
4
  import time
5
  import gradio as gr
6
  from openai import OpenAI
 
 
 
7
  import prompts
8
  import traceback
9
  from io import StringIO
10
+ import pandas as pd
11
+ from typing import Dict, Any
12
+
13
+ from typing import List, Optional
14
+ from pydantic import BaseModel, Field
15
+ from structures import ClinicalInfo
16
+
17
 
18
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
19
 
20
  model_name = "gpt-4o-2024-08-06"
21
+ # import pdb; pdb.set_trace()
22
  try:
23
  demo = client.beta.assistants.create(
24
  name="Information Extractor",
25
+ instructions="Extract information from this note and return it as a JSON object.",
26
  model=model_name,
27
  tools=[{"type": "file_search"}],
28
  )
29
+
30
  except Exception as e:
31
  print(f"Error creating assistant: {str(e)}")
32
  raise
33
 
34
+ def parse_response(prompt):
35
+ chat_completion = client.beta.chat.completions.parse(
36
+ messages=[
37
+ {
38
+ "role": "user",
39
+ "content": prompt,
40
+ }
41
+ ],
42
+ model=model_name,
43
+ response_format=ClinicalInfo,
44
+ )
45
+ return chat_completion.choices[0].message.parsed.model_dump()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def get_response(file_id, assistant_id, max_retries=3):
49
  for attempt in range(max_retries):
 
55
  "content": prompts.info_prompt,
56
  "attachments": [
57
  {"file_id": file_id, "tools": [{"type": "file_search"}]}
58
+ ],}
 
59
  ]
60
  )
61
+ # import pdb; pdb.set_trace()
62
+ run = client.beta.threads.runs.create(
63
+ thread_id=thread.id,
64
+ assistant_id=assistant_id,
65
+ instructions="Please provide your response as a valid JSON object.",
66
  )
67
+ run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
68
+ while run.status != "completed":
69
+ time.sleep(1)
70
+ run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
71
+
72
  messages = list(
73
  client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
74
  )
75
+
76
  assert len(messages) == 1, f"Expected 1 message, got {len(messages)}"
77
  message_content = messages[0].content[0].text
78
  annotations = message_content.annotations
 
87
  time.sleep(5)
88
  else:
89
  raise Exception("Max retries reached. Unable to get response from the model.")
90
+
91
+ def clinical_info_to_dataframe(clinical_info: Dict[str, Any]) -> pd.DataFrame:
92
+ """
93
+ Convert ClinicalInfo dictionary to a DataFrame.
94
+ """
95
+ data = []
96
+ for field, value in clinical_info.items():
97
+ if isinstance(value, dict):
98
+ for sub_field, sub_value in value.items():
99
+ data.append({
100
+ 'Category': field,
101
+ 'Field': sub_field,
102
+ 'Value': str(sub_value)
103
+ })
104
+ elif isinstance(value, list):
105
+ for i, item in enumerate(value):
106
+ for sub_field, sub_value in item.items():
107
+ data.append({
108
+ 'Category': f"{field}_{i+1}",
109
+ 'Field': sub_field,
110
+ 'Value': str(sub_value)
111
+ })
112
+ elif value is None:
113
+ data.append({
114
+ 'Category': field,
115
+ 'Field': 'value',
116
+ 'Value': 'None'
117
+ })
118
+ return pd.DataFrame(data)
119
+
120
 
121
  def process(file_content):
122
  try:
 
129
  message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants")
130
 
131
  response = get_response(message_file.id, demo.id) # This now includes retry logic
132
+ response_prompt = f"Please parse the following response into the correct format: {response}"
133
+ clinical_info = parse_response(response_prompt)
134
+ df = clinical_info_to_dataframe(clinical_info)
135
 
136
  if df.empty:
137
  return "<p>No valid information could be extracted from the provided file.</p>"
138
 
139
+ # Sort the DataFrame
140
+ df = df.sort_values(['Category', 'Field'])
 
 
141
 
142
  # Convert to HTML with some basic styling
143
+ html = df.to_html(index=False, classes='table table-striped table-bordered', escape=False)
144
 
145
  # Add some custom CSS for better readability
146
  html = f"""
 
194
  def run_in_terminal():
195
  print("Clinical Note Information Extractor")
196
  print("This tool extracts key information from clinical notes in PDF format.")
197
+ file_path = "../clinicalnotes_raw/0b7wtxiunxwploe6tnnluh0l84qg.pdf"
 
198
 
199
  if not os.path.exists(file_path):
200
  print(f"Error: File not found at {file_path}")
structures.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, List
2
+ from typing_extensions import Literal
3
+ from pydantic import BaseModel, Field as FieldInfo
4
+
5
+ class Name(BaseModel):
6
+ reasoning: str
7
+ first_name: str
8
+ last_name: str
9
+
10
+ class DateInfo(BaseModel):
11
+ reasoning: str
12
+ date: str
13
+
14
+ class SexInfo(BaseModel):
15
+ reasoning: str
16
+ sex: str
17
+
18
+ class ChemoInfo(BaseModel):
19
+ reasoning: str
20
+ drug: str
21
+ date: Optional[str] = None
22
+
23
+ class TreatmentInfo(BaseModel):
24
+ reasoning: str
25
+ treatment: str
26
+ date: Optional[str] = None
27
+
28
+ class MedicationInfo(BaseModel):
29
+ reasoning: str
30
+ medication: str
31
+ date: Optional[str] = None
32
+
33
+ class SurgeryInfo(BaseModel):
34
+ reasoning: str
35
+ resection: str
36
+
37
+ class SurgeryOutcomeInfo(BaseModel):
38
+ reasoning: str
39
+ outcome: str
40
+
41
+ class MetastasisInfo(BaseModel):
42
+ reasoning: str
43
+ metastasis: str
44
+
45
+ class PharmacyInfo(BaseModel):
46
+ reasoning: str
47
+ pharmacy: str
48
+
49
+ class AdverseEffectInfo(BaseModel):
50
+ reasoning: str
51
+ medication: str
52
+ dosage: Optional[str] = None
53
+ date: Optional[str] = None
54
+ description: str
55
+
56
+ class WeightInfo(BaseModel):
57
+ reasoning: str
58
+ weight: str
59
+
60
+ class ClinicalInfo(BaseModel):
61
+ patient_name: Optional[Name] = None
62
+ date_of_birth: Optional[DateInfo] = None
63
+ sex: Optional[SexInfo] = None
64
+ traditional_chemo: Optional[List[ChemoInfo]] = None
65
+ other_cancer_treatments: Optional[List[TreatmentInfo]] = None
66
+ other_conmeds: Optional[List[MedicationInfo]] = None
67
+ surgery: Optional[SurgeryInfo] = None
68
+ surgery_outcome: Optional[SurgeryOutcomeInfo] = None
69
+ metastasis_at_time_of_diagnosis: Optional[MetastasisInfo] = None
70
+ compounding_pharmacy: Optional[PharmacyInfo] = None
71
+ adverse_effects: Optional[List[AdverseEffectInfo]] = None
72
+ date_of_death: Optional[DateInfo] = None
73
+ weight: Optional[WeightInfo] = None