Anthony Miyaguchi commited on
Commit
43c4ba2
·
1 Parent(s): 6c84b4c

Remove duplicates

Browse files
Files changed (1) hide show
  1. script.py +2 -2
script.py CHANGED
@@ -24,7 +24,7 @@ class ImageDataset(Dataset):
24
  def __getitem__(self, idx):
25
  row = self.metadata.iloc[idx]
26
  image_path = Path(self.images_root_path) / row.filename
27
- img = Image.open(image_path)
28
  img = torch.from_numpy(np.array(img))
29
  return {"features": img, "observation_id": row.observation_id}
30
 
@@ -79,7 +79,7 @@ def make_submission(
79
  for observation_id, class_id in zip(observation_ids, class_ids):
80
  row = {"observation_id": int(observation_id), "class_id": int(class_id)}
81
  rows.append(row)
82
- submission_df = pd.DataFrame(rows)
83
  submission_df.to_csv(output_csv_path, index=False)
84
 
85
 
 
24
  def __getitem__(self, idx):
25
  row = self.metadata.iloc[idx]
26
  image_path = Path(self.images_root_path) / row.filename
27
+ img = Image.open(image_path).convert("RGB")
28
  img = torch.from_numpy(np.array(img))
29
  return {"features": img, "observation_id": row.observation_id}
30
 
 
79
  for observation_id, class_id in zip(observation_ids, class_ids):
80
  row = {"observation_id": int(observation_id), "class_id": int(class_id)}
81
  rows.append(row)
82
+ submission_df = pd.DataFrame(rows).drop_duplicates("observation_id", keep="first")
83
  submission_df.to_csv(output_csv_path, index=False)
84
 
85