Spaces:
Sleeping
Sleeping
import glob | |
import json | |
import os | |
import time | |
import gradio as gr | |
from openai import OpenAI | |
import xml.etree.ElementTree as ET | |
import re | |
import pandas as pd | |
import prompts | |
import traceback | |
from io import StringIO | |
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
model_name = "gpt-4o-2024-08-06" | |
try: | |
demo = client.beta.assistants.create( | |
name="Information Extractor", | |
instructions="Extract information from this note.", | |
model=model_name, | |
tools=[{"type": "file_search"}], | |
) | |
except Exception as e: | |
print(f"Error creating assistant: {str(e)}") | |
raise | |
def parse_xml_response(xml_string: str) -> pd.DataFrame: | |
""" | |
Parse the XML response from the model and extract all fields into a dictionary, | |
then convert it to a pandas DataFrame with a nested index. | |
""" | |
try: | |
# Extract only the XML content between the outermost tags | |
xml_content = re.findall(r'<[^>]+>.*?</[^>]+>', xml_string, re.DOTALL) | |
if not xml_content: | |
print("No valid XML content found.") | |
return pd.DataFrame() | |
# Wrap the content in a root element to ensure there's only one root | |
xml_string = f"<root>{''.join(xml_content)}</root>" | |
# Parse the XML | |
root = ET.fromstring(xml_string) | |
result = {} | |
for element in root: | |
tag = element.tag | |
if tag in ['patient_name', 'date_of_birth', 'sex', 'weight', 'date_of_death']: | |
result[tag] = { | |
'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, | |
**{child.tag: child.text.strip() if child.text else None | |
for child in element if child.tag != 'reasoning'} | |
} | |
elif tag in ['traditional_chemo', 'other_cancer_treatments', 'other_conmeds']: | |
if tag not in result: | |
result[tag] = [] | |
reasoning = element.find('reasoning') | |
for item in element: | |
if item.tag in ['drug', 'treatment', 'medication']: | |
date_element = element.find('date') | |
result[tag].append({ | |
'reasoning': reasoning.text.strip() if reasoning is not None else None, | |
'name': item.text.strip() if item.text else None, | |
'date': date_element.text.strip() if date_element is not None and date_element.text else None | |
}) | |
elif tag in ['surgery', 'surgery_outcome', 'metastasis_at_time_of_diagnosis']: | |
result[tag] = { | |
'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, | |
**{child.tag: child.text.strip() if child.text else None | |
for child in element if child.tag != 'reasoning'} | |
} | |
elif tag == 'compounding_pharmacy': | |
result[tag] = { | |
'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, | |
'pharmacy': element.find('pharmacy').text.strip() if element.find('pharmacy') is not None else None | |
} | |
elif tag == 'adverse_effects': | |
if tag not in result: | |
result[tag] = [] | |
effect = { | |
'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None | |
} | |
for child in element: | |
if child.tag != 'reasoning': | |
effect[child.tag] = child.text.strip() if child.text else None | |
if effect: | |
result[tag].append(effect) | |
# Convert to nested DataFrame | |
df_data = {} | |
for key, value in result.items(): | |
if isinstance(value, dict): | |
for sub_key, sub_value in value.items(): | |
df_data[(key, '1', sub_key)] = [sub_value] | |
elif isinstance(value, list): | |
for i, item in enumerate(value): | |
for sub_key, sub_value in item.items(): | |
df_data[(key, f"{i+1}", sub_key)] = [sub_value] | |
else: | |
df_data[(key, '1', '')] = [value] | |
# Create multi-index DataFrame | |
df = pd.DataFrame(df_data) | |
df.columns = pd.MultiIndex.from_tuples(df.columns) | |
return df | |
except ET.ParseError as e: | |
print(f"XML parsing error: {str(e)}") | |
print(f"Problematic XML content: {xml_string[:500]}...") # Print first 500 chars of XML | |
return pd.DataFrame() | |
except Exception as e: | |
print(f"Error in parse_xml_response: {str(e)}") | |
print(f"Traceback: {traceback.format_exc()}") | |
return pd.DataFrame() | |
def get_response(file_id, assistant_id, max_retries=3): | |
for attempt in range(max_retries): | |
try: | |
thread = client.beta.threads.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompts.info_prompt, | |
"attachments": [ | |
{"file_id": file_id, "tools": [{"type": "file_search"}]} | |
], | |
} | |
] | |
) | |
run = client.beta.threads.runs.create_and_poll( | |
thread_id=thread.id, assistant_id=assistant_id | |
) | |
messages = list( | |
client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) | |
) | |
assert len(messages) == 1, f"Expected 1 message, got {len(messages)}" | |
message_content = messages[0].content[0].text | |
annotations = message_content.annotations | |
for index, annotation in enumerate(annotations): | |
message_content.value = message_content.value.replace(annotation.text, f"") | |
return message_content.value | |
except Exception as e: | |
print(f"Error in get_response (attempt {attempt + 1}): {str(e)}") | |
print(f"Traceback: {traceback.format_exc()}") | |
if attempt < max_retries - 1: | |
print(f"Retrying in 5 seconds...") | |
time.sleep(5) | |
else: | |
raise Exception("Max retries reached. Unable to get response from the model.") | |
def process(file_content): | |
try: | |
if not os.path.exists("cache"): | |
os.makedirs("cache") | |
file_name = f"cache/{time.time()}.pdf" | |
with open(file_name, "wb") as f: | |
f.write(file_content) | |
message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants") | |
response = get_response(message_file.id, demo.id) # This now includes retry logic | |
df = parse_xml_response(response) | |
if df.empty: | |
return "<p>No valid information could be extracted from the provided file.</p>" | |
# Transpose the DataFrame | |
df_transposed = df.T.reset_index() | |
df_transposed.columns = ['Category', 'Index', 'Field', 'Value'] | |
df_transposed = df_transposed.sort_values(['Category', 'Index', 'Field']) | |
# Convert to HTML with some basic styling | |
html = df_transposed.to_html(index=False, classes='table table-striped table-bordered', escape=False) | |
# Add some custom CSS for better readability | |
html = f""" | |
<style> | |
.table {{ | |
width: 100%; | |
max-width: 100%; | |
margin-bottom: 1rem; | |
background-color: transparent; | |
}} | |
.table td, .table th {{ | |
padding: .75rem; | |
vertical-align: top; | |
border-top: 1px solid #dee2e6; | |
}} | |
.table thead th {{ | |
vertical-align: bottom; | |
border-bottom: 2px solid #dee2e6; | |
}} | |
.table tbody + tbody {{ | |
border-top: 2px solid #dee2e6; | |
}} | |
.table-striped tbody tr:nth-of-type(odd) {{ | |
background-color: rgba(0,0,0,.05); | |
}} | |
</style> | |
{html} | |
""" | |
return html | |
except Exception as e: | |
error_message = f"An error occurred while processing the file: {str(e)}" | |
print(error_message) | |
print(f"Traceback: {traceback.format_exc()}") | |
return f"<p>{error_message}</p>" | |
def gradio_interface(): | |
upload_component = gr.File(label="Upload PDF", type="binary") | |
output_component = gr.HTML(label="Extracted Information") | |
demo = gr.Interface( | |
fn=process, | |
inputs=upload_component, | |
outputs=output_component, | |
title="Clinical Note Information Extractor", | |
description="This tool extracts key information from clinical notes in PDF format.", | |
) | |
demo.queue() | |
demo.launch() | |
def run_in_terminal(): | |
print("Clinical Note Information Extractor") | |
print("This tool extracts key information from clinical notes in PDF format.") | |
print("Enter the path to your PDF file:") | |
file_path = input().strip() | |
if not os.path.exists(file_path): | |
print(f"Error: File not found at {file_path}") | |
return | |
try: | |
with open(file_path, "rb") as file: | |
file_content = file.read() | |
result = process(file_content) | |
if result.startswith("<p>"): | |
# Error message | |
print(result[3:-4]) # Remove <p> tags | |
else: | |
# Save the HTML output to a file | |
output_file = f"output_{time.time()}.html" | |
with open(output_file, "w", encoding="utf-8") as f: | |
f.write(result) | |
print(f"Extraction completed. Results saved to {output_file}") | |
# Also print a simplified version to the console | |
df = pd.read_html(result)[0] | |
print("\nExtracted Information:") | |
for _, row in df.iterrows(): | |
print(f"{row['Category']} - {row['Field']}: {row['Value']}") | |
except Exception as e: | |
print(f"An error occurred while processing the file: {str(e)}") | |
print(f"Traceback: {traceback.format_exc()}") | |
if __name__ == "__main__": | |
try: | |
gradio_interface() | |
# run_in_terminal() | |
except Exception as e: | |
print(f"Error launching Gradio interface: {str(e)}") | |
print(f"Traceback: {traceback.format_exc()}") |