hujiecpp commited on
Commit
399850e
·
1 Parent(s): ba148f1

init project

Browse files
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import sys
3
  sys.path.append(os.path.abspath('./modules'))
4
 
5
- import math
6
  import tempfile
7
  import gradio
8
  import torch
@@ -11,23 +11,23 @@ import numpy as np
11
  import functools
12
  import trimesh
13
  import copy
14
- from PIL import Image
15
  from scipy.spatial.transform import Rotation
16
 
17
  from modules.pe3r.images import Images
18
 
19
  from modules.dust3r.inference import inference
20
  from modules.dust3r.image_pairs import make_pairs
21
- from modules.dust3r.utils.image import load_images, rgb
22
  from modules.dust3r.utils.device import to_numpy
23
  from modules.dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
24
  from modules.dust3r.cloud_opt import global_aligner, GlobalAlignerMode
25
- from copy import deepcopy
26
- import cv2
27
- from typing import Any, Dict, Generator,List
28
- import matplotlib.pyplot as pl
29
 
30
- from modules.mobilesamv2.utils.transforms import ResizeLongestSide
31
  # from modules.pe3r.models import Models
32
  import torchvision.transforms as tvf
33
 
@@ -447,7 +447,7 @@ def get_3D_model_from_scene(outdir, scene, min_conf_thr=3, as_pointcloud=False,
447
  # return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
448
 
449
 
450
- @spaces.GPU(duration=60)
451
  def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_conf_thr=3.0,
452
  as_pointcloud=True, mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05,
453
  scenegraph_type='complete', winsize=1, refid=0):
@@ -541,7 +541,7 @@ def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_
541
  torch.cuda.empty_cache()
542
  return scene, outfile
543
 
544
- # @spaces.GPU(duration=60)
545
  # def get_3D_object_from_scene(outdir, text, threshold, scene, min_conf_thr, as_pointcloud,
546
  # mask_sky, clean_depth, transparent_cams, cam_size):
547
 
@@ -561,65 +561,36 @@ def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_
561
  # clean_depth, transparent_cams, cam_size)
562
  # return outfile
563
 
 
564
 
565
- with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
566
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname)
567
- # model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
568
- # get_3D_object_from_scene_fun = functools.partial(get_3D_object_from_scene, tmpdirname)
569
-
570
- with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="PE3R Demo") as demo:
571
- # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
572
- scene = gradio.State(None)
573
- gradio.HTML('<h2 style="text-align: center;">PE3R Demo</h2>')
574
- with gradio.Column():
575
- inputfiles = gradio.File(file_count="multiple")
576
- # with gradio.Row():
577
- # schedule = gradio.Dropdown(["linear", "cosine"],
578
- # value='linear', label="schedule", info="For global alignment!",
579
- # visible=False)
580
- # niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
581
- # label="num_iterations", info="For global alignment!",
582
- # visible=False)
583
- # scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
584
- # ("swin: sliding window", "swin"),
585
- # ("oneref: match one image with all", "oneref")],
586
- # value='complete', label="Scenegraph",
587
- # info="Define how to make pairs",
588
- # interactive=True,
589
- # visible=False)
590
- # winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
591
- # minimum=1, maximum=1, step=1, visible=False)
592
- # refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
593
-
594
- run_btn = gradio.Button("Reconstruct")
595
-
596
- # with gradio.Row():
597
- # adjust the confidence threshold
598
- # min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
599
- # adjust the camera size in the output pointcloud
600
- # cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001, visible=False)
601
- # with gradio.Row():
602
- # as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud", visible=False)
603
- # two post process implemented
604
- # mask_sky = gradio.Checkbox(value=False, label="Mask sky", visible=False)
605
- # clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
606
- # transparent_cams = gradio.Checkbox(value=True, label="Transparent cameras", visible=False)
607
-
608
- with gradio.Row():
609
- text_input = gradio.Textbox(label="Query Text")
610
- threshold = gradio.Slider(label="Threshold", value=0.85, minimum=0.0, maximum=1.0, step=0.01)
611
-
612
- find_btn = gradio.Button("Find")
613
-
614
- outmodel = gradio.Model3D()
615
- # events
616
-
617
- run_btn.click(fn=recon_fun,
618
- inputs=[inputfiles],
619
- outputs=[scene, outmodel]) # , outgallery
620
-
621
- # find_btn.click(fn=get_3D_object_from_scene_fun,
622
- # inputs=[text_input, threshold, scene, min_conf_thr, as_pointcloud, mask_sky,
623
- # clean_depth, transparent_cams, cam_size],
624
- # outputs=outmodel)
625
- demo.launch(show_error=True, share=None, server_name=None, server_port=None)
 
