Guill-Bla commited on
Commit
78f1a97
·
verified ·
1 Parent(s): ada3b45

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +36 -22
tasks/image.py CHANGED
@@ -7,6 +7,10 @@ import random
7
  import os
8
 
9
  from torch.utils.data import DataLoader
 
 
 
 
10
 
11
  from ultralytics import YOLO
12
  from .utils.evaluation import ImageEvaluationRequest
@@ -36,22 +40,12 @@ model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobile
36
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
37
  model.eval()
38
 
39
- from torch.utils.data import Dataset
40
-
41
- def preprocess(image):
42
- # Ensure input image is resized to a fixed size (512, 512)
43
- image = image.resize((512, 512))
44
-
45
- # Convert to NumPy and ensure BGR normalization
46
- image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
47
- image = np.array(image, dtype=np.float32) / 255.0
48
-
49
- # Return as a PIL Image for feature extractor compatibility
50
- return Image.fromarray((image * 255).astype(np.uint8))
51
 
52
  class SmokeDataset(torch.utils.data.Dataset):
53
- def __init__(self, dataset):
54
  self.dataset = dataset
 
 
55
 
56
  def __len__(self):
57
  return len(self.dataset)
@@ -60,15 +54,34 @@ class SmokeDataset(torch.utils.data.Dataset):
60
  example = self.dataset[idx]
61
  image = example["image"]
62
  annotation = example.get("annotations", "").strip()
63
-
64
- # Resize image and preprocess
65
- image = preprocess(image) # Apply resizing and preprocessing
66
 
67
- # Extract features with padding set to True
68
- features = feature_extractor(images=image, return_tensors="pt", padding=True)
69
-
70
- # Return pixel values directly as tensors
71
- return features.pixel_values.squeeze(0), annotation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  def preprocess_batch(images):
@@ -178,7 +191,8 @@ async def evaluate_image(request: ImageEvaluationRequest):
178
  # Update the code below to replace the random baseline with your model inference
179
  #--------------------------------------------------------------------------------------------
180
  smoke_dataset = SmokeDataset(test_dataset)
181
- dataloader = DataLoader(smoke_dataset, batch_size=16, shuffle=False)
 
182
 
183
  predictions = []
184
  true_labels = []
 
7
  import os
8
 
9
  from torch.utils.data import DataLoader
10
+ from torch.utils.data import Dataset
11
+ from PIL import Image
12
+ import torch
13
+
14
 
15
  from ultralytics import YOLO
16
  from .utils.evaluation import ImageEvaluationRequest
 
40
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
41
  model.eval()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  class SmokeDataset(torch.utils.data.Dataset):
45
+ def __init__(self, dataset, feature_extractor, target_size=(224, 224)):
46
  self.dataset = dataset
47
+ self.feature_extractor = feature_extractor
48
+ self.target_size = target_size
49
 
50
  def __len__(self):
51
  return len(self.dataset)
 
54
  example = self.dataset[idx]
55
  image = example["image"]
56
  annotation = example.get("annotations", "").strip()
 
 
 
57
 
58
+ # Ensure image is resized to a fixed target size using PIL
59
+ if isinstance(image, torch.Tensor):
60
+ image = Image.fromarray(image.numpy())
61
+ resized_image = image.resize(self.target_size, Image.ANTIALIAS)
62
+
63
+ # Process image using feature extractor
64
+ features = self.feature_extractor(images=resized_image, return_tensors="pt").pixel_values
65
+
66
+ return features.squeeze(0), annotation
67
+
68
+
69
+ def collate_fn(batch):
70
+ images, annotations = zip(*batch)
71
+ images = torch.stack(images) # Ensure batch has uniform shape
72
+ return images, annotations
73
+
74
+
75
+ def preprocess(image):
76
+ # Ensure input image is resized to a fixed size (512, 512)
77
+ image = image.resize((512, 512))
78
+
79
+ # Convert to NumPy and ensure BGR normalization
80
+ image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
81
+ image = np.array(image, dtype=np.float32) / 255.0
82
+
83
+ # Return as a PIL Image for feature extractor compatibility
84
+ return Image.fromarray((image * 255).astype(np.uint8))
85
 
86
 
87
  def preprocess_batch(images):
 
191
  # Update the code below to replace the random baseline with your model inference
192
  #--------------------------------------------------------------------------------------------
193
  smoke_dataset = SmokeDataset(test_dataset)
194
+ # dataloader = DataLoader(smoke_dataset, batch_size=16, shuffle=False)
195
+ dataloader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)
196
 
197
  predictions = []
198
  true_labels = []