joselobenitezg commited on
Commit
9930f16
·
1 Parent(s): 2e49a94

add normal

Browse files
Files changed (2) hide show
  1. app.py +5 -0
  2. inference/normal.py +176 -0
app.py CHANGED
@@ -8,6 +8,7 @@ import spaces
8
  from inference.seg import process_image_or_video as process_seg
9
  from inference.pose import process_image_or_video as process_pose
10
  from inference.depth import process_image_or_video as process_depth
 
11
  from config import SAPIENS_LITE_MODELS_PATH
12
 
13
  def update_model_choices(task):
@@ -24,6 +25,8 @@ def process_image(input_image, task, version):
24
  result = process_pose(input_image, task=task.lower(), version=version)
25
  elif task.lower() == 'depth':
26
  result = process_depth(input_image, task=task.lower(), version=version)
 
 
27
  else:
28
  result = None
29
  print(f"Tarea no soportada: {task}")
@@ -50,6 +53,8 @@ def process_video(input_video, task, version):
50
  processed_frame = process_pose(frame_rgb, task=task.lower(), version=version)
51
  elif task.lower() == 'depth':
52
  processed_frame = process_depth(frame_rgb, task=task.lower(), version=version)
 
 
53
  else:
54
  processed_frame = None
55
  print(f"Tarea no soportada: {task}")
 
8
  from inference.seg import process_image_or_video as process_seg
9
  from inference.pose import process_image_or_video as process_pose
10
  from inference.depth import process_image_or_video as process_depth
11
+ from inference.normal import process_image_or_video as process_normal
12
  from config import SAPIENS_LITE_MODELS_PATH
13
 
14
  def update_model_choices(task):
 
25
  result = process_pose(input_image, task=task.lower(), version=version)
26
  elif task.lower() == 'depth':
27
  result = process_depth(input_image, task=task.lower(), version=version)
28
+ elif task.lower() == 'normal':
29
+ result = process_normal(input_image, task=task.lower(), version=version)
30
  else:
31
  result = None
32
  print(f"Tarea no soportada: {task}")
 
53
  processed_frame = process_pose(frame_rgb, task=task.lower(), version=version)
54
  elif task.lower() == 'depth':
55
  processed_frame = process_depth(frame_rgb, task=task.lower(), version=version)
56
+ elif task.lower() == 'normal':
57
+ processed_frame = process_normal(frame_rgb, task=task.lower(), version=version)
58
  else:
59
  processed_frame = None
60
  print(f"Tarea no soportada: {task}")