2
  import sys
3
  sys.path.append(os.path.abspath('./modules'))
4
 
5
+ # import math
6
  import tempfile
7
  import gradio
8
  import torch
 
11
  import functools
12
  import trimesh
13
  import copy
14
+ # from PIL import Image
15
  from scipy.spatial.transform import Rotation
16
 
17
  from modules.pe3r.images import Images
18
 
19
  from modules.dust3r.inference import inference
20
  from modules.dust3r.image_pairs import make_pairs
21
+ from modules.dust3r.utils.image import load_images #, rgb
22
  from modules.dust3r.utils.device import to_numpy
23
  from modules.dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
24
  from modules.dust3r.cloud_opt import global_aligner, GlobalAlignerMode
25
+ # from copy import deepcopy
26
+ # import cv2
27
+ # from typing import Any, Dict, Generator,List
28
+ # import matplotlib.pyplot as pl
29
 
30
+ # from modules.mobilesamv2.utils.transforms import ResizeLongestSide
31
  # from modules.pe3r.models import Models
32
  import torchvision.transforms as tvf
33
 
 
447
  # return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
448
 
449
 
450
+ @spaces.GPU(duration=30)
451
  def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_conf_thr=3.0,
452
  as_pointcloud=True, mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05,
453
  scenegraph_type='complete', winsize=1, refid=0):
 
541
  torch.cuda.empty_cache()
542
  return scene, outfile
543
 
544
+ # @spaces.GPU(duration=30)
545
  # def get_3D_object_from_scene(outdir, text, threshold, scene, min_conf_thr, as_pointcloud,
546
  # mask_sky, clean_depth, transparent_cams, cam_size):
547
 
 
561
  # clean_depth, transparent_cams, cam_size)
562
  # return outfile
563
 
564
+ tmpdirname = tempfile.mkdtemp(suffix='pe3r_gradio_demo')
565
 
