Spaces:
Runtime error
Runtime error
import os | |
import io | |
import json | |
import csv | |
import asyncio | |
import xml.etree.ElementTree as ET | |
from typing import Any, Dict, Optional, Tuple, Union, List | |
import httpx | |
import gradio as gr | |
import torch | |
from dotenv import load_dotenv | |
from loguru import logger | |
from huggingface_hub import login | |
from openai import OpenAI | |
from reportlab.pdfgen import canvas | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
MarianMTModel, | |
MarianTokenizer, | |
) | |
import pandas as pd | |
import altair as alt | |
import spacy | |
import spacy.cli | |
import PyPDF2 | |
############################################################################### | |
# 1) ENVIRONMENT & LOGGING # | |
############################################################################### | |
# Ensure spaCy model is downloaded (English Core Web) | |
try: | |
nlp = spacy.load("en_core_web_sm") | |
except OSError: | |
logger.info("Downloading SpaCy 'en_core_web_sm' model...") | |
spacy.cli.download("en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
# Logging | |
logger.add("error_logs.log", rotation="1 MB", level="ERROR") | |
# Load environment variables | |
load_dotenv() | |
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
BIOPORTAL_API_KEY = os.getenv("BIOPORTAL_API_KEY") # For BioPortal integration | |
ENTREZ_EMAIL = os.getenv("ENTREZ_EMAIL") | |
if not HUGGINGFACE_TOKEN or not OPENAI_API_KEY: | |
logger.error("Missing Hugging Face or OpenAI credentials.") | |
raise ValueError("Missing credentials for Hugging Face or OpenAI.") | |
# Warn if BioPortal key is missing | |
if not BIOPORTAL_API_KEY: | |
logger.warning("BIOPORTAL_API_KEY is not set. BioPortal fetch calls will fail.") | |
# Hugging Face login | |
login(HUGGINGFACE_TOKEN) | |
# OpenAI | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
# Device: CPU or GPU | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {device}") | |
############################################################################### | |
# 2) HUGGING FACE & TRANSLATION MODEL SETUP # | |
############################################################################### | |
MODEL_NAME = "mgbam/bert-base-finetuned-mgbam" | |
try: | |
model = AutoModelForSequenceClassification.from_pretrained( | |
MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN | |
) | |
except Exception as e: | |
logger.error(f"Model load error: {e}") | |
raise | |
try: | |
translation_model_name = "Helsinki-NLP/opus-mt-en-fr" | |
translation_model = MarianMTModel.from_pretrained( | |
translation_model_name, use_auth_token=HUGGINGFACE_TOKEN | |
).to(device) | |
translation_tokenizer = MarianTokenizer.from_pretrained( | |
translation_model_name, use_auth_token=HUGGINGFACE_TOKEN | |
) | |
except Exception as e: | |
logger.error(f"Translation model load error: {e}") | |
raise | |
# Language map for translation | |
LANGUAGE_MAP: Dict[str, Tuple[str, str]] = { | |
"English to French": ("en", "fr"), | |
"French to English": ("fr", "en"), | |
} | |
############################################################################### | |
# 3) API ENDPOINTS & CONSTANTS # | |
############################################################################### | |
PUBMED_SEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" | |
PUBMED_FETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" | |
EUROPE_PMC_BASE_URL = "https://www.ebi.ac.uk/europepmc/webservices/rest/search" | |
BIOPORTAL_API_BASE = "https://data.bioontology.org" | |
CROSSREF_API_URL = "https://api.crossref.org/works" | |
############################################################################### | |
# 4) HELPER FUNCTIONS # | |
############################################################################### | |
def safe_json_parse(text: str) -> Union[Dict[str, Any], None]: | |
"""Safely parse JSON.""" | |
try: | |
return json.loads(text) | |
except json.JSONDecodeError as e: | |
logger.error(f"JSON parsing error: {e}") | |
return None | |
def parse_pubmed_xml(xml_data: str) -> List[Dict[str, Any]]: | |
"""Parse PubMed XML data into a structured list of articles.""" | |
root = ET.fromstring(xml_data) | |
articles = [] | |
for article in root.findall(".//PubmedArticle"): | |
pmid = article.findtext(".//PMID") | |
title = article.findtext(".//ArticleTitle") | |
abstract = article.findtext(".//AbstractText") | |
journal = article.findtext(".//Journal/Title") | |
pub_date_elem = article.find(".//JournalIssue/PubDate") | |
pub_date = None | |
if pub_date_elem is not None: | |
year = pub_date_elem.findtext("Year") | |
month = pub_date_elem.findtext("Month") | |
day = pub_date_elem.findtext("Day") | |
if year and month and day: | |
pub_date = f"{year}-{month}-{day}" | |
else: | |
pub_date = year | |
articles.append({ | |
"PMID": pmid, | |
"Title": title, | |
"Abstract": abstract, | |
"Journal": journal, | |
"PublicationDate": pub_date, | |
}) | |
return articles | |
def explain_clinical_results(results: str) -> str: | |
"""Generate a clinical explanation from raw results.""" | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": f"Explain the clinical test results:\n{results}"}], | |
max_tokens=500, | |
temperature=0.7, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(f"Explanation error: {e}") | |
return "Failed to generate explanation." | |
############################################################################### | |
# 6) CORE FUNCTIONS # | |
############################################################################### | |
def summarize_text(text: str) -> str: | |
"""OpenAI GPT-3.5 summarization.""" | |
if not text.strip(): | |
return "No text provided for summarization." | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": f"Summarize this clinical data:\n{text}"}], | |
max_tokens=200, | |
temperature=0.7, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(f"Summarization error: {e}") | |
return "Summarization failed." | |
def generate_report(text: str, filename: str = "clinical_report.pdf") -> Optional[str]: | |
"""Generate a professional PDF report from the text.""" | |
try: | |
if not text.strip(): | |
logger.warning("No text provided for the report.") | |
c = canvas.Canvas(filename) | |
c.drawString(100, 750, "Clinical Research Report") | |
lines = text.split("\n") | |
y = 730 | |
for line in lines: | |
if y < 50: | |
c.showPage() | |
y = 750 | |
c.drawString(100, y, line) | |
y -= 15 | |
c.save() | |
logger.info(f"Report generated: {filename}") | |
return filename | |
except Exception as e: | |
logger.error(f"Report generation error: {e}") | |
return None | |
def visualize_predictions(predictions: Dict[str, float]) -> alt.Chart: | |
"""Simple Altair bar chart to visualize classification probabilities.""" | |
data = pd.DataFrame(list(predictions.items()), columns=["Label", "Probability"]) | |
chart = ( | |
alt.Chart(data) | |
.mark_bar() | |
.encode( | |
x=alt.X("Label:N", sort=None), | |
y="Probability:Q", | |
tooltip=["Label", "Probability"], | |
) | |
.properties(title="Prediction Probabilities", width=500, height=300) | |
) | |
return chart | |
############################################################################### | |
# 7) BUILDING THE GRADIO APP # | |
############################################################################### | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🏥 AI-Driven Clinical Assistant") | |
gr.Markdown(""" | |
**Highlights**: | |
- **Summarize** clinical text (OpenAI GPT-3.5) | |
- **Explain** clinical test results and trial outcomes | |
- **Generate** professional PDF reports | |
""") | |
text_input = gr.Textbox(label="Input Text", lines=5, placeholder="Enter clinical text or test results...") | |
action = gr.Radio( | |
[ | |
"Summarize", | |
"Explain Clinical Results", | |
"Generate Report", | |
], | |
label="Select an Action", | |
) | |
output_text = gr.Textbox(label="Output", lines=8) | |
output_file = gr.File(label="Generated File") | |
submit_btn = gr.Button("Submit") | |
async def handle_action( | |
action: str, | |
txt: str, | |
report_fn: str | |
) -> Tuple[Optional[str], Optional[str]]: | |
"""Handle clinical actions based on the user's selection.""" | |
try: | |
combined_text = txt.strip() | |
if action == "Summarize": | |
summary = summarize_text(combined_text) | |
return summary, None | |
elif action == "Explain Clinical Results": | |
explanation = explain_clinical_results(combined_text) | |
return explanation, None | |
elif action == "Generate Report": | |
path = generate_report(combined_text, report_fn) | |
msg = f"Report generated: {path}" if path else "Report generation failed." | |
return msg, path | |
return "Invalid action.", None | |
except Exception as e: | |
logger.error(f"Exception: {e}") | |
return f"Error: {str(e)}", None | |
submit_btn.click( | |
fn=handle_action, | |
inputs=[action, text_input, report_filename_input], | |
outputs=[output_text, output_file], | |
) | |
# Launch the Gradio interface | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |