File size: 2,036 Bytes
fd38558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from PIL import Image, ImageDraw
from torchvision.transforms import Compose, ToTensor, Normalize
from transformers import DetrForObjectDetection, DetrImageProcessor
import gradio as gr

# Load the pre-trained DETR model and processor
model_name = "facebook/detr-resnet-50"
model = DetrForObjectDetection.from_pretrained(model_name)
processor = DetrImageProcessor.from_pretrained(model_name)

# Define fracture detection function
def detect_fractures(image):
    """Detect fractures in the given image using DETR."""
    # Convert the input image to a format suitable for the model
    inputs = processor(images=image, return_tensors="pt")

    # Perform object detection
    outputs = model(**inputs)

    # Extract predictions
    logits = outputs.logits
    bboxes = outputs.pred_boxes
    scores = logits.softmax(-1)[..., :-1].max(-1)

    # Filter predictions
    threshold = 0.5  # confidence threshold
    keep = scores.values > threshold

    filtered_boxes = bboxes[keep].detach().cpu()
    filtered_scores = scores.values[keep].detach().cpu().tolist()

    # Convert normalized bounding boxes to absolute coordinates
    width, height = image.size
    filtered_boxes = filtered_boxes * torch.tensor([width, height, width, height])

    # Draw bounding boxes on the image
    draw = ImageDraw.Draw(image)
    for box, score in zip(filtered_boxes, filtered_scores):
        x_min, y_min, x_max, y_max = box.tolist()
        draw.rectangle(((x_min, y_min), (x_max, y_max)), outline="red", width=3)
        draw.text((x_min, y_min), f"Fracture: {score:.2f}", fill="red")

    return image

# Define Gradio interface
def infer(image):
    """Run fracture detection and return the result image."""
    return detect_fractures(image)

iface = gr.Interface(
    fn=infer,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil"),
    title="Fracture Detection",
    description="Upload an X-ray or medical image to detect fractures using DETR.",
    # examples=["example1.jpg", "example2.jpg"],
)

iface.launch()