Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import torchvision.transforms as transforms
|
@@ -83,21 +92,26 @@ def apply_gaussian_blur(image, sigma):
|
|
83 |
|
84 |
return Image.fromarray(blurred.astype(np.uint8))
|
85 |
|
86 |
-
# Initialize depth estimation pipeline
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
)
|
|
|
93 |
|
94 |
def process_image(image, blur_type, gaussian_sigma, lens_min_sigma, lens_max_sigma):
|
95 |
"""Main processing function for Gradio interface"""
|
|
|
|
|
|
|
96 |
processed_image = preprocess_image(image)
|
97 |
|
98 |
if blur_type == "Gaussian Blur":
|
99 |
result = apply_gaussian_blur(processed_image, gaussian_sigma)
|
100 |
else: # Lens Blur
|
|
|
101 |
depth_map = estimate_depth(processed_image, pipe)
|
102 |
result = apply_depth_aware_blur(processed_image, depth_map, lens_max_sigma, lens_min_sigma)
|
103 |
|
|
|
1 |
+
try:
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
except ImportError:
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "torchvision"])
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
|
11 |
import gradio as gr
|
|
|
12 |
import numpy as np
|
13 |
from PIL import Image
|
14 |
import torchvision.transforms as transforms
|
|
|
92 |
|
93 |
return Image.fromarray(blurred.astype(np.uint8))
|
94 |
|
95 |
+
# Initialize depth estimation pipeline (moved inside the processing function to avoid CUDA issues)
|
96 |
+
def get_depth_pipeline():
|
97 |
+
return pipeline(
|
98 |
+
task="depth-estimation",
|
99 |
+
model="depth-anything/Depth-Anything-V2-Small-hf",
|
100 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
101 |
+
device=0 if torch.cuda.is_available() else -1
|
102 |
+
)
|
103 |
|
104 |
def process_image(image, blur_type, gaussian_sigma, lens_min_sigma, lens_max_sigma):
|
105 |
"""Main processing function for Gradio interface"""
|
106 |
+
if image is None:
|
107 |
+
return None
|
108 |
+
|
109 |
processed_image = preprocess_image(image)
|
110 |
|
111 |
if blur_type == "Gaussian Blur":
|
112 |
result = apply_gaussian_blur(processed_image, gaussian_sigma)
|
113 |
else: # Lens Blur
|
114 |
+
pipe = get_depth_pipeline()
|
115 |
depth_map = estimate_depth(processed_image, pipe)
|
116 |
result = apply_depth_aware_blur(processed_image, depth_map, lens_max_sigma, lens_min_sigma)
|
117 |
|