# Copyright 2024 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import unittest import torch from diffusers.hooks import HookRegistry, ModelHook from diffusers.training_utils import free_memory from diffusers.utils.logging import get_logger from diffusers.utils.testing_utils import CaptureLogger, torch_device logger = get_logger(__name__) # pylint: disable=invalid-name class DummyBlock(torch.nn.Module): def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: super().__init__() self.proj_in = torch.nn.Linear(in_features, hidden_features) self.activation = torch.nn.ReLU() self.proj_out = torch.nn.Linear(hidden_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj_in(x) x = self.activation(x) x = self.proj_out(x) return x class DummyModel(torch.nn.Module): def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: super().__init__() self.linear_1 = torch.nn.Linear(in_features, hidden_features) self.activation = torch.nn.ReLU() self.blocks = torch.nn.ModuleList( [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] ) self.linear_2 = torch.nn.Linear(hidden_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear_1(x) x = self.activation(x) for block in self.blocks: x = block(x) x = self.linear_2(x) return x class AddHook(ModelHook): def __init__(self, value: int): super().__init__() self.value = value def pre_forward(self, module: torch.nn.Module, *args, **kwargs): logger.debug("AddHook pre_forward") args = ((x + self.value) if torch.is_tensor(x) else x for x in args) return args, kwargs def post_forward(self, module, output): logger.debug("AddHook post_forward") return output class MultiplyHook(ModelHook): def __init__(self, value: int): super().__init__() self.value = value def pre_forward(self, module, *args, **kwargs): logger.debug("MultiplyHook pre_forward") args = ((x * self.value) if torch.is_tensor(x) else x for x in args) return args, kwargs def post_forward(self, module, output): logger.debug("MultiplyHook post_forward") return output def __repr__(self): return f"MultiplyHook(value={self.value})" class StatefulAddHook(ModelHook): _is_stateful = True def __init__(self, value: int): super().__init__() self.value = value self.increment = 0 def pre_forward(self, module, *args, **kwargs): logger.debug("StatefulAddHook pre_forward") add_value = self.value + self.increment self.increment += 1 args = ((x + add_value) if torch.is_tensor(x) else x for x in args) return args, kwargs def reset_state(self, module): self.increment = 0 class SkipLayerHook(ModelHook): def __init__(self, skip_layer: bool): super().__init__() self.skip_layer = skip_layer def pre_forward(self, module, *args, **kwargs): logger.debug("SkipLayerHook pre_forward") return args, kwargs def new_forward(self, module, *args, **kwargs): logger.debug("SkipLayerHook new_forward") if self.skip_layer: return args[0] return self.fn_ref.original_forward(*args, **kwargs) def post_forward(self, module, output): logger.debug("SkipLayerHook post_forward") return output class HookTests(unittest.TestCase): in_features = 4 hidden_features = 8 out_features = 4 num_layers = 2 def setUp(self): params = self.get_module_parameters() self.model = DummyModel(**params) self.model.to(torch_device) def tearDown(self): super().tearDown() del self.model gc.collect() free_memory() def get_module_parameters(self): return { "in_features": self.in_features, "hidden_features": self.hidden_features, "out_features": self.out_features, "num_layers": self.num_layers, } def get_generator(self): return torch.manual_seed(0) def test_hook_registry(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) registry.register_hook(AddHook(1), "add_hook") registry.register_hook(MultiplyHook(2), "multiply_hook") registry_repr = repr(registry) expected_repr = ( "HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")" ) self.assertEqual(len(registry.hooks), 2) self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) self.assertEqual(registry_repr, expected_repr) registry.remove_hook("add_hook") self.assertEqual(len(registry.hooks), 1) self.assertEqual(registry._hook_order, ["multiply_hook"]) def test_stateful_hook(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) registry.register_hook(StatefulAddHook(1), "stateful_add_hook") self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0) input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) num_repeats = 3 for i in range(num_repeats): result = self.model(input) if i == 0: output1 = result self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats) registry.reset_stateful_hooks() output2 = self.model(input) self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1) self.assertTrue(torch.allclose(output1, output2)) def test_inference(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) registry.register_hook(AddHook(1), "add_hook") registry.register_hook(MultiplyHook(2), "multiply_hook") input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) output1 = self.model(input).mean().detach().cpu().item() registry.remove_hook("multiply_hook") new_input = input * 2 output2 = self.model(new_input).mean().detach().cpu().item() registry.remove_hook("add_hook") new_input = input * 2 + 1 output3 = self.model(new_input).mean().detach().cpu().item() self.assertAlmostEqual(output1, output2, places=5) self.assertAlmostEqual(output1, output3, places=5) def test_skip_layer_hook(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") input = torch.zeros(1, 4, device=torch_device) output = self.model(input).mean().detach().cpu().item() self.assertEqual(output, 0.0) registry.remove_hook("skip_layer_hook") registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook") output = self.model(input).mean().detach().cpu().item() self.assertNotEqual(output, 0.0) def test_skip_layer_internal_block(self): registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1) input = torch.zeros(1, 4, device=torch_device) registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") with self.assertRaises(RuntimeError) as cm: self.model(input).mean().detach().cpu().item() self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception)) registry.remove_hook("skip_layer_hook") output = self.model(input).mean().detach().cpu().item() self.assertNotEqual(output, 0.0) registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1]) registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") output = self.model(input).mean().detach().cpu().item() self.assertNotEqual(output, 0.0) def test_invocation_order_stateful_first(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) registry.register_hook(StatefulAddHook(1), "add_hook") registry.register_hook(AddHook(2), "add_hook_2") registry.register_hook(MultiplyHook(3), "multiply_hook") input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) logger = get_logger(__name__) logger.setLevel("DEBUG") with CaptureLogger(logger) as cap_logger: self.model(input) output = cap_logger.out.replace(" ", "").replace("\n", "") expected_invocation_order_log = ( ( "MultiplyHook pre_forward\n" "AddHook pre_forward\n" "StatefulAddHook pre_forward\n" "AddHook post_forward\n" "MultiplyHook post_forward\n" ) .replace(" ", "") .replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log) registry.remove_hook("add_hook") with CaptureLogger(logger) as cap_logger: self.model(input) output = cap_logger.out.replace(" ", "").replace("\n", "") expected_invocation_order_log = ( ( "MultiplyHook pre_forward\n" "AddHook pre_forward\n" "AddHook post_forward\n" "MultiplyHook post_forward\n" ) .replace(" ", "") .replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log) def test_invocation_order_stateful_middle(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) registry.register_hook(AddHook(2), "add_hook") registry.register_hook(StatefulAddHook(1), "add_hook_2") registry.register_hook(MultiplyHook(3), "multiply_hook") input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) logger = get_logger(__name__) logger.setLevel("DEBUG") with CaptureLogger(logger) as cap_logger: self.model(input) output = cap_logger.out.replace(" ", "").replace("\n", "") expected_invocation_order_log = ( ( "MultiplyHook pre_forward\n" "StatefulAddHook pre_forward\n" "AddHook pre_forward\n" "AddHook post_forward\n" "MultiplyHook post_forward\n" ) .replace(" ", "") .replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log) registry.remove_hook("add_hook") with CaptureLogger(logger) as cap_logger: self.model(input) output = cap_logger.out.replace(" ", "").replace("\n", "") expected_invocation_order_log = ( ("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n") .replace(" ", "") .replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log) registry.remove_hook("add_hook_2") with CaptureLogger(logger) as cap_logger: self.model(input) output = cap_logger.out.replace(" ", "").replace("\n", "") expected_invocation_order_log = ( ("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log) def test_invocation_order_stateful_last(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) registry.register_hook(AddHook(1), "add_hook") registry.register_hook(MultiplyHook(2), "multiply_hook") registry.register_hook(StatefulAddHook(3), "add_hook_2") input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) logger = get_logger(__name__) logger.setLevel("DEBUG") with CaptureLogger(logger) as cap_logger: self.model(input) output = cap_logger.out.replace(" ", "").replace("\n", "") expected_invocation_order_log = ( ( "StatefulAddHook pre_forward\n" "MultiplyHook pre_forward\n" "AddHook pre_forward\n" "AddHook post_forward\n" "MultiplyHook post_forward\n" ) .replace(" ", "") .replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log) registry.remove_hook("add_hook") with CaptureLogger(logger) as cap_logger: self.model(input) output = cap_logger.out.replace(" ", "").replace("\n", "") expected_invocation_order_log = ( ("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n") .replace(" ", "") .replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log)