Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,889 Bytes
0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import sys
import os
import os.path as osp
import torch
import random
import json
import time
import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
sys.path.append('.')
from src.benchmarks import get_semistructured_data, get_qa_dataset
from src.tools.api import get_openai_embeddings
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='amazon',
choices=['amazon', 'primekg', 'mag']
)
parser.add_argument('--emb_model', default='text-embedding-ada-002',
choices=[
'text-embedding-ada-002',
'text-embedding-3-small',
'text-embedding-3-large'
]
)
parser.add_argument('--mode', default='doc', choices=['doc', 'query'])
parser.add_argument("--emb_dir", default="emb/", type=str)
parser.add_argument('--add_rel', action='store_true', default=False,
help='add relation to the text')
parser.add_argument('--compact', action='store_true', default=False,
help='make the text compact when input to the model')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, args.mode)
os.makedirs(emb_dir, exist_ok=True)
if args.mode == 'doc':
kb = get_semistructured_data(args.dataset)
lst = kb.candidate_ids
emb_path = osp.join(emb_dir, f'candidate_emb_dict.pt')
if args.mode == 'query':
qa_dataset = get_qa_dataset(args.dataset)
lst = [qa_dataset[i][1] for i in range(len(qa_dataset))]
emb_path = osp.join(emb_dir, f'query_emb_dict.pt')
random.shuffle(lst)
if osp.exists(emb_path):
emb_dict = torch.load(emb_path)
exisiting_indices = list(emb_dict.keys())
print(f'Loaded existing embeddings from {emb_path}. Size: {len(emb_dict)}')
else:
emb_dict = {}
exisiting_indices = []
texts, indices = [], []
for idx in tqdm(lst):
if idx in exisiting_indices:
continue
if args.mode == 'query':
text = qa_dataset.get_query_by_qid(idx)
elif args.mode == 'doc':
text = kb.get_doc_info(idx, add_rel=args.add_rel, compact=args.compact)
texts.append(text)
indices.append(idx)
print(f'Generating embeddings for {len(texts)} texts...')
embs = get_openai_embeddings(texts, model=args.emb_model).view(len(texts), -1).cpu()
print('Embedding size:', embs.size())
for idx, emb in zip(indices, embs):
emb_dict[idx] = emb
torch.save(emb_dict, emb_path)
print(f'Saved embeddings to {emb_path}!')
|