chirmy commited on
Commit
220d059
·
verified ·
1 Parent(s): e2fe2c4

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +19 -25
script.py CHANGED
@@ -49,47 +49,39 @@ class PytorchWorker:
49
  # T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
50
 
51
 
52
- def predict_image(self, image: np.ndarray) -> list():
53
  """Run inference using ONNX runtime.
54
  :param image: Input image as numpy array.
55
  :return: A list with logits and confidences.
56
  """
57
-
58
- # logits = self.model(self.transforms(image).unsqueeze(0).to(self.device))
59
-
60
  self.model.eval()
61
-
62
  outputs = self.model(self.transforms(image).unsqueeze(0).to(self.device))
63
-
64
- _, preds = torch.max(outputs, 1)
65
-
66
- preds = preds.cpu() # Move tensor to CPU
67
-
68
- # post process
69
- # max_value = torch.max(outputs)
70
- # if max_value < -20:
71
- # preds[0]=1604
72
-
73
- print("preds: ", preds)
74
 
75
- return preds.tolist() # Convert tensor to list
76
 
77
-
78
- def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
79
  """Make submission with given """
80
 
81
  model = PytorchWorker(model_path, model_name)
82
-
 
83
  predictions = []
 
 
84
 
85
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
86
  image_path = os.path.join(images_root_path, row.image_path)
87
-
88
  test_image = Image.open(image_path).convert("RGB")
89
-
90
- logits = model.predict_image(test_image)
91
 
92
- pred_class_id = logits[0] if logits[0] !=1604 else -1
 
 
 
 
 
 
 
 
93
 
94
  predictions.append(pred_class_id)
95
 
@@ -113,7 +105,8 @@ if __name__ == "__main__":
113
  # MODEL_PATH = './efficientnet_b3_epoch_21_trick1.2.5_a0.7237_l17.1662.pth'
114
  # MODEL_PATH = './efficientnet_b3_epoch_21_trcik1.5.2.pth'
115
  # MODEL_PATH = './efficientnet_b3_epoch_28_1.4.3.pth'
116
- MODEL_PATH = './efficientnet_b3_epoch_28_trick1.4.3.2.pth'
 
117
  MODEL_NAME = 'tf_efficientnet_b3_ns' #"tf_efficientnet_b1.ap_in1k"
118
 
119
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
@@ -122,5 +115,6 @@ if __name__ == "__main__":
122
  make_submission(
123
  test_metadata=test_metadata,
124
  model_path=MODEL_PATH,
 
125
  model_name=MODEL_NAME
126
  )
 
49
  # T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
50
 
51
 
52
+ def predict_image(self, image: np.ndarray) -> list:
53
  """Run inference using ONNX runtime.
54
  :param image: Input image as numpy array.
55
  :return: A list with logits and confidences.
56
  """
 
 
 
57
  self.model.eval()
 
58
  outputs = self.model(self.transforms(image).unsqueeze(0).to(self.device))
59
+ return outputs.cpu() # Convert tensor to list
 
 
 
 
 
 
 
 
 
 
60
 
 
61
 
62
+ def make_submission(test_metadata, model_path, model_path2, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
 
63
  """Make submission with given """
64
 
65
  model = PytorchWorker(model_path, model_name)
66
+ model2 = PytorchWorker(model_path2, model_name)
67
+
68
  predictions = []
69
+ correct_max_values = []
70
+ incorrect_max_values = []
71
 
72
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
73
  image_path = os.path.join(images_root_path, row.image_path)
 
74
  test_image = Image.open(image_path).convert("RGB")
 
 
75
 
76
+ outputs = model.predict_image(test_image)
77
+ outputs2 = model2.predict_image(test_image)
78
+
79
+ # max_value = torch.max(outputs+outputs2)
80
+ _, preds = torch.max(outputs+outputs2, 1)
81
+ pred_class_id = preds.tolist()
82
+ # max_value2 = torch.max(outputs2)
83
+
84
+ pred_class_id = pred_class_id[0] if pred_class_id[0] != 1604 else -1
85
 
86
  predictions.append(pred_class_id)
87
 
 
105
  # MODEL_PATH = './efficientnet_b3_epoch_21_trick1.2.5_a0.7237_l17.1662.pth'
106
  # MODEL_PATH = './efficientnet_b3_epoch_21_trcik1.5.2.pth'
107
  # MODEL_PATH = './efficientnet_b3_epoch_28_1.4.3.pth'
108
+ MODEL_PATH = './efficientnet_b3_epoch_28_1.4.3.pth'
109
+ MODEL_PATH2 = './efficientnet_b3_epoch_23_trick1.4.1.pth'
110
  MODEL_NAME = 'tf_efficientnet_b3_ns' #"tf_efficientnet_b1.ap_in1k"
111
 
112
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
 
115
  make_submission(
116
  test_metadata=test_metadata,
117
  model_path=MODEL_PATH,
118
+ model_path2=MODEL_PATH2,
119
  model_name=MODEL_NAME
120
  )