Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,9 @@ model8.load_weights('weights/RealESRGAN_x8.pth', download=True)
|
|
13 |
|
14 |
|
15 |
def inference(image, size):
|
|
|
|
|
|
|
16 |
if image is None:
|
17 |
raise gr.Error("Image not uploaded")
|
18 |
|
@@ -22,13 +25,29 @@ def inference(image, size):
|
|
22 |
|
23 |
if torch.cuda.is_available():
|
24 |
torch.cuda.empty_cache()
|
25 |
-
|
26 |
if size == '2x':
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
elif size == '4x':
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
30 |
else:
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
print(f"Image size ({device}): {size} ... OK")
|
33 |
return result
|
34 |
|
|
|
13 |
|
14 |
|
15 |
def inference(image, size):
|
16 |
+
global model2
|
17 |
+
global model4
|
18 |
+
global model8
|
19 |
if image is None:
|
20 |
raise gr.Error("Image not uploaded")
|
21 |
|
|
|
25 |
|
26 |
if torch.cuda.is_available():
|
27 |
torch.cuda.empty_cache()
|
28 |
+
|
29 |
if size == '2x':
|
30 |
+
try:
|
31 |
+
result = model2.predict(image.convert('RGB'))
|
32 |
+
except torch.cuda.OutOfMemoryError as e:
|
33 |
+
model2 = RealESRGAN(device, scale=2)
|
34 |
+
model2.load_weights('weights/RealESRGAN_x2.pth', download=False)
|
35 |
+
result = model2.predict(image.convert('RGB'))
|
36 |
elif size == '4x':
|
37 |
+
try:
|
38 |
+
result = model4.predict(image.convert('RGB'))
|
39 |
+
except torch.cuda.OutOfMemoryError as e:
|
40 |
+
model4 = RealESRGAN(device, scale=4)
|
41 |
+
model4.load_weights('weights/RealESRGAN_x4.pth', download=False)
|
42 |
+
result = model4.predict(image.convert('RGB'))
|
43 |
else:
|
44 |
+
try:
|
45 |
+
result = model8.predict(image.convert('RGB'))
|
46 |
+
except torch.cuda.OutOfMemoryError as e:
|
47 |
+
model8 = RealESRGAN(device, scale=8)
|
48 |
+
model8.load_weights('weights/RealESRGAN_x8.pth', download=False)
|
49 |
+
result = model8.predict(image.convert('RGB'))
|
50 |
+
|
51 |
print(f"Image size ({device}): {size} ... OK")
|
52 |
return result
|
53 |
|