annie08's picture
add app.py requirements
fd38558
raw
history blame
2.04 kB
import torch
from PIL import Image, ImageDraw
from torchvision.transforms import Compose, ToTensor, Normalize
from transformers import DetrForObjectDetection, DetrImageProcessor
import gradio as gr
# Load the pre-trained DETR model and processor
model_name = "facebook/detr-resnet-50"
model = DetrForObjectDetection.from_pretrained(model_name)
processor = DetrImageProcessor.from_pretrained(model_name)
# Define fracture detection function
def detect_fractures(image):
"""Detect fractures in the given image using DETR."""
# Convert the input image to a format suitable for the model
inputs = processor(images=image, return_tensors="pt")
# Perform object detection
outputs = model(**inputs)
# Extract predictions
logits = outputs.logits
bboxes = outputs.pred_boxes
scores = logits.softmax(-1)[..., :-1].max(-1)
# Filter predictions
threshold = 0.5 # confidence threshold
keep = scores.values > threshold
filtered_boxes = bboxes[keep].detach().cpu()
filtered_scores = scores.values[keep].detach().cpu().tolist()
# Convert normalized bounding boxes to absolute coordinates
width, height = image.size
filtered_boxes = filtered_boxes * torch.tensor([width, height, width, height])
# Draw bounding boxes on the image
draw = ImageDraw.Draw(image)
for box, score in zip(filtered_boxes, filtered_scores):
x_min, y_min, x_max, y_max = box.tolist()
draw.rectangle(((x_min, y_min), (x_max, y_max)), outline="red", width=3)
draw.text((x_min, y_min), f"Fracture: {score:.2f}", fill="red")
return image
# Define Gradio interface
def infer(image):
"""Run fracture detection and return the result image."""
return detect_fractures(image)
iface = gr.Interface(
fn=infer,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Fracture Detection",
description="Upload an X-ray or medical image to detect fractures using DETR.",
# examples=["example1.jpg", "example2.jpg"],
)
iface.launch()