NightRaven109 commited on
Commit
c4a40e9
·
verified ·
1 Parent(s): 6ffe57a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from depth_anything_v2.dpt import DepthAnythingV2
6
+
7
+ # Model initialization
8
+ model_configs = {
9
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
10
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
11
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
12
+ }
13
+
14
+ def initialize_model():
15
+ encoder = 'vitl'
16
+ max_depth = 1
17
+
18
+ model = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth})
19
+
20
+ # Load checkpoint
21
+ checkpoint = torch.load('checkpoints/model2.pth', map_location='cpu')
22
+
23
+ # Get state dict
24
+ state_dict = {}
25
+ for key in checkpoint.keys():
26
+ if key not in ['optimizer', 'epoch', 'previous_best']:
27
+ state_dict = checkpoint[key]
28
+
29
+ # Handle module prefix
30
+ my_state_dict = {}
31
+ for key in state_dict.keys():
32
+ new_key = key.replace('module.', '')
33
+ my_state_dict[new_key] = state_dict[key]
34
+
35
+ model.load_state_dict(my_state_dict)
36
+ model.eval()
37
+ return model
38
+
39
+ # Initialize model globally
40
+ MODEL = initialize_model()
41
+
42
+ def process_image(input_image):
43
+ """
44
+ Process the input image and return depth maps
45
+ """
46
+ # Convert from RGB to BGR (since original code uses cv2.imread which loads in BGR)
47
+ if input_image is None:
48
+ return None, None
49
+
50
+ input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB_BGR)
51
+
52
+ # Get depth map
53
+ depth = MODEL.infer_image(input_image)
54
+
55
+ # Normalize depth for visualization (0-255)
56
+ depth_normalized = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
57
+
58
+ # Apply colormap for better visualization
59
+ depth_colormap = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_INFERNO)
60
+ depth_colormap = cv2.cvtColor(depth_colormap, cv2.COLOR_BGR2RGB) # Convert back to RGB for Gradio
61
+
62
+ return depth_normalized, depth_colormap
63
+
64
+ # Create Gradio interface
65
+ def gradio_interface(input_img):
66
+ depth_raw, depth_colored = process_image(input_img)
67
+ return [input_img, depth_raw, depth_colored]
68
+
69
+ # Define interface
70
+ iface = gr.Interface(
71
+ fn=gradio_interface,
72
+ inputs=gr.Image(label="Input Image"),
73
+ outputs=[
74
+ gr.Image(label="Original Image"),
75
+ gr.Image(label="Raw Depth Map"),
76
+ gr.Image(label="Colored Depth Map")
77
+ ],
78
+ title="Depth Estimation",
79
+ description="Upload an image to generate its depth map.",
80
+ examples=["image.jpg"] # Add example images here
81
+ )
82
+
83
+ # Launch the app
84
+ if __name__ == "__main__":
85
+ iface.launch()