Update app.py
Browse files
app.py
CHANGED
@@ -18,10 +18,14 @@ import os
|
|
18 |
import numpy as np
|
19 |
from scipy.stats import ttest_ind, f_oneway
|
20 |
import json
|
|
|
|
|
|
|
21 |
|
22 |
# Initialize Groq Client
|
23 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
24 |
|
|
|
25 |
# ---------------------- Base Classes and Schemas ---------------------------
|
26 |
class ResearchInput(BaseModel):
|
27 |
"""Base schema for research tool inputs"""
|
@@ -287,17 +291,50 @@ class MedicalKnowledgeBase():
|
|
287 |
pass
|
288 |
|
289 |
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
290 |
-
"""Simple Medical Knowledge Class"""
|
291 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
try:
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
297 |
else:
|
298 |
-
|
299 |
except Exception as e:
|
300 |
-
|
301 |
|
302 |
|
303 |
class ForecastingEngine(ABC):
|
@@ -494,8 +531,9 @@ def main():
|
|
494 |
st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
|
495 |
if 'knowledge_base' not in st.session_state:
|
496 |
st.session_state.knowledge_base = SimpleMedicalKnowledge()
|
|
|
|
|
497 |
|
498 |
-
|
499 |
# Sidebar for Data Management
|
500 |
with st.sidebar:
|
501 |
st.header("⚙️ Data Management")
|
@@ -668,51 +706,4 @@ def main():
|
|
668 |
st.json(result)
|
669 |
with insights_tab:
|
670 |
if selected_data_key:
|
671 |
-
|
672 |
-
available_analysis = ["EDA", "temporal", "distribution", "hypothesis", "model"]
|
673 |
-
selected_analysis = st.multiselect("Select Analysis", available_analysis)
|
674 |
-
if st.button("Generate Automated Insights"):
|
675 |
-
with st.spinner("Generating Insights"):
|
676 |
-
results = st.session_state.automated_insights.generate_insights(data, analysis_names=selected_analysis)
|
677 |
-
st.json(results)
|
678 |
-
st.subheader("Diagnosis Support")
|
679 |
-
target_col = st.selectbox("Select Target Variable for Diagnosis", data.columns.tolist())
|
680 |
-
num_cols = data.select_dtypes(include=np.number).columns.tolist()
|
681 |
-
selected_cols_diagnosis = st.multiselect("Select Feature Variables for Diagnosis", num_cols)
|
682 |
-
if st.button("Generate Diagnosis"):
|
683 |
-
if target_col and selected_cols_diagnosis:
|
684 |
-
with st.spinner("Generating Diagnosis"):
|
685 |
-
result = st.session_state.diagnosis_support.diagnose(data, target_col=target_col, columns=selected_cols_diagnosis, diagnosis_key="diagnosis_result")
|
686 |
-
st.json(result)
|
687 |
-
|
688 |
-
st.subheader("Treatment Recommendation")
|
689 |
-
condition_col = st.selectbox("Select Condition Column for Treatment Recommendation", data.columns.tolist())
|
690 |
-
treatment_col = st.selectbox("Select Treatment Column for Treatment Recommendation", data.columns.tolist())
|
691 |
-
if st.button("Generate Treatment Recommendation"):
|
692 |
-
if condition_col and treatment_col:
|
693 |
-
with st.spinner("Generating Treatment Recommendation"):
|
694 |
-
result = st.session_state.treatment_recommendation.recommend(data, condition_col = condition_col, treatment_col = treatment_col, recommendation_key="treatment_recommendation")
|
695 |
-
st.json(result)
|
696 |
-
|
697 |
-
with reports_tab:
|
698 |
-
st.header("Reports")
|
699 |
-
report_name = st.text_input("Report Name")
|
700 |
-
report_def = st.text_area("Report definition")
|
701 |
-
if st.button("Create Report Definition"):
|
702 |
-
st.session_state.automated_reports.create_report_definition(report_name, report_def)
|
703 |
-
st.success("Report definition created")
|
704 |
-
if selected_data_key:
|
705 |
-
data = st.session_state.data
|
706 |
-
if st.button("Generate Report"):
|
707 |
-
with st.spinner("Generating Report..."):
|
708 |
-
report = st.session_state.automated_reports.generate_report(report_name, data)
|
709 |
-
with knowledge_tab:
|
710 |
-
st.header("Medical Knowledge")
|
711 |
-
query = st.text_input("Enter your medical question here:")
|
712 |
-
if st.button("Search"):
|
713 |
-
with st.spinner("Searching..."):
|
714 |
-
result = st.session_state.knowledge_base.search_medical_info(query)
|
715 |
-
st.write(result)
|
716 |
-
|
717 |
-
if __name__ == "__main__":
|
718 |
-
main()
|
|
|
18 |
import numpy as np
|
19 |
from scipy.stats import ttest_ind, f_oneway
|
20 |
import json
|
21 |
+
from Bio import Entrez
|
22 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
23 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
24 |
|
25 |
# Initialize Groq Client
|
26 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
27 |
|
28 |
+
|
29 |
# ---------------------- Base Classes and Schemas ---------------------------
|
30 |
class ResearchInput(BaseModel):
|
31 |
"""Base schema for research tool inputs"""
|
|
|
291 |
pass
|
292 |
|
293 |
class SimpleMedicalKnowledge(MedicalKnowledgeBase):
|
294 |
+
"""Simple Medical Knowledge Class with TF-IDF and PubMed"""
|
295 |
+
def __init__(self):
|
296 |
+
self.knowledge_base = {
|
297 |
+
"diabetes": "The recommended treatment for diabetes includes lifestyle changes, medication, and monitoring.",
|
298 |
+
"heart disease": "Risk factors for heart disease include high blood pressure, high cholesterol, and smoking.",
|
299 |
+
"fever": "For a fever, you can consider over-the-counter medications like acetaminophen or ibuprofen. Rest and hydration are also important.",
|
300 |
+
"headache": "For a headache, try rest, hydration, and over-the-counter pain relievers. Consult a doctor if it is severe or persistent.",
|
301 |
+
"cold": "For a cold, get rest, drink plenty of fluids, and use over-the-counter remedies like decongestants."
|
302 |
+
}
|
303 |
+
self.vectorizer = TfidfVectorizer()
|
304 |
+
self.tfidf_matrix = self.vectorizer.fit_transform(self.knowledge_base.values())
|
305 |
+
|
306 |
+
def search_pubmed(self, query: str, email: str) -> str:
|
307 |
+
try:
|
308 |
+
Entrez.email = email
|
309 |
+
handle = Entrez.esearch(db="pubmed", term=query, retmax=1)
|
310 |
+
record = Entrez.read(handle)
|
311 |
+
handle.close()
|
312 |
+
if record["IdList"]:
|
313 |
+
handle = Entrez.efetch(db="pubmed", id=record["IdList"][0], rettype="abstract", retmode="text")
|
314 |
+
abstract = handle.read()
|
315 |
+
handle.close()
|
316 |
+
return abstract
|
317 |
+
else:
|
318 |
+
return "No abstracts found for this query on PubMed"
|
319 |
+
except Exception as e:
|
320 |
+
return f"Error searching pubmed {e}"
|
321 |
+
|
322 |
+
|
323 |
+
def search_medical_info(self, query: str, pub_email: str = "") -> str:
|
324 |
try:
|
325 |
+
query_vector = self.vectorizer.transform([query])
|
326 |
+
similarities = cosine_similarity(query_vector, self.tfidf_matrix)
|
327 |
+
best_match_index = np.argmax(similarities)
|
328 |
+
best_match_keyword = list(self.knowledge_base.keys())[best_match_index]
|
329 |
+
best_match_info = list(self.knowledge_base.values())[best_match_index]
|
330 |
+
|
331 |
+
pubmed_result = self.search_pubmed(query, pub_email)
|
332 |
+
if "No abstracts found for this query on PubMed" not in pubmed_result:
|
333 |
+
return f"Based on the query provided, I found this: {best_match_info} \n\nFrom Pubmed I also found the following abstract: \n {pubmed_result}"
|
334 |
else:
|
335 |
+
return f"Based on the query provided, I found this: {best_match_info} \n\n{pubmed_result}"
|
336 |
except Exception as e:
|
337 |
+
return f"Medical Knowledge Search Failed {e}"
|
338 |
|
339 |
|
340 |
class ForecastingEngine(ABC):
|
|
|
531 |
st.session_state.treatment_recommendation = BasicTreatmentRecommendation()
|
532 |
if 'knowledge_base' not in st.session_state:
|
533 |
st.session_state.knowledge_base = SimpleMedicalKnowledge()
|
534 |
+
if 'pub_email' not in st.session_state:
|
535 |
+
st.session_state.pub_email = st.secrets.get("PUB_EMAIL", "") # Load PUB_EMAIL from secrets
|
536 |
|
|
|
537 |
# Sidebar for Data Management
|
538 |
with st.sidebar:
|
539 |
st.header("⚙️ Data Management")
|
|
|
706 |
st.json(result)
|
707 |
with insights_tab:
|
708 |
if selected_data_key:
|
709 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|