prompt-search-engine / app /vectorizer.py
supertskone's picture
Merge branch 'main' of https://huggingface.co/spaces/supertskone/prompt-search-engine
aad4def unverified
import os
import logging
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from pinecone import Pinecone, ServerlessSpec
# Disable parallelism for tokenizers
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Vectorizer:
def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True):
logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}")
self.model = SentenceTransformer(model_name)
self.prompts = []
self.batch_size = batch_size
self.pinecone_index_name = "hfs-search-prompts-index"
self._init_pinecone = init_pinecone
self._setup_pinecone()
self._load_prompts()
def _setup_pinecone(self):
logger.info("Setting up Pinecone")
# Initialize Pinecone
pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090')
# Check if the Pinecone index exists, if not create it
existing_indexes = pinecone.list_indexes()
if self.pinecone_index_name not in existing_indexes:
logger.info(f"Creating Pinecone index: {self.pinecone_index_name}")
if self._init_pinecone:
# pinecone.delete_index(self.pinecone_index_name)
pinecone.create_index(
name=self.pinecone_index_name,
dimension=768,
metric='cosine',
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
else:
logger.info(f"Pinecone index {self.pinecone_index_name} already exists")
pinecone.delete_index(self.pinecone_index_name)
self.index = pinecone.Index(self.pinecone_index_name)
def _load_prompts(self):
logger.info("Loading prompts from Pinecone")
self.prompts = []
# Fetch vectors from the Pinecone index
index_stats = self.index.describe_index_stats()
logger.info(f"Index stats: {index_stats}")
namespaces = index_stats['namespaces']
for namespace, stats in namespaces.items():
vector_count = stats['vector_count']
ids = [str(i) for i in range(vector_count)]
for i in range(0, vector_count, self.batch_size):
batch_ids = ids[i:i + self.batch_size]
response = self.index.fetch(ids=batch_ids)
for vector in response.vectors.values():
metadata = vector.get('metadata')
if metadata and 'text' in metadata:
self.prompts.append(metadata['text'])
logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone")
def _store_prompts(self, dataset):
logger.info("Storing prompts in Pinecone")
for i in range(0, len(dataset), self.batch_size):
batch = dataset[i:i + self.batch_size]
vectors = self.model.encode(batch)
# Prepare data for Pinecone
pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector
in enumerate(vectors)]
self.index.upsert(vectors=pinecone_data)
logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone")
def transform(self, prompts):
return np.array(self.model.encode(prompts))
def store_from_dataset(self, store_data=False):
if store_data:
logger.info("Loading dataset")
dataset = load_dataset('fantasyfish/laion-art', split='train')
logger.info(f"Loaded {len(dataset)} items from dataset")
logger.info("Please wait for storing. This may take up to five minutes. ")
self._store_prompts([item['text'] for item in dataset])
logger.info("Items from dataset are stored.")
# Ensure prompts are loaded after storing
self._load_prompts()
logger.info("Items from dataset are loaded.")