mike23415 commited on
Commit
1127754
·
verified ·
1 Parent(s): 89d1265

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -10,12 +10,12 @@ from realesrgan import RealESRGANer
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
- # Initialize model with explicit device allocation
14
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
16
-
17
- try:
18
- upsampler = RealESRGANer(
19
  scale=4,
20
  model_path='weights/realesr-general-x4v3.pth',
21
  model=model,
@@ -25,6 +25,10 @@ try:
25
  half=False,
26
  device=device
27
  )
 
 
 
 
28
  except Exception as e:
29
  print(f"Model initialization failed: {str(e)}")
30
  upsampler = None
@@ -42,14 +46,8 @@ def enhance_image():
42
  img = Image.open(file.stream).convert('RGB')
43
  img_array = np.array(img)
44
 
45
- # Process image
46
- output, _ = upsampler.enhance(
47
- img_array,
48
- outscale=4,
49
- alpha_upsampler='realesrgan'
50
- )
51
 
52
- # Convert to bytes
53
  img_byte_arr = io.BytesIO()
54
  Image.fromarray(output).save(img_byte_arr, format='JPEG')
55
  img_byte_arr.seek(0)
@@ -61,7 +59,7 @@ def enhance_image():
61
 
62
  @app.route('/health', methods=['GET'])
63
  def health_check():
64
- status = 'ready' if upsampler else 'initializing'
65
  return {'status': status}, 200
66
 
67
  if __name__ == '__main__':
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
+ # Initialize model
14
+ def init_upsampler():
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
17
+
18
+ return RealESRGANer(
19
  scale=4,
20
  model_path='weights/realesr-general-x4v3.pth',
21
  model=model,
 
25
  half=False,
26
  device=device
27
  )
28
+
29
+ try:
30
+ upsampler = init_upsampler()
31
+ print("Model initialized successfully!")
32
  except Exception as e:
33
  print(f"Model initialization failed: {str(e)}")
34
  upsampler = None
 
46
  img = Image.open(file.stream).convert('RGB')
47
  img_array = np.array(img)
48
 
49
+ output, _ = upsampler.enhance(img_array, outscale=4)
 
 
 
 
 
50
 
 
51
  img_byte_arr = io.BytesIO()
52
  Image.fromarray(output).save(img_byte_arr, format='JPEG')
53
  img_byte_arr.seek(0)
 
59
 
60
  @app.route('/health', methods=['GET'])
61
  def health_check():
62
+ status = 'ready' if upsampler else 'unavailable'
63
  return {'status': status}, 200
64
 
65
  if __name__ == '__main__':