import torch import numpy as np import PIL from PIL import Image from torchvision import transforms from matplotlib import pyplot as plt import gradio as gr import transformers transformers.utils.move_cache() from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel from transformers import BlipProcessor, BlipForConditionalGeneration from accelerate import Accelerator torch.set_num_threads(2) import warnings warnings.filterwarnings("ignore") from models import MainModel, UNetAuto, Autoencoder from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB from stable import blip_image_captioning, apply_color device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Stable diffusion accelerator = Accelerator( mixed_precision="fp16" ) controlnet = ControlNetModel.from_pretrained( pretrained_model_name_or_path="nickpai/sdxl_light_caption_output", subfolder="checkpoint-30000/controlnet", ) pipe = StableDiffusionXLControlNetPipeline.from_pretrained( pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet ) blip_processor = BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-large", ) blip_generator = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-large", ) pipe.to(accelerator.device) blip_generator.to(accelerator.device) def colorize_image_sdxl(image, positive_prompt=None, negative_prompt=None, caption_generate=True, seed=123, infer_steps=5): image = PIL.Image.fromarray(image) torch.cuda.empty_cache() if caption_generate: caption = blip_image_captioning(image=image, device=accelerator.device, processor=blip_processor, generator=blip_generator) else: caption = "" original_size = image.size control_image = image.convert("L").convert("RGB").resize((512, 512)) prompt = [positive_prompt + ", " + caption] colorized_image = pipe(prompt=prompt, num_inference_steps=infer_steps, generator=torch.manual_seed(seed), image=control_image, negative_prompt=negative_prompt).images[0] result_image = apply_color(control_image, colorized_image) result_image = result_image.resize(original_size) return result_image, caption # Hàm load models cho autoencoder và gan def load_autoencoder_model(auto_model_path): unet = UNetAuto(in_channels=1, out_channels=2).to(device) model = Autoencoder(unet).to(device) model.load_state_dict(torch.load(auto_model_path, map_location=device)) model.to(device) model.eval() return model def load_model(generator_model_path, colorization_model_path, model_type='resnet'): if model_type == 'resnet': net_G = build_res_unet(n_input=1, n_output=2, size=256) elif model_type == 'mobilenet': net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256) net_G.load_state_dict(torch.load(generator_model_path, map_location=device)) model = MainModel(net_G=net_G) model.load_state_dict(torch.load(colorization_model_path, map_location=device)) model.to(device) model.eval() return model resnet_model = load_model( "weight/pascal_res18-unet.pt", "weight/pascal_final_model_weights.pt", model_type='resnet' ) mobilenet_model = load_model( "weight/mobile-unet.pt", "weight/mobile_pascal_final_model_weights.pt", model_type='mobilenet' ) autoencoder_model = load_autoencoder_model("weight/autoencoder.pt") # Transformations def preprocess_image(image): image = image.resize((256, 256)) image = transforms.ToTensor()(image)[:1] * 2. - 1. return image def postprocess_image(grayscale, prediction, original_size): # Convert Lab back to RGB and resize to the original image size colorized_image = lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0] colorized_image = Image.fromarray((colorized_image * 255).astype("uint8")) return colorized_image.resize(original_size) # Prediction function with output control def colorize_image(input_image, mode): grayscale_image = Image.fromarray(input_image).convert('L') original_size = grayscale_image.size # Store original size grayscale = preprocess_image(grayscale_image).to(device) with torch.no_grad(): resnet_output = resnet_model.net_G(grayscale.unsqueeze(0)) mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0)) autoencoder_output = autoencoder_model(grayscale.unsqueeze(0)) # Resize outputs to match the original size resnet_colorized = postprocess_image(grayscale, resnet_output, original_size) mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size) autoencoder_colorized = postprocess_image(grayscale, autoencoder_output, original_size) if mode == "ResNet": return resnet_colorized, None, None elif mode == "MobileNet": return None, mobilenet_colorized, None elif mode == "Unet": return None, None, autoencoder_colorized elif mode == "Comparison": return resnet_colorized, mobilenet_colorized, autoencoder_colorized def gradio_interface(): with gr.Blocks() as app: with gr.Tab("Prompt-Free"): with gr.Blocks(): input_image = gr.Image(type="numpy", label="Upload an Image") output_modes = gr.Radio( choices=["ResNet", "MobileNet", "Unet", "Comparison"], value="ResNet", label="Output Mode" ) submit_button = gr.Button("Submit") with gr.Row(): # Place output images in a single row resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False) mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False) autoencoder_output = gr.Image(label="Colorized Image (Unet)", visible=False) def update_visibility(mode): if mode == "ResNet": return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) elif mode == "MobileNet": return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) elif mode == "Unet": return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) elif mode == "Comparison": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) output_modes.change( fn=update_visibility, inputs=[output_modes], outputs=[resnet_output, mobilenet_output, autoencoder_output] ) submit_button.click( fn=colorize_image, inputs=[input_image, output_modes], outputs=[resnet_output, mobilenet_output, autoencoder_output] ) with gr.Tab("Prompt_Guided(ControlNet-SDXL)"): with gr.Blocks(): with gr.Row(): with gr.Column(scale=1): sd_image = gr.Image(label="Upload a Color Image") positive_prompt = gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt") negative_prompt = gr.Textbox( value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate", label="Negative Prompt", placeholder="Text for negative prompt" ) generate_caption = gr.Checkbox(label="Generate Caption?", value=True) seed = gr.Number(label="Seed", value=123, precision=0) inference_steps = gr.Number(label="Inference Steps", value=5, precision=0) submit_sd = gr.Button("Generate") with gr.Column(scale=1): sd_output_image = gr.Image(label="Colorized Image") sd_caption = gr.Textbox(label="Captioning Result", show_copy_button=True, visible=True) submit_sd.click( fn=colorize_image_sdxl, inputs=[sd_image, positive_prompt, negative_prompt, generate_caption, seed, inference_steps], outputs=[sd_output_image, sd_caption] ) return app # Launch if __name__ == "__main__": gradio_interface().launch()