Spaces:
Runtime error
Runtime error
Update gradio_demo.py
Browse files- gradio_demo.py +8 -8
gradio_demo.py
CHANGED
@@ -86,7 +86,7 @@ model.eval()
|
|
86 |
# Change UNet
|
87 |
|
88 |
with torch.no_grad():
|
89 |
-
new_conv_in = torch.nn.Conv2d(
|
90 |
new_conv_in.weight.zero_()
|
91 |
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
92 |
new_conv_in.bias = unet.conv_in.bias
|
@@ -142,8 +142,8 @@ unet.forward = hooked_unet_forward
|
|
142 |
# Load
|
143 |
|
144 |
# Model paths
|
145 |
-
model_path = './models/
|
146 |
-
|
147 |
model_path3 = './checkpoints/sam2_hiera_large.pt'
|
148 |
model_path4 = './checkpoints/config.json'
|
149 |
model_path5 = './checkpoints/preprocessor_config.json'
|
@@ -155,7 +155,7 @@ BASE_URL = 'https://huggingface.co/Ashoka74/Placement/resolve/main/'
|
|
155 |
|
156 |
# Model URLs
|
157 |
model_urls = {
|
158 |
-
model_path: '
|
159 |
model_path2: 'depth_anything_v2_vits.pth',
|
160 |
model_path3: 'sam2_hiera_large.pt',
|
161 |
model_path4: 'config.json',
|
@@ -186,8 +186,8 @@ ensure_directories()
|
|
186 |
download_models()
|
187 |
|
188 |
|
189 |
-
if not os.path.exists(model_path):
|
190 |
-
|
191 |
|
192 |
sd_offset = sf.load_file(model_path)
|
193 |
sd_origin = unet.state_dict()
|
@@ -953,7 +953,7 @@ def process_image(input_image, input_text):
|
|
953 |
|
954 |
block = gr.Blocks().queue()
|
955 |
with block:
|
956 |
-
with gr.Tab("Text"):
|
957 |
with gr.Row():
|
958 |
gr.Markdown("## Product Placement from Text")
|
959 |
with gr.Row():
|
@@ -1044,7 +1044,7 @@ with block:
|
|
1044 |
outputs=[extracted_objects, extracted_fg], show_progress=True
|
1045 |
)
|
1046 |
|
1047 |
-
with gr.Tab("Background", visible=
|
1048 |
# empty cache
|
1049 |
|
1050 |
mask_mover = MaskMover()
|
|
|
86 |
# Change UNet
|
87 |
|
88 |
with torch.no_grad():
|
89 |
+
new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
90 |
new_conv_in.weight.zero_()
|
91 |
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
92 |
new_conv_in.bias = unet.conv_in.bias
|
|
|
142 |
# Load
|
143 |
|
144 |
# Model paths
|
145 |
+
model_path = './models/iclight_sd15_fbc.safetensors'
|
146 |
+
model_path = './checkpoints/depth_anything_v2_vits.pth'
|
147 |
model_path3 = './checkpoints/sam2_hiera_large.pt'
|
148 |
model_path4 = './checkpoints/config.json'
|
149 |
model_path5 = './checkpoints/preprocessor_config.json'
|
|
|
155 |
|
156 |
# Model URLs
|
157 |
model_urls = {
|
158 |
+
model_path: 'iclight_sd15_fbc.safetensors',
|
159 |
model_path2: 'depth_anything_v2_vits.pth',
|
160 |
model_path3: 'sam2_hiera_large.pt',
|
161 |
model_path4: 'config.json',
|
|
|
186 |
download_models()
|
187 |
|
188 |
|
189 |
+
# if not os.path.exists(model_path):
|
190 |
+
# download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
|
191 |
|
192 |
sd_offset = sf.load_file(model_path)
|
193 |
sd_origin = unet.state_dict()
|
|
|
953 |
|
954 |
block = gr.Blocks().queue()
|
955 |
with block:
|
956 |
+
with gr.Tab("Text", visible=False):
|
957 |
with gr.Row():
|
958 |
gr.Markdown("## Product Placement from Text")
|
959 |
with gr.Row():
|
|
|
1044 |
outputs=[extracted_objects, extracted_fg], show_progress=True
|
1045 |
)
|
1046 |
|
1047 |
+
with gr.Tab("Background", visible=True):
|
1048 |
# empty cache
|
1049 |
|
1050 |
mask_mover = MaskMover()
|