File size: 3,786 Bytes
cfa5142
 
 
 
 
 
 
 
 
d3e66e1
cfa5142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60434a4
cfa5142
 
 
 
 
 
 
60434a4
cfa5142
60434a4
cfa5142
 
 
 
 
 
 
 
60434a4
8d52a7d
cfa5142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d52a7d
cfa5142
 
 
60434a4
 
 
 
8d52a7d
cfa5142
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import os
from datetime import datetime
import numpy as np

from modules.model_downloader import (
    AVAILABLE_MODELS, DEFAULT_MODEL_TYPE, OUTPUT_DIR,
    is_sam_exist,
    download_sam_model_url
)
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
from modules.mask_utils import (
    save_psd_with_masks,
    create_mask_combined_images,
    create_mask_gallery
)

CONFIGS = {
    "sam2_hiera_tiny": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_t.yaml"),
    "sam2_hiera_small": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_s.yaml"),
    "sam2_hiera_base_plus": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_b+.yaml"),
    "sam2_hiera_large": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_l.yaml"),
}


class SamInference:
    def __init__(self,
                 model_dir: str = MODELS_DIR,
                 output_dir: str = OUTPUT_DIR
                 ):
        self.model = None
        self.available_models = list(AVAILABLE_MODELS.keys())
        self.model_type = DEFAULT_MODEL_TYPE
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.mask_generator = None
        self.image_predictor = None
        self.video_predictor = None

    def load_model(self):
        config = CONFIGS[self.model_type]
        filename, url = AVAILABLE_MODELS[self.model_type]
        model_path = os.path.join(self.model_dir, filename)

        if not is_sam_exist(self.model_type):
            print(f"\nNo SAM2 model found, downloading {self.model_type} model...")
            download_sam_model_url(self.model_type)
        print("\nApplying configs to model..")

        try:
            self.model = build_sam2(
                config_file=config,
                ckpt_path=model_path,
                device=self.device
            )
        except Exception as e:
            print(f"Error while Loading SAM2 model! {e}")

    def generate_mask(self,
                      image: np.ndarray):
        return self.mask_generator.generate(image)

    def generate_mask_app(self,
                          image: np.ndarray,
                          model_type: str,
                          *params
                          ):
        maskgen_hparams = {
            'points_per_side': int(params[0]),
            'points_per_batch': int(params[1]),
            'pred_iou_thresh': float(params[2]),
            'stability_score_thresh': float(params[3]),
            'stability_score_offset': float(params[4]),
            'crop_n_layers': int(params[5]),
            'box_nms_thresh': float(params[6]),
            'crop_n_points_downscale_factor': int(params[7]),
            'min_mask_region_area': int(params[8]),
            'use_m2m': bool(params[9])
        }
        timestamp = datetime.now().strftime("%m%d%H%M%S")
        output_file_name = f"result-{timestamp}.psd"
        output_path = os.path.join(self.output_dir, "psd", output_file_name)

        if self.model is None or self.model_type != model_type:
            self.model_type = model_type
            self.load_model()

        self.mask_generator = SAM2AutomaticMaskGenerator(
            model=self.model,
            **maskgen_hparams
        )

        masks = self.mask_generator.generate(image)

        save_psd_with_masks(image, masks, output_path)
        combined_image = create_mask_combined_images(image, masks)
        gallery = create_mask_gallery(image, masks)
        return [combined_image] + gallery, output_path