Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,43 +10,48 @@ from realesrgan import RealESRGANer
|
|
10 |
app = Flask(__name__)
|
11 |
CORS(app)
|
12 |
|
13 |
-
# Initialize model
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
@app.route('/enhance', methods=['POST'])
|
28 |
def enhance_image():
|
|
|
|
|
|
|
29 |
if 'file' not in request.files:
|
30 |
return {'error': 'No file uploaded'}, 400
|
31 |
|
32 |
-
file = request.files['file']
|
33 |
-
|
34 |
try:
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
# Convert to numpy array
|
39 |
-
img_array = np.array(input_img)
|
40 |
-
|
41 |
-
# Enhance image
|
42 |
-
output, _ = upsampler.enhance(img_array, outscale=4)
|
43 |
|
44 |
-
#
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
#
|
48 |
img_byte_arr = io.BytesIO()
|
49 |
-
|
50 |
img_byte_arr.seek(0)
|
51 |
|
52 |
return send_file(img_byte_arr, mimetype='image/jpeg')
|
@@ -56,7 +61,8 @@ def enhance_image():
|
|
56 |
|
57 |
@app.route('/health', methods=['GET'])
|
58 |
def health_check():
|
59 |
-
|
|
|
60 |
|
61 |
if __name__ == '__main__':
|
62 |
app.run(host='0.0.0.0', port=5000)
|
|
|
10 |
app = Flask(__name__)
|
11 |
CORS(app)
|
12 |
|
13 |
+
# Initialize model with explicit device allocation
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
16 |
+
|
17 |
+
try:
|
18 |
+
upsampler = RealESRGANer(
|
19 |
+
scale=4,
|
20 |
+
model_path='weights/realesr-general-x4v3.pth',
|
21 |
+
model=model,
|
22 |
+
tile=400,
|
23 |
+
tile_pad=10,
|
24 |
+
pre_pad=0,
|
25 |
+
half=False,
|
26 |
+
device=device
|
27 |
+
)
|
28 |
+
except Exception as e:
|
29 |
+
print(f"Model initialization failed: {str(e)}")
|
30 |
+
upsampler = None
|
31 |
|
32 |
@app.route('/enhance', methods=['POST'])
|
33 |
def enhance_image():
|
34 |
+
if not upsampler:
|
35 |
+
return {'error': 'Model not initialized'}, 500
|
36 |
+
|
37 |
if 'file' not in request.files:
|
38 |
return {'error': 'No file uploaded'}, 400
|
39 |
|
|
|
|
|
40 |
try:
|
41 |
+
file = request.files['file']
|
42 |
+
img = Image.open(file.stream).convert('RGB')
|
43 |
+
img_array = np.array(img)
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
# Process image
|
46 |
+
output, _ = upsampler.enhance(
|
47 |
+
img_array,
|
48 |
+
outscale=4,
|
49 |
+
alpha_upsampler='realesrgan'
|
50 |
+
)
|
51 |
|
52 |
+
# Convert to bytes
|
53 |
img_byte_arr = io.BytesIO()
|
54 |
+
Image.fromarray(output).save(img_byte_arr, format='JPEG')
|
55 |
img_byte_arr.seek(0)
|
56 |
|
57 |
return send_file(img_byte_arr, mimetype='image/jpeg')
|
|
|
61 |
|
62 |
@app.route('/health', methods=['GET'])
|
63 |
def health_check():
|
64 |
+
status = 'ready' if upsampler else 'initializing'
|
65 |
+
return {'status': status}, 200
|
66 |
|
67 |
if __name__ == '__main__':
|
68 |
app.run(host='0.0.0.0', port=5000)
|