import numpy as np
import cv2 as cv
from PIL import Image

def norm_mat(mat):
    return cv.normalize(mat, None, 0, 255, cv.NORM_MINMAX).astype(np.uint8)

def minmax_dev(patch, mask):
    c = patch[1, 1]
    minimum, maximum, _, _ = cv.minMaxLoc(patch, mask)
    if c < minimum:
        return -1
    if c > maximum:
        return +1
    return 0

def blk_filter(img, radius):
    result = np.zeros_like(img, np.float32)
    rows, cols = result.shape
    block = 2 * radius + 1
    for i in range(radius, rows, block):
        for j in range(radius, cols, block):
            result[
                i - radius : i + radius + 1, j - radius : j + radius + 1
            ] = np.std(
                img[i - radius : i + radius + 1, j - radius : j + radius + 1]
            )
    return cv.normalize(result, None, 0, 127, cv.NORM_MINMAX, cv.CV_8UC1)

def preprocess(image, channel=4, radius=2):
    if not isinstance(image, np.ndarray):
        image = np.array(image)  # Ensure image is a NumPy array
    if channel == 0:
        img = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
    elif channel == 4:
        b, g, r = cv.split(image.astype(np.float64))
        img = cv.sqrt(cv.pow(b, 2) + cv.pow(g, 2) + cv.pow(r, 2))
    else:
        img = image[:, :, 3 - channel]
    kernel = 3
    border = kernel // 2
    shape = (img.shape[0] - kernel + 1, img.shape[1] - kernel + 1, kernel, kernel)
    strides = 2 * img.strides
    patches = np.lib.stride_tricks.as_strided(img, shape=shape, strides=strides)
    patches = patches.reshape((-1, kernel, kernel))
    mask = np.full((kernel, kernel), 255, dtype=np.uint8)
    mask[border, border] = 0
    blocks = [0] * shape[0] * shape[1]
    for i, patch in enumerate(patches):
        blocks[i] = minmax_dev(patch, mask)
    output = np.array(blocks).reshape(shape[:-2])
    output = cv.copyMakeBorder(
        output, border, border, border, border, cv.BORDER_CONSTANT
    )
    low = output == -1
    high = output == +1
    minmax = np.zeros_like(image)
    if radius > 0:
        radius += 3
        low = blk_filter(low, radius)
        high = blk_filter(high, radius)
        if channel <= 2:
            minmax[:, :, 2 - channel] = low
            minmax[:, :, 2 - channel] += high
        else:
            minmax = np.repeat(low[:, :, np.newaxis], 3, axis=2)
            minmax += np.repeat(high[:, :, np.newaxis], 3, axis=2)
        minmax = norm_mat(minmax)
    else:
        if channel == 0:
            minmax[low] = [0, 0, 255]
            minmax[high] = [0, 0, 255]
        elif channel == 1:
            minmax[low] = [0, 255, 0]
            minmax[high] = [0, 255, 0]
        elif channel == 2:
            minmax[low] = [255, 0, 0]
            minmax[high] = [255, 0, 0]
        elif channel == 3:
            minmax[low] = [255, 255, 255]
            minmax[high] = [255, 255, 255]
    return Image.fromarray(minmax)