mednow commited on
Commit
1136c1e
·
verified ·
1 Parent(s): 8f5e607

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -31
app.py CHANGED
@@ -4,49 +4,47 @@ import torch
4
  from RealESRGAN import RealESRGAN
5
  from io import BytesIO
6
 
7
- # Define the target size for the image (used for initial resizing before enhancement)
8
- TARGET_SIZE = (240, 240)
9
-
10
  # Function to load the model based on scale and anime toggle
11
  def load_model(scale, anime=False):
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
- model = RealESRGAN(device, scale=scale, anime=anime)
14
- model_path = {
15
- (2, False): 'model/RealESRGAN_x2.pth',
16
- (4, False): 'model/RealESRGAN_x4plus.pth',
17
- (8, False): 'model/RealESRGAN_x8.pth',
18
- (4, True): 'model/RealESRGAN_x4plus_anime_6B.pth'
19
- }[(scale, anime)]
20
- model.load_weights(model_path)
21
- return model
 
 
 
 
22
 
23
  def enhance_image(image, scale, anime):
 
 
 
 
24
  try:
25
- model = load_model(scale, anime=anime)
26
-
27
  # Convert image to RGB if it has an alpha channel
28
  if image.mode != 'RGB':
29
  image = image.convert('RGB')
30
 
31
- # Store original image size
32
- original_size = image.size
33
-
34
- # Resize image to target dimensions for processing
35
- image = image.resize(TARGET_SIZE)
36
-
37
- # Perform image enhancement
38
  sr_image = model.predict(image)
39
 
40
- # Resize the enhanced image back to the original size
41
- sr_image = sr_image.resize(original_size)
42
 
 
43
  buffer = BytesIO()
44
  sr_image.save(buffer, format="PNG")
45
  buffer.seek(0)
46
-
47
- return sr_image, buffer, None
48
  except Exception as e:
49
- return None, None, str(e)
 
50
 
51
  def main():
52
  st.title("Generative AI Image Restoration")
@@ -70,11 +68,9 @@ def main():
70
 
71
  # Enhance button
72
  if st.button("Restore Image"):
73
- enhanced_image, buffer, error_message = enhance_image(image, scale_value, anime)
74
 
75
- if error_message:
76
- st.error(f"An error occurred: {error_message}")
77
- else:
78
  # Show images side by side
79
  col1, col2 = st.columns(2)
80
  with col1:
 
4
  from RealESRGAN import RealESRGAN
5
  from io import BytesIO
6
 
 
 
 
7
  # Function to load the model based on scale and anime toggle
8
  def load_model(scale, anime=False):
9
+ try:
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ model = RealESRGAN(device, scale=scale, anime=anime)
12
+ model_path = {
13
+ (2, False): 'model/RealESRGAN_x2.pth',
14
+ (4, False): 'model/RealESRGAN_x4plus.pth',
15
+ (8, False): 'model/RealESRGAN_x8.pth',
16
+ (4, True): 'model/RealESRGAN_x4plus_anime_6B.pth'
17
+ }[(scale, anime)]
18
+ model.load_weights(model_path)
19
+ return model
20
+ except Exception as e:
21
+ st.error(f"Error loading the model: {e}")
22
+ return None
23
 
24
  def enhance_image(image, scale, anime):
25
+ model = load_model(scale, anime=anime)
26
+ if model is None:
27
+ return None, None
28
+
29
  try:
 
 
30
  # Convert image to RGB if it has an alpha channel
31
  if image.mode != 'RGB':
32
  image = image.convert('RGB')
33
 
34
+ # Process the image with the model
 
 
 
 
 
 
35
  sr_image = model.predict(image)
36
 
37
+ # Ensure the enhanced image has the same size as the original
38
+ sr_image = sr_image.resize(image.size)
39
 
40
+ # Save enhanced image to buffer
41
  buffer = BytesIO()
42
  sr_image.save(buffer, format="PNG")
43
  buffer.seek(0)
44
+ return sr_image, buffer
 
45
  except Exception as e:
46
+ st.error(f"Error enhancing the image: {e}")
47
+ return None, None
48
 
49
  def main():
50
  st.title("Generative AI Image Restoration")
 
68
 
69
  # Enhance button
70
  if st.button("Restore Image"):
71
+ enhanced_image, buffer = enhance_image(image, scale_value, anime)
72
 
73
+ if enhanced_image:
 
 
74
  # Show images side by side
75
  col1, col2 = st.columns(2)
76
  with col1: