annie08 commited on
Commit
fd38558
·
1 Parent(s): 365b371

add app.py requirements

Browse files
Files changed (2) hide show
  1. app.py +60 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw
3
+ from torchvision.transforms import Compose, ToTensor, Normalize
4
+ from transformers import DetrForObjectDetection, DetrImageProcessor
5
+ import gradio as gr
6
+
7
+ # Load the pre-trained DETR model and processor
8
+ model_name = "facebook/detr-resnet-50"
9
+ model = DetrForObjectDetection.from_pretrained(model_name)
10
+ processor = DetrImageProcessor.from_pretrained(model_name)
11
+
12
+ # Define fracture detection function
13
+ def detect_fractures(image):
14
+ """Detect fractures in the given image using DETR."""
15
+ # Convert the input image to a format suitable for the model
16
+ inputs = processor(images=image, return_tensors="pt")
17
+
18
+ # Perform object detection
19
+ outputs = model(**inputs)
20
+
21
+ # Extract predictions
22
+ logits = outputs.logits
23
+ bboxes = outputs.pred_boxes
24
+ scores = logits.softmax(-1)[..., :-1].max(-1)
25
+
26
+ # Filter predictions
27
+ threshold = 0.5 # confidence threshold
28
+ keep = scores.values > threshold
29
+
30
+ filtered_boxes = bboxes[keep].detach().cpu()
31
+ filtered_scores = scores.values[keep].detach().cpu().tolist()
32
+
33
+ # Convert normalized bounding boxes to absolute coordinates
34
+ width, height = image.size
35
+ filtered_boxes = filtered_boxes * torch.tensor([width, height, width, height])
36
+
37
+ # Draw bounding boxes on the image
38
+ draw = ImageDraw.Draw(image)
39
+ for box, score in zip(filtered_boxes, filtered_scores):
40
+ x_min, y_min, x_max, y_max = box.tolist()
41
+ draw.rectangle(((x_min, y_min), (x_max, y_max)), outline="red", width=3)
42
+ draw.text((x_min, y_min), f"Fracture: {score:.2f}", fill="red")
43
+
44
+ return image
45
+
46
+ # Define Gradio interface
47
+ def infer(image):
48
+ """Run fracture detection and return the result image."""
49
+ return detect_fractures(image)
50
+
51
+ iface = gr.Interface(
52
+ fn=infer,
53
+ inputs=gr.Image(type="pil"),
54
+ outputs=gr.Image(type="pil"),
55
+ title="Fracture Detection",
56
+ description="Upload an X-ray or medical image to detect fractures using DETR.",
57
+ # examples=["example1.jpg", "example2.jpg"],
58
+ )
59
+
60
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ Pillow
3
+ torchvision
4
+ transformers
5
+ gradio