|
import gradio as gr |
|
import h5py |
|
import mrcfile |
|
import numpy as np |
|
from PIL import Image |
|
from omegaconf import DictConfig |
|
import torch |
|
from pathlib import Path |
|
from torchvision.transforms import functional as F |
|
import torchvision.transforms.v2 as v2 |
|
import spaces |
|
|
|
|
|
from draco.configuration import CfgNode |
|
from draco.model import ( |
|
build_model, |
|
load_pretrained |
|
) |
|
|
|
example_files = { |
|
"EMPIAR-10078": "example/empiar-10078-00-000093-full_patch_aligned.h5", |
|
"EMPIAR-10154": "example/empiar-10154-00-000130-full_patch_aligned.h5", |
|
"EMPIAR-10185": "example/empiar-10185-00-000032-full_patch_aligned.h5", |
|
"EMPIAR-10200": "example/empiar-10200-00-000139-full_patch_aligned.h5", |
|
"EMPIAR-10216": "example/empiar-10216-00-000036-full_patch_aligned.h5" |
|
} |
|
|
|
class DRACODenoiser(object): |
|
def __init__(self, |
|
cfg: DictConfig, |
|
ckpt_path: Path, |
|
) -> None: |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
self.transform = self.build_transform() |
|
self.model = build_model(cfg).to(self.device).eval() |
|
self.model = load_pretrained(self.model, ckpt_path, self.device) |
|
self.patch_size = cfg.MODEL.PATCH_SIZE |
|
|
|
def patchify(self, image: torch.Tensor) -> torch.Tensor: |
|
B, C, H, W = image.shape |
|
P = self.patch_size |
|
if H % P != 0 or W % P != 0: |
|
image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0) |
|
|
|
patches = image.unfold(2, P, P).unfold(3, P, P) |
|
patches = patches.permute(0, 2, 3, 4, 5, 1) |
|
patches = patches.reshape(B, -1, P * P * C) |
|
return patches |
|
|
|
def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor: |
|
B = patches.shape[0] |
|
P = self.patch_size |
|
|
|
images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1) |
|
images = images.permute(0, 5, 1, 3, 2, 4) |
|
images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P) |
|
images = images[..., :H, :W] |
|
return images |
|
|
|
@classmethod |
|
def build_transform(cls) -> v2.Compose: |
|
return v2.Compose([ |
|
v2.ToImage(), |
|
v2.ToDtype(torch.float32, scale=True) |
|
]) |
|
|
|
@spaces.GPU |
|
def inference(self, image: Image.Image) -> None: |
|
W, H = image.size |
|
|
|
x = self.transform(image).unsqueeze(0).to(self.device) |
|
y = self.model(x) |
|
|
|
x = self.patchify(x).detach().cpu().numpy() |
|
denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy() |
|
|
|
return denoised |
|
|
|
|
|
cfg = CfgNode.load_yaml_with_base(Path("draco.yaml")) |
|
CfgNode.merge_with_dotlist(cfg, []) |
|
ckpt_path = Path("denoise.ckpt") |
|
denoiser = DRACODenoiser(cfg, ckpt_path) |
|
|
|
def Auto_contrast(image, t_mean=150.0/255.0, t_sd=40.0/255.0) -> np.ndarray: |
|
|
|
image = (image - image.min()) / (image.max() - image.min()) |
|
mean = image.mean() |
|
std = image.std() |
|
|
|
f = std / t_sd |
|
|
|
black = mean - t_mean * f |
|
white = mean + (1 - t_mean) * f |
|
|
|
new_image = np.clip(image, black, white) |
|
new_image = (new_image - black) / (white - black) |
|
return new_image |
|
|
|
|
|
def load_data(file_path) -> np.ndarray: |
|
if file_path.endswith('.h5'): |
|
with h5py.File(file_path, "r") as f: |
|
full_micrograph = f["micrograph"] if "micrograph" in f else f["data"] |
|
full_mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean() |
|
full_std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std() |
|
data = full_micrograph[:].astype(np.float32) |
|
elif file_path.endswith('.mrc'): |
|
with mrcfile.open(file_path, "r") as f: |
|
data = f.data[:].astype(np.float32) |
|
full_mean = data.mean() |
|
full_std = data.std() |
|
else: |
|
raise ValueError("Unsupported file format. Please upload a .mrc or .h5 file.") |
|
data = (data - full_mean) / full_std |
|
return data |
|
|
|
def display_crop(data, x_offset, y_offset, auto_contrast) -> Image: |
|
|
|
if data is None: |
|
return None |
|
|
|
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024] |
|
original_image_normalized = Auto_contrast(crop) if auto_contrast else (crop - crop.min()) / (crop.max() - crop.min()) |
|
input_image = Image.fromarray((original_image_normalized * 255).astype(np.uint8)) |
|
|
|
return input_image |
|
|
|
@spaces.GPU |
|
def process_and_denoise(data, x_offset, y_offset, auto_contrast) -> Image: |
|
|
|
if data is None: |
|
return None |
|
|
|
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024] |
|
denoised_data = denoiser.inference(Image.fromarray(crop)) |
|
|
|
denoised_data = denoised_data.squeeze() |
|
denoised_image_normalized = Auto_contrast(denoised_data) if auto_contrast else (denoised_data - denoised_data.min()) / (denoised_data.max() - denoised_data.min()) |
|
denoised_image = Image.fromarray((denoised_image_normalized * 255).astype(np.uint8)) |
|
|
|
return denoised_image |
|
|
|
def clear_images() -> tuple: |
|
return None, None, None, gr.update(value=0,maximum=1024), gr.update(value=0,maximum=1024) |
|
|
|
with gr.Blocks(css=""" |
|
.custom-size { |
|
width: 735px; |
|
height: 127px; |
|
} |
|
""") as demo: |
|
|
|
gr.Markdown( |
|
''' |
|
<div style="text-align: center;"> |
|
<h1>Draco Denoising Demo π</h1> |
|
<p style="font-size:16px;">Upload a raw micrograph or select a example to visualize the original and denoised results</p> |
|
<p style="font-size:16px;">Our denoising model supports a bin-1 micrograph (ends with .mrc or .h5). To achieve the optimal performance, the input should be <strong>motion corrected</strong> before passing to model.</p> |
|
</div> |
|
''' |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
example_selector = gr.Radio(label="Choose an example Raw Micrograph File", choices=list(example_files.keys())) |
|
file_input = gr.File(label="Or upload a Micrograph File in .h5 or .mrc format") |
|
|
|
with gr.Column(): |
|
auto_contrast = gr.Checkbox(label="Enable Auto Contrast", value=False, elem_classes=["custom-size"]) |
|
x_slider = gr.Slider(0, 1024, step=10, label="X Offset", elem_classes=["custom-size"]) |
|
y_slider = gr.Slider(0, 1024, step=10, label="Y Offset", elem_classes=["custom-size"]) |
|
|
|
with gr.Row(): |
|
denoise_button = gr.Button("Denoise") |
|
clear_button = gr.Button("Clear") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
original_image = gr.Image(type="pil", label="Original Image") |
|
with gr.Column(): |
|
denoised_image = gr.Image(type="pil", label="Denoised Image") |
|
|
|
active_data = gr.State() |
|
|
|
def load_image_and_update_sliders(file_path) -> tuple: |
|
data = load_data(file_path) |
|
h, w = data.shape[:2] |
|
original_image = display_crop(data, 0, 0, auto_contrast) |
|
return data, original_image, None, gr.update(value=0,maximum=w-1024), gr.update(value=0,maximum=h-1024) |
|
|
|
|
|
example_selector.change( |
|
lambda choice:load_image_and_update_sliders(example_files[choice]), |
|
inputs=example_selector, |
|
outputs=[active_data, original_image, denoised_image, x_slider, y_slider] |
|
) |
|
|
|
file_input.clear( |
|
clear_images, |
|
inputs=None, |
|
outputs=[original_image, denoised_image, active_data, x_slider, y_slider] |
|
) |
|
|
|
file_input.change( |
|
lambda file: load_image_and_update_sliders(file.name) if file else (None, None, None, gr.update(maximum=1024), gr.update(maximum=1024)), |
|
inputs=file_input, |
|
outputs=[active_data, original_image, denoised_image, x_slider, y_slider] |
|
) |
|
|
|
x_slider.change( |
|
display_crop, |
|
inputs=[active_data, x_slider, y_slider, auto_contrast], |
|
outputs=original_image |
|
) |
|
|
|
y_slider.change( |
|
display_crop, |
|
inputs=[active_data, x_slider, y_slider, auto_contrast], |
|
outputs=original_image |
|
) |
|
|
|
denoise_button.click( |
|
process_and_denoise, |
|
inputs=[active_data, x_slider, y_slider, auto_contrast], |
|
outputs=denoised_image |
|
) |
|
|
|
clear_button.click(clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider]) |
|
|
|
demo.launch() |
|
|
|
|