mike23415 commited on
Commit
b40c2f4
·
verified ·
1 Parent(s): 8bd4072

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -26
app.py CHANGED
@@ -1,42 +1,62 @@
1
- from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
3
- import os
4
  from PIL import Image
5
  import io
6
  import torch
7
- from gfpgan import GFPGANer
 
8
 
9
  app = Flask(__name__)
10
  CORS(app)
11
 
12
- model = GFPGANer(
13
- model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.3.pth',
14
- upscale=2,
15
- arch='clean',
16
- channel_multiplier=2,
17
- bg_upsampler=None,
18
- device='cpu'
 
 
 
 
 
19
  )
20
 
21
  @app.route('/enhance', methods=['POST'])
22
- def enhance():
23
- if 'image' not in request.files:
24
- return jsonify({'error': 'No image uploaded'}), 400
25
-
26
- img_file = request.files['image']
27
- img = Image.open(img_file.stream).convert('RGB')
28
-
29
  try:
30
- _, _, restored_img = model.enhance(np.array(img), has_aligned=False, only_center_face=True, paste_back=True)
31
- result_img = Image.fromarray(restored_img)
32
-
33
- img_io = io.BytesIO()
34
- result_img.save(img_io, 'JPEG')
35
- img_io.seek(0)
36
-
37
- return send_file(img_io, mimetype='image/jpeg')
 
 
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
- return jsonify({'error': str(e)}), 500
 
 
 
 
40
 
41
  if __name__ == '__main__':
42
  app.run(host='0.0.0.0', port=5000)
 
1
+ from flask import Flask, request, send_file
2
  from flask_cors import CORS
3
+ import numpy as np
4
  from PIL import Image
5
  import io
6
  import torch
7
+ from basicsr.archs.rrdbnet_arch import RRDBNet
8
+ from realesrgan import RealESRGANer
9
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
+ # Initialize model
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
16
+ upsampler = RealESRGANer(
17
+ scale=4,
18
+ model_path='weights/realesr-general-x4v3.pth',
19
+ model=model,
20
+ tile=0,
21
+ tile_pad=10,
22
+ pre_pad=0,
23
+ half=False,
24
+ device=device
25
  )
26
 
27
  @app.route('/enhance', methods=['POST'])
28
+ def enhance_image():
29
+ if 'file' not in request.files:
30
+ return {'error': 'No file uploaded'}, 400
31
+
32
+ file = request.files['file']
33
+
 
34
  try:
35
+ # Read image
36
+ input_img = Image.open(file.stream).convert('RGB')
37
+
38
+ # Convert to numpy array
39
+ img_array = np.array(input_img)
40
+
41
+ # Enhance image
42
+ output, _ = upsampler.enhance(img_array, outscale=4)
43
+
44
+ # Convert back to PIL Image
45
+ result_img = Image.fromarray(output)
46
+
47
+ # Prepare response
48
+ img_byte_arr = io.BytesIO()
49
+ result_img.save(img_byte_arr, format='JPEG', quality=95)
50
+ img_byte_arr.seek(0)
51
+
52
+ return send_file(img_byte_arr, mimetype='image/jpeg')
53
+
54
  except Exception as e:
55
+ return {'error': str(e)}, 500
56
+
57
+ @app.route('/health', methods=['GET'])
58
+ def health_check():
59
+ return {'status': 'healthy'}, 200
60
 
61
  if __name__ == '__main__':
62
  app.run(host='0.0.0.0', port=5000)