developerskyebrowse commited on
Commit
1d0d99a
·
1 Parent(s): 5120359
app.py CHANGED
@@ -29,7 +29,7 @@ print("loading pipe")
29
  compiled = False
30
  from preprocess import Preprocessor
31
  preprocessor = Preprocessor()
32
- preprocessor.load("NormalBae")
33
  # api = HfApi()
34
 
35
  import spaces
 
29
  compiled = False
30
  from preprocess import Preprocessor
31
  preprocessor = Preprocessor()
32
+ # preprocessor.load("NormalBae")
33
  # api = HfApi()
34
 
35
  import spaces
controlnet_aux_local/normalbae/__init__.py CHANGED
@@ -1,31 +1,19 @@
1
  import os
2
  import types
3
  import warnings
4
-
5
- # import cv2
6
- import numpy as np
7
  import torch
8
  import torchvision.transforms as transforms
9
  from einops import rearrange
10
  from huggingface_hub import hf_hub_download
11
  from PIL import Image
 
12
 
13
  from ..util import HWC3, resize_image
14
  from .nets.NNET import NNET
15
 
16
-
17
- # load model
18
  def load_checkpoint(fpath, model):
19
  ckpt = torch.load(fpath, map_location='cpu')['model']
20
-
21
- load_dict = {}
22
- for k, v in ckpt.items():
23
- if k.startswith('module.'):
24
- k_ = k.replace('module.', '')
25
- load_dict[k_] = v
26
- else:
27
- load_dict[k] = v
28
-
29
  model.load_state_dict(load_dict)
30
  return model
31
 
@@ -37,21 +25,10 @@ class NormalBaeDetector:
37
  @classmethod
38
  def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False):
39
  filename = filename or "scannet.pt"
 
40
 
41
- if os.path.isdir(pretrained_model_or_path):
42
- model_path = os.path.join(pretrained_model_or_path, filename)
43
- else:
44
- model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only)
45
-
46
- args = types.SimpleNamespace()
47
- args.mode = 'client'
48
- args.architecture = 'BN'
49
- args.pretrained = 'scannet'
50
- args.sampling_ratio = 0.4
51
- args.importance_ratio = 0.7
52
- model = NNET(args)
53
- model = load_checkpoint(model_path, model)
54
- model.eval()
55
 
56
  return cls(model)
57
 
@@ -59,75 +36,24 @@ class NormalBaeDetector:
59
  self.model.to(device)
60
  return self
61
 
62
- def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
63
- if "return_pil" in kwargs:
64
- warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
65
- output_type = "pil" if kwargs["return_pil"] else "np"
66
- if type(output_type) is bool:
67
- warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
68
- if output_type:
69
- output_type = "pil"
70
-
71
- device = next(iter(self.model.parameters())).device
72
- if not isinstance(input_image, np.ndarray):
73
- input_image = np.array(input_image, dtype=np.uint8)
74
 
 
 
75
  input_image = HWC3(input_image)
76
  input_image = resize_image(input_image, detect_resolution)
77
 
78
- assert input_image.ndim == 3
79
- image_normal = input_image
80
- with torch.no_grad():
81
- image_normal = torch.from_numpy(image_normal).float().to(device)
82
- image_normal = image_normal / 255.0
83
- image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
84
- image_normal = self.norm(image_normal)
85
-
86
- normal = self.model(image_normal)
87
- normal = normal[0][-1][:, :3]
88
- normal = ((normal + 1) * 0.5).clip(0, 1)
89
 
90
- normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
91
- normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
 
92
 
93
- # detected_map = normal_image
94
  detected_map = HWC3(normal_image)
95
 
