radiologist / app.py
ftx7go's picture
Create app.py
c08364f verified
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import HTMLResponse, Response
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
import io
import uvicorn
import base64
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Image as ReportLabImage, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.colors import red, blue, black
from reportlab.lib.units import inch
app = FastAPI()
# Chargement des modèles
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
models = load_models()
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Knochenbruch",
"normal": "Normal",
"abnormal": "Auffällig",
"F1": "Knochenbruch",
"NF": "Kein Knochenbruch"
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
if score > 0.8:
fill_color = (255, 0, 0, 100)
border_color = (255, 0, 0, 255)
elif score > 0.6:
fill_color = (255, 165, 0, 100)
border_color = (255, 165, 0, 255)
else:
fill_color = (255, 255, 0, 100)
border_color = (255, 255, 0, 255)
draw.rectangle([x1, y1, x2, y2], fill=fill_color)
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
for pred in predictions:
box = pred['box']
score = pred['score']
overlay = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, overlay)
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
label = f"{translate_label(pred['label'])} ({score:.1%}{temp:.1f}°C)"
text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
draw.text(
(box['xmin'], box['ymin']-20),
label,
fill=(255, 255, 255, 255)
)
return result_image
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
def generate_report(patient_name, analyzed_image_bytes, prediction, confidence):
buffer = io.BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=letter)
styles = getSampleStyleSheet()
title_style = ParagraphStyle(
name='TitleStyle',
parent=styles['Normal'],
fontSize=16,
textColor=blue,
alignment=1 # Center alignment
)
heading_style = ParagraphStyle(
name='HeadingStyle',
parent=styles['Normal'],
fontSize=12,
textColor=red
)
prediction_style = ParagraphStyle(
name='PredictionStyle',
parent=styles['Normal'],
fontSize=14,
alignment=1
)
story = []
# Hospital Name
hospital_name = Paragraph("youesh hospital , mumbai ( west )", title_style)
story.append(hospital_name)
story.append(Spacer(1, 0.2*inch))
# Patient Greeting
greeting = Paragraph(f"hello , {patient_name} thank you for using our services this is your radiology report", heading_style)
story.append(greeting)
story.append(Spacer(1, 0.2*inch))
# Horizontal Line
story.append(Paragraph("<hr/>", styles['Normal']))
story.append(Spacer(1, 0.2*inch))
# Analyzed Image
img = ReportLabImage(io.BytesIO(analyzed_image_bytes), width=400, height=400, kind='direct')
story.append(img)
story.append(Spacer(1, 0.2*inch))
# Prediction
prediction_text = f"<b>Prediction:</b> {prediction.capitalize()}"
confidence_text = f"<b>Confidence:</b> {'Yes' if confidence > 0.6 else 'No'}"
story.append(Paragraph(prediction_text, prediction_style))
story.append(Paragraph(confidence_text, prediction_style))
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
COMMON_STYLES = """
body {
font-family: system-ui, -apple-system, sans-serif;
background: #f0f2f5;
margin: 0;
padding: 20px;
color: #1a1a1a;
}
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: transparent;
}
::-webkit-scrollbar-thumb {
background-color: rgba(156, 163, 175, 0.5);
border-radius: 4px;
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.button {
background: #404040; /* Changed button background color */
color: white;
border: none;
padding: 12px 30px;
border-radius: 8px;
cursor: pointer;
font-size: 1.1em;
transition: all 0.3s ease;
position: relative;
}
.button:hover {
background: #555;
}
@keyframes progress {
0% { width: 0; }
100% { width: 100%; }
}
.button-progress {
position: absolute;
bottom: 0;
left: 0;
height: 4px;
background: rgba(255, 255, 255, 0.5);
width: 0;
}
.button:active .button-progress {
animation: progress 2s linear forwards;
}
img {
max-width: 100%;
height: auto;
border-radius: 8px;
}
@keyframes blink {
0% { opacity: 1; }
50% { opacity: 0; }
100% { opacity: 1; }
}
#loading {
display: none;
color: white;
margin-top: 10px;
animation: blink 1s infinite;
text-align: center;
}
"""
@app.get("/", response_class=HTMLResponse)
async def main():
content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Fraktur Detektion</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.input-group {
margin-bottom: 20px;
}
.input-group label {
display: block;
margin-bottom: 5px;
color: #404040;
font-weight: bold;
}
.input-group input[type="text"] {
width: calc(100% - 22px);
padding: 10px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 1em;
}
.upload-section {
background: #2d2d2d;
padding: 40px;
border-radius: 12px;
margin: 20px 0;
text-align: center;
border: 2px dashed #404040;
transition: all 0.3s ease;
color: white;
}
.upload-section:hover {
border-color: #555;
}
input[type="file"] {
font-size: 1.1em;
margin: 20px 0;
color: white;
}
input[type="file"]::file-selector-button {
font-size: 1em;
padding: 10px 20px;
border-radius: 8px;
border: 1px solid #404040;
background: #2d2d2d;
color: white;
transition: all 0.3s ease;
cursor: pointer;
}
input[type="file"]::file-selector-button:hover {
background: #404040;
}
</style>
</head>
<body>
<div class="container">
<form action="/analyze" method="post" enctype="multipart/form-data" onsubmit="document.getElementById('loading').style.display = 'block';">
<div class="input-group">
<label for="name">Name:</label>
<input type="text" id="name" name="name" required>
</div>
<div class="upload-section">
<div>
<input type="file" name="file" accept="image/*" required>
</div>
<button type="submit" class="button">
Generate Report
<div class="button-progress"></div>
</button>
<div id="loading">Loading...</div>
</div>
</form>
</div>
</body>
</html>
"""
return content
@app.post("/analyze", response_class=Response)
async def analyze_file(name: str = Form(...), file: UploadFile = File(...), threshold: float = Form(0.6)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
predictions_watcher = models["KnochenWächter"](image)
predictions_master = models["RöntgenMeister"](image)
predictions_locator = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions_locator if p['score'] >= threshold]
analyzed_image = image
overall_prediction = "No Fracture"
max_confidence = 0.0
if filtered_preds:
analyzed_image = draw_boxes(image, filtered_preds)
overall_prediction = "Fracture Detected"
max_confidence = max([p['score'] for p in filtered_preds])
image_stream = io.BytesIO()
analyzed_image.save(image_stream, format="PNG")
image_bytes = image_stream.getvalue()
pdf_report = generate_report(name, image_bytes, overall_prediction, max_confidence)
headers = {
'Content-Disposition': 'attachment; filename="report.pdf"'
}
return Response(content=pdf_report, headers=headers, media_type="application/pdf")
except Exception as e:
error_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Fehler</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.error-box {
background: #fee2e2;
border: 1px solid #ef4444;
padding: 20px;
border-radius: 8px;
margin: 20px 0;
}
</style>
</head>
<body>
<div class="container">
<div class="error-box">
<h3>Fehler</h3>
<p>{str(e)}</p>
</div>
<a href="/" class="button back-button">
← Zurück
<div class="button-progress"></div>
</a>
</div>
</body>
</html>
"""
return HTMLResponse(content=error_html)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)