Spaces:
Runtime error
Runtime error
import os | |
import json | |
import uuid | |
import numpy as np | |
from datetime import datetime | |
from flask import Flask, request, jsonify, send_from_directory | |
from flask_cors import CORS | |
from werkzeug.utils import secure_filename | |
import google.generativeai as genai | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from transformers import pipeline | |
import faiss | |
import markdown | |
# Configuration | |
GEMINI_API_KEY = ( | |
"AIzaSyBbb8rH6ksakMg_v2W6hvUNzgHDI3lxWk0" # Replace with your actual API key | |
) | |
genai.configure(api_key=GEMINI_API_KEY) | |
# Initialize Flask app | |
app = Flask(__name__, static_folder="../frontend", static_url_path="") | |
CORS(app) | |
# RAG Model Initialization | |
print("π Initializing RAG System...") | |
# Load medical guidelines dataset | |
print("π Loading dataset...") | |
dataset = load_dataset("epfl-llm/guidelines", split="train") | |
TITLE_COL = "title" | |
CONTENT_COL = "clean_text" | |
# Initialize models | |
print("π€ Loading AI models...") | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
qa_pipeline = pipeline( | |
"question-answering", model="distilbert-base-cased-distilled-squad" | |
) | |
# Build FAISS index | |
print("π Building FAISS index...") | |
def embed_text(batch): | |
combined_texts = [ | |
f"{title} {content[:200]}" | |
for title, content in zip(batch[TITLE_COL], batch[CONTENT_COL]) | |
] | |
return {"embeddings": embedder.encode(combined_texts, show_progress_bar=False)} | |
dataset = dataset.map(embed_text, batched=True, batch_size=32) | |
dataset.add_faiss_index(column="embeddings") | |
# Processing Functions | |
def format_response(text): | |
"""Convert Markdown text to HTML for proper frontend display.""" | |
return markdown.markdown(text) | |
def summarize_report(report): | |
"""Generate a clinical summary using QA and Gemini model.""" | |
questions = [ | |
"Patient's age?", | |
"Patient's gender?", | |
"Current symptoms?", | |
"Medical history?", | |
] | |
answers = [] | |
for q in questions: | |
result = qa_pipeline(question=q, context=report) | |
answers.append(result["answer"] if result["score"] > 0.1 else "Not specified") | |
model = genai.GenerativeModel("gemini-1.5-flash") | |
prompt = f"""Create clinical summary from: | |
- Age: {answers[0]} | |
- Gender: {answers[1]} | |
- Symptoms: {answers[2]} | |
- History: {answers[3]} | |
Format: "[Age] [Gender] with [History], presenting with [Symptoms]" | |
Add relevant medical context.""" | |
summary = model.generate_content(prompt).text.strip() | |
print(f"Generated Summary: {summary}") # Debugging log | |
return format_response(summary) | |
def rag_retrieval(query, k=3): | |
"""Retrieve relevant guidelines using FAISS.""" | |
query_embedding = embedder.encode([query]) | |
scores, examples = dataset.get_nearest_examples("embeddings", query_embedding, k=k) | |
return [ | |
{ | |
"title": title, | |
"content": content[:1000], | |
"source": examples.get("source", ["N/A"] * len(examples[TITLE_COL]))[i], | |
"score": float(score), | |
} | |
for i, (title, content, score) in enumerate( | |
zip(examples[TITLE_COL], examples[CONTENT_COL], scores) | |
) | |
] | |
def generate_recommendations(report): | |
"""Generate treatment recommendations with RAG context.""" | |
guidelines = rag_retrieval(report) | |
context = "Relevant Clinical Guidelines:\n" + "\n".join( | |
[f"β’ {g['title']}: {g['content']} [Source: {g['source']}]" for g in guidelines] | |
) | |
model = genai.GenerativeModel("gemini-1.5-flash") | |
prompt = f"""Generate treatment recommendations using these guidelines: | |
{context} | |
Patient Presentation: | |
{report} | |
Format with: | |
- Bold section headers | |
- Clear bullet points | |
- Evidence markers [Guideline #] | |
- Risk-benefit analysis | |
- Include references to the sources provided where applicable | |
""" | |
recommendations = model.generate_content(prompt).text.strip() | |
references = [g["source"] for g in guidelines if g["source"] != "N/A"] | |
return format_response(recommendations), references | |
def generate_risk_assessment(summary): | |
"""Generate risk assessment using the summary.""" | |
model = genai.GenerativeModel("gemini-1.5-flash") | |
prompt = f"""Analyze clinical risk: | |
{summary} | |
Output format: | |
Risk Score: 0-100 | |
Alert Level: π΄ High/π‘ Medium/π’ Low | |
Key Risk Factors: bullet points | |
Recommended Actions: bullet points""" | |
return format_response(model.generate_content(prompt).text.strip()) | |
# Flask Endpoints | |
def handle_upload(): | |
"""Handle text file upload and return processed data.""" | |
if "file" not in request.files: | |
return jsonify({"error": "No file provided"}), 400 | |
file = request.files["file"] | |
if not file or not file.filename.endswith(".txt"): | |
return jsonify({"error": "Invalid file, must be a .txt file"}), 400 | |
try: | |
content = file.read().decode("utf-8") | |
if not content.strip(): | |
return jsonify({"error": "File is empty"}), 400 | |
summary = summarize_report(content) | |
recommendations, references = generate_recommendations(content) | |
risk_assessment = generate_risk_assessment(summary) | |
response = { | |
"session_id": str(uuid.uuid4()), | |
"timestamp": datetime.now().isoformat(), | |
"summary": summary, | |
"recommendations": recommendations, | |
"risk_assessment": risk_assessment, | |
"references": references, | |
} | |
print( | |
f"Response Sent to Frontend: {json.dumps(response, indent=2)}" | |
) # Debugging log | |
return jsonify(response) | |
except Exception as e: | |
return jsonify({"error": f"Processing failed: {str(e)}"}), 500 | |
def serve_index(): | |
"""Serve the index.html file.""" | |
return send_from_directory(app.static_folder, "index.html") | |
def serve_static(path): | |
"""Serve other static files from the frontend directory.""" | |
return send_from_directory(app.static_folder, path) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=5000, debug=True) | |