0x90e commited on
Commit
6887d0a
·
1 Parent(s): 4f2336f

Better colab compat

Browse files
Files changed (3) hide show
  1. app.py +7 -2
  2. test.py +8 -5
  3. util.py +6 -0
app.py CHANGED
@@ -1,10 +1,13 @@
1
  import os
2
  import gradio as gr
3
  import random
 
4
  from random import randint
5
  import sys
6
  from subprocess import call
7
 
 
 
8
  def run_cmd(command):
9
  try:
10
  print(command)
@@ -29,8 +32,10 @@ def inference(img):
29
  input_image = gr.inputs.Image(type="pil", label="Input")
30
  output_image = gr.outputs.Image(type="file", label="Output")
31
 
32
- gr.Interface(
33
  inference,
34
  inputs=[input_image],
35
  outputs=[output_image]
36
- ).launch()
 
 
 
1
  import os
2
  import gradio as gr
3
  import random
4
+ import util
5
  from random import randint
6
  import sys
7
  from subprocess import call
8
 
9
+ is_colab = util.is_google_colab()
10
+
11
  def run_cmd(command):
12
  try:
13
  print(command)
 
32
  input_image = gr.inputs.Image(type="pil", label="Input")
33
  output_image = gr.outputs.Image(type="file", label="Output")
34
 
35
+ demo = gr.Interface(
36
  inference,
37
  inputs=[input_image],
38
  outputs=[output_image]
39
+ )
40
+
41
+ demo.launch(debug=is_colab, share=is_colab)
test.py CHANGED
@@ -6,18 +6,21 @@ import numpy as np
6
  import torch
7
  import architecture as arch
8
  import multiprocessing
 
9
 
10
- cpu_count = multiprocessing.cpu_count()
11
-
12
- print(cpu_count)
 
 
13
 
14
  model_path = '4x_eula_digimanga_bw_v2_nc1_307k.pth'
15
  img_path = sys.argv[1]
16
  output_dir = sys.argv[2]
17
- device = torch.device('cpu')
18
 
19
  model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
20
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
21
  model.eval()
22
 
23
  for k, v in model.named_parameters():
 
6
  import torch
7
  import architecture as arch
8
  import multiprocessing
9
+ import util
10
 
11
+ def is_cuda():
12
+ if torch.cuda.is_available() or not util.is_google_colab():
13
+ return True
14
+ else:
15
+ return False
16
 
17
  model_path = '4x_eula_digimanga_bw_v2_nc1_307k.pth'
18
  img_path = sys.argv[1]
19
  output_dir = sys.argv[2]
20
+ device = torch.device('cuda' if is_cuda() else 'cpu')
21
 
22
  model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
23
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda' if is_cuda() else 'cpu')), strict=True)
24
  model.eval()
25
 
26
  for k, v in model.named_parameters():
util.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def is_google_colab():
2
+ try:
3
+ import google.colab
4
+ return True
5
+ except:
6
+ return False