bfd_report_gen / app.py
ftx7go's picture
Update app.py
3bba4e3 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import HTMLResponse
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
import io
import uvicorn
import base64
import random
app = FastAPI()
# Loading the models
def load_models():
return {
"BoneEye": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"BoneGuard": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"XRayMaster": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
models = load_models()
def translate_label(label):
translations = {
"fracture": "Fracture",
"no fracture": "No Fracture",
"normal": "Normal",
"abnormal": "Abnormal",
"F1": "Fracture",
"NF": "No Fracture"
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
if score > 0.8:
fill_color = (255, 0, 0, 100)
border_color = (255, 0, 0, 255)
elif score > 0.6:
fill_color = (255, 165, 0, 100)
border_color = (255, 165, 0, 255)
else:
fill_color = (255, 255, 0, 100)
border_color = (255, 255, 0, 255)
draw.rectangle([x1, y1, x2, y2], fill=fill_color)
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
for pred in predictions:
box = pred['box']
score = pred['score']
overlay = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, overlay)
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
label = f"{translate_label(pred['label'])} ({score:.1%}{temp:.1f}°C)"
text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
draw.text(
(box['xmin'], box['ymin']-20),
label,
fill=(255, 255, 255, 255)
)
return result_image
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
COMMON_STYLES = """
body {
font-family: system-ui, -apple-system, sans-serif;
background: #f0f2f5;
margin: 0;
padding: 20px;
color: #1a1a1a;
}
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: transparent;
}
::-webkit-scrollbar-thumb {
background-color: rgba(156, 163, 175, 0.5);
border-radius: 4px;
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.button {
background: #2d2d2d;
color: white;
border: none;
padding: 12px 30px;
border-radius: 8px;
cursor: pointer;
font-size: 1.1em;
transition: all 0.3s ease;
position: relative;
}
.button:hover {
background: #404040;
}
@keyframes progress {
0% { width: 0; }
100% { width: 100%; }
}
.button-progress {
position: absolute;
bottom: 0;
left: 0;
height: 4px;
background: rgba(255, 255, 255, 0.5);
width: 0;
}
.button:active .button-progress {
animation: progress 2s linear forwards;
}
img {
max-width: 100%;
height: auto;
border-radius: 8px;
}
@keyframes blink {
0% { opacity: 1; }
50% { opacity: 0; }
100% { opacity: 1; }
}
#loading {
display: none;
color: white;
margin-top: 10px;
animation: blink 1s infinite;
text-align: center;
}
"""
SAMPLE_IMAGES = [
{"id": "sample1", "filename": "sample1.png", "label": "Fracture"},
{"id": "sample2", "filename": "sample2.png", "label": "No Fracture"},
{"id": "sample3", "filename": "sample3.png", "label": "Fracture"},
{"id": "sample4", "filename": "sample4.png", "label": "No Fracture"},
{"id": "sample5", "filename": "sample5.png", "label": "Fracture"},
{"id": "sample6", "filename": "sample6.png", "label": "No Fracture"},
{"id": "sample7", "filename": "sample7.png", "label": "Fracture"},
{"id": "sample8", "filename": "sample8.png", "label": "No Fracture"},
{"id": "sample9", "filename": "sample9.png", "label": "Fracture"},
{"id": "sample10", "filename": "sample10.png", "label": "No Fracture"},
]
@app.get("/", response_class=HTMLResponse)
async def main():
image_options = "".join(
f'<option value="{img["filename"]}">{img["id"]}</option>' for img in SAMPLE_IMAGES
)
content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Fracture Detection</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.upload-section {{
background: #2d2d2d;
padding: 40px;
border-radius: 12px;
margin: 20px 0;
text-align: center;
border: 2px dashed #404040;
transition: all 0.3s ease;
color: white;
}}
.upload-section:hover {{
border-color: #555;
}}
.image-selection {{
font-size: 1.1em;
margin: 20px 0;
color: white;
}}
select {{
padding: 10px;
border-radius: 8px;
border: 1px solid #404040;
background: #2d2d2d;
color: white;
transition: all 0.3s ease;
cursor: pointer;
font-size: 1em;
}}
select:hover {{
background: #404040;
}}
.confidence-slider {{
width: 100%;
max-width: 300px;
margin: 20px auto;
}}
input[type="range"] {{
width: 100%;
height: 8px;
border-radius: 4px;
background: #404040;
outline: none;
transition: all 0.3s ease;
-webkit-appearance: none;
}}
input[type="range"]::-webkit-slider-thumb {{
-webkit-appearance: none;
width: 20px;
height: 20px;
border-radius: 50%;
background: white;
cursor: pointer;
border: none;
}}
</style>
</head>
<body>
<div class="container">
<div class="upload-section">
<form action="/analyze" method="post" enctype="multipart/form-data" onsubmit="document.getElementById('loading').style.display = 'block';">
<div class="image-selection">
<label for="sample_image">Select a Sample X-ray:</label>
<select name="sample_image" id="sample_image">
{image_options}
</select>
</div>
<div class="confidence-slider">
<label for="threshold">Confidence Threshold: <span id="thresholdValue">0.60</span></label>
<input type="range" id="threshold" name="threshold"
min="0" max="1" step="0.05" value="0.60"
oninput="document.getElementById('thresholdValue').textContent = parseFloat(this.value).toFixed(2)">
</div>
<button type="submit" class="button">
Analyze
<div class="button-progress"></div>
</button>
<div id="loading">Loading...</div>
</form>
</div>
</div>
</body>
</html>
"""
return content
@app.post("/analyze", response_class=HTMLResponse)
async def analyze_file(sample_image: str, threshold: float = 0.60):
try:
# For now, let's just pick a random sample image for demonstration.
# Replace this with actual loading of the selected image.
# You will need to place the sample images in a directory (e.g., 'sample_images')
image_path = f"sample_images/{sample_image}"
try:
image = Image.open(image_path)
except FileNotFoundError:
return f"""
<!DOCTYPE html>
<html>
<head>
<title>Error</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.error-box {{
background: #fee2e2;
border: 1px solid #ef4444;
padding: 20px;
border-radius: 8px;
margin: 20px 0;
}}
</style>
</head>
<body>
<div class="container">
<div class="error-box">
<h3>Error</h3>
<p>Sample image '{sample_image}' not found. Please ensure the image exists in the 'sample_images' directory.</p>
</div>
<a href="/" class="button back-button">
← Back
<div class="button-progress"></div>
</a>
</div>
</body>
</html>
"""
predictions_watcher = models["BoneGuard"](image)
predictions_master = models["XRayMaster"](image)
predictions_locator = models["BoneEye"](image)
filtered_preds = [p for p in predictions_locator if p['score'] >= threshold]
if filtered_preds:
result_image = draw_boxes(image, filtered_preds)
else:
result_image = image
# Logic to make fractured area black will be implemented here once images and fracture data are available.
# For demonstration purposes, let's just mark the detected areas.
result_image_b64 = image_to_base64(result_image)
results_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Results</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.results-grid {{
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-top: 20px;
}}
.result-box {{
background: white;
padding: 20px;
border-radius: 12px;
margin: 10px 0;
border: 1px solid #e9ecef;
}}
.score-high {{
color: #0066cc;
font-weight: bold;
}}
.score-medium {{
color: #ffa500;
font-weight: bold;
}}
.back-button {{
display: inline-block;
text-decoration: none;
margin-top: 20px;
}}
h3 {{
color: #0066cc;
margin-top: 0;
}}
@media (max-width: 768px) {{
.results-grid {{
grid-template-columns: 1fr;
}}
}}
</style>
</head>
<body>
<div class="container">
<div class="results-grid">
<div>
<div class="result-box"><h3>BoneGuard</h3>
"""
for pred in predictions_watcher:
confidence_class = "score-high" if pred['score'] > 0.7 else "score-medium"
results_html += f"""
<div>
<span class="{confidence_class}">{pred['score']:.1%}</span> -
{translate_label(pred['label'])}
</div>
"""
results_html += "</div>"
results_html += "<div class='result-box'><h3>XRayMaster</h3>"
for pred in predictions_master:
confidence_class = "score-high" if pred['score'] > 0.7 else "score-medium"
results_html += f"""
<div>
<span class="{confidence_class}">{pred['score']:.1%}</span> -
{translate_label(pred['label'])}
</div>
"""
results_html += "</div></div>"
results_html += f"""
<div class='result-box'>
<h3>Fracture Localization</h3>
<img src="{result_image_b64}" alt="Analyzed image">
</div>
</div>
<a href="/" class="button back-button">
← Back
<div class="button-progress"></div>
</a>
</div>
</body>
</html>
"""
return results_html
except Exception as e:
return f"""
<!DOCTYPE html>
<html>
<head>
<title>Error</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.error-box {{
background: #fee2e2;
border: 1px solid #ef4444;
padding: 20px;
border-radius: 8px;
margin: 20px 0;
}}
</style>
</head>
<body>
<div class="container">
<div class="error-box">
<h3>Error</h3>
<p>{str(e)}</p>
</div>
<a href="/" class="button back-button">
← Back
<div class="button-progress"></div>
</a>
</div>
</body>
</html>
"""
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)