Spaces:
Sleeping
Sleeping
import json | |
import torch | |
import open_clip | |
from tqdm import tqdm | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if device == "cpu": | |
device = "mps" if torch.backends.mps.is_available() else "cpu" | |
def generate_cache(texts: list[str], model_name: str, batch_size: int = 16) -> dict: | |
model, _, _ = open_clip.create_model_and_transforms(model_name, device=device) | |
tokenizer = open_clip.get_tokenizer(model_name) | |
cache = {} | |
for i in tqdm(range(0, len(texts), batch_size)): | |
batch = texts[i : i + batch_size] | |
tokens = tokenizer(batch).to(device) | |
with torch.no_grad(), torch.cuda.amp.autocast(): | |
embeddings = model.encode_text(tokens, normalize=True).cpu().numpy() | |
for text, embedding in zip(batch, embeddings): | |
cache[text] = embedding.tolist() | |
return cache | |
def flatten_taxonomy(taxonomy: dict) -> list[str]: | |
classes = [] | |
for key, value in taxonomy.items(): | |
classes.append(key) | |
if isinstance(value, dict): | |
classes.extend(flatten_taxonomy(value)) | |
if isinstance(value, list): | |
classes.extend(value) | |
return classes | |
def main(): | |
models = [ | |
"hf-hub:Marqo/marqo-ecommerce-embeddings-B", | |
"hf-hub:Marqo/marqo-ecommerce-embeddings-L", | |
"ViT-B-16" | |
] | |
with open("amazon.json") as f: | |
taxonomy = json.load(f) | |
print("Loaded taxonomy") | |
print("Flattening taxonomy") | |
texts = flatten_taxonomy(taxonomy) | |
print("Generating cache") | |
for model in models: | |
cache = generate_cache(texts, model) | |
with open(f'{model.split("/")[-1]}.json', "w+") as f: | |
json.dump(cache, f) | |
if __name__ == "__main__": | |
main() | |