sjtu-deepvision commited on
Commit
bc651cc
·
verified ·
1 Parent(s): b2943a0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
7
 
8
  # 延迟 CUDA 初始化
9
  weight_dtype = torch.float32
@@ -45,6 +46,17 @@ def process_image(input_image):
45
  # 将 Gradio 输入转换为 PIL 图像
46
  input_image = Image.fromarray(input_image)
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  # 处理图像
49
  pipe_out = pipe(
50
  image=input_image,
@@ -58,7 +70,7 @@ def process_image(input_image):
58
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
59
  processed_frame = Image.fromarray(processed_frame)
60
 
61
- return processed_frame
62
 
63
  # 创建 Gradio 界面
64
  def create_gradio_interface():
@@ -70,7 +82,6 @@ def create_gradio_interface():
70
  description = """Official demo for **Dereflection Any Image**.
71
  Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
72
 
73
-
74
  with gr.Blocks() as demo:
75
  gr.Markdown(title)
76
  gr.Markdown(description)
@@ -79,7 +90,7 @@ def create_gradio_interface():
79
  input_image = gr.Image(label="Input Image", type="numpy")
80
  submit_btn = gr.Button("Remove Reflection", variant="primary")
81
  with gr.Column():
82
- output_image = gr.Image(label="Processed Image")
83
 
84
  # 添加示例
85
  gr.Examples(
 
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
+ from gradio_imageslider import ImageSlider # 导入 gradio_imageslider
8
 
9
  # 延迟 CUDA 初始化
10
  weight_dtype = torch.float32
 
46
  # 将 Gradio 输入转换为 PIL 图像
47
  input_image = Image.fromarray(input_image)
48
 
49
+ # 调整输入图像的最大边为 768
50
+ width, height = input_image.size
51
+ max_size = 768
52
+ if width > height:
53
+ new_width = max_size
54
+ new_height = int(height * (max_size / width))
55
+ else:
56
+ new_height = max_size
57
+ new_width = int(width * (max_size / height))
58
+ resized_input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
59
+
60
  # 处理图像
61
  pipe_out = pipe(
62
  image=input_image,
 
70
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
71
  processed_frame = Image.fromarray(processed_frame)
72
 
73
+ return resized_input_image, processed_frame
74
 
75
  # 创建 Gradio 界面
76
  def create_gradio_interface():
 
82
  description = """Official demo for **Dereflection Any Image**.
83
  Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
84
 
 
85
  with gr.Blocks() as demo:
86
  gr.Markdown(title)
87
  gr.Markdown(description)
 
90
  input_image = gr.Image(label="Input Image", type="numpy")
91
  submit_btn = gr.Button("Remove Reflection", variant="primary")
92
  with gr.Column():
93
+ output_image = ImageSlider(label="Processed Image") # 使用 ImageSlider
94
 
95
  # 添加示例
96
  gr.Examples(