abiabidali commited on
Commit
8de6861
·
verified ·
1 Parent(s): c8dab2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -52
app.py CHANGED
@@ -1,13 +1,12 @@
1
-
2
  import torch
3
  from PIL import Image
4
  from RealESRGAN import RealESRGAN
5
  import gradio as gr
6
  import numpy as np
7
- import tempfile
8
- import time
9
  import zipfile
10
  import os
 
11
 
12
  # Set the device to CUDA if available, otherwise CPU
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -20,7 +19,6 @@ def load_model(scale):
20
  print(f"Weights for scale {scale} loaded successfully.")
21
  except Exception as e:
22
  print(f"Error loading weights for scale {scale}: {e}")
23
- model.load_weights(weights_path, download=False)
24
  return model
25
 
26
  # Load models for different scales
@@ -33,15 +31,8 @@ def enhance_image(image, scale):
33
  print(f"Enhancing image with scale {scale}...")
34
  start_time = time.time()
35
  image_np = np.array(image.convert('RGB'))
36
- print(f"Image converted to numpy array: shape {image_np.shape}, dtype {image_np.dtype}")
37
-
38
- if scale == '2x':
39
- result = model2.predict(image_np)
40
- elif scale == '4x':
41
- result = model4.predict(image_np)
42
- else:
43
- result = model8.predict(image_np)
44
-
45
  enhanced_image = Image.fromarray(np.uint8(result))
46
  print(f"Image enhanced in {time.time() - start_time:.2f} seconds")
47
  return enhanced_image
@@ -49,55 +40,48 @@ def enhance_image(image, scale):
49
  print(f"Error enhancing image: {e}")
50
  return image
51
 
52
- def muda_dpi(input_image, dpi):
53
- dpi_tuple = (dpi, dpi)
54
- image = Image.fromarray(input_image.astype('uint8'), 'RGB')
55
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
56
- image.save(temp_file, format='JPEG', dpi=dpi_tuple)
57
- temp_file.close()
58
- return Image.open(temp_file.name)
 
59
 
60
- def resize_image(input_image, width, height):
61
- image = Image.fromarray(input_image.astype('uint8'), 'RGB')
62
- resized_image = image.resize((width, height))
63
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
64
- resized_image.save(temp_file, format='JPEG')
65
- temp_file.close()
66
- return Image.open(temp_file.name)
67
 
68
  def process_images(image_files, enhance, scale, adjust_dpi, dpi, resize, width, height):
69
  processed_images = []
70
- temp_dir = tempfile.mkdtemp()
71
 
72
  for image_file in image_files:
73
- input_image = np.array(Image.open(image_file).convert('RGB'))
74
- original_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
75
 
76
  if enhance:
77
- original_image = enhance_image(original_image, scale)
78
 
79
  if adjust_dpi:
80
- original_image = muda_dpi(np.array(original_image), dpi)
81
-
82
  if resize:
83
- original_image = resize_image(np.array(original_image), width, height)
84
 
85
- # Save each image as JPEG, preserving the original filename
86
- file_name = os.path.basename(image_file.name)
87
- output_path = os.path.join(temp_dir, file_name)
88
- original_image.save(output_path, format='JPEG')
89
- processed_images.append(output_path)
90
-
91
- # Create a ZIP file with all processed images
92
- zip_path = os.path.join(temp_dir, 'processed_images.zip')
93
- with zipfile.ZipFile(zip_path, 'w') as zipf:
94
- for file_path in processed_images:
95
- zipf.write(file_path, os.path.basename(file_path))
96
 
97
- # Load images for display in the gallery
98
- display_images = [Image.open(img_path) for img_path in processed_images]
99
-
100
- return display_images, zip_path
101
 
102
  iface = gr.Interface(
103
  fn=process_images,
@@ -120,6 +104,3 @@ iface = gr.Interface(
120
  )
121
 
122
  iface.launch(debug=True)
123
-
124
-
125
-
 
 
1
  import torch
2
  from PIL import Image
3
  from RealESRGAN import RealESRGAN
4
  import gradio as gr
5
  import numpy as np
6
+ import io
 
7
  import zipfile
8
  import os
9
+ import time
10
 
11
  # Set the device to CUDA if available, otherwise CPU
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
19
  print(f"Weights for scale {scale} loaded successfully.")
20
  except Exception as e:
21
  print(f"Error loading weights for scale {scale}: {e}")
 
22
  return model
23
 
24
  # Load models for different scales
 
31
  print(f"Enhancing image with scale {scale}...")
32
  start_time = time.time()
33
  image_np = np.array(image.convert('RGB'))
34
+ model = model2 if scale == '2x' else model4 if scale == '4x' else model8
35
+ result = model.predict(image_np)
 
 
 
 
 
 
 
36
  enhanced_image = Image.fromarray(np.uint8(result))
37
  print(f"Image enhanced in {time.time() - start_time:.2f} seconds")
38
  return enhanced_image
 
40
  print(f"Error enhancing image: {e}")
41
  return image
42
 
43
+ def muda_dpi(image, dpi):
44
+ try:
45
+ with io.BytesIO() as output:
46
+ image.save(output, format='JPEG', dpi=(dpi, dpi))
47
+ return Image.open(output)
48
+ except Exception as e:
49
+ print(f"Error adjusting DPI: {e}")
50
+ return image
51
 
52
+ def resize_image(image, width, height):
53
+ try:
54
+ resized_image = image.resize((width, height))
55
+ return resized_image
56
+ except Exception as e:
57
+ print(f"Error resizing image: {e}")
58
+ return image
59
 
60
  def process_images(image_files, enhance, scale, adjust_dpi, dpi, resize, width, height):
61
  processed_images = []
62
+ zip_buffer = io.BytesIO()
63
 
64
  for image_file in image_files:
65
+ image = Image.open(image_file).convert('RGB')
 
66
 
67
  if enhance:
68
+ image = enhance_image(image, scale)
69
 
70
  if adjust_dpi:
71
+ image = muda_dpi(image, dpi)
72
+
73
  if resize:
74
+ image = resize_image(image, width, height)
75
 
76
+ # Save image to the in-memory ZIP buffer
77
+ buffer = io.BytesIO()
78
+ image.save(buffer, format='JPEG')
79
+ processed_images.append(Image.open(io.BytesIO(buffer.getvalue())))
80
+ with zipfile.ZipFile(zip_buffer, 'a') as zipf:
81
+ zipf.writestr(os.path.basename(image_file.name), buffer.getvalue())
 
 
 
 
 
82
 
83
+ zip_buffer.seek(0)
84
+ return processed_images, zip_buffer
 
 
85
 
86
  iface = gr.Interface(
87
  fn=process_images,
 
104
  )
105
 
106
  iface.launch(debug=True)