shortcipher commited on
Commit
a772610
·
verified ·
1 Parent(s): 1da259a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space: 2D to 3D Stereo Pair Generator using Depth + LaMa Inpainting
2
+
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ from transformers import DPTForDepthEstimation, DPTFeatureExtractor
9
+ import requests
10
+ import tempfile
11
+ import subprocess
12
+ import os
13
+
14
+ # === DEVICE ===
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # === DEPTH MODEL ===
18
+ def load_depth_model():
19
+ model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
20
+ processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
21
+ return model, processor
22
+
23
+ @torch.no_grad()
24
+ def estimate_depth(image: Image.Image, model, processor):
25
+ image = image.resize((384, 384))
26
+ inputs = processor(images=image, return_tensors="pt").to(device)
27
+ depth = model(**inputs).predicted_depth
28
+ depth = torch.nn.functional.interpolate(
29
+ depth.unsqueeze(1),
30
+ size=image.size[::-1],
31
+ mode="bicubic",
32
+ align_corners=False,
33
+ ).squeeze().detach().cpu().numpy()
34
+ depth_min, depth_max = depth.min(), depth.max()
35
+ return (depth - depth_min) / (depth_max - depth_min)
36
+
37
+ def depth_to_disparity(depth, max_disp=32):
38
+ return (1.0 - depth) * max_disp
39
+
40
+ def generate_right_and_mask(image, disparity):
41
+ h, w = image.shape[:2]
42
+ right = np.zeros_like(image)
43
+ mask = np.ones((h, w), dtype=np.uint8)
44
+
45
+ for y in range(h):
46
+ for x in range(w):
47
+ d = int(round(disparity[y, x]))
48
+ x_r = x - d
49
+ if 0 <= x_r < w:
50
+ right[y, x_r] = image[y, x]
51
+ mask[y, x_r] = 0
52
+ return right, mask
53
+
54
+ # === LAMA INPAINTING ===
55
+ LAMA_API = "https://huggingface.co/spaces/saic-mdal/lama-inpainting"
56
+
57
+ def run_lama_inpainting(image_bgr, mask):
58
+ img = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
59
+ mask_img = Image.fromarray(mask * 255).convert("RGB")
60
+
61
+ # Save temporarily
62
+ tmp_dir = tempfile.mkdtemp()
63
+ img_path = os.path.join(tmp_dir, "input.png")
64
+ mask_path = os.path.join(tmp_dir, "mask.png")
65
+ img.save(img_path)
66
+ mask_img.save(mask_path)
67
+
68
+ # Use Hugging Face's API-compatible request
69
+ files = {"image": open(img_path, "rb"), "mask": open(mask_path, "rb")}
70
+ response = requests.post(f"{LAMA_API}/run/predict", files=files)
71
+ if response.status_code == 200:
72
+ result = Image.open(requests.get(response.json()["data"][0]["name"], stream=True).raw)
73
+ return cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)
74
+ else:
75
+ raise Exception("LAMA inpainting failed")
76
+
77
+ # === APP LOGIC ===
78
+ depth_model, depth_processor = load_depth_model()
79
+
80
+ def stereo_pipeline(image_pil):
81
+ image = image_pil.convert("RGB")
82
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
83
+
84
+ depth = estimate_depth(image, depth_model, depth_processor)
85
+ disparity = depth_to_disparity(depth)
86
+ right_img, mask = generate_right_and_mask(image_cv, disparity)
87
+ right_filled = run_lama_inpainting(right_img, mask)
88
+
89
+ left = image_pil
90
+ right = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
91
+ return left, right
92
+
93
+ # === GRADIO UI ===
94
+ demo = gr.Interface(
95
+ fn=stereo_pipeline,
96
+ inputs=gr.Image(type="pil", label="Upload 2D Image"),
97
+ outputs=[
98
+ gr.Image(label="Left Eye (Original)"),
99
+ gr.Image(label="Right Eye (AI Generated)")
100
+ ],
101
+ title="2D to 3D Stereo Generator with LaMa Inpainting",
102
+ description="Generates a stereo pair from a 2D image using depth estimation and LaMa AI inpainting to handle occluded pixels in the right-eye view."
103
+ )
104
+
105
+ demo.launch()