ftx7go's picture
Create app.py
bf254f3 verified
raw
history blame
11.8 kB
from fastapi import FastAPI, File, UploadFile from fastapi.responses import HTMLResponse from transformers import pipeline from PIL import Image, ImageDraw import numpy as np import io import uvicorn import base64 app = FastAPI() # Chargement des modèles def load_models(): return { "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), "KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), "RöntgenMeister": pipeline("image-classification", model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") } models = load_models() def translate_label(label): translations = { "fracture": "Knochenbruch", "no fracture": "Kein Knochenbruch", "normal": "Normal", "abnormal": "Auffällig", "F1": "Knochenbruch", "NF": "Kein Knochenbruch" } return translations.get(label.lower(), label) def create_heatmap_overlay(image, box, score): overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) x1, y1 = box['xmin'], box['ymin'] x2, y2 = box['xmax'], box['ymax'] if score > 0.8: fill_color = (255, 0, 0, 100) border_color = (255, 0, 0, 255) elif score > 0.6: fill_color = (255, 165, 0, 100) border_color = (255, 165, 0, 255) else: fill_color = (255, 255, 0, 100) border_color = (255, 255, 0, 255) draw.rectangle([x1, y1, x2, y2], fill=fill_color) draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2) return overlay def draw_boxes(image, predictions): result_image = image.copy().convert('RGBA') for pred in predictions: box = pred['box'] score = pred['score'] overlay = create_heatmap_overlay(image, box, score) result_image = Image.alpha_composite(result_image, overlay) draw = ImageDraw.Draw(result_image) temp = 36.5 + (score * 2.5) label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)" text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label) draw.rectangle(text_bbox, fill=(0, 0, 0, 180)) draw.text( (box['xmin'], box['ymin']-20), label, fill=(255, 255, 255, 255) ) return result_image def image_to_base64(image): buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/png;base64,{img_str}" COMMON_STYLES = """ body { font-family: system-ui, -apple-system, sans-serif; background: #f0f2f5; margin: 0; padding: 20px; color: #1a1a1a; } ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: transparent; } ::-webkit-scrollbar-thumb { background-color: rgba(156, 163, 175, 0.5); border-radius: 4px; } .container { max-width: 1200px; margin: 0 auto; background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } .button { background: #2d2d2d; color: white; border: none; padding: 12px 30px; border-radius: 8px; cursor: pointer; font-size: 1.1em; transition: all 0.3s ease; position: relative; } .button:hover { background: #404040; } @keyframes progress { 0% { width: 0; } 100% { width: 100%; } } .button-progress { position: absolute; bottom: 0; left: 0; height: 4px; background: rgba(255, 255, 255, 0.5); width: 0; } .button:active .button-progress { animation: progress 2s linear forwards; } img { max-width: 100%; height: auto; border-radius: 8px; } @keyframes blink { 0% { opacity: 1; } 50% { opacity: 0; } 100% { opacity: 1; } } #loading { display: none; color: white; margin-top: 10px; animation: blink 1s infinite; text-align: center; } """ @app.get("/", response_class=HTMLResponse) async def main(): content = f""" <!DOCTYPE html> <html> <head> <title>Fraktur Detektion</title> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <style> {COMMON_STYLES} .upload-section {{ background: #2d2d2d; padding: 40px; border-radius: 12px; margin: 20px 0; text-align: center; border: 2px dashed #404040; transition: all 0.3s ease; color: white; }} .upload-section:hover {{ border-color: #555; }} input[type="file"] {{ font-size: 1.1em; margin: 20px 0; color: white; }} input[type="file"]::file-selector-button {{ font-size: 1em; padding: 10px 20px; border-radius: 8px; border: 1px solid #404040; background: #2d2d2d; color: white; transition: all 0.3s ease; cursor: pointer; }} input[type="file"]::file-selector-button:hover {{ background: #404040; }} .confidence-slider {{ width: 100%; max-width: 300px; margin: 20px auto; }} input[type="range"] {{ width: 100%; height: 8px; border-radius: 4px; background: #404040; outline: none; transition: all 0.3s ease; -webkit-appearance: none; }} input[type="range"]::-webkit-slider-thumb {{ -webkit-appearance: none; width: 20px; height: 20px; border-radius: 50%; background: white; cursor: pointer; border: none; }} </style> </head> <body> <div class="container"> <div class="upload-section"> <form action="/analyze" method="post" enctype="multipart/form-data" onsubmit="document.getElementById('loading').style.display = 'block';"> <div> <input type="file" name="file" accept="image/*" required> </div> <div class="confidence-slider"> <label for="threshold">Konfidenzschwelle: <span id="thresholdValue">0.60</span></label> <input type="range" id="threshold" name="threshold" min="0" max="1" step="0.05" value="0.60" oninput="document.getElementById('thresholdValue').textContent = parseFloat(this.value).toFixed(2)"> </div> <button type="submit" class="button"> Analysieren <div class="button-progress"></div> </button> <div id="loading">Loading...</div> </form> </div> </div> </body> </html> """ return content @app.post("/analyze", response_class=HTMLResponse) async def analyze_file(file: UploadFile = File(...)): try: contents = await file.read() image = Image.open(io.BytesIO(contents)) predictions_watcher = models["KnochenWächter"](image) predictions_master = models["RöntgenMeister"](image) predictions_locator = models["KnochenAuge"](image) filtered_preds = [p for p in predictions_locator if p['score'] >= 0.6] if filtered_preds: result_image = draw_boxes(image, filtered_preds) else: result_image = image result_image_b64 = image_to_base64(result_image) results_html = f""" <!DOCTYPE html> <html> <head> <title>Ergebnisse</title> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <style> {COMMON_STYLES} .results-grid {{ display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin-top: 20px; }} .result-box {{ background: white; padding: 20px; border-radius: 12px; margin: 10px 0; border: 1px solid #e9ecef; }} .score-high {{ color: #0066cc; font-weight: bold; }} .score-medium {{ color: #ffa500; font-weight: bold; }} .back-button {{ display: inline-block; text-decoration: none; margin-top: 20px; }} h3 {{ color: #0066cc; margin-top: 0; }} @media (max-width: 768px) {{ .results-grid {{ grid-template-columns: 1fr; }} }} </style> </head> <body> <div class="container"> <div class="results-grid"> <div> <div class="result-box"><h3>KnochenWächter</h3> """ for pred in predictions_watcher: confidence_class = "score-high" if pred['score'] > 0.7 else "score-medium" results_html += f""" <div> <span class="{confidence_class}">{pred['score']:.1%}</span> - {translate_label(pred['label'])} </div> """ results_html += "</div>" results_html += "<div class='result-box'><h3>RöntgenMeister</h3>" for pred in predictions_master: confidence_class = "score-high" if pred['score'] > 0.7 else "score-medium" results_html += f""" <div> <span class="{confidence_class}">{pred['score']:.1%}</span> - {translate_label(pred['label'])} </div> """ results_html += "</div></div>" results_html += f""" <div class='result-box'> <h3>Fraktur Lokalisation</h3> <img src="{result_image_b64}" alt="Analyzed image"> </div> </div> <a href="/" class="button back-button"> ← Zurück <div class="button-progress"></div> </a> </div> </body> </html> """ return results_html except Exception as e: return f""" <!DOCTYPE html> <html> <head> <title>Fehler</title> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <style> {COMMON_STYLES} .error-box {{ background: #fee2e2; border: 1px solid #ef4444; padding: 20px; border-radius: 8px; margin: 20px 0; }} </style> </head> <body> <div class="container"> <div class="error-box"> <h3>Fehler</h3> <p>{str(e)}</p> </div> <a href="/" class="button back-button"> ← Zurück <div class="button-progress"></div> </a> </div> </body> </html> """ if __name__ == "__main__": uvicorn.run(aimport gradio as gr
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing import image
from PIL import Image
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
import os
# Load the trained model once
model = tf.keras.models.load_model("my_keras_model.h5")
image_size = (224, 224) # Ensure consistent image size
# Function to analyze injury severity
def analyze_injury(prediction):
if prediction < 0.3:
return "Mild", "Rest and pain relief.", "₹2,000 - ₹5,000", "₹10,000 - ₹20,000"
elif 0.3 <= prediction < 0.7:
return "Moderate", "Plaster cast or minor surgery.", "₹8,000 - ₹15,000", "₹30,000 - ₹60,000"
else:
return "Severe", "Major surgery with metal implants.", "₹20,000 - ₹50,000", "₹1,00,000+"
# Function to generate report
def generate_report(patient_name, age, gender, xray1_path, xray2_path):
if not os.path.exists(xray1_path) or not os.path.exists(xray2_path):
return "Error: One or both X-ray images are missing!"
try:
# Process X-ray 1
img1 = Image.open(xray1_path).resize(image_size).convert("RGB")
img_array1 = image.img_to_array(img1)
img_array1 = np.expand_dims(img_array1, axis=0) / 255.0
prediction1 = model.predict(img_array1)[0][0]
# Process X-ray 2
img2 = Image.open(xray2_path).resize(image_size).convert("RGB")
img_array2 = image.img_to_array(img2)
img_array2 = np.expand_dims(img_array2, axis=0) / 255.0
prediction2 = model.predict(img_array2)[0][0]
# Get final analysis
avg_prediction = (prediction1 + prediction2) / 2
predicted_class = "Fractured" if avg_prediction > 0.5 else "Normal"
severity, treatment, gov_cost, private_cost = analyze_injury(avg_prediction)
# Generate PDF Report
report_path = f"{patient_name}_fracture_report.pdf"
c = canvas.Canvas(report_path, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(100, 750, f"Patient Name: {patient_name}")
c.drawString(100, 730, f"Age: {age}")
c.drawString(100, 710, f"Gender: {gender}")
c.drawString(100, 690, f"Diagnosis: {predicted_class}")
c.drawString(100, 670, f"Injury Severity: {severity}")
c.drawString(100, 650, f"Recommended Treatment: {treatment}")
c.drawString(100, 630, f"Estimated Cost (Govt Hospital): {gov_cost}")
c.drawString(100, 610, f"Estimated Cost (Private Hospital): {private_cost}")
c.save()
if os.path.exists(report_path):
return report_path
else:
return "Error: Report generation failed!"
except Exception as e:
return f"Error generating report: {str(e)}"
# Define Gradio Interface
interface = gr.Interface(
fn=generate_report,
inputs=[
gr.Textbox(label="Patient Name"),
gr.Number(label="Age"),
gr.Radio(["Male", "Female", "Other"], label="Gender"),
gr.Image(type="filepath", label="Upload X-ray Image 1"),
gr.Image(type="filepath", label="Upload X-ray Image 2"),
],
outputs=gr.File(label="Download Report"),
title="Bone Fracture Detection & Medical Report",
description="Enter patient details, upload two X-ray images, and generate a detailed medical report with treatment suggestions and cost estimates."
)
if __name__ == "__main__":
interfacepp, host="0.0.0.0", port=7860)