import gradio as gr from PIL import Image, ImageEnhance import numpy as np import cv2 from lang_sam import LangSAM from color_matcher import ColorMatcher from color_matcher.normalizer import Normalizer # Load the LangSAM model model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("", "") # Function to apply color matching based on reference image def apply_color_matching(source_img_np, ref_img_np): # Initialize ColorMatcher cm = ColorMatcher() # Apply color matching img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl') # Normalize the result img_res = Normalizer(img_res).uint8_norm() return img_res # Function to extract sky and apply color matching using a reference image def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"): # Use LangSAM to predict the mask for the sky masks, boxes, phrases, logits = model.predict(image_pil, text_prompt) # Convert the mask to a binary format and create a mask image sky_mask = masks[0].astype(np.uint8) * 255 # Convert PIL image to numpy array for processing img_np = np.array(image_pil) # Convert sky mask to 3-channel format to blend with the original image sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask]) # Extract the sky region sky_region = cv2.bitwise_and(img_np, sky_mask_3ch) # Convert the reference image to a numpy array ref_img_np = np.array(reference_image_pil) # Apply color matching using the reference image to the extracted sky region sky_region_color_matched = apply_color_matching(sky_region, ref_img_np) # Combine the color-matched sky region back into the original image result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np) # Convert the result back to PIL Image for final output result_img_pil = Image.fromarray(result_img_np) return result_img_pil # Gradio Interface def gradio_interface(): # Gradio function to be called on input def process_image(source_img, ref_img): # Extract sky and apply color matching using reference image result_img_pil = extract_and_color_match_sky(source_img, ref_img) return result_img_pil # Define Gradio input components inputs = [ gr.Image(type="pil", label="Source Image"), gr.Image(type="pil", label="Reference Image") # Second input for reference image ] # Define Gradio output component outputs = gr.Image(type="pil", label="Resulting Image") # Launch Gradio app gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch() # Run the Gradio Interface if __name__ == "__main__": gradio_interface()