CELL-E_2-Image_Prediction / celle_main.py
Emaad's picture
Update celle_main.py
1d52ab8
raw
history blame
20.3 kB
import os
import numpy as np
import torch
import torch.random
from torch.optim import AdamW
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from dataloader import CellLoader
from celle import VQGanVAE, CELLE
from omegaconf import OmegaConf
import argparse, os, sys, datetime, glob
from celle.celle import gumbel_sample, top_k
torch.random.manual_seed(42)
np.random.seed(42)
from celle_taming_main import (
instantiate_from_config,
nondefault_trainer_args,
get_parser,
)
class CellDataModule(pl.LightningDataModule):
def __init__(
self,
data_csv,
dataset,
sequence_mode="standard",
vocab="bert",
crop_size=256,
resize=600,
batch_size=1,
threshold="median",
text_seq_len=1000,
num_workers=1,
**kwargs,
):
super().__init__()
self.data_csv = data_csv
self.dataset = dataset
self.protein_sequence_length = 0
self.image_folders = []
self.crop_size = crop_size
self.resize = resize
self.batch_size = batch_size
self.sequence_mode = sequence_mode
self.threshold = threshold
self.text_seq_len = int(text_seq_len)
self.vocab = vocab
self.num_workers = num_workers if num_workers is not None else batch_size * 2
def setup(self, stage=None):
# called on every GPU
self.cell_dataset_train = CellLoader(
data_csv=self.data_csv,
dataset=self.dataset,
crop_size=self.crop_size,
resize=self.resize,
split_key="train",
crop_method="random",
sequence_mode=self.sequence_mode,
vocab=self.vocab,
text_seq_len=self.text_seq_len,
threshold=self.threshold,
)
self.cell_dataset_val = CellLoader(
data_csv=self.data_csv,
dataset=self.dataset,
crop_size=self.crop_size,
resize=self.resize,
crop_method="center",
split_key="val",
sequence_mode=self.sequence_mode,
vocab=self.vocab,
text_seq_len=self.text_seq_len,
threshold=self.threshold,
)
def prepare_data(self):
pass
def train_dataloader(self):
return DataLoader(
self.cell_dataset_train,
num_workers=self.num_workers,
shuffle=True,
batch_size=self.batch_size,
)
def val_dataloader(self):
return DataLoader(
self.cell_dataset_val,
num_workers=self.num_workers,
batch_size=self.batch_size,
)
# def test_dataloader(self):
# transforms = ...
# return DataLoader(self.test, batch_size=64)
class CELLE_trainer(pl.LightningModule):
def __init__(
self,
vqgan_model_path,
vqgan_config_path,
ckpt_path=None,
image_key="threshold",
condition_model_path=None,
condition_config_path=None,
num_images=2,
dim=2,
num_text_tokens=30,
text_seq_len=1000,
depth=16,
heads=16,
dim_head=64,
attn_dropout=0.1,
ff_dropout=0.1,
attn_types="full",
loss_img_weight=7,
stable=False,
rotary_emb=True,
text_embedding="bert",
fixed_embedding=True,
loss_cond_weight=1,
learning_rate=3e-4,
monitor="val_loss",
):
super().__init__()
vae = VQGanVAE(
vqgan_model_path=vqgan_model_path, vqgan_config_path=vqgan_config_path
)
self.image_key = image_key
if condition_config_path:
condition_vae = VQGanVAE(
vqgan_model_path=condition_model_path,
vqgan_config_path=condition_config_path,
)
else:
condition_vae = None
self.celle = CELLE(
dim=dim,
vae=vae, # automatically infer (1) image sequence length and (2) number of image tokens
condition_vae=condition_vae,
num_images=num_images,
num_text_tokens=num_text_tokens, # vocab size for text
text_seq_len=text_seq_len, # text sequence length
depth=depth, # should aim to be 64
heads=heads, # attention heads
dim_head=dim_head, # attention head dimension
attn_dropout=attn_dropout, # attention dropout
ff_dropout=ff_dropout, # feedforward dropout
loss_img_weight=loss_img_weight,
stable=stable,
rotary_emb=rotary_emb,
text_embedding=text_embedding,
fixed_embedding=fixed_embedding,
loss_cond_weight=loss_cond_weight,
)
self.learning_rate = learning_rate
self.num_text_tokens = num_text_tokens
self.num_images = num_images
if monitor is not None:
self.monitor = monitor
ignore_keys = []
if condition_model_path:
ignore_keys.append("celle.condition_vae")
if vqgan_model_path:
ignore_keys.append("celle.vae")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
ckpt = sd.copy()
for k in sd.keys():
for ik in ignore_keys:
if k.startswith(ik):
# print("Deleting key {} from state_dict.".format(k))
del ckpt[k]
self.load_state_dict(ckpt, strict=False)
print(f"Restored from {path}")
def forward(self, text, condition, target, return_loss=True):
return self.celle(
text=text, condition=condition, image=target, return_loss=return_loss
)
def get_input(self, batch):
text = batch["sequence"].squeeze(1)
condition = batch["nucleus"]
target = batch[self.image_key]
return text, condition, target
def get_image_from_logits(self, logits, temperature=0.9):
filtered_logits = top_k(logits, thres=0.5)
sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
self.celle.vae.eval()
out = self.celle.vae.decode(
sample[:, self.celle.text_seq_len + self.celle.condition_seq_len :]
- (self.celle.num_text_tokens + self.celle.num_condition_tokens)
)
return out
def get_loss(self, text, condition, target):
loss_dict = {}
loss, loss_dict, logits = self(text, condition, target, return_loss=True)
return loss, loss_dict
def total_loss(
self,
loss,
loss_dict,
mode="train",
):
loss_dict = {f"{mode}/{key}": value for key, value in loss_dict.items()}
for key, value in loss_dict.items():
self.log(
key,
value,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=True,
)
return loss
def training_step(self, batch, batch_idx):
text, condition, target = self.get_input(batch)
loss, log_dict = self.get_loss(text, condition, target)
loss = self.total_loss(loss, log_dict, mode="train")
return loss
def validation_step(self, batch, batch_idx):
with torch.no_grad():
text, condition, target = self.get_input(batch)
loss, log_dict = self.get_loss(text, condition, target)
loss = self.total_loss(loss, log_dict, mode="val")
return loss
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
return optimizer
def scale_image(self, image):
for tensor in image:
if torch.min(tensor) < 0:
tensor += -torch.min(tensor)
else:
tensor -= torch.min(tensor)
tensor /= torch.max(tensor)
return image
@torch.no_grad()
def log_images(self, batch, **kwargs):
log = []
text, condition, target = self.get_input(batch)
text = text.squeeze(1).to(self.device)
condition = condition.to(self.device)
out = self.celle.generate_images(text=text, condition=condition)
log["condition"] = self.scale_image(condition)
log["output"] = self.scale_image(out)
if self.image_key == "threshold":
log["threshold"] = self.scale_image(target)
log["target"] = self.scale_image(batch["target"])
else:
log["target"] = self.scale_image(target)
return log
# from https://github.com/CompVis/taming-transformers/blob/master/celle_main.py
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: celle_main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# add cwd for convenience and to make classes in this file available when
# running as `python celle_main.py`
# (in particular `celle_main.DataModuleFromConfig`)
sys.path.append(os.getcwd())
parser = get_parser()
parser = Trainer.add_argparse_args(parser)
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
idx = len(paths) - paths[::-1].index("logs") + 1
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
nowname = _tmp[_tmp.index("logs") + 1]
else:
if opt.name:
name = "_" + opt.name
elif opt.base:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = "_" + cfg_name
else:
name = ""
nowname = now + name + opt.postfix
logdir = os.path.join("logs", nowname)
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
try:
# init and save configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to ddp
# trainer_config["distributed_backend"] = "ddp"
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not "gpus" in trainer_config:
del trainer_config["distributed_backend"]
cpu = True
else:
gpuinfo = trainer_config["gpus"]
print(f"Running on GPUs {gpuinfo}")
cpu = False
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
# model = instantiate_from_config(config.model)
model = instantiate_from_config(config.model)
# trainer and callbacks
trainer_kwargs = dict()
# default logger configs
# NOTE wandb < 0.10.0 interferes with shutdown
# wandb >= 0.10.0 seems to fix it but still interferes with pudb
# debugging (wrongly sized pudb ui)
# thus prefer testtube for now
default_logger_cfgs = {
"wandb": {
"target": "pytorch_lightning.loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
"offline": opt.debug,
"id": nowname,
},
},
"testtube": {
# "target": "pytorch_lightning.loggers.TestTubeLogger",
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"params": {
"name": "testtube",
"save_dir": logdir,
},
},
}
default_logger_cfg = default_logger_cfgs["testtube"]
# logger_cfg = lightning_config.logger or OmegaConf.create()
try:
logger_cfg = lightning_config.logger
except:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"checkpoint_callback": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
},
}
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["checkpoint_callback"]["params"][
"monitor"
] = model.monitor
default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3
try:
modelckpt_cfg = lightning_config.modelcheckpoint
except:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
# trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "celle_taming_main.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
"lightning_config": lightning_config,
},
},
# "image_logger": {
# "target": "celle_taming_main.ImageLogger",
# "params": {
# "batch_frequency": 0,
# "max_images": 0,
# "clamp": False,
# "increase_log_steps": False,
# },
# },
# "learning_rate_logger": {
# "target": "celle_taming_main.LearningRateMonitor",
# "params": {
# "logging_interval": "step",
# # "log_momentum": True
# },
# },
}
try:
callbacks_cfg = lightning_config.callbacks
except:
callbacks_cfg = OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg)
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
trainer = Trainer.from_argparse_args(
trainer_opt, **trainer_kwargs, profiler="simple"
)
# data
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.setup()
data.prepare_data()
# configure learning rate
bs, lr = config.data.params.batch_size, config.model.learning_rate
if not cpu:
ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
else:
ngpu = 1
try:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
except:
accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
model.learning_rate = accumulate_grad_batches * ngpu * bs * lr
print(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (lr)".format(
model.learning_rate, accumulate_grad_batches, ngpu, bs, lr
)
)
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
# run
if opt.train:
try:
# model = torch.compile(model, mode="reduce_overhead")
torch.compile(trainer.fit(model, data), mode="max-autotune")
except Exception:
melk()
raise
if not opt.no_test and not trainer.interrupted:
trainer.test(model, data)
except Exception:
if opt.debug and trainer.global_rank == 0:
try:
import pudb as debugger
except ImportError:
import pdb as debugger
debugger.post_mortem()
raise
finally:
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank == 0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)