sunshangquan commited on
Commit
048e7e4
·
1 Parent(s): 14e98b6

commit from ssq

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -6,8 +6,12 @@ from skimage import img_as_ubyte
6
 
7
  from Allweather.util import load_img, save_img
8
  from basicsr.models.archs.histoformer_arch import Histoformer
 
 
 
 
9
 
10
- model_restoration = Histoformer.from_pretrained("sunsean/Histoformer-real").to("cuda")
11
 
12
  model_restoration.eval()
13
 
@@ -15,7 +19,7 @@ factor = 8
15
  def predict(input_img):
16
  img = np.float32(load_img(input_img))/255.
17
  img = torch.from_numpy(img).permute(2,0,1)
18
- input_ = img.unsqueeze(0).cuda()
19
 
20
  # Padding in case images are not multiples of 8
21
  h,w = input_.shape[2], input_.shape[3]
@@ -27,7 +31,7 @@ def predict(input_img):
27
  restored = model_restoration(input_)
28
  output_path = "restored.png"
29
  restored = restored[:,:,:h,:w]
30
- restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
31
 
32
  save_img(output_path, img_as_ubyte(restored))
33
 
 
6
 
7
  from Allweather.util import load_img, save_img
8
  from basicsr.models.archs.histoformer_arch import Histoformer
9
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
10
+ # True
11
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
12
+ # Tesla T4
13
 
14
+ model_restoration = Histoformer.from_pretrained("sunsean/Histoformer-real").to("cpu")
15
 
16
  model_restoration.eval()
17
 
 
19
  def predict(input_img):
20
  img = np.float32(load_img(input_img))/255.
21
  img = torch.from_numpy(img).permute(2,0,1)
22
+ input_ = img.unsqueeze(0)
23
 
24
  # Padding in case images are not multiples of 8
25
  h,w = input_.shape[2], input_.shape[3]
 
31
  restored = model_restoration(input_)
32
  output_path = "restored.png"
33
  restored = restored[:,:,:h,:w]
34
+ restored = torch.clamp(restored,0,1).detach().permute(0, 2, 3, 1).squeeze(0).numpy()
35
 
36
  save_img(output_path, img_as_ubyte(restored))
37