Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import pandas as pd
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
import seaborn as sns
|
12 |
import streamlit as st
|
|
|
13 |
|
14 |
from scipy.stats import ttest_ind, f_oneway
|
15 |
from sklearn.model_selection import train_test_split
|
@@ -22,7 +23,7 @@ from statsmodels.tsa.seasonal import seasonal_decompose
|
|
22 |
from statsmodels.tsa.stattools import adfuller
|
23 |
|
24 |
from pydantic import BaseModel, Field
|
25 |
-
from Bio import Entrez
|
26 |
|
27 |
from langchain.prompts import PromptTemplate
|
28 |
from groq import Groq
|
@@ -31,6 +32,14 @@ from groq import Groq
|
|
31 |
# Initialize Groq Client with API Key from environment variables
|
32 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# ---------------------- Base Classes and Schemas ---------------------------
|
35 |
|
36 |
class ResearchInput(BaseModel):
|
@@ -68,7 +77,7 @@ class AdvancedEDA(DataAnalyzer):
|
|
68 |
"dimensionality": {
|
69 |
"rows": len(data),
|
70 |
"columns": list(data.columns),
|
71 |
-
"
|
72 |
},
|
73 |
"statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(),
|
74 |
"temporal_analysis": {
|
@@ -127,10 +136,12 @@ class TemporalAnalyzer(DataAnalyzer):
|
|
127 |
plt.close()
|
128 |
plot_data = base64.b64encode(buf.getvalue()).decode()
|
129 |
|
|
|
|
|
130 |
return {
|
131 |
"trend_statistics": {
|
132 |
-
"stationarity_p_value":
|
133 |
-
"seasonality_strength": max(decomposition.seasonal)
|
134 |
},
|
135 |
"visualization": plot_data
|
136 |
}
|
@@ -197,7 +208,9 @@ class LogisticRegressionTrainer(DataAnalyzer):
|
|
197 |
try:
|
198 |
X = data[columns]
|
199 |
y = data[target_col]
|
200 |
-
X_train, X_test, y_train, y_test = train_test_split(
|
|
|
|
|
201 |
model = LogisticRegression(max_iter=1000)
|
202 |
model.fit(X_train, y_train)
|
203 |
y_pred = model.predict(X_test)
|
@@ -232,7 +245,8 @@ class ClinicalRulesEngine:
|
|
232 |
for rule_name, rule in self.rules.items():
|
233 |
try:
|
234 |
# Evaluate the condition using the dataframe 'df'
|
235 |
-
|
|
|
236 |
results[rule_name] = {
|
237 |
"rule_matched": rule_matched,
|
238 |
"action": rule.action if rule_matched else None,
|
@@ -264,7 +278,8 @@ class ClinicalKPIMonitoring:
|
|
264 |
results = {}
|
265 |
for kpi_name, kpi in self.kpis.items():
|
266 |
try:
|
267 |
-
|
|
|
268 |
results[kpi_name] = {
|
269 |
"value": kpi_value,
|
270 |
"threshold": kpi.threshold,
|
@@ -363,6 +378,8 @@ class BasicTreatmentRecommendation(TreatmentRecommendation):
|
|
363 |
recommendation_key: ["No treatment recommendation found!"]
|
364 |
})
|
365 |
|
|
|
|
|
366 |
class MedicalKnowledgeBase(ABC):
|
367 |
"""Abstract class for Medical Knowledge."""
|
368 |
@abstractmethod
|
@@ -370,14 +387,15 @@ class MedicalKnowledgeBase(ABC):
|
|
370 |
pass
|
371 |
|
372 |
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
373 |
-
"""Simple Medical Knowledge Class with TF-IDF and PubMed."""
|
374 |
def __init__(self):
|
375 |
self.knowledge_base = {
|
376 |
"diabetes": "The recommended treatment for diabetes includes lifestyle changes, medication, and monitoring.",
|
377 |
"heart disease": "Risk factors for heart disease include high blood pressure, high cholesterol, and smoking.",
|
378 |
"fever": "For a fever, you can consider over-the-counter medications like acetaminophen or ibuprofen. Rest and hydration are also important.",
|
379 |
"headache": "For a headache, try rest, hydration, and over-the-counter pain relievers. Consult a doctor if it is severe or persistent.",
|
380 |
-
"cold": "For a cold, get rest, drink plenty of fluids, and use over-the-counter remedies like decongestants."
|
|
|
381 |
}
|
382 |
self.vectorizer = TfidfVectorizer()
|
383 |
self.tfidf_matrix = self.vectorizer.fit_transform(self.knowledge_base.values())
|
@@ -402,26 +420,61 @@ class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
|
402 |
def search_medical_info(self, query: str, pub_email: str = "") -> str:
|
403 |
"""Search the medical knowledge base and PubMed for relevant information."""
|
404 |
try:
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
417 |
else:
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
except Exception as e:
|
423 |
return f"Medical Knowledge Search Failed: {e}"
|
424 |
|
|
|
|
|
425 |
class ForecastingEngine(ABC):
|
426 |
"""Abstract class for forecasting."""
|
427 |
@abstractmethod
|
@@ -472,23 +525,19 @@ class Dashboard:
|
|
472 |
st.header("Dashboard")
|
473 |
for vis_name, vis_type in self.layout.items():
|
474 |
st.subheader(vis_name)
|
475 |
-
|
476 |
-
|
477 |
-
if
|
478 |
st.table(df)
|
479 |
-
|
480 |
-
st.write("Data Not Found")
|
481 |
-
elif vis_type == "plot":
|
482 |
-
df = data_dict.get(vis_name)
|
483 |
-
if df is not None:
|
484 |
if len(df.columns) > 1:
|
485 |
fig = plt.figure()
|
486 |
sns.lineplot(data=df)
|
487 |
st.pyplot(fig)
|
488 |
else:
|
489 |
st.write("Please select a DataFrame with more than 1 column for plotting.")
|
490 |
-
|
491 |
-
|
492 |
|
493 |
class AutomatedReports:
|
494 |
"""Manages automated report definitions and generation."""
|
@@ -754,9 +803,9 @@ def main_tabs_section():
|
|
754 |
|
755 |
def data_analysis_section():
|
756 |
"""Handles the Data Analysis tab."""
|
757 |
-
selected_data_key = st.
|
758 |
if not selected_data_key:
|
759 |
-
st.warning("Please select a dataset
|
760 |
return
|
761 |
|
762 |
data = st.session_state.data[selected_data_key]
|
@@ -791,21 +840,33 @@ def perform_temporal_analysis(data: pd.DataFrame):
|
|
791 |
time_cols = data.select_dtypes(include='datetime').columns
|
792 |
num_cols = data.select_dtypes(include=np.number).columns
|
793 |
|
|
|
|
|
|
|
|
|
794 |
time_col = st.selectbox("Select Temporal Variable", time_cols)
|
795 |
value_col = st.selectbox("Select Analysis Variable", num_cols)
|
796 |
|
797 |
if time_col and value_col:
|
798 |
analyzer = TemporalAnalyzer()
|
799 |
result = analyzer.invoke(data=data, time_col=time_col, value_col=value_col)
|
800 |
-
if "visualization" in result:
|
801 |
-
st.image(f"data:image/png;base64,{result['visualization']}")
|
802 |
st.json(result)
|
803 |
|
804 |
def perform_comparative_statistics(data: pd.DataFrame):
|
805 |
"""Performs Comparative Statistics."""
|
806 |
-
categorical_cols = data.select_dtypes(include='category').columns
|
807 |
numeric_cols = data.select_dtypes(include=np.number).columns
|
808 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
809 |
group_col = st.selectbox("Select Grouping Variable", categorical_cols)
|
810 |
value_col = st.selectbox("Select Metric Variable", numeric_cols)
|
811 |
|
@@ -823,10 +884,12 @@ def perform_distribution_analysis(data: pd.DataFrame):
|
|
823 |
if selected_cols:
|
824 |
analyzer = DistributionVisualizer()
|
825 |
img_data = analyzer.invoke(data=data, columns=selected_cols)
|
826 |
-
if "Visualization Error"
|
827 |
-
st.image(f"data:image/png;base64,{img_data}")
|
828 |
else:
|
829 |
st.error(img_data)
|
|
|
|
|
830 |
|
831 |
def perform_logistic_regression_training(data: pd.DataFrame):
|
832 |
"""Trains a Logistic Regression model."""
|
@@ -839,6 +902,8 @@ def perform_logistic_regression_training(data: pd.DataFrame):
|
|
839 |
result = analyzer.invoke(data=data, target_col=target_col, columns=selected_cols)
|
840 |
st.subheader("Logistic Regression Model Results")
|
841 |
st.json(result)
|
|
|
|
|
842 |
|
843 |
def clinical_logic_section():
|
844 |
"""Handles the Clinical Logic tab."""
|
@@ -853,39 +918,45 @@ def clinical_logic_section():
|
|
853 |
severity = st.selectbox("Enter Severity for the Rule", ["low", "medium", "high"])
|
854 |
|
855 |
if st.button("Add Clinical Rule"):
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
|
|
|
|
|
|
867 |
|
868 |
# Clinical KPI Management
|
869 |
st.subheader("Clinical KPI Definition")
|
870 |
kpi_name = st.text_input("Enter KPI Name")
|
871 |
kpi_calculation = st.text_area("Enter KPI Calculation (use 'df' for DataFrame)",
|
872 |
help="Example: df['patient_count'].sum()")
|
873 |
-
threshold = st.text_input("Enter Threshold for KPI", help="
|
874 |
|
875 |
if st.button("Add Clinical KPI"):
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
|
|
|
|
|
|
889 |
|
890 |
# Execute Clinical Rules and Calculate KPIs
|
891 |
selected_data_key = st.selectbox("Select Dataset for Clinical Logic", list(st.session_state.data.keys()))
|
@@ -899,6 +970,8 @@ def clinical_logic_section():
|
|
899 |
with st.spinner("Calculating Clinical KPIs..."):
|
900 |
result = st.session_state.kpi_monitoring.calculate_kpis(data)
|
901 |
st.json(result)
|
|
|
|
|
902 |
|
903 |
def insights_section():
|
904 |
"""Handles the Insights tab."""
|
@@ -914,11 +987,14 @@ def insights_section():
|
|
914 |
selected_analyses = st.multiselect("Select Analyses for Insights", available_analyses)
|
915 |
|
916 |
if st.button("Generate Automated Insights"):
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
|
|
|
|
|
|
922 |
|
923 |
# Diagnosis Support
|
924 |
st.subheader("Diagnosis Support")
|
@@ -958,7 +1034,7 @@ def reports_section():
|
|
958 |
# Create Report Definition
|
959 |
st.subheader("Create Report Definition")
|
960 |
report_name = st.text_input("Report Name")
|
961 |
-
report_def = st.text_area("Report Definition")
|
962 |
|
963 |
if st.button("Create Report Definition"):
|
964 |
if report_name and report_def:
|
@@ -969,22 +1045,22 @@ def reports_section():
|
|
969 |
|
970 |
# Generate Report
|
971 |
st.subheader("Generate Report")
|
972 |
-
|
973 |
-
|
974 |
-
|
975 |
-
if
|
976 |
with st.spinner("Generating Report..."):
|
977 |
report = st.session_state.automated_reports.generate_report(report_name_to_generate, st.session_state.data)
|
978 |
if "error" not in report:
|
979 |
-
st.header(f"Report: {
|
980 |
-
st.
|
981 |
for df_name, df_content in report["Data"].items():
|
982 |
st.subheader(f"Data: {df_name}")
|
983 |
-
st.
|
984 |
else:
|
985 |
st.error(report["error"])
|
986 |
-
|
987 |
-
|
988 |
|
989 |
def medical_knowledge_section():
|
990 |
"""Handles the Medical Knowledge tab."""
|
@@ -992,9 +1068,11 @@ def medical_knowledge_section():
|
|
992 |
query = st.text_input("Enter your medical question here:")
|
993 |
|
994 |
if st.button("Search"):
|
995 |
-
if query:
|
996 |
with st.spinner("Searching..."):
|
997 |
-
result = st.session_state.knowledge_base.search_medical_info(
|
|
|
|
|
998 |
st.markdown(result)
|
999 |
else:
|
1000 |
st.error("Please enter a medical question to search.")
|
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
import seaborn as sns
|
12 |
import streamlit as st
|
13 |
+
import spacy
|
14 |
|
15 |
from scipy.stats import ttest_ind, f_oneway
|
16 |
from sklearn.model_selection import train_test_split
|
|
|
23 |
from statsmodels.tsa.stattools import adfuller
|
24 |
|
25 |
from pydantic import BaseModel, Field
|
26 |
+
from Bio import Entrez # Ensure BioPython is installed
|
27 |
|
28 |
from langchain.prompts import PromptTemplate
|
29 |
from groq import Groq
|
|
|
32 |
# Initialize Groq Client with API Key from environment variables
|
33 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
34 |
|
35 |
+
# Load spaCy model with error handling
|
36 |
+
try:
|
37 |
+
nlp = spacy.load("en_core_web_sm")
|
38 |
+
except OSError:
|
39 |
+
st.write("Downloading en_core_web_sm spaCy model...")
|
40 |
+
spacy.cli.download("en_core_web_sm")
|
41 |
+
nlp = spacy.load("en_core_web_sm")
|
42 |
+
|
43 |
# ---------------------- Base Classes and Schemas ---------------------------
|
44 |
|
45 |
class ResearchInput(BaseModel):
|
|
|
77 |
"dimensionality": {
|
78 |
"rows": len(data),
|
79 |
"columns": list(data.columns),
|
80 |
+
"memory_usage_MB": f"{data.memory_usage().sum() / 1e6:.2f} MB"
|
81 |
},
|
82 |
"statistical_profile": data.describe(percentiles=[.25, .5, .75]).to_dict(),
|
83 |
"temporal_analysis": {
|
|
|
136 |
plt.close()
|
137 |
plot_data = base64.b64encode(buf.getvalue()).decode()
|
138 |
|
139 |
+
stationarity_p_value = adfuller(ts_data)[1]
|
140 |
+
|
141 |
return {
|
142 |
"trend_statistics": {
|
143 |
+
"stationarity_p_value": stationarity_p_value,
|
144 |
+
"seasonality_strength": float(max(decomposition.seasonal))
|
145 |
},
|
146 |
"visualization": plot_data
|
147 |
}
|
|
|
208 |
try:
|
209 |
X = data[columns]
|
210 |
y = data[target_col]
|
211 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
212 |
+
X, y, test_size=0.2, random_state=42
|
213 |
+
)
|
214 |
model = LogisticRegression(max_iter=1000)
|
215 |
model.fit(X_train, y_train)
|
216 |
y_pred = model.predict(X_test)
|
|
|
245 |
for rule_name, rule in self.rules.items():
|
246 |
try:
|
247 |
# Evaluate the condition using the dataframe 'df'
|
248 |
+
# **Warning**: Using eval can be dangerous. Ensure that user inputs are sanitized.
|
249 |
+
rule_matched = eval(rule.condition, {"__builtins__": None}, {"df": data})
|
250 |
results[rule_name] = {
|
251 |
"rule_matched": rule_matched,
|
252 |
"action": rule.action if rule_matched else None,
|
|
|
278 |
results = {}
|
279 |
for kpi_name, kpi in self.kpis.items():
|
280 |
try:
|
281 |
+
# **Warning**: Using eval can be dangerous. Ensure that user inputs are sanitized.
|
282 |
+
kpi_value = eval(kpi.calculation, {"__builtins__": None}, {"df": data})
|
283 |
results[kpi_name] = {
|
284 |
"value": kpi_value,
|
285 |
"threshold": kpi.threshold,
|
|
|
378 |
recommendation_key: ["No treatment recommendation found!"]
|
379 |
})
|
380 |
|
381 |
+
# ---------------------- Medical Knowledge Base ---------------------------
|
382 |
+
|
383 |
class MedicalKnowledgeBase(ABC):
|
384 |
"""Abstract class for Medical Knowledge."""
|
385 |
@abstractmethod
|
|
|
387 |
pass
|
388 |
|
389 |
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
390 |
+
"""Simple Medical Knowledge Class with TF-IDF, NER, and PubMed."""
|
391 |
def __init__(self):
|
392 |
self.knowledge_base = {
|
393 |
"diabetes": "The recommended treatment for diabetes includes lifestyle changes, medication, and monitoring.",
|
394 |
"heart disease": "Risk factors for heart disease include high blood pressure, high cholesterol, and smoking.",
|
395 |
"fever": "For a fever, you can consider over-the-counter medications like acetaminophen or ibuprofen. Rest and hydration are also important.",
|
396 |
"headache": "For a headache, try rest, hydration, and over-the-counter pain relievers. Consult a doctor if it is severe or persistent.",
|
397 |
+
"cold": "For a cold, get rest, drink plenty of fluids, and use over-the-counter remedies like decongestants.",
|
398 |
+
"cancer drugs": "Please consult with your doctor to get personalized treatment and the latest drug information for your type of cancer."
|
399 |
}
|
400 |
self.vectorizer = TfidfVectorizer()
|
401 |
self.tfidf_matrix = self.vectorizer.fit_transform(self.knowledge_base.values())
|
|
|
420 |
def search_medical_info(self, query: str, pub_email: str = "") -> str:
|
421 |
"""Search the medical knowledge base and PubMed for relevant information."""
|
422 |
try:
|
423 |
+
query_lower = query.lower()
|
424 |
+
doc = nlp(query_lower)
|
425 |
+
entities = [ent.text for ent in doc.ents]
|
426 |
+
|
427 |
+
if entities:
|
428 |
+
best_match_keyword = ""
|
429 |
+
best_match_score = -1
|
430 |
+
for entity in entities:
|
431 |
+
query_vector = self.vectorizer.transform([entity])
|
432 |
+
similarities = cosine_similarity(query_vector, self.tfidf_matrix)
|
433 |
+
current_best_match_index = np.argmax(similarities)
|
434 |
+
current_best_score = np.max(similarities)
|
435 |
+
if current_best_score > best_match_score:
|
436 |
+
best_match_keyword = list(self.knowledge_base.keys())[current_best_match_index]
|
437 |
+
best_match_score = current_best_score
|
438 |
else:
|
439 |
+
query_vector = self.vectorizer.transform([query_lower])
|
440 |
+
similarities = cosine_similarity(query_vector, self.tfidf_matrix)
|
441 |
+
best_match_index = np.argmax(similarities)
|
442 |
+
best_match_keyword = list(self.knowledge_base.keys())[best_match_index]
|
443 |
+
|
444 |
+
best_match_info = self.knowledge_base.get(best_match_keyword, "No specific information is available based on the query provided.")
|
445 |
+
|
446 |
+
pubmed_result = self.search_pubmed(best_match_keyword, pub_email)
|
447 |
+
feedback_key = f"feedback_{query_lower}" # Unique key for feedback
|
448 |
+
|
449 |
+
response = f"**Based on your query:** {best_match_info}\n\n"
|
450 |
+
|
451 |
+
if "No abstracts found for this query on PubMed" not in pubmed_result:
|
452 |
+
response += f"**PubMed Abstract:**\n{pubmed_result}"
|
453 |
+
else:
|
454 |
+
response += f"{pubmed_result}"
|
455 |
+
|
456 |
+
# Initialize feedback in session state
|
457 |
+
if feedback_key not in st.session_state:
|
458 |
+
st.session_state[feedback_key] = {"feedback": None}
|
459 |
+
|
460 |
+
# Display feedback buttons only if a valid response is generated
|
461 |
+
if "error" not in pubmed_result:
|
462 |
+
col1, col2 = st.columns([1, 1])
|
463 |
+
with col1:
|
464 |
+
if st.button("Good Result", key=f"good_{feedback_key}"):
|
465 |
+
st.session_state[feedback_key]["feedback"] = "positive"
|
466 |
+
st.success("Thank you for the feedback!")
|
467 |
+
with col2:
|
468 |
+
if st.button("Bad Result", key=f"bad_{feedback_key}"):
|
469 |
+
st.session_state[feedback_key]["feedback"] = "negative"
|
470 |
+
st.error("Thank you for the feedback!")
|
471 |
+
|
472 |
+
return response
|
473 |
except Exception as e:
|
474 |
return f"Medical Knowledge Search Failed: {e}"
|
475 |
|
476 |
+
# ---------------------- Forecasting Engine ---------------------------
|
477 |
+
|
478 |
class ForecastingEngine(ABC):
|
479 |
"""Abstract class for forecasting."""
|
480 |
@abstractmethod
|
|
|
525 |
st.header("Dashboard")
|
526 |
for vis_name, vis_type in self.layout.items():
|
527 |
st.subheader(vis_name)
|
528 |
+
df = data_dict.get(vis_name)
|
529 |
+
if df is not None:
|
530 |
+
if vis_type == "table":
|
531 |
st.table(df)
|
532 |
+
elif vis_type == "plot":
|
|
|
|
|
|
|
|
|
533 |
if len(df.columns) > 1:
|
534 |
fig = plt.figure()
|
535 |
sns.lineplot(data=df)
|
536 |
st.pyplot(fig)
|
537 |
else:
|
538 |
st.write("Please select a DataFrame with more than 1 column for plotting.")
|
539 |
+
else:
|
540 |
+
st.write("Data Not Found")
|
541 |
|
542 |
class AutomatedReports:
|
543 |
"""Manages automated report definitions and generation."""
|
|
|
803 |
|
804 |
def data_analysis_section():
|
805 |
"""Handles the Data Analysis tab."""
|
806 |
+
selected_data_key = st.session_state.get('selected_data_key', None)
|
807 |
if not selected_data_key:
|
808 |
+
st.warning("Please select a dataset from the metadata section.")
|
809 |
return
|
810 |
|
811 |
data = st.session_state.data[selected_data_key]
|
|
|
840 |
time_cols = data.select_dtypes(include='datetime').columns
|
841 |
num_cols = data.select_dtypes(include=np.number).columns
|
842 |
|
843 |
+
if len(time_cols) == 0:
|
844 |
+
st.warning("No datetime columns available for temporal analysis.")
|
845 |
+
return
|
846 |
+
|
847 |
time_col = st.selectbox("Select Temporal Variable", time_cols)
|
848 |
value_col = st.selectbox("Select Analysis Variable", num_cols)
|
849 |
|
850 |
if time_col and value_col:
|
851 |
analyzer = TemporalAnalyzer()
|
852 |
result = analyzer.invoke(data=data, time_col=time_col, value_col=value_col)
|
853 |
+
if "visualization" in result and result["visualization"]:
|
854 |
+
st.image(f"data:image/png;base64,{result['visualization']}", use_column_width=True)
|
855 |
st.json(result)
|
856 |
|
857 |
def perform_comparative_statistics(data: pd.DataFrame):
|
858 |
"""Performs Comparative Statistics."""
|
859 |
+
categorical_cols = data.select_dtypes(include=['category', 'object']).columns
|
860 |
numeric_cols = data.select_dtypes(include=np.number).columns
|
861 |
|
862 |
+
if len(categorical_cols) == 0:
|
863 |
+
st.warning("No categorical columns available for hypothesis testing.")
|
864 |
+
return
|
865 |
+
|
866 |
+
if len(numeric_cols) == 0:
|
867 |
+
st.warning("No numerical columns available for hypothesis testing.")
|
868 |
+
return
|
869 |
+
|
870 |
group_col = st.selectbox("Select Grouping Variable", categorical_cols)
|
871 |
value_col = st.selectbox("Select Metric Variable", numeric_cols)
|
872 |
|
|
|
884 |
if selected_cols:
|
885 |
analyzer = DistributionVisualizer()
|
886 |
img_data = analyzer.invoke(data=data, columns=selected_cols)
|
887 |
+
if not img_data.startswith("Visualization Error"):
|
888 |
+
st.image(f"data:image/png;base64,{img_data}", use_column_width=True)
|
889 |
else:
|
890 |
st.error(img_data)
|
891 |
+
else:
|
892 |
+
st.info("Please select at least one numerical column to visualize.")
|
893 |
|
894 |
def perform_logistic_regression_training(data: pd.DataFrame):
|
895 |
"""Trains a Logistic Regression model."""
|
|
|
902 |
result = analyzer.invoke(data=data, target_col=target_col, columns=selected_cols)
|
903 |
st.subheader("Logistic Regression Model Results")
|
904 |
st.json(result)
|
905 |
+
else:
|
906 |
+
st.warning("Please select both target and feature variables for model training.")
|
907 |
|
908 |
def clinical_logic_section():
|
909 |
"""Handles the Clinical Logic tab."""
|
|
|
918 |
severity = st.selectbox("Enter Severity for the Rule", ["low", "medium", "high"])
|
919 |
|
920 |
if st.button("Add Clinical Rule"):
|
921 |
+
if rule_name and condition and action and severity:
|
922 |
+
try:
|
923 |
+
rule = ClinicalRule(
|
924 |
+
name=rule_name,
|
925 |
+
condition=condition,
|
926 |
+
action=action,
|
927 |
+
severity=severity
|
928 |
+
)
|
929 |
+
st.session_state.clinical_rules.add_rule(rule)
|
930 |
+
st.success("Added Clinical Rule successfully.")
|
931 |
+
except Exception as e:
|
932 |
+
st.error(f"Error in rule definition: {e}")
|
933 |
+
else:
|
934 |
+
st.error("Please fill in all fields to add a clinical rule.")
|
935 |
|
936 |
# Clinical KPI Management
|
937 |
st.subheader("Clinical KPI Definition")
|
938 |
kpi_name = st.text_input("Enter KPI Name")
|
939 |
kpi_calculation = st.text_area("Enter KPI Calculation (use 'df' for DataFrame)",
|
940 |
help="Example: df['patient_count'].sum()")
|
941 |
+
threshold = st.text_input("Enter Threshold for KPI (Optional)", help="Leave blank if not applicable")
|
942 |
|
943 |
if st.button("Add Clinical KPI"):
|
944 |
+
if kpi_name and kpi_calculation:
|
945 |
+
try:
|
946 |
+
threshold_value = float(threshold) if threshold else None
|
947 |
+
kpi = ClinicalKPI(
|
948 |
+
name=kpi_name,
|
949 |
+
calculation=kpi_calculation,
|
950 |
+
threshold=threshold_value
|
951 |
+
)
|
952 |
+
st.session_state.kpi_monitoring.add_kpi(kpi)
|
953 |
+
st.success(f"Added KPI '{kpi_name}' successfully.")
|
954 |
+
except ValueError:
|
955 |
+
st.error("Threshold must be a numeric value.")
|
956 |
+
except Exception as e:
|
957 |
+
st.error(f"Error creating KPI: {e}")
|
958 |
+
else:
|
959 |
+
st.error("Please provide both KPI name and calculation.")
|
960 |
|
961 |
# Execute Clinical Rules and Calculate KPIs
|
962 |
selected_data_key = st.selectbox("Select Dataset for Clinical Logic", list(st.session_state.data.keys()))
|
|
|
970 |
with st.spinner("Calculating Clinical KPIs..."):
|
971 |
result = st.session_state.kpi_monitoring.calculate_kpis(data)
|
972 |
st.json(result)
|
973 |
+
else:
|
974 |
+
st.warning("Please ingest data to execute clinical rules and calculate KPIs.")
|
975 |
|
976 |
def insights_section():
|
977 |
"""Handles the Insights tab."""
|
|
|
987 |
selected_analyses = st.multiselect("Select Analyses for Insights", available_analyses)
|
988 |
|
989 |
if st.button("Generate Automated Insights"):
|
990 |
+
if selected_analyses:
|
991 |
+
with st.spinner("Generating Insights..."):
|
992 |
+
results = st.session_state.automated_insights.generate_insights(
|
993 |
+
data, analysis_names=selected_analyses
|
994 |
+
)
|
995 |
+
st.json(results)
|
996 |
+
else:
|
997 |
+
st.warning("Please select at least one analysis to generate insights.")
|
998 |
|
999 |
# Diagnosis Support
|
1000 |
st.subheader("Diagnosis Support")
|
|
|
1034 |
# Create Report Definition
|
1035 |
st.subheader("Create Report Definition")
|
1036 |
report_name = st.text_input("Report Name")
|
1037 |
+
report_def = st.text_area("Report Definition", help="Describe the structure and content of the report.")
|
1038 |
|
1039 |
if st.button("Create Report Definition"):
|
1040 |
if report_name and report_def:
|
|
|
1045 |
|
1046 |
# Generate Report
|
1047 |
st.subheader("Generate Report")
|
1048 |
+
report_names = list(st.session_state.automated_reports.report_definitions.keys())
|
1049 |
+
if report_names:
|
1050 |
+
report_name_to_generate = st.selectbox("Select Report to Generate", report_names)
|
1051 |
+
if st.button("Generate Report"):
|
1052 |
with st.spinner("Generating Report..."):
|
1053 |
report = st.session_state.automated_reports.generate_report(report_name_to_generate, st.session_state.data)
|
1054 |
if "error" not in report:
|
1055 |
+
st.header(f"Report: {report['Report Name']}")
|
1056 |
+
st.markdown(f"**Definition:** {report['Report Definition']}")
|
1057 |
for df_name, df_content in report["Data"].items():
|
1058 |
st.subheader(f"Data: {df_name}")
|
1059 |
+
st.dataframe(pd.DataFrame(df_content))
|
1060 |
else:
|
1061 |
st.error(report["error"])
|
1062 |
+
else:
|
1063 |
+
st.info("No report definitions found. Please create a report definition first.")
|
1064 |
|
1065 |
def medical_knowledge_section():
|
1066 |
"""Handles the Medical Knowledge tab."""
|
|
|
1068 |
query = st.text_input("Enter your medical question here:")
|
1069 |
|
1070 |
if st.button("Search"):
|
1071 |
+
if query.strip():
|
1072 |
with st.spinner("Searching..."):
|
1073 |
+
result = st.session_state.knowledge_base.search_medical_info(
|
1074 |
+
query, pub_email=st.session_state.pub_email
|
1075 |
+
)
|
1076 |
st.markdown(result)
|
1077 |
else:
|
1078 |
st.error("Please enter a medical question to search.")
|