Spaces:
Sleeping
Sleeping
import glob | |
import json | |
import os | |
import time | |
import gradio as gr | |
from openai import OpenAI | |
import xml.etree.ElementTree as ET | |
import re | |
import pandas as pd | |
import api_keys | |
import note_extraction.hf_hosting.prompts as prompts | |
client = OpenAI(api_key=api_keys.OPENAI_API_KEY) | |
model_name = "gpt-4o-2024-08-06" | |
demo = client.beta.assistants.create( | |
name="Information Extractor", | |
instructions="Extract information from this note.", | |
model=model_name, | |
tools=[{"type": "file_search"}], | |
) | |
def parse_xml_response(xml_string: str) -> pd.DataFrame: | |
""" | |
Parse the XML response from the model and extract all fields into a dictionary, | |
then convert it to a pandas DataFrame with a nested index. | |
""" | |
# Extract only the XML content between the first and last tags | |
xml_content = re.search(r'<.*?>.*</.*?>', xml_string, re.DOTALL) | |
if xml_content: | |
xml_string = xml_content.group(0) | |
else: | |
print("No valid XML content found.") | |
return pd.DataFrame() | |
try: | |
root = ET.fromstring(xml_string) | |
except ET.ParseError as e: | |
print(f"Error parsing XML: {e}") | |
return pd.DataFrame() | |
result = {} | |
for element in root: | |
tag = element.tag | |
if tag in ['patient_name', 'date_of_birth', 'sex', 'weight', 'date_of_death']: | |
result[tag] = { | |
'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, | |
**{child.tag: child.text.strip() if child.text else None | |
for child in element if child.tag != 'reasoning'} | |
} | |
elif tag in ['traditional_chemo', 'other_cancer_treatments', 'other_conmeds']: | |
if tag not in result: | |
result[tag] = [] | |
reasoning = element.find('reasoning') | |
for item in element: | |
if item.tag in ['drug', 'treatment', 'medication']: | |
date_element = element.find('date') | |
result[tag].append({ | |
'reasoning': reasoning.text.strip() if reasoning is not None else None, | |
'name': item.text.strip() if item.text else None, | |
'date': date_element.text.strip() if date_element is not None and date_element.text else None | |
}) | |
elif tag in ['surgery', 'surgery_outcome', 'metastasis_at_time_of_diagnosis']: | |
result[tag] = { | |
'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, | |
**{child.tag: child.text.strip() if child.text else None | |
for child in element if child.tag != 'reasoning'} | |
} | |
elif tag == 'compounding_pharmacy': | |
result[tag] = { | |
'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, | |
'pharmacy': element.find('pharmacy').text.strip() if element.find('pharmacy') is not None else None | |
} | |
elif tag == 'adverse_effects': | |
if tag not in result: | |
result[tag] = [] | |
effect = { | |
'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None | |
} | |
for child in element: | |
if child.tag != 'reasoning': | |
effect[child.tag] = child.text.strip() if child.text else None | |
if effect: | |
result[tag].append(effect) | |
# Convert to nested DataFrame | |
df_data = {} | |
for key, value in result.items(): | |
if isinstance(value, dict): | |
for sub_key, sub_value in value.items(): | |
df_data[(key, '1', sub_key)] = [sub_value] | |
elif isinstance(value, list): | |
for i, item in enumerate(value): | |
for sub_key, sub_value in item.items(): | |
df_data[(key, f"{i+1}", sub_key)] = [sub_value] | |
else: | |
df_data[(key, '1', '')] = [value] | |
# Create multi-index DataFrame | |
df = pd.DataFrame(df_data) | |
df.columns = pd.MultiIndex.from_tuples(df.columns) | |
return df | |
def get_response(prompt, file_id, assistant_id): | |
thread = client.beta.threads.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompts.info_prompt, | |
"attachments": [ | |
{"file_id": file_id, "tools": [{"type": "file_search"}]} | |
], | |
} | |
] | |
) | |
run = client.beta.threads.runs.create_and_poll( | |
thread_id=thread.id, assistant_id=assistant_id | |
) | |
messages = list( | |
client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) | |
) | |
assert len(messages) == 1 | |
message_content = messages[0].content[0].text | |
annotations = message_content.annotations | |
for index, annotation in enumerate(annotations): | |
message_content.value = message_content.value.replace(annotation.text, f"") | |
return message_content.value | |
def process(file_content): | |
if not os.path.exists("cache"): | |
os.makedirs("cache") | |
file_name = f"cache/{time.time()}.pdf" | |
with open(file_name, "wb") as f: | |
f.write(file_content) | |
message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants") | |
response = get_response(prompts.info_prompt, message_file.id, demo.id) | |
df = parse_xml_response(response) | |
if df.empty: | |
return "<p>No valid information could be extracted from the provided file.</p>" | |
# Transpose the DataFrame | |
df_transposed = df.T.reset_index() | |
df_transposed.columns = ['Category', 'Index', 'Field', 'Value'] | |
df_transposed = df_transposed.sort_values(['Category', 'Index', 'Field']) | |
# Convert to HTML with some basic styling | |
html = df_transposed.to_html(index=False, classes='table table-striped table-bordered', escape=False) | |
# Add some custom CSS for better readability | |
html = f""" | |
<style> | |
.table {{ | |
width: 100%; | |
max-width: 100%; | |
margin-bottom: 1rem; | |
background-color: transparent; | |
}} | |
.table td, .table th {{ | |
padding: .75rem; | |
vertical-align: top; | |
border-top: 1px solid #dee2e6; | |
}} | |
.table thead th {{ | |
vertical-align: bottom; | |
border-bottom: 2px solid #dee2e6; | |
}} | |
.table tbody + tbody {{ | |
border-top: 2px solid #dee2e6; | |
}} | |
.table-striped tbody tr:nth-of-type(odd) {{ | |
background-color: rgba(0,0,0,.05); | |
}} | |
</style> | |
{html} | |
""" | |
return html | |
def gradio_interface(): | |
upload_component = gr.File(label="Upload PDF", type="binary") | |
output_component = gr.HTML(label="Extracted Information") | |
demo = gr.Interface( | |
fn=process, | |
inputs=upload_component, | |
outputs=output_component, | |
title="Clinical Note Information Extractor", | |
description="This tool extracts key information from clinical notes in PDF format.", | |
) | |
demo.queue() | |
demo.launch() | |
if __name__ == "__main__": | |
gradio_interface() |