Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.torch_utils.network.merge import TorchBilinearCustomized, TorchBilinear, BilinearGeneral, FiLM | |
def test_torch_bilinear_customized(): | |
batch_size = 10 | |
in1_features = 20 | |
in2_features = 30 | |
out_features = 40 | |
bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features) | |
x = torch.randn(batch_size, in1_features) | |
z = torch.randn(batch_size, in2_features) | |
out = bilinear_customized(x, z) | |
assert out.shape == (batch_size, out_features), "Output shape does not match expected shape." | |
def test_torch_bilinear(): | |
batch_size = 10 | |
in1_features = 20 | |
in2_features = 30 | |
out_features = 40 | |
torch_bilinear = TorchBilinear(in1_features, in2_features, out_features) | |
x = torch.randn(batch_size, in1_features) | |
z = torch.randn(batch_size, in2_features) | |
out = torch_bilinear(x, z) | |
assert out.shape == (batch_size, out_features), "Output shape does not match expected shape." | |
def test_bilinear_consistency(): | |
batch_size = 10 | |
in1_features = 20 | |
in2_features = 30 | |
out_features = 40 | |
# Initialize weights and biases with set values | |
weight = torch.randn(out_features, in1_features, in2_features) | |
bias = torch.randn(out_features) | |
# Create and initialize TorchBilinearCustomized and TorchBilinear models | |
bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features) | |
bilinear_customized.weight.data = weight.clone() | |
bilinear_customized.bias.data = bias.clone() | |
torch_bilinear = TorchBilinear(in1_features, in2_features, out_features) | |
torch_bilinear.weight.data = weight.clone() | |
torch_bilinear.bias.data = bias.clone() | |
# Provide same input to both models | |
x = torch.randn(batch_size, in1_features) | |
z = torch.randn(batch_size, in2_features) | |
# Compute outputs | |
out_bilinear_customized = bilinear_customized(x, z) | |
out_torch_bilinear = torch_bilinear(x, z) | |
# Compute the mean squared error between outputs | |
mse = torch.mean((out_bilinear_customized - out_torch_bilinear) ** 2) | |
print(f"Mean Squared Error between outputs: {mse.item()}") | |
# Check if outputs are the same | |
# assert torch.allclose(out_bilinear_customized, out_torch_bilinear), | |
# "Outputs of TorchBilinearCustomized and TorchBilinear are not the same." | |
def test_bilinear_general(): | |
""" | |
Overview: | |
Test for the `BilinearGeneral` class. | |
""" | |
# Define the input dimensions and batch size | |
in1_features = 20 | |
in2_features = 30 | |
out_features = 40 | |
batch_size = 10 | |
# Create a BilinearGeneral instance | |
bilinear_general = BilinearGeneral(in1_features, in2_features, out_features) | |
# Create random inputs | |
input1 = torch.randn(batch_size, in1_features) | |
input2 = torch.randn(batch_size, in2_features) | |
# Perform forward pass | |
output = bilinear_general(input1, input2) | |
# Check output shape | |
assert output.shape == (batch_size, out_features), "Output shape does not match expected shape." | |
# Check parameter shapes | |
assert bilinear_general.W.shape == ( | |
out_features, in1_features, in2_features | |
), "Weight W shape does not match expected shape." | |
assert bilinear_general.U.shape == (out_features, in2_features), "Weight U shape does not match expected shape." | |
assert bilinear_general.V.shape == (out_features, in1_features), "Weight V shape does not match expected shape." | |
assert bilinear_general.b.shape == (out_features, ), "Bias shape does not match expected shape." | |
# Check parameter types | |
assert isinstance(bilinear_general.W, torch.nn.Parameter), "Weight W is not an instance of torch.nn.Parameter." | |
assert isinstance(bilinear_general.U, torch.nn.Parameter), "Weight U is not an instance of torch.nn.Parameter." | |
assert isinstance(bilinear_general.V, torch.nn.Parameter), "Weight V is not an instance of torch.nn.Parameter." | |
assert isinstance(bilinear_general.b, torch.nn.Parameter), "Bias is not an instance of torch.nn.Parameter." | |
def test_film_forward(): | |
# Set the feature and context dimensions | |
feature_dim = 128 | |
context_dim = 256 | |
# Initialize the FiLM layer | |
film_layer = FiLM(feature_dim, context_dim) | |
# Create random feature and context vectors | |
feature = torch.randn((32, feature_dim)) # batch size is 32 | |
context = torch.randn((32, context_dim)) # batch size is 32 | |
# Forward propagation | |
conditioned_feature = film_layer(feature, context) | |
# Check the output shape | |
assert conditioned_feature.shape == feature.shape, \ | |
f'Expected output shape {feature.shape}, but got {conditioned_feature.shape}' | |
# Check that the output is different from the input | |
assert not torch.all(torch.eq(feature, conditioned_feature)), \ | |
'The output feature is the same as the input feature' | |