mike23415 commited on
Commit
d74b38a
·
verified ·
1 Parent(s): eda4fae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -10,43 +10,48 @@ from realesrgan import RealESRGANer
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
- # Initialize model
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
16
- upsampler = RealESRGANer(
17
- scale=4,
18
- model_path='weights/realesr-general-x4v3.pth',
19
- model=model,
20
- tile=0,
21
- tile_pad=10,
22
- pre_pad=0,
23
- half=False,
24
- device=device
25
- )
 
 
 
 
 
26
 
27
  @app.route('/enhance', methods=['POST'])
28
  def enhance_image():
 
 
 
29
  if 'file' not in request.files:
30
  return {'error': 'No file uploaded'}, 400
31
 
32
- file = request.files['file']
33
-
34
  try:
35
- # Read image
36
- input_img = Image.open(file.stream).convert('RGB')
37
-
38
- # Convert to numpy array
39
- img_array = np.array(input_img)
40
-
41
- # Enhance image
42
- output, _ = upsampler.enhance(img_array, outscale=4)
43
 
44
- # Convert back to PIL Image
45
- result_img = Image.fromarray(output)
 
 
 
 
46
 
47
- # Prepare response
48
  img_byte_arr = io.BytesIO()
49
- result_img.save(img_byte_arr, format='JPEG', quality=95)
50
  img_byte_arr.seek(0)
51
 
52
  return send_file(img_byte_arr, mimetype='image/jpeg')
@@ -56,7 +61,8 @@ def enhance_image():
56
 
57
  @app.route('/health', methods=['GET'])
58
  def health_check():
59
- return {'status': 'healthy'}, 200
 
60
 
61
  if __name__ == '__main__':
62
  app.run(host='0.0.0.0', port=5000)
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
+ # Initialize model with explicit device allocation
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
16
+
17
+ try:
18
+ upsampler = RealESRGANer(
19
+ scale=4,
20
+ model_path='weights/realesr-general-x4v3.pth',
21
+ model=model,
22
+ tile=400,
23
+ tile_pad=10,
24
+ pre_pad=0,
25
+ half=False,
26
+ device=device
27
+ )
28
+ except Exception as e:
29
+ print(f"Model initialization failed: {str(e)}")
30
+ upsampler = None
31
 
32
  @app.route('/enhance', methods=['POST'])
33
  def enhance_image():
34
+ if not upsampler:
35
+ return {'error': 'Model not initialized'}, 500
36
+
37
  if 'file' not in request.files:
38
  return {'error': 'No file uploaded'}, 400
39
 
 
 
40
  try:
41
+ file = request.files['file']
42
+ img = Image.open(file.stream).convert('RGB')
43
+ img_array = np.array(img)
 
 
 
 
 
44
 
45
+ # Process image
46
+ output, _ = upsampler.enhance(
47
+ img_array,
48
+ outscale=4,
49
+ alpha_upsampler='realesrgan'
50
+ )
51
 
52
+ # Convert to bytes
53
  img_byte_arr = io.BytesIO()
54
+ Image.fromarray(output).save(img_byte_arr, format='JPEG')
55
  img_byte_arr.seek(0)
56
 
57
  return send_file(img_byte_arr, mimetype='image/jpeg')
 
61
 
62
  @app.route('/health', methods=['GET'])
63
  def health_check():
64
+ status = 'ready' if upsampler else 'initializing'
65
+ return {'status': status}, 200
66
 
67
  if __name__ == '__main__':
68
  app.run(host='0.0.0.0', port=5000)