File size: 2,889 Bytes
0c3992e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys

import os
import os.path as osp

import torch
import random
import json
import time
import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
sys.path.append('.')
from src.benchmarks import get_semistructured_data, get_qa_dataset
from src.tools.api import get_openai_embeddings


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='amazon', 
                        choices=['amazon', 'primekg', 'mag']
                        )
    parser.add_argument('--emb_model', default='text-embedding-ada-002', 
                        choices=[
                            'text-embedding-ada-002', 
                            'text-embedding-3-small', 
                            'text-embedding-3-large'
                            ]
                        )
    parser.add_argument('--mode', default='doc', choices=['doc', 'query'])
    parser.add_argument("--emb_dir", default="emb/", type=str)
    parser.add_argument('--add_rel', action='store_true', default=False, 
                        help='add relation to the text')
    parser.add_argument('--compact', action='store_true', default=False, 
                        help='make the text compact when input to the model')
    return parser.parse_args()
    
    

if __name__ == '__main__':
    args = parse_args()
    emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, args.mode)
    os.makedirs(emb_dir, exist_ok=True)

    if args.mode == 'doc':
        kb = get_semistructured_data(args.dataset)
        lst = kb.candidate_ids
        emb_path = osp.join(emb_dir, f'candidate_emb_dict.pt')
    if args.mode == 'query':
        qa_dataset = get_qa_dataset(args.dataset)
        lst = [qa_dataset[i][1] for i in range(len(qa_dataset))]
        emb_path = osp.join(emb_dir, f'query_emb_dict.pt')
    random.shuffle(lst)
            
    if osp.exists(emb_path):
        emb_dict = torch.load(emb_path)
        exisiting_indices = list(emb_dict.keys())
        print(f'Loaded existing embeddings from {emb_path}. Size: {len(emb_dict)}')
    else:
        emb_dict = {}
        exisiting_indices = []

    texts, indices = [], []
    for idx in tqdm(lst):
        if idx in exisiting_indices:
            continue
        if args.mode == 'query':
            text = qa_dataset.get_query_by_qid(idx)
        elif args.mode == 'doc':
            text = kb.get_doc_info(idx, add_rel=args.add_rel, compact=args.compact)
        texts.append(text)
        indices.append(idx)
        
    print(f'Generating embeddings for {len(texts)} texts...')
    embs = get_openai_embeddings(texts, model=args.emb_model).view(len(texts), -1).cpu()
    print('Embedding size:', embs.size())
    
    for idx, emb in zip(indices, embs):
        emb_dict[idx] = emb
    torch.save(emb_dict, emb_path)
    print(f'Saved embeddings to {emb_path}!')