Spaces:
svjack
/
Runtime error

File size: 2,481 Bytes
c614b0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Imports
import pdb
import time

import torch
import tqlt.utils as tu
from models.birefnet import BiRefNet
from PIL import Image
from torchvision import transforms

# # Option 1: loading BiRefNet with weights:
from transformers import AutoModelForImageSegmentation

# # Option-3: Loading model and weights from local disk:
from utils import check_state_dict

# birefnet = AutoModelForImageSegmentation.from_pretrained(
#     "zhengpeng7/BiRefNet", trust_remote_code=True, local
# )

# # Option-2: loading weights with BiReNet codes:
# birefnet = BiRefNet.from_pretrained('zhengpeng7/BiRefNet')
imgs = tu.next_files("./in_the_wild", ".png")


birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load("./BiRefNet-general-epoch_244.pth", map_location="cpu")
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)


# Load Model
device = "cuda"
torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet.to(device)
birefnet.eval()
print("BiRefNet is ready to use.")

# Input Data
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


import os
from glob import glob

from image_proc import refine_foreground

src_dir = "./images_todo"
image_paths = glob(os.path.join(src_dir, "*"))
dst_dir = "./predictions"
os.makedirs(dst_dir, exist_ok=True)
for image_path in imgs:

    print("Processing {} ...".format(image_path))
    image = Image.open(image_path)
    input_images = transform_image(image).unsqueeze(0).to("cuda")

    # Prediction
    start = time.time()

    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()

    print(time.time() - start)
    pred = preds[0].squeeze()

    # Save Results
    file_ext = os.path.splitext(image_path)[-1]
    pred_pil = transforms.ToPILImage()(pred)
    pred_pil = pred_pil.resize(image.size)
    pred_pil.save(image_path.replace(src_dir, dst_dir).replace(file_ext, "-mask.png"))
    image_masked = refine_foreground(image, pred_pil)
    image_masked.putalpha(pred_pil)
    image_masked.save(
        image_path.replace(src_dir, dst_dir).replace(file_ext, "-subject.png")
    )

    # Save Results
    file_ext = os.path.splitext(image_path)[-1]
    pred_pil = transforms.ToPILImage()(pred)
    pred_pil = pred_pil.resize(image.size)
    pred_pil.save(image_path.replace(src_dir, dst_dir).replace(file_ext, "-mask.png"))