Spaces:
Running
on
Zero
Running
on
Zero
import argparse, os, sys, datetime | |
import numpy as np | |
import time | |
import torch | |
import torchvision | |
import pytorch_lightning as pl | |
from packaging import version | |
from omegaconf import OmegaConf | |
from torch.utils.data import DataLoader | |
from functools import partial | |
from PIL import Image | |
from pytorch_lightning import seed_everything | |
from pytorch_lightning.trainer import Trainer | |
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor | |
from pytorch_lightning.utilities.distributed import rank_zero_only | |
from pytorch_lightning.utilities import rank_zero_info | |
from ldm.util import instantiate_from_config | |
# Ignore warnings | |
import warnings | |
warnings.filterwarnings("ignore") | |
def get_parser(**parser_kwargs): | |
def str2bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise argparse.ArgumentTypeError("Boolean value expected.") | |
parser = argparse.ArgumentParser(**parser_kwargs) | |
# default | |
parser.add_argument("-n", "--name", type=str, nargs="?", default="") | |
parser.add_argument("-r", "--resume", type=str, nargs="?", default="") | |
parser.add_argument("-t", "--train", type=str2bool, nargs="?", default=True) | |
parser.add_argument("-s", "--seed", type=int, nargs="?", default=3407) | |
parser.add_argument("-f", "--postfix", type=str, nargs="?", default="") | |
parser.add_argument("--train_from_scratch", type=str2bool, nargs="?", default=False) | |
parser.add_argument("-d", "--debug", type=str2bool, nargs="?", default=False) | |
# train.sh | |
parser.add_argument("-b", "--base", type=str, nargs="?", default="configs/train_vitonhd.yaml") | |
parser.add_argument("-l", "--logdir", type=str, nargs="?", default="logs") | |
parser.add_argument("-p", "--pretrained_model", type=str, nargs="?", default="checkpoints/pbe_dim6.ckpt") | |
return parser | |
class ImageLogger(Callback): | |
def __init__(self, batch_frequency=2000, log_steps=[1]): | |
super().__init__() | |
self.batch_freq = batch_frequency | |
self.log_steps = log_steps | |
# At the end of each batch, determine whether you need to save images. | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): | |
if pl_module.global_step > 0: | |
self.log_img(pl_module, batch, batch_idx, split="train") | |
# Save images | |
def log_img(self, pl_module, batch, batch_idx, split="train"): | |
check_idx = pl_module.global_step | |
if (self.check_frequency(check_idx) and hasattr(pl_module, "sample_log") and callable(pl_module.sample_log)): | |
is_train = pl_module.training | |
if is_train: | |
pl_module.eval() | |
with torch.no_grad(): | |
images = pl_module.sample_log(batch) | |
for k in images: | |
N = images[k].shape[0] | |
images[k] = images[k][:N] | |
if isinstance(images[k], torch.Tensor): | |
images[k] = images[k].detach().cpu() | |
images[k] = torch.clamp(images[k], -1., 1.) | |
self.log_local(pl_module.logger.save_dir, | |
split, | |
images, | |
pl_module.global_step, | |
pl_module.current_epoch, | |
batch_idx) | |
if is_train: | |
pl_module.train() | |
# check_index is a multiple of self.batch_freq, or check_idx is in self.log_steps | |
def check_frequency(self, check_idx): | |
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and check_idx > 0: | |
try: | |
self.log_steps.pop(0) | |
except IndexError as e: | |
print(e) | |
pass | |
return True | |
return False | |
# Save images in local folder | |
def log_local(self, | |
save_dir, | |
split, | |
images, | |
global_step, | |
current_epoch, | |
batch_idx): | |
root = os.path.join(save_dir, "images", split) | |
for k in images: | |
grid = torchvision.utils.make_grid(images[k], nrow=4) | |
grid = (grid + 1.0) / 2.0 | |
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
grid = grid.numpy() | |
grid = (grid * 255).astype(np.uint8) | |
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( | |
k, | |
global_step, | |
current_epoch, | |
batch_idx) | |
path = os.path.join(root, filename) | |
os.makedirs(os.path.split(path)[0], exist_ok=True) | |
Image.fromarray(grid).save(path) | |
class CUDACallback(Callback): | |
def on_train_epoch_start(self, trainer, pl_module): | |
torch.cuda.reset_peak_memory_stats(trainer.root_gpu) | |
torch.cuda.synchronize(trainer.root_gpu) | |
self.start_time = time.time() | |
def on_train_epoch_end(self, trainer, pl_module, outputs): | |
torch.cuda.synchronize(trainer.root_gpu) | |
epoch = trainer.current_epoch | |
if epoch % 5 == 0 and epoch !=0 : | |
trainer.save_checkpoint(f'epoch={epoch}.ckpt') | |
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 | |
epoch_time = time.time() - self.start_time | |
try: | |
max_memory = trainer.training_type_plugin.reduce(max_memory) | |
epoch_time = trainer.training_type_plugin.reduce(epoch_time) | |
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") | |
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") | |
except AttributeError: | |
pass | |
class DataModuleFromConfig(pl.LightningDataModule): | |
def __init__(self, | |
batch_size, # N | |
train=None, # {...} | |
wrap=False, # False | |
use_worker_init_fn=False): # False | |
super().__init__() | |
self.batch_size = batch_size | |
self.num_workers = batch_size * 2 | |
self.use_worker_init_fn = use_worker_init_fn | |
self.wrap = wrap | |
self.dataset_configs = dict() | |
self.train = train | |
self.train_dataloader = self._train_dataloader | |
def _train_dataloader(self): | |
return DataLoader(instantiate_from_config(self.train), | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
shuffle=True, | |
worker_init_fn=None) | |
if __name__ == "__main__": | |
sys.path.append(os.getcwd()) | |
# ============================================================= | |
# Get parser and generate opt | |
# ============================================================= | |
parser = get_parser() | |
parser = Trainer.add_argparse_args(parser) | |
opt, unknown = parser.parse_known_args() | |
# ============================================================= | |
# Generate logdir path | |
# ============================================================= | |
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # 2023-05-18T04-27-24 | |
cfg_fname = os.path.split(opt.base)[-1] # train_vitonhd.yaml | |
cfg_name = os.path.splitext(cfg_fname)[0] # train_vitonhd | |
nowname = now + "_" + cfg_name # 2023-05-18T04-27-24_train_vitonhd | |
logdir = os.path.join(opt.logdir, nowname) # logs/2023-05-18T04-27-24_train_vitonhd | |
ckptdir = os.path.join(logdir, "checkpoints") # logs/2023-05-18T04-27-24_train_vitonhd/checkpoints | |
cfgdir = os.path.join(logdir, "configs") # logs/2023-05-18T04-27-24_train_vitonhd/configs | |
# ============================================================= | |
# Set seed | |
# ============================================================= | |
seed_everything(opt.seed) | |
# ============================================================= | |
# Initialize config | |
# ============================================================= | |
config = OmegaConf.load(opt.base) # Load the yaml file to DictConfig | |
lightning_config = config.pop("lightning", OmegaConf.create()) # Remove lightning from config and return lightning_config | |
trainer_config = lightning_config.get("trainer", OmegaConf.create()) # Extract trainer from lightning_config | |
trainer_opt = argparse.Namespace(**trainer_config) # argparse.Namespace(accelerator='ddp', gpus='0,1', max_epochs=200, num_nodes=1) | |
# ============================================================= | |
# Load model and initialize it | |
# ============================================================= | |
# Use config.model["params"] to initialize config.model["target"] | |
model = instantiate_from_config(config.model) | |
# Load pre-trained model weights | |
model.load_state_dict(torch.load(opt.pretrained_model, map_location='cpu'), strict=False) | |
# ============================================================= | |
# Set trainer_kwargs | |
# ============================================================= | |
trainer_kwargs = dict() | |
# Gradient accumulation | |
trainer_kwargs["accumulate_grad_batches"] = 8 | |
# Log the training process in logdir/testtube | |
default_logger_cfg = { | |
"target": "pytorch_lightning.loggers.TestTubeLogger", | |
"params": {"name": "testtube", "save_dir": logdir} | |
} | |
logger_cfg = OmegaConf.create(default_logger_cfg) | |
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) | |
# Callbacks setting | |
default_callbacks_cfg = { | |
# Save images in local folder during training process | |
"image_logger": { | |
"target": "train.ImageLogger", | |
"params": {"batch_frequency": 2000, "log_steps": [1]} | |
}, | |
# Automatically record learning rate during training process | |
"learning_rate_logger": { | |
"target": "train.LearningRateMonitor", | |
"params": {"logging_interval": "step"} | |
}, | |
# on_train_epoch_start and on_train_epoch_end | |
"cuda_callback": { | |
"target": "train.CUDACallback" | |
}, | |
} | |
callbacks_cfg = OmegaConf.create(default_callbacks_cfg) | |
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] | |
# ============================================================= | |
# Initialize trainer with trainer_kwargs | |
# ============================================================= | |
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) | |
trainer.logdir = logdir | |
# ============================================================= | |
# Load dataset | |
# ============================================================= | |
# Use config.data["params"] to initialize config.data["target"] | |
data = instantiate_from_config(config.data) | |
print(f"{'train'}, {data.train_dataloader().__class__.__name__}, {len(data.train_dataloader())}") | |
# ============================================================= | |
# Set learning_rate | |
# ============================================================= | |
model.learning_rate = config.model.base_learning_rate | |
# ============================================================= | |
# Training | |
# ============================================================= | |
trainer.fit(model, data) | |