Spaces:
Sleeping
Sleeping
import os | |
import sys | |
from pathlib import Path | |
import torch | |
import torch.nn.functional as F | |
from tqdm.auto import tqdm | |
script_path = os.path.abspath(__file__) | |
script_dir = os.path.dirname(script_path) | |
project_root = os.path.abspath(os.path.join(script_dir, "..", "..")) | |
sys.path.append(project_root) | |
from src.data.embs import ImageDataset | |
from src.model.blip_embs import blip_embs | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_blip_config(model="base"): | |
config = dict() | |
if model == "base": | |
config[ | |
"pretrained" | |
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth " | |
config["vit"] = "base" | |
config["batch_size_train"] = 32 | |
config["batch_size_test"] = 16 | |
config["vit_grad_ckpt"] = True | |
config["vit_ckpt_layer"] = 4 | |
config["init_lr"] = 1e-5 | |
elif model == "large": | |
config[ | |
"pretrained" | |
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" | |
config["vit"] = "large" | |
config["batch_size_train"] = 16 | |
config["batch_size_test"] = 32 | |
config["vit_grad_ckpt"] = True | |
config["vit_ckpt_layer"] = 12 | |
config["init_lr"] = 5e-6 | |
config["image_size"] = 384 | |
config["queue_size"] = 57600 | |
config["alpha"] = 0.4 | |
config["k_test"] = 256 | |
config["negative_all_rank"] = True | |
return config | |
def main(args): | |
dataset = ImageDataset( | |
image_dir=args.image_dir, | |
img_ext=args.img_ext, | |
save_dir=args.save_dir, | |
) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=False, | |
pin_memory=True, | |
num_workers=args.num_workers, | |
) | |
print("Creating model") | |
config = get_blip_config(args.model_type) | |
model = blip_embs( | |
pretrained=config["pretrained"], | |
image_size=config["image_size"], | |
vit=config["vit"], | |
vit_grad_ckpt=config["vit_grad_ckpt"], | |
vit_ckpt_layer=config["vit_ckpt_layer"], | |
queue_size=config["queue_size"], | |
negative_all_rank=config["negative_all_rank"], | |
) | |
model = model.to(device) | |
model.eval() | |
for imgs, video_ids in tqdm(loader): | |
imgs = imgs.to(device) | |
img_embs = model.visual_encoder(imgs) | |
img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() | |
for img_feat, video_id in zip(img_feats, video_ids): | |
torch.save(img_feat, args.save_dir / f"{video_id}.pth") | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--image_dir", type=Path, required=True, help="Path to image directory" | |
) | |
parser.add_argument("--save_dir", type=Path) | |
parser.add_argument("--img_ext", type=str, default="png") | |
parser.add_argument("--batch_size", type=int, default=128) | |
parser.add_argument("--num_workers", type=int, default=8) | |
parser.add_argument( | |
"--model_type", type=str, default="large", choices=["base", "large"] | |
) | |
args = parser.parse_args() | |
subdirectories = [subdir for subdir in args.image_dir.iterdir() if subdir.is_dir()] | |
if len(subdirectories) == 0: | |
args.save_dir = args.image_dir.parent / f"blip-embs-{args.model_type}" | |
args.save_dir.mkdir(exist_ok=True) | |
main(args) | |
else: | |
for subdir in subdirectories: | |
args.image_dir = subdir | |
args.save_dir = ( | |
subdir.parent.parent / f"blip-embs-{args.model_type}" / subdir.name | |
) | |
args.save_dir.mkdir(exist_ok=True, parents=True) | |
main(args) | |