wenmengzhou's picture
add code and adapt to zero gpus
703e263 verified
raw
history blame
2.02 kB
import torch
import numpy as np
from .processors import Processor_id
class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
self.processor_id = processor_id
self.model_path = model_path
self.scale = scale
class ControlNetUnit:
def __init__(self, processor, model, scale=1.0):
self.processor = processor
self.model = model
self.scale = scale
class MultiControlNetManager:
def __init__(self, controlnet_units=[]):
self.processors = [unit.processor for unit in controlnet_units]
self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units]
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
processed_image = torch.concat([
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
for image_ in processed_image
], dim=0)
return processed_image
def __call__(
self,
sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32, **kwargs
):
res_stack = None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
processor_id=processor.processor_id
)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack