lukiod commited on
Commit
8d7600a
·
1 Parent(s): 4d6bb3d
Files changed (2) hide show
  1. app.py +575 -0
  2. config/config.yaml +111 -0
app.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import base64
4
+ import io
5
+ import logging # For better logging
6
+ # Import specific handlers and formatter
7
+ from logging.handlers import RotatingFileHandler
8
+ import traceback # For detailed exception logging
9
+ from flask import Flask, request, jsonify, send_from_directory
10
+ from flask_cors import CORS # To handle Cross-Origin requests from your frontend
11
+ import torch
12
+ import cv2
13
+ import numpy as np
14
+ import yaml
15
+ from torchvision import transforms
16
+ from transformers import SegformerForSemanticSegmentation
17
+ from omegaconf import OmegaConf # Import OmegaConf itself
18
+ import torch.nn.functional as F
19
+ from werkzeug.utils import secure_filename # For safer filenames
20
+
21
+ # --- Configuration ---
22
+ # Use absolute paths for robustness
23
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Directory where this script is running
24
+
25
+ # >>> Point this to your actual config file <<<
26
+ CONFIG_PATH = os.path.join(BASE_DIR, "config/config.yaml") # Assuming config.yaml is in the same dir
27
+
28
+ # >>> Point this to your actual checkpoint file <<<
29
+ CHECKPOINT_PATH = "ckpt_000-vloss_0.4685_vf1_0.6469.ckpt"
30
+
31
+ UPLOAD_FOLDER = os.path.join(BASE_DIR, 'uploads')
32
+ RESULT_FOLDER = os.path.join(BASE_DIR, 'results')
33
+ LOG_FILE_PATH = os.path.join(BASE_DIR, 'flask_app.log') # Define log file path
34
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp', 'tif', 'tiff'}
35
+
36
+ # --- Logging Setup ---
37
+ # Clear existing handlers from the root logger to avoid duplicates on reload
38
+ logging.getLogger().handlers.clear()
39
+ # Create formatter
40
+ log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s')
41
+ # Create Console Handler
42
+ console_handler = logging.StreamHandler()
43
+ console_handler.setFormatter(log_formatter)
44
+ # Create File Handler (using RotatingFileHandler for log rotation)
45
+ file_handler = RotatingFileHandler(LOG_FILE_PATH, maxBytes=5*1024*1024, backupCount=3)
46
+ file_handler.setFormatter(log_formatter)
47
+ # Get the root logger and add handlers
48
+ logger = logging.getLogger()
49
+ logger.setLevel(logging.INFO) # Set minimum level for the logger (e.g., INFO, DEBUG)
50
+ logger.addHandler(console_handler)
51
+ logger.addHandler(file_handler)
52
+
53
+ # --- Ensure upload and result directories exist ---
54
+ try:
55
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
56
+ os.makedirs(RESULT_FOLDER, exist_ok=True)
57
+ logger.info(f"Ensured directories exist: {UPLOAD_FOLDER}, {RESULT_FOLDER}")
58
+ except OSError as e:
59
+ logger.error(f"Error creating directories: {e}")
60
+ exit(1) # Exit if we can't create essential folders
61
+
62
+ # --- Load Config ---
63
+ config = None
64
+ try:
65
+ # Load the YAML file using OmegaConf
66
+ config = OmegaConf.load(CONFIG_PATH)
67
+ # Note: We don't need OmegaConf.create() if loading directly from file
68
+ logger.info(f"Configuration loaded successfully from: {CONFIG_PATH}")
69
+ # Log some key values to confirm loading
70
+ logger.info(f"Config check: num_classes={config.data.num_classes}, model_name={config.training.model_name}")
71
+ except FileNotFoundError:
72
+ logger.error(f"Configuration file not found: {CONFIG_PATH}")
73
+ exit(1)
74
+ except Exception as e: # Catch broader errors during loading/parsing
75
+ logger.error(f"Error loading or parsing configuration file '{CONFIG_PATH}': {e}")
76
+ logger.error(traceback.format_exc())
77
+ exit(1)
78
+
79
+ # --- Model Definition ---
80
+ class InferenceModel(torch.nn.Module):
81
+ def __init__(self, model_config): # Use local name 'model_config'
82
+ super().__init__()
83
+ try:
84
+ # Access config values needed for model init
85
+ model_name = model_config.training.model_name
86
+ num_classes = model_config.data.num_classes
87
+ logger.info(f"Initializing SegformerForSemanticSegmentation with model='{model_name}', num_labels={num_classes}")
88
+ self.model = SegformerForSemanticSegmentation.from_pretrained(
89
+ model_name,
90
+ num_labels=num_classes,
91
+ ignore_mismatched_sizes=True # Important if fine-tuning head size differs
92
+ )
93
+ logger.info("Segformer model part initialized.")
94
+ except AttributeError as ae:
95
+ logger.error(f"Config error during model init: Missing key? {ae}")
96
+ logger.error(f"Check if 'training.model_name' and 'data.num_classes' exist in {CONFIG_PATH}")
97
+ raise # Re-raise error to stop execution
98
+ except Exception as e:
99
+ logger.error(f"Error initializing Segformer model from Hugging Face: {e}")
100
+ logger.error(traceback.format_exc())
101
+ raise # Re-raise error to stop execution
102
+
103
+ def forward(self, x):
104
+ # Expects pixel_values as input
105
+ outputs = self.model(pixel_values=x, return_dict=True)
106
+ # Upsample logits to original input size
107
+ logits = F.interpolate(
108
+ outputs.logits,
109
+ size=x.shape[-2:], # Get H, W from input tensor x
110
+ mode="bilinear",
111
+ align_corners=False
112
+ )
113
+ return logits
114
+
115
+ # --- Utility Functions ---
116
+ def num_to_rgb(num_arr, color_map_dict):
117
+ """Converts a label mask (numpy array) to an RGB color mask."""
118
+ single_layer = np.squeeze(num_arr)
119
+ output = np.zeros(num_arr.shape[:2] + (3,), dtype=np.uint8) # Initialize with uint8 zeros
120
+
121
+ # Expects color_map_dict to be a standard Python dict {int_label: [R, G, B]}
122
+ if not isinstance(color_map_dict, dict):
123
+ logger.error(f"Invalid color_map provided to num_to_rgb: {type(color_map_dict)}. Expected dict.")
124
+ return np.float32(output) / 255.0 # Return black float image
125
+
126
+ unique_labels = np.unique(single_layer)
127
+ for k in unique_labels:
128
+ label_key = int(k) # Ensure key is standard int for lookup
129
+ if label_key in color_map_dict:
130
+ # Assign color, ensure color value is appropriate (e.g., list/tuple of 3 ints)
131
+ color = color_map_dict[label_key]
132
+ if isinstance(color, (list, tuple)) and len(color) == 3:
133
+ output[single_layer == k] = color
134
+ else:
135
+ logger.warning(f"Invalid color format for label {label_key} in color map: {color}. Skipping.")
136
+ else:
137
+ if label_key != 0: # Often 0 is background, might not be in map
138
+ logger.warning(f"Label Key {label_key} found in mask but not in provided color map.")
139
+ # Default color (e.g., black) is already set by np.zeros
140
+
141
+ return np.float32(output) / 255.0 # Return float32 RGB image [0, 1]
142
+
143
+ def denormalize(tensor, mean, std):
144
+ """Denormalizes a torch tensor (CHW format)."""
145
+ # Expects standard Python lists/tuples for mean/std
146
+ if not isinstance(mean, (list, tuple)) or not isinstance(std, (list, tuple)):
147
+ logger.error(f"Mean ({type(mean)}) or std ({type(std)}) are not lists/tuples in denormalize.")
148
+ return None
149
+ # Input tensor expected shape: Batch, Channel, Height, Width (e.g., from dataloader or transform)
150
+ if tensor.dim() != 4: # B C H W
151
+ logger.error(f"Unexpected tensor dimension {tensor.dim()} in denormalize. Expected 4 (BCHW).")
152
+ # Attempt to add batch dim if it's 3D (CHW)
153
+ if tensor.dim() == 3:
154
+ logger.warning("Denormalize received 3D tensor, adding batch dimension.")
155
+ tensor = tensor.unsqueeze(0)
156
+ else:
157
+ return None # Cannot handle other dims
158
+
159
+ num_channels = tensor.shape[1] # Channel dimension
160
+ if len(mean) != num_channels or len(std) != num_channels:
161
+ logger.error(f"Mean/std length ({len(mean)}/{len(std)}) mismatch with tensor channels ({num_channels})")
162
+ return None
163
+
164
+ # Clone to avoid modifying original tensor
165
+ tensor = tensor.clone().cpu() # Work on CPU copy
166
+
167
+ # Denormalize each channel
168
+ for c in range(num_channels):
169
+ tensor[:, c, :, :] = tensor[:, c, :, :] * std[c] + mean[c] # Apply to all items in batch
170
+
171
+ # Clamp values, remove batch dimension, permute to HWC for display/saving
172
+ # Assumes we are processing one image at a time here for inference result
173
+ denormalized_img_tensor = torch.clamp(tensor.squeeze(0), 0, 1).permute(1, 2, 0)
174
+
175
+ return denormalized_img_tensor.numpy() # Convert to numpy array (HWC, float32, [0,1])
176
+
177
+ # --- Load Model (Corrected Version) ---
178
+ def load_trained_model(checkpoint_path, model_config):
179
+ """Loads the trained model from a checkpoint, handling potential key mismatches."""
180
+ try:
181
+ model_instance = InferenceModel(model_config) # Create model structure
182
+ logger.info(f"Attempting to load checkpoint from: {checkpoint_path}")
183
+ if not os.path.exists(checkpoint_path):
184
+ raise FileNotFoundError(f"Checkpoint file not found at specified path: {checkpoint_path}")
185
+
186
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
187
+ logger.info(f"Checkpoint loaded into memory. Type: {type(checkpoint)}")
188
+
189
+ # Extract the state dictionary - flexible based on common saving patterns
190
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
191
+ state_dict = checkpoint["state_dict"]
192
+ logger.info("Using 'state_dict' key from checkpoint.")
193
+ elif isinstance(checkpoint, dict):
194
+ # Assume the dict *is* the state_dict if 'state_dict' key is absent
195
+ state_dict = checkpoint
196
+ logger.info("Using checkpoint dictionary directly as state_dict (no 'state_dict' key found).")
197
+ else:
198
+ # Could be the model itself was saved directly (less common with frameworks)
199
+ logger.warning(f"Checkpoint is not a dictionary. Attempting to load directly into model (less common). Type was: {type(checkpoint)}")
200
+ # This path might need adjustment based on how the model was saved if not a state_dict
201
+ try:
202
+ model_instance.load_state_dict(checkpoint) # Try loading directly
203
+ logger.info("Loaded state_dict directly from checkpoint object.")
204
+ model_instance.eval()
205
+ return model_instance
206
+ except Exception as e:
207
+ logger.error(f"Failed to load state_dict directly from checkpoint object: {e}")
208
+ return None # Failed direct load
209
+
210
+ # --- Key Prefix Correction Logic ---
211
+ target_keys = set(model_instance.state_dict().keys())
212
+ loaded_keys = set(state_dict.keys())
213
+ if not loaded_keys: logger.warning("Loaded state_dict is empty!"); return None # Check if state_dict is empty
214
+ first_loaded_key = next(iter(loaded_keys), None)
215
+ first_target_key = next(iter(target_keys), None)
216
+ corrected_state_dict = {}
217
+ prefix_added = False
218
+
219
+ # Check if prefix 'model.' needs to be ADDED to loaded keys
220
+ if first_loaded_key and not first_loaded_key.startswith('model.') and \
221
+ first_target_key and first_target_key.startswith('model.'):
222
+ logger.warning("Checkpoint keys missing 'model.' prefix. Attempting to add it.")
223
+ prefix_added = True
224
+ keys_not_prefixed_properly = []
225
+ for k, v in state_dict.items():
226
+ new_key = f"model.{k}"
227
+ if new_key in target_keys: corrected_state_dict[new_key] = v
228
+ else: keys_not_prefixed_properly.append(k); corrected_state_dict[k] = v # Keep original if prefixed version not wanted
229
+ if keys_not_prefixed_properly: logger.warning(f"Keys kept without prefix (target doesn't expect): {keys_not_prefixed_properly}")
230
+ logger.info("Finished attempting prefix addition.")
231
+ # Check if prefix 'model.' needs to be REMOVED from loaded keys
232
+ elif first_loaded_key and first_loaded_key.startswith('model.') and \
233
+ first_target_key and not first_target_key.startswith('model.'):
234
+ logger.warning("Checkpoint keys HAVE 'model.' prefix, but target model DOES NOT. Attempting to remove it.")
235
+ prefix_added = False # Indicate we removed prefix, not added
236
+ keys_not_stripped_properly = []
237
+ for k, v in state_dict.items():
238
+ if k.startswith('model.'):
239
+ new_key = k.partition('model.')[2] # Get part after 'model.'
240
+ if new_key in target_keys: corrected_state_dict[new_key] = v
241
+ else: keys_not_stripped_properly.append(k); corrected_state_dict[k] = v # Keep original if stripped version not wanted
242
+ else:
243
+ # Keep keys that didn't have prefix anyway
244
+ corrected_state_dict[k] = v
245
+ if keys_not_stripped_properly: logger.warning(f"Keys kept with prefix (target doesn't expect stripped): {keys_not_stripped_properly}")
246
+ logger.info("Finished attempting prefix removal.")
247
+ else:
248
+ logger.info("State dict keys seem to have correct prefix structure (or other mismatch). Using as is.")
249
+ corrected_state_dict = state_dict # Use the original dict
250
+
251
+ # --- Load the State Dictionary ---
252
+ logger.info("Attempting to load state_dict with strict=False for checking...")
253
+ missing_keys, unexpected_keys = model_instance.load_state_dict(corrected_state_dict, strict=False)
254
+
255
+ # Report detailed findings
256
+ final_msg = []
257
+ is_load_successful = True
258
+ if missing_keys:
259
+ final_msg.append(f"MISSING keys in checkpoint: {missing_keys}")
260
+ logger.error("CRITICAL FAILURE: Model is missing required keys.")
261
+ is_load_successful = False
262
+ if unexpected_keys:
263
+ final_msg.append(f"UNEXPECTED keys in checkpoint (exist in file but not in model): {unexpected_keys}")
264
+ # Decide if unexpected keys are acceptable
265
+ acceptable_unexpected = [k for k in unexpected_keys if k.endswith('num_batches_tracked')]
266
+ unacceptable_unexpected = [k for k in unexpected_keys if not k.endswith('num_batches_tracked')]
267
+ if unacceptable_unexpected:
268
+ logger.error(f"CRITICAL FAILURE: Model received unacceptable unexpected keys: {unacceptable_unexpected}")
269
+ is_load_successful = False
270
+ elif acceptable_unexpected:
271
+ logger.warning(f"Ignoring acceptable unexpected keys: {acceptable_unexpected}")
272
+
273
+ if not is_load_successful:
274
+ logger.error(f"State dict loading failed. Issues: {'; '.join(final_msg)}")
275
+ return None # Failed to load properly
276
+
277
+ logger.info(f"State dictionary loaded successfully. Issues (if any): {final_msg if final_msg else 'None'}")
278
+ model_instance.eval() # Set to evaluation mode
279
+ logger.info(f"Model loading process complete for {checkpoint_path}")
280
+ return model_instance
281
+
282
+ except FileNotFoundError as fnf_error:
283
+ logger.error(f"{fnf_error}") # Log the specific FileNotFoundError message
284
+ return None
285
+ except Exception as e:
286
+ logger.error(f"Unexpected error during model loading: {e}")
287
+ logger.error(traceback.format_exc()) # Log full traceback
288
+ return None
289
+
290
+
291
+ # --- Determine device & Load Model Globally ---
292
+ device = "cuda" if torch.cuda.is_available() else "cpu"
293
+ logger.info(f"Using device: {device}")
294
+ # Load the model using the global config object
295
+ model = load_trained_model(CHECKPOINT_PATH, config) # Pass the loaded config
296
+ if model is None:
297
+ logger.critical("CRITICAL: Failed to load model. Application cannot continue.")
298
+ exit(1) # Critical error, stop the application
299
+ model.to(device) # Move model to the appropriate device
300
+
301
+ # --- Inference Pipeline (Corrected Config Handling) ---
302
+ def run_inference_on_bytes(image_bytes, inference_model, model_config, device):
303
+ """Runs inference on image bytes, returns denormalized image, color mask, and overlay."""
304
+ try:
305
+ nparr = np.frombuffer(image_bytes, np.uint8)
306
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
307
+ if img is None: logger.error("Failed cv2.imdecode."); return None, None, None
308
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
309
+ logger.debug("Image decoded and converted to RGB.")
310
+
311
+ # --- Preprocessing (with config conversion) ---
312
+ # Check necessary config attributes exist before conversion attempt
313
+ required_data_keys = ['image_size', 'mean', 'std', 'num_classes']
314
+ for key in required_data_keys:
315
+ if not OmegaConf.select(model_config, f'data.{key}', default=None):
316
+ logger.error(f"Config missing required data field: data.{key}")
317
+ return None, None, None
318
+ if not OmegaConf.select(model_config, 'id2color', default=None):
319
+ logger.error("Config missing required field: id2color")
320
+ return None, None, None
321
+ if not OmegaConf.select(model_config, 'training.model_name', default=None):
322
+ logger.error("Config missing required field: training.model_name")
323
+ return None, None, None
324
+
325
+ try:
326
+ # Convert OmegaConf structures to standard Python types using OmegaConf.to_container
327
+ # resolve=True handles variable interpolation (like ${data.base_dir}) if used in relevant fields
328
+ img_size = tuple(OmegaConf.to_container(model_config.data.image_size, resolve=True))
329
+ mean = list(OmegaConf.to_container(model_config.data.mean, resolve=True))
330
+ std = list(OmegaConf.to_container(model_config.data.std, resolve=True))
331
+ # Ensure keys in id2color are standard integers
332
+ id2color_map = {int(k): v for k, v in OmegaConf.to_container(model_config.id2color, resolve=True).items()}
333
+ num_classes = int(model_config.data.num_classes) # Ensure int
334
+
335
+ logger.debug(f"Converted config values: size={img_size}, mean={mean}, std={std}, id2color keys={list(id2color_map.keys())}, num_classes={num_classes}")
336
+
337
+ # Basic validation after conversion
338
+ if not isinstance(mean, list) or not isinstance(std, list) or not isinstance(id2color_map, dict): raise TypeError("Config values did not convert to list/dict.")
339
+ if len(mean) != 3 or len(std) != 3: raise ValueError(f"Incorrect mean/std length. Expected 3.") # Assuming 3 color channels
340
+ if len(img_size) != 2: raise ValueError(f"Incorrect image_size length. Expected 2 (H, W).")
341
+
342
+ except Exception as e:
343
+ logger.error(f"Error processing/converting configuration values: {e}")
344
+ logger.error(traceback.format_exc())
345
+ return None, None, None
346
+
347
+ # Define the image transformation pipeline
348
+ transform = transforms.Compose([
349
+ transforms.ToTensor(), # HWC [0,255] numpy -> CHW [0,1] torch
350
+ transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR), # Use converted tuple size, BILINEAR is common for images before model
351
+ transforms.Normalize(mean=mean, std=std) # Use converted lists
352
+ ])
353
+ logger.debug(f"Image transform applied for size {img_size}.")
354
+ input_tensor = transform(img_rgb).unsqueeze(0).to(device) # Add batch dim (B=1), move to device
355
+ logger.debug(f"Input tensor created with shape: {input_tensor.shape}") # Should be [1, 3, H, W]
356
+
357
+ # --- Run Prediction ---
358
+ with torch.no_grad():
359
+ logits = inference_model(input_tensor) # Expect [B, C, H, W] logits
360
+ logger.debug(f"Logits received with shape: {logits.shape}")
361
+ # Check logits shape again after potential upsampling in model forward
362
+ if logits.dim() != 4 or logits.shape[1] != num_classes:
363
+ logger.error(f"Unexpected final logits shape or class number: {logits.shape}. Expected B x {num_classes} x H x W.")
364
+ return None, None, None
365
+ # Argmax along class dimension (C), remove batch dim, move to CPU, convert type
366
+ pred_mask = logits.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8) # H W, uint8
367
+ logger.debug(f"Prediction mask generated with shape: {pred_mask.shape}") # Should be [H, W]
368
+
369
+ # --- Post-processing ---
370
+ color_mask = num_to_rgb(pred_mask, id2color_map) # Use converted map
371
+ if color_mask is None: logger.error("num_to_rgb failed."); return None, None, None
372
+ logger.debug("Color mask generated.")
373
+
374
+ # Denormalize the *input tensor* for overlay display
375
+ denorm_img = denormalize(input_tensor, mean, std) # Use converted mean/std
376
+ if denorm_img is None: logger.error("denormalize failed."); return None, None, None
377
+ logger.debug("Input tensor denormalized for overlay.") # HWC, float32, [0,1]
378
+
379
+ # --- Create Overlay ---
380
+ # Ensure shapes match before blending (resize color mask to match denorm_img)
381
+ if denorm_img.shape[:2] != color_mask.shape[:2]:
382
+ logger.warning(f"Denorm img shape {denorm_img.shape[:2]} != Color mask shape {color_mask.shape[:2]}. Resizing color mask using INTER_NEAREST.")
383
+ # Resize color_mask (HWC float32) to match denorm_img (HWC float32)
384
+ color_mask = cv2.resize(color_mask, (denorm_img.shape[1], denorm_img.shape[0]), interpolation=cv2.INTER_NEAREST) # Use INTER_NEAREST for label masks
385
+
386
+ # Blend images: Original (denorm_img) * alpha + Mask (color_mask) * beta + gamma
387
+ overlay = cv2.addWeighted(denorm_img, 0.7, color_mask, 0.3, 0)
388
+ logger.debug("Overlay created using cv2.addWeighted.")
389
+ # overlay is HWC, float32, [0, 1], RGB
390
+
391
+ return denorm_img, color_mask, overlay
392
+
393
+ except Exception as e:
394
+ logger.error(f"Exception during inference pipeline for image: {e}")
395
+ logger.error(traceback.format_exc())
396
+ return None, None, None
397
+
398
+
399
+ # --- Flask App ---
400
+ app = Flask(__name__)
401
+ CORS(app) # Allow all origins for API and Result routes resources={r"/api/*": {"origins": "*"}, r"/Result/*": {"origins": "*"}}
402
+ logger.info("Flask app created and CORS enabled.")
403
+
404
+ def allowed_file(filename):
405
+ """Checks if the filename has an allowed extension."""
406
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
407
+
408
+ # --- API Endpoints ---
409
+ @app.route('/api/analyze', methods=['POST'])
410
+ def analyze_image():
411
+ """Receives Base64 image, runs inference, saves original and overlay."""
412
+ global model, config, device # Access global vars
413
+ endpoint_log_prefix = "[POST /api/analyze]"
414
+ logger.info(f"{endpoint_log_prefix} Received request.")
415
+
416
+ # --- Basic Checks ---
417
+ if model is None: logger.error(f"{endpoint_log_prefix} Model not loaded."); return jsonify({"success": False, "message": "Model not loaded"}), 500
418
+ if not request.is_json: logger.warning(f"{endpoint_log_prefix} Not JSON."); return jsonify({"success": False, "message": "Request must be JSON"}), 400
419
+ data = request.get_json()
420
+ if not data or 'image' not in data or 'filename' not in data:
421
+ logger.warning(f"{endpoint_log_prefix} Missing image/filename in JSON body. Data received: {data}")
422
+ return jsonify({"success": False, "message": "Missing 'image' (base64) or 'filename' in JSON body"}), 400
423
+
424
+ base64_image_data = data['image']; original_filename = data['filename']
425
+ logger.info(f"{endpoint_log_prefix} Original filename from request: '{original_filename}'")
426
+ safe_original_filename = secure_filename(original_filename) # Sanitize
427
+ if not safe_original_filename or not allowed_file(safe_original_filename):
428
+ logger.warning(f"{endpoint_log_prefix} Invalid/disallowed filename after sanitization: '{safe_original_filename}' from '{original_filename}'")
429
+ return jsonify({"success": False, "message": "Invalid or disallowed filename/extension"}), 400
430
+ logger.info(f"{endpoint_log_prefix} Sanitized filename for saving/processing: '{safe_original_filename}'")
431
+
432
+ try:
433
+ # --- Decode Base64 ---
434
+ if ',' in base64_image_data: header, encoded = base64_image_data.split(',', 1)
435
+ else: encoded = base64_image_data # Assume no header
436
+ image_bytes = base64.b64decode(encoded)
437
+ logger.info(f"{endpoint_log_prefix} Base64 image decoded ({len(image_bytes)} bytes).")
438
+
439
+ # --- Save Original Image ---
440
+ original_path = os.path.join(UPLOAD_FOLDER, safe_original_filename)
441
+ try:
442
+ with open(original_path, "wb") as f: f.write(image_bytes)
443
+ logger.info(f"{endpoint_log_prefix} Original image saved to: '{original_path}'")
444
+ except Exception as e:
445
+ logger.error(f"{endpoint_log_prefix} Failed to save original image to '{original_path}': {e}")
446
+ return jsonify({"success": False, "message": "Failed to save uploaded image on server"}), 500
447
+
448
+ # --- Run Inference ---
449
+ logger.info(f"{endpoint_log_prefix} Starting inference for '{safe_original_filename}'...")
450
+ # Pass the global config object here
451
+ denorm_img, color_mask, overlay = run_inference_on_bytes(image_bytes, model, config, device)
452
+ if overlay is None: # Check if inference failed
453
+ logger.error(f"{endpoint_log_prefix} Inference pipeline returned None for '{safe_original_filename}'.")
454
+ return jsonify({"success": False, "message": "Inference process failed on server"}), 500
455
+ logger.info(f"{endpoint_log_prefix} Inference completed successfully for '{safe_original_filename}'.")
456
+
457
+ # --- Save Overlay Image ---
458
+ name_part, ext = os.path.splitext(safe_original_filename)
459
+ # Create consistent overlay filename (crucial for toggle endpoint)
460
+ overlay_filename = f"analyzed_{name_part}{ext}"
461
+ overlay_path = os.path.join(RESULT_FOLDER, overlay_filename)
462
+ logger.info(f"{endpoint_log_prefix} Determined overlay filename: '{overlay_filename}' -> path: '{overlay_path}'")
463
+
464
+ # Convert overlay (float32 HWC RGB [0,1]) to uint8 HWC BGR [0,255] for cv2.imwrite
465
+ try:
466
+ overlay_to_save_uint8 = (overlay * 255).astype(np.uint8)
467
+ overlay_to_save_bgr = cv2.cvtColor(overlay_to_save_uint8, cv2.COLOR_RGB2BGR)
468
+ save_success = cv2.imwrite(overlay_path, overlay_to_save_bgr)
469
+ if not save_success:
470
+ raise IOError(f"cv2.imwrite failed to save the overlay image to {overlay_path}")
471
+ logger.info(f"{endpoint_log_prefix} Overlay image saved successfully to: '{overlay_path}'")
472
+ except Exception as e:
473
+ logger.error(f"{endpoint_log_prefix} Failed to convert or save overlay image to '{overlay_path}': {e}")
474
+ logger.error(traceback.format_exc())
475
+ return jsonify({"success": False, "message": "Failed to save analysis result image"}), 500
476
+
477
+ # --- Success Response ---
478
+ logger.info(f"{endpoint_log_prefix} Analysis successful for '{safe_original_filename}'. Returning success.")
479
+ return jsonify({
480
+ "success": True,
481
+ "message": "Analysis complete",
482
+ # Optionally return relative paths for info, client mainly needs overlay_filename
483
+ "paths": {"original": os.path.relpath(original_path, BASE_DIR), "overlay": os.path.relpath(overlay_path, BASE_DIR)},
484
+ "overlay_filename": overlay_filename # Return the *exact* filename saved
485
+ }), 200
486
+
487
+ except base64.binascii.Error as e:
488
+ logger.error(f"{endpoint_log_prefix} Invalid Base64 data received: {e}")
489
+ return jsonify({"success": False, "message": "Invalid Base64 image data received"}), 400
490
+ except Exception as e:
491
+ logger.error(f"{endpoint_log_prefix} Unexpected error during analysis request processing: {e}")
492
+ logger.error(traceback.format_exc())
493
+ return jsonify({"success": False, "message": "Internal server error during analysis processing"}), 500
494
+
495
+
496
+ @app.route('/api/toggle-image', methods=['GET'])
497
+ def get_analysis_path():
498
+ """Checks if the analyzed version of a given original filename exists."""
499
+ endpoint_log_prefix = "[GET /api/toggle-image]"
500
+ logger.info(f"{endpoint_log_prefix} Received request.")
501
+ logger.info(f"{endpoint_log_prefix} Full request URL: {request.url}")
502
+ logger.info(f"{endpoint_log_prefix} Request Query Args: {request.args}") # Log received args
503
+
504
+ original_filename = request.args.get('filename') # Get filename from ?filename=...
505
+ if not original_filename:
506
+ logger.warning(f"{endpoint_log_prefix} Missing 'filename' query parameter.")
507
+ return jsonify({"message": "Missing 'filename' query parameter"}), 400
508
+
509
+ logger.info(f"{endpoint_log_prefix} Original filename received from query: '{original_filename}'")
510
+ safe_original_filename = secure_filename(original_filename) # Sanitize
511
+ if not safe_original_filename:
512
+ logger.warning(f"{endpoint_log_prefix} Invalid filename after sanitization: '{safe_original_filename}' from '{original_filename}'")
513
+ return jsonify({"message": "Invalid filename format"}), 400
514
+ logger.info(f"{endpoint_log_prefix} Sanitized filename for lookup: '{safe_original_filename}'")
515
+
516
+ # --- Construct Expected Overlay Path (MUST match /analyze logic) ---
517
+ name_part, ext = os.path.splitext(safe_original_filename)
518
+ expected_overlay_filename = f"analyzed_{name_part}{ext}"
519
+ expected_overlay_path = os.path.join(RESULT_FOLDER, expected_overlay_filename)
520
+ logger.info(f"{endpoint_log_prefix} Expecting overlay file at: '{expected_overlay_path}'")
521
+
522
+ # --- Check if File Exists ---
523
+ if os.path.exists(expected_overlay_path):
524
+ logger.info(f"{endpoint_log_prefix} Found analysis result file: '{expected_overlay_filename}'")
525
+ # Return just the filename, client constructs the full /Result/ URL
526
+ return jsonify({"filepath": expected_overlay_filename}), 200
527
+ else:
528
+ # Explicitly log the path that was checked and not found
529
+ logger.warning(f"{endpoint_log_prefix} Analysis result file NOT FOUND at checked path: '{expected_overlay_path}'")
530
+ # Return 404 Not Found status code
531
+ return jsonify({"message": f"Analysis result not found for '{original_filename}'"}), 404
532
+
533
+
534
+ @app.route('/Result/<filename>')
535
+ def serve_result_image(filename):
536
+ """Serves images from the RESULT_FOLDER."""
537
+ endpoint_log_prefix = "[GET /Result]"
538
+ # Sanitize filename received in URL path for security
539
+ safe_filename = secure_filename(filename)
540
+ if safe_filename != filename:
541
+ # Log if the requested filename was changed by sanitization
542
+ logger.warning(f"{endpoint_log_prefix} Requested filename '{filename}' was sanitized to '{safe_filename}'. Serving sanitized version.")
543
+
544
+ logger.info(f"{endpoint_log_prefix} Attempting to serve file: '{safe_filename}' from directory: '{RESULT_FOLDER}'")
545
+ try:
546
+ # Use Flask's send_from_directory - safer than manual path joining
547
+ # as_attachment=False means display in browser if possible
548
+ return send_from_directory(RESULT_FOLDER, safe_filename, as_attachment=False)
549
+ except FileNotFoundError:
550
+ # Log the specific file that was not found
551
+ logger.error(f"{endpoint_log_prefix} Requested file not found in result folder: '{safe_filename}'")
552
+ # Return 404 Not Found
553
+ return jsonify({"message": "Requested analysis image not found"}), 404
554
+ except Exception as e:
555
+ # Catch other potential errors (e.g., permission issues)
556
+ logger.error(f"{endpoint_log_prefix} Error serving file '{safe_filename}': {e}")
557
+ logger.error(traceback.format_exc())
558
+ # Return 500 Internal Server Error
559
+ return jsonify({"message": "Error serving analysis image"}), 500
560
+
561
+
562
+ # --- Main Execution ---
563
+ if __name__ == '__main__':
564
+ # Ensure model loaded successfully before starting server
565
+ if model:
566
+ logger.info("Model loaded successfully. Starting Flask development server...")
567
+ # Use debug=True for development (auto-reload, debugger)
568
+ # Use debug=False for production!
569
+ # host='0.0.0.0' makes it accessible on the network
570
+ app.run(host='0.0.0.0', port=7860, debug=True)
571
+ else:
572
+ # This message should appear if load_trained_model returned None
573
+ logger.critical("APPLICATION FAILED TO START: MODEL COULD NOT BE LOADED.")
574
+ # Exit code 1 indicates an error
575
+ exit(1)
config/config.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+
4
+ data:
5
+ num_classes: 14
6
+ image_size: [512, 512]
7
+ mean: [0.4810, 0.4810, 0.4810] #[0.4513, 0.4513, 0.4513]
8
+ std: [0.2492, 0.2492, 0.2492] #[0.1879, 0.1879, 0.1879]
9
+ background_cls_id: 0
10
+ split_ratio: 0.8
11
+ base_dir: 'E:\Seg\Test\medical_segmentation\Bitewing'
12
+ mask_dir: "${data.base_dir}/newAnnotations/"
13
+ original_dir: "${data.base_dir}/Images/"
14
+ dataset_path: "${hydra:runtime.cwd}/Bitwening_dataset1"
15
+
16
+ # Normalization settings
17
+ normalization:
18
+ always_apply: true
19
+
20
+ # Image resize configuration
21
+ resize:
22
+ interpolation: "INTER_NEAREST" # CV2 interpolation method
23
+
24
+ # Complete augmentation configuration with all parameters
25
+ augmentation:
26
+ # HorizontalFlip
27
+ use_horizontal_flip: true
28
+ horizontal_flip_prob: 0.5
29
+
30
+ # VerticalFlip
31
+ use_vertical_flip: true
32
+ vertical_flip_prob: 0.5
33
+
34
+ # ShiftScaleRotate
35
+ use_shift_scale_rotate: true
36
+ shift_scale_rotate_prob: 0.5
37
+ rotate_limit: 0.15
38
+ scale_limit: 0.12
39
+ shift_limit: 0.12
40
+ border_mode: 4 # cv2.BORDER_REFLECT_101
41
+
42
+ # RandomBrightnessContrast
43
+ use_brightness_contrast: true
44
+ brightness_contrast_prob: 0.5
45
+ brightness_limit: 0.2
46
+ contrast_limit: 0.2
47
+
48
+ # CoarseDropout
49
+ use_coarse_dropout: true
50
+ coarse_dropout:
51
+ max_holes: 8
52
+ min_holes: 5
53
+ max_height: 25
54
+ max_width: 25
55
+ fill_value: 0
56
+ mask_fill_value: 0
57
+ prob: 0.5
58
+
59
+ training:
60
+ batch_size: 10
61
+ num_epochs: 1
62
+ init_lr: 3e-4
63
+ optimizer_name: "AdamW"
64
+ weight_decay: 0.1
65
+ use_scheduler: true
66
+ scheduler: "MultiStepLR"
67
+ model_name: "nvidia/segformer-b4-finetuned-ade-512-512"
68
+ num_workers: 0
69
+ pin_memory: true
70
+ drop_last: true
71
+ shuffle_train: true
72
+ shuffle_valid: false
73
+
74
+ inference:
75
+ batch_size: 10
76
+ num_batches: 3
77
+
78
+ wandb:
79
+ project: "UM_medical_segmentation"
80
+ log_model: true
81
+
82
+ trainer:
83
+ accelerator: "gpu"
84
+ devices: "1"
85
+ strategy: "auto"
86
+ precision: "16-mixed"
87
+ enable_model_summary: false
88
+
89
+ id2color:
90
+ 0: [0, 0, 0] # Black
91
+ 1: [0, 0, 255] # Blue
92
+ 2: [0, 255, 0] # Green
93
+ 3: [255, 0, 0] # Red
94
+ 4: [255, 255, 0] # Yellow
95
+ 5: [255, 165, 0] # Orange
96
+ 6: [128, 0, 128] # Purple
97
+ 7: [0, 255, 255] # Cyan
98
+ 8: [255, 20, 147] # Deep Pink
99
+ 9: [75, 0, 130] # Indigo
100
+ 10: [139, 69, 19] # Saddle Brown
101
+ 11: [255, 192, 203] # Pink
102
+ 12: [47, 79, 79] # Dark Slate Gray
103
+ 13: [173, 255, 47] # Green Yellow
104
+ 14: [0, 128, 128] # Teal
105
+
106
+
107
+
108
+ experiment:
109
+ name: "EXPERIMENT_1 bitewing dataset" # Descriptive name for the experiment
110
+ description: "Testing on 14 classes dataset with no layers and just fine tuning the model " # What this experiment is testing
111
+ goal: "to check the performance of the model on 100 epochs" # What you hope to achieve