sjtu-deepvision commited on
Commit
30a2134
·
verified ·
1 Parent(s): c0fa2d4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -14
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
- from gradio_imageslider import ImageSlider # 导入 gradio_imageslider
8
 
9
  # 延迟 CUDA 初始化
10
  weight_dtype = torch.float32
@@ -46,17 +45,6 @@ def process_image(input_image):
46
  # 将 Gradio 输入转换为 PIL 图像
47
  input_image = Image.fromarray(input_image)
48
 
49
- # 调整输入图像的最大边为 768
50
- width, height = input_image.size
51
- max_size = 768
52
- if width > height:
53
- new_width = max_size
54
- new_height = int(height * (max_size / width))
55
- else:
56
- new_height = max_size
57
- new_width = int(width * (max_size / height))
58
- resized_input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
59
-
60
  # 处理图像
61
  pipe_out = pipe(
62
  image=input_image,
@@ -70,7 +58,7 @@ def process_image(input_image):
70
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
71
  processed_frame = Image.fromarray(processed_frame)
72
 
73
- return resized_input_image, processed_frame
74
 
75
  # 创建 Gradio 界面
76
  def create_gradio_interface():
@@ -82,6 +70,7 @@ def create_gradio_interface():
82
  description = """Official demo for **Dereflection Any Image**.
83
  Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
84
 
 
85
  with gr.Blocks() as demo:
86
  gr.Markdown(title)
87
  gr.Markdown(description)
@@ -90,7 +79,7 @@ def create_gradio_interface():
90
  input_image = gr.Image(label="Input Image", type="numpy")
91
  submit_btn = gr.Button("Remove Reflection", variant="primary")
92
  with gr.Column():
93
- output_image = ImageSlider(label="Processed Image") # 使用 ImageSlider
94
 
95
  # 添加示例
96
  gr.Examples(
 
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
7
 
8
  # 延迟 CUDA 初始化
9
  weight_dtype = torch.float32
 
45
  # 将 Gradio 输入转换为 PIL 图像
46
  input_image = Image.fromarray(input_image)
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  # 处理图像
49
  pipe_out = pipe(
50
  image=input_image,
 
58
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
59
  processed_frame = Image.fromarray(processed_frame)
60
 
61
+ return processed_frame
62
 
63
  # 创建 Gradio 界面
64
  def create_gradio_interface():
 
70
  description = """Official demo for **Dereflection Any Image**.
71
  Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
72
 
73
+
74
  with gr.Blocks() as demo:
75
  gr.Markdown(title)
76
  gr.Markdown(description)
 
79
  input_image = gr.Image(label="Input Image", type="numpy")
80
  submit_btn = gr.Button("Remove Reflection", variant="primary")
81
  with gr.Column():
82
+ output_image = gr.Image(label="Processed Image")
83
 
84
  # 添加示例
85
  gr.Examples(