Younes Belkada commited on
Commit
b62b9fa
·
1 Parent(s): 0189f5d
Files changed (2) hide show
  1. coco_utils.py +18 -0
  2. cocoevaluate.py +4 -3
coco_utils.py CHANGED
@@ -27,6 +27,24 @@ def is_dist_avail_and_initialized():
27
  return False
28
  return True
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def get_world_size():
32
  if not is_dist_avail_and_initialized():
 
27
  return False
28
  return True
29
 
30
+ class CocoDetection(torchvision.datasets.CocoDetection):
31
+ def __init__(self, img_folder, feature_extractor, ann_file):
32
+ super(CocoDetection, self).__init__(img_folder, ann_file)
33
+ self.feature_extractor = feature_extractor
34
+
35
+ def __getitem__(self, idx):
36
+ # read in PIL image and target in COCO format
37
+ img, target = super(CocoDetection, self).__getitem__(idx)
38
+
39
+ # preprocess image and target (converting target to DETR format, resizing + normalization of both image and target)
40
+ image_id = self.ids[idx]
41
+ target = {'image_id': image_id, 'annotations': target}
42
+ encoding = self.feature_extractor(images=img, annotations=target, return_tensors="pt")
43
+ pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
44
+ target = encoding["labels"][0] # remove batch dimension
45
+
46
+ return pixel_values, target
47
+
48
 
49
  def get_world_size():
50
  if not is_dist_avail_and_initialized():
cocoevaluate.py CHANGED
@@ -17,7 +17,7 @@ import evaluate
17
  import datasets
18
  import pyarrow as pa
19
 
20
- from .coco_utils import CocoEvaluator, get_coco_api_from_dataset
21
 
22
  # TODO: Add BibTeX citation
23
  _CITATION = """\
@@ -72,9 +72,10 @@ def summarize_if_long_list(obj):
72
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
73
  class COCOEvaluate(evaluate.Metric):
74
  """TODO: Short description of my evaluation module."""
75
- def __init__(self, coco_dataset, iou_types=['bbox'], **kwargs):
76
  super().__init__(**kwargs)
77
- base_ds = get_coco_api_from_dataset(coco_dataset)
 
78
  self.coco_evaluator = CocoEvaluator(base_ds, iou_types)
79
 
80
 
 
17
  import datasets
18
  import pyarrow as pa
19
 
20
+ from .coco_utils import CocoEvaluator, get_coco_api_from_dataset, CocoDetection
21
 
22
  # TODO: Add BibTeX citation
23
  _CITATION = """\
 
72
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
73
  class COCOEvaluate(evaluate.Metric):
74
  """TODO: Short description of my evaluation module."""
75
+ def __init__(self, coco_path, feature_extractor, annotation_path, iou_types=['bbox'], **kwargs):
76
  super().__init__(**kwargs)
77
+ self.coco_dataset = CocoDetection(coco_path, feature_extractor, annotation_path)
78
+ base_ds = get_coco_api_from_dataset(self.coco_dataset)
79
  self.coco_evaluator = CocoEvaluator(base_ds, iou_types)
80
 
81