96
- # img = resize_image(input_image, image_resolution)
97
- # H, W, C = input_image.shape
98
-
99
- # detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
100
-
101
- if output_type == "pil":
102
- detected_map = Image.fromarray(detected_map)
103
-
104
- return detected_map
105
-
106
- # def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
107
- # if "return_pil" in kwargs:
108
- # warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
109
- # output_type = "pil" if kwargs["return_pil"] else "np"
110
- # if type(output_type) is bool:
111
- # warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
112
- # if output_type:
113
- # output_type = "pil"
114
-
115
- # device = next(iter(self.model.parameters())).device
116
- # input_image = resize_image(input_image, detect_resolution)
117
-
118
- # with torch.no_grad():
119
- # image_normal = torch.from_numpy(input_image).float().to(device)
120
- # image_normal = image_normal / 255.0
121
- # image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
122
- # image_normal = self.norm(image_normal)
123
- # normal = self.model(image_normal)
124
- # normal = normal[0][-1][:, :3]
125
- # normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
126
-
127
- # detected_map = normal
128
-
129
- # if output_type == "pil":
130
- # detected_map = Image.fromarray((detected_map * 255.0).astype(np.uint8))
131
-
132
- # return detected_map
133
-
 
1
  import os
2
  import types
3
  import warnings
 
 
 
4
  import torch
5
  import torchvision.transforms as transforms
6
  from einops import rearrange
7
  from huggingface_hub import hf_hub_download
8
  from PIL import Image
9
+ import numpy as np
10
 
11
  from ..util import HWC3, resize_image
12
  from .nets.NNET import NNET
13
 
 
 
14
  def load_checkpoint(fpath, model):
15
  ckpt = torch.load(fpath, map_location='cpu')['model']
16
+ load_dict = {k.replace('module.', ''): v for k, v in ckpt.items()}
 
 
 
 
 
 
 
 
17
  model.load_state_dict(load_dict)
18
  return model
19
 
 
25
  @classmethod
26
  def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False):
27
  filename = filename or "scannet.pt"
28
+ model_path = os.path.join(pretrained_model_or_path, filename) if os.path.isdir(pretrained_model_or_path) else hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only)
29
 
30
+ args = types.SimpleNamespace(mode='client', architecture='BN', pretrained='scannet', sampling_ratio=0.4, importance_ratio=0.7)
31
+ model = load_checkpoint(model_path, NNET(args)).eval()
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  return cls(model)
34
 
 
36
  self.model.to(device)
37
  return self
38
 
39
+ @torch.no_grad()
40
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", **kwargs):
41
+ if isinstance(output_type, bool) or "return_pil" in kwargs:
42
+ warnings.warn("Deprecated: Use output_type='pil' or 'np' instead of boolean values.", DeprecationWarning)
43
+ output_type = "pil" if (kwargs.get("return_pil", output_type) if isinstance(output_type, bool) else output_type) else "np"
 
 
 
 
 
 
 
44
 
45
+ device = next(self.model.parameters()).device
46
+ input_image = np.array(input_image, dtype=np.uint8) if not isinstance(input_image, np.ndarray) else input_image
47
  input_image = HWC3(input_image)
48
  input_image = resize_image(input_image, detect_resolution)
49
 
50
+ image_normal = torch.from_numpy(input_image).float().to(device)
51
+ image_normal = self.norm(image_normal.permute(2, 0, 1).unsqueeze(0) / 255.0)
 
 
 
 
 
 
 
 
 
52
 
