mgbam commited on
Commit
cd0b15a
·
verified ·
1 Parent(s): 9969a4b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +653 -0
  2. requirements.txt +26 -0
app.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import csv
4
+ import asyncio
5
+ import xml.etree.ElementTree as ET
6
+ from typing import Any, Dict, Optional, Tuple, Union, List
7
+
8
+ import httpx
9
+ import gradio as gr
10
+ import torch
11
+ from dotenv import load_dotenv
12
+ from loguru import logger
13
+ from huggingface_hub import login
14
+ from openai import OpenAI
15
+ from reportlab.pdfgen import canvas
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ AutoModelForSequenceClassification,
19
+ MarianMTModel,
20
+ MarianTokenizer,
21
+ )
22
+ import pandas as pd
23
+ import altair as alt
24
+ import spacy
25
+ import spacy.cli
26
+ import PyPDF2 # For PDF reading
27
+
28
+ # Ensure spaCy model is downloaded
29
+ try:
30
+ nlp = spacy.load("en_core_web_sm")
31
+ except OSError:
32
+ logger.info("Downloading SpaCy 'en_core_web_sm' model...")
33
+ spacy.cli.download("en_core_web_sm")
34
+ nlp = spacy.load("en_core_web_sm")
35
+
36
+ # Logging
37
+ logger.add("error_logs.log", rotation="1 MB", level="ERROR")
38
+
39
+ # Load environment variables
40
+ load_dotenv()
41
+ HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
42
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
43
+ ENTREZ_EMAIL = os.getenv("ENTREZ_EMAIL")
44
+
45
+ # Basic checks
46
+ if not HUGGINGFACE_TOKEN or not OPENAI_API_KEY:
47
+ logger.error("Missing Hugging Face or OpenAI credentials.")
48
+ raise ValueError("Missing credentials for Hugging Face or OpenAI.")
49
+
50
+ # API endpoints
51
+ PUBMED_SEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
52
+ PUBMED_FETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
53
+ EUROPE_PMC_BASE_URL = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
54
+
55
+ # Hugging Face login
56
+ login(HUGGINGFACE_TOKEN)
57
+
58
+ # Initialize OpenAI
59
+ client = OpenAI(api_key=OPENAI_API_KEY)
60
+
61
+ # Device setting
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ logger.info(f"Using device: {device}")
64
+
65
+ # Model settings
66
+ MODEL_NAME = "mgbam/bert-base-finetuned-mgbam"
67
+ try:
68
+ model = AutoModelForSequenceClassification.from_pretrained(
69
+ MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
70
+ ).to(device)
71
+ tokenizer = AutoTokenizer.from_pretrained(
72
+ MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
73
+ )
74
+ except Exception as e:
75
+ logger.error(f"Model load error: {e}")
76
+ raise
77
+
78
+ # Translation model settings
79
+ try:
80
+ translation_model_name = "Helsinki-NLP/opus-mt-en-fr"
81
+ translation_model = MarianMTModel.from_pretrained(
82
+ translation_model_name, use_auth_token=HUGGINGFACE_TOKEN
83
+ ).to(device)
84
+ translation_tokenizer = MarianTokenizer.from_pretrained(
85
+ translation_model_name, use_auth_token=HUGGINGFACE_TOKEN
86
+ )
87
+ except Exception as e:
88
+ logger.error(f"Translation model load error: {e}")
89
+ raise
90
+
91
+ LANGUAGE_MAP: Dict[str, Tuple[str, str]] = {
92
+ "English to French": ("en", "fr"),
93
+ "French to English": ("fr", "en"),
94
+ }
95
+
96
+ ### Utility Functions ###
97
+ def safe_json_parse(text: str) -> Union[Dict, None]:
98
+ """Safely parse JSON string into a Python dictionary."""
99
+ try:
100
+ return json.loads(text)
101
+ except json.JSONDecodeError as e:
102
+ logger.error(f"JSON parsing error: {e}")
103
+ return None
104
+
105
+ def parse_pubmed_xml(xml_data: str) -> List[Dict[str, Any]]:
106
+ """Parses PubMed XML data and returns a list of structured articles."""
107
+ root = ET.fromstring(xml_data)
108
+ articles = []
109
+ for article in root.findall(".//PubmedArticle"):
110
+ pmid = article.findtext(".//PMID")
111
+ title = article.findtext(".//ArticleTitle")
112
+ abstract = article.findtext(".//AbstractText")
113
+ journal = article.findtext(".//Journal/Title")
114
+ pub_date_elem = article.find(".//JournalIssue/PubDate")
115
+ pub_date = None
116
+ if pub_date_elem is not None:
117
+ year = pub_date_elem.findtext("Year")
118
+ month = pub_date_elem.findtext("Month")
119
+ day = pub_date_elem.findtext("Day")
120
+ if year and month and day:
121
+ pub_date = f"{year}-{month}-{day}"
122
+ else:
123
+ pub_date = year
124
+ articles.append({
125
+ "PMID": pmid,
126
+ "Title": title,
127
+ "Abstract": abstract,
128
+ "Journal": journal,
129
+ "PublicationDate": pub_date,
130
+ })
131
+ return articles
132
+
133
+ ### Async Functions for Europe PMC ###
134
+ async def fetch_articles_by_nct_id(nct_id: str) -> Dict[str, Any]:
135
+ params = {"query": nct_id, "format": "json"}
136
+ async with httpx.AsyncClient() as client_http:
137
+ try:
138
+ response = await client_http.get(EUROPE_PMC_BASE_URL, params=params)
139
+ response.raise_for_status()
140
+ return response.json()
141
+ except Exception as e:
142
+ logger.error(f"Error fetching articles for {nct_id}: {e}")
143
+ return {"error": str(e)}
144
+
145
+ async def fetch_articles_by_query(query_params: str) -> Dict[str, Any]:
146
+ parsed_params = safe_json_parse(query_params)
147
+ if not parsed_params or not isinstance(parsed_params, dict):
148
+ return {"error": "Invalid JSON."}
149
+ query_string = " AND ".join(f"{k}:{v}" for k, v in parsed_params.items())
150
+ params = {"query": query_string, "format": "json"}
151
+ async with httpx.AsyncClient() as client_http:
152
+ try:
153
+ response = await client_http.get(EUROPE_PMC_BASE_URL, params=params)
154
+ response.raise_for_status()
155
+ return response.json()
156
+ except Exception as e:
157
+ logger.error(f"Error fetching articles: {e}")
158
+ return {"error": str(e)}
159
+
160
+ ### PubMed Integration ###
161
+ async def fetch_pubmed_by_query(query_params: str) -> Dict[str, Any]:
162
+ parsed_params = safe_json_parse(query_params)
163
+ if not parsed_params or not isinstance(parsed_params, dict):
164
+ return {"error": "Invalid JSON for PubMed."}
165
+
166
+ search_params = {
167
+ "db": "pubmed",
168
+ "retmode": "json",
169
+ "email": ENTREZ_EMAIL,
170
+ "retmax": parsed_params.get("retmax", "10"),
171
+ "term": parsed_params.get("term", ""),
172
+ }
173
+
174
+ async with httpx.AsyncClient() as client_http:
175
+ try:
176
+ search_response = await client_http.get(PUBMED_SEARCH_URL, params=search_params)
177
+ search_response.raise_for_status()
178
+ search_data = search_response.json()
179
+ id_list = search_data.get("esearchresult", {}).get("idlist", [])
180
+ if not id_list:
181
+ return {"result": ""}
182
+
183
+ fetch_params = {
184
+ "db": "pubmed",
185
+ "id": ",".join(id_list),
186
+ "retmode": "xml",
187
+ "email": ENTREZ_EMAIL,
188
+ }
189
+ fetch_response = await client_http.get(PUBMED_FETCH_URL, params=fetch_params)
190
+ fetch_response.raise_for_status()
191
+ return {"result": fetch_response.text}
192
+ except Exception as e:
193
+ logger.error(f"Error fetching PubMed articles: {e}")
194
+ return {"error": str(e)}
195
+
196
+ ### Crossref Integration ###
197
+ async def fetch_crossref_by_query(query_params: str) -> Dict[str, Any]:
198
+ parsed_params = safe_json_parse(query_params)
199
+ if not parsed_params or not isinstance(parsed_params, dict):
200
+ return {"error": "Invalid JSON for Crossref."}
201
+ CROSSREF_API_URL = "https://api.crossref.org/works"
202
+ async with httpx.AsyncClient() as client_http:
203
+ try:
204
+ response = await client_http.get(CROSSREF_API_URL, params=parsed_params)
205
+ response.raise_for_status()
206
+ return response.json()
207
+ except Exception as e:
208
+ logger.error(f"Error fetching Crossref data: {e}")
209
+ return {"error": str(e)}
210
+
211
+ ### Core Functions ###
212
+ def summarize_text(text: str) -> str:
213
+ """Summarize text using OpenAI."""
214
+ if not text.strip():
215
+ return "No text provided for summarization."
216
+ try:
217
+ response = client.chat.completions.create(
218
+ model="gpt-3.5-turbo",
219
+ messages=[{"role": "user", "content": f"Summarize the following clinical data:\n{text}"}],
220
+ max_tokens=200,
221
+ temperature=0.7,
222
+ )
223
+ return response.choices[0].message.content.strip()
224
+ except Exception as e:
225
+ logger.error(f"Summarization Error: {e}")
226
+ return "Summarization failed."
227
+
228
+ def predict_outcome(text: str) -> Union[Dict[str, float], str]:
229
+ """Predict outcomes (classification) using a fine-tuned model."""
230
+ if not text.strip():
231
+ return "No text provided for prediction."
232
+ try:
233
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
234
+ inputs = {k: v.to(device) for k, v in inputs.items()}
235
+ with torch.no_grad():
236
+ outputs = model(**inputs)
237
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
238
+ return {f"Label {i+1}": float(prob.item()) for i, prob in enumerate(probabilities)}
239
+ except Exception as e:
240
+ logger.error(f"Prediction Error: {e}")
241
+ return "Prediction failed."
242
+
243
+ def generate_report(text: str, filename: str = "clinical_report.pdf") -> Optional[str]:
244
+ """Generate a PDF report from the given text."""
245
+ try:
246
+ if not text.strip():
247
+ logger.warning("No text provided for the report.")
248
+ c = canvas.Canvas(filename)
249
+ c.drawString(100, 750, "Clinical Research Report")
250
+ lines = text.split("\n")
251
+ y = 730
252
+ for line in lines:
253
+ if y < 50:
254
+ c.showPage()
255
+ y = 750
256
+ c.drawString(100, y, line)
257
+ y -= 15
258
+ c.save()
259
+ logger.info(f"Report generated: {filename}")
260
+ return filename
261
+ except Exception as e:
262
+ logger.error(f"Report Generation Error: {e}")
263
+ return None
264
+
265
+ def visualize_predictions(predictions: Dict[str, float]) -> Optional[alt.Chart]:
266
+ """Visualize model prediction probabilities using Altair."""
267
+ try:
268
+ data = pd.DataFrame(list(predictions.items()), columns=["Label", "Probability"])
269
+ chart = (
270
+ alt.Chart(data)
271
+ .mark_bar()
272
+ .encode(
273
+ x=alt.X("Label:N", sort=None),
274
+ y="Probability:Q",
275
+ tooltip=["Label", "Probability"],
276
+ )
277
+ .properties(title="Prediction Probabilities", width=500, height=300)
278
+ )
279
+ return chart
280
+ except Exception as e:
281
+ logger.error(f"Visualization Error: {e}")
282
+ return None
283
+
284
+ def translate_text(text: str, translation_option: str) -> str:
285
+ """Translate text between English and French."""
286
+ if not text.strip():
287
+ return "No text provided for translation."
288
+ try:
289
+ if translation_option not in LANGUAGE_MAP:
290
+ return "Unsupported translation option."
291
+ inputs = translation_tokenizer(text, return_tensors="pt", padding=True).to(device)
292
+ translated_tokens = translation_model.generate(**inputs)
293
+ return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
294
+ except Exception as e:
295
+ logger.error(f"Translation Error: {e}")
296
+ return "Translation failed."
297
+
298
+ def perform_named_entity_recognition(text: str) -> str:
299
+ """Perform Named Entity Recognition (NER) using spaCy."""
300
+ if not text.strip():
301
+ return "No text provided for NER."
302
+ try:
303
+ doc = nlp(text)
304
+ entities = [(ent.text, ent.label_) for ent in doc.ents]
305
+ if not entities:
306
+ return "No named entities found."
307
+ return "\n".join(f"{ent_text} -> {ent_label}" for ent_text, ent_label in entities)
308
+ except Exception as e:
309
+ logger.error(f"NER Error: {e}")
310
+ return "Named Entity Recognition failed."
311
+
312
+ ### Enhanced EDA ###
313
+ def perform_enhanced_eda(df: pd.DataFrame) -> Tuple[str, Optional[alt.Chart], Optional[alt.Chart]]:
314
+ """
315
+ Perform a more advanced EDA given a DataFrame:
316
+ - Show dataset info (columns, shape, numeric summary).
317
+ - Generate a correlation heatmap (for numeric columns).
318
+ - Generate distribution plots (histograms) for numeric columns.
319
+ Returns (text_summary, correlation_chart, distribution_chart).
320
+ """
321
+ try:
322
+ # Basic info
323
+ columns_info = f"Columns: {list(df.columns)}"
324
+ shape_info = f"Shape: {df.shape[0]} rows x {df.shape[1]} columns"
325
+
326
+ # Use describe with "include='all'" to show all columns summary
327
+ with pd.option_context("display.max_colwidth", 200, "display.max_rows", None):
328
+ describe_info = df.describe(include="all").to_string()
329
+
330
+ summary_text = (
331
+ f"--- Enhanced EDA Summary ---\n"
332
+ f"{columns_info}\n{shape_info}\n\n"
333
+ f"Summary Statistics:\n{describe_info}\n"
334
+ )
335
+
336
+ # Correlation heatmap
337
+ numeric_cols = df.select_dtypes(include="number")
338
+ corr_chart = None
339
+ if numeric_cols.shape[1] >= 2:
340
+ corr = numeric_cols.corr()
341
+ corr_melted = corr.reset_index().melt(id_vars="index")
342
+ corr_melted.columns = ["Feature1", "Feature2", "Correlation"]
343
+ corr_chart = (
344
+ alt.Chart(corr_melted)
345
+ .mark_rect()
346
+ .encode(
347
+ x="Feature1:O",
348
+ y="Feature2:O",
349
+ color="Correlation:Q",
350
+ tooltip=["Feature1", "Feature2", "Correlation"]
351
+ )
352
+ .properties(width=400, height=400, title="Correlation Heatmap")
353
+ )
354
+
355
+ # Distribution plots (histograms) for numeric columns
356
+ distribution_chart = None
357
+ if numeric_cols.shape[1] >= 1:
358
+ df_long = numeric_cols.melt(var_name='Column', value_name='Value')
359
+ distribution_chart = (
360
+ alt.Chart(df_long)
361
+ .mark_bar()
362
+ .encode(
363
+ alt.X("Value:Q", bin=alt.Bin(maxbins=30)),
364
+ alt.Y('count()'),
365
+ alt.Facet('Column:N', columns=2),
366
+ tooltip=["Value"]
367
+ )
368
+ .properties(
369
+ title='Distribution of Numeric Columns',
370
+ width=300,
371
+ height=200
372
+ )
373
+ .interactive()
374
+ )
375
+
376
+ return summary_text, corr_chart, distribution_chart
377
+
378
+ except Exception as e:
379
+ logger.error(f"Enhanced EDA Error: {e}")
380
+ return f"Enhanced EDA failed: {e}", None, None
381
+
382
+ ### File Handling ###
383
+ def read_uploaded_file(uploaded_file: Optional[gr.File]) -> str:
384
+ """
385
+ Reads the content of an uploaded file (txt, csv, xls, xlsx, pdf).
386
+ Returns the extracted text or CSV-like content.
387
+ """
388
+ if uploaded_file is None:
389
+ return ""
390
+
391
+ file_name = uploaded_file.name
392
+ file_ext = os.path.splitext(file_name)[1].lower()
393
+
394
+ try:
395
+ # For text
396
+ if file_ext == ".txt":
397
+ return uploaded_file.read().decode("utf-8")
398
+
399
+ # For CSV
400
+ elif file_ext == ".csv":
401
+ return uploaded_file.read().decode("utf-8")
402
+
403
+ # For Excel
404
+ elif file_ext in [".xls", ".xlsx"]:
405
+ # We'll just return empty here and parse it later into a DataFrame
406
+ # because we can read the binary directly into pd.read_excel().
407
+ # Or store as bytes for later use in EDA.
408
+ return "EXCEL_FILE_PLACEHOLDER" # We'll handle it differently in EDA step
409
+
410
+ # For PDF
411
+ elif file_ext == ".pdf":
412
+ pdf_reader = PyPDF2.PdfReader(uploaded_file)
413
+ text_content = []
414
+ for page in pdf_reader.pages:
415
+ text_content.append(page.extract_text())
416
+ return "\n".join(text_content)
417
+
418
+ else:
419
+ return f"Unsupported file format: {file_ext}"
420
+ except Exception as e:
421
+ logger.error(f"File read error: {e}")
422
+ return f"Error reading file: {e}"
423
+
424
+ def parse_excel_file(uploaded_file) -> pd.DataFrame:
425
+ """
426
+ Parse an Excel file into a pandas DataFrame.
427
+ We assume the user wants the first sheet or we can guess.
428
+ """
429
+ try:
430
+ # For Excel, we can do:
431
+ df = pd.read_excel(uploaded_file, engine="openpyxl")
432
+ return df
433
+ except Exception as e:
434
+ logger.error(f"Excel parsing error: {e}")
435
+ raise
436
+
437
+ def parse_csv_content(csv_content: str) -> pd.DataFrame:
438
+ """
439
+ Attempt to parse CSV content with both utf-8 and utf-8-sig to handle BOM issues.
440
+ """
441
+ from io import StringIO
442
+ errors = []
443
+ for encoding_try in ["utf-8", "utf-8-sig"]:
444
+ try:
445
+ df = pd.read_csv(StringIO(csv_content), encoding=encoding_try)
446
+ return df
447
+ except Exception as e:
448
+ errors.append(f"Encoding {encoding_try} failed: {e}")
449
+ error_msg = "Could not parse CSV content.\n" + "\n".join(errors)
450
+ logger.error(error_msg)
451
+ raise ValueError(error_msg)
452
+
453
+ ### Gradio Interface ###
454
+ with gr.Blocks() as demo:
455
+ gr.Markdown("# ✨ Advanced Clinical Research Assistant with Enhanced EDA ✨")
456
+ gr.Markdown("""
457
+ Welcome to the **Enhanced** AI-Powered Clinical Assistant!
458
+ - **Summarize** large blocks of clinical text.
459
+ - **Predict** outcomes with a fine-tuned model.
460
+ - **Translate** text between English & French.
461
+ - **Perform Named Entity Recognition** with spaCy.
462
+ - **Fetch** from PubMed, Crossref, Europe PMC.
463
+ - **Generate** professional PDF reports.
464
+ - **Perform Enhanced EDA** on CSV/Excel data with correlation heatmaps & distribution plots.
465
+ """)
466
+
467
+ # Inputs
468
+ with gr.Row():
469
+ text_input = gr.Textbox(label="Input Text", lines=5, placeholder="Enter clinical text or query...")
470
+ file_input = gr.File(
471
+ label="Upload File (txt/csv/xls/xlsx/pdf)",
472
+ file_types=[".txt", ".csv", ".xls", ".xlsx", ".pdf"]
473
+ )
474
+
475
+ action = gr.Radio(
476
+ [
477
+ "Summarize",
478
+ "Predict Outcome",
479
+ "Generate Report",
480
+ "Translate",
481
+ "Perform Named Entity Recognition",
482
+ "Perform Enhanced EDA",
483
+ "Fetch Clinical Studies",
484
+ "Fetch PubMed Articles (Legacy)",
485
+ "Fetch PubMed by Query",
486
+ "Fetch Crossref by Query",
487
+ ],
488
+ label="Select an Action",
489
+ )
490
+ translation_option = gr.Dropdown(
491
+ choices=list(LANGUAGE_MAP.keys()),
492
+ label="Translation Option",
493
+ value="English to French"
494
+ )
495
+ query_params_input = gr.Textbox(
496
+ label="Query Parameters (JSON Format)",
497
+ placeholder='{"term": "cancer", "retmax": "5"}'
498
+ )
499
+ nct_id_input = gr.Textbox(label="NCT ID for Article Search")
500
+ report_filename_input = gr.Textbox(
501
+ label="Report Filename",
502
+ placeholder="clinical_report.pdf",
503
+ value="clinical_report.pdf"
504
+ )
505
+ export_format = gr.Dropdown(["None", "CSV", "JSON"], label="Export Format")
506
+
507
+ # Outputs
508
+ output_text = gr.Textbox(label="Output", lines=10)
509
+
510
+ with gr.Row():
511
+ output_chart = gr.Plot(label="Visualization 1")
512
+ output_chart2 = gr.Plot(label="Visualization 2")
513
+
514
+ output_file = gr.File(label="Generated File")
515
+
516
+ submit_button = gr.Button("Submit")
517
+
518
+ # Async function for handling actions
519
+ async def handle_action(
520
+ action: str,
521
+ text: str,
522
+ file_up: gr.File,
523
+ translation_opt: str,
524
+ query_params: str,
525
+ nct_id: str,
526
+ report_filename: str,
527
+ export_format: str
528
+ ) -> Tuple[Optional[str], Optional[Any], Optional[Any], Optional[str]]:
529
+
530
+ # Read the uploaded file
531
+ file_content = read_uploaded_file(file_up)
532
+ combined_text = (text + "\n" + file_content).strip() if file_content else text
533
+
534
+ # Branch by action
535
+ if action == "Summarize":
536
+ return summarize_text(combined_text), None, None, None
537
+
538
+ elif action == "Predict Outcome":
539
+ predictions = predict_outcome(combined_text)
540
+ if isinstance(predictions, dict):
541
+ chart = visualize_predictions(predictions)
542
+ return json.dumps(predictions, indent=2), chart, None, None
543
+ return predictions, None, None, None
544
+
545
+ elif action == "Generate Report":
546
+ file_path = generate_report(combined_text, filename=report_filename)
547
+ msg = f"Report generated: {file_path}" if file_path else "Report generation failed."
548
+ return msg, None, None, file_path
549
+
550
+ elif action == "Translate":
551
+ return translate_text(combined_text, translation_opt), None, None, None
552
+
553
+ elif action == "Perform Named Entity Recognition":
554
+ ner_result = perform_named_entity_recognition(combined_text)
555
+ return ner_result, None, None, None
556
+
557
+ elif action == "Perform Enhanced EDA":
558
+ # We expect the user to either upload a CSV or Excel, or paste CSV content.
559
+ if file_up is None and not combined_text:
560
+ return "No data provided for EDA.", None, None, None
561
+
562
+ # If Excel was uploaded
563
+ if file_up and file_up.name.lower().endswith((".xls", ".xlsx")):
564
+ try:
565
+ df_excel = parse_excel_file(file_up)
566
+ eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_excel)
567
+ return eda_summary, corr_chart, dist_chart, None
568
+ except Exception as e:
569
+ return f"Excel EDA failed: {e}", None, None, None
570
+
571
+ # If CSV was uploaded
572
+ if file_up and file_up.name.lower().endswith(".csv"):
573
+ try:
574
+ df_csv = parse_csv_content(file_content)
575
+ eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_csv)
576
+ return eda_summary, corr_chart, dist_chart, None
577
+ except Exception as e:
578
+ return f"CSV EDA failed: {e}", None, None, None
579
+
580
+ # If user just pasted CSV content (no file)
581
+ if not file_up and "," in combined_text:
582
+ try:
583
+ df_csv = parse_csv_content(combined_text)
584
+ eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_csv)
585
+ return eda_summary, corr_chart, dist_chart, None
586
+ except Exception as e:
587
+ return f"CSV EDA failed: {e}", None, None, None
588
+
589
+ # Otherwise, not supported
590
+ return "No valid CSV/Excel data found for EDA.", None, None, None
591
+
592
+ elif action == "Fetch Clinical Studies":
593
+ if nct_id:
594
+ result = await fetch_articles_by_nct_id(nct_id)
595
+ elif query_params:
596
+ result = await fetch_articles_by_query(query_params)
597
+ else:
598
+ return "Provide either an NCT ID or valid query parameters.", None, None, None
599
+
600
+ articles = result.get("resultList", {}).get("result", [])
601
+ if not articles:
602
+ return "No articles found.", None, None, None
603
+
604
+ formatted_results = "\n\n".join(
605
+ f"Title: {a.get('title')}\nJournal: {a.get('journalTitle')} ({a.get('pubYear')})"
606
+ for a in articles
607
+ )
608
+ return formatted_results, None, None, None
609
+
610
+ elif action in ["Fetch PubMed Articles (Legacy)", "Fetch PubMed by Query"]:
611
+ pubmed_result = await fetch_pubmed_by_query(query_params)
612
+ xml_data = pubmed_result.get("result")
613
+ if xml_data:
614
+ articles = parse_pubmed_xml(xml_data)
615
+ if not articles:
616
+ return "No articles found.", None, None, None
617
+ formatted = "\n\n".join(
618
+ f"{a['Title']} - {a['Journal']} ({a['PublicationDate']})"
619
+ for a in articles if a['Title']
620
+ )
621
+ return formatted if formatted else "No articles found.", None, None, None
622
+ return "No articles found or error fetching data.", None, None, None
623
+
624
+ elif action == "Fetch Crossref by Query":
625
+ crossref_result = await fetch_crossref_by_query(query_params)
626
+ items = crossref_result.get("message", {}).get("items", [])
627
+ if not items:
628
+ return "No results found.", None, None, None
629
+ formatted = "\n\n".join(
630
+ f"Title: {item.get('title', ['No title'])[0]}, DOI: {item.get('DOI')}"
631
+ for item in items
632
+ )
633
+ return formatted, None, None, None
634
+
635
+ return "Invalid action.", None, None, None
636
+
637
+ submit_button.click(
638
+ handle_action,
639
+ inputs=[
640
+ action,
641
+ text_input,
642
+ file_input,
643
+ translation_option,
644
+ query_params_input,
645
+ nct_id_input,
646
+ report_filename_input,
647
+ export_format,
648
+ ],
649
+ outputs=[output_text, output_chart, output_chart2, output_file],
650
+ )
651
+
652
+ # Launch the Gradio app
653
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ gradio
3
+ openai>=0.27.8
4
+ torch>=2.0.0
5
+ transformers>=4.33.0
6
+ huggingface-hub>=0.16.0
7
+ python-dotenv>=1.0.0
8
+ reportlab>=3.6.0
9
+ matplotlib>=3.7.1
10
+ pandas>=2.0.3
11
+ altair>=4.2.2
12
+ loguru>=0.7.0
13
+ spacy>=3.6.0
14
+ PyPDF2>=3.0.0
15
+ pdfplumber>=0.9.0
16
+ Pillow>=10.0.0
17
+ sentencepiece
18
+ sacremoses>=0.0.53
19
+ httpx
20
+ numpy
21
+ reportlab
22
+ requests
23
+ openpyxl
24
+
25
+
26
+