import os import cv2 import gradio as gr import torch import requests # ------------------------------------------------------------------------------ # Dependency Management # ------------------------------------------------------------------------------ # Instead of using os.system to manage dependencies in production, # it's recommended to use a requirements.txt file. # For this demo, we ensure that numpy and torchvision are of compatible versions. os.system("pip install --upgrade 'numpy<2'") os.system("pip install torchvision==0.12.0") # Fixes: ModuleNotFoundError for torchvision.transforms.functional_tensor # ------------------------------------------------------------------------------ # Utility Function: Download Weight Files # ------------------------------------------------------------------------------ def download_file(filename, url): """ ELI5: If the file (like a model weight) isn't on your computer, download it! """ if not os.path.exists(filename): print(f"Downloading {filename} from {url}...") response = requests.get(url, stream=True) if response.status_code == 200: with open(filename, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) else: print(f"Failed to download {filename}") # ------------------------------------------------------------------------------ # Download Required Model Weights # ------------------------------------------------------------------------------ weights = { "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth", "GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", "GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", "RestoreFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth", "CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth", } for filename, url in weights.items(): download_file(filename, url) # ------------------------------------------------------------------------------ # Import Model-Related Modules After Ensuring Dependencies # ------------------------------------------------------------------------------ from basicsr.archs.srvgg_arch import SRVGGNetCompact from gfpgan.utils import GFPGANer from realesrgan.utils import RealESRGANer # ------------------------------------------------------------------------------ # Initialize ESRGAN Upsampler # ------------------------------------------------------------------------------ # ELI5: We build a mini brain (model) to help make images look better. model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') model_path = 'realesr-general-x4v3.pth' half = torch.cuda.is_available() # Use half-precision if you have a GPU. upsampler = RealESRGANer( scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half ) # Create output directory for saving enhanced images. os.makedirs('output', exist_ok=True) # ------------------------------------------------------------------------------ # Image Inference Function # ------------------------------------------------------------------------------ def inference(img, version, scale): """ ELI5: This function takes your uploaded image, picks a model version, and a scaling factor. It then: 1. Reads your image. 2. Checks if it's in a special format (like with transparency). 3. Resizes small images for better processing. 4. Uses a face enhancement model (GFPGAN) and a background upscaler (RealESRGAN) to make the image look better. 5. Optionally resizes the final image. 6. Saves and returns the enhanced image. """ try: # Read the image from the provided file path. img_path = str(img) extension = os.path.splitext(os.path.basename(img_path))[1] img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) if img is None: print("Error: Could not read the image. Please check the file.") return None, None # Determine the image mode: RGBA (has transparency) or not. if len(img.shape) == 3 and img.shape[2] == 4: img_mode = 'RGBA' elif len(img.shape) == 2: # If the image is grayscale, convert it to a color image. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img_mode = None else: img_mode = None # If the image is too small, double its size. h, w = img.shape[:2] if h < 300: img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) # Map the selected model version to its weight file. model_paths = { 'v1.2': 'GFPGANv1.2.pth', 'v1.3': 'GFPGANv1.3.pth', 'v1.4': 'GFPGANv1.4.pth', 'RestoreFormer': 'RestoreFormer.pth', 'CodeFormer': 'CodeFormer.pth', 'RealESR-General-x4v3': 'realesr-general-x4v3.pth' } # Initialize GFPGAN for face enhancement. face_enhancer = GFPGANer( model_path=model_paths[version], upscale=2, # Face region upscale factor. arch='clean' if version.startswith('v1') else version, channel_multiplier=2, bg_upsampler=upsampler # Use the ESRGAN upsampler for background. ) # Enhance the image. _, _, output = face_enhancer.enhance( img, has_aligned=False, only_center_face=False, paste_back=True ) # Optionally, further rescale the enhanced image. if scale != 2: interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 h, w = output.shape[:2] output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) # Decide on file extension based on image mode. extension = 'png' if img_mode == 'RGBA' else 'jpg' save_path = os.path.join('output', f'out.{extension}') # Save the enhanced image. cv2.imwrite(save_path, output) # Convert BGR to RGB for proper display in Gradio. output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) return output_rgb, save_path except Exception as error: print("Error during inference:", error) return None, None # ------------------------------------------------------------------------------ # Build the Gradio UI # ------------------------------------------------------------------------------ with gr.Blocks() as demo: gr.Markdown("## 📸 Image Upscaling & Restoration") gr.Markdown("### Enhance your images using GFPGAN & RealESRGAN with a friendly UI!") with gr.Row(): with gr.Column(): image_input = gr.Image(type="filepath", label="Upload Your Image") version_selector = gr.Radio( choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer', 'RealESR-General-x4v3'], label="Select Model Version", value="v1.4" ) scale_factor = gr.Number(value=2, label="Rescaling Factor (e.g., 2 for default)") enhance_button = gr.Button("Enhance Image 🚀") with gr.Column(): output_image = gr.Image(type="numpy", label="Enhanced Image") download_link = gr.File(label="Download Enhanced Image") # Link the button click to the inference function. enhance_button.click( fn=inference, inputs=[image_input, version_selector, scale_factor], outputs=[output_image, download_link] ) # ------------------------------------------------------------------------------ # Launch the Gradio App # ------------------------------------------------------------------------------ demo.launch()