import gradio as gr import h5py import mrcfile import numpy as np from PIL import Image from omegaconf import DictConfig import torch from pathlib import Path from torchvision.transforms import functional as F import torchvision.transforms.v2 as v2 import spaces from draco.configuration import CfgNode from draco.model import ( build_model, load_pretrained ) example_files = { "EMPIAR-10078": "example/empiar-10078-00-000093-full_patch_aligned.h5", "EMPIAR-10154": "example/empiar-10154-00-000130-full_patch_aligned.h5", "EMPIAR-10185": "example/empiar-10185-00-000032-full_patch_aligned.h5", "EMPIAR-10200": "example/empiar-10200-00-000139-full_patch_aligned.h5", "EMPIAR-10216": "example/empiar-10216-00-000036-full_patch_aligned.h5" } class DRACODenoiser(object): def __init__(self, cfg: DictConfig, ckpt_path: Path, ) -> None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.transform = self.build_transform() self.model = build_model(cfg).to(self.device).eval() self.model = load_pretrained(self.model, ckpt_path, self.device) self.patch_size = cfg.MODEL.PATCH_SIZE def patchify(self, image: torch.Tensor) -> torch.Tensor: B, C, H, W = image.shape P = self.patch_size if H % P != 0 or W % P != 0: image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0) patches = image.unfold(2, P, P).unfold(3, P, P) patches = patches.permute(0, 2, 3, 4, 5, 1) patches = patches.reshape(B, -1, P * P * C) return patches def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor: B = patches.shape[0] P = self.patch_size images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1) images = images.permute(0, 5, 1, 3, 2, 4) images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P) images = images[..., :H, :W] return images @classmethod def build_transform(cls) -> v2.Compose: return v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True) ]) @spaces.GPU def inference(self, image: Image.Image) -> None: W, H = image.size x = self.transform(image).unsqueeze(0).to(self.device) y = self.model(x) x = self.patchify(x).detach().cpu().numpy() denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy() return denoised # Model Initialization cfg = CfgNode.load_yaml_with_base(Path("draco.yaml")) CfgNode.merge_with_dotlist(cfg, []) ckpt_path = Path("denoise.ckpt") denoiser = DRACODenoiser(cfg, ckpt_path) def Auto_contrast(image, t_mean=150.0/255.0, t_sd=40.0/255.0) -> np.ndarray: image = (image - image.min()) / (image.max() - image.min()) mean = image.mean() std = image.std() f = std / t_sd black = mean - t_mean * f white = mean + (1 - t_mean) * f new_image = np.clip(image, black, white) new_image = (new_image - black) / (white - black) return new_image def load_data(file_path) -> np.ndarray: if file_path.endswith('.h5'): with h5py.File(file_path, "r") as f: full_micrograph = f["micrograph"] if "micrograph" in f else f["data"] full_mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean() full_std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std() data = full_micrograph[:].astype(np.float32) elif file_path.endswith('.mrc'): with mrcfile.open(file_path, "r") as f: data = f.data[:].astype(np.float32) full_mean = data.mean() full_std = data.std() else: raise ValueError("Unsupported file format. Please upload a .mrc or .h5 file.") data = (data - full_mean) / full_std return data def display_crop(data, x_offset, y_offset, auto_contrast) -> Image: if data is None: return None crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024] original_image_normalized = Auto_contrast(crop) if auto_contrast else (crop - crop.min()) / (crop.max() - crop.min()) input_image = Image.fromarray((original_image_normalized * 255).astype(np.uint8)) return input_image @spaces.GPU def process_and_denoise(data, x_offset, y_offset, auto_contrast) -> Image: if data is None: return None crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024] denoised_data = denoiser.inference(Image.fromarray(crop)) denoised_data = denoised_data.squeeze() denoised_image_normalized = Auto_contrast(denoised_data) if auto_contrast else (denoised_data - denoised_data.min()) / (denoised_data.max() - denoised_data.min()) denoised_image = Image.fromarray((denoised_image_normalized * 255).astype(np.uint8)) return denoised_image def clear_images() -> tuple: return None, None, None, gr.update(value=0,maximum=1024), gr.update(value=0,maximum=1024) with gr.Blocks(css=""" .custom-size { width: 735px; height: 127px; } """) as demo: gr.Markdown( '''

Draco Denoising Demo 🙉

Upload a raw micrograph or select a example to visualize the original and denoised results

Our denoising model supports a bin-1 micrograph (ends with .mrc or .h5). To achieve the optimal performance, the input should be motion corrected before passing to model.

''' ) with gr.Row(): with gr.Column(): example_selector = gr.Radio(label="Choose an example Raw Micrograph File", choices=list(example_files.keys())) file_input = gr.File(label="Or upload a Micrograph File in .h5 or .mrc format") with gr.Column(): auto_contrast = gr.Checkbox(label="Enable Auto Contrast", value=False, elem_classes=["custom-size"]) x_slider = gr.Slider(0, 1024, step=10, label="X Offset", elem_classes=["custom-size"]) y_slider = gr.Slider(0, 1024, step=10, label="Y Offset", elem_classes=["custom-size"]) with gr.Row(): denoise_button = gr.Button("Denoise") clear_button = gr.Button("Clear") with gr.Row(): with gr.Column(): original_image = gr.Image(type="pil", label="Original Image") with gr.Column(): denoised_image = gr.Image(type="pil", label="Denoised Image") active_data = gr.State() def load_image_and_update_sliders(file_path) -> tuple: data = load_data(file_path) h, w = data.shape[:2] original_image = display_crop(data, 0, 0, auto_contrast) return data, original_image, None, gr.update(value=0,maximum=w-1024), gr.update(value=0,maximum=h-1024) example_selector.change( lambda choice:load_image_and_update_sliders(example_files[choice]), inputs=example_selector, outputs=[active_data, original_image, denoised_image, x_slider, y_slider] ) file_input.clear( clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider] ) file_input.change( lambda file: load_image_and_update_sliders(file.name) if file else (None, None, None, gr.update(maximum=1024), gr.update(maximum=1024)), inputs=file_input, outputs=[active_data, original_image, denoised_image, x_slider, y_slider] ) x_slider.change( display_crop, inputs=[active_data, x_slider, y_slider, auto_contrast], outputs=original_image ) y_slider.change( display_crop, inputs=[active_data, x_slider, y_slider, auto_contrast], outputs=original_image ) denoise_button.click( process_and_denoise, inputs=[active_data, x_slider, y_slider, auto_contrast], outputs=denoised_image ) clear_button.click(clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider]) demo.launch()