blanchon commited on
Commit
9a1d37d
·
1 Parent(s): 0d7ac74

Refactor image post-processing in gradio_demo_rgb2x.py

Browse files
Files changed (1) hide show
  1. rgb2x/gradio_demo_rgb2x.py +13 -16
rgb2x/gradio_demo_rgb2x.py CHANGED
@@ -1,11 +1,11 @@
 
1
  import spaces
2
- import numpy as np
3
  import os
4
- from typing import cast
5
  import gradio as gr
6
  from PIL import Image
7
  import torch
8
  import torchvision
 
9
  from diffusers import DDIMScheduler
10
  from load_image import load_exr_image, load_ldr_image
11
  from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
@@ -36,15 +36,15 @@ def generate(
36
  num_samples: int,
37
  ) -> list[Image.Image]:
38
  generator = torch.Generator(device="cuda").manual_seed(seed)
39
-
40
- if photo.name.endswith(".exr"):
41
- photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
42
  elif (
43
- photo.name.endswith(".png")
44
- or photo.name.endswith(".jpg")
45
- or photo.name.endswith(".jpeg")
46
  ):
47
- photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
48
 
49
  # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
50
  old_height = photo.shape[1]
@@ -96,10 +96,7 @@ def generate(
96
  generated_image = (generated_image, f"Generated {aov_name} {i}")
97
  return_list.append(generated_image)
98
 
99
- def post_process_image(img, **kwargs):
100
- return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image"))
101
-
102
- return_list.append(post_process_image(photo, label="Input Image"))
103
  return return_list
104
 
105
 
@@ -149,9 +146,9 @@ with gr.Blocks() as demo:
149
  examples = gr.Examples(
150
  examples=[
151
  [
152
- "rgb2x/example/Castlereagh_corridor_photo.png", # Photo
153
- 0, # Seed
154
- 50, # Inference Step
155
  1, # Samples
156
  ]
157
  ],
 
1
+ from typing import cast
2
  import spaces
 
3
  import os
 
4
  import gradio as gr
5
  from PIL import Image
6
  import torch
7
  import torchvision
8
+
9
  from diffusers import DDIMScheduler
10
  from load_image import load_exr_image, load_ldr_image
11
  from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
 
36
  num_samples: int,
37
  ) -> list[Image.Image]:
38
  generator = torch.Generator(device="cuda").manual_seed(seed)
39
+ photo_name = photo.name
40
+ if photo_name.endswith(".exr"):
41
+ photo = load_exr_image(photo_name, tonemaping=True, clamp=True).to("cuda")
42
  elif (
43
+ photo_name.endswith(".png")
44
+ or photo_name.endswith(".jpg")
45
+ or photo_name.endswith(".jpeg")
46
  ):
47
+ photo = load_ldr_image(photo_name, from_srgb=True).to("cuda")
48
 
49
  # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
50
  old_height = photo.shape[1]
 
96
  generated_image = (generated_image, f"Generated {aov_name} {i}")
97
  return_list.append(generated_image)
98
 
99
+ return_list.append((photo_name, "Input Image"))
 
 
 
100
  return return_list
101
 
102
 
 
146
  examples = gr.Examples(
147
  examples=[
148
  [
149
+ "rgb2x/example/Castlereagh_corridor_photo.png", # Photo
150
+ 0, # Seed
151
+ 50, # Inference Step
152
  1, # Samples
153
  ]
154
  ],