inference/normal.py CHANGED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import torch.nn.functional as F
3
+ # import numpy as np
4
+ # import cv2
5
+ # from PIL import Image
6
+ # from config import SAPIENS_LITE_MODELS_PATH
7
+
8
+ # # Example usage
9
+ # TASK = 'normal'
10
+ # VERSION = 'sapiens_0.3b'
11
+
12
+ # model_path = get_model_path(TASK, VERSION)
13
+ # print(model_path)
14
+
15
+ # model = torch.jit.load(model_path)
16
+ # model.eval()
17
+ # model.to("cuda")
18
+
19
+ # import torch
20
+ # import torch.nn.functional as F
21
+ # import numpy as np
22
+ # import cv2
23
+
24
+ # def get_normal(image, normal_model, input_shape=(3, 1024, 768), device="cuda"):
25
+ # # Preprocess the image
26
+ # img = preprocess_image(image, input_shape)
27
+
28
+ # # Run the model
29
+ # with torch.no_grad():
30
+ # result = normal_model(img.to(device))
31
+
32
+ # # Post-process the output
33
+ # normal_map = post_process_normal(result, (image.shape[0], image.shape[1]))
34
+
35
+ # # Visualize the normal map
36
+ # normal_image = visualize_normal(normal_map)
37
+
38
+ # return normal_image, normal_map
39
+
40
+ # def preprocess_image(image, input_shape):
41
+ # img = cv2.resize(image, (input_shape[2], input_shape[1]), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
42
+ # img = torch.from_numpy(img)
43
+ # img = img[[2, 1, 0], ...].float()
44
+ # mean = torch.tensor([123.5, 116.5, 103.5]).view(-1, 1, 1)
45
+ # std = torch.tensor([58.5, 57.0, 57.5]).view(-1, 1, 1)
46
+ # img = (img - mean) / std
47
+ # return img.unsqueeze(0)
48
+
49
+ # def post_process_normal(result, original_shape):
50
+ # # Check the dimensionality of the result
51
+ # if result.dim() == 3:
52
+ # result = result.unsqueeze(0)
53
+ # elif result.dim() == 4:
54
+ # pass
55
+ # else:
56
+ # raise ValueError(f"Unexpected result dimension: {result.dim()}")
57
+
58
+ # # Ensure we're interpolating to the correct dimensions
59
+ # seg_logits = F.interpolate(result, size=original_shape, mode="bilinear", align_corners=False).squeeze(0)
60
+ # normal_map = seg_logits.float().cpu().numpy().transpose(1, 2, 0) # H x W x 3
61
+ # return normal_map
62
+
63
+ # def visualize_normal(normal_map):
64
+ # normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True)
65
+ # normal_map_normalized = normal_map / (normal_map_norm + 1e-5) # Add a small epsilon to avoid division by zero
66
+
67
+ # # Convert to 0-255 range and BGR format for visualization
68
+ # normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8)
69
+ # normal_map_vis = normal_map_vis[:, :, ::-1] # RGB to BGR
70
+
71
+ # return normal_map_vis
72
+
73
+ # def load_normal_model(checkpoint, use_torchscript=False):
74
+ # if use_torchscript:
75
+ # return torch.jit.load(checkpoint)
76
+ # else:
77
+ # model = torch.export.load(checkpoint).module()
78
+ # model = model.to("cuda")
79
+ # model = torch.compile(model, mode="max-autotune", fullgraph=True)
80
+ # return model
81
+
82
+ # import cv2
83
+ # import numpy as np
84
+
85
+ # # Load the model
86
+ # normal_model = load_normal_model(model_path, use_torchscript='_torchscript')
87
+
88
+ # # Load the image
89
+ # image = cv2.imread("/home/user/app/assets/image.webp")
90
+
91
+ # # Get the normal map and visualization
92
+ # normal_image, normal_map = get_normal(image, normal_model)
93
+
94
+ # # Save the results
95
+ # cv2.imwrite("output_normal_image.png", normal_image)
96
+
97
+ import torch
98
+ import torch.nn.functional as F
99
+ import numpy as np
100
+ import cv2
101
+ from PIL import Image
102
+ from config import SAPIENS_LITE_MODELS_PATH
103
+
104
+ def load_model(task, version):
105
+ try:
106
+ model_path = SAPIENS_LITE_MODELS_PATH[task][version]
107
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
+ model = torch.jit.load(model_path)
109
+ model.eval()
110
+ model.to(device)
111
+ return model, device
112
+ except KeyError as e:
113
+ print(f"Error: Tarea o versión inválida. {e}")
114
+ return None, None
115
+
116
+ def preprocess_image(image, input_shape):
117
+ img = cv2.resize(image, (input_shape[2], input_shape[1]), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
118
+ img = torch.from_numpy(img)
119
+ img = img[[2, 1, 0], ...].float()
120
+ mean = torch.tensor([123.5, 116.5, 103.5]).view(-1, 1, 1)
121
+ std = torch.tensor([58.5, 57.0, 57.5]).view(-1, 1, 1)
122
+ img = (img - mean) / std
123
+ return img.unsqueeze(0)
124
+
125
+ def post_process_normal(result, original_shape):
126
+ if result.dim() == 3:
127
+ result = result.unsqueeze(0)
128
+ elif result.dim() == 4:
129
+ pass
130
+ else:
131
+ raise ValueError(f"Unexpected result dimension: {result.dim()}")
132
+
133
+ seg_logits = F.interpolate(result, size=original_shape, mode="bilinear", align_corners=False).squeeze(0)
134
+ normal_map = seg_logits.float().cpu().numpy().transpose(1, 2, 0) # H x W x 3
135
+ return normal_map
136
+
137
+ def visualize_normal(normal_map):
138
+ normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True)
139
+ normal_map_normalized = normal_map / (normal_map_norm + 1e-5) # Add a small epsilon to avoid division by zero
140
+
141
+ normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8)
142
+ normal_map_vis = normal_map_vis[:, :, ::-1] # RGB to BGR
143
+
144
+ return normal_map_vis
145
+
146
+ def process_image_or_video(input_data, task='normal', version='sapiens_0.3b'):
147
+ model, device = load_model(task, version)
148
+ if model is None or device is None:
149
+ return None
150
+
151
+ input_shape = (3, 1024, 768)
152
+
153
+ def process_frame(frame):
154
+ if isinstance(frame, Image.Image):
155
+ frame = np.array(frame)
156
+
157
+ if frame.shape[2] == 4: # RGBA
158
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
159
+
160
+ img = preprocess_image(frame, input_shape)
161
+
162
+ with torch.no_grad():
163
+ result = model(img.to(device))
164
+
165
+ normal_map = post_process_normal(result, (frame.shape[0], frame.shape[1]))
166
+ normal_image = visualize_normal(normal_map)
167
+
168
+ return Image.fromarray(cv2.cvtColor(normal_image, cv2.COLOR_BGR2RGB))
169
+
170
+ if isinstance(input_data, np.ndarray): # Video frame
171
+ return process_frame(input_data)
172
+ elif isinstance(input_data, Image.Image): # Imagen
173
+ return process_frame(input_data)
174
+ else:
175
+ print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
176
+ return None