draco / app.py
Felix-Xu's picture
intro update
19e554d
import gradio as gr
import h5py
import mrcfile
import numpy as np
from PIL import Image
from omegaconf import DictConfig
import torch
from pathlib import Path
from torchvision.transforms import functional as F
import torchvision.transforms.v2 as v2
import spaces
from draco.configuration import CfgNode
from draco.model import (
build_model,
load_pretrained
)
example_files = {
"EMPIAR-10078": "example/empiar-10078-00-000093-full_patch_aligned.h5",
"EMPIAR-10154": "example/empiar-10154-00-000130-full_patch_aligned.h5",
"EMPIAR-10185": "example/empiar-10185-00-000032-full_patch_aligned.h5",
"EMPIAR-10200": "example/empiar-10200-00-000139-full_patch_aligned.h5",
"EMPIAR-10216": "example/empiar-10216-00-000036-full_patch_aligned.h5"
}
class DRACODenoiser(object):
def __init__(self,
cfg: DictConfig,
ckpt_path: Path,
) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.transform = self.build_transform()
self.model = build_model(cfg).to(self.device).eval()
self.model = load_pretrained(self.model, ckpt_path, self.device)
self.patch_size = cfg.MODEL.PATCH_SIZE
def patchify(self, image: torch.Tensor) -> torch.Tensor:
B, C, H, W = image.shape
P = self.patch_size
if H % P != 0 or W % P != 0:
image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0)
patches = image.unfold(2, P, P).unfold(3, P, P)
patches = patches.permute(0, 2, 3, 4, 5, 1)
patches = patches.reshape(B, -1, P * P * C)
return patches
def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor:
B = patches.shape[0]
P = self.patch_size
images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1)
images = images.permute(0, 5, 1, 3, 2, 4)
images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P)
images = images[..., :H, :W]
return images
@classmethod
def build_transform(cls) -> v2.Compose:
return v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True)
])
@spaces.GPU
def inference(self, image: Image.Image) -> None:
W, H = image.size
x = self.transform(image).unsqueeze(0).to(self.device)
y = self.model(x)
x = self.patchify(x).detach().cpu().numpy()
denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
return denoised
# Model Initialization
cfg = CfgNode.load_yaml_with_base(Path("draco.yaml"))
CfgNode.merge_with_dotlist(cfg, [])
ckpt_path = Path("denoise.ckpt")
denoiser = DRACODenoiser(cfg, ckpt_path)
def Auto_contrast(image, t_mean=150.0/255.0, t_sd=40.0/255.0) -> np.ndarray:
image = (image - image.min()) / (image.max() - image.min())
mean = image.mean()
std = image.std()
f = std / t_sd
black = mean - t_mean * f
white = mean + (1 - t_mean) * f
new_image = np.clip(image, black, white)
new_image = (new_image - black) / (white - black)
return new_image
def load_data(file_path) -> np.ndarray:
if file_path.endswith('.h5'):
with h5py.File(file_path, "r") as f:
full_micrograph = f["micrograph"] if "micrograph" in f else f["data"]
full_mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean()
full_std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std()
data = full_micrograph[:].astype(np.float32)
elif file_path.endswith('.mrc'):
with mrcfile.open(file_path, "r") as f:
data = f.data[:].astype(np.float32)
full_mean = data.mean()
full_std = data.std()
else:
raise ValueError("Unsupported file format. Please upload a .mrc or .h5 file.")
data = (data - full_mean) / full_std
return data
def display_crop(data, x_offset, y_offset, auto_contrast) -> Image:
if data is None:
return None
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
original_image_normalized = Auto_contrast(crop) if auto_contrast else (crop - crop.min()) / (crop.max() - crop.min())
input_image = Image.fromarray((original_image_normalized * 255).astype(np.uint8))
return input_image
@spaces.GPU
def process_and_denoise(data, x_offset, y_offset, auto_contrast) -> Image:
if data is None:
return None
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
denoised_data = denoiser.inference(Image.fromarray(crop))
denoised_data = denoised_data.squeeze()
denoised_image_normalized = Auto_contrast(denoised_data) if auto_contrast else (denoised_data - denoised_data.min()) / (denoised_data.max() - denoised_data.min())
denoised_image = Image.fromarray((denoised_image_normalized * 255).astype(np.uint8))
return denoised_image
def clear_images() -> tuple:
return None, None, None, gr.update(value=0,maximum=1024), gr.update(value=0,maximum=1024)
with gr.Blocks(css="""
.custom-size {
width: 735px;
height: 127px;
}
""") as demo:
gr.Markdown(
'''
<div style="text-align: center;">
<h1>Draco Denoising Demo πŸ™‰</h1>
<p style="font-size:16px;">Upload a raw micrograph or select a example to visualize the original and denoised results</p>
<p style="font-size:16px;">Our denoising model supports a bin-1 micrograph (ends with .mrc or .h5). To achieve the optimal performance, the input should be <strong>motion corrected</strong> before passing to model.</p>
</div>
'''
)
with gr.Row():
with gr.Column():
example_selector = gr.Radio(label="Choose an example Raw Micrograph File", choices=list(example_files.keys()))
file_input = gr.File(label="Or upload a Micrograph File in .h5 or .mrc format")
with gr.Column():
auto_contrast = gr.Checkbox(label="Enable Auto Contrast", value=False, elem_classes=["custom-size"])
x_slider = gr.Slider(0, 1024, step=10, label="X Offset", elem_classes=["custom-size"])
y_slider = gr.Slider(0, 1024, step=10, label="Y Offset", elem_classes=["custom-size"])
with gr.Row():
denoise_button = gr.Button("Denoise")
clear_button = gr.Button("Clear")
with gr.Row():
with gr.Column():
original_image = gr.Image(type="pil", label="Original Image")
with gr.Column():
denoised_image = gr.Image(type="pil", label="Denoised Image")
active_data = gr.State()
def load_image_and_update_sliders(file_path) -> tuple:
data = load_data(file_path)
h, w = data.shape[:2]
original_image = display_crop(data, 0, 0, auto_contrast)
return data, original_image, None, gr.update(value=0,maximum=w-1024), gr.update(value=0,maximum=h-1024)
example_selector.change(
lambda choice:load_image_and_update_sliders(example_files[choice]),
inputs=example_selector,
outputs=[active_data, original_image, denoised_image, x_slider, y_slider]
)
file_input.clear(
clear_images,
inputs=None,
outputs=[original_image, denoised_image, active_data, x_slider, y_slider]
)
file_input.change(
lambda file: load_image_and_update_sliders(file.name) if file else (None, None, None, gr.update(maximum=1024), gr.update(maximum=1024)),
inputs=file_input,
outputs=[active_data, original_image, denoised_image, x_slider, y_slider]
)
x_slider.change(
display_crop,
inputs=[active_data, x_slider, y_slider, auto_contrast],
outputs=original_image
)
y_slider.change(
display_crop,
inputs=[active_data, x_slider, y_slider, auto_contrast],
outputs=original_image
)
denoise_button.click(
process_and_denoise,
inputs=[active_data, x_slider, y_slider, auto_contrast],
outputs=denoised_image
)
clear_button.click(clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider])
demo.launch()