erfaneshrati commited on
Commit
024f5b3
·
1 Parent(s): 25bd72d

jp2 as input type

Browse files
Files changed (2) hide show
  1. config.pbtxt +16 -7
  2. model.py +88 -15
config.pbtxt CHANGED
@@ -1,16 +1,25 @@
1
  backend: "python"
2
- max_batch_size: 0 # Disable batching if not supported
 
3
  input [
4
  {
5
- name: "input_array"
6
- data_type: TYPE_FP32
7
- dims: [3, -1, -1] # Channels, height, width
8
  }
9
  ]
 
10
  output [
11
  {
12
  name: "output_mask"
13
- data_type: TYPE_UINT8 # Adjust based on actual output type
14
- dims: [-1, -1] # Height, width
15
  }
16
- ]
 
 
 
 
 
 
 
 
1
  backend: "python"
2
+ max_batch_size: 0 # Keep batching disabled as per original config
3
+
4
  input [
5
  {
6
+ name: "input_jp2_bytes" # New input name for JP2 bytes
7
+ data_type: TYPE_STRING # Use TYPE_STRING for bytes
8
+ dims: [ 3 ] # Expecting 3 elements: Red, Green, NIR bytes
9
  }
10
  ]
11
+
12
  output [
13
  {
14
  name: "output_mask"
15
+ data_type: TYPE_UINT8
16
+ dims: [-1, -1] # Variable height, width
17
  }
18
+ ]
19
+
20
+ # Optional: Specify instance_group if running on GPU
21
+ # instance_group [
22
+ # {
23
+ # kind: KIND_GPU
24
+ # }
25
+ # ]
model.py CHANGED
@@ -1,28 +1,101 @@
1
  import numpy as np
2
  import triton_python_backend_utils as pb_utils
3
  from omnicloudmask import predict_from_array
 
 
 
4
 
5
  class TritonPythonModel:
6
  def initialize(self, args):
7
- pass
 
 
 
 
 
8
 
9
  def execute(self, requests):
 
 
 
10
  responses = []
 
11
  for request in requests:
12
- # Get input tensor
13
- input_tensor = pb_utils.get_input_tensor_by_name(request, "input_array")
14
- input_array = input_tensor.as_numpy()
15
-
16
- # Perform inference
17
- pred_mask = predict_from_array(input_array)
18
-
19
- # Create output tensor
20
- output_tensor = pb_utils.Tensor(
21
- "output_mask",
22
- pred_mask.astype(np.uint8)
23
- )
24
- responses.append(pb_utils.InferenceResponse([output_tensor]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  return responses
26
 
27
  def finalize(self):
28
- pass
 
 
 
 
 
1
  import numpy as np
2
  import triton_python_backend_utils as pb_utils
3
  from omnicloudmask import predict_from_array
4
+ import rasterio
5
+ from rasterio.io import MemoryFile
6
+ from rasterio.enums import Resampling
7
 
8
  class TritonPythonModel:
9
  def initialize(self, args):
10
+ """
11
+ Initialize the model. This function is called once when the model is loaded.
12
+ """
13
+ # You can load models or initialize resources here if needed.
14
+ # Ensure rasterio is installed in the Python backend environment.
15
+ print('Initialized Cloud Detection model with JP2 input')
16
 
17
  def execute(self, requests):
18
+ """
19
+ Process inference requests.
20
+ """
21
  responses = []
22
+ # Every request must contain three JP2 byte strings (Red, Green, NIR).
23
  for request in requests:
24
+ # Get the input tensor containing the byte arrays
25
+ input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
26
+ # as_numpy() for TYPE_STRING gives an ndarray of Python bytes objects
27
+ jp2_bytes_list = input_tensor.as_numpy()
28
+
29
+ if len(jp2_bytes_list) != 3:
30
+ # Send an error response if the input shape is incorrect
31
+ error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}")
32
+ response = pb_utils.InferenceResponse(output_tensors=[], error=error)
33
+ responses.append(response)
34
+ continue # Skip to the next request
35
+
36
+ # Assume order: Red, Green, NIR based on client logic
37
+ red_bytes = jp2_bytes_list[0]
38
+ green_bytes = jp2_bytes_list[1]
39
+ nir_bytes = jp2_bytes_list[2]
40
+
41
+ try:
42
+ # Process JP2 bytes using rasterio in memory
43
+ with MemoryFile(red_bytes) as memfile_red:
44
+ with memfile_red.open() as src_red:
45
+ red_data = src_red.read(1).astype(np.float32)
46
+ target_height = src_red.height
47
+ target_width = src_red.width
48
+
49
+ with MemoryFile(green_bytes) as memfile_green:
50
+ with memfile_green.open() as src_green:
51
+ # Ensure green band matches red band dimensions (should if B03)
52
+ if src_green.height != target_height or src_green.width != target_width:
53
+ # Optional: Resample green if necessary, though B03 usually matches B04
54
+ green_data = src_green.read(
55
+ 1,
56
+ out_shape=(1, target_height, target_width),
57
+ resampling=Resampling.bilinear
58
+ ).astype(np.float32)
59
+ else:
60
+ green_data = src_green.read(1).astype(np.float32)
61
+
62
+
63
+ with MemoryFile(nir_bytes) as memfile_nir:
64
+ with memfile_nir.open() as src_nir:
65
+ # Resample NIR (B8A) to match Red/Green (B04/B03) resolution
66
+ nir_data = src_nir.read(
67
+ 1, # Read the first band
68
+ out_shape=(1, target_height, target_width),
69
+ resampling=Resampling.bilinear
70
+ ).astype(np.float32)
71
+
72
+ # Stack bands in CHW format (Red, Green, NIR) for the model
73
+ # Match the channel order expected by predict_from_array
74
+ input_array = np.stack([red_data, green_data, nir_data], axis=0)
75
+
76
+ # Perform inference using the original function
77
+ pred_mask = predict_from_array(input_array)
78
+
79
+ # Create output tensor
80
+ output_tensor = pb_utils.Tensor(
81
+ "output_mask",
82
+ pred_mask.astype(np.uint8)
83
+ )
84
+ response = pb_utils.InferenceResponse([output_tensor])
85
+
86
+ except Exception as e:
87
+ # Handle errors during processing (e.g., invalid JP2 data)
88
+ error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}")
89
+ response = pb_utils.InferenceResponse(output_tensors=[], error=error)
90
+
91
+ responses.append(response)
92
+
93
+ # Return a list of responses
94
  return responses
95
 
96
  def finalize(self):
97
+ """
98
+ Called when the model is unloaded. Perform any necessary cleanup.
99
+ """
100
+ print('Finalizing Cloud Detection model')
101
+