import torch from PIL import Image, ImageDraw from torchvision.transforms import Compose, ToTensor, Normalize from transformers import DetrForObjectDetection, DetrImageProcessor import gradio as gr # Load the pre-trained DETR model and processor model_name = "facebook/detr-resnet-50" model = DetrForObjectDetection.from_pretrained(model_name) processor = DetrImageProcessor.from_pretrained(model_name) # Define fracture detection function def detect_fractures(image): """Detect fractures in the given image using DETR.""" # Convert the input image to a format suitable for the model inputs = processor(images=image, return_tensors="pt") # Perform object detection outputs = model(**inputs) # Extract predictions logits = outputs.logits bboxes = outputs.pred_boxes scores = logits.softmax(-1)[..., :-1].max(-1) # Filter predictions threshold = 0.5 # confidence threshold keep = scores.values > threshold filtered_boxes = bboxes[keep].detach().cpu() filtered_scores = scores.values[keep].detach().cpu().tolist() # Convert normalized bounding boxes to absolute coordinates width, height = image.size filtered_boxes = filtered_boxes * torch.tensor([width, height, width, height]) # Draw bounding boxes on the image draw = ImageDraw.Draw(image) for box, score in zip(filtered_boxes, filtered_scores): x_min, y_min, x_max, y_max = box.tolist() draw.rectangle(((x_min, y_min), (x_max, y_max)), outline="red", width=3) draw.text((x_min, y_min), f"Fracture: {score:.2f}", fill="red") return image # Define Gradio interface def infer(image): """Run fracture detection and return the result image.""" return detect_fractures(image) iface = gr.Interface( fn=infer, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Fracture Detection", description="Upload an X-ray or medical image to detect fractures using DETR.", # examples=["example1.jpg", "example2.jpg"], ) iface.launch()