chirmy commited on
Commit
9cf45b7
·
verified ·
1 Parent(s): f9274a6

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +16 -13
script.py CHANGED
@@ -59,16 +59,15 @@ class PytorchWorker:
59
  return outputs.cpu() # Convert tensor to list
60
 
61
 
62
- def make_submission(test_metadata, model_path, model_path2, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
63
  """Make submission with given """
64
-
65
- model = PytorchWorker(model_path, model_name)
66
- model2 = PytorchWorker(model_path2, model_name)
67
 
68
  predictions = []
69
  correct_max_values = []
70
  incorrect_max_values = []
71
-
72
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
73
  image_path = os.path.join(images_root_path, row.image_path)
74
  test_image = Image.open(image_path).convert("RGB")
@@ -77,11 +76,14 @@ def make_submission(test_metadata, model_path, model_path2, model_name, output_c
77
  outputs2 = model2.predict_image(test_image)
78
 
79
  # max_value = torch.max(outputs+outputs2)
80
- _, preds = torch.max(outputs+outputs2, 1)
 
 
81
  pred_class_id = preds.tolist()
 
82
  # max_value2 = torch.max(outputs2)
83
 
84
- pred_class_id = pred_class_id[0] if pred_class_id[0] != 1604 else -1
85
 
86
  predictions.append(pred_class_id)
87
 
@@ -105,10 +107,10 @@ if __name__ == "__main__":
105
  # MODEL_PATH = './efficientnet_b3_epoch_21_trick1.2.5_a0.7237_l17.1662.pth'
106
  # MODEL_PATH = './efficientnet_b3_epoch_21_trcik1.5.2.pth'
107
  # MODEL_PATH = './efficientnet_b3_epoch_28_1.4.3.pth'
108
- MODEL_PATH = './efficientnet_b3_epoch_28_1.4.3.pth'
109
- MODEL_PATH2 = './efficientnet_b3_epoch_23_trick1.4.1.pth'
110
- MODEL_NAME = 'tf_efficientnet_b3_ns' #"tf_efficientnet_b1.ap_in1k"
111
-
112
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
113
  test_metadata = pd.read_csv(metadata_file_path)
114
 
@@ -116,5 +118,6 @@ if __name__ == "__main__":
116
  test_metadata=test_metadata,
117
  model_path=MODEL_PATH,
118
  model_path2=MODEL_PATH2,
119
- model_name=MODEL_NAME
120
- )
 
 
59
  return outputs.cpu() # Convert tensor to list
60
 
61
 
62
+ def make_submission(test_metadata, model_path, model_path2, model_name, model_name2, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
63
  """Make submission with given """
64
+ model = PytorchWorker(model_path, model_name, number_of_categories=1604)
65
+ model2 = PytorchWorker(model_path2, model_name2)
 
66
 
67
  predictions = []
68
  correct_max_values = []
69
  incorrect_max_values = []
70
+
71
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
72
  image_path = os.path.join(images_root_path, row.image_path)
73
  test_image = Image.open(image_path).convert("RGB")
 
76
  outputs2 = model2.predict_image(test_image)
77
 
78
  # max_value = torch.max(outputs+outputs2)
79
+ _, preds = torch.max(outputs, 1) # baseline
80
+ _, preds2 = torch.max(outputs2, 1) # 1.4.3
81
+
82
  pred_class_id = preds.tolist()
83
+ pred_class_id2 = preds2.tolist()
84
  # max_value2 = torch.max(outputs2)
85
 
86
+ pred_class_id = pred_class_id[0] if pred_class_id2[0] != 1604 else -1
87
 
88
  predictions.append(pred_class_id)
89
 
 
107
  # MODEL_PATH = './efficientnet_b3_epoch_21_trick1.2.5_a0.7237_l17.1662.pth'
108
  # MODEL_PATH = './efficientnet_b3_epoch_21_trcik1.5.2.pth'
109
  # MODEL_PATH = './efficientnet_b3_epoch_28_1.4.3.pth'
110
+ MODEL_PATH = './pytorch_model.bin'
111
+ MODEL_PATH2 = './efficientnet_b3_epoch_28_1.4.3.pth'
112
+ MODEL_NAME = "tf_efficientnet_b1_ap"
113
+ MODEL_NAME2 = "tf_efficientnet_b3_ns"
114
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
115
  test_metadata = pd.read_csv(metadata_file_path)
116
 
 
118
  test_metadata=test_metadata,
119
  model_path=MODEL_PATH,
120
  model_path2=MODEL_PATH2,
121
+ model_name=MODEL_NAME,
122
+ model_name2=MODEL_NAME2
123
+ )