Spaces:
Sleeping
Sleeping
Merge branch 'main' of https://huggingface.co/spaces/supertskone/prompt-search-engine
aad4def
unverified
import os | |
import logging | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
from pinecone import Pinecone, ServerlessSpec | |
# Disable parallelism for tokenizers | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Vectorizer: | |
def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True): | |
logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}") | |
self.model = SentenceTransformer(model_name) | |
self.prompts = [] | |
self.batch_size = batch_size | |
self.pinecone_index_name = "hfs-search-prompts-index" | |
self._init_pinecone = init_pinecone | |
self._setup_pinecone() | |
self._load_prompts() | |
def _setup_pinecone(self): | |
logger.info("Setting up Pinecone") | |
# Initialize Pinecone | |
pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090') | |
# Check if the Pinecone index exists, if not create it | |
existing_indexes = pinecone.list_indexes() | |
if self.pinecone_index_name not in existing_indexes: | |
logger.info(f"Creating Pinecone index: {self.pinecone_index_name}") | |
if self._init_pinecone: | |
# pinecone.delete_index(self.pinecone_index_name) | |
pinecone.create_index( | |
name=self.pinecone_index_name, | |
dimension=768, | |
metric='cosine', | |
spec=ServerlessSpec( | |
cloud="aws", | |
region="us-east-1" | |
) | |
) | |
else: | |
logger.info(f"Pinecone index {self.pinecone_index_name} already exists") | |
pinecone.delete_index(self.pinecone_index_name) | |
self.index = pinecone.Index(self.pinecone_index_name) | |
def _load_prompts(self): | |
logger.info("Loading prompts from Pinecone") | |
self.prompts = [] | |
# Fetch vectors from the Pinecone index | |
index_stats = self.index.describe_index_stats() | |
logger.info(f"Index stats: {index_stats}") | |
namespaces = index_stats['namespaces'] | |
for namespace, stats in namespaces.items(): | |
vector_count = stats['vector_count'] | |
ids = [str(i) for i in range(vector_count)] | |
for i in range(0, vector_count, self.batch_size): | |
batch_ids = ids[i:i + self.batch_size] | |
response = self.index.fetch(ids=batch_ids) | |
for vector in response.vectors.values(): | |
metadata = vector.get('metadata') | |
if metadata and 'text' in metadata: | |
self.prompts.append(metadata['text']) | |
logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone") | |
def _store_prompts(self, dataset): | |
logger.info("Storing prompts in Pinecone") | |
for i in range(0, len(dataset), self.batch_size): | |
batch = dataset[i:i + self.batch_size] | |
vectors = self.model.encode(batch) | |
# Prepare data for Pinecone | |
pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector | |
in enumerate(vectors)] | |
self.index.upsert(vectors=pinecone_data) | |
logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone") | |
def transform(self, prompts): | |
return np.array(self.model.encode(prompts)) | |
def store_from_dataset(self, store_data=False): | |
if store_data: | |
logger.info("Loading dataset") | |
dataset = load_dataset('fantasyfish/laion-art', split='train') | |
logger.info(f"Loaded {len(dataset)} items from dataset") | |
logger.info("Please wait for storing. This may take up to five minutes. ") | |
self._store_prompts([item['text'] for item in dataset]) | |
logger.info("Items from dataset are stored.") | |
# Ensure prompts are loaded after storing | |
self._load_prompts() | |
logger.info("Items from dataset are loaded.") | |