MattGPT's picture
Update app.py
508b442 verified
import os
import cv2
import gradio as gr
import torch
import requests
# ------------------------------------------------------------------------------
# Dependency Management
# ------------------------------------------------------------------------------
# Instead of using os.system to manage dependencies in production,
# it's recommended to use a requirements.txt file.
# For this demo, we ensure that numpy and torchvision are of compatible versions.
os.system("pip install --upgrade 'numpy<2'")
os.system("pip install torchvision==0.12.0") # Fixes: ModuleNotFoundError for torchvision.transforms.functional_tensor
# ------------------------------------------------------------------------------
# Utility Function: Download Weight Files
# ------------------------------------------------------------------------------
def download_file(filename, url):
"""
ELI5: If the file (like a model weight) isn't on your computer, download it!
"""
if not os.path.exists(filename):
print(f"Downloading {filename} from {url}...")
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
else:
print(f"Failed to download {filename}")
# ------------------------------------------------------------------------------
# Download Required Model Weights
# ------------------------------------------------------------------------------
weights = {
"realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
"GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
"GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
"GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
"RestoreFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth",
"CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
}
for filename, url in weights.items():
download_file(filename, url)
# ------------------------------------------------------------------------------
# Import Model-Related Modules After Ensuring Dependencies
# ------------------------------------------------------------------------------
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
# ------------------------------------------------------------------------------
# Initialize ESRGAN Upsampler
# ------------------------------------------------------------------------------
# ELI5: We build a mini brain (model) to help make images look better.
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
half = torch.cuda.is_available() # Use half-precision if you have a GPU.
upsampler = RealESRGANer(
scale=4,
model_path=model_path,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=half
)
# Create output directory for saving enhanced images.
os.makedirs('output', exist_ok=True)
# ------------------------------------------------------------------------------
# Image Inference Function
# ------------------------------------------------------------------------------
def inference(img, version, scale):
"""
ELI5: This function takes your uploaded image, picks a model version,
and a scaling factor. It then:
1. Reads your image.
2. Checks if it's in a special format (like with transparency).
3. Resizes small images for better processing.
4. Uses a face enhancement model (GFPGAN) and a background upscaler (RealESRGAN)
to make the image look better.
5. Optionally resizes the final image.
6. Saves and returns the enhanced image.
"""
try:
# Read the image from the provided file path.
img_path = str(img)
extension = os.path.splitext(os.path.basename(img_path))[1]
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img is None:
print("Error: Could not read the image. Please check the file.")
return None, None
# Determine the image mode: RGBA (has transparency) or not.
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2:
# If the image is grayscale, convert it to a color image.
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img_mode = None
else:
img_mode = None
# If the image is too small, double its size.
h, w = img.shape[:2]
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
# Map the selected model version to its weight file.
model_paths = {
'v1.2': 'GFPGANv1.2.pth',
'v1.3': 'GFPGANv1.3.pth',
'v1.4': 'GFPGANv1.4.pth',
'RestoreFormer': 'RestoreFormer.pth',
'CodeFormer': 'CodeFormer.pth',
'RealESR-General-x4v3': 'realesr-general-x4v3.pth'
}
# Initialize GFPGAN for face enhancement.
face_enhancer = GFPGANer(
model_path=model_paths[version],
upscale=2, # Face region upscale factor.
arch='clean' if version.startswith('v1') else version,
channel_multiplier=2,
bg_upsampler=upsampler # Use the ESRGAN upsampler for background.
)
# Enhance the image.
_, _, output = face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True
)
# Optionally, further rescale the enhanced image.
if scale != 2:
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = output.shape[:2]
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
# Decide on file extension based on image mode.
extension = 'png' if img_mode == 'RGBA' else 'jpg'
save_path = os.path.join('output', f'out.{extension}')
# Save the enhanced image.
cv2.imwrite(save_path, output)
# Convert BGR to RGB for proper display in Gradio.
output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
return output_rgb, save_path
except Exception as error:
print("Error during inference:", error)
return None, None
# ------------------------------------------------------------------------------
# Build the Gradio UI
# ------------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("## 📸 Image Upscaling & Restoration")
gr.Markdown("### Enhance your images using GFPGAN & RealESRGAN with a friendly UI!")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="filepath", label="Upload Your Image")
version_selector = gr.Radio(
choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer', 'RealESR-General-x4v3'],
label="Select Model Version",
value="v1.4"
)
scale_factor = gr.Number(value=2, label="Rescaling Factor (e.g., 2 for default)")
enhance_button = gr.Button("Enhance Image 🚀")
with gr.Column():
output_image = gr.Image(type="numpy", label="Enhanced Image")
download_link = gr.File(label="Download Enhanced Image")
# Link the button click to the inference function.
enhance_button.click(
fn=inference,
inputs=[image_input, version_selector, scale_factor],
outputs=[output_image, download_link]
)
# ------------------------------------------------------------------------------
# Launch the Gradio App
# ------------------------------------------------------------------------------
demo.launch()