import cv2 import numpy as np import scipy.sparse as sp import scipy.sparse.linalg as splin from numba import jit import gradio as gr @jit(nopython=True) def build_poisson_sparse_matrix(ys, xs, im2var, img_s, img_t, mask): nnz = len(ys) img_s_h, img_s_w = img_s.shape A_data = np.zeros(16 * nnz, dtype=np.float64) A_rows = np.zeros(16 * nnz, dtype=np.int32) A_cols = np.zeros(16 * nnz, dtype=np.int32) b = np.zeros(4 * nnz, dtype=np.float64) offsets = np.array([(0, 1), (0, -1), (1, 0), (-1, 0)]) idx = 0 for n in range(nnz): y, x = ys[n], xs[n] for i in range(4): dy, dx = offsets[i] n_y, n_x = y + dy, x + dx e = 4 * n + i if 0 <= n_y < img_s_h and 0 <= n_x < img_s_w: A_data[idx] = 1 A_rows[idx] = e A_cols[idx] = im2var[y, x] idx += 1 b[e] = img_s[y, x] - img_s[n_y, n_x] if im2var[n_y, n_x] != -1: A_data[idx] = -1 A_rows[idx] = e A_cols[idx] = im2var[n_y, n_x] idx += 1 else: b[e] += img_t[n_y, n_x] return A_data[:idx], A_rows[:idx], A_cols[:idx], b def poisson_blend_fast_jit(img_s: np.ndarray, mask: np.ndarray, img_t: np.ndarray) -> np.ndarray: nnz = np.sum(mask > 0) im2var = np.full(mask.shape, -1, dtype=np.int32) im2var[mask > 0] = np.arange(nnz) ys, xs = np.nonzero(mask) A_data, A_rows, A_cols, b = build_poisson_sparse_matrix(ys, xs, im2var, img_s, img_t, mask) A = sp.csr_matrix((A_data, (A_rows, A_cols)), shape=(4*nnz, nnz)) v = splin.lsqr(A, b)[0] img_t_out = img_t.copy() img_t_out[mask > 0] = v[im2var[mask > 0]] return np.clip(img_t_out, 0, 1) @jit(nopython=True) def neighbours(i: int, j: int, max_i: int, max_j: int): pairs = [] for n in (-1, 1): if 0 <= i+n <= max_i: pairs.append((i+n, j)) if 0 <= j+n <= max_j: pairs.append((i, j+n)) return pairs @jit(nopython=True) def build_mixed_blend_sparse_matrix(ys, xs, im2var, img_s, img_t, mask): nnz = len(ys) img_s_h, img_s_w = img_s.shape A_data = np.zeros(8 * nnz, dtype=np.float64) A_rows = np.zeros(8 * nnz, dtype=np.int32) A_cols = np.zeros(8 * nnz, dtype=np.int32) b = np.zeros(4 * nnz, dtype=np.float64) idx = 0 e = 0 for n in range(nnz): y, x = ys[n], xs[n] for n_y, n_x in neighbours(y, x, img_s_h-1, img_s_w-1): ds = img_s[y, x] - img_s[n_y, n_x] dt = img_t[y, x] - img_t[n_y, n_x] d = ds if abs(ds) > abs(dt) else dt A_data[idx] = 1 A_rows[idx] = e A_cols[idx] = im2var[y, x] idx += 1 b[e] = d if im2var[n_y, n_x] != -1: A_data[idx] = -1 A_rows[idx] = e A_cols[idx] = im2var[n_y, n_x] idx += 1 else: b[e] += img_t[n_y, n_x] e += 1 return A_data[:idx], A_rows[:idx], A_cols[:idx], b[:e] def mixed_blend_fast_jit(img_s: np.ndarray, mask: np.ndarray, img_t: np.ndarray) -> np.ndarray: nnz = np.sum(mask > 0) im2var = np.full(mask.shape, -1, dtype=np.int32) im2var[mask > 0] = np.arange(nnz) ys, xs = np.nonzero(mask) A_data, A_rows, A_cols, b = build_mixed_blend_sparse_matrix(ys, xs, im2var, img_s, img_t, mask) A = sp.csr_matrix((A_data, (A_rows, A_cols)), shape=(len(b), nnz)) v = splin.spsolve(A.T @ A, A.T @ b) img_t_out = img_t.copy() img_t_out[mask > 0] = v[im2var[mask > 0]] return np.clip(img_t_out, 0, 1) def _2d_gaussian(sigma: float) -> np.ndarray: ksize = np.int64(np.ceil(sigma)*6+1) gaussian_1d = cv2.getGaussianKernel(ksize, sigma) return gaussian_1d * np.transpose(gaussian_1d) def _low_pass_filter(img: np.ndarray, sigma: float) -> np.ndarray: return cv2.filter2D(img, -1, _2d_gaussian(sigma)) def _high_pass_filter(img: np.ndarray, sigma: float) -> np.ndarray: return img - _low_pass_filter(img, sigma) def _gaus_pyramid(img: np.ndarray, depth: int, sigma: int): _im = img.copy() pyramid = [] for d in range(depth-1): _im = _low_pass_filter(_im.copy(), sigma) pyramid.append(_im) _im = cv2.pyrDown(_im) return pyramid def _lap_pyramid(img: np.ndarray, depth: int, sigma: int): _im = img.copy() pyramid = [] for d in range(depth-1): lap = _high_pass_filter(_im.copy(), sigma) pyramid.append(lap) _im = cv2.pyrDown(_im) return pyramid def _blend(img1: np.ndarray, img2: np.ndarray, mask: np.ndarray) -> np.ndarray: return img1 * mask + img2 * (1.0 - mask) def laplacian_blend(img1: np.ndarray, img2: np.ndarray, mask: np.ndarray, depth: int, sigma: int) -> np.ndarray: mask_gaus_pyramid = _gaus_pyramid(mask, depth, sigma) img1_lap_pyramid, img2_lap_pyramid = _lap_pyramid(img1, depth, sigma), _lap_pyramid(img2, depth, sigma) blended = [_blend(obj, bg, mask) for obj, bg, mask in zip(img1_lap_pyramid, img2_lap_pyramid, mask_gaus_pyramid)][::-1] h, w = blended[0].shape[:2] img1 = cv2.resize(img1, (w, h)) img2 = cv2.resize(img2, (w, h)) mask = cv2.resize(mask, (w, h)) blanded_img = _blend(img1, img2, mask) blanded_img = cv2.resize(blanded_img, blended[0].shape[:2]) imgs = [] for d in range(0, depth-1): gaussian_img = _low_pass_filter(blanded_img.copy(), sigma) reconstructed_img = cv2.add(blended[d], gaussian_img) imgs.append(reconstructed_img) blanded_img = cv2.pyrUp(reconstructed_img) return np.clip(imgs[-1], 0, 1) def get_image(img_input, mask=False, scale=True): if isinstance(img_input, dict) and 'composite' in img_input: img = img_input['composite'] elif isinstance(img_input, np.ndarray): img = img_input elif isinstance(img_input, str): img = cv2.imread(img_input) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) else: raise ValueError("Unsupported image input type") if mask: if len(img.shape) == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) return np.where(img > 127, 1, 0) # Threshold at 127 for the mask if scale and img.dtype != np.float64: return img.astype('float64') / 255.0 return img def blend_images(bg_img, obj_img, mask_img, method): bg_img = get_image(bg_img) obj_img = get_image(obj_img) mask_img = get_image(mask_img, mask=True) if method == "Poisson": blend_img = np.zeros_like(bg_img) for b in range(3): blend_img[:,:,b] = poisson_blend_fast_jit(obj_img[:,:,b], mask_img, bg_img[:,:,b].copy()) elif method == "Mixed Gradient": blend_img = np.zeros_like(bg_img) for b in range(3): blend_img[:,:,b] = mixed_blend_fast_jit(obj_img[:,:,b], mask_img, bg_img[:,:,b].copy()) elif method == "Laplacian": mask_stack = np.stack((mask_img.astype(float),) * 3, axis=-1) blend_img = laplacian_blend(obj_img, bg_img, mask_stack, 5, 25.0) return (blend_img * 255).astype(np.uint8) with gr.Blocks(theme='bethecloud/storj_theme') as iface: gr.HTML("