from fastapi import FastAPI, File, UploadFile from fastapi.responses import HTMLResponse from transformers import pipeline from PIL import Image, ImageDraw import numpy as np import io import uvicorn import base64 import random app = FastAPI() # Loading the models def load_models(): return { "BoneEye": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), "BoneGuard": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), "XRayMaster": pipeline("image-classification", model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") } models = load_models() def translate_label(label): translations = { "fracture": "Fracture", "no fracture": "No Fracture", "normal": "Normal", "abnormal": "Abnormal", "F1": "Fracture", "NF": "No Fracture" } return translations.get(label.lower(), label) def create_heatmap_overlay(image, box, score): overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) x1, y1 = box['xmin'], box['ymin'] x2, y2 = box['xmax'], box['ymax'] if score > 0.8: fill_color = (255, 0, 0, 100) border_color = (255, 0, 0, 255) elif score > 0.6: fill_color = (255, 165, 0, 100) border_color = (255, 165, 0, 255) else: fill_color = (255, 255, 0, 100) border_color = (255, 255, 0, 255) draw.rectangle([x1, y1, x2, y2], fill=fill_color) draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2) return overlay def draw_boxes(image, predictions): result_image = image.copy().convert('RGBA') for pred in predictions: box = pred['box'] score = pred['score'] overlay = create_heatmap_overlay(image, box, score) result_image = Image.alpha_composite(result_image, overlay) draw = ImageDraw.Draw(result_image) temp = 36.5 + (score * 2.5) label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)" text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label) draw.rectangle(text_bbox, fill=(0, 0, 0, 180)) draw.text( (box['xmin'], box['ymin']-20), label, fill=(255, 255, 255, 255) ) return result_image def image_to_base64(image): buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/png;base64,{img_str}" COMMON_STYLES = """ body { font-family: system-ui, -apple-system, sans-serif; background: #f0f2f5; margin: 0; padding: 20px; color: #1a1a1a; } ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: transparent; } ::-webkit-scrollbar-thumb { background-color: rgba(156, 163, 175, 0.5); border-radius: 4px; } .container { max-width: 1200px; margin: 0 auto; background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } .button { background: #2d2d2d; color: white; border: none; padding: 12px 30px; border-radius: 8px; cursor: pointer; font-size: 1.1em; transition: all 0.3s ease; position: relative; } .button:hover { background: #404040; } @keyframes progress { 0% { width: 0; } 100% { width: 100%; } } .button-progress { position: absolute; bottom: 0; left: 0; height: 4px; background: rgba(255, 255, 255, 0.5); width: 0; } .button:active .button-progress { animation: progress 2s linear forwards; } img { max-width: 100%; height: auto; border-radius: 8px; } @keyframes blink { 0% { opacity: 1; } 50% { opacity: 0; } 100% { opacity: 1; } } #loading { display: none; color: white; margin-top: 10px; animation: blink 1s infinite; text-align: center; } """ SAMPLE_IMAGES = [ {"id": "sample1", "filename": "sample1.png", "label": "Fracture"}, {"id": "sample2", "filename": "sample2.png", "label": "No Fracture"}, {"id": "sample3", "filename": "sample3.png", "label": "Fracture"}, {"id": "sample4", "filename": "sample4.png", "label": "No Fracture"}, {"id": "sample5", "filename": "sample5.png", "label": "Fracture"}, {"id": "sample6", "filename": "sample6.png", "label": "No Fracture"}, {"id": "sample7", "filename": "sample7.png", "label": "Fracture"}, {"id": "sample8", "filename": "sample8.png", "label": "No Fracture"}, {"id": "sample9", "filename": "sample9.png", "label": "Fracture"}, {"id": "sample10", "filename": "sample10.png", "label": "No Fracture"}, ] @app.get("/", response_class=HTMLResponse) async def main(): image_options = "".join( f'' for img in SAMPLE_IMAGES ) content = f""" Fracture Detection
Loading...
""" return content @app.post("/analyze", response_class=HTMLResponse) async def analyze_file(sample_image: str, threshold: float = 0.60): try: # For now, let's just pick a random sample image for demonstration. # Replace this with actual loading of the selected image. # You will need to place the sample images in a directory (e.g., 'sample_images') image_path = f"sample_images/{sample_image}" try: image = Image.open(image_path) except FileNotFoundError: return f""" Error

Error

Sample image '{sample_image}' not found. Please ensure the image exists in the 'sample_images' directory.

← Back
""" predictions_watcher = models["BoneGuard"](image) predictions_master = models["XRayMaster"](image) predictions_locator = models["BoneEye"](image) filtered_preds = [p for p in predictions_locator if p['score'] >= threshold] if filtered_preds: result_image = draw_boxes(image, filtered_preds) else: result_image = image # Logic to make fractured area black will be implemented here once images and fracture data are available. # For demonstration purposes, let's just mark the detected areas. result_image_b64 = image_to_base64(result_image) results_html = f""" Results

BoneGuard

""" for pred in predictions_watcher: confidence_class = "score-high" if pred['score'] > 0.7 else "score-medium" results_html += f"""
{pred['score']:.1%} - {translate_label(pred['label'])}
""" results_html += "
" results_html += "

XRayMaster

" for pred in predictions_master: confidence_class = "score-high" if pred['score'] > 0.7 else "score-medium" results_html += f"""
{pred['score']:.1%} - {translate_label(pred['label'])}
""" results_html += "
" results_html += f"""

Fracture Localization

Analyzed image
← Back
""" return results_html except Exception as e: return f""" Error

Error

{str(e)}

← Back
""" if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)