mednow commited on
Commit
8f5e607
·
verified ·
1 Parent(s): a1074ad

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -32
app.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from RealESRGAN import RealESRGAN
5
  from io import BytesIO
6
 
7
- # Define the target size for the image
8
  TARGET_SIZE = (240, 240)
9
 
10
  # Function to load the model based on scale and anime toggle
@@ -21,21 +21,32 @@ def load_model(scale, anime=False):
21
  return model
22
 
23
  def enhance_image(image, scale, anime):
24
- model = load_model(scale, anime=anime)
25
-
26
- # Convert image to RGB if it has an alpha channel
27
- if image.mode != 'RGB':
28
- image = image.convert('RGB')
29
-
30
- # Resize image to target dimensions
31
- image = image.resize(TARGET_SIZE)
32
-
33
- sr_image = model.predict(image)
34
-
35
- buffer = BytesIO()
36
- sr_image.save(buffer, format="PNG")
37
- buffer.seek(0)
38
- return sr_image, buffer
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def main():
41
  st.title("Generative AI Image Restoration")
@@ -59,22 +70,25 @@ def main():
59
 
60
  # Enhance button
61
  if st.button("Restore Image"):
62
- enhanced_image, buffer = enhance_image(image, scale_value, anime)
63
-
64
- # Show images side by side
65
- col1, col2 = st.columns(2)
66
- with col1:
67
- st.image(image, caption="Original Image", use_column_width=True)
68
- with col2:
69
- st.image(enhanced_image, caption="Enhanced Image", use_column_width=True)
70
 
71
- # Download button
72
- st.download_button(
73
- label="Download Enhanced Image",
74
- data=buffer,
75
- file_name="enhanced_image.png",
76
- mime="image/png"
77
- )
 
 
 
 
 
 
 
 
 
 
78
 
79
  if __name__ == "__main__":
80
- main()
 
4
  from RealESRGAN import RealESRGAN
5
  from io import BytesIO
6
 
7
+ # Define the target size for the image (used for initial resizing before enhancement)
8
  TARGET_SIZE = (240, 240)
9
 
10
  # Function to load the model based on scale and anime toggle
 
21
  return model
22
 
23
  def enhance_image(image, scale, anime):
24
+ try:
25
+ model = load_model(scale, anime=anime)
26
+
27
+ # Convert image to RGB if it has an alpha channel
28
+ if image.mode != 'RGB':
29
+ image = image.convert('RGB')
30
+
31
+ # Store original image size
32
+ original_size = image.size
33
+
34
+ # Resize image to target dimensions for processing
35
+ image = image.resize(TARGET_SIZE)
36
+
37
+ # Perform image enhancement
38
+ sr_image = model.predict(image)
39
+
40
+ # Resize the enhanced image back to the original size
41
+ sr_image = sr_image.resize(original_size)
42
+
43
+ buffer = BytesIO()
44
+ sr_image.save(buffer, format="PNG")
45
+ buffer.seek(0)
46
+
47
+ return sr_image, buffer, None
48
+ except Exception as e:
49
+ return None, None, str(e)
50
 
51
  def main():
52
  st.title("Generative AI Image Restoration")
 
70
 
71
  # Enhance button
72
  if st.button("Restore Image"):
73
+ enhanced_image, buffer, error_message = enhance_image(image, scale_value, anime)
 
 
 
 
 
 
 
74
 
75
+ if error_message:
76
+ st.error(f"An error occurred: {error_message}")
77
+ else:
78
+ # Show images side by side
79
+ col1, col2 = st.columns(2)
80
+ with col1:
81
+ st.image(image, caption="Original Image", use_column_width=True)
82
+ with col2:
83
+ st.image(enhanced_image, caption="Enhanced Image", use_column_width=True)
84
+
85
+ # Download button
86
+ st.download_button(
87
+ label="Download Enhanced Image",
88
+ data=buffer,
89
+ file_name="enhanced_image.png",
90
+ mime="image/png"
91
+ )
92
 
93
  if __name__ == "__main__":
94
+ main()