supertskone commited on
Commit
72de987
·
verified ·
1 Parent(s): f63ca0c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app/__init__.py +0 -0
  2. app/similarity.py +22 -0
  3. app/vectorizer.py +98 -0
app/__init__.py ADDED
File without changes
app/similarity.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def cosine_similarity(
5
+ query_vector: np.ndarray,
6
+ corpus_vectors: np.ndarray
7
+ ) -> np.ndarray:
8
+ """
9
+ Calculate cosine similarity between a query vector and a corpus of vectors.
10
+
11
+ Args:
12
+ query_vector: Vectorized prompt query of shape (D,).
13
+ corpus_vectors: Vectorized prompt corpus of shape (N, D).
14
+
15
+ Returns:
16
+ np.ndarray: The vector of shape (N,) with values in range [-1, 1] where 1
17
+ is max similarity i.e., two vectors are the same.
18
+ """
19
+ dot_product = np.dot(corpus_vectors, query_vector)
20
+ norm_query = np.linalg.norm(query_vector)
21
+ norm_corpus = np.linalg.norm(corpus_vectors, axis=1)
22
+ return dot_product / (norm_query * norm_corpus)
app/vectorizer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from datasets import load_dataset
7
+ from pinecone import Pinecone, ServerlessSpec
8
+
9
+ # Disable parallelism for tokenizers
10
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Vectorizer:
18
+ def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True):
19
+ logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}")
20
+ self.model = SentenceTransformer(model_name)
21
+ self.prompts = []
22
+ self.batch_size = batch_size
23
+ self.pinecone_index_name = "prompts-index"
24
+ self._init_pinecone = init_pinecone
25
+ self._setup_pinecone()
26
+ self._load_prompts()
27
+
28
+ def _setup_pinecone(self):
29
+ logger.info("Setting up Pinecone")
30
+ # Initialize Pinecone
31
+ pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090')
32
+ # Check if the Pinecone index exists, if not create it
33
+ existing_indexes = pinecone.list_indexes()
34
+
35
+ logger.info(f"self.init_pineconeself.init_pineconeself"
36
+ f".init_pineconeself.init_pineconeself.init_pinecone: {self._init_pinecone}")
37
+ if self.pinecone_index_name not in existing_indexes:
38
+ logger.info(f"Creating Pinecone index: {self.pinecone_index_name}")
39
+ if self._init_pinecone:
40
+ pinecone.create_index(
41
+ name=self.pinecone_index_name,
42
+ dimension=768,
43
+ metric='cosine',
44
+ spec=ServerlessSpec(
45
+ cloud="aws",
46
+ region="us-east-1"
47
+ )
48
+ )
49
+ else:
50
+ logger.info(f"Pinecone index {self.pinecone_index_name} already exists")
51
+
52
+ self.index = pinecone.Index(self.pinecone_index_name)
53
+
54
+ def _load_prompts(self):
55
+ logger.info("Loading prompts from Pinecone")
56
+ self.prompts = []
57
+ # Fetch vectors from the Pinecone index
58
+ index_stats = self.index.describe_index_stats()
59
+ logger.info(f"Index stats: {index_stats}")
60
+
61
+ namespaces = index_stats['namespaces']
62
+ for namespace, stats in namespaces.items():
63
+ vector_count = stats['vector_count']
64
+ ids = [str(i) for i in range(vector_count)]
65
+ for i in range(0, vector_count, self.batch_size):
66
+ batch_ids = ids[i:i + self.batch_size]
67
+ response = self.index.fetch(ids=batch_ids)
68
+ for vector in response.vectors.values():
69
+ metadata = vector.get('metadata')
70
+ if metadata and 'text' in metadata:
71
+ self.prompts.append(metadata['text'])
72
+ logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone")
73
+
74
+ def _store_prompts(self, dataset):
75
+ logger.info("Storing prompts in Pinecone")
76
+ for i in range(0, len(dataset), self.batch_size):
77
+ batch = dataset[i:i + self.batch_size]
78
+ vectors = self.model.encode(batch)
79
+ # Prepare data for Pinecone
80
+ pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector
81
+ in enumerate(vectors)]
82
+ self.index.upsert(vectors=pinecone_data)
83
+ logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone")
84
+
85
+ def transform(self, prompts):
86
+ return np.array(self.model.encode(prompts))
87
+
88
+ def store_from_dataset(self, store_data=False):
89
+ if store_data:
90
+ logger.info("Loading dataset")
91
+ dataset = load_dataset('fantasyfish/laion-art', split='train')
92
+ logger.info(f"Loaded {len(dataset)} items from dataset")
93
+ logger.info("Please wait for storing. This may take up to five minutes. ")
94
+ self._store_prompts([item['text'] for item in dataset])
95
+ logger.info("Items from dataset are stored.")
96
+ # Ensure prompts are loaded after storing
97
+ self._load_prompts()
98
+ logger.info("Items from dataset are loaded.")