Spaces:
Runtime error
Runtime error
tt
Browse files- app.py +575 -0
- config/config.yaml +111 -0
app.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import os
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
import logging # For better logging
|
6 |
+
# Import specific handlers and formatter
|
7 |
+
from logging.handlers import RotatingFileHandler
|
8 |
+
import traceback # For detailed exception logging
|
9 |
+
from flask import Flask, request, jsonify, send_from_directory
|
10 |
+
from flask_cors import CORS # To handle Cross-Origin requests from your frontend
|
11 |
+
import torch
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import yaml
|
15 |
+
from torchvision import transforms
|
16 |
+
from transformers import SegformerForSemanticSegmentation
|
17 |
+
from omegaconf import OmegaConf # Import OmegaConf itself
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from werkzeug.utils import secure_filename # For safer filenames
|
20 |
+
|
21 |
+
# --- Configuration ---
|
22 |
+
# Use absolute paths for robustness
|
23 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Directory where this script is running
|
24 |
+
|
25 |
+
# >>> Point this to your actual config file <<<
|
26 |
+
CONFIG_PATH = os.path.join(BASE_DIR, "config/config.yaml") # Assuming config.yaml is in the same dir
|
27 |
+
|
28 |
+
# >>> Point this to your actual checkpoint file <<<
|
29 |
+
CHECKPOINT_PATH = "ckpt_000-vloss_0.4685_vf1_0.6469.ckpt"
|
30 |
+
|
31 |
+
UPLOAD_FOLDER = os.path.join(BASE_DIR, 'uploads')
|
32 |
+
RESULT_FOLDER = os.path.join(BASE_DIR, 'results')
|
33 |
+
LOG_FILE_PATH = os.path.join(BASE_DIR, 'flask_app.log') # Define log file path
|
34 |
+
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp', 'tif', 'tiff'}
|
35 |
+
|
36 |
+
# --- Logging Setup ---
|
37 |
+
# Clear existing handlers from the root logger to avoid duplicates on reload
|
38 |
+
logging.getLogger().handlers.clear()
|
39 |
+
# Create formatter
|
40 |
+
log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s')
|
41 |
+
# Create Console Handler
|
42 |
+
console_handler = logging.StreamHandler()
|
43 |
+
console_handler.setFormatter(log_formatter)
|
44 |
+
# Create File Handler (using RotatingFileHandler for log rotation)
|
45 |
+
file_handler = RotatingFileHandler(LOG_FILE_PATH, maxBytes=5*1024*1024, backupCount=3)
|
46 |
+
file_handler.setFormatter(log_formatter)
|
47 |
+
# Get the root logger and add handlers
|
48 |
+
logger = logging.getLogger()
|
49 |
+
logger.setLevel(logging.INFO) # Set minimum level for the logger (e.g., INFO, DEBUG)
|
50 |
+
logger.addHandler(console_handler)
|
51 |
+
logger.addHandler(file_handler)
|
52 |
+
|
53 |
+
# --- Ensure upload and result directories exist ---
|
54 |
+
try:
|
55 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
56 |
+
os.makedirs(RESULT_FOLDER, exist_ok=True)
|
57 |
+
logger.info(f"Ensured directories exist: {UPLOAD_FOLDER}, {RESULT_FOLDER}")
|
58 |
+
except OSError as e:
|
59 |
+
logger.error(f"Error creating directories: {e}")
|
60 |
+
exit(1) # Exit if we can't create essential folders
|
61 |
+
|
62 |
+
# --- Load Config ---
|
63 |
+
config = None
|
64 |
+
try:
|
65 |
+
# Load the YAML file using OmegaConf
|
66 |
+
config = OmegaConf.load(CONFIG_PATH)
|
67 |
+
# Note: We don't need OmegaConf.create() if loading directly from file
|
68 |
+
logger.info(f"Configuration loaded successfully from: {CONFIG_PATH}")
|
69 |
+
# Log some key values to confirm loading
|
70 |
+
logger.info(f"Config check: num_classes={config.data.num_classes}, model_name={config.training.model_name}")
|
71 |
+
except FileNotFoundError:
|
72 |
+
logger.error(f"Configuration file not found: {CONFIG_PATH}")
|
73 |
+
exit(1)
|
74 |
+
except Exception as e: # Catch broader errors during loading/parsing
|
75 |
+
logger.error(f"Error loading or parsing configuration file '{CONFIG_PATH}': {e}")
|
76 |
+
logger.error(traceback.format_exc())
|
77 |
+
exit(1)
|
78 |
+
|
79 |
+
# --- Model Definition ---
|
80 |
+
class InferenceModel(torch.nn.Module):
|
81 |
+
def __init__(self, model_config): # Use local name 'model_config'
|
82 |
+
super().__init__()
|
83 |
+
try:
|
84 |
+
# Access config values needed for model init
|
85 |
+
model_name = model_config.training.model_name
|
86 |
+
num_classes = model_config.data.num_classes
|
87 |
+
logger.info(f"Initializing SegformerForSemanticSegmentation with model='{model_name}', num_labels={num_classes}")
|
88 |
+
self.model = SegformerForSemanticSegmentation.from_pretrained(
|
89 |
+
model_name,
|
90 |
+
num_labels=num_classes,
|
91 |
+
ignore_mismatched_sizes=True # Important if fine-tuning head size differs
|
92 |
+
)
|
93 |
+
logger.info("Segformer model part initialized.")
|
94 |
+
except AttributeError as ae:
|
95 |
+
logger.error(f"Config error during model init: Missing key? {ae}")
|
96 |
+
logger.error(f"Check if 'training.model_name' and 'data.num_classes' exist in {CONFIG_PATH}")
|
97 |
+
raise # Re-raise error to stop execution
|
98 |
+
except Exception as e:
|
99 |
+
logger.error(f"Error initializing Segformer model from Hugging Face: {e}")
|
100 |
+
logger.error(traceback.format_exc())
|
101 |
+
raise # Re-raise error to stop execution
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
# Expects pixel_values as input
|
105 |
+
outputs = self.model(pixel_values=x, return_dict=True)
|
106 |
+
# Upsample logits to original input size
|
107 |
+
logits = F.interpolate(
|
108 |
+
outputs.logits,
|
109 |
+
size=x.shape[-2:], # Get H, W from input tensor x
|
110 |
+
mode="bilinear",
|
111 |
+
align_corners=False
|
112 |
+
)
|
113 |
+
return logits
|
114 |
+
|
115 |
+
# --- Utility Functions ---
|
116 |
+
def num_to_rgb(num_arr, color_map_dict):
|
117 |
+
"""Converts a label mask (numpy array) to an RGB color mask."""
|
118 |
+
single_layer = np.squeeze(num_arr)
|
119 |
+
output = np.zeros(num_arr.shape[:2] + (3,), dtype=np.uint8) # Initialize with uint8 zeros
|
120 |
+
|
121 |
+
# Expects color_map_dict to be a standard Python dict {int_label: [R, G, B]}
|
122 |
+
if not isinstance(color_map_dict, dict):
|
123 |
+
logger.error(f"Invalid color_map provided to num_to_rgb: {type(color_map_dict)}. Expected dict.")
|
124 |
+
return np.float32(output) / 255.0 # Return black float image
|
125 |
+
|
126 |
+
unique_labels = np.unique(single_layer)
|
127 |
+
for k in unique_labels:
|
128 |
+
label_key = int(k) # Ensure key is standard int for lookup
|
129 |
+
if label_key in color_map_dict:
|
130 |
+
# Assign color, ensure color value is appropriate (e.g., list/tuple of 3 ints)
|
131 |
+
color = color_map_dict[label_key]
|
132 |
+
if isinstance(color, (list, tuple)) and len(color) == 3:
|
133 |
+
output[single_layer == k] = color
|
134 |
+
else:
|
135 |
+
logger.warning(f"Invalid color format for label {label_key} in color map: {color}. Skipping.")
|
136 |
+
else:
|
137 |
+
if label_key != 0: # Often 0 is background, might not be in map
|
138 |
+
logger.warning(f"Label Key {label_key} found in mask but not in provided color map.")
|
139 |
+
# Default color (e.g., black) is already set by np.zeros
|
140 |
+
|
141 |
+
return np.float32(output) / 255.0 # Return float32 RGB image [0, 1]
|
142 |
+
|
143 |
+
def denormalize(tensor, mean, std):
|
144 |
+
"""Denormalizes a torch tensor (CHW format)."""
|
145 |
+
# Expects standard Python lists/tuples for mean/std
|
146 |
+
if not isinstance(mean, (list, tuple)) or not isinstance(std, (list, tuple)):
|
147 |
+
logger.error(f"Mean ({type(mean)}) or std ({type(std)}) are not lists/tuples in denormalize.")
|
148 |
+
return None
|
149 |
+
# Input tensor expected shape: Batch, Channel, Height, Width (e.g., from dataloader or transform)
|
150 |
+
if tensor.dim() != 4: # B C H W
|
151 |
+
logger.error(f"Unexpected tensor dimension {tensor.dim()} in denormalize. Expected 4 (BCHW).")
|
152 |
+
# Attempt to add batch dim if it's 3D (CHW)
|
153 |
+
if tensor.dim() == 3:
|
154 |
+
logger.warning("Denormalize received 3D tensor, adding batch dimension.")
|
155 |
+
tensor = tensor.unsqueeze(0)
|
156 |
+
else:
|
157 |
+
return None # Cannot handle other dims
|
158 |
+
|
159 |
+
num_channels = tensor.shape[1] # Channel dimension
|
160 |
+
if len(mean) != num_channels or len(std) != num_channels:
|
161 |
+
logger.error(f"Mean/std length ({len(mean)}/{len(std)}) mismatch with tensor channels ({num_channels})")
|
162 |
+
return None
|
163 |
+
|
164 |
+
# Clone to avoid modifying original tensor
|
165 |
+
tensor = tensor.clone().cpu() # Work on CPU copy
|
166 |
+
|
167 |
+
# Denormalize each channel
|
168 |
+
for c in range(num_channels):
|
169 |
+
tensor[:, c, :, :] = tensor[:, c, :, :] * std[c] + mean[c] # Apply to all items in batch
|
170 |
+
|
171 |
+
# Clamp values, remove batch dimension, permute to HWC for display/saving
|
172 |
+
# Assumes we are processing one image at a time here for inference result
|
173 |
+
denormalized_img_tensor = torch.clamp(tensor.squeeze(0), 0, 1).permute(1, 2, 0)
|
174 |
+
|
175 |
+
return denormalized_img_tensor.numpy() # Convert to numpy array (HWC, float32, [0,1])
|
176 |
+
|
177 |
+
# --- Load Model (Corrected Version) ---
|
178 |
+
def load_trained_model(checkpoint_path, model_config):
|
179 |
+
"""Loads the trained model from a checkpoint, handling potential key mismatches."""
|
180 |
+
try:
|
181 |
+
model_instance = InferenceModel(model_config) # Create model structure
|
182 |
+
logger.info(f"Attempting to load checkpoint from: {checkpoint_path}")
|
183 |
+
if not os.path.exists(checkpoint_path):
|
184 |
+
raise FileNotFoundError(f"Checkpoint file not found at specified path: {checkpoint_path}")
|
185 |
+
|
186 |
+
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
187 |
+
logger.info(f"Checkpoint loaded into memory. Type: {type(checkpoint)}")
|
188 |
+
|
189 |
+
# Extract the state dictionary - flexible based on common saving patterns
|
190 |
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
191 |
+
state_dict = checkpoint["state_dict"]
|
192 |
+
logger.info("Using 'state_dict' key from checkpoint.")
|
193 |
+
elif isinstance(checkpoint, dict):
|
194 |
+
# Assume the dict *is* the state_dict if 'state_dict' key is absent
|
195 |
+
state_dict = checkpoint
|
196 |
+
logger.info("Using checkpoint dictionary directly as state_dict (no 'state_dict' key found).")
|
197 |
+
else:
|
198 |
+
# Could be the model itself was saved directly (less common with frameworks)
|
199 |
+
logger.warning(f"Checkpoint is not a dictionary. Attempting to load directly into model (less common). Type was: {type(checkpoint)}")
|
200 |
+
# This path might need adjustment based on how the model was saved if not a state_dict
|
201 |
+
try:
|
202 |
+
model_instance.load_state_dict(checkpoint) # Try loading directly
|
203 |
+
logger.info("Loaded state_dict directly from checkpoint object.")
|
204 |
+
model_instance.eval()
|
205 |
+
return model_instance
|
206 |
+
except Exception as e:
|
207 |
+
logger.error(f"Failed to load state_dict directly from checkpoint object: {e}")
|
208 |
+
return None # Failed direct load
|
209 |
+
|
210 |
+
# --- Key Prefix Correction Logic ---
|
211 |
+
target_keys = set(model_instance.state_dict().keys())
|
212 |
+
loaded_keys = set(state_dict.keys())
|
213 |
+
if not loaded_keys: logger.warning("Loaded state_dict is empty!"); return None # Check if state_dict is empty
|
214 |
+
first_loaded_key = next(iter(loaded_keys), None)
|
215 |
+
first_target_key = next(iter(target_keys), None)
|
216 |
+
corrected_state_dict = {}
|
217 |
+
prefix_added = False
|
218 |
+
|
219 |
+
# Check if prefix 'model.' needs to be ADDED to loaded keys
|
220 |
+
if first_loaded_key and not first_loaded_key.startswith('model.') and \
|
221 |
+
first_target_key and first_target_key.startswith('model.'):
|
222 |
+
logger.warning("Checkpoint keys missing 'model.' prefix. Attempting to add it.")
|
223 |
+
prefix_added = True
|
224 |
+
keys_not_prefixed_properly = []
|
225 |
+
for k, v in state_dict.items():
|
226 |
+
new_key = f"model.{k}"
|
227 |
+
if new_key in target_keys: corrected_state_dict[new_key] = v
|
228 |
+
else: keys_not_prefixed_properly.append(k); corrected_state_dict[k] = v # Keep original if prefixed version not wanted
|
229 |
+
if keys_not_prefixed_properly: logger.warning(f"Keys kept without prefix (target doesn't expect): {keys_not_prefixed_properly}")
|
230 |
+
logger.info("Finished attempting prefix addition.")
|
231 |
+
# Check if prefix 'model.' needs to be REMOVED from loaded keys
|
232 |
+
elif first_loaded_key and first_loaded_key.startswith('model.') and \
|
233 |
+
first_target_key and not first_target_key.startswith('model.'):
|
234 |
+
logger.warning("Checkpoint keys HAVE 'model.' prefix, but target model DOES NOT. Attempting to remove it.")
|
235 |
+
prefix_added = False # Indicate we removed prefix, not added
|
236 |
+
keys_not_stripped_properly = []
|
237 |
+
for k, v in state_dict.items():
|
238 |
+
if k.startswith('model.'):
|
239 |
+
new_key = k.partition('model.')[2] # Get part after 'model.'
|
240 |
+
if new_key in target_keys: corrected_state_dict[new_key] = v
|
241 |
+
else: keys_not_stripped_properly.append(k); corrected_state_dict[k] = v # Keep original if stripped version not wanted
|
242 |
+
else:
|
243 |
+
# Keep keys that didn't have prefix anyway
|
244 |
+
corrected_state_dict[k] = v
|
245 |
+
if keys_not_stripped_properly: logger.warning(f"Keys kept with prefix (target doesn't expect stripped): {keys_not_stripped_properly}")
|
246 |
+
logger.info("Finished attempting prefix removal.")
|
247 |
+
else:
|
248 |
+
logger.info("State dict keys seem to have correct prefix structure (or other mismatch). Using as is.")
|
249 |
+
corrected_state_dict = state_dict # Use the original dict
|
250 |
+
|
251 |
+
# --- Load the State Dictionary ---
|
252 |
+
logger.info("Attempting to load state_dict with strict=False for checking...")
|
253 |
+
missing_keys, unexpected_keys = model_instance.load_state_dict(corrected_state_dict, strict=False)
|
254 |
+
|
255 |
+
# Report detailed findings
|
256 |
+
final_msg = []
|
257 |
+
is_load_successful = True
|
258 |
+
if missing_keys:
|
259 |
+
final_msg.append(f"MISSING keys in checkpoint: {missing_keys}")
|
260 |
+
logger.error("CRITICAL FAILURE: Model is missing required keys.")
|
261 |
+
is_load_successful = False
|
262 |
+
if unexpected_keys:
|
263 |
+
final_msg.append(f"UNEXPECTED keys in checkpoint (exist in file but not in model): {unexpected_keys}")
|
264 |
+
# Decide if unexpected keys are acceptable
|
265 |
+
acceptable_unexpected = [k for k in unexpected_keys if k.endswith('num_batches_tracked')]
|
266 |
+
unacceptable_unexpected = [k for k in unexpected_keys if not k.endswith('num_batches_tracked')]
|
267 |
+
if unacceptable_unexpected:
|
268 |
+
logger.error(f"CRITICAL FAILURE: Model received unacceptable unexpected keys: {unacceptable_unexpected}")
|
269 |
+
is_load_successful = False
|
270 |
+
elif acceptable_unexpected:
|
271 |
+
logger.warning(f"Ignoring acceptable unexpected keys: {acceptable_unexpected}")
|
272 |
+
|
273 |
+
if not is_load_successful:
|
274 |
+
logger.error(f"State dict loading failed. Issues: {'; '.join(final_msg)}")
|
275 |
+
return None # Failed to load properly
|
276 |
+
|
277 |
+
logger.info(f"State dictionary loaded successfully. Issues (if any): {final_msg if final_msg else 'None'}")
|
278 |
+
model_instance.eval() # Set to evaluation mode
|
279 |
+
logger.info(f"Model loading process complete for {checkpoint_path}")
|
280 |
+
return model_instance
|
281 |
+
|
282 |
+
except FileNotFoundError as fnf_error:
|
283 |
+
logger.error(f"{fnf_error}") # Log the specific FileNotFoundError message
|
284 |
+
return None
|
285 |
+
except Exception as e:
|
286 |
+
logger.error(f"Unexpected error during model loading: {e}")
|
287 |
+
logger.error(traceback.format_exc()) # Log full traceback
|
288 |
+
return None
|
289 |
+
|
290 |
+
|
291 |
+
# --- Determine device & Load Model Globally ---
|
292 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
293 |
+
logger.info(f"Using device: {device}")
|
294 |
+
# Load the model using the global config object
|
295 |
+
model = load_trained_model(CHECKPOINT_PATH, config) # Pass the loaded config
|
296 |
+
if model is None:
|
297 |
+
logger.critical("CRITICAL: Failed to load model. Application cannot continue.")
|
298 |
+
exit(1) # Critical error, stop the application
|
299 |
+
model.to(device) # Move model to the appropriate device
|
300 |
+
|
301 |
+
# --- Inference Pipeline (Corrected Config Handling) ---
|
302 |
+
def run_inference_on_bytes(image_bytes, inference_model, model_config, device):
|
303 |
+
"""Runs inference on image bytes, returns denormalized image, color mask, and overlay."""
|
304 |
+
try:
|
305 |
+
nparr = np.frombuffer(image_bytes, np.uint8)
|
306 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
307 |
+
if img is None: logger.error("Failed cv2.imdecode."); return None, None, None
|
308 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
309 |
+
logger.debug("Image decoded and converted to RGB.")
|
310 |
+
|
311 |
+
# --- Preprocessing (with config conversion) ---
|
312 |
+
# Check necessary config attributes exist before conversion attempt
|
313 |
+
required_data_keys = ['image_size', 'mean', 'std', 'num_classes']
|
314 |
+
for key in required_data_keys:
|
315 |
+
if not OmegaConf.select(model_config, f'data.{key}', default=None):
|
316 |
+
logger.error(f"Config missing required data field: data.{key}")
|
317 |
+
return None, None, None
|
318 |
+
if not OmegaConf.select(model_config, 'id2color', default=None):
|
319 |
+
logger.error("Config missing required field: id2color")
|
320 |
+
return None, None, None
|
321 |
+
if not OmegaConf.select(model_config, 'training.model_name', default=None):
|
322 |
+
logger.error("Config missing required field: training.model_name")
|
323 |
+
return None, None, None
|
324 |
+
|
325 |
+
try:
|
326 |
+
# Convert OmegaConf structures to standard Python types using OmegaConf.to_container
|
327 |
+
# resolve=True handles variable interpolation (like ${data.base_dir}) if used in relevant fields
|
328 |
+
img_size = tuple(OmegaConf.to_container(model_config.data.image_size, resolve=True))
|
329 |
+
mean = list(OmegaConf.to_container(model_config.data.mean, resolve=True))
|
330 |
+
std = list(OmegaConf.to_container(model_config.data.std, resolve=True))
|
331 |
+
# Ensure keys in id2color are standard integers
|
332 |
+
id2color_map = {int(k): v for k, v in OmegaConf.to_container(model_config.id2color, resolve=True).items()}
|
333 |
+
num_classes = int(model_config.data.num_classes) # Ensure int
|
334 |
+
|
335 |
+
logger.debug(f"Converted config values: size={img_size}, mean={mean}, std={std}, id2color keys={list(id2color_map.keys())}, num_classes={num_classes}")
|
336 |
+
|
337 |
+
# Basic validation after conversion
|
338 |
+
if not isinstance(mean, list) or not isinstance(std, list) or not isinstance(id2color_map, dict): raise TypeError("Config values did not convert to list/dict.")
|
339 |
+
if len(mean) != 3 or len(std) != 3: raise ValueError(f"Incorrect mean/std length. Expected 3.") # Assuming 3 color channels
|
340 |
+
if len(img_size) != 2: raise ValueError(f"Incorrect image_size length. Expected 2 (H, W).")
|
341 |
+
|
342 |
+
except Exception as e:
|
343 |
+
logger.error(f"Error processing/converting configuration values: {e}")
|
344 |
+
logger.error(traceback.format_exc())
|
345 |
+
return None, None, None
|
346 |
+
|
347 |
+
# Define the image transformation pipeline
|
348 |
+
transform = transforms.Compose([
|
349 |
+
transforms.ToTensor(), # HWC [0,255] numpy -> CHW [0,1] torch
|
350 |
+
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR), # Use converted tuple size, BILINEAR is common for images before model
|
351 |
+
transforms.Normalize(mean=mean, std=std) # Use converted lists
|
352 |
+
])
|
353 |
+
logger.debug(f"Image transform applied for size {img_size}.")
|
354 |
+
input_tensor = transform(img_rgb).unsqueeze(0).to(device) # Add batch dim (B=1), move to device
|
355 |
+
logger.debug(f"Input tensor created with shape: {input_tensor.shape}") # Should be [1, 3, H, W]
|
356 |
+
|
357 |
+
# --- Run Prediction ---
|
358 |
+
with torch.no_grad():
|
359 |
+
logits = inference_model(input_tensor) # Expect [B, C, H, W] logits
|
360 |
+
logger.debug(f"Logits received with shape: {logits.shape}")
|
361 |
+
# Check logits shape again after potential upsampling in model forward
|
362 |
+
if logits.dim() != 4 or logits.shape[1] != num_classes:
|
363 |
+
logger.error(f"Unexpected final logits shape or class number: {logits.shape}. Expected B x {num_classes} x H x W.")
|
364 |
+
return None, None, None
|
365 |
+
# Argmax along class dimension (C), remove batch dim, move to CPU, convert type
|
366 |
+
pred_mask = logits.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8) # H W, uint8
|
367 |
+
logger.debug(f"Prediction mask generated with shape: {pred_mask.shape}") # Should be [H, W]
|
368 |
+
|
369 |
+
# --- Post-processing ---
|
370 |
+
color_mask = num_to_rgb(pred_mask, id2color_map) # Use converted map
|
371 |
+
if color_mask is None: logger.error("num_to_rgb failed."); return None, None, None
|
372 |
+
logger.debug("Color mask generated.")
|
373 |
+
|
374 |
+
# Denormalize the *input tensor* for overlay display
|
375 |
+
denorm_img = denormalize(input_tensor, mean, std) # Use converted mean/std
|
376 |
+
if denorm_img is None: logger.error("denormalize failed."); return None, None, None
|
377 |
+
logger.debug("Input tensor denormalized for overlay.") # HWC, float32, [0,1]
|
378 |
+
|
379 |
+
# --- Create Overlay ---
|
380 |
+
# Ensure shapes match before blending (resize color mask to match denorm_img)
|
381 |
+
if denorm_img.shape[:2] != color_mask.shape[:2]:
|
382 |
+
logger.warning(f"Denorm img shape {denorm_img.shape[:2]} != Color mask shape {color_mask.shape[:2]}. Resizing color mask using INTER_NEAREST.")
|
383 |
+
# Resize color_mask (HWC float32) to match denorm_img (HWC float32)
|
384 |
+
color_mask = cv2.resize(color_mask, (denorm_img.shape[1], denorm_img.shape[0]), interpolation=cv2.INTER_NEAREST) # Use INTER_NEAREST for label masks
|
385 |
+
|
386 |
+
# Blend images: Original (denorm_img) * alpha + Mask (color_mask) * beta + gamma
|
387 |
+
overlay = cv2.addWeighted(denorm_img, 0.7, color_mask, 0.3, 0)
|
388 |
+
logger.debug("Overlay created using cv2.addWeighted.")
|
389 |
+
# overlay is HWC, float32, [0, 1], RGB
|
390 |
+
|
391 |
+
return denorm_img, color_mask, overlay
|
392 |
+
|
393 |
+
except Exception as e:
|
394 |
+
logger.error(f"Exception during inference pipeline for image: {e}")
|
395 |
+
logger.error(traceback.format_exc())
|
396 |
+
return None, None, None
|
397 |
+
|
398 |
+
|
399 |
+
# --- Flask App ---
|
400 |
+
app = Flask(__name__)
|
401 |
+
CORS(app) # Allow all origins for API and Result routes resources={r"/api/*": {"origins": "*"}, r"/Result/*": {"origins": "*"}}
|
402 |
+
logger.info("Flask app created and CORS enabled.")
|
403 |
+
|
404 |
+
def allowed_file(filename):
|
405 |
+
"""Checks if the filename has an allowed extension."""
|
406 |
+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
407 |
+
|
408 |
+
# --- API Endpoints ---
|
409 |
+
@app.route('/api/analyze', methods=['POST'])
|
410 |
+
def analyze_image():
|
411 |
+
"""Receives Base64 image, runs inference, saves original and overlay."""
|
412 |
+
global model, config, device # Access global vars
|
413 |
+
endpoint_log_prefix = "[POST /api/analyze]"
|
414 |
+
logger.info(f"{endpoint_log_prefix} Received request.")
|
415 |
+
|
416 |
+
# --- Basic Checks ---
|
417 |
+
if model is None: logger.error(f"{endpoint_log_prefix} Model not loaded."); return jsonify({"success": False, "message": "Model not loaded"}), 500
|
418 |
+
if not request.is_json: logger.warning(f"{endpoint_log_prefix} Not JSON."); return jsonify({"success": False, "message": "Request must be JSON"}), 400
|
419 |
+
data = request.get_json()
|
420 |
+
if not data or 'image' not in data or 'filename' not in data:
|
421 |
+
logger.warning(f"{endpoint_log_prefix} Missing image/filename in JSON body. Data received: {data}")
|
422 |
+
return jsonify({"success": False, "message": "Missing 'image' (base64) or 'filename' in JSON body"}), 400
|
423 |
+
|
424 |
+
base64_image_data = data['image']; original_filename = data['filename']
|
425 |
+
logger.info(f"{endpoint_log_prefix} Original filename from request: '{original_filename}'")
|
426 |
+
safe_original_filename = secure_filename(original_filename) # Sanitize
|
427 |
+
if not safe_original_filename or not allowed_file(safe_original_filename):
|
428 |
+
logger.warning(f"{endpoint_log_prefix} Invalid/disallowed filename after sanitization: '{safe_original_filename}' from '{original_filename}'")
|
429 |
+
return jsonify({"success": False, "message": "Invalid or disallowed filename/extension"}), 400
|
430 |
+
logger.info(f"{endpoint_log_prefix} Sanitized filename for saving/processing: '{safe_original_filename}'")
|
431 |
+
|
432 |
+
try:
|
433 |
+
# --- Decode Base64 ---
|
434 |
+
if ',' in base64_image_data: header, encoded = base64_image_data.split(',', 1)
|
435 |
+
else: encoded = base64_image_data # Assume no header
|
436 |
+
image_bytes = base64.b64decode(encoded)
|
437 |
+
logger.info(f"{endpoint_log_prefix} Base64 image decoded ({len(image_bytes)} bytes).")
|
438 |
+
|
439 |
+
# --- Save Original Image ---
|
440 |
+
original_path = os.path.join(UPLOAD_FOLDER, safe_original_filename)
|
441 |
+
try:
|
442 |
+
with open(original_path, "wb") as f: f.write(image_bytes)
|
443 |
+
logger.info(f"{endpoint_log_prefix} Original image saved to: '{original_path}'")
|
444 |
+
except Exception as e:
|
445 |
+
logger.error(f"{endpoint_log_prefix} Failed to save original image to '{original_path}': {e}")
|
446 |
+
return jsonify({"success": False, "message": "Failed to save uploaded image on server"}), 500
|
447 |
+
|
448 |
+
# --- Run Inference ---
|
449 |
+
logger.info(f"{endpoint_log_prefix} Starting inference for '{safe_original_filename}'...")
|
450 |
+
# Pass the global config object here
|
451 |
+
denorm_img, color_mask, overlay = run_inference_on_bytes(image_bytes, model, config, device)
|
452 |
+
if overlay is None: # Check if inference failed
|
453 |
+
logger.error(f"{endpoint_log_prefix} Inference pipeline returned None for '{safe_original_filename}'.")
|
454 |
+
return jsonify({"success": False, "message": "Inference process failed on server"}), 500
|
455 |
+
logger.info(f"{endpoint_log_prefix} Inference completed successfully for '{safe_original_filename}'.")
|
456 |
+
|
457 |
+
# --- Save Overlay Image ---
|
458 |
+
name_part, ext = os.path.splitext(safe_original_filename)
|
459 |
+
# Create consistent overlay filename (crucial for toggle endpoint)
|
460 |
+
overlay_filename = f"analyzed_{name_part}{ext}"
|
461 |
+
overlay_path = os.path.join(RESULT_FOLDER, overlay_filename)
|
462 |
+
logger.info(f"{endpoint_log_prefix} Determined overlay filename: '{overlay_filename}' -> path: '{overlay_path}'")
|
463 |
+
|
464 |
+
# Convert overlay (float32 HWC RGB [0,1]) to uint8 HWC BGR [0,255] for cv2.imwrite
|
465 |
+
try:
|
466 |
+
overlay_to_save_uint8 = (overlay * 255).astype(np.uint8)
|
467 |
+
overlay_to_save_bgr = cv2.cvtColor(overlay_to_save_uint8, cv2.COLOR_RGB2BGR)
|
468 |
+
save_success = cv2.imwrite(overlay_path, overlay_to_save_bgr)
|
469 |
+
if not save_success:
|
470 |
+
raise IOError(f"cv2.imwrite failed to save the overlay image to {overlay_path}")
|
471 |
+
logger.info(f"{endpoint_log_prefix} Overlay image saved successfully to: '{overlay_path}'")
|
472 |
+
except Exception as e:
|
473 |
+
logger.error(f"{endpoint_log_prefix} Failed to convert or save overlay image to '{overlay_path}': {e}")
|
474 |
+
logger.error(traceback.format_exc())
|
475 |
+
return jsonify({"success": False, "message": "Failed to save analysis result image"}), 500
|
476 |
+
|
477 |
+
# --- Success Response ---
|
478 |
+
logger.info(f"{endpoint_log_prefix} Analysis successful for '{safe_original_filename}'. Returning success.")
|
479 |
+
return jsonify({
|
480 |
+
"success": True,
|
481 |
+
"message": "Analysis complete",
|
482 |
+
# Optionally return relative paths for info, client mainly needs overlay_filename
|
483 |
+
"paths": {"original": os.path.relpath(original_path, BASE_DIR), "overlay": os.path.relpath(overlay_path, BASE_DIR)},
|
484 |
+
"overlay_filename": overlay_filename # Return the *exact* filename saved
|
485 |
+
}), 200
|
486 |
+
|
487 |
+
except base64.binascii.Error as e:
|
488 |
+
logger.error(f"{endpoint_log_prefix} Invalid Base64 data received: {e}")
|
489 |
+
return jsonify({"success": False, "message": "Invalid Base64 image data received"}), 400
|
490 |
+
except Exception as e:
|
491 |
+
logger.error(f"{endpoint_log_prefix} Unexpected error during analysis request processing: {e}")
|
492 |
+
logger.error(traceback.format_exc())
|
493 |
+
return jsonify({"success": False, "message": "Internal server error during analysis processing"}), 500
|
494 |
+
|
495 |
+
|
496 |
+
@app.route('/api/toggle-image', methods=['GET'])
|
497 |
+
def get_analysis_path():
|
498 |
+
"""Checks if the analyzed version of a given original filename exists."""
|
499 |
+
endpoint_log_prefix = "[GET /api/toggle-image]"
|
500 |
+
logger.info(f"{endpoint_log_prefix} Received request.")
|
501 |
+
logger.info(f"{endpoint_log_prefix} Full request URL: {request.url}")
|
502 |
+
logger.info(f"{endpoint_log_prefix} Request Query Args: {request.args}") # Log received args
|
503 |
+
|
504 |
+
original_filename = request.args.get('filename') # Get filename from ?filename=...
|
505 |
+
if not original_filename:
|
506 |
+
logger.warning(f"{endpoint_log_prefix} Missing 'filename' query parameter.")
|
507 |
+
return jsonify({"message": "Missing 'filename' query parameter"}), 400
|
508 |
+
|
509 |
+
logger.info(f"{endpoint_log_prefix} Original filename received from query: '{original_filename}'")
|
510 |
+
safe_original_filename = secure_filename(original_filename) # Sanitize
|
511 |
+
if not safe_original_filename:
|
512 |
+
logger.warning(f"{endpoint_log_prefix} Invalid filename after sanitization: '{safe_original_filename}' from '{original_filename}'")
|
513 |
+
return jsonify({"message": "Invalid filename format"}), 400
|
514 |
+
logger.info(f"{endpoint_log_prefix} Sanitized filename for lookup: '{safe_original_filename}'")
|
515 |
+
|
516 |
+
# --- Construct Expected Overlay Path (MUST match /analyze logic) ---
|
517 |
+
name_part, ext = os.path.splitext(safe_original_filename)
|
518 |
+
expected_overlay_filename = f"analyzed_{name_part}{ext}"
|
519 |
+
expected_overlay_path = os.path.join(RESULT_FOLDER, expected_overlay_filename)
|
520 |
+
logger.info(f"{endpoint_log_prefix} Expecting overlay file at: '{expected_overlay_path}'")
|
521 |
+
|
522 |
+
# --- Check if File Exists ---
|
523 |
+
if os.path.exists(expected_overlay_path):
|
524 |
+
logger.info(f"{endpoint_log_prefix} Found analysis result file: '{expected_overlay_filename}'")
|
525 |
+
# Return just the filename, client constructs the full /Result/ URL
|
526 |
+
return jsonify({"filepath": expected_overlay_filename}), 200
|
527 |
+
else:
|
528 |
+
# Explicitly log the path that was checked and not found
|
529 |
+
logger.warning(f"{endpoint_log_prefix} Analysis result file NOT FOUND at checked path: '{expected_overlay_path}'")
|
530 |
+
# Return 404 Not Found status code
|
531 |
+
return jsonify({"message": f"Analysis result not found for '{original_filename}'"}), 404
|
532 |
+
|
533 |
+
|
534 |
+
@app.route('/Result/<filename>')
|
535 |
+
def serve_result_image(filename):
|
536 |
+
"""Serves images from the RESULT_FOLDER."""
|
537 |
+
endpoint_log_prefix = "[GET /Result]"
|
538 |
+
# Sanitize filename received in URL path for security
|
539 |
+
safe_filename = secure_filename(filename)
|
540 |
+
if safe_filename != filename:
|
541 |
+
# Log if the requested filename was changed by sanitization
|
542 |
+
logger.warning(f"{endpoint_log_prefix} Requested filename '{filename}' was sanitized to '{safe_filename}'. Serving sanitized version.")
|
543 |
+
|
544 |
+
logger.info(f"{endpoint_log_prefix} Attempting to serve file: '{safe_filename}' from directory: '{RESULT_FOLDER}'")
|
545 |
+
try:
|
546 |
+
# Use Flask's send_from_directory - safer than manual path joining
|
547 |
+
# as_attachment=False means display in browser if possible
|
548 |
+
return send_from_directory(RESULT_FOLDER, safe_filename, as_attachment=False)
|
549 |
+
except FileNotFoundError:
|
550 |
+
# Log the specific file that was not found
|
551 |
+
logger.error(f"{endpoint_log_prefix} Requested file not found in result folder: '{safe_filename}'")
|
552 |
+
# Return 404 Not Found
|
553 |
+
return jsonify({"message": "Requested analysis image not found"}), 404
|
554 |
+
except Exception as e:
|
555 |
+
# Catch other potential errors (e.g., permission issues)
|
556 |
+
logger.error(f"{endpoint_log_prefix} Error serving file '{safe_filename}': {e}")
|
557 |
+
logger.error(traceback.format_exc())
|
558 |
+
# Return 500 Internal Server Error
|
559 |
+
return jsonify({"message": "Error serving analysis image"}), 500
|
560 |
+
|
561 |
+
|
562 |
+
# --- Main Execution ---
|
563 |
+
if __name__ == '__main__':
|
564 |
+
# Ensure model loaded successfully before starting server
|
565 |
+
if model:
|
566 |
+
logger.info("Model loaded successfully. Starting Flask development server...")
|
567 |
+
# Use debug=True for development (auto-reload, debugger)
|
568 |
+
# Use debug=False for production!
|
569 |
+
# host='0.0.0.0' makes it accessible on the network
|
570 |
+
app.run(host='0.0.0.0', port=7860, debug=True)
|
571 |
+
else:
|
572 |
+
# This message should appear if load_trained_model returned None
|
573 |
+
logger.critical("APPLICATION FAILED TO START: MODEL COULD NOT BE LOADED.")
|
574 |
+
# Exit code 1 indicates an error
|
575 |
+
exit(1)
|
config/config.yaml
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
|
4 |
+
data:
|
5 |
+
num_classes: 14
|
6 |
+
image_size: [512, 512]
|
7 |
+
mean: [0.4810, 0.4810, 0.4810] #[0.4513, 0.4513, 0.4513]
|
8 |
+
std: [0.2492, 0.2492, 0.2492] #[0.1879, 0.1879, 0.1879]
|
9 |
+
background_cls_id: 0
|
10 |
+
split_ratio: 0.8
|
11 |
+
base_dir: 'E:\Seg\Test\medical_segmentation\Bitewing'
|
12 |
+
mask_dir: "${data.base_dir}/newAnnotations/"
|
13 |
+
original_dir: "${data.base_dir}/Images/"
|
14 |
+
dataset_path: "${hydra:runtime.cwd}/Bitwening_dataset1"
|
15 |
+
|
16 |
+
# Normalization settings
|
17 |
+
normalization:
|
18 |
+
always_apply: true
|
19 |
+
|
20 |
+
# Image resize configuration
|
21 |
+
resize:
|
22 |
+
interpolation: "INTER_NEAREST" # CV2 interpolation method
|
23 |
+
|
24 |
+
# Complete augmentation configuration with all parameters
|
25 |
+
augmentation:
|
26 |
+
# HorizontalFlip
|
27 |
+
use_horizontal_flip: true
|
28 |
+
horizontal_flip_prob: 0.5
|
29 |
+
|
30 |
+
# VerticalFlip
|
31 |
+
use_vertical_flip: true
|
32 |
+
vertical_flip_prob: 0.5
|
33 |
+
|
34 |
+
# ShiftScaleRotate
|
35 |
+
use_shift_scale_rotate: true
|
36 |
+
shift_scale_rotate_prob: 0.5
|
37 |
+
rotate_limit: 0.15
|
38 |
+
scale_limit: 0.12
|
39 |
+
shift_limit: 0.12
|
40 |
+
border_mode: 4 # cv2.BORDER_REFLECT_101
|
41 |
+
|
42 |
+
# RandomBrightnessContrast
|
43 |
+
use_brightness_contrast: true
|
44 |
+
brightness_contrast_prob: 0.5
|
45 |
+
brightness_limit: 0.2
|
46 |
+
contrast_limit: 0.2
|
47 |
+
|
48 |
+
# CoarseDropout
|
49 |
+
use_coarse_dropout: true
|
50 |
+
coarse_dropout:
|
51 |
+
max_holes: 8
|
52 |
+
min_holes: 5
|
53 |
+
max_height: 25
|
54 |
+
max_width: 25
|
55 |
+
fill_value: 0
|
56 |
+
mask_fill_value: 0
|
57 |
+
prob: 0.5
|
58 |
+
|
59 |
+
training:
|
60 |
+
batch_size: 10
|
61 |
+
num_epochs: 1
|
62 |
+
init_lr: 3e-4
|
63 |
+
optimizer_name: "AdamW"
|
64 |
+
weight_decay: 0.1
|
65 |
+
use_scheduler: true
|
66 |
+
scheduler: "MultiStepLR"
|
67 |
+
model_name: "nvidia/segformer-b4-finetuned-ade-512-512"
|
68 |
+
num_workers: 0
|
69 |
+
pin_memory: true
|
70 |
+
drop_last: true
|
71 |
+
shuffle_train: true
|
72 |
+
shuffle_valid: false
|
73 |
+
|
74 |
+
inference:
|
75 |
+
batch_size: 10
|
76 |
+
num_batches: 3
|
77 |
+
|
78 |
+
wandb:
|
79 |
+
project: "UM_medical_segmentation"
|
80 |
+
log_model: true
|
81 |
+
|
82 |
+
trainer:
|
83 |
+
accelerator: "gpu"
|
84 |
+
devices: "1"
|
85 |
+
strategy: "auto"
|
86 |
+
precision: "16-mixed"
|
87 |
+
enable_model_summary: false
|
88 |
+
|
89 |
+
id2color:
|
90 |
+
0: [0, 0, 0] # Black
|
91 |
+
1: [0, 0, 255] # Blue
|
92 |
+
2: [0, 255, 0] # Green
|
93 |
+
3: [255, 0, 0] # Red
|
94 |
+
4: [255, 255, 0] # Yellow
|
95 |
+
5: [255, 165, 0] # Orange
|
96 |
+
6: [128, 0, 128] # Purple
|
97 |
+
7: [0, 255, 255] # Cyan
|
98 |
+
8: [255, 20, 147] # Deep Pink
|
99 |
+
9: [75, 0, 130] # Indigo
|
100 |
+
10: [139, 69, 19] # Saddle Brown
|
101 |
+
11: [255, 192, 203] # Pink
|
102 |
+
12: [47, 79, 79] # Dark Slate Gray
|
103 |
+
13: [173, 255, 47] # Green Yellow
|
104 |
+
14: [0, 128, 128] # Teal
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
experiment:
|
109 |
+
name: "EXPERIMENT_1 bitewing dataset" # Descriptive name for the experiment
|
110 |
+
description: "Testing on 14 classes dataset with no layers and just fine tuning the model " # What this experiment is testing
|
111 |
+
goal: "to check the performance of the model on 100 epochs" # What you hope to achieve
|