Spaces:
Sleeping
Sleeping
import torch | |
import pytest | |
from torch import nn | |
from lzero.model.stochastic_muzero_model import ChanceEncoder | |
# Initialize a ChanceEncoder instance for testing | |
def encoder(): | |
return ChanceEncoder((3, 32, 32), 4) | |
def test_ChanceEncoder(encoder): | |
# Create a dummy tensor for testing | |
x_and_last_x = torch.randn(1, 6, 32, 32) | |
# Forward pass through the encoder | |
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) | |
# Check the output shapes | |
assert chance_encoding_t.shape == (1, 4) | |
assert chance_onehot_t.shape == (1, 4) | |
# Check that chance_onehot_t is indeed one-hot | |
assert torch.all((chance_onehot_t == 0) | (chance_onehot_t == 1)) | |
assert torch.all(torch.sum(chance_onehot_t, dim=1) == 1) | |
def test_ChanceEncoder_gradients_chance_encoding(encoder): | |
# Create a dummy tensor for testing | |
x_and_last_x = torch.randn(1, 6, 32, 32) | |
# Forward pass through the encoder | |
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) | |
# Create a dummy target tensor for a simple loss function | |
target = torch.randn(1, 4) | |
# Use mean squared error as a simple loss function | |
loss = nn.MSELoss()(chance_encoding_t, target) | |
# Backward pass | |
loss.backward() | |
# Check if gradients are computed | |
for param in encoder.parameters(): | |
assert param.grad is not None | |
# Check if gradients have the correct shape | |
for param in encoder.parameters(): | |
assert param.grad.shape == param.shape | |
def test_ChanceEncoder_gradients_chance_onehot_t(encoder): | |
# Create a dummy tensor for testing | |
x_and_last_x = torch.randn(1, 6, 32, 32) | |
# Forward pass through the encoder | |
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) | |
# Create a dummy target tensor for a simple loss function | |
target = torch.randn(1, 4) | |
# Use mean squared error as a simple loss function | |
loss = nn.MSELoss()(chance_onehot_t, target) | |
# Backward pass | |
loss.backward() | |
# Check if gradients are computed | |
for param in encoder.parameters(): | |
assert param.grad is not None | |
# Check if gradients have the correct shape | |
for param in encoder.parameters(): | |
assert param.grad.shape == param.shape | |