Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# @Time : 2024/8/3 上午10:46 | |
# @Author : xiaoshun | |
# @Email : [email protected] | |
# @File : wandb_vis.py | |
# @Software: PyCharm | |
import argparse | |
import os | |
import shutil | |
from glob import glob | |
import albumentations as albu | |
import numpy as np | |
import torch | |
import wandb | |
from PIL import Image | |
from albumentations.pytorch.transforms import ToTensorV2 | |
from matplotlib import pyplot as plt | |
from rich.progress import track | |
from thop import profile | |
from src.data.components.hrcwhu import HRCWHU | |
from src.data.hrcwhu_datamodule import HRCWHUDataModule | |
from src.models.components.cdnetv1 import CDnetV1 | |
from src.models.components.cdnetv2 import CDnetV2 | |
from src.models.components.dbnet import DBNet | |
from src.models.components.hrcloudnet import HRCloudNet | |
from src.models.components.kappamask import KappaMask | |
from src.models.components.mcdnet import MCDNet | |
from src.models.components.scnn import SCNN | |
from src.models.components.unetmobv2 import UNetMobV2 | |
class WandbVis: | |
def __init__(self, model_name): | |
self.model_name = model_name | |
self.device = "cuda:1" if torch.cuda.is_available() else "cpu" | |
# self.device = "cpu" | |
self.colors = ((255, 255, 255), (128, 192, 128)) | |
self.num_classes = 2 | |
self.model = self.load_model() | |
self.dataloader = self.load_dataset() | |
self.macs, self.params = None, None | |
wandb.init(project='model_vis', name=self.model_name) | |
def load_weight(self, weight_path: str): | |
weight = torch.load(weight_path, map_location=self.device) | |
state_dict = {} | |
for key, value in weight["state_dict"].items(): | |
new_key = key[4:] | |
state_dict[new_key] = value | |
return state_dict | |
def load_model_by_model_name(self): | |
if self.model_name == 'dbnet': | |
return DBNet(img_size=256, in_channels=3, num_classes=2).to(self.device) | |
if self.model_name == "cdnetv1": | |
return CDnetV1(num_classes=2).to(self.device) | |
if self.model_name == "cdnetv2": | |
return CDnetV2(num_classes=2).to(self.device) | |
if self.model_name == "hrcloud": | |
return HRCloudNet(num_classes=2).to(self.device) | |
if self.model_name == "mcdnet": | |
return MCDNet(in_channels=3, num_classes=2).to(self.device) | |
if self.model_name == "scnn": | |
return SCNN(num_classes=2).to(self.device) | |
if self.model_name == "unetmobv2": | |
return UNetMobV2(num_classes=2).to(self.device) | |
if self.model_name == "kappamask": | |
return KappaMask(num_classes=2, in_channels=3).to(self.device) | |
raise ValueError(f"{self.model_name}模型不存在") | |
def load_model(self): | |
weight_path = glob(f"logs/train/runs/hrcwhu_{self.model_name}/*/checkpoints/*.ckpt")[0] | |
model = self.load_model_by_model_name() | |
state_dict = self.load_weight(weight_path) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
def load_dataset(self): | |
all_transform = albu.Compose( | |
[ | |
albu.Resize( | |
height=HRCWHU.METAINFO["img_size"][1], | |
width=HRCWHU.METAINFO["img_size"][2], | |
always_apply=True | |
) | |
] | |
) | |
img_transform = albu.Compose([ | |
albu.ToFloat(), | |
ToTensorV2() | |
]) | |
ann_transform = None | |
val_pipeline = dict( | |
all_transform=all_transform, | |
img_transform=img_transform, | |
ann_transform=ann_transform, | |
) | |
dataloader = HRCWHUDataModule( | |
root="/home/liujie/liumin/cloudseg/data/hrcwhu", | |
train_pipeline=val_pipeline, | |
val_pipeline=val_pipeline, | |
test_pipeline=val_pipeline, | |
batch_size=1, | |
) | |
dataloader.setup() | |
test_dataloader = dataloader.test_dataloader() | |
return test_dataloader | |
# for data in test_dataloader: | |
# print(data['img'].shape,data['ann'].shape,data['img_path'],data['ann_path'],data['lac_type']) | |
# break | |
# torch.Size([1, 3, 256, 256]) | |
# torch.Size([1, 256, 256]) | |
# ['/home/liujie/liumin/cloudseg/data/hrcwhu/img_dir/test/barren_30.tif'] | |
# ['/home/liujie/liumin/cloudseg/data/hrcwhu/ann_dir/test/barren_30.tif'] | |
# ['barren'] | |
def give_colors_to_mask(self, mask: np.ndarray): | |
""" | |
赋予mask颜色 | |
""" | |
assert len(mask.shape) == 2, "Value Error,mask的形状为(height,width)" | |
colors_mask = np.zeros((mask.shape[0], mask.shape[1], 3)).astype(np.float32) | |
for color in range(2): | |
segc = (mask == color) | |
colors_mask[:, :, 0] += segc * (self.colors[color][0]) | |
colors_mask[:, :, 1] += segc * (self.colors[color][1]) | |
colors_mask[:, :, 2] += segc * (self.colors[color][2]) | |
return colors_mask | |
def pred_mask(self, x: torch.Tensor): | |
x = x.to(self.device) | |
self.macs, self.params = profile(self.model, inputs=(x,), verbose=False) | |
logits = self.model(x) | |
if isinstance(logits, tuple): | |
logits = logits[0] | |
fake_mask = torch.argmax(logits, 1).detach().cpu().squeeze(0).numpy() | |
return fake_mask | |
def np_pil_np(self, image: np.ndarray, filename="colors_ann"): | |
colors_np = self.give_colors_to_mask(image) | |
pil_np = Image.fromarray(np.uint8(colors_np)) | |
return np.array(pil_np) | |
def run(self, delete_wadb_log=True): | |
# print(len(self.dataloader)) | |
# 30 | |
for data in track(self.dataloader): | |
img = data["img"] | |
ann = data["ann"].squeeze(0).numpy() | |
img_path = data["img_path"] | |
fake_mask = self.pred_mask(img) | |
colors_ann = self.np_pil_np(ann) | |
colors_fake = self.np_pil_np(fake_mask, "colors_fake") | |
image_name = img_path[0].split(os.path.sep)[-1].split(".")[0] | |
plt.subplot(1, 2, 1) | |
plt.axis("off") | |
plt.title("groud true") | |
plt.imshow(colors_ann) | |
plt.subplot(1, 2, 2) | |
plt.axis("off") | |
plt.title("predict mask") | |
plt.imshow(colors_fake) | |
wandb.log({image_name: wandb.Image(plt)}) | |
wandb.log({"macs": self.macs, "params": self.params}) | |
wandb.finish() | |
if delete_wadb_log and os.path.exists("wandb"): | |
shutil.rmtree("wandb") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-name", type=str, default="dbnet") | |
parser.add_argument("--delete-wadb-log", type=bool, default=True) | |
args = parser.parse_args() | |
vis = WandbVis(model_name=args.model_name) | |
vis.run(delete_wadb_log=args.delete_wadb_log) | |