import torch from PIL import Image from torchvision import transforms from matplotlib import pyplot as plt import gradio as gr from models import MainModel # Import class for your main model from utils import lab_to_rgb, build_res_unet#, build_mobile_unet # Utility to convert LAB to RGB device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(generator_model_path, colorization_model_path): #, model_type='resnet') #if model_type == 'resnet': net_G = build_res_unet(n_input=1, n_output=2, size=256) # elif model_type == 'mobilenet': # net_G = build_mobile_unet(n_input=1, n_output=2, size=256) net_G.load_state_dict(torch.load(generator_model_path, map_location=device)) # Create MainModel and load weights model = MainModel(net_G=net_G) model.load_state_dict(torch.load(colorization_model_path, map_location=device)) # Move model to device and set to eval mode model.to(device) model.eval() return model # Load pretrained models resnet_model = load_model( "weight/pascal_res18-unet.pt", "weight/pascal_final_model_weights.pt" # model_type='resnet' ) # mobilenet_model = load_model( # "weight/mobile-unet.pt", # "weight/mobile_pascal_final_model_weights.pt", # model_type='mobilenet' # ) # Transformations def preprocess_image(image): image = image.resize((256, 256)) image = transforms.ToTensor()(image)[:1] * 2. - 1. # Normalize to [-1, 1] return image def postprocess_image(grayscale, prediction): return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0] # Prediction function def colorize_image(input_image): # Convert input to grayscale input_image = Image.fromarray(input_image).convert('L') grayscale = preprocess_image(input_image).to(device) # Generate predictions with torch.no_grad(): resnet_output = resnet_model.net_G(grayscale.unsqueeze(0)) # mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0)) # Post-process results resnet_colorized = postprocess_image(grayscale, resnet_output) # mobilenet_colorized = postprocess_image(grayscale, mobilenet_output) return ( input_image, # Grayscale image resnet_colorized # ResNet18 colorized image # mobilenet_colorized # MobileNet colorized image ) # Gradio Interface interface = gr.Interface( fn=colorize_image, inputs=gr.Image(type="numpy", label="Upload a Color Image"), outputs=[ gr.Image(label="Grayscale Image"), gr.Image(label="Colorized Image (ResNet18)") # gr.Image(label="Colorized Image (MobileNet)") ], title="Image Colorization", description="Upload a color image" ) # Launch Gradio app if __name__ == '__main__': interface.launch()