fbnnb commited on
Commit
562b42f
Β·
verified Β·
1 Parent(s): 8417181

Update lvdm/common.py

Browse files
Files changed (1) hide show
  1. lvdm/common.py +5 -5
lvdm/common.py CHANGED
@@ -77,7 +77,7 @@ def init_(tensor):
77
  tensor.uniform_(-std, std)
78
  return tensor
79
 
80
- ckpt = torch.utils.checkpoint.checkpoint
81
  def checkpoint(func, inputs, params, flag):
82
  """
83
  Evaluate a function without caching intermediate activations, allowing for
@@ -88,7 +88,7 @@ def checkpoint(func, inputs, params, flag):
88
  explicitly take as arguments.
89
  :param flag: if False, disable gradient checkpointing.
90
  """
91
- if flag:
92
- return ckpt(func, *inputs, use_reentrant=False)
93
- else:
94
- return func(*inputs)
 
77
  tensor.uniform_(-std, std)
78
  return tensor
79
 
80
+ # ckpt = torch.utils.checkpoint.checkpoint
81
  def checkpoint(func, inputs, params, flag):
82
  """
83
  Evaluate a function without caching intermediate activations, allowing for
 
88
  explicitly take as arguments.
89
  :param flag: if False, disable gradient checkpointing.
90
  """
91
+ # if flag:
92
+ # return ckpt(func, *inputs, use_reentrant=False)
93
+ # else:
94
+ return func(*inputs)