File size: 2,233 Bytes
1060621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import numpy as np
import cv2
from tqdm import tqdm
import argparse
from mmseg.apis import init_model, inference_model


def process_single_img(img_path, model, outpath, palette_dict):
    
    img_bgr = cv2.imread(img_path)

    result = inference_model(model, img_bgr)
    pred_mask = result.pred_sem_seg.data[0].cpu().numpy()

    # Map the predicted integer ID to the color of the corresponding category
    pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
    for idx in palette_dict.keys():
        pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
    pred_mask_bgr = pred_mask_bgr.astype('uint8')
    
    save_path = os.path.join(outpath, os.path.basename(img_path))
    cv2.imwrite(save_path, pred_mask_bgr)



def main(args):
    # Initialize model
    model = init_model(args.config_file, args.checkpoint_file, device=args.device)

    # Define class palette
    palette = [
        ['background', [0, 0, 0]],
        ['red', [0, 0, 255]]
    ]
    palette_dict = {idx: each[1] for idx, each in enumerate(palette)}

    # Create output directory if not exists
    if not os.path.exists(args.outpath):
        os.mkdir(args.outpath)

    # Process each image in the given directory
    for img_name in tqdm(os.listdir(args.data_folder)):
        img_path = os.path.join(args.data_folder, img_name)
        process_single_img(img_path, model, args.outpath, palette_dict)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Process images for semantic segmentation inference.")
    parser.add_argument('-d','--data_folder', type=str, required=True, help="Path to the folder containing input images.")
    parser.add_argument('-m','--config_file', type=str, required=True, help="Path to the model config file.")
    parser.add_argument('-pth','--checkpoint_file', type=str, required=True, help="Path to the model checkpoint file.")
    parser.add_argument('-o','--outpath', type=str, help="Path to save the output images.")
    parser.add_argument('--device', type=str, default='cuda:0', help="Device to run the model (e.g., 'cuda:0', 'cpu').")
    
    args = parser.parse_args()
    main(args)