File size: 4,227 Bytes
fe51e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e5232f
fe51e27
 
 
 
 
 
 
 
 
 
 
 
 
 
a1bf242
fe51e27
 
 
 
 
 
 
 
 
 
 
000d89c
fe51e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import logging

import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from pinecone import Pinecone, ServerlessSpec

# Disable parallelism for tokenizers
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class Vectorizer:
    def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True):
        logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}")
        self.model = SentenceTransformer(model_name)
        self.prompts = []
        self.batch_size = batch_size
        self.pinecone_index_name = "hfs-search-prompts-index"
        self._init_pinecone = init_pinecone
        self._setup_pinecone()
        self._load_prompts()

    def _setup_pinecone(self):
        logger.info("Setting up Pinecone")
        # Initialize Pinecone
        pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090')
        # Check if the Pinecone index exists, if not create it
        existing_indexes = pinecone.list_indexes()

        if self.pinecone_index_name not in existing_indexes:
            logger.info(f"Creating Pinecone index: {self.pinecone_index_name}")
            if self._init_pinecone:
                # pinecone.delete_index(self.pinecone_index_name)
                pinecone.create_index(
                    name=self.pinecone_index_name,
                    dimension=768,
                    metric='cosine',
                    spec=ServerlessSpec(
                        cloud="aws",
                        region="us-east-1"
                    )
                )
        else:
            logger.info(f"Pinecone index {self.pinecone_index_name} already exists")
            pinecone.delete_index(self.pinecone_index_name)

        self.index = pinecone.Index(self.pinecone_index_name)

    def _load_prompts(self):
        logger.info("Loading prompts from Pinecone")
        self.prompts = []
        # Fetch vectors from the Pinecone index
        index_stats = self.index.describe_index_stats()
        logger.info(f"Index stats: {index_stats}")

        namespaces = index_stats['namespaces']
        for namespace, stats in namespaces.items():
            vector_count = stats['vector_count']
            ids = [str(i) for i in range(vector_count)]
            for i in range(0, vector_count, self.batch_size):
                batch_ids = ids[i:i + self.batch_size]
                response = self.index.fetch(ids=batch_ids)
                for vector in response.vectors.values():
                    metadata = vector.get('metadata')
                    if metadata and 'text' in metadata:
                        self.prompts.append(metadata['text'])
        logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone")

    def _store_prompts(self, dataset):
        logger.info("Storing prompts in Pinecone")
        for i in range(0, len(dataset), self.batch_size):
            batch = dataset[i:i + self.batch_size]
            vectors = self.model.encode(batch)
            # Prepare data for Pinecone
            pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector
                             in enumerate(vectors)]
            self.index.upsert(vectors=pinecone_data)
            logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone")

    def transform(self, prompts):
        return np.array(self.model.encode(prompts))

    def store_from_dataset(self, store_data=False):
        if store_data:
            logger.info("Loading dataset")
            dataset = load_dataset('fantasyfish/laion-art', split='train')
            logger.info(f"Loaded {len(dataset)} items from dataset")
            logger.info("Please wait for storing. This may take up to five minutes. ")
            self._store_prompts([item['text'] for item in dataset])
            logger.info("Items from dataset are stored.")
            # Ensure prompts are loaded after storing
            self._load_prompts()
            logger.info("Items from dataset are loaded.")