Spaces:
Runtime error
Runtime error
import os | |
import io | |
import json | |
import csv | |
import asyncio | |
import xml.etree.ElementTree as ET | |
from typing import Any, Dict, Optional, Tuple, Union, List | |
import httpx | |
import gradio as gr | |
import torch | |
from dotenv import load_dotenv | |
from loguru import logger | |
from huggingface_hub import login | |
from openai import OpenAI | |
from reportlab.pdfgen import canvas | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
MarianMTModel, | |
MarianTokenizer, | |
) | |
import pandas as pd | |
import altair as alt | |
import spacy | |
import spacy.cli | |
import PyPDF2 | |
# Ensure spaCy model is downloaded | |
try: | |
nlp = spacy.load("en_core_web_sm") | |
except OSError: | |
logger.info("Downloading SpaCy 'en_core_web_sm' model...") | |
spacy.cli.download("en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
# Logging | |
logger.add("error_logs.log", rotation="1 MB", level="ERROR") | |
# Load environment variables | |
load_dotenv() | |
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ENTREZ_EMAIL = os.getenv("ENTREZ_EMAIL") | |
# Basic checks | |
if not HUGGINGFACE_TOKEN or not OPENAI_API_KEY: | |
logger.error("Missing Hugging Face or OpenAI credentials.") | |
raise ValueError("Missing credentials for Hugging Face or OpenAI.") | |
# API endpoints | |
PUBMED_SEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" | |
PUBMED_FETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" | |
EUROPE_PMC_BASE_URL = "https://www.ebi.ac.uk/europepmc/webservices/rest/search" | |
# Log in to Hugging Face | |
login(HUGGINGFACE_TOKEN) | |
# Initialize OpenAI | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
# Device setting | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {device}") | |
# Model settings | |
MODEL_NAME = "mgbam/bert-base-finetuned-mgbam" | |
try: | |
model = AutoModelForSequenceClassification.from_pretrained( | |
MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN | |
) | |
except Exception as e: | |
logger.error(f"Model load error: {e}") | |
raise | |
# Translation model settings | |
try: | |
translation_model_name = "Helsinki-NLP/opus-mt-en-fr" | |
translation_model = MarianMTModel.from_pretrained( | |
translation_model_name, use_auth_token=HUGGINGFACE_TOKEN | |
).to(device) | |
translation_tokenizer = MarianTokenizer.from_pretrained( | |
translation_model_name, use_auth_token=HUGGINGFACE_TOKEN | |
) | |
except Exception as e: | |
logger.error(f"Translation model load error: {e}") | |
raise | |
LANGUAGE_MAP: Dict[str, Tuple[str, str]] = { | |
"English to French": ("en", "fr"), | |
"French to English": ("fr", "en"), | |
} | |
################################################### | |
# UTILS # | |
################################################### | |
def safe_json_parse(text: str) -> Union[Dict, None]: | |
"""Safely parse JSON string into a Python dictionary.""" | |
try: | |
return json.loads(text) | |
except json.JSONDecodeError as e: | |
logger.error(f"JSON parsing error: {e}") | |
return None | |
def parse_pubmed_xml(xml_data: str) -> List[Dict[str, Any]]: | |
"""Parses PubMed XML data and returns a list of structured articles.""" | |
root = ET.fromstring(xml_data) | |
articles = [] | |
for article in root.findall(".//PubmedArticle"): | |
pmid = article.findtext(".//PMID") | |
title = article.findtext(".//ArticleTitle") | |
abstract = article.findtext(".//AbstractText") | |
journal = article.findtext(".//Journal/Title") | |
pub_date_elem = article.find(".//JournalIssue/PubDate") | |
pub_date = None | |
if pub_date_elem is not None: | |
year = pub_date_elem.findtext("Year") | |
month = pub_date_elem.findtext("Month") | |
day = pub_date_elem.findtext("Day") | |
if year and month and day: | |
pub_date = f"{year}-{month}-{day}" | |
else: | |
pub_date = year | |
articles.append({ | |
"PMID": pmid, | |
"Title": title, | |
"Abstract": abstract, | |
"Journal": journal, | |
"PublicationDate": pub_date, | |
}) | |
return articles | |
################################################### | |
# ASYNC FETCHES # | |
################################################### | |
async def fetch_articles_by_nct_id(nct_id: str) -> Dict[str, Any]: | |
params = {"query": nct_id, "format": "json"} | |
async with httpx.AsyncClient() as client_http: | |
try: | |
response = await client_http.get(EUROPE_PMC_BASE_URL, params=params) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
logger.error(f"Error fetching articles for {nct_id}: {e}") | |
return {"error": str(e)} | |
async def fetch_articles_by_query(query_params: str) -> Dict[str, Any]: | |
parsed_params = safe_json_parse(query_params) | |
if not parsed_params or not isinstance(parsed_params, dict): | |
return {"error": "Invalid JSON."} | |
query_string = " AND ".join(f"{k}:{v}" for k, v in parsed_params.items()) | |
params = {"query": query_string, "format": "json"} | |
async with httpx.AsyncClient() as client_http: | |
try: | |
response = await client_http.get(EUROPE_PMC_BASE_URL, params=params) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
logger.error(f"Error fetching articles: {e}") | |
return {"error": str(e)} | |
async def fetch_pubmed_by_query(query_params: str) -> Dict[str, Any]: | |
parsed_params = safe_json_parse(query_params) | |
if not parsed_params or not isinstance(parsed_params, dict): | |
return {"error": "Invalid JSON for PubMed."} | |
search_params = { | |
"db": "pubmed", | |
"retmode": "json", | |
"email": ENTREZ_EMAIL, | |
"retmax": parsed_params.get("retmax", "10"), | |
"term": parsed_params.get("term", ""), | |
} | |
async with httpx.AsyncClient() as client_http: | |
try: | |
search_response = await client_http.get(PUBMED_SEARCH_URL, params=search_params) | |
search_response.raise_for_status() | |
search_data = search_response.json() | |
id_list = search_data.get("esearchresult", {}).get("idlist", []) | |
if not id_list: | |
return {"result": ""} | |
fetch_params = { | |
"db": "pubmed", | |
"id": ",".join(id_list), | |
"retmode": "xml", | |
"email": ENTREZ_EMAIL, | |
} | |
fetch_response = await client_http.get(PUBMED_FETCH_URL, params=fetch_params) | |
fetch_response.raise_for_status() | |
return {"result": fetch_response.text} | |
except Exception as e: | |
logger.error(f"Error fetching PubMed articles: {e}") | |
return {"error": str(e)} | |
async def fetch_crossref_by_query(query_params: str) -> Dict[str, Any]: | |
parsed_params = safe_json_parse(query_params) | |
if not parsed_params or not isinstance(parsed_params, dict): | |
return {"error": "Invalid JSON for Crossref."} | |
CROSSREF_API_URL = "https://api.crossref.org/works" | |
async with httpx.AsyncClient() as client_http: | |
try: | |
response = await client_http.get(CROSSREF_API_URL, params=parsed_params) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
logger.error(f"Error fetching Crossref data: {e}") | |
return {"error": str(e)} | |
################################################### | |
# CORE LOGIC # | |
################################################### | |
def summarize_text(text: str) -> str: | |
"""Summarize text using OpenAI.""" | |
if not text.strip(): | |
return "No text provided for summarization." | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": f"Summarize the following clinical data:\n{text}"}], | |
max_tokens=200, | |
temperature=0.7, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(f"Summarization Error: {e}") | |
return "Summarization failed." | |
def predict_outcome(text: str) -> Union[Dict[str, float], str]: | |
"""Predict outcomes (classification) using a fine-tuned model.""" | |
if not text.strip(): | |
return "No text provided for prediction." | |
try: | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] | |
return {f"Label {i+1}": float(prob.item()) for i, prob in enumerate(probabilities)} | |
except Exception as e: | |
logger.error(f"Prediction Error: {e}") | |
return "Prediction failed." | |
def generate_report(text: str, filename: str = "clinical_report.pdf") -> Optional[str]: | |
"""Generate a PDF report from the given text.""" | |
try: | |
if not text.strip(): | |
logger.warning("No text provided for the report.") | |
c = canvas.Canvas(filename) | |
c.drawString(100, 750, "Clinical Research Report") | |
lines = text.split("\n") | |
y = 730 | |
for line in lines: | |
if y < 50: | |
c.showPage() | |
y = 750 | |
c.drawString(100, y, line) | |
y -= 15 | |
c.save() | |
logger.info(f"Report generated: {filename}") | |
return filename | |
except Exception as e: | |
logger.error(f"Report Generation Error: {e}") | |
return None | |
def visualize_predictions(predictions: Dict[str, float]) -> Optional[alt.Chart]: | |
"""Visualize model prediction probabilities using Altair.""" | |
try: | |
data = pd.DataFrame(list(predictions.items()), columns=["Label", "Probability"]) | |
chart = ( | |
alt.Chart(data) | |
.mark_bar() | |
.encode( | |
x=alt.X("Label:N", sort=None), | |
y="Probability:Q", | |
tooltip=["Label", "Probability"], | |
) | |
.properties(title="Prediction Probabilities", width=500, height=300) | |
) | |
return chart | |
except Exception as e: | |
logger.error(f"Visualization Error: {e}") | |
return None | |
def translate_text(text: str, translation_option: str) -> str: | |
"""Translate text between English and French.""" | |
if not text.strip(): | |
return "No text provided for translation." | |
try: | |
if translation_option not in LANGUAGE_MAP: | |
return "Unsupported translation option." | |
inputs = translation_tokenizer(text, return_tensors="pt", padding=True).to(device) | |
translated_tokens = translation_model.generate(**inputs) | |
return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
except Exception as e: | |
logger.error(f"Translation Error: {e}") | |
return "Translation failed." | |
def perform_named_entity_recognition(text: str) -> str: | |
"""Perform Named Entity Recognition (NER) using spaCy.""" | |
if not text.strip(): | |
return "No text provided for NER." | |
try: | |
doc = nlp(text) | |
entities = [(ent.text, ent.label_) for ent in doc.ents] | |
if not entities: | |
return "No named entities found." | |
return "\n".join(f"{ent_text} -> {ent_label}" for ent_text, ent_label in entities) | |
except Exception as e: | |
logger.error(f"NER Error: {e}") | |
return "Named Entity Recognition failed." | |
################################################### | |
# ENHANCED EDA # | |
################################################### | |
def perform_enhanced_eda(df: pd.DataFrame) -> Tuple[str, Optional[alt.Chart], Optional[alt.Chart]]: | |
""" | |
Show columns, shape, numeric summary, correlation heatmap, and distribution histograms. | |
Returns (text_summary, correlation_chart, distribution_chart). | |
""" | |
try: | |
columns_info = f"Columns: {list(df.columns)}" | |
shape_info = f"Shape: {df.shape[0]} rows x {df.shape[1]} columns" | |
with pd.option_context("display.max_colwidth", 200, "display.max_rows", None): | |
describe_info = df.describe(include="all").to_string() | |
summary_text = ( | |
f"--- Enhanced EDA Summary ---\n" | |
f"{columns_info}\n{shape_info}\n\n" | |
f"Summary Statistics:\n{describe_info}\n" | |
) | |
numeric_cols = df.select_dtypes(include="number") | |
corr_chart = None | |
if numeric_cols.shape[1] >= 2: | |
corr = numeric_cols.corr() | |
corr_melted = corr.reset_index().melt(id_vars="index") | |
corr_melted.columns = ["Feature1", "Feature2", "Correlation"] | |
corr_chart = ( | |
alt.Chart(corr_melted) | |
.mark_rect() | |
.encode( | |
x="Feature1:O", | |
y="Feature2:O", | |
color="Correlation:Q", | |
tooltip=["Feature1", "Feature2", "Correlation"] | |
) | |
.properties(width=400, height=400, title="Correlation Heatmap") | |
) | |
distribution_chart = None | |
if numeric_cols.shape[1] >= 1: | |
df_long = numeric_cols.melt(var_name='Column', value_name='Value') | |
distribution_chart = ( | |
alt.Chart(df_long) | |
.mark_bar() | |
.encode( | |
alt.X("Value:Q", bin=alt.Bin(maxbins=30)), | |
alt.Y('count()'), | |
alt.Facet('Column:N', columns=2), | |
tooltip=["Value"] | |
) | |
.properties( | |
title='Distribution of Numeric Columns', | |
width=300, | |
height=200 | |
) | |
.interactive() | |
) | |
return summary_text, corr_chart, distribution_chart | |
except Exception as e: | |
logger.error(f"Enhanced EDA Error: {e}") | |
return f"Enhanced EDA failed: {e}", None, None | |
################################################### | |
# FILE PARSING # | |
################################################### | |
def parse_text_file(uploaded_file: gr.File) -> str: | |
"""Reads a .txt file as UTF-8 text.""" | |
return uploaded_file.read().decode("utf-8") | |
def parse_csv_file(uploaded_file: gr.File) -> pd.DataFrame: | |
""" | |
Reads CSV content with possible BOM issues | |
by trying 'utf-8' and 'utf-8-sig'. | |
""" | |
content = uploaded_file.read().decode("utf-8", errors="replace") | |
# We can attempt to parse with multiple encodings if needed: | |
# For simplicity, let's just do a fallback approach: | |
try: | |
from io import StringIO | |
df = pd.read_csv(StringIO(content)) | |
return df | |
except Exception as e: | |
raise ValueError(f"CSV parse error: {e}") | |
def parse_excel_file(uploaded_file: gr.File) -> pd.DataFrame: | |
""" | |
Parse an Excel file into a pandas DataFrame. | |
1) If the path exists, read directly from path. | |
2) Else read from uploaded_file.file (in-memory) in binary mode. | |
""" | |
import pandas as pd | |
import os | |
excel_path = uploaded_file.name | |
# Try local path first | |
if os.path.isfile(excel_path): | |
return pd.read_excel(excel_path, engine="openpyxl") | |
# Fall back to reading raw bytes from uploaded_file.file | |
try: | |
excel_bytes = uploaded_file.file.read() | |
return pd.read_excel(io.BytesIO(excel_bytes), engine="openpyxl") | |
except Exception as e: | |
raise ValueError(f"Excel parse error: {e}") | |
def parse_pdf_file(uploaded_file: gr.File) -> str: | |
"""Reads a PDF file with PyPDF2, extracting text from each page.""" | |
try: | |
pdf_reader = PyPDF2.PdfReader(uploaded_file) | |
text_content = [] | |
for page in pdf_reader.pages: | |
text_content.append(page.extract_text()) | |
return "\n".join(text_content) | |
except Exception as e: | |
logger.error(f"PDF parse error: {e}") | |
return f"Error reading PDF file: {e}" | |
################################################### | |
# GRADIO INTERFACE # | |
################################################### | |
with gr.Blocks() as demo: | |
gr.Markdown("# ✨ Advanced Clinical Research Assistant with Enhanced EDA ✨") | |
gr.Markdown(""" | |
Welcome to the **Enhanced** AI-Powered Clinical Assistant! | |
- **Summarize** large blocks of clinical text. | |
- **Predict** outcomes with a fine-tuned model. | |
- **Translate** text (English ↔ French). | |
- **Perform Named Entity Recognition** (spaCy). | |
- **Fetch** from PubMed, Crossref, Europe PMC. | |
- **Generate** professional PDF reports. | |
- **Perform Enhanced EDA** on CSV/Excel data (correlation heatmaps + distribution plots). | |
""") | |
# Inputs | |
with gr.Row(): | |
text_input = gr.Textbox(label="Input Text", lines=5, placeholder="Enter clinical text or query...") | |
# We'll rely on .name and .file for the path and file handle | |
file_input = gr.File( | |
label="Upload File (txt/csv/xls/xlsx/pdf)", | |
file_types=[".txt", ".csv", ".xls", ".xlsx", ".pdf"] | |
) | |
action = gr.Radio( | |
[ | |
"Summarize", | |
"Predict Outcome", | |
"Generate Report", | |
"Translate", | |
"Perform Named Entity Recognition", | |
"Perform Enhanced EDA", | |
"Fetch Clinical Studies", | |
"Fetch PubMed Articles (Legacy)", | |
"Fetch PubMed by Query", | |
"Fetch Crossref by Query", | |
], | |
label="Select an Action", | |
) | |
translation_option = gr.Dropdown( | |
choices=list(LANGUAGE_MAP.keys()), | |
label="Translation Option", | |
value="English to French" | |
) | |
query_params_input = gr.Textbox( | |
label="Query Parameters (JSON Format)", | |
placeholder='{"term": "cancer", "retmax": "5"}' | |
) | |
nct_id_input = gr.Textbox(label="NCT ID for Article Search") | |
report_filename_input = gr.Textbox( | |
label="Report Filename", | |
placeholder="clinical_report.pdf", | |
value="clinical_report.pdf" | |
) | |
export_format = gr.Dropdown(["None", "CSV", "JSON"], label="Export Format") | |
# Outputs | |
output_text = gr.Textbox(label="Output", lines=10) | |
with gr.Row(): | |
output_chart = gr.Plot(label="Visualization 1") | |
output_chart2 = gr.Plot(label="Visualization 2") | |
output_file = gr.File(label="Generated File") | |
submit_button = gr.Button("Submit") | |
################################################################ | |
# MAIN HANDLER FUNCTION # | |
################################################################ | |
async def handle_action( | |
action: str, | |
text: str, | |
file_up: gr.File, | |
translation_opt: str, | |
query_params: str, | |
nct_id: str, | |
report_filename: str, | |
export_format: str | |
) -> Tuple[Optional[str], Optional[Any], Optional[Any], Optional[str]]: | |
# 1) Start with user-provided text | |
combined_text = text.strip() | |
# 2) If user uploaded a file, parse it based on extension | |
if file_up is not None: | |
file_ext = os.path.splitext(file_up.name)[1].lower() | |
if file_ext == ".txt": | |
file_text = parse_text_file(file_up) | |
combined_text = (combined_text + "\n" + file_text).strip() | |
elif file_ext == ".csv": | |
# If user chose EDA, we'll parse into DataFrame below | |
# If we just want to combine text for Summarize, etc., do so: | |
pass | |
elif file_ext in [".xls", ".xlsx"]: | |
# We'll handle Excel parsing in the EDA step if needed | |
pass | |
elif file_ext == ".pdf": | |
file_text = parse_pdf_file(file_up) | |
combined_text = (combined_text + "\n" + file_text).strip() | |
### ACTIONS ### | |
if action == "Summarize": | |
if file_up and file_up.name.endswith(".csv"): | |
# Merge CSV text into combined_text | |
# in case user wants summarization of the CSV's raw text | |
try: | |
df_csv = parse_csv_file(file_up) | |
# Turn CSV into text | |
csv_as_text = df_csv.to_csv(index=False) | |
combined_text = (combined_text + "\n" + csv_as_text).strip() | |
except Exception as e: | |
return f"CSV parse error for Summarize: {e}", None, None, None | |
# Summarize the combined text | |
return summarize_text(combined_text), None, None, None | |
elif action == "Predict Outcome": | |
return _action_predict_outcome(combined_text, file_up) | |
elif action == "Generate Report": | |
# Add CSV content if needed | |
if file_up and file_up.name.endswith(".csv"): | |
try: | |
df_csv = parse_csv_file(file_up) | |
combined_text += "\n" + df_csv.to_csv(index=False) | |
except Exception as e: | |
logger.error(f"Error reading CSV for report: {e}") | |
file_path = generate_report(combined_text, filename=report_filename) | |
msg = f"Report generated: {file_path}" if file_path else "Report generation failed." | |
return msg, None, None, file_path | |
elif action == "Translate": | |
# Optionally read CSV or PDF text? | |
if file_up and file_up.name.endswith(".csv"): | |
try: | |
df_csv = parse_csv_file(file_up) | |
combined_text += "\n" + df_csv.to_csv(index=False) | |
except Exception as e: | |
return f"CSV parse error for Translate: {e}", None, None, None | |
translated = translate_text(combined_text, translation_opt) | |
return translated, None, None, None | |
elif action == "Perform Named Entity Recognition": | |
# Merge CSV as text if user wants NER on CSV | |
if file_up and file_up.name.endswith(".csv"): | |
try: | |
df_csv = parse_csv_file(file_up) | |
combined_text += "\n" + df_csv.to_csv(index=False) | |
except Exception as e: | |
return f"CSV parse error for NER: {e}", None, None, None | |
ner_result = perform_named_entity_recognition(combined_text) | |
return ner_result, None, None, None | |
elif action == "Perform Enhanced EDA": | |
return await _action_eda(combined_text, file_up, text) | |
elif action == "Fetch Clinical Studies": | |
if nct_id: | |
result = await fetch_articles_by_nct_id(nct_id) | |
elif query_params: | |
result = await fetch_articles_by_query(query_params) | |
else: | |
return "Provide either an NCT ID or valid query parameters.", None, None, None | |
articles = result.get("resultList", {}).get("result", []) | |
if not articles: | |
return "No articles found.", None, None, None | |
formatted_results = "\n\n".join( | |
f"Title: {a.get('title')}\nJournal: {a.get('journalTitle')} ({a.get('pubYear')})" | |
for a in articles | |
) | |
return formatted_results, None, None, None | |
elif action in ["Fetch PubMed Articles (Legacy)", "Fetch PubMed by Query"]: | |
pubmed_result = await fetch_pubmed_by_query(query_params) | |
xml_data = pubmed_result.get("result") | |
if xml_data: | |
articles = parse_pubmed_xml(xml_data) | |
if not articles: | |
return "No articles found.", None, None, None | |
formatted = "\n\n".join( | |
f"{a['Title']} - {a['Journal']} ({a['PublicationDate']})" | |
for a in articles if a['Title'] | |
) | |
return formatted if formatted else "No articles found.", None, None, None | |
return "No articles found or error fetching data.", None, None, None | |
elif action == "Fetch Crossref by Query": | |
crossref_result = await fetch_crossref_by_query(query_params) | |
items = crossref_result.get("message", {}).get("items", []) | |
if not items: | |
return "No results found.", None, None, None | |
formatted = "\n\n".join( | |
f"Title: {item.get('title', ['No title'])[0]}, DOI: {item.get('DOI')}" | |
for item in items | |
) | |
return formatted, None, None, None | |
return "Invalid action.", None, None, None | |
def _action_predict_outcome(combined_text: str, file_up: gr.File) -> Tuple[Optional[str], Optional[Any], Optional[Any], Optional[str]]: | |
# If CSV is uploaded, we can merge it into text or do separate logic | |
if file_up and file_up.name.endswith(".csv"): | |
try: | |
df_csv = parse_csv_file(file_up) | |
# Optionally, merge CSV content into the text to be classified | |
combined_text_local = combined_text + "\n" + df_csv.to_csv(index=False) | |
except Exception as e: | |
return f"CSV parse error for Predict Outcome: {e}", None, None, None | |
else: | |
combined_text_local = combined_text | |
predictions = predict_outcome(combined_text_local) | |
if isinstance(predictions, dict): | |
chart = visualize_predictions(predictions) | |
return json.dumps(predictions, indent=2), chart, None, None | |
return predictions, None, None, None | |
async def _action_eda(combined_text: str, file_up: Optional[gr.File], raw_text: str) -> Tuple[Optional[str], Optional[Any], Optional[Any], Optional[str]]: | |
""" | |
Perform Enhanced EDA on a CSV or Excel file if uploaded. | |
If .csv is present, parse as CSV; if .xls/.xlsx is present, parse as Excel. | |
""" | |
# Make sure we either have a file or some data in the text | |
if not file_up and not raw_text.strip(): | |
return "No data provided for EDA.", None, None, None | |
if file_up: | |
file_ext = os.path.splitext(file_up.name)[1].lower() | |
if file_ext == ".csv": | |
try: | |
df_csv = parse_csv_file(file_up) | |
eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_csv) | |
return eda_summary, corr_chart, dist_chart, None | |
except Exception as e: | |
return f"CSV EDA failed: {e}", None, None, None | |
elif file_ext in [".xls", ".xlsx"]: | |
try: | |
df_excel = parse_excel_file(file_up) | |
eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_excel) | |
return eda_summary, corr_chart, dist_chart, None | |
except Exception as e: | |
return f"Excel EDA failed: {e}", None, None, None | |
else: | |
# EDA not supported for PDF or .txt in this example | |
return "No valid CSV/Excel data found for EDA.", None, None, None | |
else: | |
# If no file, maybe the user pasted CSV into the text box | |
if "," in raw_text: | |
# Attempt to parse text as CSV | |
try: | |
from io import StringIO | |
df_csv = pd.read_csv(StringIO(raw_text)) | |
eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_csv) | |
return eda_summary, corr_chart, dist_chart, None | |
except Exception as e: | |
return f"EDA parse error for pasted CSV: {e}", None, None, None | |
return "No valid CSV/Excel data found for EDA.", None, None, None | |
submit_button.click( | |
handle_action, | |
inputs=[ | |
action, | |
text_input, | |
file_input, | |
translation_option, | |
query_params_input, | |
nct_id_input, | |
report_filename_input, | |
export_format, | |
], | |
outputs=[ | |
output_text, | |
output_chart, | |
output_chart2, | |
output_file, | |
], | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |