Spaces:
Sleeping
Sleeping
import glob | |
import json | |
import os | |
import time | |
import gradio as gr | |
from openai import OpenAI | |
import prompts | |
import traceback | |
from io import StringIO | |
import pandas as pd | |
from typing import Dict, Any | |
from typing import List, Optional | |
from pydantic import BaseModel, Field | |
from structures import ClinicalInfo | |
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
model_name = "gpt-4o-2024-08-06" | |
# import pdb; pdb.set_trace() | |
try: | |
demo = client.beta.assistants.create( | |
name="Information Extractor", | |
instructions="Extract information from this note and return it as a JSON object.", | |
model=model_name, | |
tools=[{"type": "file_search"}], | |
) | |
except Exception as e: | |
print(f"Error creating assistant: {str(e)}") | |
raise | |
def parse_response(prompt): | |
chat_completion = client.beta.chat.completions.parse( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt, | |
} | |
], | |
model=model_name, | |
response_format=ClinicalInfo, | |
) | |
return chat_completion.choices[0].message.parsed.model_dump() | |
def get_response(file_id, assistant_id, max_retries=3): | |
for attempt in range(max_retries): | |
try: | |
thread = client.beta.threads.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompts.info_prompt, | |
"attachments": [ | |
{"file_id": file_id, "tools": [{"type": "file_search"}]} | |
],} | |
] | |
) | |
# import pdb; pdb.set_trace() | |
run = client.beta.threads.runs.create( | |
thread_id=thread.id, | |
assistant_id=assistant_id, | |
instructions="Please provide your response as a valid JSON object.", | |
) | |
run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) | |
while run.status != "completed": | |
time.sleep(1) | |
run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) | |
messages = list( | |
client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) | |
) | |
assert len(messages) == 1, f"Expected 1 message, got {len(messages)}" | |
message_content = messages[0].content[0].text | |
annotations = message_content.annotations | |
for index, annotation in enumerate(annotations): | |
message_content.value = message_content.value.replace(annotation.text, f"") | |
return message_content.value | |
except Exception as e: | |
print(f"Error in get_response (attempt {attempt + 1}): {str(e)}") | |
print(f"Traceback: {traceback.format_exc()}") | |
if attempt < max_retries - 1: | |
print(f"Retrying in 5 seconds...") | |
time.sleep(5) | |
else: | |
raise Exception("Max retries reached. Unable to get response from the model.") | |
def clinical_info_to_dataframe(clinical_info: Dict[str, Any]) -> pd.DataFrame: | |
""" | |
Convert ClinicalInfo dictionary to a DataFrame. | |
""" | |
data = [] | |
for field, value in clinical_info.items(): | |
if isinstance(value, dict): | |
for sub_field, sub_value in value.items(): | |
data.append({ | |
'Category': field, | |
'Field': sub_field, | |
'Value': str(sub_value) | |
}) | |
elif isinstance(value, list): | |
for i, item in enumerate(value): | |
for sub_field, sub_value in item.items(): | |
data.append({ | |
'Category': f"{field}_{i+1}", | |
'Field': sub_field, | |
'Value': str(sub_value) | |
}) | |
elif value is None: | |
data.append({ | |
'Category': field, | |
'Field': 'value', | |
'Value': 'None' | |
}) | |
return pd.DataFrame(data) | |
def process(file_content): | |
try: | |
if not os.path.exists("cache"): | |
os.makedirs("cache") | |
file_name = f"cache/{time.time()}.pdf" | |
with open(file_name, "wb") as f: | |
f.write(file_content) | |
message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants") | |
response = get_response(message_file.id, demo.id) # This now includes retry logic | |
response_prompt = f"Please parse the following response into the correct format: {response}" | |
clinical_info = parse_response(response_prompt) | |
df = clinical_info_to_dataframe(clinical_info) | |
if df.empty: | |
return "<p>No valid information could be extracted from the provided file.</p>" | |
# Sort the DataFrame | |
df = df.sort_values(['Category', 'Field']) | |
# Convert to HTML with some basic styling | |
html = df.to_html(index=False, classes='table table-striped table-bordered', escape=False) | |
# Add some custom CSS for better readability | |
html = f""" | |
<style> | |
.table {{ | |
width: 100%; | |
max-width: 100%; | |
margin-bottom: 1rem; | |
background-color: transparent; | |
}} | |
.table td, .table th {{ | |
padding: .75rem; | |
vertical-align: top; | |
border-top: 1px solid #dee2e6; | |
}} | |
.table thead th {{ | |
vertical-align: bottom; | |
border-bottom: 2px solid #dee2e6; | |
}} | |
.table tbody + tbody {{ | |
border-top: 2px solid #dee2e6; | |
}} | |
.table-striped tbody tr:nth-of-type(odd) {{ | |
background-color: rgba(0,0,0,.05); | |
}} | |
</style> | |
{html} | |
""" | |
return html | |
except Exception as e: | |
error_message = f"An error occurred while processing the file: {str(e)}" | |
print(error_message) | |
print(f"Traceback: {traceback.format_exc()}") | |
return f"<p>{error_message}</p>" | |
def gradio_interface(): | |
upload_component = gr.File(label="Upload PDF", type="binary") | |
output_component = gr.HTML(label="Extracted Information") | |
demo = gr.Interface( | |
fn=process, | |
inputs=upload_component, | |
outputs=output_component, | |
title="Clinical Note Information Extractor", | |
description="This tool extracts key information from clinical notes in PDF format.", | |
) | |
demo.queue() | |
demo.launch() | |
def run_in_terminal(): | |
print("Clinical Note Information Extractor") | |
print("This tool extracts key information from clinical notes in PDF format.") | |
file_path = "../clinicalnotes_raw/0b7wtxiunxwploe6tnnluh0l84qg.pdf" | |
if not os.path.exists(file_path): | |
print(f"Error: File not found at {file_path}") | |
return | |
try: | |
with open(file_path, "rb") as file: | |
file_content = file.read() | |
result = process(file_content) | |
if result.startswith("<p>"): | |
# Error message | |
print(result[3:-4]) # Remove <p> tags | |
else: | |
# Save the HTML output to a file | |
output_file = f"output_{time.time()}.html" | |
with open(output_file, "w", encoding="utf-8") as f: | |
f.write(result) | |
print(f"Extraction completed. Results saved to {output_file}") | |
# Also print a simplified version to the console | |
df = pd.read_html(result)[0] | |
print("\nExtracted Information:") | |
for _, row in df.iterrows(): | |
print(f"{row['Category']} - {row['Field']}: {row['Value']}") | |
except Exception as e: | |
print(f"An error occurred while processing the file: {str(e)}") | |
print(f"Traceback: {traceback.format_exc()}") | |
if __name__ == "__main__": | |
try: | |
gradio_interface() | |
# run_in_terminal() | |
except Exception as e: | |
print(f"Error launching Gradio interface: {str(e)}") | |
print(f"Traceback: {traceback.format_exc()}") |