Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import gradio as gr | |
import torch | |
import requests | |
# ------------------------------------------------------------------------------ | |
# Dependency Management | |
# ------------------------------------------------------------------------------ | |
# Instead of using os.system to manage dependencies in production, | |
# it's recommended to use a requirements.txt file. | |
# For this demo, we ensure that numpy and torchvision are of compatible versions. | |
os.system("pip install --upgrade 'numpy<2'") | |
os.system("pip install torchvision==0.12.0") # Fixes: ModuleNotFoundError for torchvision.transforms.functional_tensor | |
# ------------------------------------------------------------------------------ | |
# Utility Function: Download Weight Files | |
# ------------------------------------------------------------------------------ | |
def download_file(filename, url): | |
""" | |
ELI5: If the file (like a model weight) isn't on your computer, download it! | |
""" | |
if not os.path.exists(filename): | |
print(f"Downloading {filename} from {url}...") | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
with open(filename, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
else: | |
print(f"Failed to download {filename}") | |
# ------------------------------------------------------------------------------ | |
# Download Required Model Weights | |
# ------------------------------------------------------------------------------ | |
weights = { | |
"realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", | |
"GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth", | |
"GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", | |
"GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", | |
"RestoreFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth", | |
"CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth", | |
} | |
for filename, url in weights.items(): | |
download_file(filename, url) | |
# ------------------------------------------------------------------------------ | |
# Import Model-Related Modules After Ensuring Dependencies | |
# ------------------------------------------------------------------------------ | |
from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
from gfpgan.utils import GFPGANer | |
from realesrgan.utils import RealESRGANer | |
# ------------------------------------------------------------------------------ | |
# Initialize ESRGAN Upsampler | |
# ------------------------------------------------------------------------------ | |
# ELI5: We build a mini brain (model) to help make images look better. | |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
model_path = 'realesr-general-x4v3.pth' | |
half = torch.cuda.is_available() # Use half-precision if you have a GPU. | |
upsampler = RealESRGANer( | |
scale=4, | |
model_path=model_path, | |
model=model, | |
tile=0, | |
tile_pad=10, | |
pre_pad=0, | |
half=half | |
) | |
# Create output directory for saving enhanced images. | |
os.makedirs('output', exist_ok=True) | |
# ------------------------------------------------------------------------------ | |
# Image Inference Function | |
# ------------------------------------------------------------------------------ | |
def inference(img, version, scale): | |
""" | |
ELI5: This function takes your uploaded image, picks a model version, | |
and a scaling factor. It then: | |
1. Reads your image. | |
2. Checks if it's in a special format (like with transparency). | |
3. Resizes small images for better processing. | |
4. Uses a face enhancement model (GFPGAN) and a background upscaler (RealESRGAN) | |
to make the image look better. | |
5. Optionally resizes the final image. | |
6. Saves and returns the enhanced image. | |
""" | |
try: | |
# Read the image from the provided file path. | |
img_path = str(img) | |
extension = os.path.splitext(os.path.basename(img_path))[1] | |
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) | |
if img is None: | |
print("Error: Could not read the image. Please check the file.") | |
return None, None | |
# Determine the image mode: RGBA (has transparency) or not. | |
if len(img.shape) == 3 and img.shape[2] == 4: | |
img_mode = 'RGBA' | |
elif len(img.shape) == 2: | |
# If the image is grayscale, convert it to a color image. | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
img_mode = None | |
else: | |
img_mode = None | |
# If the image is too small, double its size. | |
h, w = img.shape[:2] | |
if h < 300: | |
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) | |
# Map the selected model version to its weight file. | |
model_paths = { | |
'v1.2': 'GFPGANv1.2.pth', | |
'v1.3': 'GFPGANv1.3.pth', | |
'v1.4': 'GFPGANv1.4.pth', | |
'RestoreFormer': 'RestoreFormer.pth', | |
'CodeFormer': 'CodeFormer.pth', | |
'RealESR-General-x4v3': 'realesr-general-x4v3.pth' | |
} | |
# Initialize GFPGAN for face enhancement. | |
face_enhancer = GFPGANer( | |
model_path=model_paths[version], | |
upscale=2, # Face region upscale factor. | |
arch='clean' if version.startswith('v1') else version, | |
channel_multiplier=2, | |
bg_upsampler=upsampler # Use the ESRGAN upsampler for background. | |
) | |
# Enhance the image. | |
_, _, output = face_enhancer.enhance( | |
img, has_aligned=False, only_center_face=False, paste_back=True | |
) | |
# Optionally, further rescale the enhanced image. | |
if scale != 2: | |
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
h, w = output.shape[:2] | |
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) | |
# Decide on file extension based on image mode. | |
extension = 'png' if img_mode == 'RGBA' else 'jpg' | |
save_path = os.path.join('output', f'out.{extension}') | |
# Save the enhanced image. | |
cv2.imwrite(save_path, output) | |
# Convert BGR to RGB for proper display in Gradio. | |
output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) | |
return output_rgb, save_path | |
except Exception as error: | |
print("Error during inference:", error) | |
return None, None | |
# ------------------------------------------------------------------------------ | |
# Build the Gradio UI | |
# ------------------------------------------------------------------------------ | |
with gr.Blocks() as demo: | |
gr.Markdown("## 📸 Image Upscaling & Restoration") | |
gr.Markdown("### Enhance your images using GFPGAN & RealESRGAN with a friendly UI!") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="filepath", label="Upload Your Image") | |
version_selector = gr.Radio( | |
choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer', 'RealESR-General-x4v3'], | |
label="Select Model Version", | |
value="v1.4" | |
) | |
scale_factor = gr.Number(value=2, label="Rescaling Factor (e.g., 2 for default)") | |
enhance_button = gr.Button("Enhance Image 🚀") | |
with gr.Column(): | |
output_image = gr.Image(type="numpy", label="Enhanced Image") | |
download_link = gr.File(label="Download Enhanced Image") | |
# Link the button click to the inference function. | |
enhance_button.click( | |
fn=inference, | |
inputs=[image_input, version_selector, scale_factor], | |
outputs=[output_image, download_link] | |
) | |
# ------------------------------------------------------------------------------ | |
# Launch the Gradio App | |
# ------------------------------------------------------------------------------ | |
demo.launch() | |