BioTrove-CLIP-Demo / make_txt_embedding.py
BGLab's picture
BioTrove
2c39a46 verified
"""
Makes the entire set of text emebeddings for all possible names in the tree of life.
Uses the catalog.csv file from TreeOfLife-10M.
"""
import argparse
import csv
import json
import os
import logging
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm
import lib
from templates import openai_imagenet_template
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger()
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
@torch.no_grad()
def write_txt_features(name_lookup):
if os.path.isfile(args.out_path):
all_features = np.load(args.out_path)
else:
all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
batch_size = args.batch_size // len(openai_imagenet_template)
for batch, (names, indices) in enumerate(
tqdm(
lib.batched(name_lookup.values(), batch_size),
desc="txt feats",
total=len(name_lookup) // batch_size,
)
):
# Skip if any non-zero elements
if all_features[:, indices].any():
logger.info(f"Skipping batch {batch}")
continue
txts = [
template(name) for name in names for template in openai_imagenet_template
]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = torch.reshape(
txt_features, (len(names), len(openai_imagenet_template), 512)
)
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
txt_features /= txt_features.norm(dim=1, keepdim=True)
all_features[:, indices] = txt_features.T.cpu().numpy()
if batch % 100 == 0:
np.save(args.out_path, all_features)
np.save(args.out_path, all_features)
def convert_txt_features_to_avgs(name_lookup):
assert os.path.isfile(args.out_path)
# Put that big boy on the GPU. We're going fast.
all_features = torch.from_numpy(np.load(args.out_path)).to(device)
logger.info("Loaded text features from disk to %s.", device)
names_by_rank = [set() for rank in ranks]
for name, index in tqdm(name_lookup.values()):
i = len(name) - 1
names_by_rank[i].add((name, index))
zeroed = 0
for i, rank in reversed(list(enumerate(ranks))):
if rank == "Species":
continue
for name, index in tqdm(names_by_rank[i], desc=rank):
species = tuple(
zip(
*(
(d, i)
for d, i in name_lookup.descendants(prefix=name)
if len(d) >= 6
)
)
)
if not species:
logger.warning("No species for %s.", " ".join(name))
all_features[:, index] = 0.0
zeroed += 1
continue
values, indices = species
mean = all_features[:, indices].mean(dim=1)
all_features[:, index] = F.normalize(mean, dim=0)
out_path, ext = os.path.splitext(args.out_path)
np.save(f"{out_path}_avgs{ext}", all_features.cpu().numpy())
if zeroed:
logger.warning(
"Zeroed out %d nodes because they didn't have any genus or species-level labels.",
zeroed,
)
def convert_txt_features_to_species_only(name_lookup):
assert os.path.isfile(args.out_path)
all_features = np.load(args.out_path)
logger.info("Loaded text features from disk.")
species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7]
species_features = np.zeros((512, len(species)), dtype=np.float32)
species_names = [""] * len(species)
for new_i, (name, old_i) in enumerate(tqdm(species)):
species_features[:, new_i] = all_features[:, old_i]
species_names[new_i] = name
out_path, ext = os.path.splitext(args.out_path)
np.save(f"{out_path}_species{ext}", species_features)
with open(f"{out_path}_species.json", "w") as fd:
json.dump(species_names, fd, indent=2)
def get_name_lookup(catalog_path, cache_path):
if os.path.isfile(cache_path):
with open(cache_path) as fd:
lookup = lib.TaxonomicTree.from_dict(json.load(fd))
return lookup
lookup = lib.TaxonomicTree()
with open(catalog_path) as fd:
reader = csv.DictReader(fd)
for row in tqdm(reader, desc="catalog"):
name = [
row["kingdom"],
row["phylum"],
row["class"],
row["order"],
row["family"],
row["genus"],
row["species"],
]
if any(not value for value in name):
name = name[: name.index("")]
lookup.add(name)
with open(args.name_cache_path, "w") as fd:
json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
return lookup
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--catalog-path",
help="Path to the catalog.csv file from TreeOfLife-10M.",
required=True,
)
parser.add_argument("--out-path", help="Path to the output file.", required=True)
parser.add_argument(
"--name-cache-path",
help="Path to the name cache file.",
default="name_lookup.json",
)
parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
args = parser.parse_args()
name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
logger.info("Got name lookup.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
logger.info("Created model.")
model = torch.compile(model)
logger.info("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
write_txt_features(name_lookup)
convert_txt_features_to_avgs(name_lookup)
convert_txt_features_to_species_only(name_lookup)