supertskone commited on
Commit
b88f0e6
·
verified ·
1 Parent(s): 1aa9434

Delete app/vectorizer.py

Browse files
Files changed (1) hide show
  1. app/vectorizer.py +0 -98
app/vectorizer.py DELETED
@@ -1,98 +0,0 @@
1
- import os
2
- import logging
3
-
4
- import numpy as np
5
- from sentence_transformers import SentenceTransformer
6
- from datasets import load_dataset
7
- from pinecone import Pinecone, ServerlessSpec
8
-
9
- # Disable parallelism for tokenizers
10
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
11
-
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class Vectorizer:
18
- def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True):
19
- logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}")
20
- self.model = SentenceTransformer(model_name)
21
- self.prompts = []
22
- self.batch_size = batch_size
23
- self.pinecone_index_name = "prompts-index"
24
- self._init_pinecone = init_pinecone
25
- self._setup_pinecone()
26
- self._load_prompts()
27
-
28
- def _setup_pinecone(self):
29
- logger.info("Setting up Pinecone")
30
- # Initialize Pinecone
31
- pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090')
32
- # Check if the Pinecone index exists, if not create it
33
- existing_indexes = pinecone.list_indexes()
34
-
35
- logger.info(f"self.init_pineconeself.init_pineconeself"
36
- f".init_pineconeself.init_pineconeself.init_pinecone: {self._init_pinecone}")
37
- if self.pinecone_index_name not in existing_indexes:
38
- logger.info(f"Creating Pinecone index: {self.pinecone_index_name}")
39
- if self._init_pinecone:
40
- pinecone.create_index(
41
- name=self.pinecone_index_name,
42
- dimension=768,
43
- metric='cosine',
44
- spec=ServerlessSpec(
45
- cloud="aws",
46
- region="us-east-1"
47
- )
48
- )
49
- else:
50
- logger.info(f"Pinecone index {self.pinecone_index_name} already exists")
51
-
52
- self.index = pinecone.Index(self.pinecone_index_name)
53
-
54
- def _load_prompts(self):
55
- logger.info("Loading prompts from Pinecone")
56
- self.prompts = []
57
- # Fetch vectors from the Pinecone index
58
- index_stats = self.index.describe_index_stats()
59
- logger.info(f"Index stats: {index_stats}")
60
-
61
- namespaces = index_stats['namespaces']
62
- for namespace, stats in namespaces.items():
63
- vector_count = stats['vector_count']
64
- ids = [str(i) for i in range(vector_count)]
65
- for i in range(0, vector_count, self.batch_size):
66
- batch_ids = ids[i:i + self.batch_size]
67
- response = self.index.fetch(ids=batch_ids)
68
- for vector in response.vectors.values():
69
- metadata = vector.get('metadata')
70
- if metadata and 'text' in metadata:
71
- self.prompts.append(metadata['text'])
72
- logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone")
73
-
74
- def _store_prompts(self, dataset):
75
- logger.info("Storing prompts in Pinecone")
76
- for i in range(0, len(dataset), self.batch_size):
77
- batch = dataset[i:i + self.batch_size]
78
- vectors = self.model.encode(batch)
79
- # Prepare data for Pinecone
80
- pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector
81
- in enumerate(vectors)]
82
- self.index.upsert(vectors=pinecone_data)
83
- logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone")
84
-
85
- def transform(self, prompts):
86
- return np.array(self.model.encode(prompts))
87
-
88
- def store_from_dataset(self, store_data=False):
89
- if store_data:
90
- logger.info("Loading dataset")
91
- dataset = load_dataset('fantasyfish/laion-art', split='train')
92
- logger.info(f"Loaded {len(dataset)} items from dataset")
93
- logger.info("Please wait for storing. This may take up to five minutes. ")
94
- self._store_prompts([item['text'] for item in dataset])
95
- logger.info("Items from dataset are stored.")
96
- # Ensure prompts are loaded after storing
97
- self._load_prompts()
98
- logger.info("Items from dataset are loaded.")