huntrezz's picture
Update app.py
8ad09ea verified
raw
history blame
2.03 kB
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import gradio as gr
import torch.nn.utils.prune as prune
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float32)
model.eval()
# Apply global unstructured pruning
parameters_to_prune = [
(module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # Prune 40% of weights
)
for module, _ in parameters_to_prune:
prune.remove(module, "weight")
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
)
model = model.to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
def preprocess_image(image):
image = cv2.resize(image, (128, 128))
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
return image / 255.0
@torch.inference_mode()
def process_frame(image):
if image is None:
return None
preprocessed = preprocess_image(image)
predicted_depth = model(preprocessed).predicted_depth
depth_map = predicted_depth.squeeze().cpu().numpy()
# Normalize depth map
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
# Create a more visually informative depth map
depth_color = cv2.applyColorMap((depth_map * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
# Blend original image with depth map for context
original_resized = cv2.resize(image, (128, 128))
blended = cv2.addWeighted(original_resized, 0.6, depth_color, 0.4, 0)
return blended
interface = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources="webcam", streaming=True),
outputs="image",
live=True
)
interface.launch()