File size: 4,982 Bytes
40bbb34
0ab9a32
 
 
 
 
 
40bbb34
 
0ab9a32
 
40bbb34
0ab9a32
40bbb34
 
0ab9a32
40bbb34
0ab9a32
 
 
 
40bbb34
0ab9a32
 
 
 
 
 
 
 
 
 
40bbb34
 
 
0ab9a32
40bbb34
 
 
 
 
 
 
 
 
 
 
 
 
 
0ab9a32
40bbb34
0ab9a32
 
 
 
 
 
 
 
 
 
 
 
1899131
0ab9a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40bbb34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

from segmenter import build_segmenter
import os
import argparse
import pdb
import time
from PIL import Image



class CaptionAnything():
    def __init__(self, args, api_key=""):
        self.args = args
#         self.captioner = build_captioner(args.captioner, args.device, args)
        self.segmenter = build_segmenter(args.segmenter, args.device, args)
        self.text_refiner = None
        
    
    def inference(self, image, prompt, controls, disable_gpt=False):
        #  segment with prompt
        print("CA prompt: ", prompt, "CA controls",controls)
        print(image)
        seg_mask = self.segmenter.inference(image, prompt)[0, ...]
        mask_save_path = f'result/mask_{time.time()}.png'
        if not os.path.exists(os.path.dirname(mask_save_path)):
            os.makedirs(os.path.dirname(mask_save_path))
        seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
        if seg_mask_img.mode != 'RGB':
            seg_mask_img = seg_mask_img.convert('RGB')
        seg_mask_img.save(mask_save_path)
        print('seg_mask path: ', mask_save_path)
        print("seg_mask.shape: ", seg_mask.shape)
        
#         mask_image = mask_image(image,np.array(seg_mask_img))
#         cv2.imwrite(f'result/mask_vis.png',mask_image)
        #  captioning with mask
#         if self.args.enable_reduce_tokens:
#             caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
#         else:
#             caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
            
#         #  refining with TextRefiner
#         context_captions = []
#         if self.args.context_captions:
#             context_captions.append(self.captioner.inference(image))
#         if not disable_gpt and self.text_refiner is not None:
#             refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
#         else:
#             refined_caption = {'raw_caption': caption}                
        out = {
            'mask_save_path': mask_save_path,
            'mask': seg_mask_img}
        return out
    
def parse_augment():
    parser = argparse.ArgumentParser()
    parser.add_argument('--captioner', type=str, default="blip2")
    parser.add_argument('--segmenter', type=str, default="base")
    parser.add_argument('--text_refiner', type=str, default="base")
    parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
    parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
    parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
    parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
    parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
    parser.add_argument('--device', type=str, default="cpu")
    parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")  
    parser.add_argument('--debug', action="store_true")
    parser.add_argument('--gradio_share', action="store_true")
    parser.add_argument('--disable_gpt', action="store_true")
    parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
    parser.add_argument('--disable_reuse_features', action="store_true", default=False)
    args = parser.parse_args()

    if args.debug:
        print(args)
    return args

if __name__ == "__main__":
    args = parse_augment()
    # image_path = 'test_img/img3.jpg'
    image_path = 'test_img/img13.jpg'
    prompts = [
        {
            "prompt_type":["click"],
            "input_point":[[500, 300], [1000, 500]],
            "input_label":[1, 0],
            "multimask_output":"True",
        },
        {
            "prompt_type":["click"],
            "input_point":[[900, 800]],
            "input_label":[1],
            "multimask_output":"True",
        }
    ]
    controls = {
            "length": "30",
            "sentiment": "positive",
            # "imagination": "True",
            "imagination": "False",
            "language": "English",
        }
    
    model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
    for prompt in prompts:
        print('*'*30)
        print('Image path: ', image_path)
        image = Image.open(image_path)
        print(image)
        print('Visual controls (SAM prompt):\n', prompt)
        print('Language controls:\n', controls)
        out = model.inference(image_path, prompt, controls)