mgbam commited on
Commit
4ec8667
·
verified ·
1 Parent(s): 0fc08b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +785 -539
app.py CHANGED
@@ -1,59 +1,67 @@
1
- import streamlit as st
2
- import pandas as pd
3
- from typing import Dict, List, Optional, Any
4
- from pydantic import BaseModel, Field
5
  import base64
6
  import io
 
 
 
 
 
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
9
- from abc import ABC, abstractmethod
 
 
10
  from sklearn.model_selection import train_test_split
11
  from sklearn.linear_model import LogisticRegression
12
  from sklearn.metrics import accuracy_score
 
 
 
13
  from statsmodels.tsa.seasonal import seasonal_decompose
14
  from statsmodels.tsa.stattools import adfuller
 
 
 
 
15
  from langchain.prompts import PromptTemplate
16
  from groq import Groq
17
- import os
18
- import numpy as np
19
- from scipy.stats import ttest_ind, f_oneway
20
- import json
21
- from Bio import Entrez
22
- from sklearn.feature_extraction.text import TfidfVectorizer
23
- from sklearn.metrics.pairwise import cosine_similarity
24
 
25
- # Initialize Groq Client
 
26
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
27
 
28
  # ---------------------- Base Classes and Schemas ---------------------------
 
29
  class ResearchInput(BaseModel):
30
- """Base schema for research tool inputs"""
31
  data_key: str = Field(..., description="Session state key containing DataFrame")
32
  columns: Optional[List[str]] = Field(None, description="List of columns to analyze")
33
 
34
  class TemporalAnalysisInput(ResearchInput):
35
- """Schema for temporal analysis"""
36
  time_col: str = Field(..., description="Name of timestamp column")
37
  value_col: str = Field(..., description="Name of value column to analyze")
38
 
39
  class HypothesisInput(ResearchInput):
40
- """Schema for hypothesis testing"""
41
  group_col: str = Field(..., description="Categorical column defining groups")
42
  value_col: str = Field(..., description="Numerical column to compare")
43
 
44
  class ModelTrainingInput(ResearchInput):
45
- """Schema for model training"""
46
  target_col: str = Field(..., description="Name of target column")
47
 
48
  class DataAnalyzer(ABC):
49
- """Abstract base class for data analysis modules"""
50
  @abstractmethod
51
- def invoke(self, data:pd.DataFrame, **kwargs) -> Dict[str, Any]:
52
  pass
53
 
54
  # ---------------------- Concrete Analyzer Implementations ---------------------------
 
55
  class AdvancedEDA(DataAnalyzer):
56
- """Comprehensive Exploratory Data Analysis"""
57
  def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
58
  try:
