Spaces:
Sleeping
Sleeping
from typing import Union, List | |
import torch | |
def is_differentiable( | |
loss: torch.Tensor, model: Union[torch.nn.Module, List[torch.nn.Module]], print_instead: bool = False | |
) -> None: | |
""" | |
Overview: | |
Judge whether the model/models are differentiable. First check whether module's grad is None, | |
then do loss's back propagation, finally check whether module's grad are torch.Tensor. | |
Arguments: | |
- loss (:obj:`torch.Tensor`): loss tensor of the model | |
- model (:obj:`Union[torch.nn.Module, List[torch.nn.Module]]`): model or models to be checked | |
- print_instead (:obj:`bool`): Whether to print module's final grad result, \ | |
instead of asserting. Default set to ``False``. | |
""" | |
assert isinstance(loss, torch.Tensor) | |
if isinstance(model, list): | |
for m in model: | |
assert isinstance(m, torch.nn.Module) | |
for k, p in m.named_parameters(): | |
assert p.grad is None, k | |
elif isinstance(model, torch.nn.Module): | |
for k, p in model.named_parameters(): | |
assert p.grad is None, k | |
else: | |
raise TypeError('model must be list or nn.Module') | |
loss.backward() | |
if isinstance(model, list): | |
for m in model: | |
for k, p in m.named_parameters(): | |
if print_instead: | |
if not isinstance(p.grad, torch.Tensor): | |
print(k, "grad is:", p.grad) | |
else: | |
assert isinstance(p.grad, torch.Tensor), k | |
elif isinstance(model, torch.nn.Module): | |
for k, p in model.named_parameters(): | |
if print_instead: | |
if not isinstance(p.grad, torch.Tensor): | |
print(k, "grad is:", p.grad) | |
else: | |
assert isinstance(p.grad, torch.Tensor), k | |