import os from typing import Tuple, Union import gradio as gr import numpy as np from PIL import Image, ImageChops, ImageOps samples_dir = "./samples" class ImageInfo: def __init__( self, size: Tuple[int, int], channels: int, data_type: str, min_val: float, max_val: float, ): self.size = size self.channels = channels self.data_type = data_type self.min_val = min_val self.max_val = max_val @classmethod def from_pil(cls, pil_image: Image.Image) -> "ImageInfo": size = (pil_image.width, pil_image.height) channels = len(pil_image.getbands()) data_type = str(pil_image.mode) extrema = pil_image.getextrema() if channels > 1: # Multi-band image min_val = min([band[0] for band in extrema]) max_val = max([band[1] for band in extrema]) else: # Single-band image min_val, max_val = extrema return cls(size, channels, data_type, min_val, max_val) @classmethod def from_numpy(cls, np_array: np.ndarray) -> "ImageInfo": if len(np_array.shape) > 3: raise ValueError(f"Unsupported array shape: {np_array.shape}") size = (np_array.shape[1], np_array.shape[0]) channels = 1 if len(np_array.shape) == 2 else np_array.shape[2] data_type = str(np_array.dtype) min_val, max_val = np_array.min(), np_array.max() return cls(size, channels, data_type, min_val, max_val) @classmethod def from_any(cls, image: Union[Image.Image, np.ndarray]) -> "ImageInfo": if isinstance(image, np.ndarray): return cls.from_numpy(image) elif isinstance(image, Image.Image): return cls.from_pil(image) else: raise ValueError(f"Unsupported image type: {type(image)}") def __str__(self) -> str: return f"{str(self.size)} {self.channels}C {self.data_type} {round(self.min_val, 2)}min/{round(self.max_val, 2)}max" @property def aspect_ratio(self) -> float: return self.size[0] / self.size[1] def nextpow2(n): """Find the next power of 2 greater than or equal to `n`.""" return int(2 ** np.ceil(np.log2(n))) def pad_image_nextpow2(image): print("-" * 80) print("pad_image_nextpow2: ") print(ImageInfo.from_any(image)) assert image.ndim in (2, 3), f"Expected 2D or 3D image. Got {image.ndim}D." height, width, channels = image.shape height_new = nextpow2(height) width_new = nextpow2(width) height_diff = height_new - height width_diff = width_new - width image = np.pad( image, ( (height_diff // 2, height_diff - height_diff // 2), (width_diff // 2, width_diff - width_diff // 2), (0, 0), ) if channels == 3 else ( (height_diff // 2, height_diff - height_diff // 2), (width_diff // 2, width_diff - width_diff // 2), ), mode="constant", # mode="edge", # mode="linear_ramp", # mode="maximum", # mode="mean", # mode="median", # mode="minimum", # mode="reflect", # mode="symmetric", # mode="wrap", # mode="empty", ) print(ImageInfo.from_any(image)) return image def get_fft(image): print("-" * 80) print(f"get_fft: {image.shape}") print("image:", ImageInfo.from_any(image)) fft = np.fft.fft2(image, axes=np.arange(image.ndim)) fft = np.fft.fftshift(fft) return fft def get_ifft_image(fft): print("-" * 80) print(f"get_ifft_image: {fft.shape}") ifft = np.fft.ifftshift(fft) ifft = np.fft.ifft2(ifft, axes=np.arange(fft.ndim)) # we only need the real part ifft_image = np.real(ifft) # remove padding # ifft = ifft[ # h_diff // 2 : h_diff // 2 + original_shape[0], # w_diff // 2 : w_diff // 2 + original_shape[1], # ] ifft_image = (ifft_image - np.min(ifft_image)) / ( np.max(ifft_image) - np.min(ifft_image) ) ifft_image = ifft_image * 255 ifft_image = ifft_image.astype(np.uint8) return ifft_image def fft_mag_image(fft): print("-" * 80) print(f"fft_mag_image: {fft.shape}") fft_mag = np.abs(fft) fft_mag = np.log(fft_mag + 1) # scale 0 to 1 fft_mag = (fft_mag - np.min(fft_mag)) / (np.max(fft_mag) - np.min(fft_mag) + 1e-6) # scale to (0, 255) fft_mag = fft_mag * 255 fft_mag = fft_mag.astype(np.uint8) return fft_mag def fft_phase_image(fft): print("-" * 80) print(f"fft_phase_image: {fft.shape}") fft_phase = np.angle(fft) fft_phase = fft_phase + np.pi fft_phase = fft_phase / (2 * np.pi) # scale 0 to 1 fft_phase = (fft_phase - np.min(fft_phase)) / ( np.max(fft_phase) - np.min(fft_phase) + 1e-6 ) # scale to (0, 255) fft_phase = fft_phase * 255 fft_phase = fft_phase.astype(np.uint8) return fft_phase def onclick_process_fft(state, inp_image, mask_opacity, inverted_mask, pad): print("-" * 80) print("onclick_process_fft:") if isinstance(inp_image, dict): if "image" not in inp_image: raise gr.Error("Please upload or select an image first.") image, mask = inp_image["image"], inp_image["mask"] print("image:", ImageInfo.from_any(image)) print("mask:", ImageInfo.from_any(image)) image = Image.fromarray(image) mask = Image.fromarray(mask).convert(image.mode) if not inverted_mask: mask = ImageOps.invert(mask) image_final = ImageChops.multiply(image, mask) image_final = Image.blend(image, image_final, mask_opacity) image_final = image_final.convert(image.mode) image_final = np.array(image_final) elif isinstance(inp_image, np.ndarray): image_final = inp_image else: raise gr.Error("Please upload or select an image first.") print("image_final:", ImageInfo.from_any(image_final)) if pad: image_final = pad_image_nextpow2(image_final) state["inp_image"] = image_final image_mag = fft_mag_image(get_fft(image_final)) image_phase = fft_phase_image(get_fft(image_final)) return ( state, [ (image_final, "Input Image (Final)"), (image_mag, "FFT Magnitude (Original)"), (image_phase, "FFT Phase (Original)"), ], image_mag, image_phase, ) def onclick_process_ifft(state, mag_and_mask, phase_and_mask): print("-" * 80) print("onclick_process_ifft:") if state["inp_image"] is None: raise gr.Error("Please process FFT first.") image = state["inp_image"] # h_new = nextpow2(original_shape[0]) # w_new = nextpow2(original_shape[1]) # h_diff = h_new - original_shape[0] # w_diff = w_new - original_shape[1] mask_mag = mag_and_mask["mask"] print("mag_mask:", ImageInfo.from_any(mask_mag)) mask_phase = phase_and_mask["mask"] print("phase_mask:", ImageInfo.from_any(mask_phase)) fft = get_fft(state["inp_image"]) print(f"fft: {fft.shape}") fft_mag = np.where(mask_mag == 255, 0, np.abs(fft)) fft_phase = np.where(mask_phase == 255, 0, np.angle(fft)) fft = fft_mag * np.exp(1j * fft_phase) ifft_image = get_ifft_image(fft) image_mag = fft_mag_image(fft) image_phase = fft_phase_image(fft) return ( [ (image, "Input Image (Final)"), (image_mag, "FFT Magnitude (Final)"), (image_phase, "FFT Phase (Final)"), ], ifft_image, ) def get_start_image(): return (np.ones((512, 512, 3)) * 255).astype(np.uint8) def update_image_input(state, selection): print("-" * 80) print("update_image_input:") print(f"selection: {selection}") if not selection: white_image = get_start_image() state["inp_image"] = white_image return ( state, white_image, [white_image], None, None, None, ) image_path = os.path.join(samples_dir, selection) print(f"image_path: {image_path}") if not os.path.exists(image_path): raise gr.Error(f"Image not found: {image_path}") image = Image.open(image_path) image = np.array(image) state["inp_image"] = image return ( state, image, [image], None, None, None, ) def clear_image_input(state): print("-" * 80) print("clear_image_input:") state["inp_image"] = None return ( state, None, [], None, None, None, ) css = """ .fft_mag > .image-container > button > div:first-child { display: none; } .fft_phase > .image-container > button > div:first-child { display: none; } .ifft_img > .image-container > button > div:first-child { display: none; } """ with gr.Blocks(css=css) as demo: state = gr.State( { "inp_image": None, }, ) with gr.Row(): with gr.Column(): inp_image = gr.Image( value=get_start_image(), label="Input Image", height=512, type="numpy", interactive=True, tool="sketch", mask_opacity=1.0, elem_classes=["inp_img"], ) files = os.listdir(samples_dir) files = sorted(files) inp_samples = gr.Dropdown( choices=files, label="Select Example Image", ) with gr.Column(): out_gallery = gr.Gallery( label="Input Gallery", height=512, rows=1, columns=3, allow_preview=True, preview=False, selected_index=None, ) with gr.Row(): inp_mask_opacity = gr.Slider( label="Mask Opacity", minimum=0.0, maximum=1.0, step=0.05, value=1.0, ) inp_invert_mask = gr.Checkbox( label="Invert Mask", value=False, ) inp_pad = gr.Checkbox( label="Pad NextPow2", value=True, ) btn_fft = gr.Button("Process FFT") with gr.Row(): out_fft_mag = gr.Image( label="FFT Magnitude", height=512, width=512, type="numpy", interactive=True, # source="canvas", tool="sketch", mask_opacity=1.0, elem_classes=["fft_mag"], ) out_fft_phase = gr.Image( label="FFT Phase", height=512, width=512, type="numpy", interactive=True, # source="canvas", tool="sketch", mask_opacity=1.0, elem_classes=["fft_phase"], ) btn_ifft = gr.Button("Process IFFT") out_ifft = gr.Image( label="IFFT", height=512, type="numpy", interactive=True, show_download_button=True, elem_classes=["ifft_img"], ) inp_image.clear( clear_image_input, [state], [state, inp_samples, out_gallery, out_fft_mag, out_fft_phase, out_ifft], ) # Set up event listener for the Dropdown component to update the image input inp_samples.change( update_image_input, [state, inp_samples], [state, inp_image, out_gallery, out_fft_mag, out_fft_phase, out_ifft], ) # Set up events for fft processing btn_fft.click( onclick_process_fft, [state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad], [state, out_gallery, out_fft_mag, out_fft_phase], ) out_fft_mag.clear( onclick_process_fft, [state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad], [state, out_gallery, out_fft_mag, out_fft_phase], ) out_fft_phase.clear( onclick_process_fft, [state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad], [state, out_gallery, out_fft_mag, out_fft_phase], ) # inp_image.edit( # get_fft_images, # [state, inp_image], # [out_gallery, out_fft_mag, out_fft_phase], # ) # Set up events for ifft processing btn_ifft.click( onclick_process_ifft, [state, out_fft_mag, out_fft_phase], [out_gallery, out_ifft], ) # out_fft_mag.edit( # get_ifft_image, # [state, out_fft_mag, out_fft_phase], # [out_ifft], # ) # out_fft_phase.edit( # get_ifft_image, # [state, out_fft_mag, out_fft_phase], # [out_ifft], # ) if __name__ == "__main__": demo.launch()