Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -51,6 +51,7 @@ app_ui = ui.page_fillable(
|
|
51 |
ui.input_switch("show_margins", "Show marginal plots", value=True),
|
52 |
),
|
53 |
ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image
|
|
|
54 |
ui.output_ui("value_boxes"),
|
55 |
ui.output_plot("scatter", fill=True),
|
56 |
ui.help_text(
|
@@ -70,7 +71,29 @@ def tif_bytes_to_pil_image(tif_bytes):
|
|
70 |
|
71 |
return image
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def server(input: Inputs, output: Outputs, session: Session):
|
|
|
|
|
|
|
|
|
74 |
@reactive.Calc
|
75 |
def uploaded_image_path() -> str:
|
76 |
"""Returns the path to the uploaded image"""
|
@@ -88,34 +111,75 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
88 |
return img
|
89 |
else:
|
90 |
return None # Return an empty string if no image is uploaded
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def process_image():
|
93 |
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
|
|
|
|
94 |
# Load the uploaded image
|
95 |
uploaded_image_bytes = input.tile_image()[0].read()
|
96 |
|
97 |
# Convert the uploaded TIFF bytes to a PIL Image object
|
98 |
uploaded_image = tif_bytes_to_pil_image(uploaded_image_bytes)
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
# image_array = preprocess_image(uploaded_image)
|
104 |
-
|
105 |
-
# Load the model configuration
|
106 |
-
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
|
107 |
-
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
108 |
|
109 |
-
#
|
110 |
-
|
111 |
-
# Update the model by loading the weights from saved file
|
112 |
-
model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu'))
|
113 |
-
model.load_state_dict(model_state_dict)
|
114 |
-
|
115 |
-
# set the device to cuda if available, otherwise use cpu
|
116 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
117 |
-
model.to(device)
|
118 |
|
|
|
|
|
|
|
|
|
119 |
# Evaluate the image with the model
|
120 |
# Example: predictions = model.predict(image_array)
|
121 |
|
|
|
51 |
ui.input_switch("show_margins", "Show marginal plots", value=True),
|
52 |
),
|
53 |
ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image
|
54 |
+
ui.output_text("processed_output")
|
55 |
ui.output_ui("value_boxes"),
|
56 |
ui.output_plot("scatter", fill=True),
|
57 |
ui.help_text(
|
|
|
71 |
|
72 |
return image
|
73 |
|
74 |
+
def load_model():
|
75 |
+
""" Get Model """
|
76 |
+
# Load the model configuration
|
77 |
+
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
|
78 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
79 |
+
|
80 |
+
# Create an instance of the model architecture with the loaded configuration
|
81 |
+
model = SamModel(config=model_config)
|
82 |
+
# Update the model by loading the weights from saved file
|
83 |
+
model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu'))
|
84 |
+
model.load_state_dict(model_state_dict)
|
85 |
+
|
86 |
+
# set the device to cuda if available, otherwise use cpu
|
87 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
88 |
+
model.to(device)
|
89 |
+
|
90 |
+
return model, processor
|
91 |
+
|
92 |
def server(input: Inputs, output: Outputs, session: Session):
|
93 |
+
|
94 |
+
# load model and processor once
|
95 |
+
model, processor = load_model()
|
96 |
+
|
97 |
@reactive.Calc
|
98 |
def uploaded_image_path() -> str:
|
99 |
"""Returns the path to the uploaded image"""
|
|
|
111 |
return img
|
112 |
else:
|
113 |
return None # Return an empty string if no image is uploaded
|
114 |
+
|
115 |
+
@reactive.Calc
|
116 |
+
def generate_input_points():
|
117 |
+
"""
|
118 |
+
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
|
119 |
+
Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
|
120 |
+
Generally yields to much better results. The points can be obtained by passing a
|
121 |
+
list of list of list to the processor that will create corresponding torch tensors
|
122 |
+
of dimension 4. The first dimension is the image batch size, the second dimension
|
123 |
+
is the point batch size (i.e. how many segmentation masks do we want the model to
|
124 |
+
predict per input point), the third dimension is the number of points per segmentation
|
125 |
+
mask (it is possible to pass multiple points for a single mask), and the last dimension
|
126 |
+
is the x (vertical) and y (horizontal) coordinates of the point. If a different number
|
127 |
+
of points is passed either for each image, or for each mask, the processor will create
|
128 |
+
“PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
|
129 |
+
embedding will be skipped for these points using the labels.
|
130 |
+
|
131 |
+
"""
|
132 |
+
# Define the size of your array
|
133 |
+
array_size = 256
|
134 |
+
|
135 |
+
# Define the size of your grid
|
136 |
+
grid_size = 10
|
137 |
+
|
138 |
+
# Generate the grid points
|
139 |
+
x = np.linspace(0, array_size-1, grid_size)
|
140 |
+
y = np.linspace(0, array_size-1, grid_size)
|
141 |
+
|
142 |
+
# Generate a grid of coordinates
|
143 |
+
xv, yv = np.meshgrid(x, y)
|
144 |
+
|
145 |
+
# Convert the numpy arrays to lists
|
146 |
+
xv_list = xv.tolist()
|
147 |
+
yv_list = yv.tolist()
|
148 |
+
|
149 |
+
# Combine the x and y coordinates into a list of list of lists
|
150 |
+
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
|
151 |
+
|
152 |
+
#We need to reshape our nxn grid to the expected shape of the input_points tensor
|
153 |
+
# (batch_size, point_batch_size, num_points_per_image, 2),
|
154 |
+
# where the last dimension of 2 represents the x and y coordinates of each point.
|
155 |
+
#batch_size: The number of images you're processing at once.
|
156 |
+
#point_batch_size: The number of point sets you have for each image.
|
157 |
+
#num_points_per_image: The number of points in each set.
|
158 |
+
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
|
159 |
+
|
160 |
+
return input_points
|
161 |
+
|
162 |
def process_image():
|
163 |
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
164 |
+
|
165 |
+
""" Get Image """
|
166 |
# Load the uploaded image
|
167 |
uploaded_image_bytes = input.tile_image()[0].read()
|
168 |
|
169 |
# Convert the uploaded TIFF bytes to a PIL Image object
|
170 |
uploaded_image = tif_bytes_to_pil_image(uploaded_image_bytes)
|
171 |
|
172 |
+
""" Prepare Inputs """
|
173 |
+
# get input points prompt (grid of points)
|
174 |
+
input_points = generate_input_points(image)
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
+
# prepare image and prompt for the model
|
177 |
+
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
# remove batch dimension which the processor adds by default
|
180 |
+
inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
181 |
+
|
182 |
+
""" Get Predictions """
|
183 |
# Evaluate the image with the model
|
184 |
# Example: predictions = model.predict(image_array)
|
185 |
|