Nutrigenics-chatbot / tools /embs /save_blip_embs_vids.py
OmkarThawakar
initail commit
ed00004
import os
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
project_root = os.path.abspath(os.path.join(script_dir, "..", ".."))
sys.path.append(project_root)
from src.data.embs import VideoDataset
from src.model.blip_embs import blip_embs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_blip_config(model="base"):
config = dict()
if model == "base":
config[
"pretrained"
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
# ] = "/linkhome/rech/genuvt01/ucp99db/.cache/torch/hub/checkpoints/model_base_retrieval_coco.pth"
config["vit"] = "base"
config["batch_size_train"] = 32
config["batch_size_test"] = 16
config["vit_grad_ckpt"] = True
config["vit_ckpt_layer"] = 4
config["init_lr"] = 1e-5
elif model == "large":
config[
"pretrained"
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
config["vit"] = "large"
config["batch_size_train"] = 16
config["batch_size_test"] = 32
config["vit_grad_ckpt"] = True
config["vit_ckpt_layer"] = 12
config["init_lr"] = 5e-6
config["image_size"] = 384
config["queue_size"] = 57600
config["alpha"] = 0.4
config["k_test"] = 256
config["negative_all_rank"] = True
return config
@torch.no_grad()
def main(args):
save_tokens = "tokens-" if args.save_all_tokens else ""
save_dir = (
args.video_dir.parent / f"blip-vid-embs-{save_tokens}{args.model_type}-all"
)
save_dir.mkdir(exist_ok=True)
dataset = VideoDataset(
video_dir=args.video_dir,
todo_ids=args.todo_ids,
num_shards=args.num_shards,
shard_id=args.shard_id,
frames_video=args.frames_video,
save_dir=save_dir,
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
print(f"Creating model {args.model_type}")
config = get_blip_config(args.model_type)
model = blip_embs(
pretrained=config["pretrained"],
image_size=config["image_size"],
vit=config["vit"],
vit_grad_ckpt=config["vit_grad_ckpt"],
vit_ckpt_layer=config["vit_ckpt_layer"],
queue_size=config["queue_size"],
negative_all_rank=config["negative_all_rank"],
)
model = model.to(device)
model.eval()
for video_ids, f_idxs, frames in tqdm(loader):
frames = frames.to(device)
bs, nf, c, h, w = frames.shape
frames = frames.view(bs * nf, c, h, w)
frm_embs = model.visual_encoder(frames)
if args.save_all_tokens:
frm_feats = frm_embs.cpu()
frm_feats = frm_feats.view(bs, nf, 577, 1024)
else:
frm_feats = F.normalize(model.vision_proj(frm_embs[:, 0, :]), dim=-1).cpu()
frm_feats = frm_feats.view(bs, nf, -1)
for video_id, f_idx, frm_feat in zip(video_ids, f_idxs, frm_feats):
# remove the features with f_idx=-1
frm_feat = frm_feat[f_idx > -1]
f_idx = f_idx[f_idx > -1]
if len(f_idx) == 0:
continue
save_pth = save_dir / f"{video_id}.pth"
if save_pth.exists():
continue
save_pth.parent.mkdir(exist_ok=True)
torch.save(frm_feat, save_pth)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--video_dir", type=Path, default="datasets/WebVid/8M/train/")
parser.add_argument("--todo_ids", type=str, default=None)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument(
"--model_type", type=str, default="large", choices=["base", "large"]
)
parser.add_argument("--num_shards", type=int, default=1)
parser.add_argument("--shard_id", type=int, default=0)
parser.add_argument("--frames_video", type=int, default=15)
parser.add_argument("--save_all_tokens", action="store_true")
args = parser.parse_args()
assert args.video_dir.exists(), f"{args.video_dir} does not exist"
main(args)