0llheaven commited on
Commit
277563c
·
verified ·
1 Parent(s): 5e73f80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw
7
  processor = AutoImageProcessor.from_pretrained("0llheaven/Conditional-detr-finetuned")
8
  model = AutoModelForObjectDetection.from_pretrained("0llheaven/Conditional-detr-finetuned")
9
 
10
- def detect_objects(image):
11
  # Convert image to RGB if it's grayscale
12
  if image.mode != "RGB":
13
  image = image.convert("RGB")
@@ -16,7 +16,7 @@ def detect_objects(image):
16
  inputs = processor(images=image, return_tensors="pt")
17
  outputs = model(**inputs)
18
 
19
- # Filter predictions with confidence greater than 0.5
20
  target_sizes = torch.tensor([image.size[::-1]])
21
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
22
 
@@ -28,17 +28,18 @@ def detect_objects(image):
28
  boxes = result["boxes"]
29
 
30
  for score, label, box in zip(scores, labels, boxes):
31
- box = [round(i, 2) for i in box.tolist()]
32
- label_name = "Pneumonia" if label.item() == 0 else "Other"
33
- draw.rectangle(box, outline="red", width=3)
34
- draw.text((box[0], box[1]), f"{label_name}: {round(score.item(), 3)}", fill="red")
 
35
 
36
  return image
37
 
38
  # Create the Gradio interface
39
  interface = gr.Interface(
40
  fn=detect_objects,
41
- inputs=gr.Image(type="pil"),
42
  outputs=gr.Image(type="pil"), # Corrected output type
43
  title="Object Detection with Transformers",
44
  description="Upload an image to detect objects using a fine-tuned Conditional-DETR model."
 
7
  processor = AutoImageProcessor.from_pretrained("0llheaven/Conditional-detr-finetuned")
8
  model = AutoModelForObjectDetection.from_pretrained("0llheaven/Conditional-detr-finetuned")
9
 
10
+ def detect_objects(image, score_threshold):
11
  # Convert image to RGB if it's grayscale
12
  if image.mode != "RGB":
13
  image = image.convert("RGB")
 
16
  inputs = processor(images=image, return_tensors="pt")
17
  outputs = model(**inputs)
18
 
19
+ # Filter predictions based on the user-defined score threshold
20
  target_sizes = torch.tensor([image.size[::-1]])
21
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
22
 
 
28
  boxes = result["boxes"]
29
 
30
  for score, label, box in zip(scores, labels, boxes):
31
+ if score >= score_threshold: # Only draw if score is above threshold
32
+ box = [round(i, 2) for i in box.tolist()]
33
+ label_name = "Pneumonia" if label.item() == 0 else "Other"
34
+ draw.rectangle(box, outline="red", width=3)
35
+ draw.text((box[0], box[1]), f"{label_name}: {round(score.item(), 3)}", fill="red")
36
 
37
  return image
38
 
39
  # Create the Gradio interface
40
  interface = gr.Interface(
41
  fn=detect_objects,
42
+ inputs=[gr.Image(type="pil"), gr.Slider(0, 1, value=0.5, label="Score Threshold")], # Add slider for score threshold
43
  outputs=gr.Image(type="pil"), # Corrected output type
44
  title="Object Detection with Transformers",
45
  description="Upload an image to detect objects using a fine-tuned Conditional-DETR model."