--- license: apache-2.0 datasets: - detection-datasets/coco language: - en library_name: diffusers tags: - pytorch - controlnet - image-colorization - image-to-image pipeline_tag: image-to-image --- # Model Card for ColorizeNet This model is a ControlNet training to perform image colorization from black and white images. ## Model Details ### Model Description ColorizeNet is an image colorization model based on ControlNet, trained using the pre-trained Stable Diffusion model version 2.1 proposed by Stability AI. - **Finetuned from model :** [https://huggingface.co/stabilityai/stable-diffusion-2-1] ### Model Sources [optional] - **Repository:** [https://github.com/rensortino/ColorizeNet] ## Usage ### Training Data The model has been trained on COCO, using all the images in the dataset and converting them to grayscale to use them to condition the ControlNet [https://huggingface.co/datasets/detection-datasets/coco] ### Run the model Instantiate the model and load its configuration and weights ```python import random import cv2 import einops import numpy as np import torch from pytorch_lightning import seed_everything from utils.data import HWC3, apply_color, resize_image from utils.ddim import DDIMSampler from utils.model import create_model, load_state_dict model = create_model('./models/cldm_v21.yaml').cpu() model.load_state_dict(load_state_dict( 'lightning_logs/version_6/checkpoints/colorizenet-sd21.ckpt', location='cuda')) model = model.cuda() ddim_sampler = DDIMSampler(model) ``` Read the image to be colorized ```python input_image = cv2.imread("sample_data/sample1_bw.jpg") input_image = HWC3(input_image) img = resize_image(input_image, resolution=512) H, W, C = img.shape num_samples = 1 control = torch.from_numpy(img.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() ``` Prepare the input and parameters of the model ```python seed = 1294574436 seed_everything(seed) prompt = "Colorize this image" n_prompt = "" guess_mode = False strength = 1.0 eta = 0.0 ddim_steps = 20 scale = 9.0 cond = {"c_concat": [control], "c_crossattn": [ model.get_learned_conditioning([prompt] * num_samples)]} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [ model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) ``` Sample and post-process the results ```python samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] colored_results = [apply_color(img, result) for result in results] ``` ## Results BW Input | Colorized :-------------------------:|:-------------------------: ![image](docs/sample1_bw.jpg) | ![image](docs/sample1.png) ![image](docs/sample2_bw.jpg) | ![image](docs/sample2.png) ![image](docs/sample3_bw.jpg) | ![image](docs/sample3.png) ![image](docs/sample4_bw.jpg) | ![image](docs/sample4.png) ![image](docs/sample5_bw.jpg) | ![image](docs/sample5.png) ![image](docs/sample6_bw.jpg) | ![image](docs/sample6.png)