from fastapi import FastAPI, File, UploadFile from fastapi.responses import HTMLResponse from transformers import pipeline from PIL import Image, ImageDraw import numpy as np import io import uvicorn import base64 import random app = FastAPI() # Loading the models def load_models(): return { "BoneEye": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), "BoneGuard": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), "XRayMaster": pipeline("image-classification", model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") } models = load_models() def translate_label(label): translations = { "fracture": "Fracture", "no fracture": "No Fracture", "normal": "Normal", "abnormal": "Abnormal", "F1": "Fracture", "NF": "No Fracture" } return translations.get(label.lower(), label) def create_heatmap_overlay(image, box, score): overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) x1, y1 = box['xmin'], box['ymin'] x2, y2 = box['xmax'], box['ymax'] if score > 0.8: fill_color = (255, 0, 0, 100) border_color = (255, 0, 0, 255) elif score > 0.6: fill_color = (255, 165, 0, 100) border_color = (255, 165, 0, 255) else: fill_color = (255, 255, 0, 100) border_color = (255, 255, 0, 255) draw.rectangle([x1, y1, x2, y2], fill=fill_color) draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2) return overlay def draw_boxes(image, predictions): result_image = image.copy().convert('RGBA') for pred in predictions: box = pred['box'] score = pred['score'] overlay = create_heatmap_overlay(image, box, score) result_image = Image.alpha_composite(result_image, overlay) draw = ImageDraw.Draw(result_image) temp = 36.5 + (score * 2.5) label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)" text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label) draw.rectangle(text_bbox, fill=(0, 0, 0, 180)) draw.text( (box['xmin'], box['ymin']-20), label, fill=(255, 255, 255, 255) ) return result_image def image_to_base64(image): buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/png;base64,{img_str}" COMMON_STYLES = """ body { font-family: system-ui, -apple-system, sans-serif; background: #f0f2f5; margin: 0; padding: 20px; color: #1a1a1a; } ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: transparent; } ::-webkit-scrollbar-thumb { background-color: rgba(156, 163, 175, 0.5); border-radius: 4px; } .container { max-width: 1200px; margin: 0 auto; background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } .button { background: #2d2d2d; color: white; border: none; padding: 12px 30px; border-radius: 8px; cursor: pointer; font-size: 1.1em; transition: all 0.3s ease; position: relative; } .button:hover { background: #404040; } @keyframes progress { 0% { width: 0; } 100% { width: 100%; } } .button-progress { position: absolute; bottom: 0; left: 0; height: 4px; background: rgba(255, 255, 255, 0.5); width: 0; } .button:active .button-progress { animation: progress 2s linear forwards; } img { max-width: 100%; height: auto; border-radius: 8px; } @keyframes blink { 0% { opacity: 1; } 50% { opacity: 0; } 100% { opacity: 1; } } #loading { display: none; color: white; margin-top: 10px; animation: blink 1s infinite; text-align: center; } """ SAMPLE_IMAGES = [ {"id": "sample1", "filename": "sample1.png", "label": "Fracture"}, {"id": "sample2", "filename": "sample2.png", "label": "No Fracture"}, {"id": "sample3", "filename": "sample3.png", "label": "Fracture"}, {"id": "sample4", "filename": "sample4.png", "label": "No Fracture"}, {"id": "sample5", "filename": "sample5.png", "label": "Fracture"}, {"id": "sample6", "filename": "sample6.png", "label": "No Fracture"}, {"id": "sample7", "filename": "sample7.png", "label": "Fracture"}, {"id": "sample8", "filename": "sample8.png", "label": "No Fracture"}, {"id": "sample9", "filename": "sample9.png", "label": "Fracture"}, {"id": "sample10", "filename": "sample10.png", "label": "No Fracture"}, ] @app.get("/", response_class=HTMLResponse) async def main(): image_options = "".join( f'' for img in SAMPLE_IMAGES ) content = f"""
Sample image '{sample_image}' not found. Please ensure the image exists in the 'sample_images' directory.
{str(e)}