59
  analysis = {
@@ -84,44 +92,44 @@ class AdvancedEDA(DataAnalyzer):
84
  return {"error": f"EDA Failed: {str(e)}"}
85
 
86
  class DistributionVisualizer(DataAnalyzer):
87
- """Distribution visualizations"""
88
  def invoke(self, data: pd.DataFrame, columns: List[str], **kwargs) -> str:
89
- try:
90
- plt.figure(figsize=(12, 6))
91
- for i, col in enumerate(columns, 1):
92
- plt.subplot(1, len(columns), i)
93
- sns.histplot(data[col], kde=True, stat="density")
94
- plt.title(f'Distribution of {col}', fontsize=10)
95
- plt.xticks(fontsize=8)
96
- plt.yticks(fontsize=8)
97
- plt.tight_layout()
98
-
99
- buf = io.BytesIO()
100
- plt.savefig(buf, format='png', dpi=300, bbox_inches='tight')
101
- plt.close()
102
- return base64.b64encode(buf.getvalue()).decode()
103
- except Exception as e:
104
- return f"Visualization Error: {str(e)}"
105
 
106
  class TemporalAnalyzer(DataAnalyzer):
107
- """Time series analysis"""
108
  def invoke(self, data: pd.DataFrame, time_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
109
  try:
110
  ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col]
111
  decomposition = seasonal_decompose(ts_data, period=365)
112
-
113
  plt.figure(figsize=(12, 8))
114
  decomposition.plot()
115
  plt.tight_layout()
116
-
117
  buf = io.BytesIO()
118
  plt.savefig(buf, format='png')
119
  plt.close()
120
  plot_data = base64.b64encode(buf.getvalue()).decode()
121
-
122
  return {
123
  "trend_statistics": {
124
- "stationarity": adfuller(ts_data)[1],
125
  "seasonality_strength": max(decomposition.seasonal)
126
  },
127
  "visualization": plot_data
@@ -130,628 +138,866 @@ class TemporalAnalyzer(DataAnalyzer):
130
  return {"error": f"Temporal Analysis Failed: {str(e)}"}
131
 
132
  class HypothesisTester(DataAnalyzer):
133
- """Statistical hypothesis testing"""
134
  def invoke(self, data: pd.DataFrame, group_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
135
- try:
136
- groups = data[group_col].unique()
137
-
138
- if len(groups) < 2:
139
- return {"error": "Insufficient groups for comparison"}
140
-
141
- if len(groups) == 2:
142
  group_data = [data[data[group_col] == g][value_col] for g in groups]
143
- stat, p = ttest_ind(*group_data)
144
- test_type = "Independent t-test"
145
- else:
146
- group_data = [data[data[group_col] == g][value_col] for g in groups]
147
- stat, p = f_oneway(*group_data)
148
- test_type = "ANOVA"
149
-
150
- return {
151
- "test_type": test_type,
152
- "test_statistic": stat,
153
- "p_value": p,
154
- "effect_size": {
155
- "cohens_d": abs(group_data[0].mean() - group_data[1].mean())/np.sqrt(
156
- (group_data[0].var() + group_data[1].var())/2
157
- ) if len(groups) == 2 else None
158
- },
159
- "interpretation": self.interpret_p_value(p)
160
- }
161
- except Exception as e:
162
- return {"error": f"Hypothesis Testing Failed: {str(e)}"}
163
 
164
- def interpret_p_value(self, p: float) -> str:
165
- if p < 0.001: return "Very strong evidence against H0"
166
- elif p < 0.01: return "Strong evidence against H0"
167
- elif p < 0.05: return "Evidence against H0"
168
- elif p < 0.1: return "Weak evidence against H0"
169
- else: return "No significant evidence against H0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  class LogisticRegressionTrainer(DataAnalyzer):
172
- """Logistic Regression Model Trainer"""
173
  def invoke(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
174
- try:
175
- X = data[columns]
176
- y = data[target_col]
177
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
178
- model = LogisticRegression(max_iter=1000)
179
- model.fit(X_train, y_train)
180
- y_pred = model.predict(X_test)
181
- accuracy = accuracy_score(y_test, y_pred)
182
- return {
183
- "model_type": "Logistic Regression",
184
- "accuracy": accuracy,
185
- "model_params": model.get_params()
186
- }
187
- except Exception as e:
188
- return {"error": f"Logistic Regression Model Error: {str(e)}"}
 
189
  # ---------------------- Business Logic Layer ---------------------------
190
 
191
  class ClinicalRule(BaseModel):
192
- """Defines a clinical rule"""
193
  name: str
194
  condition: str
195
  action: str
196
- severity: str # low, medium or high
197
 
198
- class ClinicalRulesEngine():
199
  """Executes rules against patient data."""
200
  def __init__(self):
201
- self.rules: Dict[str, ClinicalRule] = {}
202
-
203
  def add_rule(self, rule: ClinicalRule):
204
  self.rules[rule.name] = rule
205
 
206
- def execute_rules(self, data: pd.DataFrame):
207
  results = {}
208
  for rule_name, rule in self.rules.items():
209
  try:
210
- if eval(rule.condition, {}, {"df":data}):
211
- results[rule_name] = {"rule_matched": True,
212
- "action": rule.action,
213
- "severity": rule.severity
214
- }
215
- else:
216
- results[rule_name] = {"rule_matched": False, "action": None, "severity": None}
217
  except Exception as e:
218
- results[rule_name] = {"rule_matched": False, "error": str(e), "severity": None}
 
 
 
 
219
  return results
220
 
221
  class ClinicalKPI(BaseModel):
222
- """Define a clinical KPI"""
223
- name: str
224
- calculation: str
225
- threshold: Optional[float] = None
226
 
227
- class ClinicalKPIMonitoring():
228
- """Calculates KPIs based on data"""
229
  def __init__(self):
230
- self.kpis : Dict[str, ClinicalKPI] = {}
231
 
232
- def add_kpi(self, kpi:ClinicalKPI):
233
- self.kpis[kpi.name] = kpi
234
 
235
- def calculate_kpis(self, data: pd.DataFrame):
236
  results = {}
237
  for kpi_name, kpi in self.kpis.items():
238
  try:
239
- results[kpi_name] = eval(kpi.calculation, {}, {"df": data})
 
 
 
 
 
240
  except Exception as e:
241
  results[kpi_name] = {"error": str(e)}
242
  return results
243
 
 
 
 
 
 
 
 
 
 
244
  class DiagnosisSupport(ABC):
245
- """Abstract class for implementing clinical diagnoses."""
246
- @abstractmethod
247
- def diagnose(self, data: pd.DataFrame, target_col: str, columns: List[str], diagnosis_key : str = "diagnosis" , **kwargs) -> pd.DataFrame:
248
- pass
 
 
 
 
 
 
 
249
 
250
  class SimpleDiagnosis(DiagnosisSupport):
251
- """Provides a simple diagnosis example, based on the Logistic regression model"""
252
  def __init__(self):
253
- self.model : LogisticRegressionTrainer = LogisticRegressionTrainer()
254
-
255
- def diagnose(self, data: pd.DataFrame, target_col: str, columns: List[str], diagnosis_key : str = "diagnosis", **kwargs) -> pd.DataFrame:
 
 
 
 
 
 
 
256
  try:
257
- result = self.model.invoke(data, target_col=target_col, columns = columns)
258
- if "accuracy" in result:
259
- return pd.DataFrame({diagnosis_key: [f"Accuracy {result['accuracy']}"],
260
- "model": result["model_type"]})
261
- else:
262
- return pd.DataFrame({diagnosis_key: [f"Diagnosis failed: {result}"]})
263
-
 
 
 
264
  except Exception as e:
265
- return pd.DataFrame({diagnosis_key:[f"Error during diagnosis {e}"]})
266
-
 
267
 
268
  class TreatmentRecommendation(ABC):
269
- """Abstract class for treatment recommendations"""
270
- @abstractmethod
271
- def recommend(self, data: pd.DataFrame, condition_col: str, treatment_col:str, recommendation_key: str = "recommendation", **kwargs) -> pd.DataFrame:
272
- pass
 
 
 
 
 
 
 
273
 
274
  class BasicTreatmentRecommendation(TreatmentRecommendation):
275
- """A placeholder class for basic treatment recommendations"""
276
- def recommend(self, data: pd.DataFrame, condition_col: str, treatment_col:str, recommendation_key: str = "recommendation", **kwargs) -> pd.DataFrame:
277
- if condition_col not in data.columns or treatment_col not in data.columns:
278
- return pd.DataFrame({recommendation_key: ["Condition or Treatment columns not found!"]})
279
- treatment = data[data[condition_col] == "High"][treatment_col].to_list()
280
- if len(treatment)>0:
281
- return pd.DataFrame({recommendation_key: [f"Treatment recommended for High risk patients: {treatment}"]})
282
- else:
283
- return pd.DataFrame({recommendation_key: [f"No treatment recommendation found!"]})
284
-
285
-
286
- class MedicalKnowledgeBase():
287
- """Abstract class for Medical Knowledge"""
288
- @abstractmethod
289
- def search_medical_info(self, query: str, pub_email:str="") -> str:
290
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  class SimpleMedicalKnowledge(MedicalKnowledgeBase):
293
- """Simple Medical Knowledge Class with TF-IDF and PubMed"""
294
  def __init__(self):
295
- self.knowledge_base = {
296
  "diabetes": "The recommended treatment for diabetes includes lifestyle changes, medication, and monitoring.",
297
  "heart disease": "Risk factors for heart disease include high blood pressure, high cholesterol, and smoking.",
298
- "fever": "For a fever, you can consider over-the-counter medications like acetaminophen or ibuprofen. Rest and hydration are also important.",
299
  "headache": "For a headache, try rest, hydration, and over-the-counter pain relievers. Consult a doctor if it is severe or persistent.",
300
  "cold": "For a cold, get rest, drink plenty of fluids, and use over-the-counter remedies like decongestants."
301
- }
302
- self.vectorizer = TfidfVectorizer()
303
- self.tfidf_matrix = self.vectorizer.fit_transform(self.knowledge_base.values())
304
-
305
  def search_pubmed(self, query: str, email: str) -> str:
306
- try:
307
- Entrez.email = email
308
- handle = Entrez.esearch(db="pubmed", term=query, retmax=1)
309
- record = Entrez.read(handle)
310
- handle.close()
311
- if record["IdList"]:
312
- handle = Entrez.efetch(db="pubmed", id=record["IdList"][0], rettype="abstract", retmode="text")
313
- abstract = handle.read()
314
  handle.close()
315
- return abstract
316
- else:
317
- return "No abstracts found for this query on PubMed"
318
- except Exception as e:
319
- return f"Error searching pubmed {e}"
320
-
 
 
 
321
 
322
  def search_medical_info(self, query: str, pub_email: str = "") -> str:
 
323
  try:
324
- query_vector = self.vectorizer.transform([query])
325
- similarities = cosine_similarity(query_vector, self.tfidf_matrix)
326
- best_match_index = np.argmax(similarities)
327
- best_match_keyword = list(self.knowledge_base.keys())[best_match_index]
328
- best_match_info = list(self.knowledge_base.values())[best_match_index]
329
-
330
- pubmed_result = self.search_pubmed(query, pub_email)
331
- if "No abstracts found for this query on PubMed" not in pubmed_result:
332
- return f"Based on the query provided, I found this: {best_match_info} \n\nFrom Pubmed I also found the following abstract: \n {pubmed_result}"
333
- else:
334
- return f"Based on the query provided, I found this: {best_match_info} \n\n{pubmed_result}"
 
 
 
 
 
 
335
  except Exception as e:
336
- return f"Medical Knowledge Search Failed {e}"
337
-
338
 
339
  class ForecastingEngine(ABC):
 
340
  @abstractmethod
341
  def predict(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
342
- pass
343
 
344
  class SimpleForecasting(ForecastingEngine):
345
- def predict(self, data: pd.DataFrame, period: int = 7, **kwargs) -> pd.DataFrame:
346
- #Placeholder for actual forecasting
347
- return pd.DataFrame({"forecast":[f"Forecast for the next {period} days"]})
 
348
 
349
  # ---------------------- Insights and Reporting Layer ---------------------------
350
- class AutomatedInsights():
 
 
351
  def __init__(self):
352
- self.analyses : Dict[str, DataAnalyzer] = {
353
- "EDA": AdvancedEDA(),
354
- "temporal": TemporalAnalyzer(),
355
- "distribution": DistributionVisualizer(),
356
- "hypothesis": HypothesisTester(),
357
- "model": LogisticRegressionTrainer()
358
- }
359
-
360
- def generate_insights(self, data: pd.DataFrame, analysis_names: List[str], **kwargs):
361
- results = {}
362
- for name in analysis_names:
363
- if name in self.analyses:
364
- analyzer = self.analyses[name]
365
- try:
366
- results[name] = analyzer.invoke(data=data, **kwargs)
367
- except Exception as e:
368
- results[name] = {"error": str(e)}
369
- else:
370
- results[name] = {"error": "Analysis not found"}
371
- return results
372
-
373
- class Dashboard():
 
374
  def __init__(self):
375
- self.layout: Dict[str,str] = {}
376
-
377
  def add_visualisation(self, vis_name: str, vis_type: str):
378
  self.layout[vis_name] = vis_type
379
-
380
- def display_dashboard(self, data_dict: Dict[str,pd.DataFrame]):
381
- st.header("Dashboard")
382
- for vis_name, vis_type in self.layout.items():
383
- st.subheader(vis_name)
384
- if vis_type == "table":
385
- if vis_name in data_dict:
386
- st.table(data_dict[vis_name])
387
- else:
388
- st.write("Data Not Found")
389
- elif vis_type == "plot":
390
- if vis_name in data_dict:
391
- df = data_dict[vis_name]
392
- if len(df.columns) > 1:
393
- fig = plt.figure()
394
- sns.lineplot(data=df)
395
- st.pyplot(fig)
396
- else:
397
- st.write("Please have more than 1 column")
398
- else:
399
- st.write("Data not found")
400
- class AutomatedReports():
401
- def __init__(self):
402
- self.report_definition: Dict[str,str] = {}
403
-
404
- def create_report_definition(self, report_name: str, definition: str):
405
- self.report_definition[report_name] = definition
406
-
407
- def generate_report(self, report_name: str, data:Dict[str, pd.DataFrame]):
408
- if report_name not in self.report_definition:
409
- return {"error":"Report name not found"}
410
- st.header(f"Report : {report_name}")
411
- st.write(f"Report Definition: {self.report_definition[report_name]}")
412
- for df_name, df in data.items():
413
- st.subheader(f"Data: {df_name}")
414
- st.table(df)
 
 
 
 
415
 
416
  # ---------------------- Data Acquisition Layer ---------------------------
 
417
  class DataSource(ABC):
418
  """Base class for data sources."""
419
  @abstractmethod
420
  def connect(self) -> None:
421
  """Connect to the data source."""
422
  pass
423
-
424
  @abstractmethod
425
  def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
426
- """Fetch the data based on a specific query."""
427
- pass
428
-
429
 
430
  class CSVDataSource(DataSource):
431
  """Data source for CSV files."""
432
- def __init__(self, file_path: str):
433
  self.file_path = file_path
434
  self.data: Optional[pd.DataFrame] = None
435
-
436
  def connect(self):
437
  self.data = pd.read_csv(self.file_path)
438
-
439
  def fetch_data(self, query: str = None, **kwargs) -> pd.DataFrame:
440
- if self.data is None:
441
- raise Exception("No connection is made, call connect()")
442
- return self.data
443
-
444
  class DatabaseSource(DataSource):
445
- def __init__(self, connection_string: str, database_type: str):
 
446
  self.connection_string = connection_string
447
- self.database_type = database_type
448
  self.connection = None
449
-
450
- def connect(self):
451
- if self.database_type.lower() == "sql":
452
- #Placeholder for the actual database connection
453
- self.connection = "Connected to SQL Database"
454
- else:
455
- raise Exception(f"Database type '{self.database_type}' is not supported")
456
-
457
- def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
458
  if self.connection is None:
459
  raise Exception("No connection is made, call connect()")
460
- #Placeholder for the data fetching
461
- return pd.DataFrame({"result":[f"Fetched data based on query: {query}"]})
462
-
463
 
464
  class DataIngestion:
 
465
  def __init__(self):
466
- self.sources : Dict[str, DataSource] = {}
467
-
468
  def add_source(self, source_name: str, source: DataSource):
469
- self.sources[source_name] = source
470
-
471
  def ingest_data(self, source_name: str, query: str = None, **kwargs) -> pd.DataFrame:
472
- if source_name not in self.sources:
473
- raise Exception(f"Source '{source_name}' not found")
474
- source = self.sources[source_name]
475
- source.connect()
476
- return source.fetch_data(query, **kwargs)
477
-
478
  class DataModel(BaseModel):
479
- name : str
480
- kpis : List[str] = Field(default_factory=list)
481
- dimensions : List[str] = Field(default_factory=list)
482
- custom_calculations : Optional[Dict[str, str]] = None
483
- relations: Optional[Dict[str,str]] = None #Example {table1: table2}
484
-
485
- def to_json(self):
486
- return json.dumps(self.dict())
487
-
488
- @staticmethod
489
- def from_json(json_str):
490
- return DataModel(**json.loads(json_str))
491
-
492
- class DataModelling():
 
 
493
  def __init__(self):
494
- self.models : Dict[str, DataModel] = {}
495
-
496
- def add_model(self, model:DataModel):
497
- self.models[model.name] = model
498
-
499
  def get_model(self, model_name: str) -> DataModel:
500
- if model_name not in self.models:
501
- raise Exception(f"Model '{model_name}' not found")
502
- return self.models[model_name]
 
503
  # ---------------------- Main Streamlit Application ---------------------------
 
504
  def main():
 
505
  st.set_page_config(page_title="AI Clinical Intelligence Hub", layout="wide")
506
  st.title("🏥 AI-Powered Clinical Intelligence Hub")
507
 
508
- # Session State
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  if 'data' not in st.session_state:
510
- st.session_state.data = {} # store pd.DataFrame under a name
511
  if 'data_ingestion' not in st.session_state:
512
- st.session_state.data_ingestion = DataIngestion()
513
  if 'data_modelling' not in st.session_state:
514
- st.session_state.data_modelling = DataModelling()
515
  if 'clinical_rules' not in st.session_state:
516
- st.session_state.clinical_rules = ClinicalRulesEngine()
517
  if 'kpi_monitoring' not in st.session_state:
518
- st.session_state.kpi_monitoring = ClinicalKPIMonitoring()
519
  if 'forecasting_engine' not in st.session_state:
520
- st.session_state.forecasting_engine = SimpleForecasting()
521
  if 'automated_insights' not in st.session_state:
522
- st.session_state.automated_insights = AutomatedInsights()
523
  if 'dashboard' not in st.session_state:
524
- st.session_state.dashboard = Dashboard()
525
  if 'automated_reports' not in st.session_state:
526
- st.session_state.automated_reports = AutomatedReports()
527
  if 'diagnosis_support' not in st.session_state:
528
- st.session_state.diagnosis_support = SimpleDiagnosis()
529
  if 'treatment_recommendation' not in st.session_state:
530
- st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
531
  if 'knowledge_base' not in st.session_state:
532
- st.session_state.knowledge_base = SimpleMedicalKnowledge()
533
  if 'pub_email' not in st.session_state:
534
  st.session_state.pub_email = st.secrets.get("PUB_EMAIL", "") # Load PUB_EMAIL from secrets
535
-
536
 
537
- # Sidebar for Data Management
538
- with st.sidebar:
539
- st.header("⚙️ Data Management")
540
- data_source_selection = st.selectbox("Select Data Source Type",["CSV","SQL Database"])
541
- if data_source_selection == "CSV":
542
- uploaded_file = st.file_uploader("Upload research dataset (CSV)", type=["csv"])
543
- if uploaded_file:
544
- source_name = st.text_input("Data Source Name")
545
- if source_name:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  try:
547
- csv_source = CSVDataSource(file_path=uploaded_file)
548
- st.session_state.data_ingestion.add_source(source_name,csv_source)
549
- st.success(f"Uploaded {uploaded_file.name}")
550
  except Exception as e:
551
- st.error(f"Error loading dataset: {e}")
552
- elif data_source_selection == "SQL Database":
553
- conn_str = st.text_input("Enter connection string for SQL DB")
554
- if conn_str:
555
- source_name = st.text_input("Data Source Name")
556
- if source_name:
557
- try:
558
- sql_source = DatabaseSource(connection_string=conn_str, database_type="sql")
559
- st.session_state.data_ingestion.add_source(source_name, sql_source)
560
- st.success(f"Added SQL DB Source {source_name}")
561
- except Exception as e:
562
- st.error(f"Error loading database source {e}")
563
-
564
-
565
- if st.button("Ingest Data"):
566
- if st.session_state.data_ingestion.sources:
567
- source_name_to_fetch = st.selectbox("Select Data Source to Ingest", list(st.session_state.data_ingestion.sources.keys()))
568
- query = st.text_area("Optional Query to Fetch data")
569
- if source_name_to_fetch:
570
- with st.spinner("Ingesting data..."):
571
- try:
572
- data = st.session_state.data_ingestion.ingest_data(source_name_to_fetch, query)
573
- st.session_state.data[source_name_to_fetch] = data
574
- st.success(f"Ingested data from {source_name_to_fetch}")
575
- except Exception as e:
576
- st.error(f"Ingestion failed: {e}")
577
- else:
578
- st.error("No data source added, please add data source")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
580
- if st.session_state.data:
581
- col1, col2 = st.columns([1, 3])
 
 
 
 
582
 
583
- with col1:
584
- st.subheader("Dataset Metadata")
585
-
586
- data_source_keys = list(st.session_state.data.keys())
587
- selected_data_key = st.selectbox("Select Dataset", data_source_keys)
588
-
589
- if selected_data_key:
590
- data = st.session_state.data[selected_data_key]
591
- st.json({
592
- "Variables": list(data.columns),
593
- "Time Range": {
594
- col: {
595
- "min": data[col].min(),
596
- "max": data[col].max()
597
- } for col in data.select_dtypes(include='datetime').columns
598
- },
599
- "Size": f"{data.memory_usage().sum() / 1e6:.2f} MB"
600
- })
601
- with col2:
602
- analysis_tab, clinical_logic_tab, insights_tab, reports_tab, knowledge_tab = st.tabs([
603
- "Data Analysis",
604
- "Clinical Logic",
605
- "Insights",
606
- "Reports",
607
- "Medical Knowledge"
608
- ])
609
-
610
- with analysis_tab:
611
- if selected_data_key:
612
- analysis_type = st.selectbox("Select Analysis Mode", [
613
- "Exploratory Data Analysis",
614
- "Temporal Pattern Analysis",
615
- "Comparative Statistics",
616
- "Distribution Analysis",
617
- "Train Logistic Regression Model"
618
- ])
619
- data = st.session_state.data[selected_data_key]
620
- if analysis_type == "Exploratory Data Analysis":
621
- analyzer = AdvancedEDA()
622
- eda_result = analyzer.invoke(data=data)
623
- st.subheader("Data Quality Report")
624
- st.json(eda_result)
625
-
626
- elif analysis_type == "Temporal Pattern Analysis":
627
- time_col = st.selectbox("Temporal Variable",
628
- data.select_dtypes(include='datetime').columns)
629
- value_col = st.selectbox("Analysis Variable",
630
- data.select_dtypes(include=np.number).columns)
631
-
632
- if time_col and value_col:
633
- analyzer = TemporalAnalyzer()
634
- result = analyzer.invoke(data=data, time_col=time_col, value_col=value_col)
635
- if "visualization" in result:
636
- st.image(f"data:image/png;base64,{result['visualization']}")
637
- st.json(result)
638
-
639
- elif analysis_type == "Comparative Statistics":
640
- group_col = st.selectbox("Grouping Variable",
641
- data.select_dtypes(include='category').columns)
642
- value_col = st.selectbox("Metric Variable",
643
- data.select_dtypes(include=np.number).columns)
644
-
645
- if group_col and value_col:
646
- analyzer = HypothesisTester()
647
- result = analyzer.invoke(data=data, group_col=group_col, value_col=value_col)
648
- st.subheader("Statistical Test Results")
649
- st.json(result)
650
-
651
- elif analysis_type == "Distribution Analysis":
652
- num_cols = data.select_dtypes(include=np.number).columns.tolist()
653
- selected_cols = st.multiselect("Select Variables", num_cols)
654
- if selected_cols:
655
- analyzer = DistributionVisualizer()
656
- img_data = analyzer.invoke(data=data, columns=selected_cols)
657
- st.image(f"data:image/png;base64,{img_data}")
658
-
659
- elif analysis_type == "Train Logistic Regression Model":
660
- num_cols = data.select_dtypes(include=np.number).columns.tolist()
661
- target_col = st.selectbox("Select Target Variable",
662
- data.columns.tolist())
663
- selected_cols = st.multiselect("Select Feature Variables", num_cols)
664
- if selected_cols and target_col:
665
- analyzer = LogisticRegressionTrainer()
666
- result = analyzer.invoke(data=data, target_col=target_col, columns=selected_cols)
667
- st.subheader("Logistic Regression Model Results")
668
- st.json(result)
669
- with clinical_logic_tab:
670
- st.header("Clinical Logic")
671
- st.subheader("Clinical Rules")
672
- rule_name = st.text_input("Enter Rule Name")
673
- condition = st.text_area("Enter Rule Condition (use 'df' for data frame), Example df['blood_pressure'] > 140")
674
- action = st.text_area("Enter Action to be Taken on Rule Match")
675
- severity = st.selectbox("Enter Severity for the Rule", ["low","medium","high"])
676
- if st.button("Add Clinical Rule"):
677
- try:
678
- rule = ClinicalRule(name=rule_name, condition=condition, action=action, severity=severity)
679
- st.session_state.clinical_rules.add_rule(rule)
680
- st.success("Added Clinical Rule")
681
- except Exception as e:
682
- st.error(f"Error in rule definition: {e}")
683
-
684
- st.subheader("Clinical KPI Definition")
685
- kpi_name = st.text_input("Enter KPI name")
686
- kpi_calculation = st.text_area("Enter KPI calculation (use 'df' for data frame), Example df['patient_count'].sum()")
687
- threshold = st.text_input("Enter Threshold for KPI")
688
- if st.button("Add Clinical KPI"):
689
- try:
690
- threshold_value = float(threshold) if threshold else None
691
- kpi = ClinicalKPI(name=kpi_name, calculation=kpi_calculation, threshold=threshold_value)
692
- st.session_state.kpi_monitoring.add_kpi(kpi)
693
- st.success(f"Added KPI {kpi_name}")
694
- except Exception as e:
695
- st.error(f"Error creating KPI: {e}")
696
-
697
- if selected_data_key:
698
- data = st.session_state.data[selected_data_key]
699
- if st.button("Execute Clinical Rules"):
700
- with st.spinner("Executing Clinical Rules.."):
701
- result = st.session_state.clinical_rules.execute_rules(data)
702
- st.json(result)
703
- if st.button("Calculate Clinical KPIs"):
704
- with st.spinner("Calculating Clinical KPIs..."):
705
- result = st.session_state.kpi_monitoring.calculate_kpis(data)
706
- st.json(result)
707
- with insights_tab:
708
- if selected_data_key:
709
- data = st.session_state.data[selected_data_key]
710
- available_analysis = ["EDA", "temporal", "distribution", "hypothesis", "model"]
711
- selected_analysis = st.multiselect("Select Analysis", available_analysis)
712
- if st.button("Generate Automated Insights"):
713
- with st.spinner("Generating Insights"):
714
- results = st.session_state.automated_insights.generate_insights(data, analysis_names=selected_analysis)
715
- st.json(results)
716
- st.subheader("Diagnosis Support")
717
- target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
718
- num_cols = data.select_dtypes(include=np.number).columns.tolist()
719
- selected_cols_diagnosis = st.multiselect("Select Feature Variables for Diagnosis", num_cols)
720
- if st.button("Generate Diagnosis"):
721
- if target_col
722
- if target_col and selected_cols_diagnosis:
723
- with st.spinner("Generating Diagnosis"):
724
- result = st.session_state.diagnosis_support.diagnose(data, target_col=target_col, columns=selected_cols_diagnosis, diagnosis_key="diagnosis_result")
725
- st.json(result)
726
-
727
- st.subheader("Treatment Recommendation")
728
- condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
729
- treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
730
- if st.button("Generate Treatment Recommendation"):
731
- if condition_col and treatment_col:
732
- with st.spinner("Generating Treatment Recommendation"):
733
- result = st.session_state.treatment_recommendation.recommend(data, condition_col = condition_col, treatment_col = treatment_col, recommendation_key="treatment_recommendation")
734
- st.json(result)
735
-
736
- with reports_tab:
737
- st.header("Reports")
738
- report_name = st.text_input("Report Name")
739
- report_def = st.text_area("Report definition")
740
- if st.button("Create Report Definition"):
741
- st.session_state.automated_reports.create_report_definition(report_name, report_def)
742
- st.success("Report definition created")
743
- if selected_data_key:
744
- data = st.session_state.data
745
- if st.button("Generate Report"):
746
- with st.spinner("Generating Report..."):
747
- report = st.session_state.automated_reports.generate_report(report_name, data)
748
- with knowledge_tab:
749
- st.header("Medical Knowledge")
750
- query = st.text_input("Enter your medical question here:")
751
- if st.button("Search"):
752
- with st.spinner("Searching..."):
753
- result = st.session_state.knowledge_base.search_medical_info(query, pub_email=st.session_state.pub_email)
754
- st.write(result)
755
 
756
  if __name__ == "__main__":
757
- main()
 
1
+ import os
2
+ import json
 
 
3
  import base64
4
  import io
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, List, Optional, Any
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
  import matplotlib.pyplot as plt
11
  import seaborn as sns
12
+ import streamlit as st
13
+
14
+ from scipy.stats import ttest_ind, f_oneway
15
  from sklearn.model_selection import train_test_split
16
  from sklearn.linear_model import LogisticRegression
17
  from sklearn.metrics import accuracy_score
18
+ from sklearn.feature_extraction.text import TfidfVectorizer
19
+ from sklearn.metrics.pairwise import cosine_similarity
20
+
21
  from statsmodels.tsa.seasonal import seasonal_decompose
22
  from statsmodels.tsa.stattools import adfuller
23
+
24
+ from pydantic import BaseModel, Field
25
+ from Bio import Entrez
26
+
27
  from langchain.prompts import PromptTemplate
28
  from groq import Groq
 
 
 
 
 
 
 
29
 
30
+ # ---------------------- Initialize External Clients ---------------------------
31
+ # Initialize Groq Client with API Key from environment variables
32
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
33
 
34
  # ---------------------- Base Classes and Schemas ---------------------------
35
+
36
  class ResearchInput(BaseModel):
37
+ """Base schema for research tool inputs."""
38
  data_key: str = Field(..., description="Session state key containing DataFrame")
39
  columns: Optional[List[str]] = Field(None, description="List of columns to analyze")
40
 
41
  class TemporalAnalysisInput(ResearchInput):
42
+ """Schema for temporal analysis."""
43
  time_col: str = Field(..., description="Name of timestamp column")
44
  value_col: str = Field(..., description="Name of value column to analyze")
45
 
46
  class HypothesisInput(ResearchInput):
47
+ """Schema for hypothesis testing."""
48
  group_col: str = Field(..., description="Categorical column defining groups")
49
  value_col: str = Field(..., description="Numerical column to compare")
50
 
51
  class ModelTrainingInput(ResearchInput):
52
+ """Schema for model training."""
53
  target_col: str = Field(..., description="Name of target column")
54
 
55
  class DataAnalyzer(ABC):
56
+ """Abstract base class for data analysis modules."""
57
  @abstractmethod
58
+ def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
59
  pass
60
 
61
  # ---------------------- Concrete Analyzer Implementations ---------------------------
62
+
63
  class AdvancedEDA(DataAnalyzer):
64
+ """Comprehensive Exploratory Data Analysis."""
65
  def invoke(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
66
  try:
67
  analysis = {
 
92
  return {"error": f"EDA Failed: {str(e)}"}
93
 
94
  class DistributionVisualizer(DataAnalyzer):
95
+ """Distribution visualizations."""
96
  def invoke(self, data: pd.DataFrame, columns: List[str], **kwargs) -> str:
97
+ try:
98
+ plt.figure(figsize=(12, 6))
99
+ for i, col in enumerate(columns, 1):
100
+ plt.subplot(1, len(columns), i)
101
+ sns.histplot(data[col], kde=True, stat="density")
102
+ plt.title(f'Distribution of {col}', fontsize=10)
103
+ plt.xticks(fontsize=8)
104
+ plt.yticks(fontsize=8)
105
+ plt.tight_layout()
106
+
107
+ buf = io.BytesIO()
108
+ plt.savefig(buf, format='png', dpi=300, bbox_inches='tight')
109
+ plt.close()
110
+ return base64.b64encode(buf.getvalue()).decode()
111
+ except Exception as e:
112
+ return f"Visualization Error: {str(e)}"
113
 
114
  class TemporalAnalyzer(DataAnalyzer):
115
+ """Time series analysis."""
116
  def invoke(self, data: pd.DataFrame, time_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
117
  try:
118
  ts_data = data.set_index(pd.to_datetime(data[time_col]))[value_col]
119
  decomposition = seasonal_decompose(ts_data, period=365)
120
+
121
  plt.figure(figsize=(12, 8))
122
  decomposition.plot()
123
  plt.tight_layout()
124
+
125
  buf = io.BytesIO()
126
  plt.savefig(buf, format='png')
127
  plt.close()
128
  plot_data = base64.b64encode(buf.getvalue()).decode()
129
+
130
  return {
131
  "trend_statistics": {
132
+ "stationarity_p_value": adfuller(ts_data)[1],
133
  "seasonality_strength": max(decomposition.seasonal)
134
  },
135
  "visualization": plot_data
 
138
  return {"error": f"Temporal Analysis Failed: {str(e)}"}
139
 
140
  class HypothesisTester(DataAnalyzer):
141
+ """Statistical hypothesis testing."""
142
  def invoke(self, data: pd.DataFrame, group_col: str, value_col: str, **kwargs) -> Dict[str, Any]:
143
+ try:
144
+ groups = data[group_col].unique()
145
+
146
+ if len(groups) < 2:
147
+ return {"error": "Insufficient groups for comparison"}
148
+
 
149
  group_data = [data[data[group_col] == g][value_col] for g in groups]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ if len(groups) == 2:
152
+ stat, p = ttest_ind(*group_data)
153
+ test_type = "Independent t-test"
154
+ effect_size = self.calculate_cohens_d(group_data[0], group_data[1])
155
+ else:
156
+ stat, p = f_oneway(*group_data)
157
+ test_type = "ANOVA"
158
+ effect_size = None
159
+
160
+ return {
161
+ "test_type": test_type,
162
+ "test_statistic": stat,
163
+ "p_value": p,
164
+ "effect_size": effect_size,
165
+ "interpretation": self.interpret_p_value(p)
166
+ }
167
+ except Exception as e:
168
+ return {"error": f"Hypothesis Testing Failed: {str(e)}"}
169
+
170
+ @staticmethod
171
+ def calculate_cohens_d(x: pd.Series, y: pd.Series) -> Optional[float]:
172
+ """Calculate Cohen's d for effect size."""
173
+ try:
174
+ mean_diff = abs(x.mean() - y.mean())
175
+ pooled_std = np.sqrt((x.var() + y.var()) / 2)
176
+ return mean_diff / pooled_std
177
+ except Exception:
178
+ return None
179
+
180
+ @staticmethod
181
+ def interpret_p_value(p: float) -> str:
182
+ """Interpret the p-value."""
183
+ if p < 0.001:
184
+ return "Very strong evidence against H0"
185
+ elif p < 0.01:
186
+ return "Strong evidence against H0"
187
+ elif p < 0.05:
188
+ return "Evidence against H0"
189
+ elif p < 0.1:
190
+ return "Weak evidence against H0"
191
+ else:
192
+ return "No significant evidence against H0"
193
 
194
  class LogisticRegressionTrainer(DataAnalyzer):
195
+ """Logistic Regression Model Trainer."""
196
  def invoke(self, data: pd.DataFrame, target_col: str, columns: List[str], **kwargs) -> Dict[str, Any]:
197
+ try:
198
+ X = data[columns]
199
+ y = data[target_col]
200
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
201
+ model = LogisticRegression(max_iter=1000)
202
+ model.fit(X_train, y_train)
203
+ y_pred = model.predict(X_test)
204
+ accuracy = accuracy_score(y_test, y_pred)
205
+ return {
206
+ "model_type": "Logistic Regression",
207
+ "accuracy": accuracy,
208
+ "model_params": model.get_params()
209
+ }
210
+ except Exception as e:
211
+ return {"error": f"Logistic Regression Model Error: {str(e)}"}
212
+
213
  # ---------------------- Business Logic Layer ---------------------------
214
 
215
  class ClinicalRule(BaseModel):
216
+ """Defines a clinical rule."""
217
  name: str
218
  condition: str
219
  action: str
220
+ severity: str # low, medium, or high
221
 
222
+ class ClinicalRulesEngine:
223
  """Executes rules against patient data."""
224
  def __init__(self):
225
+ self.rules: Dict[str, ClinicalRule] = {}
226
+
227
  def add_rule(self, rule: ClinicalRule):
228
  self.rules[rule.name] = rule
229
 
230
+ def execute_rules(self, data: pd.DataFrame) -> Dict[str, Any]:
231
  results = {}
232
  for rule_name, rule in self.rules.items():
233
  try:
234
+ # Evaluate the condition using the dataframe 'df'
235
+ rule_matched = eval(rule.condition, {}, {"df": data})
236
+ results[rule_name] = {
237
+ "rule_matched": rule_matched,
238
+ "action": rule.action if rule_matched else None,
239
+ "severity": rule.severity if rule_matched else None
240
+ }
241
  except Exception as e:
242
+ results[rule_name] = {
243
+ "rule_matched": False,
244
+ "error": str(e),
245
+ "severity": None
246
+ }
247
  return results
248
 
249
  class ClinicalKPI(BaseModel):
250
+ """Define a clinical KPI."""
251
+ name: str
252
+ calculation: str
253
+ threshold: Optional[float] = None
254
 
255
+ class ClinicalKPIMonitoring:
256
+ """Calculates KPIs based on data."""
257
  def __init__(self):
258
+ self.kpis: Dict[str, ClinicalKPI] = {}
259
 
260
+ def add_kpi(self, kpi: ClinicalKPI):
261
+ self.kpis[kpi.name] = kpi
262
 
263
+ def calculate_kpis(self, data: pd.DataFrame) -> Dict[str, Any]:
264
  results = {}
265
  for kpi_name, kpi in self.kpis.items():
266
  try:
267
+ kpi_value = eval(kpi.calculation, {}, {"df": data})
268
+ results[kpi_name] = {
269
+ "value": kpi_value,
270
+ "threshold": kpi.threshold,
271
+ "status": self.evaluate_threshold(kpi_value, kpi.threshold)
272
+ }
273
  except Exception as e:
274
  results[kpi_name] = {"error": str(e)}
275
  return results
276
 
277
+ @staticmethod
278
+ def evaluate_threshold(value: Any, threshold: Optional[float]) -> Optional[str]:
279
+ if threshold is None:
280
+ return None
281
+ try:
282
+ return "Above Threshold" if value > threshold else "Below Threshold"
283
+ except TypeError:
284
+ return "Threshold Evaluation Not Applicable"
285
+
286
  class DiagnosisSupport(ABC):
287
+ """Abstract class for implementing clinical diagnoses."""
288
+ @abstractmethod
289
+ def diagnose(
290
+ self,
291
+ data: pd.DataFrame,
292
+ target_col: str,
293
+ columns: List[str],
294
+ diagnosis_key: str = "diagnosis",
295
+ **kwargs
296
+ ) -> pd.DataFrame:
297
+ pass
298
 
299
  class SimpleDiagnosis(DiagnosisSupport):
300
+ """Provides a simple diagnosis example, based on the Logistic regression model."""
301
  def __init__(self):
302
+ self.model: LogisticRegressionTrainer = LogisticRegressionTrainer()
303
+
304
+ def diagnose(
305
+ self,
306
+ data: pd.DataFrame,
307
+ target_col: str,
308
+ columns: List[str],
309
+ diagnosis_key: str = "diagnosis",
310
+ **kwargs
311
+ ) -> pd.DataFrame:
312
  try:
313
+ result = self.model.invoke(data, target_col=target_col, columns=columns)
314
+ if "accuracy" in result:
315
+ return pd.DataFrame({
316
+ diagnosis_key: [f"Model Accuracy: {result['accuracy']:.2%}"],
317
+ "model": [result["model_type"]]
318
+ })
319
+ else:
320
+ return pd.DataFrame({
321
+ diagnosis_key: [f"Diagnosis failed: {result.get('error', 'Unknown error')}"]
322
+ })
323
  except Exception as e:
324
+ return pd.DataFrame({
325
+ diagnosis_key: [f"Error during diagnosis: {e}"]
326
+ })
327
 
328
  class TreatmentRecommendation(ABC):
329
+ """Abstract class for treatment recommendations."""
330
+ @abstractmethod
331
+ def recommend(
332
+ self,
333
+ data: pd.DataFrame,
334
+ condition_col: str,
335
+ treatment_col: str,
336
+ recommendation_key: str = "recommendation",
337
+ **kwargs
338
+ ) -> pd.DataFrame:
339
+ pass
340
 
341
  class BasicTreatmentRecommendation(TreatmentRecommendation):
342
+ """A placeholder class for basic treatment recommendations."""
343
+ def recommend(
344
+ self,
345
+ data: pd.DataFrame,
346
+ condition_col: str,
347
+ treatment_col: str,
348
+ recommendation_key: str = "recommendation",
349
+ **kwargs
350
+ ) -> pd.DataFrame:
351
+ if condition_col not in data.columns or treatment_col not in data.columns:
352
+ return pd.DataFrame({
353
+ recommendation_key: ["Condition or Treatment columns not found!"]
354
+ })
355
+
356
+ treatment = data[data[condition_col] == "High"][treatment_col].to_list()
357
+ if treatment:
358
+ return pd.DataFrame({
359
+ recommendation_key: [f"Treatment recommended for High risk patients: {treatment}"]
360
+ })
361
+ else:
362
+ return pd.DataFrame({
363
+ recommendation_key: ["No treatment recommendation found!"]
364
+ })
365
+
366
+ class MedicalKnowledgeBase(ABC):
367
+ """Abstract class for Medical Knowledge."""
368
+ @abstractmethod
369
+ def search_medical_info(self, query: str, pub_email: str = "") -> str:
370
+ pass
371
 
372
  class SimpleMedicalKnowledge(MedicalKnowledgeBase):
373
+ """Simple Medical Knowledge Class with TF-IDF and PubMed."""
374
  def __init__(self):
375
+ self.knowledge_base = {
376
  "diabetes": "The recommended treatment for diabetes includes lifestyle changes, medication, and monitoring.",
377
  "heart disease": "Risk factors for heart disease include high blood pressure, high cholesterol, and smoking.",
378
+ "fever": "For a fever, you can consider over-the-counter medications like acetaminophen or ibuprofen. Rest and hydration are also important.",
379
  "headache": "For a headache, try rest, hydration, and over-the-counter pain relievers. Consult a doctor if it is severe or persistent.",
380
  "cold": "For a cold, get rest, drink plenty of fluids, and use over-the-counter remedies like decongestants."
381
+ }
382
+ self.vectorizer = TfidfVectorizer()
383
+ self.tfidf_matrix = self.vectorizer.fit_transform(self.knowledge_base.values())
384
+
385
  def search_pubmed(self, query: str, email: str) -> str:
386
+ """Search PubMed for abstracts related to the query."""
387
+ try:
388
+ Entrez.email = email
389
+ handle = Entrez.esearch(db="pubmed", term=query, retmax=1)
390
+ record = Entrez.read(handle)
 
 
 
391
  handle.close()
392
+ if record["IdList"]:
393
+ handle = Entrez.efetch(db="pubmed", id=record["IdList"][0], rettype="abstract", retmode="text")
394
+ abstract = handle.read()
395
+ handle.close()
396
+ return abstract
397
+ else:
398
+ return "No abstracts found for this query on PubMed."
399
+ except Exception as e:
400
+ return f"Error searching PubMed: {e}"
401
 
402
  def search_medical_info(self, query: str, pub_email: str = "") -> str:
403
+ """Search the medical knowledge base and PubMed for relevant information."""
404
  try:
405
+ query_vector = self.vectorizer.transform([query])
406
+ similarities = cosine_similarity(query_vector, self.tfidf_matrix)
407
+ best_match_index = np.argmax(similarities)
408
+ best_match_keyword = list(self.knowledge_base.keys())[best_match_index]
409
+ best_match_info = list(self.knowledge_base.values())[best_match_index]
410
+
411
+ pubmed_result = self.search_pubmed(query, pub_email)
412
+ if "No abstracts found" not in pubmed_result:
413
+ return (
414
+ f"**Based on your query:** {best_match_info}\n\n"
415
+ f"**PubMed Abstract:**\n{pubmed_result}"
416
+ )
417
+ else:
418
+ return (
419
+ f"**Based on your query:** {best_match_info}\n\n"
420
+ f"{pubmed_result}"
421
+ )
422
  except Exception as e:
423
+ return f"Medical Knowledge Search Failed: {e}"
 
424
 
425
  class ForecastingEngine(ABC):
426
+ """Abstract class for forecasting."""
427
  @abstractmethod
428
  def predict(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
429
+ pass
430
 
431
  class SimpleForecasting(ForecastingEngine):
432
+ """Simple forecasting engine."""
433
+ def predict(self, data: pd.DataFrame, period: int = 7, **kwargs) -> pd.DataFrame:
434
+ # Placeholder for actual forecasting logic
435
+ return pd.DataFrame({"forecast": [f"Forecast for the next {period} days"]})
436
 
437
  # ---------------------- Insights and Reporting Layer ---------------------------
438
+
439
+ class AutomatedInsights:
440
+ """Generates automated insights based on selected analyses."""
441
  def __init__(self):
442
+ self.analyses: Dict[str, DataAnalyzer] = {
443
+ "EDA": AdvancedEDA(),
444
+ "temporal": TemporalAnalyzer(),
445
+ "distribution": DistributionVisualizer(),
446
+ "hypothesis": HypothesisTester(),
447
+ "model": LogisticRegressionTrainer()
448
+ }
449
+
450
+ def generate_insights(self, data: pd.DataFrame, analysis_names: List[str], **kwargs) -> Dict[str, Any]:
451
+ results = {}
452
+ for name in analysis_names:
453
+ analyzer = self.analyses.get(name)
454
+ if analyzer:
455
+ try:
456
+ results[name] = analyzer.invoke(data=data, **kwargs)
457
+ except Exception as e:
458
+ results[name] = {"error": str(e)}
459
+ else:
460
+ results[name] = {"error": "Analysis not found"}
461
+ return results
462
+
463
+ class Dashboard:
464
+ """Handles the creation and display of the dashboard."""
465
  def __init__(self):
466
+ self.layout: Dict[str, str] = {}
467
+
468
  def add_visualisation(self, vis_name: str, vis_type: str):
469
  self.layout[vis_name] = vis_type
470
+
471
+ def display_dashboard(self, data_dict: Dict[str, pd.DataFrame]):
472
+ st.header("Dashboard")
473
+ for vis_name, vis_type in self.layout.items():
474
+ st.subheader(vis_name)
475
+ if vis_type == "table":
476
+ df = data_dict.get(vis_name)
477
+ if df is not None:
478
+ st.table(df)
479
+ else:
480
+ st.write("Data Not Found")
481
+ elif vis_type == "plot":
482
+ df = data_dict.get(vis_name)
483
+ if df is not None:
484
+ if len(df.columns) > 1:
485
+ fig = plt.figure()
486
+ sns.lineplot(data=df)
487
+ st.pyplot(fig)
488
+ else:
489
+ st.write("Please select a DataFrame with more than 1 column for plotting.")
490
+ else:
491
+ st.write("Data not found")
492
+
493
+ class AutomatedReports:
494
+ """Manages automated report definitions and generation."""
495
+ def __init__(self):
496
+ self.report_definitions: Dict[str, str] = {}
497
+
498
+ def create_report_definition(self, report_name: str, definition: str):
499
+ self.report_definitions[report_name] = definition
500
+
501
+ def generate_report(self, report_name: str, data: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
502
+ if report_name not in self.report_definitions:
503
+ return {"error": "Report name not found"}
504
+ report_content = {
505
+ "Report Name": report_name,
506
+ "Report Definition": self.report_definitions[report_name],
507
+ "Data": {df_name: df.to_dict() for df_name, df in data.items()}
508
+ }
509
+ return report_content
510
 
511
  # ---------------------- Data Acquisition Layer ---------------------------
512
+
513
  class DataSource(ABC):
514
  """Base class for data sources."""
515
  @abstractmethod
516
  def connect(self) -> None:
517
  """Connect to the data source."""
518
  pass
519
+
520
  @abstractmethod
521
  def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
522
+ """Fetch the data based on a specific query."""
523
+ pass
 
524
 
525
  class CSVDataSource(DataSource):
526
  """Data source for CSV files."""
527
+ def __init__(self, file_path: io.BytesIO):
528
  self.file_path = file_path
529
  self.data: Optional[pd.DataFrame] = None
530
+
531
  def connect(self):
532
  self.data = pd.read_csv(self.file_path)
533
+
534
  def fetch_data(self, query: str = None, **kwargs) -> pd.DataFrame:
535
+ if self.data is None:
536
+ raise Exception("No connection is made, call connect()")
537
+ return self.data
538
+
539
  class DatabaseSource(DataSource):
540
+ """Data source for SQL Databases."""
541
+ def __init__(self, connection_string: str, database_type: str):
542
  self.connection_string = connection_string
543
+ self.database_type = database_type.lower()
544
  self.connection = None
545
+
546
+ def connect(self):
547
+ if self.database_type == "sql":
548
+ # Placeholder for actual SQL connection logic
549
+ self.connection = "Connected to SQL Database"
550
+ else:
551
+ raise Exception(f"Database type '{self.database_type}' is not supported.")
552
+
553
+ def fetch_data(self, query: str, **kwargs) -> pd.DataFrame:
554
  if self.connection is None:
555
  raise Exception("No connection is made, call connect()")
556
+ # Placeholder for data fetching logic
557
+ return pd.DataFrame({"result": [f"Fetched data based on query: {query}"]})
 
558
 
559
  class DataIngestion:
560
+ """Handles data ingestion from various sources."""
561
  def __init__(self):
562
+ self.sources: Dict[str, DataSource] = {}
563
+
564
  def add_source(self, source_name: str, source: DataSource):
565
+ self.sources[source_name] = source
566
+
567
  def ingest_data(self, source_name: str, query: str = None, **kwargs) -> pd.DataFrame:
568
+ if source_name not in self.sources:
569
+ raise Exception(f"Source '{source_name}' not found.")
570
+ source = self.sources[source_name]
571
+ source.connect()
572
+ return source.fetch_data(query, **kwargs)
573
+
574
  class DataModel(BaseModel):
575
+ """Defines a data model."""
576
+ name: str
577
+ kpis: List[str] = Field(default_factory=list)
578
+ dimensions: List[str] = Field(default_factory=list)
579
+ custom_calculations: Optional[Dict[str, str]] = None
580
+ relations: Optional[Dict[str, str]] = None # Example: {"table1": "table2"}
581
+
582
+ def to_json(self) -> str:
583
+ return json.dumps(self.dict())
584
+
585
+ @staticmethod
586
+ def from_json(json_str: str) -> 'DataModel':
587
+ return DataModel(**json.loads(json_str))
588
+
589
+ class DataModelling:
590
+ """Manages data models."""
591
  def __init__(self):
592
+ self.models: Dict[str, DataModel] = {}
593
+
594
+ def add_model(self, model: DataModel):
595
+ self.models[model.name] = model
596
+
597
  def get_model(self, model_name: str) -> DataModel:
598
+ if model_name not in self.models:
599
+ raise Exception(f"Model '{model_name}' not found.")
600
+ return self.models[model_name]
601
+
602
  # ---------------------- Main Streamlit Application ---------------------------
603
+
604
  def main():
605
+ """Main function to run the Streamlit app."""
606
  st.set_page_config(page_title="AI Clinical Intelligence Hub", layout="wide")
607
  st.title("🏥 AI-Powered Clinical Intelligence Hub")
608
 
609
+ # Initialize Session State
610
+ initialize_session_state()
611
+
612
+ # Sidebar for Data Management
613
+ with st.sidebar:
614
+ data_management_section()
615
+
616
+ # Main Content
617
+ if st.session_state.data:
618
+ col1, col2 = st.columns([1, 3])
619
+
620
+ with col1:
621
+ dataset_metadata_section()
622
+
623
+ with col2:
624
+ main_tabs_section()
625
+
626
+ def initialize_session_state():
627
+ """Initialize necessary components in Streamlit's session state."""
628
  if 'data' not in st.session_state:
629
+ st.session_state.data = {} # Store pd.DataFrame under a name
630
  if 'data_ingestion' not in st.session_state:
631
+ st.session_state.data_ingestion = DataIngestion()
632
  if 'data_modelling' not in st.session_state:
633
+ st.session_state.data_modelling = DataModelling()
634
  if 'clinical_rules' not in st.session_state:
635
+ st.session_state.clinical_rules = ClinicalRulesEngine()
636
  if 'kpi_monitoring' not in st.session_state:
637
+ st.session_state.kpi_monitoring = ClinicalKPIMonitoring()
638
  if 'forecasting_engine' not in st.session_state:
639
+ st.session_state.forecasting_engine = SimpleForecasting()
640
  if 'automated_insights' not in st.session_state:
641
+ st.session_state.automated_insights = AutomatedInsights()
642
  if 'dashboard' not in st.session_state:
643
+ st.session_state.dashboard = Dashboard()
644
  if 'automated_reports' not in st.session_state:
645
+ st.session_state.automated_reports = AutomatedReports()
646
  if 'diagnosis_support' not in st.session_state:
647
+ st.session_state.diagnosis_support = SimpleDiagnosis()
648
  if 'treatment_recommendation' not in st.session_state:
649
+ st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
650
  if 'knowledge_base' not in st.session_state:
651
+ st.session_state.knowledge_base = SimpleMedicalKnowledge()
652
  if 'pub_email' not in st.session_state:
653
  st.session_state.pub_email = st.secrets.get("PUB_EMAIL", "") # Load PUB_EMAIL from secrets
 
654
 
655
+ def data_management_section():
656
+ """Handles the data management section in the sidebar."""
657
+ st.header("⚙️ Data Management")
658
+ data_source_selection = st.selectbox("Select Data Source Type", ["CSV", "SQL Database"])
659
+
660
+ if data_source_selection == "CSV":
661
+ handle_csv_upload()
662
+ elif data_source_selection == "SQL Database":
663
+ handle_sql_database()
664
+
665
+ if st.button("Ingest Data"):
666
+ ingest_data_action()
667
+
668
+ def handle_csv_upload():
669
+ """Handles CSV file uploads."""
670
+ uploaded_file = st.file_uploader("Upload research dataset (CSV)", type=["csv"])
671
+ if uploaded_file:
672
+ source_name = st.text_input("Data Source Name")
673
+ if source_name:
674
+ try:
675
+ csv_source = CSVDataSource(file_path=uploaded_file)
676
+ st.session_state.data_ingestion.add_source(source_name, csv_source)
677
+ st.success(f"Uploaded {uploaded_file.name} as '{source_name}'.")
678
+ except Exception as e:
679
+ st.error(f"Error loading dataset: {e}")
680
+
681
+ def handle_sql_database():
682
+ """Handles SQL database connections."""
683
+ conn_str = st.text_input("Enter connection string for SQL DB")
684
+ if conn_str:
685
+ source_name = st.text_input("Data Source Name")
686
+ if source_name:
687
+ try:
688
+ sql_source = DatabaseSource(connection_string=conn_str, database_type="sql")
689
+ st.session_state.data_ingestion.add_source(source_name, sql_source)
690
+ st.success(f"Added SQL DB Source '{source_name}'.")
691
+ except Exception as e:
692
+ st.error(f"Error loading database source: {e}")
693
+
694
+ def ingest_data_action():
695
+ """Performs data ingestion from the selected source."""
696
+ if st.session_state.data_ingestion.sources:
697
+ source_name_to_fetch = st.selectbox("Select Data Source to Ingest", list(st.session_state.data_ingestion.sources.keys()))
698
+ query = st.text_area("Optional Query to Fetch data")
699
+ if source_name_to_fetch:
700
+ with st.spinner("Ingesting data..."):
701
  try:
702
+ data = st.session_state.data_ingestion.ingest_data(source_name_to_fetch, query)
703
+ st.session_state.data[source_name_to_fetch] = data
704
+ st.success(f"Ingested data from '{source_name_to_fetch}'.")
705
  except Exception as e:
706
+ st.error(f"Ingestion failed: {e}")
707
+ else:
708
+ st.error("No data source added. Please add a data source.")
709
+
710
+ def dataset_metadata_section():
711
+ """Displays metadata for the selected dataset."""
712
+ st.subheader("Dataset Metadata")
713
+ data_source_keys = list(st.session_state.data.keys())
714
+ selected_data_key = st.selectbox("Select Dataset", data_source_keys)
715
+
716
+ if selected_data_key:
717
+ data = st.session_state.data[selected_data_key]
718
+ metadata = {
719
+ "Variables": list(data.columns),
720
+ "Time Range": {
721
+ col: {
722
+ "min": data[col].min(),
723
+ "max": data[col].max()
724
+ } for col in data.select_dtypes(include='datetime').columns
725
+ },
726
+ "Size": f"{data.memory_usage().sum() / 1e6:.2f} MB"
727
+ }
728
+ st.json(metadata)
729
+
730
+ def main_tabs_section():
731
+ """Creates and manages the main tabs in the application."""
732
+ analysis_tab, clinical_logic_tab, insights_tab, reports_tab, knowledge_tab = st.tabs([
733
+ "Data Analysis",
734
+ "Clinical Logic",
735
+ "Insights",
736
+ "Reports",
737
+ "Medical Knowledge"
738
+ ])
739
+
740
+ with analysis_tab:
741
+ data_analysis_section()
742
+
743
+ with clinical_logic_tab:
744
+ clinical_logic_section()
745
+
746
+ with insights_tab:
747
+ insights_section()
748
+
749
+ with reports_tab:
750
+ reports_section()
751
+
752
+ with knowledge_tab:
753
+ medical_knowledge_section()
754
+
755
+ def data_analysis_section():
756
+ """Handles the Data Analysis tab."""
757
+ selected_data_key = st.sidebar.selectbox("Select Dataset for Analysis", list(st.session_state.data.keys()))
758
+ if not selected_data_key:
759
+ st.warning("Please select a dataset to perform analysis.")
760
+ return
761
+
762
+ data = st.session_state.data[selected_data_key]
763
+ analysis_type = st.selectbox("Select Analysis Mode", [
764
+ "Exploratory Data Analysis",
765
+ "Temporal Pattern Analysis",
766
+ "Comparative Statistics",
767
+ "Distribution Analysis",
768
+ "Train Logistic Regression Model"
769
+ ])
770
+
771
+ if analysis_type == "Exploratory Data Analysis":
772
+ perform_eda(data)
773
+ elif analysis_type == "Temporal Pattern Analysis":
774
+ perform_temporal_analysis(data)
775
+ elif analysis_type == "Comparative Statistics":
776
+ perform_comparative_statistics(data)
777
+ elif analysis_type == "Distribution Analysis":
778
+ perform_distribution_analysis(data)
779
+ elif analysis_type == "Train Logistic Regression Model":
780
+ perform_logistic_regression_training(data)
781
+
782
+ def perform_eda(data: pd.DataFrame):
783
+ """Performs Exploratory Data Analysis."""
784
+ analyzer = AdvancedEDA()
785
+ eda_result = analyzer.invoke(data=data)
786
+ st.subheader("Data Quality Report")
787
+ st.json(eda_result)
788
+
789
+ def perform_temporal_analysis(data: pd.DataFrame):
790
+ """Performs Temporal Pattern Analysis."""
791
+ time_cols = data.select_dtypes(include='datetime').columns
792
+ num_cols = data.select_dtypes(include=np.number).columns
793
+
794
+ time_col = st.selectbox("Select Temporal Variable", time_cols)
795
+ value_col = st.selectbox("Select Analysis Variable", num_cols)
796
+
797
+ if time_col and value_col:
798
+ analyzer = TemporalAnalyzer()
799
+ result = analyzer.invoke(data=data, time_col=time_col, value_col=value_col)
800
+ if "visualization" in result:
801
+ st.image(f"data:image/png;base64,{result['visualization']}")
802
+ st.json(result)
803
+
804
+ def perform_comparative_statistics(data: pd.DataFrame):
805
+ """Performs Comparative Statistics."""
806
+ categorical_cols = data.select_dtypes(include='category').columns
807
+ numeric_cols = data.select_dtypes(include=np.number).columns
808
+
809
+ group_col = st.selectbox("Select Grouping Variable", categorical_cols)
810
+ value_col = st.selectbox("Select Metric Variable", numeric_cols)
811
+
812
+ if group_col and value_col:
813
+ analyzer = HypothesisTester()
814
+ result = analyzer.invoke(data=data, group_col=group_col, value_col=value_col)
815
+ st.subheader("Statistical Test Results")
816
+ st.json(result)
817
+
818
+ def perform_distribution_analysis(data: pd.DataFrame):
819
+ """Performs Distribution Analysis."""
820
+ numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
821
+ selected_cols = st.multiselect("Select Variables for Distribution Analysis", numeric_cols)
822
+
823
+ if selected_cols:
824
+ analyzer = DistributionVisualizer()
825
+ img_data = analyzer.invoke(data=data, columns=selected_cols)
826
+ if "Visualization Error" not in img_data:
827
+ st.image(f"data:image/png;base64,{img_data}")
828
+ else:
829
+ st.error(img_data)
830
+
831
+ def perform_logistic_regression_training(data: pd.DataFrame):
832
+ """Trains a Logistic Regression model."""
833
+ numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
834
+ target_col = st.selectbox("Select Target Variable", data.columns.tolist())
835
+ selected_cols = st.multiselect("Select Feature Variables", numeric_cols)
836
+
837
+ if selected_cols and target_col:
838
+ analyzer = LogisticRegressionTrainer()
839
+ result = analyzer.invoke(data=data, target_col=target_col, columns=selected_cols)
840
+ st.subheader("Logistic Regression Model Results")
841
+ st.json(result)
842
+
843
+ def clinical_logic_section():
844
+ """Handles the Clinical Logic tab."""
845
+ st.header("Clinical Logic")
846
+
847
+ # Clinical Rules Management
848
+ st.subheader("Clinical Rules")
849
+ rule_name = st.text_input("Enter Rule Name")
850
+ condition = st.text_area("Enter Rule Condition (use 'df' for DataFrame)",
851
+ help="Example: df['blood_pressure'] > 140")
852
+ action = st.text_area("Enter Action to be Taken on Rule Match")
853
+ severity = st.selectbox("Enter Severity for the Rule", ["low", "medium", "high"])
854
+
855
+ if st.button("Add Clinical Rule"):
856
+ try:
857
+ rule = ClinicalRule(
858
+ name=rule_name,
859
+ condition=condition,
860
+ action=action,
861
+ severity=severity
862
+ )
863
+ st.session_state.clinical_rules.add_rule(rule)
864
+ st.success("Added Clinical Rule successfully.")
865
+ except Exception as e:
866
+ st.error(f"Error in rule definition: {e}")
867
 
868
+ # Clinical KPI Management
869
+ st.subheader("Clinical KPI Definition")
870
+ kpi_name = st.text_input("Enter KPI Name")
871
+ kpi_calculation = st.text_area("Enter KPI Calculation (use 'df' for DataFrame)",
872
+ help="Example: df['patient_count'].sum()")
873
+ threshold = st.text_input("Enter Threshold for KPI", help="Optional")
874
 
875
+ if st.button("Add Clinical KPI"):
876
+ try:
877
+ threshold_value = float(threshold) if threshold else None
878
+ kpi = ClinicalKPI(
879
+ name=kpi_name,
880
+ calculation=kpi_calculation,
881
+ threshold=threshold_value
882
+ )
883
+ st.session_state.kpi_monitoring.add_kpi(kpi)
884
+ st.success(f"Added KPI '{kpi_name}' successfully.")
885
+ except ValueError:
886
+ st.error("Threshold must be a numeric value.")
887
+ except Exception as e:
888
+ st.error(f"Error creating KPI: {e}")
889
+
890
+ # Execute Clinical Rules and Calculate KPIs
891
+ selected_data_key = st.selectbox("Select Dataset for Clinical Logic", list(st.session_state.data.keys()))
892
+ if selected_data_key:
893
+ data = st.session_state.data[selected_data_key]
894
+ if st.button("Execute Clinical Rules"):
895
+ with st.spinner("Executing Clinical Rules..."):
896
+ result = st.session_state.clinical_rules.execute_rules(data)
897
+ st.json(result)
898
+ if st.button("Calculate Clinical KPIs"):
899
+ with st.spinner("Calculating Clinical KPIs..."):
900
+ result = st.session_state.kpi_monitoring.calculate_kpis(data)
901
+ st.json(result)
902
+
903
+ def insights_section():
904
+ """Handles the Insights tab."""
905
+ st.header("Automated Insights")
906
+
907
+ selected_data_key = st.selectbox("Select Dataset for Insights", list(st.session_state.data.keys()))
908
+ if not selected_data_key:
909
+ st.warning("Please select a dataset to generate insights.")
910
+ return
911
+
912
+ data = st.session_state.data[selected_data_key]
913
+ available_analyses = ["EDA", "temporal", "distribution", "hypothesis", "model"]
914
+ selected_analyses = st.multiselect("Select Analyses for Insights", available_analyses)
915
+
916
+ if st.button("Generate Automated Insights"):
917
+ with st.spinner("Generating Insights..."):
918
+ results = st.session_state.automated_insights.generate_insights(
919
+ data, analysis_names=selected_analyses
920
+ )
921
+ st.json(results)
922
+
923
+ # Diagnosis Support
924
+ st.subheader("Diagnosis Support")
925
+ target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
926
+ numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
927
+ selected_feature_cols = st.multiselect("Select Feature Variables for Diagnosis", numeric_cols)
928
+
929
+ if st.button("Generate Diagnosis"):
930
+ if target_col and selected_feature_cols:
931
+ with st.spinner("Generating Diagnosis..."):
932
+ result = st.session_state.diagnosis_support.diagnose(
933
+ data, target_col=target_col, columns=selected_feature_cols, diagnosis_key="diagnosis_result"
934
+ )
935
+ st.json(result)
936
+ else:
937
+ st.error("Please select both target and feature variables for diagnosis.")
938
+
939
+ # Treatment Recommendation
940
+ st.subheader("Treatment Recommendation")
941
+ condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
942
+ treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
943
+
944
+ if st.button("Generate Treatment Recommendation"):
945
+ if condition_col and treatment_col:
946
+ with st.spinner("Generating Treatment Recommendation..."):
947
+ result = st.session_state.treatment_recommendation.recommend(
948
+ data, condition_col=condition_col, treatment_col=treatment_col, recommendation_key="treatment_recommendation"
949
+ )
950
+ st.json(result)
951
+ else:
952
+ st.error("Please select both condition and treatment columns.")
953
+
954
+ def reports_section():
955
+ """Handles the Reports tab."""
956
+ st.header("Automated Reports")
957
+
958
+ # Create Report Definition
959
+ st.subheader("Create Report Definition")
960
+ report_name = st.text_input("Report Name")
961
+ report_def = st.text_area("Report Definition")
962
+
963
+ if st.button("Create Report Definition"):
964
+ if report_name and report_def:
965
+ st.session_state.automated_reports.create_report_definition(report_name, report_def)
966
+ st.success("Report definition created successfully.")
967
+ else:
968
+ st.error("Please provide both report name and definition.")
969
+
970
+ # Generate Report
971
+ st.subheader("Generate Report")
972
+ report_name_to_generate = st.selectbox("Select Report to Generate", list(st.session_state.automated_reports.report_definitions.keys()))
973
+
974
+ if st.button("Generate Report"):
975
+ if report_name_to_generate:
976
+ with st.spinner("Generating Report..."):
977
+ report = st.session_state.automated_reports.generate_report(report_name_to_generate, st.session_state.data)
978
+ if "error" not in report:
979
+ st.header(f"Report: {report_name_to_generate}")
980
+ st.write(f"**Definition:** {report['Report Definition']}")
981
+ for df_name, df_content in report["Data"].items():
982
+ st.subheader(f"Data: {df_name}")
983
+ st.write(pd.DataFrame(df_content))
984
+ else:
985
+ st.error(report["error"])
986
+ else:
987
+ st.error("Please select a report to generate.")
988
+
989
+ def medical_knowledge_section():
990
+ """Handles the Medical Knowledge tab."""
991
+ st.header("Medical Knowledge")
992
+ query = st.text_input("Enter your medical question here:")
993
+
994
+ if st.button("Search"):
995
+ if query:
996
+ with st.spinner("Searching..."):
997
+ result = st.session_state.knowledge_base.search_medical_info(query, pub_email=st.session_state.pub_email)
998
+ st.markdown(result)
999
+ else:
1000
+ st.error("Please enter a medical question to search.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1001
 
1002
  if __name__ == "__main__":
1003
+ main()