import pinecone class VectorDB: def __init__(self, retreiver, API_KEY): pinecone.init(api_key=API_KEY, environment='us-east1-gcp') self.retreiver = retreiver if 'wikiqav2-index' not in pinecone.list_indexes(): pinecone.create_index( name='wikiqav2-index', dimension=self.retreiver.get_sentence_embedding_dimension(), metric='cosine' ) self.index = pinecone.Index('wikiqav2-index') def upsert_data(self, article_data): for i in range(len(article_data)): article_data[i]['encoding'] = self.retreiver.encode(article_data[i]['text']).tolist() upserts = [(str(v['id']), v['encoding'], {'text': v['text'], 'section': v['section']}) for v in article_data] # index.upsert(vectors=upserts[0]) for i in range(0, len(upserts), 10): i_end = i + 10 if i_end > len(upserts): i_end = len(upserts) self.index.upsert(vectors=upserts[i:i_end]) def get_contexts(self, question): xq = self.retreiver.encode([question]).tolist() contexts = self.index.query(xq, top_k=1, include_metadata=True) return contexts