Spaces:
Build error
Build error
joselobenitezg
commited on
Commit
·
9930f16
1
Parent(s):
2e49a94
add normal
Browse files- app.py +5 -0
- inference/normal.py +176 -0
app.py
CHANGED
@@ -8,6 +8,7 @@ import spaces
|
|
8 |
from inference.seg import process_image_or_video as process_seg
|
9 |
from inference.pose import process_image_or_video as process_pose
|
10 |
from inference.depth import process_image_or_video as process_depth
|
|
|
11 |
from config import SAPIENS_LITE_MODELS_PATH
|
12 |
|
13 |
def update_model_choices(task):
|
@@ -24,6 +25,8 @@ def process_image(input_image, task, version):
|
|
24 |
result = process_pose(input_image, task=task.lower(), version=version)
|
25 |
elif task.lower() == 'depth':
|
26 |
result = process_depth(input_image, task=task.lower(), version=version)
|
|
|
|
|
27 |
else:
|
28 |
result = None
|
29 |
print(f"Tarea no soportada: {task}")
|
@@ -50,6 +53,8 @@ def process_video(input_video, task, version):
|
|
50 |
processed_frame = process_pose(frame_rgb, task=task.lower(), version=version)
|
51 |
elif task.lower() == 'depth':
|
52 |
processed_frame = process_depth(frame_rgb, task=task.lower(), version=version)
|
|
|
|
|
53 |
else:
|
54 |
processed_frame = None
|
55 |
print(f"Tarea no soportada: {task}")
|
|
|
8 |
from inference.seg import process_image_or_video as process_seg
|
9 |
from inference.pose import process_image_or_video as process_pose
|
10 |
from inference.depth import process_image_or_video as process_depth
|
11 |
+
from inference.normal import process_image_or_video as process_normal
|
12 |
from config import SAPIENS_LITE_MODELS_PATH
|
13 |
|
14 |
def update_model_choices(task):
|
|
|
25 |
result = process_pose(input_image, task=task.lower(), version=version)
|
26 |
elif task.lower() == 'depth':
|
27 |
result = process_depth(input_image, task=task.lower(), version=version)
|
28 |
+
elif task.lower() == 'normal':
|
29 |
+
result = process_normal(input_image, task=task.lower(), version=version)
|
30 |
else:
|
31 |
result = None
|
32 |
print(f"Tarea no soportada: {task}")
|
|
|
53 |
processed_frame = process_pose(frame_rgb, task=task.lower(), version=version)
|
54 |
elif task.lower() == 'depth':
|
55 |
processed_frame = process_depth(frame_rgb, task=task.lower(), version=version)
|
56 |
+
elif task.lower() == 'normal':
|
57 |
+
processed_frame = process_normal(frame_rgb, task=task.lower(), version=version)
|
58 |
else:
|
59 |
processed_frame = None
|
60 |
print(f"Tarea no soportada: {task}")
|
inference/normal.py
CHANGED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import torch
|
2 |
+
# import torch.nn.functional as F
|
3 |
+
# import numpy as np
|
4 |
+
# import cv2
|
5 |
+
# from PIL import Image
|
6 |
+
# from config import SAPIENS_LITE_MODELS_PATH
|
7 |
+
|
8 |
+
# # Example usage
|
9 |
+
# TASK = 'normal'
|
10 |
+
# VERSION = 'sapiens_0.3b'
|
11 |
+
|
12 |
+
# model_path = get_model_path(TASK, VERSION)
|
13 |
+
# print(model_path)
|
14 |
+
|
15 |
+
# model = torch.jit.load(model_path)
|
16 |
+
# model.eval()
|
17 |
+
# model.to("cuda")
|
18 |
+
|
19 |
+
# import torch
|
20 |
+
# import torch.nn.functional as F
|
21 |
+
# import numpy as np
|
22 |
+
# import cv2
|
23 |
+
|
24 |
+
# def get_normal(image, normal_model, input_shape=(3, 1024, 768), device="cuda"):
|
25 |
+
# # Preprocess the image
|
26 |
+
# img = preprocess_image(image, input_shape)
|
27 |
+
|
28 |
+
# # Run the model
|
29 |
+
# with torch.no_grad():
|
30 |
+
# result = normal_model(img.to(device))
|
31 |
+
|
32 |
+
# # Post-process the output
|
33 |
+
# normal_map = post_process_normal(result, (image.shape[0], image.shape[1]))
|
34 |
+
|
35 |
+
# # Visualize the normal map
|
36 |
+
# normal_image = visualize_normal(normal_map)
|
37 |
+
|
38 |
+
# return normal_image, normal_map
|
39 |
+
|
40 |
+
# def preprocess_image(image, input_shape):
|
41 |
+
# img = cv2.resize(image, (input_shape[2], input_shape[1]), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
|
42 |
+
# img = torch.from_numpy(img)
|
43 |
+
# img = img[[2, 1, 0], ...].float()
|
44 |
+
# mean = torch.tensor([123.5, 116.5, 103.5]).view(-1, 1, 1)
|
45 |
+
# std = torch.tensor([58.5, 57.0, 57.5]).view(-1, 1, 1)
|
46 |
+
# img = (img - mean) / std
|
47 |
+
# return img.unsqueeze(0)
|
48 |
+
|
49 |
+
# def post_process_normal(result, original_shape):
|
50 |
+
# # Check the dimensionality of the result
|
51 |
+
# if result.dim() == 3:
|
52 |
+
# result = result.unsqueeze(0)
|
53 |
+
# elif result.dim() == 4:
|
54 |
+
# pass
|
55 |
+
# else:
|
56 |
+
# raise ValueError(f"Unexpected result dimension: {result.dim()}")
|
57 |
+
|
58 |
+
# # Ensure we're interpolating to the correct dimensions
|
59 |
+
# seg_logits = F.interpolate(result, size=original_shape, mode="bilinear", align_corners=False).squeeze(0)
|
60 |
+
# normal_map = seg_logits.float().cpu().numpy().transpose(1, 2, 0) # H x W x 3
|
61 |
+
# return normal_map
|
62 |
+
|
63 |
+
# def visualize_normal(normal_map):
|
64 |
+
# normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True)
|
65 |
+
# normal_map_normalized = normal_map / (normal_map_norm + 1e-5) # Add a small epsilon to avoid division by zero
|
66 |
+
|
67 |
+
# # Convert to 0-255 range and BGR format for visualization
|
68 |
+
# normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8)
|
69 |
+
# normal_map_vis = normal_map_vis[:, :, ::-1] # RGB to BGR
|
70 |
+
|
71 |
+
# return normal_map_vis
|
72 |
+
|
73 |
+
# def load_normal_model(checkpoint, use_torchscript=False):
|
74 |
+
# if use_torchscript:
|
75 |
+
# return torch.jit.load(checkpoint)
|
76 |
+
# else:
|
77 |
+
# model = torch.export.load(checkpoint).module()
|
78 |
+
# model = model.to("cuda")
|
79 |
+
# model = torch.compile(model, mode="max-autotune", fullgraph=True)
|
80 |
+
# return model
|
81 |
+
|
82 |
+
# import cv2
|
83 |
+
# import numpy as np
|
84 |
+
|
85 |
+
# # Load the model
|
86 |
+
# normal_model = load_normal_model(model_path, use_torchscript='_torchscript')
|
87 |
+
|
88 |
+
# # Load the image
|
89 |
+
# image = cv2.imread("/home/user/app/assets/image.webp")
|
90 |
+
|
91 |
+
# # Get the normal map and visualization
|
92 |
+
# normal_image, normal_map = get_normal(image, normal_model)
|
93 |
+
|
94 |
+
# # Save the results
|
95 |
+
# cv2.imwrite("output_normal_image.png", normal_image)
|
96 |
+
|
97 |
+
import torch
|
98 |
+
import torch.nn.functional as F
|
99 |
+
import numpy as np
|
100 |
+
import cv2
|
101 |
+
from PIL import Image
|
102 |
+
from config import SAPIENS_LITE_MODELS_PATH
|
103 |
+
|
104 |
+
def load_model(task, version):
|
105 |
+
try:
|
106 |
+
model_path = SAPIENS_LITE_MODELS_PATH[task][version]
|
107 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
108 |
+
model = torch.jit.load(model_path)
|
109 |
+
model.eval()
|
110 |
+
model.to(device)
|
111 |
+
return model, device
|
112 |
+
except KeyError as e:
|
113 |
+
print(f"Error: Tarea o versión inválida. {e}")
|
114 |
+
return None, None
|
115 |
+
|
116 |
+
def preprocess_image(image, input_shape):
|
117 |
+
img = cv2.resize(image, (input_shape[2], input_shape[1]), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
|
118 |
+
img = torch.from_numpy(img)
|
119 |
+
img = img[[2, 1, 0], ...].float()
|
120 |
+
mean = torch.tensor([123.5, 116.5, 103.5]).view(-1, 1, 1)
|
121 |
+
std = torch.tensor([58.5, 57.0, 57.5]).view(-1, 1, 1)
|
122 |
+
img = (img - mean) / std
|
123 |
+
return img.unsqueeze(0)
|
124 |
+
|
125 |
+
def post_process_normal(result, original_shape):
|
126 |
+
if result.dim() == 3:
|
127 |
+
result = result.unsqueeze(0)
|
128 |
+
elif result.dim() == 4:
|
129 |
+
pass
|
130 |
+
else:
|
131 |
+
raise ValueError(f"Unexpected result dimension: {result.dim()}")
|
132 |
+
|
133 |
+
seg_logits = F.interpolate(result, size=original_shape, mode="bilinear", align_corners=False).squeeze(0)
|
134 |
+
normal_map = seg_logits.float().cpu().numpy().transpose(1, 2, 0) # H x W x 3
|
135 |
+
return normal_map
|
136 |
+
|
137 |
+
def visualize_normal(normal_map):
|
138 |
+
normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True)
|
139 |
+
normal_map_normalized = normal_map / (normal_map_norm + 1e-5) # Add a small epsilon to avoid division by zero
|
140 |
+
|
141 |
+
normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8)
|
142 |
+
normal_map_vis = normal_map_vis[:, :, ::-1] # RGB to BGR
|
143 |
+
|
144 |
+
return normal_map_vis
|
145 |
+
|
146 |
+
def process_image_or_video(input_data, task='normal', version='sapiens_0.3b'):
|
147 |
+
model, device = load_model(task, version)
|
148 |
+
if model is None or device is None:
|
149 |
+
return None
|
150 |
+
|
151 |
+
input_shape = (3, 1024, 768)
|
152 |
+
|
153 |
+
def process_frame(frame):
|
154 |
+
if isinstance(frame, Image.Image):
|
155 |
+
frame = np.array(frame)
|
156 |
+
|
157 |
+
if frame.shape[2] == 4: # RGBA
|
158 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
159 |
+
|
160 |
+
img = preprocess_image(frame, input_shape)
|
161 |
+
|
162 |
+
with torch.no_grad():
|
163 |
+
result = model(img.to(device))
|
164 |
+
|
165 |
+
normal_map = post_process_normal(result, (frame.shape[0], frame.shape[1]))
|
166 |
+
normal_image = visualize_normal(normal_map)
|
167 |
+
|
168 |
+
return Image.fromarray(cv2.cvtColor(normal_image, cv2.COLOR_BGR2RGB))
|
169 |
+
|
170 |
+
if isinstance(input_data, np.ndarray): # Video frame
|
171 |
+
return process_frame(input_data)
|
172 |
+
elif isinstance(input_data, Image.Image): # Imagen
|
173 |
+
return process_frame(input_data)
|
174 |
+
else:
|
175 |
+
print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
|
176 |
+
return None
|