mike23415 commited on
Commit
1e8b30f
·
verified ·
1 Parent(s): 4f8e0e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -22
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from flask import Flask, request, send_file
2
  from flask_cors import CORS
3
  import numpy as np
4
  from PIL import Image
@@ -6,13 +6,23 @@ import io
6
  import torch
7
  from basicsr.archs.rrdbnet_arch import RRDBNet
8
  from realesrgan import RealESRGANer
 
 
9
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
- # Initialize Real-ESRGAN upsampler
14
- def initialize_enhancer():
 
 
 
 
 
 
 
15
  try:
 
16
  # Configuration for RealESRGAN x4v3
17
  model = RRDBNet(
18
  num_in_ch=3,
@@ -26,44 +36,69 @@ def initialize_enhancer():
26
  # Force CPU usage for Hugging Face compatibility
27
  device = torch.device('cpu')
28
 
29
- return RealESRGANer(
 
30
  scale=4,
31
  model_path='weights/realesr-general-x4v3.pth',
32
  model=model,
33
- tile=0, # Set to 0 for small images, increase for large images
34
  tile_pad=10,
35
  pre_pad=0,
36
- half=False, # CPU doesn't support half precision
37
  device=device
38
  )
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
- print(f"Initialization error: {str(e)}")
41
- return None
 
 
42
 
43
- # Global upsampler instance
44
- upsampler = initialize_enhancer()
45
 
46
  @app.route('/enhance', methods=['POST'])
47
  def enhance_image():
48
- if not upsampler:
49
- return {'error': 'Model failed to initialize'}, 500
 
 
 
50
 
 
 
 
 
 
51
  if 'file' not in request.files:
52
- return {'error': 'No file uploaded'}, 400
53
 
54
  try:
55
  # Read and validate image
56
  file = request.files['file']
57
  if file.filename == '':
58
- return {'error': 'Empty file submitted'}, 400
59
-
 
60
  img = Image.open(file.stream).convert('RGB')
61
  img_array = np.array(img)
62
 
 
 
 
 
 
63
  # Enhance image
64
  output, _ = upsampler.enhance(
65
  img_array,
66
- outscale=4, # 4x super-resolution
67
  alpha_upsampler='realesrgan'
68
  )
69
 
@@ -75,22 +110,34 @@ def enhance_image():
75
  return send_file(img_byte_arr, mimetype='image/jpeg')
76
 
77
  except Exception as e:
78
- return {'error': f'Processing error: {str(e)}'}, 500
79
 
80
  @app.route('/health', methods=['GET'])
81
  def health_check():
82
- status = 'ready' if upsampler else 'unavailable'
83
- return {'status': status}, 200
 
 
 
 
 
 
 
 
 
 
84
 
85
  @app.route('/')
86
  def home():
87
- return {
88
  'message': 'Image Enhancement API',
89
  'endpoints': {
90
  'POST /enhance': 'Process images (4x upscale)',
91
  'GET /health': 'Service status check'
92
- }
93
- }, 200
 
94
 
95
  if __name__ == '__main__':
 
96
  app.run(host='0.0.0.0', port=5000)
 
1
+ from flask import Flask, request, send_file, jsonify
2
  from flask_cors import CORS
3
  import numpy as np
4
  from PIL import Image
 
6
  import torch
7
  from basicsr.archs.rrdbnet_arch import RRDBNet
8
  from realesrgan import RealESRGANer
9
+ import threading
10
+ import time
11
 
12
  app = Flask(__name__)
13
  CORS(app)
14
 
15
+ # Global variables
16
+ upsampler = None
17
+ initialization_status = "initializing" # "initializing", "ready", or "failed"
18
+ initialization_error = None
19
+
20
+ # Initialize Real-ESRGAN upsampler in a separate thread
21
+ def initialize_enhancer_thread():
22
+ global upsampler, initialization_status, initialization_error
23
+
24
  try:
25
+ print("Starting model initialization...")
26
  # Configuration for RealESRGAN x4v3
27
  model = RRDBNet(
28
  num_in_ch=3,
 
36
  # Force CPU usage for Hugging Face compatibility
37
  device = torch.device('cpu')
38
 
39
+ # Initialize the upsampler
40
+ upsampler = RealESRGANer(
41
  scale=4,
42
  model_path='weights/realesr-general-x4v3.pth',
43
  model=model,
44
+ tile=0, # Set to 0 for small images, increase for large images
45
  tile_pad=10,
46
  pre_pad=0,
47
+ half=False, # CPU doesn't support half precision
48
  device=device
49
  )
50
+
51
+ # Run a small test to ensure everything is loaded
52
+ test_img = np.zeros((64, 64, 3), dtype=np.uint8)
53
+ upsampler.enhance(test_img, outscale=4, alpha_upsampler='realesrgan')
54
+
55
+ print("Model initialization completed successfully")
56
+ initialization_status = "ready"
57
+
58
  except Exception as e:
59
+ error_msg = f"Model initialization failed: {str(e)}"
60
+ print(error_msg)
61
+ initialization_status = "failed"
62
+ initialization_error = error_msg
63
 
64
+ # Start initialization in background thread
65
+ threading.Thread(target=initialize_enhancer_thread, daemon=True).start()
66
 
67
  @app.route('/enhance', methods=['POST'])
68
  def enhance_image():
69
+ global upsampler, initialization_status, initialization_error
70
+
71
+ # Check if model is ready
72
+ if initialization_status == "initializing":
73
+ return jsonify({'error': 'Enhancement model is still initializing. Please try again in a few minutes.'}), 503
74
 
75
+ # Check if initialization failed
76
+ if initialization_status == "failed":
77
+ return jsonify({'error': f'Model initialization failed: {initialization_error}'}), 500
78
+
79
+ # Check if file was uploaded
80
  if 'file' not in request.files:
81
+ return jsonify({'error': 'No file uploaded'}), 400
82
 
83
  try:
84
  # Read and validate image
85
  file = request.files['file']
86
  if file.filename == '':
87
+ return jsonify({'error': 'Empty file submitted'}), 400
88
+
89
+ # Process the image
90
  img = Image.open(file.stream).convert('RGB')
91
  img_array = np.array(img)
92
 
93
+ # Check image size
94
+ h, w = img_array.shape[:2]
95
+ if h > 2000 or w > 2000:
96
+ return jsonify({'error': 'Image too large. Maximum size is 2000x2000 pixels'}), 400
97
+
98
  # Enhance image
99
  output, _ = upsampler.enhance(
100
  img_array,
101
+ outscale=4, # 4x super-resolution
102
  alpha_upsampler='realesrgan'
103
  )
104
 
 
110
  return send_file(img_byte_arr, mimetype='image/jpeg')
111
 
112
  except Exception as e:
113
+ return jsonify({'error': f'Processing error: {str(e)}'}), 500
114
 
115
  @app.route('/health', methods=['GET'])
116
  def health_check():
117
+ global initialization_status, initialization_error
118
+
119
+ status_info = {
120
+ 'status': initialization_status,
121
+ 'timestamp': time.time()
122
+ }
123
+
124
+ if initialization_status == "failed" and initialization_error:
125
+ status_info['error'] = initialization_error
126
+
127
+ # Return 200 OK even if not ready, but include status in response
128
+ return jsonify(status_info)
129
 
130
  @app.route('/')
131
  def home():
132
+ return jsonify({
133
  'message': 'Image Enhancement API',
134
  'endpoints': {
135
  'POST /enhance': 'Process images (4x upscale)',
136
  'GET /health': 'Service status check'
137
+ },
138
+ 'status': initialization_status
139
+ })
140
 
141
  if __name__ == '__main__':
142
+ # Start the Flask app
143
  app.run(host='0.0.0.0', port=5000)