Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app/__init__.py +0 -0
- app/similarity.py +22 -0
- app/vectorizer.py +98 -0
app/__init__.py
ADDED
File without changes
|
app/similarity.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def cosine_similarity(
|
5 |
+
query_vector: np.ndarray,
|
6 |
+
corpus_vectors: np.ndarray
|
7 |
+
) -> np.ndarray:
|
8 |
+
"""
|
9 |
+
Calculate cosine similarity between a query vector and a corpus of vectors.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
query_vector: Vectorized prompt query of shape (D,).
|
13 |
+
corpus_vectors: Vectorized prompt corpus of shape (N, D).
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
np.ndarray: The vector of shape (N,) with values in range [-1, 1] where 1
|
17 |
+
is max similarity i.e., two vectors are the same.
|
18 |
+
"""
|
19 |
+
dot_product = np.dot(corpus_vectors, query_vector)
|
20 |
+
norm_query = np.linalg.norm(query_vector)
|
21 |
+
norm_corpus = np.linalg.norm(corpus_vectors, axis=1)
|
22 |
+
return dot_product / (norm_query * norm_corpus)
|
app/vectorizer.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
from datasets import load_dataset
|
7 |
+
from pinecone import Pinecone, ServerlessSpec
|
8 |
+
|
9 |
+
# Disable parallelism for tokenizers
|
10 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
11 |
+
|
12 |
+
# Configure logging
|
13 |
+
logging.basicConfig(level=logging.INFO)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class Vectorizer:
|
18 |
+
def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True):
|
19 |
+
logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}")
|
20 |
+
self.model = SentenceTransformer(model_name)
|
21 |
+
self.prompts = []
|
22 |
+
self.batch_size = batch_size
|
23 |
+
self.pinecone_index_name = "prompts-index"
|
24 |
+
self._init_pinecone = init_pinecone
|
25 |
+
self._setup_pinecone()
|
26 |
+
self._load_prompts()
|
27 |
+
|
28 |
+
def _setup_pinecone(self):
|
29 |
+
logger.info("Setting up Pinecone")
|
30 |
+
# Initialize Pinecone
|
31 |
+
pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090')
|
32 |
+
# Check if the Pinecone index exists, if not create it
|
33 |
+
existing_indexes = pinecone.list_indexes()
|
34 |
+
|
35 |
+
logger.info(f"self.init_pineconeself.init_pineconeself"
|
36 |
+
f".init_pineconeself.init_pineconeself.init_pinecone: {self._init_pinecone}")
|
37 |
+
if self.pinecone_index_name not in existing_indexes:
|
38 |
+
logger.info(f"Creating Pinecone index: {self.pinecone_index_name}")
|
39 |
+
if self._init_pinecone:
|
40 |
+
pinecone.create_index(
|
41 |
+
name=self.pinecone_index_name,
|
42 |
+
dimension=768,
|
43 |
+
metric='cosine',
|
44 |
+
spec=ServerlessSpec(
|
45 |
+
cloud="aws",
|
46 |
+
region="us-east-1"
|
47 |
+
)
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
logger.info(f"Pinecone index {self.pinecone_index_name} already exists")
|
51 |
+
|
52 |
+
self.index = pinecone.Index(self.pinecone_index_name)
|
53 |
+
|
54 |
+
def _load_prompts(self):
|
55 |
+
logger.info("Loading prompts from Pinecone")
|
56 |
+
self.prompts = []
|
57 |
+
# Fetch vectors from the Pinecone index
|
58 |
+
index_stats = self.index.describe_index_stats()
|
59 |
+
logger.info(f"Index stats: {index_stats}")
|
60 |
+
|
61 |
+
namespaces = index_stats['namespaces']
|
62 |
+
for namespace, stats in namespaces.items():
|
63 |
+
vector_count = stats['vector_count']
|
64 |
+
ids = [str(i) for i in range(vector_count)]
|
65 |
+
for i in range(0, vector_count, self.batch_size):
|
66 |
+
batch_ids = ids[i:i + self.batch_size]
|
67 |
+
response = self.index.fetch(ids=batch_ids)
|
68 |
+
for vector in response.vectors.values():
|
69 |
+
metadata = vector.get('metadata')
|
70 |
+
if metadata and 'text' in metadata:
|
71 |
+
self.prompts.append(metadata['text'])
|
72 |
+
logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone")
|
73 |
+
|
74 |
+
def _store_prompts(self, dataset):
|
75 |
+
logger.info("Storing prompts in Pinecone")
|
76 |
+
for i in range(0, len(dataset), self.batch_size):
|
77 |
+
batch = dataset[i:i + self.batch_size]
|
78 |
+
vectors = self.model.encode(batch)
|
79 |
+
# Prepare data for Pinecone
|
80 |
+
pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector
|
81 |
+
in enumerate(vectors)]
|
82 |
+
self.index.upsert(vectors=pinecone_data)
|
83 |
+
logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone")
|
84 |
+
|
85 |
+
def transform(self, prompts):
|
86 |
+
return np.array(self.model.encode(prompts))
|
87 |
+
|
88 |
+
def store_from_dataset(self, store_data=False):
|
89 |
+
if store_data:
|
90 |
+
logger.info("Loading dataset")
|
91 |
+
dataset = load_dataset('fantasyfish/laion-art', split='train')
|
92 |
+
logger.info(f"Loaded {len(dataset)} items from dataset")
|
93 |
+
logger.info("Please wait for storing. This may take up to five minutes. ")
|
94 |
+
self._store_prompts([item['text'] for item in dataset])
|
95 |
+
logger.info("Items from dataset are stored.")
|
96 |
+
# Ensure prompts are loaded after storing
|
97 |
+
self._load_prompts()
|
98 |
+
logger.info("Items from dataset are loaded.")
|