badal-12 commited on
Commit
ad8d809
Β·
verified Β·
1 Parent(s): 0ce7701

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import uuid
4
+ import numpy as np
5
+ from datetime import datetime
6
+ from flask import Flask, request, jsonify, send_from_directory
7
+ from flask_cors import CORS
8
+ from werkzeug.utils import secure_filename
9
+ import google.generativeai as genai
10
+ from datasets import load_dataset
11
+ from sentence_transformers import SentenceTransformer
12
+ from transformers import pipeline
13
+ import faiss
14
+ import markdown
15
+
16
+ # Configuration
17
+ GEMINI_API_KEY = (
18
+ "AIzaSyBbb8rH6ksakMg_v2W6hvUNzgHDI3lxWk0" # Replace with your actual API key
19
+ )
20
+ genai.configure(api_key=GEMINI_API_KEY)
21
+
22
+ # Initialize Flask app
23
+ app = Flask(__name__, static_folder="../frontend", static_url_path="")
24
+ CORS(app)
25
+
26
+ # RAG Model Initialization
27
+ print("πŸš€ Initializing RAG System...")
28
+
29
+ # Load medical guidelines dataset
30
+ print("πŸ“‚ Loading dataset...")
31
+ dataset = load_dataset("epfl-llm/guidelines", split="train")
32
+ TITLE_COL = "title"
33
+ CONTENT_COL = "clean_text"
34
+
35
+ # Initialize models
36
+ print("πŸ€– Loading AI models...")
37
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
38
+ qa_pipeline = pipeline(
39
+ "question-answering", model="distilbert-base-cased-distilled-squad"
40
+ )
41
+
42
+ # Build FAISS index
43
+ print("πŸ” Building FAISS index...")
44
+
45
+
46
+ def embed_text(batch):
47
+ combined_texts = [
48
+ f"{title} {content[:200]}"
49
+ for title, content in zip(batch[TITLE_COL], batch[CONTENT_COL])
50
+ ]
51
+ return {"embeddings": embedder.encode(combined_texts, show_progress_bar=False)}
52
+
53
+
54
+ dataset = dataset.map(embed_text, batched=True, batch_size=32)
55
+ dataset.add_faiss_index(column="embeddings")
56
+
57
+
58
+ # Processing Functions
59
+ def format_response(text):
60
+ """Convert Markdown text to HTML for proper frontend display."""
61
+ return markdown.markdown(text)
62
+
63
+
64
+ def summarize_report(report):
65
+ """Generate a clinical summary using QA and Gemini model."""
66
+ questions = [
67
+ "Patient's age?",
68
+ "Patient's gender?",
69
+ "Current symptoms?",
70
+ "Medical history?",
71
+ ]
72
+
73
+ answers = []
74
+ for q in questions:
75
+ result = qa_pipeline(question=q, context=report)
76
+ answers.append(result["answer"] if result["score"] > 0.1 else "Not specified")
77
+
78
+ model = genai.GenerativeModel("gemini-1.5-flash")
79
+ prompt = f"""Create clinical summary from:
80
+ - Age: {answers[0]}
81
+ - Gender: {answers[1]}
82
+ - Symptoms: {answers[2]}
83
+ - History: {answers[3]}
84
+
85
+ Format: "[Age] [Gender] with [History], presenting with [Symptoms]"
86
+ Add relevant medical context."""
87
+ summary = model.generate_content(prompt).text.strip()
88
+ print(f"Generated Summary: {summary}") # Debugging log
89
+ return format_response(summary)
90
+
91
+
92
+ def rag_retrieval(query, k=3):
93
+ """Retrieve relevant guidelines using FAISS."""
94
+ query_embedding = embedder.encode([query])
95
+ scores, examples = dataset.get_nearest_examples("embeddings", query_embedding, k=k)
96
+ return [
97
+ {
98
+ "title": title,
99
+ "content": content[:1000],
100
+ "source": examples.get("source", ["N/A"] * len(examples[TITLE_COL]))[i],
101
+ "score": float(score),
102
+ }
103
+ for i, (title, content, score) in enumerate(
104
+ zip(examples[TITLE_COL], examples[CONTENT_COL], scores)
105
+ )
106
+ ]
107
+
108
+
109
+ def generate_recommendations(report):
110
+ """Generate treatment recommendations with RAG context."""
111
+ guidelines = rag_retrieval(report)
112
+ context = "Relevant Clinical Guidelines:\n" + "\n".join(
113
+ [f"β€’ {g['title']}: {g['content']} [Source: {g['source']}]" for g in guidelines]
114
+ )
115
+
116
+ model = genai.GenerativeModel("gemini-1.5-flash")
117
+ prompt = f"""Generate treatment recommendations using these guidelines:
118
+ {context}
119
+
120
+ Patient Presentation:
121
+ {report}
122
+
123
+ Format with:
124
+ - Bold section headers
125
+ - Clear bullet points
126
+ - Evidence markers [Guideline #]
127
+ - Risk-benefit analysis
128
+ - Include references to the sources provided where applicable
129
+ """
130
+ recommendations = model.generate_content(prompt).text.strip()
131
+ references = [g["source"] for g in guidelines if g["source"] != "N/A"]
132
+ return format_response(recommendations), references
133
+
134
+
135
+ def generate_risk_assessment(summary):
136
+ """Generate risk assessment using the summary."""
137
+ model = genai.GenerativeModel("gemini-1.5-flash")
138
+ prompt = f"""Analyze clinical risk:
139
+ {summary}
140
+
141
+ Output format:
142
+ Risk Score: 0-100
143
+ Alert Level: πŸ”΄ High/🟑 Medium/🟒 Low
144
+ Key Risk Factors: bullet points
145
+ Recommended Actions: bullet points"""
146
+ return format_response(model.generate_content(prompt).text.strip())
147
+
148
+
149
+ # Flask Endpoints
150
+ @app.route("/upload-txt", methods=["POST"])
151
+ def handle_upload():
152
+ """Handle text file upload and return processed data."""
153
+ if "file" not in request.files:
154
+ return jsonify({"error": "No file provided"}), 400
155
+
156
+ file = request.files["file"]
157
+ if not file or not file.filename.endswith(".txt"):
158
+ return jsonify({"error": "Invalid file, must be a .txt file"}), 400
159
+
160
+ try:
161
+ content = file.read().decode("utf-8")
162
+ if not content.strip():
163
+ return jsonify({"error": "File is empty"}), 400
164
+
165
+ summary = summarize_report(content)
166
+ recommendations, references = generate_recommendations(content)
167
+ risk_assessment = generate_risk_assessment(summary)
168
+
169
+ response = {
170
+ "session_id": str(uuid.uuid4()),
171
+ "timestamp": datetime.now().isoformat(),
172
+ "summary": summary,
173
+ "recommendations": recommendations,
174
+ "risk_assessment": risk_assessment,
175
+ "references": references,
176
+ }
177
+ print(
178
+ f"Response Sent to Frontend: {json.dumps(response, indent=2)}"
179
+ ) # Debugging log
180
+ return jsonify(response)
181
+ except Exception as e:
182
+ return jsonify({"error": f"Processing failed: {str(e)}"}), 500
183
+
184
+
185
+ @app.route("/")
186
+ def serve_index():
187
+ """Serve the index.html file."""
188
+ return send_from_directory(app.static_folder, "index.html")
189
+
190
+
191
+ @app.route("/<path:path>")
192
+ def serve_static(path):
193
+ """Serve other static files from the frontend directory."""
194
+ return send_from_directory(app.static_folder, path)
195
+
196
+
197
+ if __name__ == "__main__":
198
+ app.run(host="0.0.0.0", port=5000, debug=True)