import torch import numpy as np from .processors import Processor_id class ControlNetConfigUnit: def __init__(self, processor_id: Processor_id, model_path, scale=1.0): self.processor_id = processor_id self.model_path = model_path self.scale = scale class ControlNetUnit: def __init__(self, processor, model, scale=1.0): self.processor = processor self.model = model self.scale = scale class MultiControlNetManager: def __init__(self, controlnet_units=[]): self.processors = [unit.processor for unit in controlnet_units] self.models = [unit.model for unit in controlnet_units] self.scales = [unit.scale for unit in controlnet_units] def process_image(self, image, processor_id=None): if processor_id is None: processed_image = [processor(image) for processor in self.processors] else: processed_image = [self.processors[processor_id](image)] processed_image = torch.concat([ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) for image_ in processed_image ], dim=0) return processed_image def __call__( self, sample, timestep, encoder_hidden_states, conditionings, tiled=False, tile_size=64, tile_stride=32, **kwargs ): res_stack = None for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): res_stack_ = model( sample, timestep, encoder_hidden_states, conditioning, **kwargs, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, processor_id=processor.processor_id ) res_stack_ = [res * scale for res in res_stack_] if res_stack is None: res_stack = res_stack_ else: res_stack = [i + j for i, j in zip(res_stack, res_stack_)] return res_stack