rlawjdghek commited on
Commit
929fb7f
β€’
1 Parent(s): c72201c
ldm/models/autoencoder.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- import pytorch_lightning as pl
3
  import torch.nn.functional as F
4
  from contextlib import contextmanager
5
 
@@ -10,7 +10,7 @@ from ldm.util import instantiate_from_config
10
  from ldm.modules.ema import LitEma
11
 
12
 
13
- class AutoencoderKL(pl.LightningModule):
14
  def __init__(self,
15
  ddconfig,
16
  lossconfig,
 
1
  import torch
2
+ # import pytorch_lightning as pl
3
  import torch.nn.functional as F
4
  from contextlib import contextmanager
5
 
 
10
  from ldm.modules.ema import LitEma
11
 
12
 
13
+ class AutoencoderKL(nn.Module):
14
  def __init__(self,
15
  ddconfig,
16
  lossconfig,
ldm/models/diffusion/ddpm.py CHANGED
@@ -9,7 +9,7 @@ https://github.com/CompVis/taming-transformers
9
  import torch
10
  import torch.nn as nn
11
  import numpy as np
12
- import pytorch_lightning as pl
13
  from torch.optim.lr_scheduler import LambdaLR
14
  from einops import rearrange, repeat
15
  from contextlib import contextmanager, nullcontext
@@ -17,7 +17,7 @@ from functools import partial
17
  import itertools
18
  from tqdm import tqdm
19
  from torchvision.utils import make_grid
20
- from pytorch_lightning.utilities.distributed import rank_zero_only
21
  from omegaconf import ListConfig
22
  from torchvision.transforms.functional import resize
23
  import torchvision.transforms as T
@@ -47,7 +47,7 @@ def disabled_train(self, mode=True):
47
  def uniform_on_device(r1, r2, shape, device):
48
  return (r1 - r2) * torch.rand(*shape, device=device) + r2
49
 
50
- class DDPM(pl.LightningModule):
51
  # classic DDPM with Gaussian diffusion, in image space
52
  def __init__(self,
53
  unet_config,
@@ -614,7 +614,7 @@ class LatentDiffusion(DDPM):
614
  ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
615
  self.cond_ids[:self.num_timesteps_cond] = ids
616
 
617
- @rank_zero_only
618
  @torch.no_grad()
619
  def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
620
  # only for very first batch
@@ -1387,7 +1387,7 @@ class LatentDiffusion(DDPM):
1387
  return x
1388
 
1389
 
1390
- class DiffusionWrapper(pl.LightningModule):
1391
  def __init__(self, diff_model_config, conditioning_key):
1392
  super().__init__()
1393
  self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
 
9
  import torch
10
  import torch.nn as nn
11
  import numpy as np
12
+ # import pytorch_lightning as pl
13
  from torch.optim.lr_scheduler import LambdaLR
14
  from einops import rearrange, repeat
15
  from contextlib import contextmanager, nullcontext
 
17
  import itertools
18
  from tqdm import tqdm
19
  from torchvision.utils import make_grid
20
+ # from pytorch_lightning.utilities.distributed import rank_zero_only
21
  from omegaconf import ListConfig
22
  from torchvision.transforms.functional import resize
23
  import torchvision.transforms as T
 
47
  def uniform_on_device(r1, r2, shape, device):
48
  return (r1 - r2) * torch.rand(*shape, device=device) + r2
49
 
50
+ class DDPM(nn.Module):
51
  # classic DDPM with Gaussian diffusion, in image space
52
  def __init__(self,
53
  unet_config,
 
614
  ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
615
  self.cond_ids[:self.num_timesteps_cond] = ids
616
 
617
+ # @rank_zero_only
618
  @torch.no_grad()
619
  def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
620
  # only for very first batch
 
1387
  return x
1388
 
1389
 
1390
+ class DiffusionWrapper(nn.Module):
1391
  def __init__(self, diff_model_config, conditioning_key):
1392
  super().__init__()
1393
  self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
requirements.txt CHANGED
@@ -20,5 +20,4 @@ cloudpickle
20
  fvcore
21
  omegaconf==2.1
22
  hydra-core
23
- pycocotools
24
- pytorch-lightning==1.5.0
 
20
  fvcore
21
  omegaconf==2.1
22
  hydra-core
23
+ pycocotools