dqshuai commited on
Commit
83c32da
·
unverified ·
2 Parent(s): 6059ecb 2826c26

Merge pull request #11 from BakingBrains/master

Browse files

added inference script and added get_inference_config() in config.py

Files changed (2) hide show
  1. config.py +25 -4
  2. inference.py +125 -0
config.py CHANGED
@@ -19,7 +19,7 @@ _C.BASE = ['']
19
  # -----------------------------------------------------------------------------
20
  _C.DATA = CN()
21
  # Batch size for a single GPU, could be overwritten by command line argument
22
- _C.DATA.BATCH_SIZE = 128
23
  # Path to dataset, could be overwritten by command line argument
24
  _C.DATA.DATA_PATH = ''
25
  # Dataset name
@@ -37,7 +37,7 @@ _C.DATA.CACHE_MODE = 'part'
37
  # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
38
  _C.DATA.PIN_MEMORY = True
39
  # Number of data loading threads
40
- _C.DATA.NUM_WORKERS = 8
41
  # hdfs data dir
42
  _C.DATA.TRAIN_PATH = None
43
  _C.DATA.VAL_PATH = None
@@ -89,9 +89,9 @@ _C.TRAIN.START_EPOCH = 0
89
  _C.TRAIN.EPOCHS = 300
90
  _C.TRAIN.WARMUP_EPOCHS = 20
91
  _C.TRAIN.WEIGHT_DECAY = 0.05
92
- _C.TRAIN.BASE_LR = 5e-4
93
  _C.TRAIN.WARMUP_LR = 5e-7
94
- _C.TRAIN.MIN_LR = 5e-6
95
  # Clip gradient norm
96
  _C.TRAIN.CLIP_GRAD = 5.0
97
  # Auto resume from latest checkpoint
@@ -271,3 +271,24 @@ def get_config(args):
271
  update_config(config, args)
272
 
273
  return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # -----------------------------------------------------------------------------
20
  _C.DATA = CN()
21
  # Batch size for a single GPU, could be overwritten by command line argument
22
+ _C.DATA.BATCH_SIZE = 32
23
  # Path to dataset, could be overwritten by command line argument
24
  _C.DATA.DATA_PATH = ''
25
  # Dataset name
 
37
  # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
38
  _C.DATA.PIN_MEMORY = True
39
  # Number of data loading threads
40
+ _C.DATA.NUM_WORKERS = 4
41
  # hdfs data dir
42
  _C.DATA.TRAIN_PATH = None
43
  _C.DATA.VAL_PATH = None
 
89
  _C.TRAIN.EPOCHS = 300
90
  _C.TRAIN.WARMUP_EPOCHS = 20
91
  _C.TRAIN.WEIGHT_DECAY = 0.05
92
+ _C.TRAIN.BASE_LR = 1e-4 # 5e-4
93
  _C.TRAIN.WARMUP_LR = 5e-7
94
+ _C.TRAIN.MIN_LR = 1e-5 # 5e-6
95
  # Clip gradient norm
96
  _C.TRAIN.CLIP_GRAD = 5.0
97
  # Auto resume from latest checkpoint
 
271
  update_config(config, args)
272
 
273
  return config
