|
|
|
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
from typing import Any, Dict, List
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Runs automatic mask generation on an input image or directory of images, "
|
|
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
|
|
"as well as pycocotools if saving in RLE format."
|
|
)
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--input",
|
|
type=str,
|
|
required=True,
|
|
help="Path to either a single input image or folder of images.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
required=True,
|
|
help=(
|
|
"Path to the directory where masks will be output. Output will be either a folder "
|
|
"of PNGs per image or a single json with COCO-style masks."
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--model-type",
|
|
type=str,
|
|
required=True,
|
|
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--checkpoint",
|
|
type=str,
|
|
required=True,
|
|
help="The path to the SAM checkpoint to use for mask generation.",
|
|
)
|
|
|
|
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
|
|
|
|
parser.add_argument(
|
|
"--convert-to-rle",
|
|
action="store_true",
|
|
help=(
|
|
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
|
|
"Requires pycocotools."
|
|
),
|
|
)
|
|
|
|
amg_settings = parser.add_argument_group("AMG Settings")
|
|
|
|
amg_settings.add_argument(
|
|
"--points-per-side",
|
|
type=int,
|
|
default=None,
|
|
help="Generate masks by sampling a grid over the image with this many points to a side.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--points-per-batch",
|
|
type=int,
|
|
default=None,
|
|
help="How many input points to process simultaneously in one batch.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--pred-iou-thresh",
|
|
type=float,
|
|
default=None,
|
|
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--stability-score-thresh",
|
|
type=float,
|
|
default=None,
|
|
help="Exclude masks with a stability score lower than this threshold.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--stability-score-offset",
|
|
type=float,
|
|
default=None,
|
|
help="Larger values perturb the mask more when measuring stability score.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--box-nms-thresh",
|
|
type=float,
|
|
default=None,
|
|
help="The overlap threshold for excluding a duplicate mask.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--crop-n-layers",
|
|
type=int,
|
|
default=None,
|
|
help=(
|
|
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
|
|
"The value sets how many different scales to crop at."
|
|
),
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--crop-nms-thresh",
|
|
type=float,
|
|
default=None,
|
|
help="The overlap threshold for excluding duplicate masks across different crops.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--crop-overlap-ratio",
|
|
type=int,
|
|
default=None,
|
|
help="Larger numbers mean image crops will overlap more.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--crop-n-points-downscale-factor",
|
|
type=int,
|
|
default=None,
|
|
help="The number of points-per-side in each layer of crop is reduced by this factor.",
|
|
)
|
|
|
|
amg_settings.add_argument(
|
|
"--min-mask-region-area",
|
|
type=int,
|
|
default=None,
|
|
help=(
|
|
"Disconnected mask regions or holes with area smaller than this value "
|
|
"in pixels are removed by postprocessing."
|
|
),
|
|
)
|
|
|
|
|
|
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
|
|
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h"
|
|
metadata = [header]
|
|
for i, mask_data in enumerate(masks):
|
|
mask = mask_data["segmentation"]
|
|
filename = f"{i}.png"
|
|
cv2.imwrite(os.path.join(path, filename), mask * 255)
|
|
mask_metadata = [
|
|
str(i),
|
|
str(mask_data["area"]),
|
|
*[str(x) for x in mask_data["bbox"]],
|
|
*[str(x) for x in mask_data["point_coords"][0]],
|
|
str(mask_data["predicted_iou"]),
|
|
str(mask_data["stability_score"]),
|
|
*[str(x) for x in mask_data["crop_box"]],
|
|
]
|
|
row = ",".join(mask_metadata)
|
|
metadata.append(row)
|
|
metadata_path = os.path.join(path, "metadata.csv")
|
|
with open(metadata_path, "w") as f:
|
|
f.write("\n".join(metadata))
|
|
|
|
return
|
|
|
|
|
|
def get_amg_kwargs(args):
|
|
amg_kwargs = {
|
|
"points_per_side": args.points_per_side,
|
|
"points_per_batch": args.points_per_batch,
|
|
"pred_iou_thresh": args.pred_iou_thresh,
|
|
"stability_score_thresh": args.stability_score_thresh,
|
|
"stability_score_offset": args.stability_score_offset,
|
|
"box_nms_thresh": args.box_nms_thresh,
|
|
"crop_n_layers": args.crop_n_layers,
|
|
"crop_nms_thresh": args.crop_nms_thresh,
|
|
"crop_overlap_ratio": args.crop_overlap_ratio,
|
|
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
|
|
"min_mask_region_area": args.min_mask_region_area,
|
|
}
|
|
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
|
|
return amg_kwargs
|
|
|
|
|
|
def main(args: argparse.Namespace) -> None:
|
|
print("Loading model...")
|
|
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
|
_ = sam.to(device=args.device)
|
|
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
|
|
amg_kwargs = get_amg_kwargs(args)
|
|
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
|
|
|
|
if not os.path.isdir(args.input):
|
|
targets = [args.input]
|
|
else:
|
|
targets = [
|
|
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
|
|
]
|
|
targets = [os.path.join(args.input, f) for f in targets]
|
|
|
|
os.makedirs(args.output, exist_ok=True)
|
|
|
|
for t in targets:
|
|
print(f"Processing '{t}'...")
|
|
image = cv2.imread(t)
|
|
if image is None:
|
|
print(f"Could not load '{t}' as an image, skipping...")
|
|
continue
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
masks = generator.generate(image)
|
|
|
|
base = os.path.basename(t)
|
|
base = os.path.splitext(base)[0]
|
|
save_base = os.path.join(args.output, base)
|
|
if output_mode == "binary_mask":
|
|
os.makedirs(save_base, exist_ok=False)
|
|
write_masks_to_folder(masks, save_base)
|
|
else:
|
|
save_file = save_base + ".json"
|
|
with open(save_file, "w") as f:
|
|
json.dump(masks, f)
|
|
print("Done!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
main(args)
|
|
|