chychiu commited on
Commit
d10e961
·
1 Parent(s): 010e451
Files changed (1) hide show
  1. script.py +20 -23
script.py CHANGED
@@ -12,12 +12,6 @@ def is_gpu_available():
12
  """Check if the python package `onnxruntime-gpu` is installed."""
13
  return torch.cuda.is_available()
14
 
15
- WIDTH = 224
16
- HEIGHT = 224
17
-
18
- MODEL_PATH = "metaformer-s-224.pth"
19
- MODEL_NAME = "caformer_s18.sail_in22k"
20
-
21
  class PytorchWorker:
22
  """Run inference using ONNX runtime."""
23
 
@@ -37,12 +31,12 @@ class PytorchWorker:
37
 
38
  self.model = _load_model(model_name, model_path)
39
 
40
- self.transforms = T.Compose([T.Resize((HEIGHT, WIDTH)),
41
  T.ToTensor(),
42
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
43
 
44
 
45
- def predict_image(self, image: np.ndarray) -> List:
46
  """Run inference using ONNX runtime.
47
 
48
  :param image: Input image as numpy array.
@@ -62,7 +56,7 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
62
  predictions = []
63
 
64
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
65
- image_path = os.path.join(images_root_path, row.image_path.replace("jpg", "JPG"))
66
 
67
  test_image = Image.open(image_path).convert("RGB")
68
 
@@ -80,23 +74,12 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
80
 
81
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
82
 
83
- def test_submission():
84
-
85
- metadata_file_path = "../trial_test.csv"
86
- test_metadata = pd.read_csv(metadata_file_path)
87
-
88
- make_submission(
89
- test_metadata=test_metadata,
90
- model_path=MODEL_PATH,
91
- model_name=MODEL_NAME,
92
- images_root_path="../data/DF_FULL/"
93
- )
94
-
95
-
96
  if __name__ == "__main__":
97
 
98
- # test_submission()
 
99
 
 
100
  import zipfile
101
 
102
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
@@ -110,3 +93,17 @@ if __name__ == "__main__":
110
  model_path=MODEL_PATH,
111
  model_name=MODEL_NAME
112
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """Check if the python package `onnxruntime-gpu` is installed."""
13
  return torch.cuda.is_available()
14
 
 
 
 
 
 
 
15
  class PytorchWorker:
16
  """Run inference using ONNX runtime."""
17
 
 
31
 
32
  self.model = _load_model(model_name, model_path)
33
 
34
+ self.transforms = T.Compose([T.Resize((224, 224)),
35
  T.ToTensor(),
36
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
37
 
38
 
39
+ def predict_image(self, image: np.ndarray):
40
  """Run inference using ONNX runtime.
41
 
42
  :param image: Input image as numpy array.
 
56
  predictions = []
57
 
58
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
59
+ image_path = os.path.join(images_root_path, row.image_path) #.replace("jpg", "JPG"))
60
 
61
  test_image = Image.open(image_path).convert("RGB")
62
 
 
74
 
75
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  if __name__ == "__main__":
78
 
79
+ MODEL_PATH = "metaformer-s-224.pth"
80
+ MODEL_NAME = "caformer_s18.sail_in22k"
81
 
82
+ # Real submission
83
  import zipfile
84
 
85
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
 
93
  model_path=MODEL_PATH,
94
  model_name=MODEL_NAME
95
  )
96
+
97
+ # Test submission
98
+ # metadata_file_path = "../trial_test.csv"
99
+
100
+ # test_metadata = pd.read_csv(metadata_file_path)
101
+
102
+ # make_submission(
103
+ # test_metadata=test_metadata,
104
+ # model_path=MODEL_PATH,
105
+ # model_name=MODEL_NAME,
106
+ # images_root_path="../data/DF"
107
+ # )
108
+
109
+