henry000 commited on
Commit
7daf6f0
·
1 Parent(s): 3b24b35

♻️ [Refactor] move to_tensor to transform

Browse files
yolo/__init__.py CHANGED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yolo.config.config import Config
2
+ from yolo.model.yolo import create_model
3
+ from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
+ from yolo.tools.drawer import draw_bboxes
5
+ from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
+ from yolo.utils.bounding_box_utils import bbox_nms
7
+ from yolo.utils.deploy_utils import FastModelLoader
8
+ from yolo.utils.logging_utils import custom_logger
9
+
10
+ all = [
11
+ "create_model",
12
+ "Config",
13
+ "custom_logger",
14
+ "validate_log_directory",
15
+ "draw_bboxes",
16
+ "bbox_nms",
17
+ "AugmentationComposer",
18
+ "create_dataloader",
19
+ "FastModelLoader",
20
+ "ModelTester",
21
+ "ModelTrainer",
22
+ "ModelValidator",
23
+ ]
yolo/tools/data_augmentation.py CHANGED
@@ -21,6 +21,7 @@ class AugmentationComposer:
21
  for transform in self.transforms:
22
  image, boxes = transform(image, boxes)
23
  image, boxes = self.pad_resize(image, boxes)
 
24
  return image, boxes
25
 
26
 
 
21
  for transform in self.transforms:
22
  image, boxes = transform(image, boxes)
23
  image, boxes = self.pad_resize(image, boxes)
24
+ image = TF.to_tensor(image)
25
  return image, boxes
26
 
27
 
yolo/tools/data_loader.py CHANGED
@@ -151,9 +151,7 @@ class YoloDataset(Dataset):
151
 
152
  def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
153
  img, bboxes = self.get_data(idx)
154
- if self.transform:
155
- img, bboxes = self.transform(img, bboxes)
156
- img = TF.to_tensor(img)
157
  return img, bboxes
158
 
159
  def __len__(self) -> int:
@@ -269,7 +267,7 @@ class StreamDataLoader:
269
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
270
  frame = Image.fromarray(frame)
271
  frame, _ = self.transform(frame, torch.zeros(0, 5))
272
- frame = TF.to_tensor(frame)[None]
273
  if not self.is_stream:
274
  self.queue.put(frame)
275
  else:
 
151
 
152
  def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
153
  img, bboxes = self.get_data(idx)
154
+ img, bboxes = self.transform(img, bboxes)
 
 
155
  return img, bboxes
156
 
157
  def __len__(self) -> int:
 
267
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
268
  frame = Image.fromarray(frame)
269
  frame, _ = self.transform(frame, torch.zeros(0, 5))
270
+ frame = frame[None]
271
  if not self.is_stream:
272
  self.queue.put(frame)
273
  else: