chychiu commited on
Commit
894f232
·
1 Parent(s): 815a0c4

fix submission script lmao

Browse files
Files changed (1) hide show
  1. script.py +18 -13
script.py CHANGED
@@ -71,13 +71,18 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
71
  predictions.append(np.argmax(logits))
72
 
73
  test_metadata["class_id"] = predictions
74
-
75
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
 
 
 
 
 
76
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
77
 
78
  def test_submission():
79
 
80
- metadata_file_path = "../val_mini.csv"
81
  test_metadata = pd.read_csv(metadata_file_path)
82
 
83
  make_submission(
@@ -90,18 +95,18 @@ def test_submission():
90
 
91
  if __name__ == "__main__":
92
 
93
- # test_submission()
94
 
95
- import zipfile
96
 
97
- with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
98
- zip_ref.extractall("/tmp/data")
99
 
100
- metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
101
- test_metadata = pd.read_csv(metadata_file_path)
102
 
103
- make_submission(
104
- test_metadata=test_metadata,
105
- model_path=MODEL_PATH,
106
- model_name=MODEL_NAME
107
- )
 
71
  predictions.append(np.argmax(logits))
72
 
73
  test_metadata["class_id"] = predictions
74
+
75
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
76
+
77
+ for ix, row in user_pred_df.iterrows():
78
+ if row['class_id'] == 1604:
79
+ user_pred_df.loc[ix, 'class_id'] = -1
80
+
81
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
82
 
83
  def test_submission():
84
 
85
+ metadata_file_path = "../trial_test.csv"
86
  test_metadata = pd.read_csv(metadata_file_path)
87
 
88
  make_submission(
 
95
 
96
  if __name__ == "__main__":
97
 
98
+ test_submission()
99
 
100
+ # import zipfile
101
 
102
+ # with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
103
+ # zip_ref.extractall("/tmp/data")
104
 
105
+ # metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
106
+ # test_metadata = pd.read_csv(metadata_file_path)
107
 
108
+ # make_submission(
109
+ # test_metadata=test_metadata,
110
+ # model_path=MODEL_PATH,
111
+ # model_name=MODEL_NAME
112
+ # )