566
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname)
567
+ # model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
568
+ # get_3D_object_from_scene_fun = functools.partial(get_3D_object_from_scene, tmpdirname)
569
+
570
+ with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="PE3R Demo") as demo:
571
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
572
+ scene = gradio.State(None)
573
+ gradio.HTML('<h2 style="text-align: center;">PE3R Demo</h2>')
574
+ with gradio.Column():
575
+ inputfiles = gradio.File(file_count="multiple")
576
+
577
+ run_btn = gradio.Button("Reconstruct")
578
+
579
+ with gradio.Row():
580
+ text_input = gradio.Textbox(label="Query Text")
581
+ threshold = gradio.Slider(label="Threshold", value=0.85, minimum=0.0, maximum=1.0, step=0.01)
582
+
583
+ find_btn = gradio.Button("Find")
584
+
585
+ outmodel = gradio.Model3D()
586
+ # events
587
+
588
+ run_btn.click(fn=recon_fun,
589
+ inputs=[inputfiles],
590
+ outputs=[scene, outmodel]) # , outgallery
591
+
592
+ # find_btn.click(fn=get_3D_object_from_scene_fun,
593
+ # inputs=[text_input, threshold, scene, min_conf_thr, as_pointcloud, mask_sky,
594
+ # clean_depth, transparent_cams, cam_size],
595
+ # outputs=outmodel)
596
+ demo.launch(show_error=True, share=None, server_name=None, server_port=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/dust3r/__pycache__/inference.cpython-312.pyc CHANGED
Binary files a/modules/dust3r/__pycache__/inference.cpython-312.pyc and b/modules/dust3r/__pycache__/inference.cpython-312.pyc differ
 
modules/dust3r/inference.py CHANGED
@@ -41,12 +41,12 @@ def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, u
41
  if symmetrize_batch:
42
  view1, view2 = make_batch_symmetric(batch)
43
 
44
- with torch.cuda.amp.autocast(enabled=bool(use_amp)):
45
- pred1, pred2 = model(view1, view2)
46
 
47
  # loss is supposed to be symmetric
48
- with torch.cuda.amp.autocast(enabled=False):
49
- loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
50
 
51
  result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
52
  return result[ret] if ret else result
 
41
  if symmetrize_batch:
42
  view1, view2 = make_batch_symmetric(batch)
43
 
44
+ # with torch.cuda.amp.autocast(enabled=bool(use_amp)):
45
+ pred1, pred2 = model(view1, view2)
46
 
47
  # loss is supposed to be symmetric
48
+ # with torch.cuda.amp.autocast(enabled=False):
49
+ loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
50
 
51
  result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
52
  return result[ret] if ret else result
modules/dust3r/utils/image.py.bak DELETED
@@ -1,163 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilitary functions about images (loading/converting...)
6
- # --------------------------------------------------------
7
- import os
8
- import torch
9
- import numpy as np
10
- import PIL.Image
11
- from PIL.ImageOps import exif_transpose
12
- import torchvision.transforms as tvf
13
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
14
- import cv2 # noqa
15
-
16
- try:
17
- from pillow_heif import register_heif_opener # noqa
18
- register_heif_opener()
19
- heif_support_enabled = True
20
- except ImportError:
21
- heif_support_enabled = False
22
-
23
- ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
24
-
25
-
26
- def img_to_arr( img ):
27
- if isinstance(img, str):
28
- img = imread_cv2(img)
29
- return img
30
-
31
- def imread_cv2(path, options=cv2.IMREAD_COLOR):
32
- """ Open an image or a depthmap with opencv-python.
33
- """
34
- if path.endswith(('.exr', 'EXR')):
35
- options = cv2.IMREAD_ANYDEPTH
36
- img = cv2.imread(path, options)
37
- if img is None:
38
- raise IOError(f'Could not load image={path} with {options=}')
39
- if img.ndim == 3:
40
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
41
- return img
42
-
43
-
44
- def rgb(ftensor, true_shape=None):
45
- if isinstance(ftensor, list):
46
- return [rgb(x, true_shape=true_shape) for x in ftensor]
47
- if isinstance(ftensor, torch.Tensor):
48
- ftensor = ftensor.detach().cpu().numpy() # H,W,3
49
- if ftensor.ndim == 3 and ftensor.shape[0] == 3:
50
- ftensor = ftensor.transpose(1, 2, 0)
51
- elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
52
- ftensor = ftensor.transpose(0, 2, 3, 1)
53
- if true_shape is not None:
54
- H, W = true_shape
55
- ftensor = ftensor[:H, :W]
56
- if ftensor.dtype == np.uint8:
57
- img = np.float32(ftensor) / 255
58
- else:
59
- img = (ftensor * 0.5) + 0.5
60
- return img.clip(min=0, max=1)
61
-
62
-
63
- def _resize_pil_image(img, long_edge_size):
64
- S = max(img.size)
65
- if S > long_edge_size:
66
- interp = PIL.Image.LANCZOS
67
- elif S <= long_edge_size:
68
- interp = PIL.Image.BICUBIC
69
- new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
70
- return img.resize(new_size, interp)
71
-
72
-
73
- def load_images(folder_or_list, cog_seg_maps, size, square_ok=False, verbose=True):
74
- """ open and convert all images in a list or folder to proper input format for DUSt3R
75
- """
76
- if isinstance(folder_or_list, str):
77
- if verbose:
78
- print(f'>> Loading images from {folder_or_list}')
79
- root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
80
-
81
- elif isinstance(folder_or_list, list):
82
- if verbose:
83
- print(f'>> Loading a list of {len(folder_or_list)} images')
84
- root, folder_content = '', folder_or_list
85
-
86
- else:
87
- raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
88
-
89
- supported_images_extensions = ['.jpg', '.jpeg', '.png']
90
- if heif_support_enabled:
91
- supported_images_extensions += ['.heic', '.heif']
92
- supported_images_extensions = tuple(supported_images_extensions)
93
-
94
- imgs = []
95
- for path in enumerate(folder_content):
96
- if not path.lower().endswith(supported_images_extensions):
97
- continue
98
- img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
99
-
100
- W1, H1 = img.size
101
- if size == 224:
102
- # resize short side to 224 (then crop)
103
- img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
104
- else:
105
- # resize long side to 512
106
- img = _resize_pil_image(img, size)
107
-
108
- W, H = img.size
109
- cx, cy = W//2, H//2
110
- if size == 224:
111
- half = min(cx, cy)
112
- img = img.crop((cx-half, cy-half, cx+half, cy+half))
113
- else:
114
- halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
115
- if not (square_ok) and W == H:
116
- halfh = 3*halfw/4
117
- img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
118
-
119
- W2, H2 = img.size
120
- if verbose:
121
- print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
122
- imgs.append(dict(img=img, ori_img=ImgNorm(img)[None], true_shape=np.int32(
123
- [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
124
-
125
- mean_colors = {}
126
- mean_colors_cnt = {}
127
- for i in range(len(imgs)):
128
- img_np = imgs[i]['img']
129
- seg_map = cog_seg_maps[i]
130
- unique_labels = np.unique(seg_map)
131
- for label in unique_labels:
132
- if label == -1:
133
- continue
134
- mask = (seg_map == label)
135
- mean_color = img_np[mask].mean(axis=0)
136
- if label in mean_colors.keys():
137
- mean_colors[label] += mean_color
138
- mean_colors_cnt[label] += 1
139
- else:
140
- mean_colors[label] = mean_color
141
- mean_colors_cnt[label] = 1
142
- for key in mean_colors.keys():
143
- mean_colors[key] /= mean_colors_cnt[key]
144
-
145
- for i in range(len(imgs)):
146
- img_np = np.array(imgs[i]['img'])
147
- smoothed_image = np.zeros_like(img_np)
148
- seg_map = cog_seg_maps[i]
149
- unique_labels = np.unique(seg_map)
150
- for label in unique_labels:
151
- if label == -1:
152
- continue
153
- mask = (seg_map == label)
154
- mean_color = mean_colors[label]
155
- smoothed_image[mask] = mean_color
156
- smoothed_image = cv2.addWeighted(img_np, 0.1, smoothed_image, 0.9, 0)
157
- smoothed_image = PIL.Image.fromarray(smoothed_image)
158
- imgs[i]['img'] = ImgNorm(smoothed_image)[None]
159
-
160
- assert imgs, 'no images foud at '+root
161
- if verbose:
162
- print(f' (Found {len(imgs)} images)')
163
- return imgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/dust3r/utils/image.py.ori DELETED
@@ -1,143 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilitary functions about images (loading/converting...)
6
- # --------------------------------------------------------
7
- import os
8
- import torch
9
- import numpy as np
10
- import PIL.Image
11
- from PIL.ImageOps import exif_transpose
12
- import torchvision.transforms as tvf
13
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
14
- import cv2 # noqa
15
-
16
- try:
17
- from pillow_heif import register_heif_opener # noqa
18
- register_heif_opener()
19
- heif_support_enabled = True
20
- except ImportError:
21
- heif_support_enabled = False
22
-
23
- ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
24
-
25
-
26
- def img_to_arr( img ):
27
- if isinstance(img, str):
28
- img = imread_cv2(img)
29
- return img
30
-
31
- def imread_cv2(path, options=cv2.IMREAD_COLOR):
32
- """ Open an image or a depthmap with opencv-python.
33
- """
34
- if path.endswith(('.exr', 'EXR')):
35
- options = cv2.IMREAD_ANYDEPTH
36
- img = cv2.imread(path, options)
37
- if img is None:
38
- raise IOError(f'Could not load image={path} with {options=}')
39
- if img.ndim == 3:
40
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
41
- return img
42
-
43
-
44
- def rgb(ftensor, true_shape=None):
45
- if isinstance(ftensor, list):
46
- return [rgb(x, true_shape=true_shape) for x in ftensor]
47
- if isinstance(ftensor, torch.Tensor):
48
- ftensor = ftensor.detach().cpu().numpy() # H,W,3
49
- if ftensor.ndim == 3 and ftensor.shape[0] == 3:
50
- ftensor = ftensor.transpose(1, 2, 0)
51
- elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
52
- ftensor = ftensor.transpose(0, 2, 3, 1)
53
- if true_shape is not None:
54
- H, W = true_shape
55
- ftensor = ftensor[:H, :W]
56
- if ftensor.dtype == np.uint8:
57
- img = np.float32(ftensor) / 255
58
- else:
59
- img = (ftensor * 0.5) + 0.5
60
- return img.clip(min=0, max=1)
61
-
62
-
63
- def _resize_pil_image(img, long_edge_size):
64
- S = max(img.size)
65
- if S > long_edge_size:
66
- interp = PIL.Image.LANCZOS
67
- elif S <= long_edge_size:
68
- interp = PIL.Image.BICUBIC
69
- new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
70
- return img.resize(new_size, interp)
71
-
72
-
73
- def load_images(folder_or_list, cog_seg_maps, size, square_ok=False, verbose=True):
74
- """ open and convert all images in a list or folder to proper input format for DUSt3R
75
- """
76
- if isinstance(folder_or_list, str):
77
- if verbose:
78
- print(f'>> Loading images from {folder_or_list}')
79
- root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
80
-
81
- elif isinstance(folder_or_list, list):
82
- if verbose:
83
- print(f'>> Loading a list of {len(folder_or_list)} images')
84
- root, folder_content = '', folder_or_list
85
-
86
- else:
87
- raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
88
-
89
- supported_images_extensions = ['.jpg', '.jpeg', '.png']
90
- if heif_support_enabled:
91
- supported_images_extensions += ['.heic', '.heif']
92
- supported_images_extensions = tuple(supported_images_extensions)
93
-
94
- imgs = []
95
- for i, path in enumerate(folder_content):
96
- if not path.lower().endswith(supported_images_extensions):
97
- continue
98
- img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
99
-
100
- img_np = np.array(img)
101
- smoothed_image = np.zeros_like(img_np)
102
- seg_map = cog_seg_maps[i]
103
- unique_labels = np.unique(seg_map)
104
- for label in unique_labels:
105
- mask = (seg_map == label)
106
- mean_color = img_np[mask].mean(axis=0)
107
- smoothed_image[mask] = mean_color
108
- smoothed_image = cv2.addWeighted(img_np, 0.05, smoothed_image, 0.95, 0)
109
- smoothed_image = PIL.Image.fromarray(smoothed_image)
110
-
111
- W1, H1 = img.size
112
- if size == 224:
113
- # resize short side to 224 (then crop)
114
- img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
115
- smoothed_image = _resize_pil_image(smoothed_image, round(size * max(W1/H1, H1/W1)))
116
- else:
117
- # resize long side to 512
118
- img = _resize_pil_image(img, size)
119
- smoothed_image = _resize_pil_image(smoothed_image, size)
120
-
121
- W, H = img.size
122
- cx, cy = W//2, H//2
123
- if size == 224:
124
- half = min(cx, cy)
125
- img = img.crop((cx-half, cy-half, cx+half, cy+half))
126
- smoothed_image = smoothed_image.crop((cx-half, cy-half, cx+half, cy+half))
127
- else:
128
- halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
129
- if not (square_ok) and W == H:
130
- halfh = 3*halfw/4
131
- img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
132
- smoothed_image = smoothed_image.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
133
-
134
- W2, H2 = img.size
135
- if verbose:
136
- print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
137
- imgs.append(dict(img=ImgNorm(smoothed_image)[None], ori_img=ImgNorm(img)[None], true_shape=np.int32(
138
- [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
139
-
140
- assert imgs, 'no images foud at '+root
141
- if verbose:
142
- print(f' (Found {len(imgs)} images)')
143
- return imgs