274
+
275
+
276
+ ################### For Inferencing ####################
277
+ def update_inference_config(config, args):
278
+ _update_config_from_file(config, args.cfg)
279
+
280
+ config.defrost()
281
+
282
+ config.freeze()
283
+
284
+
285
+ def get_inference_config(cfg_path):
286
+ """Get a yacs CfgNode object with default values."""
287
+ # Return a clone so that the defaults will not be altered
288
+ # This is for the "local variable" use pattern
289
+ config = _C.clone()
290
+ update_inference_config(config, cfg_path)
291
+
292
+ return config
293
+
294
+ ################### For Inferencing ####################
inference.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ from PIL import Image
4
+ from config import get_inference_config
5
+ from models import build_model
6
+ from torch.autograd import Variable
7
+ from torchvision.transforms import transforms
8
+ import numpy as np
9
+ import argparse
10
+
11
+ try:
12
+ from apex import amp
13
+ except ImportError:
14
+ amp = None
15
+
16
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
17
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
18
+
19
+
20
+ class Namespace:
21
+ def __init__(self, **kwargs):
22
+ self.__dict__.update(kwargs)
23
+
24
+
25
+ def model_config(config_path):
26
+ args = Namespace(cfg=config_path)
27
+ config = get_inference_config(args)
28
+ return config
29
+
30
+
31
+ def read_class_names(file_path):
32
+ file = open(file_path, 'r')
33
+ lines = file.readlines()
34
+ class_list = []
35
+
36
+ for l in lines:
37
+ line = l.strip().split()
38
+ # class_list.append(line[0])
39
+ class_list.append(line[1][4:])
40
+
41
+ classes = tuple(class_list)
42
+ return classes
43
+
44
+
45
+ class GenerateEmbedding:
46
+ def __init__(self, text_file):
47
+ self.text_file = text_file
48
+
49
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
50
+ self.model = AutoModel.from_pretrained("bert-base-uncased")
51
+
52
+ def generate(self):
53
+ text_list = []
54
+ with open(self.text_file, 'r') as f_text:
55
+ for line in f_text:
56
+ line = line.encode(encoding='UTF-8', errors='strict')
57
+ line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
58
+ line = line.decode('UTF-8', 'strict')
59
+ text_list.append(line)
60
+ # data = f_text.read()
61
+ select_index = np.random.randint(len(text_list))
62
+ inputs = self.tokenizer(text_list[select_index], return_tensors="pt", padding="max_length",
63
+ truncation=True, max_length=32)
64
+ outputs = self.model(**inputs)
65
+ embedding_mean = outputs[1].mean(dim=0).reshape(1, -1).detach().numpy()
66
+ embedding_full = outputs[1].detach().numpy()
67
+ embedding_words = outputs[0] # outputs[0].detach().numpy()
68
+ return None, None, embedding_words
69
+
70
+
71
+ class Inference:
72
+ def __init__(self, config_path, model_path):
73
+ self.config_path = config_path
74
+ self.model_path = model_path
75
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76
+ # self.classes = ("cat", "dog")
77
+ self.classes = read_class_names(r"D:\dataset\CUB_200_2011\CUB_200_2011\classes_custom.txt")
78
+
79
+ self.config = model_config(self.config_path)
80
+ self.model = build_model(self.config)
81
+ self.checkpoint = torch.load(self.model_path, map_location='cpu')
82
+ self.model.load_state_dict(self.checkpoint['model'], strict=False)
83
+ self.model.eval()
84
+ self.model.cuda()
85
+
86
+ self.transform_img = transforms.Compose([
87
+ transforms.Resize((224, 224), interpolation=Image.BILINEAR),
88
+ transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
89
+ transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
90
+ ])
91
+
92
+ def infer(self, img_path, meta_data_path):
93
+ _, _, meta = GenerateEmbedding(meta_data_path).generate()
94
+ meta = meta.cuda()
95
+ img = Image.open(img_path).convert('RGB')
96
+ img = self.transform_img(img)
97
+ img.unsqueeze_(0)
98
+ img = img.cuda()
99
+ img = Variable(img).to(self.device)
100
+ out = self.model(img, meta)
101
+
102
+ _, pred = torch.max(out.data, 1)
103
+ predict = self.classes[pred.data.item()]
104
+ # print(Fore.MAGENTA + f"The Prediction is: {predict}")
105
+ return predict
106
+
107
+
108
+ def parse_option():
109
+ parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
110
+ parser.add_argument('--cfg', type=str, default='D:/pycharmprojects/MetaFormer/configs/MetaFG_meta_bert_1_224.yaml', metavar="FILE", help='path to config file', )
111
+ # easy config modification
112
+ parser.add_argument('--model-path', default='D:\pycharmprojects\MetaFormer\output\MetaFG_meta_1\cub_200\ckpt_epoch_92.pth', type=str, help="path to model data")
113
+ parser.add_argument('--img-path', default=r"D:\dataset\CUB_200_2011\CUB_200_2011\images\012.Yellow_headed_Blackbird\Yellow_Headed_Blackbird_0003_8337.jpg", type=str, help='path to image')
114
+ parser.add_argument('--meta-path', default=r"D:\dataset\CUB_200_2011\text_c10\012.Yellow_headed_Blackbird\Yellow_Headed_Blackbird_0003_8337.txt", type=str, help='path to meta data')
115
+ args = parser.parse_args()
116
+ return args
117
+
118
+
119
+ if __name__ == '__main__':
120
+ args = parse_option()
121
+ result = Inference(config_path=args.cfg,
122
+ model_path=args.model_path).infer(img_path=args.img_path, meta_data_path=args.meta_path)
123
+ print("Predicted: ", result)
124
+
125
+ # Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'