Spaces:
Running
on
Zero
Running
on
Zero
# Imports | |
import pdb | |
import time | |
import torch | |
import tqlt.utils as tu | |
from models.birefnet import BiRefNet | |
from PIL import Image | |
from torchvision import transforms | |
# # Option 1: loading BiRefNet with weights: | |
from transformers import AutoModelForImageSegmentation | |
# # Option-3: Loading model and weights from local disk: | |
from utils import check_state_dict | |
# birefnet = AutoModelForImageSegmentation.from_pretrained( | |
# "zhengpeng7/BiRefNet", trust_remote_code=True, local | |
# ) | |
# # Option-2: loading weights with BiReNet codes: | |
# birefnet = BiRefNet.from_pretrained('zhengpeng7/BiRefNet') | |
imgs = tu.next_files("./in_the_wild", ".png") | |
birefnet = BiRefNet(bb_pretrained=False) | |
state_dict = torch.load("./BiRefNet-general-epoch_244.pth", map_location="cpu") | |
state_dict = check_state_dict(state_dict) | |
birefnet.load_state_dict(state_dict) | |
# Load Model | |
device = "cuda" | |
torch.set_float32_matmul_precision(["high", "highest"][0]) | |
birefnet.to(device) | |
birefnet.eval() | |
print("BiRefNet is ready to use.") | |
# Input Data | |
transform_image = transforms.Compose( | |
[ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
import os | |
from glob import glob | |
from image_proc import refine_foreground | |
src_dir = "./images_todo" | |
image_paths = glob(os.path.join(src_dir, "*")) | |
dst_dir = "./predictions" | |
os.makedirs(dst_dir, exist_ok=True) | |
for image_path in imgs: | |
print("Processing {} ...".format(image_path)) | |
image = Image.open(image_path) | |
input_images = transform_image(image).unsqueeze(0).to("cuda") | |
# Prediction | |
start = time.time() | |
with torch.no_grad(): | |
preds = birefnet(input_images)[-1].sigmoid().cpu() | |
print(time.time() - start) | |
pred = preds[0].squeeze() | |
# Save Results | |
file_ext = os.path.splitext(image_path)[-1] | |
pred_pil = transforms.ToPILImage()(pred) | |
pred_pil = pred_pil.resize(image.size) | |
pred_pil.save(image_path.replace(src_dir, dst_dir).replace(file_ext, "-mask.png")) | |
image_masked = refine_foreground(image, pred_pil) | |
image_masked.putalpha(pred_pil) | |
image_masked.save( | |
image_path.replace(src_dir, dst_dir).replace(file_ext, "-subject.png") | |
) | |
# Save Results | |
file_ext = os.path.splitext(image_path)[-1] | |
pred_pil = transforms.ToPILImage()(pred) | |
pred_pil = pred_pil.resize(image.size) | |
pred_pil.save(image_path.replace(src_dir, dst_dir).replace(file_ext, "-mask.png")) | |