MattGPT commited on
Commit
508b442
·
verified ·
1 Parent(s): 79d498b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -37
app.py CHANGED
@@ -2,14 +2,41 @@ import os
2
  import cv2
3
  import gradio as gr
4
  import torch
5
- from basicsr.archs.srvgg_arch import SRVGGNetCompact
6
- from gfpgan.utils import GFPGANer
7
- from realesrgan.utils import RealESRGANer
8
 
9
- # Ensure numpy is compatible
 
 
 
 
 
 
10
  os.system("pip install --upgrade 'numpy<2'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Download necessary model weights
13
  weights = {
14
  "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
15
  "GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
@@ -19,38 +46,79 @@ weights = {
19
  "CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
20
  }
21
 
22
- for file, url in weights.items():
23
- if not os.path.exists(file):
24
- os.system(f"wget {url} -P .")
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Load ESRGAN model
27
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
28
  model_path = 'realesr-general-x4v3.pth'
29
- half = True if torch.cuda.is_available() else False
30
- upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
31
-
 
 
 
 
 
 
 
 
 
32
  os.makedirs('output', exist_ok=True)
33
 
34
- # Image Processing Function
 
 
 
35
  def inference(img, version, scale):
 
 
 
 
 
 
 
 
 
 
 
36
  try:
 
37
  img_path = str(img)
38
  extension = os.path.splitext(os.path.basename(img_path))[1]
39
  img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
 
 
 
40
 
 
41
  if len(img.shape) == 3 and img.shape[2] == 4:
42
  img_mode = 'RGBA'
43
  elif len(img.shape) == 2:
44
- img_mode = None
45
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 
46
  else:
47
  img_mode = None
48
 
 
49
  h, w = img.shape[:2]
50
  if h < 300:
51
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
52
 
53
- # Load Face Enhancement Model
54
  model_paths = {
55
  'v1.2': 'GFPGANv1.2.pth',
56
  'v1.3': 'GFPGANv1.3.pth',
@@ -60,62 +128,73 @@ def inference(img, version, scale):
60
  'RealESR-General-x4v3': 'realesr-general-x4v3.pth'
61
  }
62
 
 
63
  face_enhancer = GFPGANer(
64
  model_path=model_paths[version],
65
- upscale=2,
66
  arch='clean' if version.startswith('v1') else version,
67
  channel_multiplier=2,
68
- bg_upsampler=upsampler
69
  )
70
 
71
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
 
 
 
72
 
 
73
  if scale != 2:
74
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
75
- h, w = img.shape[:2]
76
  output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
77
 
78
- if img_mode == 'RGBA':
79
- extension = 'png'
80
- else:
81
- extension = 'jpg'
82
 
83
- save_path = f'output/out.{extension}'
84
  cv2.imwrite(save_path, output)
85
- output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
 
86
 
87
- return output, save_path
88
 
89
  except Exception as error:
90
- print("Error:", error)
91
  return None, None
92
 
93
- # Gradio Blocks UI
 
 
 
94
  with gr.Blocks() as demo:
95
  gr.Markdown("## 📸 Image Upscaling & Restoration")
96
- gr.Markdown("### Enhance old or AI-generated images using GFPGAN & RealESRGAN")
97
 
98
  with gr.Row():
99
  with gr.Column():
100
- image_input = gr.Image(type="filepath", label="Upload Image")
101
  version_selector = gr.Radio(
102
- ['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer', 'RealESR-General-x4v3'],
103
- label="Model Version",
104
  value="v1.4"
105
  )
106
- scale_factor = gr.Number(value=2, label="Rescaling Factor")
107
-
108
  enhance_button = gr.Button("Enhance Image 🚀")
109
-
110
  with gr.Column():
111
- output_image = gr.Image(type="numpy", label="Enhanced Output")
112
  download_link = gr.File(label="Download Enhanced Image")
113
 
 
114
  enhance_button.click(
115
  fn=inference,
116
  inputs=[image_input, version_selector, scale_factor],
117
  outputs=[output_image, download_link]
118
  )
119
 
120
- # Launch the App
 
 
 
121
  demo.launch()
 
 
2
  import cv2
3
  import gradio as gr
4
  import torch
5
+ import requests
 
 
6
 
7
+ # ------------------------------------------------------------------------------
8
+ # Dependency Management
9
+ # ------------------------------------------------------------------------------
10
+
11
+ # Instead of using os.system to manage dependencies in production,
12
+ # it's recommended to use a requirements.txt file.
13
+ # For this demo, we ensure that numpy and torchvision are of compatible versions.
14
  os.system("pip install --upgrade 'numpy<2'")
15
+ os.system("pip install torchvision==0.12.0") # Fixes: ModuleNotFoundError for torchvision.transforms.functional_tensor
16
+
17
+ # ------------------------------------------------------------------------------
18
+ # Utility Function: Download Weight Files
19
+ # ------------------------------------------------------------------------------
20
+
21
+ def download_file(filename, url):
22
+ """
23
+ ELI5: If the file (like a model weight) isn't on your computer, download it!
24
+ """
25
+ if not os.path.exists(filename):
26
+ print(f"Downloading {filename} from {url}...")
27
+ response = requests.get(url, stream=True)
28
+ if response.status_code == 200:
29
+ with open(filename, 'wb') as f:
30
+ for chunk in response.iter_content(chunk_size=8192):
31
+ if chunk:
32
+ f.write(chunk)
33
+ else:
34
+ print(f"Failed to download {filename}")
35
+
36
+ # ------------------------------------------------------------------------------
37
+ # Download Required Model Weights
38
+ # ------------------------------------------------------------------------------
39
 
 
40
  weights = {
41
  "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
42
  "GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
 
46
  "CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
47
  }
48
 
49
+ for filename, url in weights.items():
50
+ download_file(filename, url)
51
+
52
+ # ------------------------------------------------------------------------------
53
+ # Import Model-Related Modules After Ensuring Dependencies
54
+ # ------------------------------------------------------------------------------
55
+
56
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
57
+ from gfpgan.utils import GFPGANer
58
+ from realesrgan.utils import RealESRGANer
59
+
60
+ # ------------------------------------------------------------------------------
61
+ # Initialize ESRGAN Upsampler
62
+ # ------------------------------------------------------------------------------
63
 
64
+ # ELI5: We build a mini brain (model) to help make images look better.
65
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
66
  model_path = 'realesr-general-x4v3.pth'
67
+ half = torch.cuda.is_available() # Use half-precision if you have a GPU.
68
+ upsampler = RealESRGANer(
69
+ scale=4,
70
+ model_path=model_path,
71
+ model=model,
72
+ tile=0,
73
+ tile_pad=10,
74
+ pre_pad=0,
75
+ half=half
76
+ )
77
+
78
+ # Create output directory for saving enhanced images.
79
  os.makedirs('output', exist_ok=True)
80
 
81
+ # ------------------------------------------------------------------------------
82
+ # Image Inference Function
83
+ # ------------------------------------------------------------------------------
84
+
85
  def inference(img, version, scale):
86
+ """
87
+ ELI5: This function takes your uploaded image, picks a model version,
88
+ and a scaling factor. It then:
89
+ 1. Reads your image.
90
+ 2. Checks if it's in a special format (like with transparency).
91
+ 3. Resizes small images for better processing.
92
+ 4. Uses a face enhancement model (GFPGAN) and a background upscaler (RealESRGAN)
93
+ to make the image look better.
94
+ 5. Optionally resizes the final image.
95
+ 6. Saves and returns the enhanced image.
96
+ """
97
  try:
98
+ # Read the image from the provided file path.
99
  img_path = str(img)
100
  extension = os.path.splitext(os.path.basename(img_path))[1]
101
  img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
102
+ if img is None:
103
+ print("Error: Could not read the image. Please check the file.")
104
+ return None, None
105
 
106
+ # Determine the image mode: RGBA (has transparency) or not.
107
  if len(img.shape) == 3 and img.shape[2] == 4:
108
  img_mode = 'RGBA'
109
  elif len(img.shape) == 2:
110
+ # If the image is grayscale, convert it to a color image.
111
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
112
+ img_mode = None
113
  else:
114
  img_mode = None
115
 
116
+ # If the image is too small, double its size.
117
  h, w = img.shape[:2]
118
  if h < 300:
119
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
120
 
121
+ # Map the selected model version to its weight file.
122
  model_paths = {
123
  'v1.2': 'GFPGANv1.2.pth',
124
  'v1.3': 'GFPGANv1.3.pth',
 
128
  'RealESR-General-x4v3': 'realesr-general-x4v3.pth'
129
  }
130
 
131
+ # Initialize GFPGAN for face enhancement.
132
  face_enhancer = GFPGANer(
133
  model_path=model_paths[version],
134
+ upscale=2, # Face region upscale factor.
135
  arch='clean' if version.startswith('v1') else version,
136
  channel_multiplier=2,
137
+ bg_upsampler=upsampler # Use the ESRGAN upsampler for background.
138
  )
139
 
140
+ # Enhance the image.
141
+ _, _, output = face_enhancer.enhance(
142
+ img, has_aligned=False, only_center_face=False, paste_back=True
143
+ )
144
 
145
+ # Optionally, further rescale the enhanced image.
146
  if scale != 2:
147
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
148
+ h, w = output.shape[:2]
149
  output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
150
 
151
+ # Decide on file extension based on image mode.
152
+ extension = 'png' if img_mode == 'RGBA' else 'jpg'
153
+ save_path = os.path.join('output', f'out.{extension}')
 
154
 
155
+ # Save the enhanced image.
156
  cv2.imwrite(save_path, output)
157
+ # Convert BGR to RGB for proper display in Gradio.
158
+ output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
159
 
160
+ return output_rgb, save_path
161
 
162
  except Exception as error:
163
+ print("Error during inference:", error)
164
  return None, None
165
 
166
+ # ------------------------------------------------------------------------------
167
+ # Build the Gradio UI
168
+ # ------------------------------------------------------------------------------
169
+
170
  with gr.Blocks() as demo:
171
  gr.Markdown("## 📸 Image Upscaling & Restoration")
172
+ gr.Markdown("### Enhance your images using GFPGAN & RealESRGAN with a friendly UI!")
173
 
174
  with gr.Row():
175
  with gr.Column():
176
+ image_input = gr.Image(type="filepath", label="Upload Your Image")
177
  version_selector = gr.Radio(
178
+ choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer', 'RealESR-General-x4v3'],
179
+ label="Select Model Version",
180
  value="v1.4"
181
  )
182
+ scale_factor = gr.Number(value=2, label="Rescaling Factor (e.g., 2 for default)")
 
183
  enhance_button = gr.Button("Enhance Image 🚀")
 
184
  with gr.Column():
185
+ output_image = gr.Image(type="numpy", label="Enhanced Image")
186
  download_link = gr.File(label="Download Enhanced Image")
187
 
188
+ # Link the button click to the inference function.
189
  enhance_button.click(
190
  fn=inference,
191
  inputs=[image_input, version_selector, scale_factor],
192
  outputs=[output_image, download_link]
193
  )
194
 
195
+ # ------------------------------------------------------------------------------
196
+ # Launch the Gradio App
197
+ # ------------------------------------------------------------------------------
198
+
199
  demo.launch()
200
+