Eugene Siow commited on
Commit
3997eb3
·
1 Parent(s): e7c6334

Add fix for output to PIL image.

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -1,8 +1,7 @@
 
1
  import torch
2
- from torchvision import transforms
3
  import gradio as gr
4
- from random import randint
5
- from pathlib import Path
6
  from super_image import ImageLoader, EdsrModel, MsrnModel, MdsrModel, AwsrnModel, A2nModel, CarnModel, PanModel, \
7
  HanModel, DrlnModel, RcanModel
8
 
@@ -40,17 +39,14 @@ def get_model(model_name, scale):
40
 
41
 
42
  def inference(img, scale_str, model_name):
43
- _id = randint(1, 1000)
44
- output_dir = Path('./tmp/')
45
- output_dir.mkdir(parents=True, exist_ok=True)
46
- # output_file = output_dir / ('output_image' + str(_id) + '.jpg')
47
  scale = int(scale_str.replace('x', ''))
48
  model = get_model(model_name, scale)
49
  inputs = ImageLoader.load_image(img)
50
  preds = model(inputs)
51
- # output_file_str = str(output_file.resolve())
52
- # ImageLoader.save_image(preds, output_file_str)
53
- return transforms.ToPILImage(mode='RGB')(preds)
 
54
 
55
 
56
  torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/baby_mini_d3_gaussian.bmp',
@@ -79,5 +75,6 @@ gr.Interface(
79
  ['baby.bmp', 'x2', 'EDSR-base'],
80
  ['woman.bmp', 'x3', 'DRLN']
81
  ],
82
- enable_queue=True
 
83
  ).launch(debug=False)
 
1
+ import cv2
2
  import torch
 
3
  import gradio as gr
4
+ from torchvision import transforms
 
5
  from super_image import ImageLoader, EdsrModel, MsrnModel, MdsrModel, AwsrnModel, A2nModel, CarnModel, PanModel, \
6
  HanModel, DrlnModel, RcanModel
7
 
 
39
 
40
 
41
  def inference(img, scale_str, model_name):
 
 
 
 
42
  scale = int(scale_str.replace('x', ''))
43
  model = get_model(model_name, scale)
44
  inputs = ImageLoader.load_image(img)
45
  preds = model(inputs)
46
+ preds = preds.data.cpu().numpy()
47
+ pred = preds[0].transpose((1, 2, 0)) * 255.0
48
+ pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
49
+ return transforms.ToPILImage(mode='RGB')(pred)
50
 
51
 
52
  torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/baby_mini_d3_gaussian.bmp',
 
75
  ['baby.bmp', 'x2', 'EDSR-base'],
76
  ['woman.bmp', 'x3', 'DRLN']
77
  ],
78
+ enable_queue=True,
79
+ allow_flagging=False,
80
  ).launch(debug=False)