53
+ normal = self.model(image_normal)[0][-1][:, :3]
54
+ normal = ((normal + 1) * 0.5).clip(0, 1)
55
+ normal_image = (normal[0].permute(1, 2, 0).cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
56
 
 
57
  detected_map = HWC3(normal_image)
58
 
59
+ return Image.fromarray(detected_map) if output_type == "pil" else detected_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
controlnet_aux_local/normalbae/__init__backup.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import types
3
+ import warnings
4
+
5
+ # import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from einops import rearrange
10
+ from huggingface_hub import hf_hub_download
11
+ from PIL import Image
12
+
13
+ from ..util import HWC3, resize_image
14
+ from .nets.NNET import NNET
15
+
16
+
17
+ # load model
18
+ def load_checkpoint(fpath, model):
19
+ ckpt = torch.load(fpath, map_location='cpu')['model']
20
+
21
+ load_dict = {}
22
+ for k, v in ckpt.items():
23
+ if k.startswith('module.'):
24
+ k_ = k.replace('module.', '')
25
+ load_dict[k_] = v
26
+ else:
27
+ load_dict[k] = v
28
+
29
+ model.load_state_dict(load_dict)
30
+ return model
31
+
32
+ class NormalBaeDetector:
33
+ def __init__(self, model):
34
+ self.model = model
35
+ self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
36
+
37
+ @classmethod
38
+ def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False):
39
+ filename = filename or "scannet.pt"
40
+
41
+ if os.path.isdir(pretrained_model_or_path):
42
+ model_path = os.path.join(pretrained_model_or_path, filename)
43
+ else:
44
+ model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only)
45
+
46
+ args = types.SimpleNamespace()
47
+ args.mode = 'client'
48
+ args.architecture = 'BN'
49
+ args.pretrained = 'scannet'
50
+ args.sampling_ratio = 0.4
51
+ args.importance_ratio = 0.7
52
+ model = NNET(args)
53
+ model = load_checkpoint(model_path, model)
54
+ model.eval()
55
+
56
+ return cls(model)
57
+
58
+ def to(self, device):
59
+ self.model.to(device)
60
+ return self
61
+
62
+ def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
63
+ if "return_pil" in kwargs:
64
+ warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
65
+ output_type = "pil" if kwargs["return_pil"] else "np"
66
+ if type(output_type) is bool:
67
+ warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
68
+ if output_type:
69
+ output_type = "pil"
70
+
71
+ device = next(iter(self.model.parameters())).device
72
+ if not isinstance(input_image, np.ndarray):
73
+ input_image = np.array(input_image, dtype=np.uint8)
74
+
75
+ input_image = HWC3(input_image)
76
+ input_image = resize_image(input_image, detect_resolution)
77
+
78
+ assert input_image.ndim == 3
79
+ image_normal = input_image
80
+ with torch.no_grad():
81
+ image_normal = torch.from_numpy(image_normal).float().to(device)
82
+ image_normal = image_normal / 255.0
83
+ image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
84
+ image_normal = self.norm(image_normal)
85
+
86
+ normal = self.model(image_normal)
87
+ normal = normal[0][-1][:, :3]
88
+ normal = ((normal + 1) * 0.5).clip(0, 1)
89
+
90
+ normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
91
+ normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
92
+
93
+ # detected_map = normal_image
94
+ detected_map = HWC3(normal_image)
95
+
96
+ # img = resize_image(input_image, image_resolution)
97
+ # H, W, C = input_image.shape
98
+
99
+ # detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
100
+
101
+ if output_type == "pil":
102
+ detected_map = Image.fromarray(detected_map)
103
+
104
+ return detected_map
105
+
106
+ # def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
107
+ # if "return_pil" in kwargs:
108
+ # warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
109
+ # output_type = "pil" if kwargs["return_pil"] else "np"
110
+ # if type(output_type) is bool:
111
+ # warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
112
+ # if output_type:
113
+ # output_type = "pil"
114
+
115
+ # device = next(iter(self.model.parameters())).device
116
+ # input_image = resize_image(input_image, detect_resolution)
117
+
118
+ # with torch.no_grad():
119
+ # image_normal = torch.from_numpy(input_image).float().to(device)
120
+ # image_normal = image_normal / 255.0
121
+ # image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
122
+ # image_normal = self.norm(image_normal)
123
+ # normal = self.model(image_normal)
124
+ # normal = normal[0][-1][:, :3]
125
+ # normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
126
+
127
+ # detected_map = normal
128
+
129
+ # if output_type == "pil":
130
+ # detected_map = Image.fromarray((detected_map * 255.0).astype(np.uint8))
131
+
132
+ # return detected_map
133
+