Spaces:
Sleeping
Sleeping
Delete app/vectorizer.py
Browse files- app/vectorizer.py +0 -98
app/vectorizer.py
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
from sentence_transformers import SentenceTransformer
|
6 |
-
from datasets import load_dataset
|
7 |
-
from pinecone import Pinecone, ServerlessSpec
|
8 |
-
|
9 |
-
# Disable parallelism for tokenizers
|
10 |
-
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
11 |
-
|
12 |
-
# Configure logging
|
13 |
-
logging.basicConfig(level=logging.INFO)
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
|
16 |
-
|
17 |
-
class Vectorizer:
|
18 |
-
def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True):
|
19 |
-
logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}")
|
20 |
-
self.model = SentenceTransformer(model_name)
|
21 |
-
self.prompts = []
|
22 |
-
self.batch_size = batch_size
|
23 |
-
self.pinecone_index_name = "prompts-index"
|
24 |
-
self._init_pinecone = init_pinecone
|
25 |
-
self._setup_pinecone()
|
26 |
-
self._load_prompts()
|
27 |
-
|
28 |
-
def _setup_pinecone(self):
|
29 |
-
logger.info("Setting up Pinecone")
|
30 |
-
# Initialize Pinecone
|
31 |
-
pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090')
|
32 |
-
# Check if the Pinecone index exists, if not create it
|
33 |
-
existing_indexes = pinecone.list_indexes()
|
34 |
-
|
35 |
-
logger.info(f"self.init_pineconeself.init_pineconeself"
|
36 |
-
f".init_pineconeself.init_pineconeself.init_pinecone: {self._init_pinecone}")
|
37 |
-
if self.pinecone_index_name not in existing_indexes:
|
38 |
-
logger.info(f"Creating Pinecone index: {self.pinecone_index_name}")
|
39 |
-
if self._init_pinecone:
|
40 |
-
pinecone.create_index(
|
41 |
-
name=self.pinecone_index_name,
|
42 |
-
dimension=768,
|
43 |
-
metric='cosine',
|
44 |
-
spec=ServerlessSpec(
|
45 |
-
cloud="aws",
|
46 |
-
region="us-east-1"
|
47 |
-
)
|
48 |
-
)
|
49 |
-
else:
|
50 |
-
logger.info(f"Pinecone index {self.pinecone_index_name} already exists")
|
51 |
-
|
52 |
-
self.index = pinecone.Index(self.pinecone_index_name)
|
53 |
-
|
54 |
-
def _load_prompts(self):
|
55 |
-
logger.info("Loading prompts from Pinecone")
|
56 |
-
self.prompts = []
|
57 |
-
# Fetch vectors from the Pinecone index
|
58 |
-
index_stats = self.index.describe_index_stats()
|
59 |
-
logger.info(f"Index stats: {index_stats}")
|
60 |
-
|
61 |
-
namespaces = index_stats['namespaces']
|
62 |
-
for namespace, stats in namespaces.items():
|
63 |
-
vector_count = stats['vector_count']
|
64 |
-
ids = [str(i) for i in range(vector_count)]
|
65 |
-
for i in range(0, vector_count, self.batch_size):
|
66 |
-
batch_ids = ids[i:i + self.batch_size]
|
67 |
-
response = self.index.fetch(ids=batch_ids)
|
68 |
-
for vector in response.vectors.values():
|
69 |
-
metadata = vector.get('metadata')
|
70 |
-
if metadata and 'text' in metadata:
|
71 |
-
self.prompts.append(metadata['text'])
|
72 |
-
logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone")
|
73 |
-
|
74 |
-
def _store_prompts(self, dataset):
|
75 |
-
logger.info("Storing prompts in Pinecone")
|
76 |
-
for i in range(0, len(dataset), self.batch_size):
|
77 |
-
batch = dataset[i:i + self.batch_size]
|
78 |
-
vectors = self.model.encode(batch)
|
79 |
-
# Prepare data for Pinecone
|
80 |
-
pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector
|
81 |
-
in enumerate(vectors)]
|
82 |
-
self.index.upsert(vectors=pinecone_data)
|
83 |
-
logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone")
|
84 |
-
|
85 |
-
def transform(self, prompts):
|
86 |
-
return np.array(self.model.encode(prompts))
|
87 |
-
|
88 |
-
def store_from_dataset(self, store_data=False):
|
89 |
-
if store_data:
|
90 |
-
logger.info("Loading dataset")
|
91 |
-
dataset = load_dataset('fantasyfish/laion-art', split='train')
|
92 |
-
logger.info(f"Loaded {len(dataset)} items from dataset")
|
93 |
-
logger.info("Please wait for storing. This may take up to five minutes. ")
|
94 |
-
self._store_prompts([item['text'] for item in dataset])
|
95 |
-
logger.info("Items from dataset are stored.")
|
96 |
-
# Ensure prompts are loaded after storing
|
97 |
-
self._load_prompts()
|
98 |
-
logger.info("Items from dataset are loaded.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|