Ashoka74 commited on
Commit
7417c4f
Β·
verified Β·
1 Parent(s): f0943f5

Upload 7 files

Browse files
Files changed (7) hide show
  1. Dockerfile +47 -0
  2. briarmbg.py +462 -0
  3. db_examples.py +217 -0
  4. gradio_demo.py +1103 -0
  5. gradio_demo_bg.py +1004 -0
  6. requirements.txt +0 -0
  7. xformers-0.0.28.post3.zip +3 -0
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Start from NVIDIA CUDA base image
2
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
3
+
4
+ # Set environment variables
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+ ENV PYTHONUNBUFFERED=1
7
+ ENV TORCH_HOME=/app/models
8
+
9
+ # Install system dependencies
10
+ RUN apt-get update && apt-get install -y \
11
+ python3.10 \
12
+ python3-pip \
13
+ git \
14
+ libgl1-mesa-glx \
15
+ libglib2.0-0 \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ # Set working directory
19
+ WORKDIR /app
20
+
21
+ # Copy requirements file
22
+ COPY requirements.txt .
23
+
24
+ # Install Python dependencies
25
+ RUN pip3 install --no-cache-dir torch==2.4.1+cu121 torchvision==0.19.1+cu121 --index-url https://download.pytorch.org/whl/cu121
26
+ RUN pip3 install --no-cache-dir -r requirements.txt
27
+
28
+ # Create models directory
29
+ RUN mkdir -p /app/models /app/outputs
30
+
31
+ # Copy application files
32
+ COPY . .
33
+
34
+ # Download model weights if needed
35
+ RUN mkdir -p /app/models && \
36
+ if [ ! -f /app/models/iclight_sd15_fc.safetensors ]; then \
37
+ wget -P /app/models https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors; \
38
+ fi && \
39
+ if [ ! -f /app/models/iclight_sd15_fbc.safetensors ]; then \
40
+ wget -P /app/models https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors; \
41
+ fi
42
+
43
+ # Expose port for Gradio
44
+ EXPOSE 7860
45
+
46
+ # Command to run the application
47
+ CMD ["python3", "gradio_demo.py"]
briarmbg.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RMBG1.4 (diffusers implementation)
2
+ # Found on huggingface space of several projects
3
+ # Not sure which project is the source of this file
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+
11
+ class REBNCONV(nn.Module):
12
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
13
+ super(REBNCONV, self).__init__()
14
+
15
+ self.conv_s1 = nn.Conv2d(
16
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
17
+ )
18
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
19
+ self.relu_s1 = nn.ReLU(inplace=True)
20
+
21
+ def forward(self, x):
22
+ hx = x
23
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
24
+
25
+ return xout
26
+
27
+
28
+ def _upsample_like(src, tar):
29
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
+ return src
31
+
32
+
33
+ ### RSU-7 ###
34
+ class RSU7(nn.Module):
35
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
36
+ super(RSU7, self).__init__()
37
+
38
+ self.in_ch = in_ch
39
+ self.mid_ch = mid_ch
40
+ self.out_ch = out_ch
41
+
42
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
43
+
44
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
45
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
46
+
47
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
48
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
+
50
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
51
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
52
+
53
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
54
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
55
+
56
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
57
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
58
+
59
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
60
+
61
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
62
+
63
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
64
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
65
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
69
+
70
+ def forward(self, x):
71
+ b, c, h, w = x.shape
72
+
73
+ hx = x
74
+ hxin = self.rebnconvin(hx)
75
+
76
+ hx1 = self.rebnconv1(hxin)
77
+ hx = self.pool1(hx1)
78
+
79
+ hx2 = self.rebnconv2(hx)
80
+ hx = self.pool2(hx2)
81
+
82
+ hx3 = self.rebnconv3(hx)
83
+ hx = self.pool3(hx3)
84
+
85
+ hx4 = self.rebnconv4(hx)
86
+ hx = self.pool4(hx4)
87
+
88
+ hx5 = self.rebnconv5(hx)
89
+ hx = self.pool5(hx5)
90
+
91
+ hx6 = self.rebnconv6(hx)
92
+
93
+ hx7 = self.rebnconv7(hx6)
94
+
95
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
96
+ hx6dup = _upsample_like(hx6d, hx5)
97
+
98
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
99
+ hx5dup = _upsample_like(hx5d, hx4)
100
+
101
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
102
+ hx4dup = _upsample_like(hx4d, hx3)
103
+
104
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
105
+ hx3dup = _upsample_like(hx3d, hx2)
106
+
107
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
108
+ hx2dup = _upsample_like(hx2d, hx1)
109
+
110
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+ hx = x
146
+
147
+ hxin = self.rebnconvin(hx)
148
+
149
+ hx1 = self.rebnconv1(hxin)
150
+ hx = self.pool1(hx1)
151
+
152
+ hx2 = self.rebnconv2(hx)
153
+ hx = self.pool2(hx2)
154
+
155
+ hx3 = self.rebnconv3(hx)
156
+ hx = self.pool3(hx3)
157
+
158
+ hx4 = self.rebnconv4(hx)
159
+ hx = self.pool4(hx4)
160
+
161
+ hx5 = self.rebnconv5(hx)
162
+
163
+ hx6 = self.rebnconv6(hx5)
164
+
165
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
166
+ hx5dup = _upsample_like(hx5d, hx4)
167
+
168
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
169
+ hx4dup = _upsample_like(hx4d, hx3)
170
+
171
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
172
+ hx3dup = _upsample_like(hx3d, hx2)
173
+
174
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
175
+ hx2dup = _upsample_like(hx2d, hx1)
176
+
177
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
178
+
179
+ return hx1d + hxin
180
+
181
+
182
+ ### RSU-5 ###
183
+ class RSU5(nn.Module):
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU5, self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
199
+
200
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
201
+
202
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
203
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
204
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
205
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
206
+
207
+ def forward(self, x):
208
+ hx = x
209
+
210
+ hxin = self.rebnconvin(hx)
211
+
212
+ hx1 = self.rebnconv1(hxin)
213
+ hx = self.pool1(hx1)
214
+
215
+ hx2 = self.rebnconv2(hx)
216
+ hx = self.pool2(hx2)
217
+
218
+ hx3 = self.rebnconv3(hx)
219
+ hx = self.pool3(hx3)
220
+
221
+ hx4 = self.rebnconv4(hx)
222
+
223
+ hx5 = self.rebnconv5(hx4)
224
+
225
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
226
+ hx4dup = _upsample_like(hx4d, hx3)
227
+
228
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
229
+ hx3dup = _upsample_like(hx3d, hx2)
230
+
231
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
232
+ hx2dup = _upsample_like(hx2d, hx1)
233
+
234
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
235
+
236
+ return hx1d + hxin
237
+
238
+
239
+ ### RSU-4 ###
240
+ class RSU4(nn.Module):
241
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
242
+ super(RSU4, self).__init__()
243
+
244
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
245
+
246
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
247
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
248
+
249
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
250
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
251
+
252
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
253
+
254
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
255
+
256
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
259
+
260
+ def forward(self, x):
261
+ hx = x
262
+
263
+ hxin = self.rebnconvin(hx)
264
+
265
+ hx1 = self.rebnconv1(hxin)
266
+ hx = self.pool1(hx1)
267
+
268
+ hx2 = self.rebnconv2(hx)
269
+ hx = self.pool2(hx2)
270
+
271
+ hx3 = self.rebnconv3(hx)
272
+
273
+ hx4 = self.rebnconv4(hx3)
274
+
275
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
276
+ hx3dup = _upsample_like(hx3d, hx2)
277
+
278
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
279
+ hx2dup = _upsample_like(hx2d, hx1)
280
+
281
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
282
+
283
+ return hx1d + hxin
284
+
285
+
286
+ ### RSU-4F ###
287
+ class RSU4F(nn.Module):
288
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
289
+ super(RSU4F, self).__init__()
290
+
291
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
292
+
293
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
294
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
295
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
296
+
297
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
298
+
299
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
300
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
301
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
302
+
303
+ def forward(self, x):
304
+ hx = x
305
+
306
+ hxin = self.rebnconvin(hx)
307
+
308
+ hx1 = self.rebnconv1(hxin)
309
+ hx2 = self.rebnconv2(hx1)
310
+ hx3 = self.rebnconv3(hx2)
311
+
312
+ hx4 = self.rebnconv4(hx3)
313
+
314
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
315
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
316
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
317
+
318
+ return hx1d + hxin
319
+
320
+
321
+ class myrebnconv(nn.Module):
322
+ def __init__(
323
+ self,
324
+ in_ch=3,
325
+ out_ch=1,
326
+ kernel_size=3,
327
+ stride=1,
328
+ padding=1,
329
+ dilation=1,
330
+ groups=1,
331
+ ):
332
+ super(myrebnconv, self).__init__()
333
+
334
+ self.conv = nn.Conv2d(
335
+ in_ch,
336
+ out_ch,
337
+ kernel_size=kernel_size,
338
+ stride=stride,
339
+ padding=padding,
340
+ dilation=dilation,
341
+ groups=groups,
342
+ )
343
+ self.bn = nn.BatchNorm2d(out_ch)
344
+ self.rl = nn.ReLU(inplace=True)
345
+
346
+ def forward(self, x):
347
+ return self.rl(self.bn(self.conv(x)))
348
+
349
+
350
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
351
+ def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
352
+ super(BriaRMBG, self).__init__()
353
+ in_ch = config["in_ch"]
354
+ out_ch = config["out_ch"]
355
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
356
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
357
+
358
+ self.stage1 = RSU7(64, 32, 64)
359
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
360
+
361
+ self.stage2 = RSU6(64, 32, 128)
362
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
363
+
364
+ self.stage3 = RSU5(128, 64, 256)
365
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
366
+
367
+ self.stage4 = RSU4(256, 128, 512)
368
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
369
+
370
+ self.stage5 = RSU4F(512, 256, 512)
371
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
372
+
373
+ self.stage6 = RSU4F(512, 256, 512)
374
+
375
+ # decoder
376
+ self.stage5d = RSU4F(1024, 256, 512)
377
+ self.stage4d = RSU4(1024, 128, 256)
378
+ self.stage3d = RSU5(512, 64, 128)
379
+ self.stage2d = RSU6(256, 32, 64)
380
+ self.stage1d = RSU7(128, 16, 64)
381
+
382
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
383
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
384
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
385
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
386
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
387
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
388
+
389
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
390
+
391
+ def forward(self, x):
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ # hx = self.pool_in(hxin)
396
+
397
+ # stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ # stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ # stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ # stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ # stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ # stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6, hx5)
420
+
421
+ # -------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
423
+ hx5dup = _upsample_like(hx5d, hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
426
+ hx4dup = _upsample_like(hx4d, hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
429
+ hx3dup = _upsample_like(hx3d, hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
432
+ hx2dup = _upsample_like(hx2d, hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
435
+
436
+ # side output
437
+ d1 = self.side1(hx1d)
438
+ d1 = _upsample_like(d1, x)
439
+
440
+ d2 = self.side2(hx2d)
441
+ d2 = _upsample_like(d2, x)
442
+
443
+ d3 = self.side3(hx3d)
444
+ d3 = _upsample_like(d3, x)
445
+
446
+ d4 = self.side4(hx4d)
447
+ d4 = _upsample_like(d4, x)
448
+
449
+ d5 = self.side5(hx5d)
450
+ d5 = _upsample_like(d5, x)
451
+
452
+ d6 = self.side6(hx6)
453
+ d6 = _upsample_like(d6, x)
454
+
455
+ return [
456
+ F.sigmoid(d1),
457
+ F.sigmoid(d2),
458
+ F.sigmoid(d3),
459
+ F.sigmoid(d4),
460
+ F.sigmoid(d5),
461
+ F.sigmoid(d6),
462
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
db_examples.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ foreground_conditioned_examples = [
2
+ [
3
+ "imgs/i1.webp",
4
+ "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
5
+ "Right Light",
6
+ 512,
7
+ 960,
8
+ 12345,
9
+ "imgs/o1.png",
10
+ ],
11
+ [
12
+ "imgs/i1.webp",
13
+ "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
14
+ "Left Light",
15
+ 512,
16
+ 960,
17
+ 50,
18
+ "imgs/o2.png",
19
+ ],
20
+ [
21
+ "imgs/i3.png",
22
+ "beautiful woman, detailed face, neon, Wong Kar-wai, warm",
23
+ "Left Light",
24
+ 512,
25
+ 768,
26
+ 12345,
27
+ "imgs/o3.png",
28
+ ],
29
+ [
30
+ "imgs/i3.png",
31
+ "beautiful woman, detailed face, sunshine from window",
32
+ "Left Light",
33
+ 512,
34
+ 768,
35
+ 12345,
36
+ "imgs/o4.png",
37
+ ],
38
+ [
39
+ "imgs/i5.png",
40
+ "beautiful woman, detailed face, warm atmosphere, at home, bedroom",
41
+ "Left Light",
42
+ 512,
43
+ 768,
44
+ 123,
45
+ "imgs/o5.png",
46
+ ],
47
+ [
48
+ "imgs/i6.jpg",
49
+ "beautiful woman, detailed face, sunshine from window",
50
+ "Right Light",
51
+ 512,
52
+ 768,
53
+ 42,
54
+ "imgs/o6.png",
55
+ ],
56
+ [
57
+ "imgs/i7.jpg",
58
+ "beautiful woman, detailed face, shadow from window",
59
+ "Left Light",
60
+ 512,
61
+ 768,
62
+ 8888,
63
+ "imgs/o7.png",
64
+ ],
65
+ [
66
+ "imgs/i8.webp",
67
+ "beautiful woman, detailed face, sunset over sea",
68
+ "Right Light",
69
+ 512,
70
+ 640,
71
+ 42,
72
+ "imgs/o8.png",
73
+ ],
74
+ [
75
+ "imgs/i9.png",
76
+ "handsome boy, detailed face, neon light, city",
77
+ "Left Light",
78
+ 512,
79
+ 640,
80
+ 12345,
81
+ "imgs/o9.png",
82
+ ],
83
+ [
84
+ "imgs/i10.png",
85
+ "beautiful woman, detailed face, light and shadow",
86
+ "Left Light",
87
+ 512,
88
+ 960,
89
+ 8888,
90
+ "imgs/o10.png",
91
+ ],
92
+ [
93
+ "imgs/i11.png",
94
+ "Buddha, detailed face, sci-fi RGB glowing, cyberpunk",
95
+ "Left Light",
96
+ 512,
97
+ 768,
98
+ 8888,
99
+ "imgs/o11.png",
100
+ ],
101
+ [
102
+ "imgs/i11.png",
103
+ "Buddha, detailed face, natural lighting",
104
+ "Left Light",
105
+ 512,
106
+ 768,
107
+ 12345,
108
+ "imgs/o12.png",
109
+ ],
110
+ [
111
+ "imgs/i13.png",
112
+ "toy, detailed face, shadow from window",
113
+ "Bottom Light",
114
+ 512,
115
+ 704,
116
+ 12345,
117
+ "imgs/o13.png",
118
+ ],
119
+ [
120
+ "imgs/i14.png",
121
+ "toy, detailed face, sunset over sea",
122
+ "Right Light",
123
+ 512,
124
+ 704,
125
+ 100,
126
+ "imgs/o14.png",
127
+ ],
128
+ [
129
+ "imgs/i15.png",
130
+ "dog, magic lit, sci-fi RGB glowing, studio lighting",
131
+ "Bottom Light",
132
+ 512,
133
+ 768,
134
+ 12345,
135
+ "imgs/o15.png",
136
+ ],
137
+ [
138
+ "imgs/i16.png",
139
+ "mysteriou human, warm atmosphere, warm atmosphere, at home, bedroom",
140
+ "Right Light",
141
+ 512,
142
+ 768,
143
+ 100,
144
+ "imgs/o16.png",
145
+ ],
146
+ ]
147
+
148
+ bg_samples = [
149
+ 'imgs/bgs/1.webp',
150
+ 'imgs/bgs/2.webp',
151
+ 'imgs/bgs/3.webp',
152
+ 'imgs/bgs/4.webp',
153
+ 'imgs/bgs/5.webp',
154
+ 'imgs/bgs/6.webp',
155
+ 'imgs/bgs/7.webp',
156
+ 'imgs/bgs/8.webp',
157
+ 'imgs/bgs/9.webp',
158
+ 'imgs/bgs/10.webp',
159
+ 'imgs/bgs/11.png',
160
+ 'imgs/bgs/12.png',
161
+ 'imgs/bgs/13.png',
162
+ 'imgs/bgs/14.png',
163
+ 'imgs/bgs/15.png',
164
+ ]
165
+
166
+ background_conditioned_examples = [
167
+ [
168
+ "imgs/alter/i3.png",
169
+ "imgs/bgs/7.webp",
170
+ "beautiful woman, cinematic lighting",
171
+ "Use Background Image",
172
+ 512,
173
+ 768,
174
+ 12345,
175
+ "imgs/alter/o1.png",
176
+ ],
177
+ [
178
+ "imgs/alter/i2.png",
179
+ "imgs/bgs/11.png",
180
+ "statue of an angel, natural lighting",
181
+ "Use Flipped Background Image",
182
+ 512,
183
+ 768,
184
+ 12345,
185
+ "imgs/alter/o2.png",
186
+ ],
187
+ [
188
+ "imgs/alter/i1.jpeg",
189
+ "imgs/bgs/2.webp",
190
+ "beautiful woman, cinematic lighting",
191
+ "Use Background Image",
192
+ 512,
193
+ 768,
194
+ 12345,
195
+ "imgs/alter/o3.png",
196
+ ],
197
+ [
198
+ "imgs/alter/i1.jpeg",
199
+ "imgs/bgs/3.webp",
200
+ "beautiful woman, cinematic lighting",
201
+ "Use Background Image",
202
+ 512,
203
+ 768,
204
+ 12345,
205
+ "imgs/alter/o4.png",
206
+ ],
207
+ [
208
+ "imgs/alter/i6.webp",
209
+ "imgs/bgs/15.png",
210
+ "handsome man, cinematic lighting",
211
+ "Use Background Image",
212
+ 512,
213
+ 768,
214
+ 12345,
215
+ "imgs/alter/o5.png",
216
+ ],
217
+ ]
gradio_demo.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import safetensors.torch as sf
7
+ import db_examples
8
+ import datetime
9
+ from pathlib import Path
10
+ from io import BytesIO
11
+
12
+ from PIL import Image
13
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
14
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
15
+ from diffusers.models.attention_processor import AttnProcessor2_0
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
+ from briarmbg import BriaRMBG
18
+ from enum import Enum
19
+ from torch.hub import download_url_to_file
20
+
21
+ from torch.hub import download_url_to_file
22
+ import cv2
23
+
24
+ from typing import Optional
25
+
26
+ from Depth.depth_anything_v2.dpt import DepthAnythingV2
27
+
28
+
29
+
30
+ # from FLORENCE
31
+ import spaces
32
+ import supervision as sv
33
+ import torch
34
+ from PIL import Image
35
+
36
+ from utils.sam import load_sam_image_model, run_sam_inference
37
+
38
+
39
+ try:
40
+ import xformers
41
+ import xformers.ops
42
+ XFORMERS_AVAILABLE = True
43
+ print("xformers is available - Using memory efficient attention")
44
+ except ImportError:
45
+ XFORMERS_AVAILABLE = False
46
+ print("xformers not available - Using default attention")
47
+
48
+ # Memory optimizations for RTX 2070
49
+ torch.backends.cudnn.benchmark = True
50
+ if torch.cuda.is_available():
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+ torch.backends.cudnn.allow_tf32 = True
53
+ # Set a smaller attention slice size for RTX 2070
54
+ torch.backends.cuda.max_split_size_mb = 512
55
+ device = torch.device('cuda')
56
+ else:
57
+ device = torch.device('cpu')
58
+
59
+ # 'stablediffusionapi/realistic-vision-v51'
60
+ # 'runwayml/stable-diffusion-v1-5'
61
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
62
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
63
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
64
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
65
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
66
+ rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
67
+
68
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
69
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
70
+ model = model.to(device)
71
+ model.eval()
72
+
73
+ # Change UNet
74
+
75
+ with torch.no_grad():
76
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
77
+ new_conv_in.weight.zero_()
78
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
79
+ new_conv_in.bias = unet.conv_in.bias
80
+ unet.conv_in = new_conv_in
81
+
82
+
83
+ unet_original_forward = unet.forward
84
+
85
+
86
+ def enable_efficient_attention():
87
+ if XFORMERS_AVAILABLE:
88
+ try:
89
+ # RTX 2070 specific settings
90
+ unet.set_use_memory_efficient_attention_xformers(True)
91
+ vae.set_use_memory_efficient_attention_xformers(True)
92
+ print("Enabled xformers memory efficient attention")
93
+ except Exception as e:
94
+ print(f"Xformers error: {e}")
95
+ print("Falling back to sliced attention")
96
+ # Use sliced attention for RTX 2070
97
+ unet.set_attention_slice_size(4)
98
+ vae.set_attention_slice_size(4)
99
+ unet.set_attn_processor(AttnProcessor2_0())
100
+ vae.set_attn_processor(AttnProcessor2_0())
101
+ else:
102
+ # Fallback for when xformers is not available
103
+ print("Using sliced attention")
104
+ unet.set_attention_slice_size(4)
105
+ vae.set_attention_slice_size(4)
106
+ unet.set_attn_processor(AttnProcessor2_0())
107
+ vae.set_attn_processor(AttnProcessor2_0())
108
+
109
+ # Add memory clearing function
110
+ def clear_memory():
111
+ if torch.cuda.is_available():
112
+ torch.cuda.empty_cache()
113
+ torch.cuda.synchronize()
114
+
115
+ # Enable efficient attention
116
+ enable_efficient_attention()
117
+
118
+
119
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
120
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
121
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
122
+ new_sample = torch.cat([sample, c_concat], dim=1)
123
+ kwargs['cross_attention_kwargs'] = {}
124
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
125
+
126
+
127
+ unet.forward = hooked_unet_forward
128
+
129
+ # Load
130
+
131
+ model_path = './models/iclight_sd15_fc.safetensors'
132
+ # model_path = './models/iclight_sd15_fbc.safetensors'
133
+
134
+
135
+ # if not os.path.exists(model_path):
136
+ # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
137
+
138
+ sd_offset = sf.load_file(model_path)
139
+ sd_origin = unet.state_dict()
140
+ keys = sd_origin.keys()
141
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
142
+ unet.load_state_dict(sd_merged, strict=True)
143
+ del sd_offset, sd_origin, sd_merged, keys
144
+
145
+ # Device
146
+
147
+ # device = torch.device('cuda')
148
+ # text_encoder = text_encoder.to(device=device, dtype=torch.float16)
149
+ # vae = vae.to(device=device, dtype=torch.bfloat16)
150
+ # unet = unet.to(device=device, dtype=torch.float16)
151
+ # rmbg = rmbg.to(device=device, dtype=torch.float32)
152
+
153
+
154
+ # Device and dtype setup
155
+ device = torch.device('cuda')
156
+ dtype = torch.float16 # RTX 2070 works well with float16
157
+
158
+ # Memory optimizations for RTX 2070
159
+ torch.backends.cudnn.benchmark = True
160
+ if torch.cuda.is_available():
161
+ torch.backends.cuda.matmul.allow_tf32 = True
162
+ torch.backends.cudnn.allow_tf32 = True
163
+ # Set a very small attention slice size for RTX 2070 to avoid OOM
164
+ torch.backends.cuda.max_split_size_mb = 128
165
+
166
+ # Move models to device with consistent dtype
167
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
168
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
169
+ unet = unet.to(device=device, dtype=dtype)
170
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
171
+
172
+
173
+ ddim_scheduler = DDIMScheduler(
174
+ num_train_timesteps=1000,
175
+ beta_start=0.00085,
176
+ beta_end=0.012,
177
+ beta_schedule="scaled_linear",
178
+ clip_sample=False,
179
+ set_alpha_to_one=False,
180
+ steps_offset=1,
181
+ )
182
+
183
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
184
+ num_train_timesteps=1000,
185
+ beta_start=0.00085,
186
+ beta_end=0.012,
187
+ steps_offset=1
188
+ )
189
+
190
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
191
+ num_train_timesteps=1000,
192
+ beta_start=0.00085,
193
+ beta_end=0.012,
194
+ algorithm_type="sde-dpmsolver++",
195
+ use_karras_sigmas=True,
196
+ steps_offset=1
197
+ )
198
+
199
+ # Pipelines
200
+
201
+ t2i_pipe = StableDiffusionPipeline(
202
+ vae=vae,
203
+ text_encoder=text_encoder,
204
+ tokenizer=tokenizer,
205
+ unet=unet,
206
+ scheduler=dpmpp_2m_sde_karras_scheduler,
207
+ safety_checker=None,
208
+ requires_safety_checker=False,
209
+ feature_extractor=None,
210
+ image_encoder=None
211
+ )
212
+
213
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
214
+ vae=vae,
215
+ text_encoder=text_encoder,
216
+ tokenizer=tokenizer,
217
+ unet=unet,
218
+ scheduler=dpmpp_2m_sde_karras_scheduler,
219
+ safety_checker=None,
220
+ requires_safety_checker=False,
221
+ feature_extractor=None,
222
+ image_encoder=None
223
+ )
224
+
225
+
226
+ @torch.inference_mode()
227
+ def encode_prompt_inner(txt: str):
228
+ max_length = tokenizer.model_max_length
229
+ chunk_length = tokenizer.model_max_length - 2
230
+ id_start = tokenizer.bos_token_id
231
+ id_end = tokenizer.eos_token_id
232
+ id_pad = id_end
233
+
234
+ def pad(x, p, i):
235
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
236
+
237
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
238
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
239
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
240
+
241
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
242
+ conds = text_encoder(token_ids).last_hidden_state
243
+
244
+ return conds
245
+
246
+
247
+ @torch.inference_mode()
248
+ def encode_prompt_pair(positive_prompt, negative_prompt):
249
+ c = encode_prompt_inner(positive_prompt)
250
+ uc = encode_prompt_inner(negative_prompt)
251
+
252
+ c_len = float(len(c))
253
+ uc_len = float(len(uc))
254
+ max_count = max(c_len, uc_len)
255
+ c_repeat = int(math.ceil(max_count / c_len))
256
+ uc_repeat = int(math.ceil(max_count / uc_len))
257
+ max_chunk = max(len(c), len(uc))
258
+
259
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
260
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
261
+
262
+ c = torch.cat([p[None, ...] for p in c], dim=1)
263
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
264
+
265
+ return c, uc
266
+
267
+
268
+ @torch.inference_mode()
269
+ def pytorch2numpy(imgs, quant=True):
270
+ results = []
271
+ for x in imgs:
272
+ y = x.movedim(0, -1)
273
+
274
+ if quant:
275
+ y = y * 127.5 + 127.5
276
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
277
+ else:
278
+ y = y * 0.5 + 0.5
279
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
280
+
281
+ results.append(y)
282
+ return results
283
+
284
+
285
+ @torch.inference_mode()
286
+ def numpy2pytorch(imgs):
287
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
288
+ h = h.movedim(-1, 1)
289
+ return h
290
+
291
+
292
+ def resize_and_center_crop(image, target_width, target_height):
293
+ pil_image = Image.fromarray(image)
294
+ original_width, original_height = pil_image.size
295
+ scale_factor = max(target_width / original_width, target_height / original_height)
296
+ resized_width = int(round(original_width * scale_factor))
297
+ resized_height = int(round(original_height * scale_factor))
298
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
299
+ left = (resized_width - target_width) / 2
300
+ top = (resized_height - target_height) / 2
301
+ right = (resized_width + target_width) / 2
302
+ bottom = (resized_height + target_height) / 2
303
+ cropped_image = resized_image.crop((left, top, right, bottom))
304
+ return np.array(cropped_image)
305
+
306
+
307
+ def resize_without_crop(image, target_width, target_height):
308
+ pil_image = Image.fromarray(image)
309
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
310
+ return np.array(resized_image)
311
+
312
+
313
+ @torch.inference_mode()
314
+ def run_rmbg(img, sigma=0.0):
315
+ # Convert RGBA to RGB if needed
316
+ if img.shape[-1] == 4:
317
+ # Use white background for alpha composition
318
+ alpha = img[..., 3:] / 255.0
319
+ rgb = img[..., :3]
320
+ white_bg = np.ones_like(rgb) * 255
321
+ img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
322
+
323
+ H, W, C = img.shape
324
+ assert C == 3
325
+ k = (256.0 / float(H * W)) ** 0.5
326
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
327
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
328
+ alpha = rmbg(feed)[0][0]
329
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
330
+ alpha = alpha.movedim(1, -1)[0]
331
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
332
+
333
+ # Create RGBA image
334
+ rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
335
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
336
+ return result.clip(0, 255).astype(np.uint8), rgba
337
+ @torch.inference_mode()
338
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
339
+ clear_memory()
340
+
341
+ # Get input dimensions
342
+ input_height, input_width = input_fg.shape[:2]
343
+
344
+ bg_source = BGSource(bg_source)
345
+
346
+
347
+ if bg_source == BGSource.UPLOAD:
348
+ pass
349
+ elif bg_source == BGSource.UPLOAD_FLIP:
350
+ input_bg = np.fliplr(input_bg)
351
+ elif bg_source == BGSource.GREY:
352
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
353
+ elif bg_source == BGSource.LEFT:
354
+ gradient = np.linspace(255, 0, input_width)
355
+ image = np.tile(gradient, (input_height, 1))
356
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
357
+ elif bg_source == BGSource.RIGHT:
358
+ gradient = np.linspace(0, 255, input_width)
359
+ image = np.tile(gradient, (input_height, 1))
360
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
361
+ elif bg_source == BGSource.TOP:
362
+ gradient = np.linspace(255, 0, input_height)[:, None]
363
+ image = np.tile(gradient, (1, input_width))
364
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
365
+ elif bg_source == BGSource.BOTTOM:
366
+ gradient = np.linspace(0, 255, input_height)[:, None]
367
+ image = np.tile(gradient, (1, input_width))
368
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
369
+ else:
370
+ raise 'Wrong initial latent!'
371
+
372
+ rng = torch.Generator(device=device).manual_seed(int(seed))
373
+
374
+ # Use input dimensions directly
375
+ fg = resize_without_crop(input_fg, input_width, input_height)
376
+
377
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
378
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
379
+
380
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
381
+
382
+ if input_bg is None:
383
+ latents = t2i_pipe(
384
+ prompt_embeds=conds,
385
+ negative_prompt_embeds=unconds,
386
+ width=input_width,
387
+ height=input_height,
388
+ num_inference_steps=steps,
389
+ num_images_per_prompt=num_samples,
390
+ generator=rng,
391
+ output_type='latent',
392
+ guidance_scale=cfg,
393
+ cross_attention_kwargs={'concat_conds': concat_conds},
394
+ ).images.to(vae.dtype) / vae.config.scaling_factor
395
+ else:
396
+ bg = resize_without_crop(input_bg, input_width, input_height)
397
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
398
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
399
+ latents = i2i_pipe(
400
+ image=bg_latent,
401
+ strength=lowres_denoise,
402
+ prompt_embeds=conds,
403
+ negative_prompt_embeds=unconds,
404
+ width=input_width,
405
+ height=input_height,
406
+ num_inference_steps=int(round(steps / lowres_denoise)),
407
+ num_images_per_prompt=num_samples,
408
+ generator=rng,
409
+ output_type='latent',
410
+ guidance_scale=cfg,
411
+ cross_attention_kwargs={'concat_conds': concat_conds},
412
+ ).images.to(vae.dtype) / vae.config.scaling_factor
413
+
414
+ pixels = vae.decode(latents).sample
415
+ pixels = pytorch2numpy(pixels)
416
+ pixels = [resize_without_crop(
417
+ image=p,
418
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
419
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
420
+ for p in pixels]
421
+
422
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
423
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
424
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
425
+
426
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
427
+
428
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
429
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
430
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
431
+
432
+ latents = i2i_pipe(
433
+ image=latents,
434
+ strength=highres_denoise,
435
+ prompt_embeds=conds,
436
+ negative_prompt_embeds=unconds,
437
+ width=highres_width,
438
+ height=highres_height,
439
+ num_inference_steps=int(round(steps / highres_denoise)),
440
+ num_images_per_prompt=num_samples,
441
+ generator=rng,
442
+ output_type='latent',
443
+ guidance_scale=cfg,
444
+ cross_attention_kwargs={'concat_conds': concat_conds},
445
+ ).images.to(vae.dtype) / vae.config.scaling_factor
446
+
447
+ pixels = vae.decode(latents).sample
448
+ pixels = pytorch2numpy(pixels)
449
+
450
+ # Resize back to input dimensions
451
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
452
+ pixels = np.stack(pixels)
453
+
454
+ return pixels
455
+
456
+ @torch.inference_mode()
457
+ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
458
+ clear_memory()
459
+ bg_source = BGSource(bg_source)
460
+
461
+ if bg_source == BGSource.UPLOAD:
462
+ pass
463
+ elif bg_source == BGSource.UPLOAD_FLIP:
464
+ input_bg = np.fliplr(input_bg)
465
+ elif bg_source == BGSource.GREY:
466
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
467
+ elif bg_source == BGSource.LEFT:
468
+ gradient = np.linspace(224, 32, image_width)
469
+ image = np.tile(gradient, (image_height, 1))
470
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
471
+ elif bg_source == BGSource.RIGHT:
472
+ gradient = np.linspace(32, 224, image_width)
473
+ image = np.tile(gradient, (image_height, 1))
474
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
475
+ elif bg_source == BGSource.TOP:
476
+ gradient = np.linspace(224, 32, image_height)[:, None]
477
+ image = np.tile(gradient, (1, image_width))
478
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
479
+ elif bg_source == BGSource.BOTTOM:
480
+ gradient = np.linspace(32, 224, image_height)[:, None]
481
+ image = np.tile(gradient, (1, image_width))
482
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
483
+ else:
484
+ raise 'Wrong background source!'
485
+
486
+ rng = torch.Generator(device=device).manual_seed(seed)
487
+
488
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
489
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
490
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
491
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
492
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
493
+
494
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
495
+
496
+ latents = t2i_pipe(
497
+ prompt_embeds=conds,
498
+ negative_prompt_embeds=unconds,
499
+ width=image_width,
500
+ height=image_height,
501
+ num_inference_steps=steps,
502
+ num_images_per_prompt=num_samples,
503
+ generator=rng,
504
+ output_type='latent',
505
+ guidance_scale=cfg,
506
+ cross_attention_kwargs={'concat_conds': concat_conds},
507
+ ).images.to(vae.dtype) / vae.config.scaling_factor
508
+
509
+ pixels = vae.decode(latents).sample
510
+ pixels = pytorch2numpy(pixels)
511
+ pixels = [resize_without_crop(
512
+ image=p,
513
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
514
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
515
+ for p in pixels]
516
+
517
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
518
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
519
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
520
+
521
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
522
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
523
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
524
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
525
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
526
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
527
+
528
+ latents = i2i_pipe(
529
+ image=latents,
530
+ strength=highres_denoise,
531
+ prompt_embeds=conds,
532
+ negative_prompt_embeds=unconds,
533
+ width=image_width,
534
+ height=image_height,
535
+ num_inference_steps=int(round(steps / highres_denoise)),
536
+ num_images_per_prompt=num_samples,
537
+ generator=rng,
538
+ output_type='latent',
539
+ guidance_scale=cfg,
540
+ cross_attention_kwargs={'concat_conds': concat_conds},
541
+ ).images.to(vae.dtype) / vae.config.scaling_factor
542
+
543
+ pixels = vae.decode(latents).sample
544
+ pixels = pytorch2numpy(pixels, quant=False)
545
+
546
+ clear_memory()
547
+ return pixels, [fg, bg]
548
+
549
+
550
+ @torch.inference_mode()
551
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
552
+ input_fg, matting = run_rmbg(input_fg)
553
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
554
+ return input_fg, results
555
+
556
+
557
+
558
+ @torch.inference_mode()
559
+ def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
560
+ bg_source = BGSource(bg_source)
561
+
562
+ # Convert numerical inputs to appropriate types
563
+ image_width = int(image_width)
564
+ image_height = int(image_height)
565
+ num_samples = int(num_samples)
566
+ seed = int(seed)
567
+ steps = int(steps)
568
+ cfg = float(cfg)
569
+ highres_scale = float(highres_scale)
570
+ highres_denoise = float(highres_denoise)
571
+
572
+ if bg_source == BGSource.UPLOAD:
573
+ pass
574
+ elif bg_source == BGSource.UPLOAD_FLIP:
575
+ input_bg = np.fliplr(input_bg)
576
+ elif bg_source == BGSource.GREY:
577
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
578
+ elif bg_source == BGSource.LEFT:
579
+ gradient = np.linspace(224, 32, image_width)
580
+ image = np.tile(gradient, (image_height, 1))
581
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
582
+ elif bg_source == BGSource.RIGHT:
583
+ gradient = np.linspace(32, 224, image_width)
584
+ image = np.tile(gradient, (image_height, 1))
585
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
586
+ elif bg_source == BGSource.TOP:
587
+ gradient = np.linspace(224, 32, image_height)[:, None]
588
+ image = np.tile(gradient, (1, image_width))
589
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
590
+ elif bg_source == BGSource.BOTTOM:
591
+ gradient = np.linspace(32, 224, image_height)[:, None]
592
+ image = np.tile(gradient, (1, image_width))
593
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
594
+ else:
595
+ raise ValueError('Wrong background source!')
596
+
597
+ input_fg, matting = run_rmbg(input_fg)
598
+ results, extra_images = process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
599
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
600
+ final_results = results + extra_images
601
+
602
+ # Save the generated images
603
+ save_images(results, prefix="relight")
604
+
605
+ return results
606
+
607
+
608
+ quick_prompts = [
609
+ 'sunshine from window',
610
+ 'neon light, city',
611
+ 'sunset over sea',
612
+ 'golden time',
613
+ 'sci-fi RGB glowing, cyberpunk',
614
+ 'natural lighting',
615
+ 'warm atmosphere, at home, bedroom',
616
+ 'magic lit',
617
+ 'evil, gothic, Yharnam',
618
+ 'light and shadow',
619
+ 'shadow from window',
620
+ 'soft studio lighting',
621
+ 'home atmosphere, cozy bedroom illumination',
622
+ 'neon, Wong Kar-wai, warm'
623
+ ]
624
+ quick_prompts = [[x] for x in quick_prompts]
625
+
626
+
627
+ quick_subjects = [
628
+ 'modern sofa, high quality leather',
629
+ 'elegant dining table, polished wood',
630
+ 'luxurious bed, premium mattress',
631
+ 'minimalist office desk, clean design',
632
+ 'vintage wooden cabinet, antique finish',
633
+ ]
634
+ quick_subjects = [[x] for x in quick_subjects]
635
+
636
+
637
+ class BGSource(Enum):
638
+ UPLOAD = "Use Background Image"
639
+ UPLOAD_FLIP = "Use Flipped Background Image"
640
+ LEFT = "Left Light"
641
+ RIGHT = "Right Light"
642
+ TOP = "Top Light"
643
+ BOTTOM = "Bottom Light"
644
+ GREY = "Ambient"
645
+
646
+ # Add save function
647
+ def save_images(images, prefix="relight"):
648
+ # Create output directory if it doesn't exist
649
+ output_dir = Path("outputs")
650
+ output_dir.mkdir(exist_ok=True)
651
+
652
+ # Create timestamp for unique filenames
653
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
654
+
655
+ saved_paths = []
656
+ for i, img in enumerate(images):
657
+ if isinstance(img, np.ndarray):
658
+ # Convert to PIL Image if numpy array
659
+ img = Image.fromarray(img)
660
+
661
+ # Create filename with timestamp
662
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
663
+ filepath = output_dir / filename
664
+
665
+ # Save image
666
+ img.save(filepath)
667
+
668
+
669
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
670
+ return saved_paths
671
+
672
+
673
+ class MaskMover:
674
+ def __init__(self):
675
+ self.extracted_fg = None
676
+ self.original_fg = None # Store original foreground
677
+
678
+ def set_extracted_fg(self, fg_image):
679
+ """Store the extracted foreground with alpha channel"""
680
+ if isinstance(fg_image, np.ndarray):
681
+ self.extracted_fg = fg_image.copy()
682
+ self.original_fg = fg_image.copy()
683
+ else:
684
+ self.extracted_fg = np.array(fg_image)
685
+ self.original_fg = np.array(fg_image)
686
+ return self.extracted_fg
687
+
688
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
689
+ """Create composite with foreground at specified position"""
690
+ if self.original_fg is None or background is None:
691
+ return background
692
+
693
+ # Convert inputs to PIL Images
694
+ if isinstance(background, np.ndarray):
695
+ bg = Image.fromarray(background).convert('RGBA')
696
+ else:
697
+ bg = background.convert('RGBA')
698
+
699
+ if isinstance(self.original_fg, np.ndarray):
700
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
701
+ else:
702
+ fg = self.original_fg.convert('RGBA')
703
+
704
+ # Scale the foreground size
705
+ new_width = int(fg.width * scale)
706
+ new_height = int(fg.height * scale)
707
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
708
+
709
+ # Center the scaled foreground at the position
710
+ x = int(x_pos - new_width / 2)
711
+ y = int(y_pos - new_height / 2)
712
+
713
+ # Create composite
714
+ result = bg.copy()
715
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
716
+
717
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
718
+
719
+ def get_depth(image):
720
+ if image is None:
721
+ return None
722
+ # Convert from PIL/gradio format to cv2
723
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
724
+ # Get depth map
725
+ depth = model.infer_image(raw_img) # HxW raw depth map
726
+ # Normalize depth for visualization
727
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
728
+ # Convert to RGB for display
729
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
730
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
731
+ return Image.fromarray(depth_colored)
732
+
733
+
734
+ from PIL import Image
735
+
736
+ def compress_image(image):
737
+ # Convert Gradio image (numpy array) to PIL Image
738
+ img = Image.fromarray(image)
739
+
740
+ # Resize image if dimensions are too large
741
+ max_size = 1024 # Maximum dimension size
742
+ if img.width > max_size or img.height > max_size:
743
+ ratio = min(max_size/img.width, max_size/img.height)
744
+ new_size = (int(img.width * ratio), int(img.height * ratio))
745
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
746
+
747
+ quality = 95 # Start with high quality
748
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
749
+
750
+ # Check file size and adjust quality if necessary
751
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
752
+ quality -= 5 # Decrease quality
753
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
754
+ if quality < 20: # Prevent quality from going too low
755
+ break
756
+
757
+ # Convert back to numpy array for Gradio
758
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
759
+ return compressed_img
760
+
761
+
762
+ block = gr.Blocks().queue()
763
+ with block:
764
+ with gr.Tab("Text"):
765
+ with gr.Row():
766
+ gr.Markdown("## Product Placement from Text")
767
+ with gr.Row():
768
+ with gr.Column():
769
+ with gr.Row():
770
+ input_fg = gr.Image(type="numpy", label="Image", height=480)
771
+ output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
772
+ with gr.Group():
773
+ prompt = gr.Textbox(label="Prompt")
774
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
775
+ value=BGSource.GREY.value,
776
+ label="Lighting Preference (Initial Latent)", type='value')
777
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
778
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
779
+ relight_button = gr.Button(value="Relight")
780
+
781
+ with gr.Group():
782
+ with gr.Row():
783
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
784
+ seed = gr.Number(label="Seed", value=12345, precision=0)
785
+
786
+ with gr.Row():
787
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
788
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
789
+
790
+ with gr.Accordion("Advanced options", open=False):
791
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
792
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
793
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
794
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
795
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
796
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
797
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
798
+ with gr.Column():
799
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
800
+ with gr.Row():
801
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
802
+ # gr.Examples(
803
+ # fn=lambda *args: ([args[-1]], None),
804
+ # examples=db_examples.foreground_conditioned_examples,
805
+ # inputs=[
806
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
807
+ # ],
808
+ # outputs=[result_gallery, output_bg],
809
+ # run_on_click=True, examples_per_page=1024
810
+ # )
811
+ ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
812
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
813
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
814
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
815
+
816
+ with gr.Tab("Background", visible=False):
817
+ mask_mover = MaskMover()
818
+
819
+
820
+ with gr.Row():
821
+ gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
822
+ gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
823
+
824
+ with gr.Row():
825
+ with gr.Column():
826
+ # Step 1: Input and Extract
827
+ with gr.Row():
828
+ with gr.Group():
829
+ gr.Markdown("### Step 1: Extract Foreground")
830
+ input_image = gr.Image(type="numpy", label="Input Image", height=480)
831
+ # find_objects_button = gr.Button(value="Find Objects")
832
+ extract_button = gr.Button(value="Remove Background")
833
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
834
+
835
+ with gr.Row():
836
+ # Step 2: Background and Position
837
+ with gr.Group():
838
+ gr.Markdown("### Step 2: Position on Background")
839
+ input_bg = gr.Image(type="numpy", label="Background Image", height=480)
840
+
841
+ with gr.Row():
842
+ x_slider = gr.Slider(
843
+ minimum=0,
844
+ maximum=1000,
845
+ label="X Position",
846
+ value=500,
847
+ visible=False
848
+ )
849
+ y_slider = gr.Slider(
850
+ minimum=0,
851
+ maximum=1000,
852
+ label="Y Position",
853
+ value=500,
854
+ visible=False
855
+ )
856
+ fg_scale_slider = gr.Slider(
857
+ label="Foreground Scale",
858
+ minimum=0.01,
859
+ maximum=3.0,
860
+ value=1.0,
861
+ step=0.01
862
+ )
863
+
864
+ editor = gr.ImageEditor(
865
+ type="numpy",
866
+ label="Position Foreground",
867
+ height=480,
868
+ visible=False
869
+ )
870
+ get_depth_button = gr.Button(value="Get Depth")
871
+ depth_image = gr.Image(type="numpy", label="Depth Image", height=480)
872
+
873
+ # Step 3: Relighting Options
874
+ with gr.Group():
875
+ gr.Markdown("### Step 3: Relighting Settings")
876
+ prompt = gr.Textbox(label="Prompt")
877
+ bg_source = gr.Radio(
878
+ choices=[e.value for e in BGSource],
879
+ value=BGSource.UPLOAD.value,
880
+ label="Background Source",
881
+ type='value'
882
+ )
883
+
884
+ example_prompts = gr.Dataset(
885
+ samples=quick_prompts,
886
+ label='Prompt Quick List',
887
+ components=[prompt]
888
+ )
889
+ # bg_gallery = gr.Gallery(
890
+ # height=450,
891
+ # label='Background Quick List',
892
+ # value=db_examples.bg_samples,
893
+ # columns=5,
894
+ # allow_preview=False
895
+ # )
896
+ relight_button_bg = gr.Button(value="Relight")
897
+
898
+ # Additional settings
899
+ with gr.Group():
900
+ with gr.Row():
901
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
902
+ seed = gr.Number(label="Seed", value=12345, precision=0)
903
+ with gr.Row():
904
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
905
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
906
+
907
+ with gr.Accordion("Advanced options", open=False):
908
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
909
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
910
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
911
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
912
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
913
+ n_prompt = gr.Textbox(
914
+ label="Negative Prompt",
915
+ value='lowres, bad anatomy, bad hands, cropped, worst quality'
916
+ )
917
+
918
+ with gr.Column():
919
+ result_gallery = gr.Image(height=832, label='Outputs')
920
+
921
+ def extract_foreground(image):
922
+ if image is None:
923
+ return None, gr.update(visible=True), gr.update(visible=True)
924
+ result, rgba = run_rmbg(image)
925
+ mask_mover.set_extracted_fg(rgba)
926
+
927
+ return result, gr.update(visible=True), gr.update(visible=True)
928
+
929
+
930
+ original_bg = None
931
+
932
+ extract_button.click(
933
+ fn=extract_foreground,
934
+ inputs=[input_image],
935
+ outputs=[extracted_fg, x_slider, y_slider]
936
+ )
937
+
938
+ # find_objects_button.click(
939
+ # fn=find_objects,
940
+ # inputs=[input_image],
941
+ # outputs=[extracted_fg]
942
+ # )
943
+
944
+ get_depth_button.click(
945
+ fn=get_depth,
946
+ inputs=[input_bg],
947
+ outputs=[depth_image]
948
+ )
949
+
950
+ # def update_position(background, x_pos, y_pos, scale):
951
+ # """Update composite when position changes"""
952
+ # global original_bg
953
+ # if background is None:
954
+ # return None
955
+
956
+ # if original_bg is None:
957
+ # original_bg = background.copy()
958
+
959
+ # # Convert string values to float
960
+ # x_pos = float(x_pos)
961
+ # y_pos = float(y_pos)
962
+ # scale = float(scale)
963
+
964
+ # return mask_mover.create_composite(original_bg, x_pos, y_pos, scale)
965
+
966
+ class BackgroundManager:
967
+ def __init__(self):
968
+ self.original_bg = None
969
+
970
+ def update_position(self, background, x_pos, y_pos, scale):
971
+ """Update composite when position changes"""
972
+ if background is None:
973
+ return None
974
+
975
+ if self.original_bg is None:
976
+ self.original_bg = background.copy()
977
+
978
+ # Convert string values to float
979
+ x_pos = float(x_pos)
980
+ y_pos = float(y_pos)
981
+ scale = float(scale)
982
+
983
+ return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
984
+
985
+ # Create an instance of BackgroundManager
986
+ bg_manager = BackgroundManager()
987
+
988
+
989
+ x_slider.change(
990
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
991
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
992
+ outputs=[input_bg]
993
+ )
994
+
995
+ y_slider.change(
996
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
997
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
998
+ outputs=[input_bg]
999
+ )
1000
+
1001
+ fg_scale_slider.change(
1002
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1003
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1004
+ outputs=[input_bg]
1005
+ )
1006
+
1007
+ # Update inputs list to include fg_scale_slider
1008
+
1009
+ def process_relight_with_position(*args):
1010
+ if mask_mover.extracted_fg is None:
1011
+ gr.Warning("Please extract foreground first")
1012
+ return None
1013
+
1014
+ background = args[1] # Get background image
1015
+ x_pos = float(args[-3]) # x_slider value
1016
+ y_pos = float(args[-2]) # y_slider value
1017
+ scale = float(args[-1]) # fg_scale_slider value
1018
+
1019
+ # Get original foreground size after scaling
1020
+ fg = Image.fromarray(mask_mover.original_fg)
1021
+ new_width = int(fg.width * scale)
1022
+ new_height = int(fg.height * scale)
1023
+
1024
+ # Calculate crop region around foreground position
1025
+ crop_x = int(x_pos - new_width/2)
1026
+ crop_y = int(y_pos - new_height/2)
1027
+ crop_width = new_width
1028
+ crop_height = new_height
1029
+
1030
+ # Add padding for context (20% extra on each side)
1031
+ padding = 0.2
1032
+ crop_x = int(crop_x - crop_width * padding)
1033
+ crop_y = int(crop_y - crop_height * padding)
1034
+ crop_width = int(crop_width * (1 + 2 * padding))
1035
+ crop_height = int(crop_height * (1 + 2 * padding))
1036
+
1037
+ # Ensure crop dimensions are multiples of 8
1038
+ crop_width = ((crop_width + 7) // 8) * 8
1039
+ crop_height = ((crop_height + 7) // 8) * 8
1040
+
1041
+ # Ensure crop region is within image bounds
1042
+ bg_height, bg_width = background.shape[:2]
1043
+ crop_x = max(0, min(crop_x, bg_width - crop_width))
1044
+ crop_y = max(0, min(crop_y, bg_height - crop_height))
1045
+
1046
+ # Get actual crop dimensions after boundary check
1047
+ crop_width = min(crop_width, bg_width - crop_x)
1048
+ crop_height = min(crop_height, bg_height - crop_y)
1049
+
1050
+ # Ensure dimensions are multiples of 8 again
1051
+ crop_width = (crop_width // 8) * 8
1052
+ crop_height = (crop_height // 8) * 8
1053
+
1054
+ # Crop region from background
1055
+ crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
1056
+
1057
+ # Create composite in cropped region
1058
+ fg_local_x = int(new_width/2 + crop_width*padding)
1059
+ fg_local_y = int(new_height/2 + crop_height*padding)
1060
+ cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
1061
+
1062
+ # Process the cropped region
1063
+ crop_args = list(args)
1064
+ crop_args[0] = cropped_composite
1065
+ crop_args[1] = crop_region
1066
+ crop_args[3] = crop_width
1067
+ crop_args[4] = crop_height
1068
+ crop_args = crop_args[:-3] # Remove position and scale arguments
1069
+
1070
+ # Get relit result
1071
+ relit_crop = process_relight_bg(*crop_args)[0]
1072
+
1073
+ # Resize relit result to match crop dimensions if needed
1074
+ if relit_crop.shape[:2] != (crop_height, crop_width):
1075
+ relit_crop = resize_without_crop(relit_crop, crop_width, crop_height)
1076
+
1077
+ # Place relit crop back into original background
1078
+ result = background.copy()
1079
+ result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop
1080
+
1081
+ return result
1082
+
1083
+ ips_bg = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
1084
+
1085
+ # Update button click events with new inputs list
1086
+ relight_button_bg.click(
1087
+ fn=process_relight_with_position,
1088
+ inputs=ips_bg,
1089
+ outputs=[result_gallery]
1090
+ )
1091
+
1092
+
1093
+ example_prompts.click(
1094
+ fn=lambda x: x[0],
1095
+ inputs=example_prompts,
1096
+ outputs=prompt,
1097
+ show_progress=False,
1098
+ queue=False
1099
+ )
1100
+
1101
+
1102
+
1103
+ block.launch(server_name='0.0.0.0', share=True)
gradio_demo_bg.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import safetensors.torch as sf
7
+ import db_examples
8
+ import datetime
9
+ from pathlib import Path
10
+
11
+ from PIL import Image
12
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
13
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
14
+ from diffusers.models.attention_processor import AttnProcessor2_0
15
+ from transformers import CLIPTextModel, CLIPTokenizer
16
+ from briarmbg import BriaRMBG
17
+ from enum import Enum
18
+ from torch.hub import download_url_to_file
19
+ import cv2
20
+
21
+ from typing import Optional
22
+
23
+ from Depth.depth_anything_v2.dpt import DepthAnythingV2
24
+
25
+
26
+
27
+ # from FLORENCE
28
+ import spaces
29
+ import supervision as sv
30
+ import torch
31
+ from PIL import Image
32
+
33
+
34
+ from utils.florence import load_florence_model, run_florence_inference, \
35
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
36
+ from utils.sam import load_sam_image_model, run_sam_inference
37
+
38
+
39
+ import torch
40
+ DEVICE = torch.device("cuda")
41
+
42
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
43
+ if torch.cuda.get_device_properties(0).major >= 8:
44
+ torch.backends.cuda.matmul.allow_tf32 = True
45
+ torch.backends.cudnn.allow_tf32 = True
46
+
47
+
48
+ FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
49
+ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
50
+
51
+ @spaces.GPU(duration=20)
52
+ @torch.inference_mode()
53
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
54
+ def process_image(image_input, text_input) -> Optional[Image.Image]:
55
+ # if not image_input:
56
+ # gr.Info("Please upload an image.")
57
+ # return None
58
+
59
+ # if not text_input:
60
+ # gr.Info("Please enter a text prompt.")
61
+ # return None
62
+
63
+ _, result = run_florence_inference(
64
+ model=FLORENCE_MODEL,
65
+ processor=FLORENCE_PROCESSOR,
66
+ device=DEVICE,
67
+ image=image_input,
68
+ task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
69
+ text=text_input
70
+ )
71
+ detections = sv.Detections.from_lmm(
72
+ lmm=sv.LMM.FLORENCE_2,
73
+ result=result,
74
+ resolution_wh=image_input.size
75
+ )
76
+ detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
77
+ if len(detections) == 0:
78
+ gr.Info("No objects detected.")
79
+ return None
80
+ return Image.fromarray(detections.mask[0].astype("uint8") * 255)
81
+
82
+
83
+ try:
84
+ import xformers
85
+ import xformers.ops
86
+ XFORMERS_AVAILABLE = True
87
+ print("xformers is available - Using memory efficient attention")
88
+ except ImportError:
89
+ XFORMERS_AVAILABLE = False
90
+ print("xformers not available - Using default attention")
91
+
92
+ # 'stablediffusionapi/realistic-vision-v51'
93
+ # 'runwayml/stable-diffusion-v1-5'
94
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
95
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
96
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
97
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
98
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
99
+ rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
100
+
101
+ # Change UNet
102
+
103
+ with torch.no_grad():
104
+ new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
105
+ new_conv_in.weight.zero_()
106
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
107
+ new_conv_in.bias = unet.conv_in.bias
108
+ unet.conv_in = new_conv_in
109
+
110
+ unet_original_forward = unet.forward
111
+
112
+
113
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
114
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
115
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
116
+ new_sample = torch.cat([sample, c_concat], dim=1)
117
+ kwargs['cross_attention_kwargs'] = {}
118
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
119
+
120
+
121
+ unet.forward = hooked_unet_forward
122
+
123
+ # Load
124
+
125
+ model_path = './models/iclight_sd15_fbc.safetensors'
126
+
127
+ if not os.path.exists(model_path):
128
+ download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors', dst=model_path)
129
+
130
+ # Device and dtype setup
131
+ device = torch.device('cuda')
132
+ dtype = torch.float16 # RTX 2070 works well with float16
133
+
134
+ # Memory optimizations for RTX 2070
135
+ torch.backends.cudnn.benchmark = True
136
+ if torch.cuda.is_available():
137
+ torch.backends.cuda.matmul.allow_tf32 = True
138
+ torch.backends.cudnn.allow_tf32 = True
139
+ # Set a smaller attention slice size for RTX 2070
140
+ torch.backends.cuda.max_split_size_mb = 512
141
+
142
+ # Move models to device with consistent dtype
143
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
144
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
145
+ unet = unet.to(device=device, dtype=dtype)
146
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
147
+
148
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
149
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
150
+ model.eval()
151
+
152
+
153
+ from utils.florence import load_florence_model, run_florence_inference, \
154
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
155
+ from utils.sam import load_sam_image_model, run_sam_inference
156
+
157
+
158
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
159
+ if torch.cuda.get_device_properties(0).major >= 8:
160
+ torch.backends.cuda.matmul.allow_tf32 = True
161
+ torch.backends.cudnn.allow_tf32 = True
162
+
163
+
164
+ #FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device)
165
+ SAM_IMAGE_MODEL = load_sam_image_model(device=device)
166
+
167
+ # Update the state dict merging to use correct dtype
168
+ sd_offset = sf.load_file(model_path)
169
+ sd_origin = unet.state_dict()
170
+ sd_merged = {k: sd_origin[k] + sd_offset[k].to(device=device, dtype=dtype) for k in sd_origin.keys()}
171
+ unet.load_state_dict(sd_merged, strict=True)
172
+ del sd_offset, sd_origin, sd_merged
173
+
174
+ def enable_efficient_attention():
175
+ if XFORMERS_AVAILABLE:
176
+ try:
177
+ # RTX 2070 specific settings
178
+ unet.set_use_memory_efficient_attention_xformers(True)
179
+ vae.set_use_memory_efficient_attention_xformers(True)
180
+ print("Enabled xformers memory efficient attention")
181
+ except Exception as e:
182
+ print(f"Xformers error: {e}")
183
+ print("Falling back to sliced attention")
184
+ # Use sliced attention for RTX 2070
185
+ unet.set_attention_slice_size(4)
186
+ vae.set_attention_slice_size(4)
187
+ unet.set_attn_processor(AttnProcessor2_0())
188
+ vae.set_attn_processor(AttnProcessor2_0())
189
+ else:
190
+ # Fallback for when xformers is not available
191
+ print("Using sliced attention")
192
+ unet.set_attention_slice_size(4)
193
+ vae.set_attention_slice_size(4)
194
+ unet.set_attn_processor(AttnProcessor2_0())
195
+ vae.set_attn_processor(AttnProcessor2_0())
196
+
197
+ # Add memory clearing function
198
+ def clear_memory():
199
+ if torch.cuda.is_available():
200
+ torch.cuda.empty_cache()
201
+ torch.cuda.synchronize()
202
+
203
+ # Enable efficient attention
204
+ enable_efficient_attention()
205
+
206
+ # Samplers
207
+
208
+ ddim_scheduler = DDIMScheduler(
209
+ num_train_timesteps=1000,
210
+ beta_start=0.00085,
211
+ beta_end=0.012,
212
+ beta_schedule="scaled_linear",
213
+ clip_sample=False,
214
+ set_alpha_to_one=False,
215
+ steps_offset=1,
216
+ )
217
+
218
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
219
+ num_train_timesteps=1000,
220
+ beta_start=0.00085,
221
+ beta_end=0.012,
222
+ steps_offset=1
223
+ )
224
+
225
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
226
+ num_train_timesteps=1000,
227
+ beta_start=0.00085,
228
+ beta_end=0.012,
229
+ algorithm_type="sde-dpmsolver++",
230
+ use_karras_sigmas=True,
231
+ steps_offset=1
232
+ )
233
+
234
+ # Pipelines
235
+
236
+ t2i_pipe = StableDiffusionPipeline(
237
+ vae=vae,
238
+ text_encoder=text_encoder,
239
+ tokenizer=tokenizer,
240
+ unet=unet,
241
+ scheduler=dpmpp_2m_sde_karras_scheduler,
242
+ safety_checker=None,
243
+ requires_safety_checker=False,
244
+ feature_extractor=None,
245
+ image_encoder=None
246
+ )
247
+
248
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
249
+ vae=vae,
250
+ text_encoder=text_encoder,
251
+ tokenizer=tokenizer,
252
+ unet=unet,
253
+ scheduler=dpmpp_2m_sde_karras_scheduler,
254
+ safety_checker=None,
255
+ requires_safety_checker=False,
256
+ feature_extractor=None,
257
+ image_encoder=None
258
+ )
259
+
260
+
261
+ @torch.inference_mode()
262
+ def encode_prompt_inner(txt: str):
263
+ max_length = tokenizer.model_max_length
264
+ chunk_length = tokenizer.model_max_length - 2
265
+ id_start = tokenizer.bos_token_id
266
+ id_end = tokenizer.eos_token_id
267
+ id_pad = id_end
268
+
269
+ def pad(x, p, i):
270
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
271
+
272
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
273
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
274
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
275
+
276
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
277
+ conds = text_encoder(token_ids).last_hidden_state
278
+
279
+ return conds
280
+
281
+
282
+ @torch.inference_mode()
283
+ def encode_prompt_pair(positive_prompt, negative_prompt):
284
+ c = encode_prompt_inner(positive_prompt)
285
+ uc = encode_prompt_inner(negative_prompt)
286
+
287
+ c_len = float(len(c))
288
+ uc_len = float(len(uc))
289
+ max_count = max(c_len, uc_len)
290
+ c_repeat = int(math.ceil(max_count / c_len))
291
+ uc_repeat = int(math.ceil(max_count / uc_len))
292
+ max_chunk = max(len(c), len(uc))
293
+
294
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
295
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
296
+
297
+ c = torch.cat([p[None, ...] for p in c], dim=1)
298
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
299
+
300
+ return c, uc
301
+
302
+
303
+ @torch.inference_mode()
304
+ def pytorch2numpy(imgs, quant=True):
305
+ results = []
306
+ for x in imgs:
307
+ y = x.movedim(0, -1)
308
+
309
+ if quant:
310
+ y = y * 127.5 + 127.5
311
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
312
+ else:
313
+ y = y * 0.5 + 0.5
314
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
315
+
316
+ results.append(y)
317
+ return results
318
+
319
+
320
+ @torch.inference_mode()
321
+ def numpy2pytorch(imgs):
322
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
323
+ h = h.movedim(-1, 1)
324
+ return h
325
+
326
+
327
+ def resize_and_center_crop(image, target_width, target_height):
328
+ pil_image = Image.fromarray(image)
329
+ original_width, original_height = pil_image.size
330
+ scale_factor = max(target_width / original_width, target_height / original_height)
331
+ resized_width = int(round(original_width * scale_factor))
332
+ resized_height = int(round(original_height * scale_factor))
333
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
334
+ left = (resized_width - target_width) / 2
335
+ top = (resized_height - target_height) / 2
336
+ right = (resized_width + target_width) / 2
337
+ bottom = (resized_height + target_height) / 2
338
+ cropped_image = resized_image.crop((left, top, right, bottom))
339
+ return np.array(cropped_image)
340
+
341
+
342
+ def resize_without_crop(image, target_width, target_height):
343
+ pil_image = Image.fromarray(image)
344
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
345
+ return np.array(resized_image)
346
+
347
+
348
+ @torch.inference_mode()
349
+ def run_rmbg(img, sigma=0.0):
350
+ # Convert RGBA to RGB if needed
351
+ if img.shape[-1] == 4:
352
+ # Use white background for alpha composition
353
+ alpha = img[..., 3:] / 255.0
354
+ rgb = img[..., :3]
355
+ white_bg = np.ones_like(rgb) * 255
356
+ img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
357
+
358
+ H, W, C = img.shape
359
+ assert C == 3
360
+ k = (256.0 / float(H * W)) ** 0.5
361
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
362
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
363
+ alpha = rmbg(feed)[0][0]
364
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
365
+ alpha = alpha.movedim(1, -1)[0]
366
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
367
+
368
+ # Create RGBA image
369
+ rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
370
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
371
+ return result.clip(0, 255).astype(np.uint8), rgba
372
+
373
+
374
+ @torch.inference_mode()
375
+ def process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
376
+ clear_memory()
377
+ bg_source = BGSource(bg_source)
378
+
379
+ if bg_source == BGSource.UPLOAD:
380
+ pass
381
+ elif bg_source == BGSource.UPLOAD_FLIP:
382
+ input_bg = np.fliplr(input_bg)
383
+ elif bg_source == BGSource.GREY:
384
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
385
+ elif bg_source == BGSource.LEFT:
386
+ gradient = np.linspace(224, 32, image_width)
387
+ image = np.tile(gradient, (image_height, 1))
388
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
389
+ elif bg_source == BGSource.RIGHT:
390
+ gradient = np.linspace(32, 224, image_width)
391
+ image = np.tile(gradient, (image_height, 1))
392
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
393
+ elif bg_source == BGSource.TOP:
394
+ gradient = np.linspace(224, 32, image_height)[:, None]
395
+ image = np.tile(gradient, (1, image_width))
396
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
397
+ elif bg_source == BGSource.BOTTOM:
398
+ gradient = np.linspace(32, 224, image_height)[:, None]
399
+ image = np.tile(gradient, (1, image_width))
400
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
401
+ else:
402
+ raise 'Wrong background source!'
403
+
404
+ rng = torch.Generator(device=device).manual_seed(seed)
405
+
406
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
407
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
408
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
409
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
410
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
411
+
412
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
413
+
414
+ latents = t2i_pipe(
415
+ prompt_embeds=conds,
416
+ negative_prompt_embeds=unconds,
417
+ width=image_width,
418
+ height=image_height,
419
+ num_inference_steps=steps,
420
+ num_images_per_prompt=num_samples,
421
+ generator=rng,
422
+ output_type='latent',
423
+ guidance_scale=cfg,
424
+ cross_attention_kwargs={'concat_conds': concat_conds},
425
+ ).images.to(vae.dtype) / vae.config.scaling_factor
426
+
427
+ pixels = vae.decode(latents).sample
428
+ pixels = pytorch2numpy(pixels)
429
+ pixels = [resize_without_crop(
430
+ image=p,
431
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
432
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
433
+ for p in pixels]
434
+
435
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
436
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
437
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
438
+
439
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
440
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
441
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
442
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
443
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
444
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
445
+
446
+ latents = i2i_pipe(
447
+ image=latents,
448
+ strength=highres_denoise,
449
+ prompt_embeds=conds,
450
+ negative_prompt_embeds=unconds,
451
+ width=image_width,
452
+ height=image_height,
453
+ num_inference_steps=int(round(steps / highres_denoise)),
454
+ num_images_per_prompt=num_samples,
455
+ generator=rng,
456
+ output_type='latent',
457
+ guidance_scale=cfg,
458
+ cross_attention_kwargs={'concat_conds': concat_conds},
459
+ ).images.to(vae.dtype) / vae.config.scaling_factor
460
+
461
+ pixels = vae.decode(latents).sample
462
+ pixels = pytorch2numpy(pixels, quant=False)
463
+
464
+ clear_memory()
465
+ return pixels, [fg, bg]
466
+
467
+
468
+ # Add save function
469
+ def save_images(images, prefix="relight"):
470
+ # Create output directory if it doesn't exist
471
+ output_dir = Path("outputs")
472
+ output_dir.mkdir(exist_ok=True)
473
+
474
+ # Create timestamp for unique filenames
475
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
476
+
477
+ saved_paths = []
478
+ for i, img in enumerate(images):
479
+ if isinstance(img, np.ndarray):
480
+ # Convert to PIL Image if numpy array
481
+ img = Image.fromarray(img)
482
+
483
+ # Create filename with timestamp
484
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
485
+ filepath = output_dir / filename
486
+
487
+ # Save image
488
+ img.save(filepath)
489
+
490
+
491
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
492
+ return saved_paths
493
+
494
+
495
+ # Modify process_relight to save images
496
+ @torch.inference_mode()
497
+ def process_relight(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
498
+ input_fg, matting = run_rmbg(input_fg)
499
+ # show input_fg in a new image
500
+ input_fg_img = Image.fromarray(input_fg)
501
+ results, extra_images = process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
502
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
503
+ final_results = results + extra_images
504
+
505
+ # Save the generated images
506
+ save_images(results, prefix="relight")
507
+
508
+ return results
509
+
510
+
511
+ # Modify process_normal to save images
512
+ @torch.inference_mode()
513
+ def process_normal(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
514
+ input_fg, matting = run_rmbg(input_fg, sigma=16)
515
+
516
+ print('left ...')
517
+ left = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.LEFT.value)[0][0]
518
+
519
+ print('right ...')
520
+ right = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.RIGHT.value)[0][0]
521
+
522
+ print('bottom ...')
523
+ bottom = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.BOTTOM.value)[0][0]
524
+
525
+ print('top ...')
526
+ top = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.TOP.value)[0][0]
527
+
528
+ inner_results = [left * 2.0 - 1.0, right * 2.0 - 1.0, bottom * 2.0 - 1.0, top * 2.0 - 1.0]
529
+
530
+ ambient = (left + right + bottom + top) / 4.0
531
+ h, w, _ = ambient.shape
532
+ matting = resize_and_center_crop((matting[..., 0] * 255.0).clip(0, 255).astype(np.uint8), w, h).astype(np.float32)[..., None] / 255.0
533
+
534
+ def safa_divide(a, b):
535
+ e = 1e-5
536
+ return ((a + e) / (b + e)) - 1.0
537
+
538
+ left = safa_divide(left, ambient)
539
+ right = safa_divide(right, ambient)
540
+ bottom = safa_divide(bottom, ambient)
541
+ top = safa_divide(top, ambient)
542
+
543
+ u = (right - left) * 0.5
544
+ v = (top - bottom) * 0.5
545
+
546
+ sigma = 10.0
547
+ u = np.mean(u, axis=2)
548
+ v = np.mean(v, axis=2)
549
+ h = (1.0 - u ** 2.0 - v ** 2.0).clip(0, 1e5) ** (0.5 * sigma)
550
+ z = np.zeros_like(h)
551
+
552
+ normal = np.stack([u, v, h], axis=2)
553
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
554
+ normal = normal * matting + np.stack([z, z, 1 - z], axis=2) * (1 - matting)
555
+
556
+ results = [normal, left, right, bottom, top] + inner_results
557
+ results = [(x * 127.5 + 127.5).clip(0, 255).astype(np.uint8) for x in results]
558
+
559
+
560
+ # Save the generated images
561
+ save_images(results, prefix="normal")
562
+
563
+ return results
564
+
565
+
566
+
567
+
568
+ quick_prompts = [
569
+ 'modern sofa in living room',
570
+ 'elegant dining table with chairs',
571
+ 'luxurious bed in bedroom, cinematic lighting',
572
+ 'minimalist office desk, natural lighting',
573
+ 'vintage wooden cabinet, warm lighting',
574
+ 'contemporary bookshelf, ambient lighting',
575
+ 'designer armchair, dramatic lighting',
576
+ 'modern kitchen island, bright lighting',
577
+ ]
578
+ quick_prompts = [[x] for x in quick_prompts]
579
+
580
+
581
+ class BGSource(Enum):
582
+ UPLOAD = "Use Background Image"
583
+ UPLOAD_FLIP = "Use Flipped Background Image"
584
+ LEFT = "Left Light"
585
+ RIGHT = "Right Light"
586
+ TOP = "Top Light"
587
+ BOTTOM = "Bottom Light"
588
+ GREY = "Ambient"
589
+
590
+
591
+ class MaskMover:
592
+ def __init__(self):
593
+ self.extracted_fg = None
594
+ self.original_fg = None # Store original foreground
595
+
596
+ def set_extracted_fg(self, fg_image):
597
+ """Store the extracted foreground with alpha channel"""
598
+ self.extracted_fg = fg_image.copy()
599
+ self.original_fg = fg_image.copy() # Keep original
600
+ return fg_image
601
+
602
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
603
+ """Create composite with foreground at specified position"""
604
+ if self.original_fg is None or background is None:
605
+ return background
606
+
607
+ # Convert inputs to PIL Images
608
+ if isinstance(background, np.ndarray):
609
+ bg = Image.fromarray(background)
610
+ else:
611
+ bg = background
612
+
613
+ if isinstance(self.original_fg, np.ndarray):
614
+ fg = Image.fromarray(self.original_fg)
615
+ else:
616
+ fg = self.original_fg
617
+
618
+ # Scale the foreground size
619
+ new_width = int(fg.width * scale)
620
+ new_height = int(fg.height * scale)
621
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
622
+
623
+ # Center the scaled foreground at the position
624
+ x = int(x_pos - new_width / 2)
625
+ y = int(y_pos - new_height / 2)
626
+
627
+ # Create composite
628
+ result = bg.copy()
629
+ if fg.mode == 'RGBA': # If foreground has alpha channel
630
+ result.paste(fg, (x, y), fg.split()[3]) # Use alpha channel as mask
631
+ else:
632
+ result.paste(fg, (x, y))
633
+
634
+ return np.array(result)
635
+
636
+ def get_depth(image):
637
+ if image is None:
638
+ return None
639
+ # Convert from PIL/gradio format to cv2
640
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
641
+ # Get depth map
642
+ depth = model.infer_image(raw_img) # HxW raw depth map
643
+ # Normalize depth for visualization
644
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
645
+ # Convert to RGB for display
646
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
647
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
648
+ return Image.fromarray(depth_colored)
649
+
650
+ # def find_objects(image_input):
651
+ # detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
652
+ # if len(detections) == 0:
653
+ # gr.Info("No objects detected.")
654
+ # return None
655
+ # return Image.fromarray(detections.mask[0].astype("uint8") * 255)
656
+
657
+
658
+
659
+ block = gr.Blocks().queue()
660
+ with block:
661
+ mask_mover = MaskMover()
662
+
663
+ with gr.Row():
664
+ gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
665
+ gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
666
+
667
+ with gr.Row():
668
+ with gr.Column():
669
+ # Step 1: Input and Extract
670
+ with gr.Group():
671
+ gr.Markdown("### Step 1: Extract Foreground")
672
+ input_image = gr.Image(type="numpy", label="Input Image", height=480)
673
+ input_text = gr.Textbox(label="Describe target object")
674
+
675
+ find_objects_button = gr.Button(value="Find Objects")
676
+ extract_button = gr.Button(value="Remove Background")
677
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
678
+
679
+
680
+ # Step 2: Background and Position
681
+ with gr.Group():
682
+ gr.Markdown("### Step 2: Position on Background")
683
+ input_bg = gr.Image(type="numpy", label="Background Image", height=480)
684
+
685
+ with gr.Row():
686
+ x_slider = gr.Slider(
687
+ minimum=0,
688
+ maximum=1000,
689
+ label="X Position",
690
+ value=500,
691
+ visible=False
692
+ )
693
+ y_slider = gr.Slider(
694
+ minimum=0,
695
+ maximum=1000,
696
+ label="Y Position",
697
+ value=500,
698
+ visible=False
699
+ )
700
+ fg_scale_slider = gr.Slider(
701
+ label="Foreground Scale",
702
+ minimum=0.01,
703
+ maximum=3.0,
704
+ value=1.0,
705
+ step=0.01
706
+ )
707
+
708
+ get_depth_button = gr.Button(value="Get Depth")
709
+
710
+ depth_image = gr.Image(type="numpy", label="Depth Image", height=480)
711
+
712
+
713
+ editor = gr.ImageEditor(
714
+ type="numpy",
715
+ label="Position Foreground",
716
+ height=480,
717
+ visible=False
718
+ )
719
+
720
+ # Step 3: Relighting Options
721
+ with gr.Group():
722
+ gr.Markdown("### Step 3: Relighting Settings")
723
+ prompt = gr.Textbox(label="Prompt")
724
+ bg_source = gr.Radio(
725
+ choices=[e.value for e in BGSource],
726
+ value=BGSource.UPLOAD.value,
727
+ label="Background Source",
728
+ type='value'
729
+ )
730
+
731
+ example_prompts = gr.Dataset(
732
+ samples=quick_prompts,
733
+ label='Prompt Quick List',
734
+ components=[prompt]
735
+ )
736
+ # bg_gallery = gr.Gallery(
737
+ # height=450,
738
+ # label='Background Quick List',
739
+ # value=db_examples.bg_samples,
740
+ # columns=5,
741
+ # allow_preview=False
742
+ # )
743
+ relight_button = gr.Button(value="Relight")
744
+
745
+ # Additional settings
746
+ with gr.Group():
747
+ with gr.Row():
748
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
749
+ seed = gr.Number(label="Seed", value=12345, precision=0)
750
+ with gr.Row():
751
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
752
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
753
+
754
+ with gr.Accordion("Advanced options", open=False):
755
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
756
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
757
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
758
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
759
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
760
+ n_prompt = gr.Textbox(
761
+ label="Negative Prompt",
762
+ value='lowres, bad anatomy, bad hands, cropped, worst quality'
763
+ )
764
+ normal_button = gr.Button(value="Compute Normal (4x Slower)")
765
+
766
+ with gr.Column():
767
+ result_gallery = gr.Image(height=832, label='Outputs')
768
+
769
+ # Event handlers
770
+ def extract_foreground(image):
771
+ if image is None:
772
+ return None, gr.update(visible=True), gr.update(visible=True)
773
+ result, rgba = run_rmbg(image)
774
+ mask_mover.set_extracted_fg(rgba)
775
+
776
+ return result, gr.update(visible=True), gr.update(visible=True)
777
+
778
+ original_bg = None
779
+
780
+ extract_button.click(
781
+ fn=extract_foreground,
782
+ inputs=[input_image],
783
+ outputs=[extracted_fg, x_slider, y_slider]
784
+ )
785
+
786
+ find_objects_button.click(
787
+ fn=process_image,
788
+ inputs=[input_image, input_text],
789
+ outputs=[extracted_fg]
790
+ )
791
+
792
+ get_depth_button.click(
793
+ fn=get_depth,
794
+ inputs=[input_bg],
795
+ outputs=[depth_image]
796
+ )
797
+
798
+ def update_position(background, x_pos, y_pos, scale):
799
+ """Update composite when position changes"""
800
+ global original_bg
801
+ if background is None:
802
+ return None
803
+
804
+ if original_bg is None:
805
+ original_bg = background.copy()
806
+
807
+ # Convert string values to float
808
+ x_pos = float(x_pos)
809
+ y_pos = float(y_pos)
810
+ scale = float(scale)
811
+
812
+ return mask_mover.create_composite(original_bg, x_pos, y_pos, scale)
813
+
814
+ x_slider.change(
815
+ fn=update_position,
816
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
817
+ outputs=[input_bg]
818
+ )
819
+
820
+ y_slider.change(
821
+ fn=update_position,
822
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
823
+ outputs=[input_bg]
824
+ )
825
+
826
+ fg_scale_slider.change(
827
+ fn=update_position,
828
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
829
+ outputs=[input_bg]
830
+ )
831
+
832
+ # Update inputs list to include fg_scale_slider
833
+ ips = [input_bg, input_bg, prompt, image_width, image_height, num_samples,
834
+ seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise,
835
+ bg_source, x_slider, y_slider, fg_scale_slider] # Added fg_scale_slider
836
+
837
+ def process_relight_with_position(*args):
838
+ if mask_mover.extracted_fg is None:
839
+ gr.Warning("Please extract foreground first")
840
+ return None
841
+
842
+ background = args[1] # Get background image
843
+ x_pos = float(args[-3]) # x_slider value
844
+ y_pos = float(args[-2]) # y_slider value
845
+ scale = float(args[-1]) # fg_scale_slider value
846
+
847
+ # Get original foreground size after scaling
848
+ fg = Image.fromarray(mask_mover.original_fg)
849
+ new_width = int(fg.width * scale)
850
+ new_height = int(fg.height * scale)
851
+
852
+ # Calculate crop region around foreground position
853
+ crop_x = int(x_pos - new_width/2)
854
+ crop_y = int(y_pos - new_height/2)
855
+ crop_width = new_width
856
+ crop_height = new_height
857
+
858
+ # Add padding for context (20% extra on each side)
859
+ padding = 0.2
860
+ crop_x = int(crop_x - crop_width * padding)
861
+ crop_y = int(crop_y - crop_height * padding)
862
+ crop_width = int(crop_width * (1 + 2 * padding))
863
+ crop_height = int(crop_height * (1 + 2 * padding))
864
+
865
+ # Ensure crop dimensions are multiples of 8
866
+ crop_width = ((crop_width + 7) // 8) * 8
867
+ crop_height = ((crop_height + 7) // 8) * 8
868
+
869
+ # Ensure crop region is within image bounds
870
+ bg_height, bg_width = background.shape[:2]
871
+ crop_x = max(0, min(crop_x, bg_width - crop_width))
872
+ crop_y = max(0, min(crop_y, bg_height - crop_height))
873
+
874
+ # Get actual crop dimensions after boundary check
875
+ crop_width = min(crop_width, bg_width - crop_x)
876
+ crop_height = min(crop_height, bg_height - crop_y)
877
+
878
+ # Ensure dimensions are multiples of 8 again
879
+ crop_width = (crop_width // 8) * 8
880
+ crop_height = (crop_height // 8) * 8
881
+
882
+ # Crop region from background
883
+ crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
884
+
885
+ # Create composite in cropped region
886
+ fg_local_x = int(new_width/2 + crop_width*padding)
887
+ fg_local_y = int(new_height/2 + crop_height*padding)
888
+ cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
889
+
890
+ # Process the cropped region
891
+ crop_args = list(args)
892
+ crop_args[0] = cropped_composite
893
+ crop_args[1] = crop_region
894
+ crop_args[3] = crop_width
895
+ crop_args[4] = crop_height
896
+ crop_args = crop_args[:-3] # Remove position and scale arguments
897
+
898
+ # Get relit result
899
+ relit_crop = process_relight(*crop_args)[0]
900
+
901
+ # Resize relit result to match crop dimensions if needed
902
+ if relit_crop.shape[:2] != (crop_height, crop_width):
903
+ relit_crop = resize_without_crop(relit_crop, crop_width, crop_height)
904
+
905
+ # Place relit crop back into original background
906
+ result = background.copy()
907
+ result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop
908
+
909
+ return result
910
+
911
+ # Update button click events with new inputs list
912
+ relight_button.click(
913
+ fn=process_relight_with_position,
914
+ inputs=ips,
915
+ outputs=[result_gallery]
916
+ )
917
+
918
+ # Update normal_button to use same argument handling
919
+ def process_normal_with_position(*args):
920
+ if mask_mover.extracted_fg is None:
921
+ gr.Warning("Please extract foreground first")
922
+ return None
923
+
924
+ background = args[1]
925
+ x_pos = float(args[-3]) # x_slider value
926
+ y_pos = float(args[-2]) # y_slider value
927
+ scale = float(args[-1]) # fg_scale_slider value
928
+
929
+ # Get original foreground size after scaling
930
+ fg = Image.fromarray(mask_mover.original_fg)
931
+ new_width = int(fg.width * scale)
932
+ new_height = int(fg.height * scale)
933
+
934
+ # Calculate crop region around foreground position
935
+ crop_x = int(x_pos - new_width/2)
936
+ crop_y = int(y_pos - new_height/2)
937
+ crop_width = new_width
938
+ crop_height = new_height
939
+
940
+ # Add padding for context (20% extra on each side)
941
+ padding = 0.2
942
+ crop_x = int(crop_x - crop_width * padding)
943
+ crop_y = int(crop_y - crop_height * padding)
944
+ crop_width = int(crop_width * (1 + 2 * padding))
945
+ crop_height = int(crop_height * (1 + 2 * padding))
946
+
947
+ # Ensure crop dimensions are multiples of 8
948
+ crop_width = ((crop_width + 7) // 8) * 8
949
+ crop_height = ((crop_height + 7) // 8) * 8
950
+
951
+ # Ensure crop region is within image bounds
952
+ bg_height, bg_width = background.shape[:2]
953
+ crop_x = max(0, min(crop_x, bg_width - crop_width))
954
+ crop_y = max(0, min(crop_y, bg_height - crop_height))
955
+
956
+ # Crop region from background
957
+ crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
958
+
959
+ # Create composite in cropped region
960
+ fg_local_x = int(new_width/2 + crop_width*padding)
961
+ fg_local_y = int(new_height/2 + crop_height*padding)
962
+ cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
963
+
964
+ # Process the cropped region
965
+ crop_args = list(args)
966
+ crop_args[0] = cropped_composite
967
+ crop_args[1] = crop_region
968
+ crop_args[3] = crop_width
969
+ crop_args[4] = crop_height
970
+ crop_args = crop_args[:-3]
971
+
972
+ # Get processed result
973
+ processed_crop = process_normal(*crop_args)
974
+
975
+ # Place processed crop back into original background
976
+ result = background.copy()
977
+ result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = processed_crop
978
+
979
+ return result[0] if result else None
980
+
981
+ normal_button.click(
982
+ fn=process_normal_with_position,
983
+ inputs=ips,
984
+ outputs=[result_gallery]
985
+ )
986
+
987
+ example_prompts.click(
988
+ fn=lambda x: x[0],
989
+ inputs=example_prompts,
990
+ outputs=prompt,
991
+ show_progress=False,
992
+ queue=False
993
+ )
994
+
995
+ # def bg_gallery_selected(gal, evt: gr.SelectData):
996
+ # return gal[evt.index]['name']
997
+
998
+ # bg_gallery.select(
999
+ # fn=bg_gallery_selected,
1000
+ # inputs=bg_gallery,
1001
+ # outputs=input_bg
1002
+ # )
1003
+
1004
+ block.launch(server_name='0.0.0.0')
requirements.txt ADDED
Binary file (5.21 kB). View file
 
xformers-0.0.28.post3.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f227dabb9841f5235c452c988b459ebae1c00a1c7d2d53d4ad1335318807c982
3
+ size 4842597