chirmy commited on
Commit
918ace3
·
verified ·
1 Parent(s): b1ec5f8

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +32 -28
script.py CHANGED
@@ -49,41 +49,47 @@ class PytorchWorker:
49
  # T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
50
 
51
 
52
- def predict_image(self, image: np.ndarray) -> list:
53
  """Run inference using ONNX runtime.
54
  :param image: Input image as numpy array.
55
  :return: A list with logits and confidences.
56
  """
 
 
 
57
  self.model.eval()
 
58
  outputs = self.model(self.transforms(image).unsqueeze(0).to(self.device))
59
- return outputs.cpu() # Convert tensor to list
 
 
 
60
 
 
 
 
 
 
 
 
 
61
 
62
- def make_submission(test_metadata, model_path, model_path2, model_name, model_name2, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
 
63
  """Make submission with given """
64
- model = PytorchWorker(model_path, model_name, number_of_categories=1604)
65
- model2 = PytorchWorker(model_path2, model_name2)
66
-
67
  predictions = []
68
- correct_max_values = []
69
- incorrect_max_values = []
70
-
71
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
72
  image_path = os.path.join(images_root_path, row.image_path)
 
73
  test_image = Image.open(image_path).convert("RGB")
 
 
74
 
75
- outputs = model.predict_image(test_image)
76
- outputs2 = model2.predict_image(test_image)
77
-
78
- # max_value = torch.max(outputs+outputs2)
79
- _, preds = torch.max(outputs, 1) # baseline
80
- _, preds2 = torch.max(outputs2, 1) # 1.4.3
81
-
82
- pred_class_id = preds.tolist()
83
- pred_class_id2 = preds2.tolist()
84
- # max_value2 = torch.max(outputs2)
85
-
86
- pred_class_id = pred_class_id[0] if pred_class_id2[0] != 1604 else -1
87
 
88
  predictions.append(pred_class_id)
89
 
@@ -107,17 +113,15 @@ if __name__ == "__main__":
107
  # MODEL_PATH = './efficientnet_b3_epoch_21_trick1.2.5_a0.7237_l17.1662.pth'
108
  # MODEL_PATH = './efficientnet_b3_epoch_21_trcik1.5.2.pth'
109
  # MODEL_PATH = './efficientnet_b3_epoch_28_1.4.3.pth'
110
- MODEL_PATH = './pytorch_model.bin'
111
- MODEL_PATH2 = './efficientnet_b3_epoch_28_1.4.3.pth'
112
- MODEL_NAME = "tf_efficientnet_b1_ap"
113
- MODEL_NAME2 = "tf_efficientnet_b3_ns"
114
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
115
  test_metadata = pd.read_csv(metadata_file_path)
116
 
117
  make_submission(
118
  test_metadata=test_metadata,
119
  model_path=MODEL_PATH,
120
- model_path2=MODEL_PATH2,
121
- model_name=MODEL_NAME,
122
- model_name2=MODEL_NAME2
123
  )
 
49
  # T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
50
 
51
 
52
+ def predict_image(self, image: np.ndarray) -> list():
53
  """Run inference using ONNX runtime.
54
  :param image: Input image as numpy array.
55
  :return: A list with logits and confidences.
56
  """
57
+
58
+ # logits = self.model(self.transforms(image).unsqueeze(0).to(self.device))
59
+
60
  self.model.eval()
61
+
62
  outputs = self.model(self.transforms(image).unsqueeze(0).to(self.device))
63
+
64
+ _, preds = torch.max(outputs, 1)
65
+
66
+ preds = preds.cpu() # Move tensor to CPU
67
 
68
+ # post process
69
+ # max_value = torch.max(outputs)
70
+ # if max_value < -20:
71
+ # preds[0]=1604
72
+
73
+ print("preds: ", preds)
74
+
75
+ return preds.tolist() # Convert tensor to list
76
 
77
+
78
+ def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
79
  """Make submission with given """
80
+
81
+ model = PytorchWorker(model_path, model_name)
82
+
83
  predictions = []
84
+
 
 
85
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
86
  image_path = os.path.join(images_root_path, row.image_path)
87
+
88
  test_image = Image.open(image_path).convert("RGB")
89
+
90
+ logits = model.predict_image(test_image)
91
 
92
+ pred_class_id = logits[0] if logits[0] !=1604 else -1
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  predictions.append(pred_class_id)
95
 
 
113
  # MODEL_PATH = './efficientnet_b3_epoch_21_trick1.2.5_a0.7237_l17.1662.pth'
114
  # MODEL_PATH = './efficientnet_b3_epoch_21_trcik1.5.2.pth'
115
  # MODEL_PATH = './efficientnet_b3_epoch_28_1.4.3.pth'
116
+ # MODEL_PATH = './efficientnet_b3_epoch_28_trick1.4.3.2.pth'
117
+ MODEL_PATH = './fused_model_soup.pth'
118
+ MODEL_NAME = 'tf_efficientnet_b3_ns' #"tf_efficientnet_b1.ap_in1k"
119
+
120
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
121
  test_metadata = pd.read_csv(metadata_file_path)
122
 
123
  make_submission(
124
  test_metadata=test_metadata,
125
  model_path=MODEL_PATH,
126
+ model_name=MODEL_NAME
 
 
127
  )