Update app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,6 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
25 |
# Initialize Groq Client
|
26 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
27 |
|
28 |
-
|
29 |
# ---------------------- Base Classes and Schemas ---------------------------
|
30 |
class ResearchInput(BaseModel):
|
31 |
"""Base schema for research tool inputs"""
|
@@ -287,7 +286,7 @@ class BasicTreatmentRecommendation(TreatmentRecommendation):
|
|
287 |
class MedicalKnowledgeBase():
|
288 |
"""Abstract class for Medical Knowledge"""
|
289 |
@abstractmethod
|
290 |
-
def search_medical_info(self, query: str) -> str:
|
291 |
pass
|
292 |
|
293 |
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
@@ -534,6 +533,7 @@ def main():
|
|
534 |
if 'pub_email' not in st.session_state:
|
535 |
st.session_state.pub_email = st.secrets.get("PUB_EMAIL", "") # Load PUB_EMAIL from secrets
|
536 |
|
|
|
537 |
# Sidebar for Data Management
|
538 |
with st.sidebar:
|
539 |
st.header("⚙️ Data Management")
|
@@ -706,4 +706,52 @@ def main():
|
|
706 |
st.json(result)
|
707 |
with insights_tab:
|
708 |
if selected_data_key:
|
709 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
# Initialize Groq Client
|
26 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
27 |
|
|
|
28 |
# ---------------------- Base Classes and Schemas ---------------------------
|
29 |
class ResearchInput(BaseModel):
|
30 |
"""Base schema for research tool inputs"""
|
|
|
286 |
class MedicalKnowledgeBase():
|
287 |
"""Abstract class for Medical Knowledge"""
|
288 |
@abstractmethod
|
289 |
+
def search_medical_info(self, query: str, pub_email:str="") -> str:
|
290 |
pass
|
291 |
|
292 |
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
|
|
533 |
if 'pub_email' not in st.session_state:
|
534 |
st.session_state.pub_email = st.secrets.get("PUB_EMAIL", "") # Load PUB_EMAIL from secrets
|
535 |
|
536 |
+
|
537 |
# Sidebar for Data Management
|
538 |
with st.sidebar:
|
539 |
st.header("⚙️ Data Management")
|
|
|
706 |
st.json(result)
|
707 |
with insights_tab:
|
708 |
if selected_data_key:
|
709 |
+
data = st.session_state.data[selected_data_key]
|
710 |
+
available_analysis = ["EDA", "temporal", "distribution", "hypothesis", "model"]
|
711 |
+
selected_analysis = st.multiselect("Select Analysis", available_analysis)
|
712 |
+
if st.button("Generate Automated Insights"):
|
713 |
+
with st.spinner("Generating Insights"):
|
714 |
+
results = st.session_state.automated_insights.generate_insights(data, analysis_names=selected_analysis)
|
715 |
+
st.json(results)
|
716 |
+
st.subheader("Diagnosis Support")
|
717 |
+
target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
|
718 |
+
num_cols = data.select_dtypes(include=np.number).columns.tolist()
|
719 |
+
selected_cols_diagnosis = st.multiselect("Select Feature Variables for Diagnosis", num_cols)
|
720 |
+
if st.button("Generate Diagnosis"):
|
721 |
+
if target_col
|
722 |
+
if target_col and selected_cols_diagnosis:
|
723 |
+
with st.spinner("Generating Diagnosis"):
|
724 |
+
result = st.session_state.diagnosis_support.diagnose(data, target_col=target_col, columns=selected_cols_diagnosis, diagnosis_key="diagnosis_result")
|
725 |
+
st.json(result)
|
726 |
+
|
727 |
+
st.subheader("Treatment Recommendation")
|
728 |
+
condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
|
729 |
+
treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
|
730 |
+
if st.button("Generate Treatment Recommendation"):
|
731 |
+
if condition_col and treatment_col:
|
732 |
+
with st.spinner("Generating Treatment Recommendation"):
|
733 |
+
result = st.session_state.treatment_recommendation.recommend(data, condition_col = condition_col, treatment_col = treatment_col, recommendation_key="treatment_recommendation")
|
734 |
+
st.json(result)
|
735 |
+
|
736 |
+
with reports_tab:
|
737 |
+
st.header("Reports")
|
738 |
+
report_name = st.text_input("Report Name")
|
739 |
+
report_def = st.text_area("Report definition")
|
740 |
+
if st.button("Create Report Definition"):
|
741 |
+
st.session_state.automated_reports.create_report_definition(report_name, report_def)
|
742 |
+
st.success("Report definition created")
|
743 |
+
if selected_data_key:
|
744 |
+
data = st.session_state.data
|
745 |
+
if st.button("Generate Report"):
|
746 |
+
with st.spinner("Generating Report..."):
|
747 |
+
report = st.session_state.automated_reports.generate_report(report_name, data)
|
748 |
+
with knowledge_tab:
|
749 |
+
st.header("Medical Knowledge")
|
750 |
+
query = st.text_input("Enter your medical question here:")
|
751 |
+
if st.button("Search"):
|
752 |
+
with st.spinner("Searching..."):
|
753 |
+
result = st.session_state.knowledge_base.search_medical_info(query, pub_email=st.session_state.pub_email)
|
754 |
+
st.write(result)
|
755 |
+
|
756 |
+
if __name__ == "__main__":
|
757 |
+
main()
|