draco / app.py
Felix-Xu's picture
denoise model update
3bf7d18
raw
history blame
7.45 kB
import gradio as gr
import h5py
import mrcfile
import numpy as np
from PIL import Image
from omegaconf import DictConfig
import torch
from pathlib import Path
from torchvision.transforms import functional as F
import torchvision.transforms.v2 as v2
from draco.configuration import CfgNode
from draco.model import (
build_model,
load_pretrained
)
class DRACODenoiser(object):
def __init__(self,
cfg: DictConfig,
ckpt_path: Path,
) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.transform = self.build_transform()
self.model = build_model(cfg).to(self.device).eval()
self.model = load_pretrained(self.model, ckpt_path, self.device)
self.patch_size = cfg.MODEL.PATCH_SIZE
def patchify(self, image: torch.Tensor) -> torch.Tensor:
B, C, H, W = image.shape
P = self.patch_size
if H % P != 0 or W % P != 0:
image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0)
patches = image.unfold(2, P, P).unfold(3, P, P)
patches = patches.permute(0, 2, 3, 4, 5, 1)
patches = patches.reshape(B, -1, P * P * C)
return patches
def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor:
B = patches.shape[0]
P = self.patch_size
images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1)
images = images.permute(0, 5, 1, 3, 2, 4)
images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P)
images = images[..., :H, :W]
return images
@classmethod
def build_transform(cls) -> v2.Compose:
return v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True)
])
@torch.inference_mode()
def inference(self, image: Image.Image) -> None:
W, H = image.size
x = self.transform(image).unsqueeze(0).to(self.device)
y = self.model(x)
x = self.patchify(x).detach().cpu().numpy()
denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
return denoised
# Model Initialization
cfg = CfgNode.load_yaml_with_base(Path("draco.yaml"))
CfgNode.merge_with_dotlist(cfg, [])
ckpt_path = Path("denoise.ckpt")
denoiser = DRACODenoiser(cfg, ckpt_path)
def Auto_contrast(image, t_mean=150.0/255.0, t_sd=40.0/255.0) -> np.ndarray:
image = (image - image.min()) / (image.max() - image.min())
mean = image.mean()
std = image.std()
f = std / t_sd
black = mean - t_mean * f
white = mean + (1 - t_mean) * f
new_image = np.clip(image, black, white)
new_image = (new_image - black) / (white - black)
return new_image
def load_data(file_path) -> np.ndarray:
if file_path.endswith('.h5'):
with h5py.File(file_path, "r") as f:
full_micrograph = f["micrograph"] if "micrograph" in f else f["data"]
full_mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean()
full_std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std()
data = full_micrograph[:].astype(np.float32)
elif file_path.endswith('.mrc'):
with mrcfile.open(file_path, "r") as f:
data = f.data[:].astype(np.float32)
full_mean = data.mean()
full_std = data.std()
else:
raise ValueError("Unsupported file format. Please upload a .mrc or .h5 file.")
data = (data - full_mean) / full_std
return data
def display_crop(data, x_offset, y_offset, auto_contrast) -> Image:
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
original_image_normalized = Auto_contrast(crop) if auto_contrast else (crop - crop.min()) / (crop.max() - crop.min())
input_image = Image.fromarray((original_image_normalized * 255).astype(np.uint8))
return input_image
def process_and_denoise(data, x_offset, y_offset, auto_contrast) -> Image:
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
denoised_data = denoiser.inference(Image.fromarray(crop))
denoised_data = denoised_data.squeeze()
denoised_image_normalized = Auto_contrast(denoised_data) if auto_contrast else (denoised_data - denoised_data.min()) / (denoised_data.max() - denoised_data.min())
denoised_image = Image.fromarray((denoised_image_normalized * 255).astype(np.uint8))
return denoised_image
def clear_images() -> tuple:
return None, None, None, gr.update(maximum=512), gr.update(maximum=512)
with gr.Blocks(css="""
.gradio-container {
background-color: #f7f9fc;
font-family: Arial, sans-serif;
}
.title-text {
text-align: center;
font-size: 30px;
font-weight: bold;
margin-bottom: 10px;
}
.description-text {
text-align: center;
font-size: 18px;
margin-bottom: 20px;
}
""") as demo:
# Centered Title and Description
with gr.Column():
gr.Markdown(
"""
<div style="text-align: center; font-size: 30px; font-weight: bold; margin-bottom: 10px;">
Denoising Demo
</div>
<div style="text-align: center; font-size: 18px;">
Upload a Raw file or select an example to view the original and denoised images
</div>
"""
)
file_input = gr.File(label="Or upload a Micrograph File in .h5 or .mrc format")
auto_contrast = gr.Checkbox(label="Enable Auto Contrast", value=False)
x_slider = gr.Slider(0, 512, step=10, label="X Offset")
y_slider = gr.Slider(0, 512, step=10, label="Y Offset")
with gr.Row():
denoise_button = gr.Button("Denoise")
clear_button = gr.Button("Clear")
with gr.Row():
with gr.Column():
original_image = gr.Image(type="pil", label="Original Image")
with gr.Column():
denoised_image = gr.Image(type="pil", label="Denoised Image")
active_data = gr.State()
def load_image_and_update_sliders(file_path) -> tuple:
data = load_data(file_path)
h, w = data.shape[:2]
return data, gr.update(maximum=w-1024), gr.update(maximum=h-1024)
file_input.clear(
clear_images,
inputs=None,
outputs=[original_image, denoised_image, active_data, x_slider, y_slider]
)
file_input.change(
lambda file: load_image_and_update_sliders(file.name) if file else (None, None, None, gr.update(maximum=512), gr.update(maximum=512)),
inputs=file_input,
outputs=[active_data, x_slider, y_slider]
)
x_slider.change(
display_crop,
inputs=[active_data, x_slider, y_slider, auto_contrast],
outputs=original_image
)
y_slider.change(
display_crop,
inputs=[active_data, x_slider, y_slider, auto_contrast],
outputs=original_image
)
denoise_button.click(
process_and_denoise,
inputs=[active_data, x_slider, y_slider, auto_contrast],
outputs=denoised_image
)
clear_button.click(clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider])
demo.launch()