|
import torch |
|
from PIL import Image, ImageDraw |
|
from torchvision.transforms import Compose, ToTensor, Normalize |
|
from transformers import DetrForObjectDetection, DetrImageProcessor |
|
import gradio as gr |
|
|
|
|
|
model_name = "facebook/detr-resnet-50" |
|
model = DetrForObjectDetection.from_pretrained(model_name) |
|
processor = DetrImageProcessor.from_pretrained(model_name) |
|
|
|
|
|
def detect_fractures(image): |
|
"""Detect fractures in the given image using DETR.""" |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
bboxes = outputs.pred_boxes |
|
scores = logits.softmax(-1)[..., :-1].max(-1) |
|
|
|
|
|
threshold = 0.5 |
|
keep = scores.values > threshold |
|
|
|
filtered_boxes = bboxes[keep].detach().cpu() |
|
filtered_scores = scores.values[keep].detach().cpu().tolist() |
|
|
|
|
|
width, height = image.size |
|
filtered_boxes = filtered_boxes * torch.tensor([width, height, width, height]) |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
for box, score in zip(filtered_boxes, filtered_scores): |
|
x_min, y_min, x_max, y_max = box.tolist() |
|
draw.rectangle(((x_min, y_min), (x_max, y_max)), outline="red", width=3) |
|
draw.text((x_min, y_min), f"Fracture: {score:.2f}", fill="red") |
|
|
|
return image |
|
|
|
|
|
def infer(image): |
|
"""Run fracture detection and return the result image.""" |
|
return detect_fractures(image) |
|
|
|
iface = gr.Interface( |
|
fn=infer, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
title="Fracture Detection", |
|
description="Upload an X-ray or medical image to detect fractures using DETR.", |
|
|
|
) |
|
|
|
iface.launch() |
|
|