mike23415 commited on
Commit
7f980fd
·
verified ·
1 Parent(s): 1127754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -25
app.py CHANGED
@@ -10,57 +10,87 @@ from realesrgan import RealESRGANer
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
- # Initialize model
14
- def init_upsampler():
15
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
17
-
18
- return RealESRGANer(
19
- scale=4,
20
- model_path='weights/realesr-general-x4v3.pth',
21
- model=model,
22
- tile=400,
23
- tile_pad=10,
24
- pre_pad=0,
25
- half=False,
26
- device=device
27
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- try:
30
- upsampler = init_upsampler()
31
- print("Model initialized successfully!")
32
- except Exception as e:
33
- print(f"Model initialization failed: {str(e)}")
34
- upsampler = None
35
 
36
  @app.route('/enhance', methods=['POST'])
37
  def enhance_image():
38
  if not upsampler:
39
- return {'error': 'Model not initialized'}, 500
40
 
41
  if 'file' not in request.files:
42
  return {'error': 'No file uploaded'}, 400
43
 
44
  try:
 
45
  file = request.files['file']
 
 
 
46
  img = Image.open(file.stream).convert('RGB')
47
  img_array = np.array(img)
48
 
49
- output, _ = upsampler.enhance(img_array, outscale=4)
 
 
 
 
 
50
 
 
51
  img_byte_arr = io.BytesIO()
52
- Image.fromarray(output).save(img_byte_arr, format='JPEG')
53
  img_byte_arr.seek(0)
54
 
55
  return send_file(img_byte_arr, mimetype='image/jpeg')
56
 
57
  except Exception as e:
58
- return {'error': str(e)}, 500
59
 
60
  @app.route('/health', methods=['GET'])
61
  def health_check():
62
  status = 'ready' if upsampler else 'unavailable'
63
  return {'status': status}, 200
64
 
 
 
 
 
 
 
 
 
 
 
65
  if __name__ == '__main__':
66
  app.run(host='0.0.0.0', port=5000)
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
+ # Initialize Real-ESRGAN upsampler
14
+ def initialize_enhancer():
15
+ try:
16
+ # Configuration for RealESRGAN x4v3
17
+ model = RRDBNet(
18
+ num_in_ch=3,
19
+ num_out_ch=3,
20
+ num_feat=64,
21
+ num_block=6, # Critical parameter for x4v3 model
22
+ num_grow_ch=32,
23
+ scale=4
24
+ )
25
+
26
+ # Force CPU usage for Hugging Face compatibility
27
+ device = torch.device('cpu')
28
+
29
+ return RealESRGANer(
30
+ scale=4,
31
+ model_path='weights/realesr-general-x4v3.pth',
32
+ model=model,
33
+ tile=0, # Set to 0 for small images, increase for large images
34
+ tile_pad=10,
35
+ pre_pad=0,
36
+ half=False, # CPU doesn't support half precision
37
+ device=device
38
+ )
39
+ except Exception as e:
40
+ print(f"Initialization error: {str(e)}")
41
+ return None
42
 
43
+ # Global upsampler instance
44
+ upsampler = initialize_enhancer()
 
 
 
 
45
 
46
  @app.route('/enhance', methods=['POST'])
47
  def enhance_image():
48
  if not upsampler:
49
+ return {'error': 'Model failed to initialize'}, 500
50
 
51
  if 'file' not in request.files:
52
  return {'error': 'No file uploaded'}, 400
53
 
54
  try:
55
+ # Read and validate image
56
  file = request.files['file']
57
+ if file.filename == '':
58
+ return {'error': 'Empty file submitted'}, 400
59
+
60
  img = Image.open(file.stream).convert('RGB')
61
  img_array = np.array(img)
62
 
63
+ # Enhance image
64
+ output, _ = upsampler.enhance(
65
+ img_array,
66
+ outscale=4, # 4x super-resolution
67
+ alpha_upsampler='realesrgan'
68
+ )
69
 
70
+ # Convert to JPEG bytes
71
  img_byte_arr = io.BytesIO()
72
+ Image.fromarray(output).save(img_byte_arr, format='JPEG', quality=95)
73
  img_byte_arr.seek(0)
74
 
75
  return send_file(img_byte_arr, mimetype='image/jpeg')
76
 
77
  except Exception as e:
78
+ return {'error': f'Processing error: {str(e)}'}, 500
79
 
80
  @app.route('/health', methods=['GET'])
81
  def health_check():
82
  status = 'ready' if upsampler else 'unavailable'
83
  return {'status': status}, 200
84
 
85
+ @app.route('/')
86
+ def home():
87
+ return {
88
+ 'message': 'Image Enhancement API',
89
+ 'endpoints': {
90
+ 'POST /enhance': 'Process images (4x upscale)',
91
+ 'GET /health': 'Service status check'
92
+ }
93
+ }, 200
94
+
95
  if __name__ == '__main__':
96
  app.run(host='0.0.0.0', port=5000)