smarques commited on
Commit
9c5c431
·
1 Parent(s): 450f5e6
Files changed (1) hide show
  1. app.py +218 -4
app.py CHANGED
@@ -9,8 +9,222 @@ from huggingface_hub import snapshot_download
9
  os.makedirs("checkpoints", exist_ok=True)
10
  snapshot_download("alex4727/InstantDrag", local_dir="./checkpoints")
11
 
12
- def greet(name):
13
- return "Hello " + name + "!!"
 
 
 
 
 
14
 
15
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
16
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  os.makedirs("checkpoints", exist_ok=True)
10
  snapshot_download("alex4727/InstantDrag", local_dir="./checkpoints")
11
 
12
+ from demo.demo_utils import (
13
+ process_img,
14
+ get_points,
15
+ undo_points_image,
16
+ clear_all,
17
+ InstantDragPipeline,
18
+ )
19
 
20
+ LENGTH = 480 # Length of the square area displaying/editing images
21
+
22
+ with gr.Blocks() as demo:
23
+ pipeline = InstantDragPipeline(seed=42, device="cuda", dtype=torch.float16)
24
+
25
+ with gr.Row():
26
+ gr.Markdown(
27
+ """
28
+ # InstantDrag: Improving Interactivity in Drag-based Image Editing
29
+ """
30
+ )
31
+ with gr.Tab(label="InstantDrag Demo"):
32
+ selected_points = gr.State([]) # Store points
33
+ original_image = gr.State(value=None) # Store original input image
34
+ with gr.Row():
35
+ # Upload & Preprocess Image Column
36
+ with gr.Column():
37
+ gr.Markdown(
38
+ """<p style="text-align: center; font-size: 20px">Upload & Preprocess Image</p>"""
39
+ )
40
+ canvas = gr.ImageEditor(
41
+ height=LENGTH,
42
+ width=LENGTH,
43
+ type="numpy",
44
+ image_mode="RGB",
45
+ label="Preprocess Image",
46
+ show_label=True,
47
+ interactive=True,
48
+ )
49
+ with gr.Row():
50
+ save_results = gr.Checkbox(
51
+ value=False,
52
+ label="Save Results",
53
+ scale=1,
54
+ )
55
+ undo_button = gr.Button("Undo Clicked Points", scale=3)
56
+
57
+ # Click Points Column
58
+ with gr.Column():
59
+ gr.Markdown(
60
+ """<p style="text-align: center; font-size: 20px">Click Points</p>"""
61
+ )
62
+ input_image = gr.Image(
63
+ type="numpy",
64
+ label="Click Points",
65
+ show_label=True,
66
+ height=LENGTH,
67
+ width=LENGTH,
68
+ interactive=False,
69
+ show_fullscreen_button=False,
70
+ )
71
+ with gr.Row():
72
+ run_button = gr.Button("Run")
73
+
74
+ # Editing Results Column
75
+ with gr.Column():
76
+ gr.Markdown(
77
+ """<p style="text-align: center; font-size: 20px">Editing Results</p>"""
78
+ )
79
+ edited_image = gr.Image(
80
+ type="numpy",
81
+ label="Editing Results",
82
+ show_label=True,
83
+ height=LENGTH,
84
+ width=LENGTH,
85
+ interactive=False,
86
+ show_fullscreen_button=False,
87
+ )
88
+ with gr.Row():
89
+ clear_all_button = gr.Button("Clear All")
90
+ with gr.Tab("Configs - make sure to check README for details"):
91
+ with gr.Row():
92
+ with gr.Column():
93
+ with gr.Row():
94
+ flowgen_choices = sorted(
95
+ [model for model in os.listdir("checkpoints/") if "flowgen" in model]
96
+ )
97
+ flowgen_ckpt = gr.Dropdown(
98
+ value=flowgen_choices[0],
99
+ label="Select FlowGen to use",
100
+ choices=flowgen_choices,
101
+ info="config2 for most cases, config3 for more fine-grained dragging",
102
+ scale=2,
103
+ )
104
+ flowdiffusion_choices = sorted(
105
+ [model for model in os.listdir("checkpoints/") if "flowdiffusion" in model]
106
+ )
107
+ flowdiffusion_ckpt = gr.Dropdown(
108
+ value=flowdiffusion_choices[0],
109
+ label="Select FlowDiffusion to use",
110
+ choices=flowdiffusion_choices,
111
+ info="single model for all cases",
112
+ scale=1,
113
+ )
114
+ image_guidance = gr.Number(
115
+ value=1.5,
116
+ label="Image Guidance Scale",
117
+ precision=2,
118
+ step=0.1,
119
+ scale=1,
120
+ info="typically between 1.0-2.0.",
121
+ )
122
+ flow_guidance = gr.Number(
123
+ value=1.5,
124
+ label="Flow Guidance Scale",
125
+ precision=2,
126
+ step=0.1,
127
+ scale=1,
128
+ info="typically between 1.0-5.0",
129
+ )
130
+ num_steps = gr.Number(
131
+ value=20,
132
+ label="Inference Steps",
133
+ precision=0,
134
+ step=1,
135
+ scale=1,
136
+ info="typically between 20-50, 20 is usually enough",
137
+ )
138
+ flowgen_output_scale = gr.Number(
139
+ value=-1.0,
140
+ label="FlowGen Output Scale",
141
+ precision=1,
142
+ step=0.1,
143
+ scale=2,
144
+ info="-1.0, by default, forces flowgen's output to [-1, 1], could be adjusted to [0, ∞] for stronger/weaker effects",
145
+ )
146
+ gr.Markdown(
147
+ """
148
+ <p style="text-align: center; font-size: 18px;">Examples</p>
149
+ """
150
+ )
151
+ with gr.Row():
152
+ gr.Examples(
153
+ examples=[
154
+ "/home/user/app/InstDrag/demo/samples/airplane.jpg",
155
+ "/home/user/app/InstDrag/demo/samples/anime.jpg",
156
+ "/home/user/app/InstDrag/demo/samples/caligraphy.jpg",
157
+ "/home/user/app/InstDrag/demo/samples/crocodile.jpg",
158
+ "/home/user/app/InstDrag/demo/samples/elephant.jpg",
159
+ "/home/user/app/InstDrag/demo/samples/meteor.jpg",
160
+ "/home/user/app/InstDrag/demo/samples/monalisa.jpg",
161
+ "/home/user/app/InstDrag/demo/samples/portrait.jpg",
162
+ "/home/user/app/InstDrag/demo/samples/sketch.jpg",
163
+ "/home/user/app/InstDrag/demo/samples/surreal.jpg",
164
+ ],
165
+ inputs=[canvas],
166
+ outputs=[original_image, selected_points, input_image],
167
+ fn=process_img,
168
+ cache_examples=False,
169
+ examples_per_page=10,
170
+ )
171
+ gr.Markdown(
172
+ """
173
+ <p style="text-align: center; font-size: 9">[Important] Our base models are solely trained on real-world talking head (facial) videos, with a focus on achieving fine-grained facial editing. <br>
174
+ Their application to other types of scenes, without fine-tuning, should be considered more of an experimental byproduct and may not perform well in many cases (we currently support only square images).</p>
175
+ """
176
+ )
177
+ # Event Handlers
178
+ canvas.change(
179
+ process_img,
180
+ [canvas],
181
+ [original_image, selected_points, input_image],
182
+ )
183
+
184
+ input_image.select(
185
+ get_points,
186
+ [input_image, selected_points],
187
+ [input_image],
188
+ )
189
+
190
+ undo_button.click(
191
+ undo_points_image,
192
+ [original_image],
193
+ [input_image, selected_points],
194
+ )
195
+
196
+ run_button.click(
197
+ pipeline.run,
198
+ [
199
+ original_image,
200
+ selected_points,
201
+ flowgen_ckpt,
202
+ flowdiffusion_ckpt,
203
+ image_guidance,
204
+ flow_guidance,
205
+ flowgen_output_scale,
206
+ num_steps,
207
+ save_results,
208
+ ],
209
+ [edited_image],
210
+ )
211
+
212
+ clear_all_button.click(
213
+ clear_all,
214
+ [],
215
+ [
216
+ canvas,
217
+ input_image,
218
+ edited_image,
219
+ selected_points,
220
+ original_image,
221
+ ],
222
+ )
223
+
224
+ demo.queue().launch(ssr_mode=False)
225
+
226
+ # def greet(name):
227
+ # return "Hello " + name + "!!"
228
+
229
+ # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
230
+ # demo.launch()