import os import json import uuid import numpy as np from datetime import datetime from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS from werkzeug.utils import secure_filename import google.generativeai as genai from datasets import load_dataset from sentence_transformers import SentenceTransformer from transformers import pipeline import faiss import markdown # Configuration GEMINI_API_KEY = ( "AIzaSyBbb8rH6ksakMg_v2W6hvUNzgHDI3lxWk0" # Replace with your actual API key ) genai.configure(api_key=GEMINI_API_KEY) # Initialize Flask app app = Flask(__name__, static_folder="../frontend", static_url_path="") CORS(app) # RAG Model Initialization print("🚀 Initializing RAG System...") # Load medical guidelines dataset print("📂 Loading dataset...") dataset = load_dataset("epfl-llm/guidelines", split="train") TITLE_COL = "title" CONTENT_COL = "clean_text" # Initialize models print("🤖 Loading AI models...") embedder = SentenceTransformer("all-MiniLM-L6-v2") qa_pipeline = pipeline( "question-answering", model="distilbert-base-cased-distilled-squad" ) # Build FAISS index print("🔍 Building FAISS index...") def embed_text(batch): combined_texts = [ f"{title} {content[:200]}" for title, content in zip(batch[TITLE_COL], batch[CONTENT_COL]) ] return {"embeddings": embedder.encode(combined_texts, show_progress_bar=False)} dataset = dataset.map(embed_text, batched=True, batch_size=32) dataset.add_faiss_index(column="embeddings") # Processing Functions def format_response(text): """Convert Markdown text to HTML for proper frontend display.""" return markdown.markdown(text) def summarize_report(report): """Generate a clinical summary using QA and Gemini model.""" questions = [ "Patient's age?", "Patient's gender?", "Current symptoms?", "Medical history?", ] answers = [] for q in questions: result = qa_pipeline(question=q, context=report) answers.append(result["answer"] if result["score"] > 0.1 else "Not specified") model = genai.GenerativeModel("gemini-1.5-flash") prompt = f"""Create clinical summary from: - Age: {answers[0]} - Gender: {answers[1]} - Symptoms: {answers[2]} - History: {answers[3]} Format: "[Age] [Gender] with [History], presenting with [Symptoms]" Add relevant medical context.""" summary = model.generate_content(prompt).text.strip() print(f"Generated Summary: {summary}") # Debugging log return format_response(summary) def rag_retrieval(query, k=3): """Retrieve relevant guidelines using FAISS.""" query_embedding = embedder.encode([query]) scores, examples = dataset.get_nearest_examples("embeddings", query_embedding, k=k) return [ { "title": title, "content": content[:1000], "source": examples.get("source", ["N/A"] * len(examples[TITLE_COL]))[i], "score": float(score), } for i, (title, content, score) in enumerate( zip(examples[TITLE_COL], examples[CONTENT_COL], scores) ) ] def generate_recommendations(report): """Generate treatment recommendations with RAG context.""" guidelines = rag_retrieval(report) context = "Relevant Clinical Guidelines:\n" + "\n".join( [f"• {g['title']}: {g['content']} [Source: {g['source']}]" for g in guidelines] ) model = genai.GenerativeModel("gemini-1.5-flash") prompt = f"""Generate treatment recommendations using these guidelines: {context} Patient Presentation: {report} Format with: - Bold section headers - Clear bullet points - Evidence markers [Guideline #] - Risk-benefit analysis - Include references to the sources provided where applicable """ recommendations = model.generate_content(prompt).text.strip() references = [g["source"] for g in guidelines if g["source"] != "N/A"] return format_response(recommendations), references def generate_risk_assessment(summary): """Generate risk assessment using the summary.""" model = genai.GenerativeModel("gemini-1.5-flash") prompt = f"""Analyze clinical risk: {summary} Output format: Risk Score: 0-100 Alert Level: 🔴 High/🟡 Medium/🟢 Low Key Risk Factors: bullet points Recommended Actions: bullet points""" return format_response(model.generate_content(prompt).text.strip()) # Flask Endpoints @app.route("/upload-txt", methods=["POST"]) def handle_upload(): """Handle text file upload and return processed data.""" if "file" not in request.files: return jsonify({"error": "No file provided"}), 400 file = request.files["file"] if not file or not file.filename.endswith(".txt"): return jsonify({"error": "Invalid file, must be a .txt file"}), 400 try: content = file.read().decode("utf-8") if not content.strip(): return jsonify({"error": "File is empty"}), 400 summary = summarize_report(content) recommendations, references = generate_recommendations(content) risk_assessment = generate_risk_assessment(summary) response = { "session_id": str(uuid.uuid4()), "timestamp": datetime.now().isoformat(), "summary": summary, "recommendations": recommendations, "risk_assessment": risk_assessment, "references": references, } print( f"Response Sent to Frontend: {json.dumps(response, indent=2)}" ) # Debugging log return jsonify(response) except Exception as e: return jsonify({"error": f"Processing failed: {str(e)}"}), 500 @app.route("/") def serve_index(): """Serve the index.html file.""" return send_from_directory(app.static_folder, "index.html") @app.route("/") def serve_static(path): """Serve other static files from the frontend directory.""" return send_from_directory(app.static_folder, path) if __name__ == "__main__": app.run(host="0.0.0.0", port=5000, debug=True)