supertskone commited on
Commit
fe51e27
·
verified ·
1 Parent(s): 3faef87

Upload 14 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile
2
+ FROM python:3.9-slim
3
+
4
+ WORKDIR /app
5
+
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "run:app", "--host", "0.0.0.0", "--port", "8000"]
app/__init__.py ADDED
File without changes
app/search_engine.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Tuple
3
+
4
+ from .similarity import cosine_similarity
5
+ from .vectorizer import Vectorizer
6
+ import logging
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class PromptSearchEngine:
14
+ def __init__(self):
15
+ self.vectorizer = Vectorizer(init_pinecone=False)
16
+ self.vectorizer._data_loaded = True
17
+ self.prompts = self.vectorizer.prompts
18
+ self.corpus_vectors = self.vectorizer.transform(self.prompts)
19
+ self.index_name = self.vectorizer.pinecone_index_name
20
+
21
+ def most_similar(self, query: str, n: int = 5, use_pinecone=True) -> List[Tuple[float, str]]:
22
+ logger.info(f"Encoding query: {query}")
23
+ query_vector = self.vectorizer.transform([query])[0]
24
+ logger.info(f"Encoded query vector: {query_vector}")
25
+ if use_pinecone:
26
+ logger.info(f"I'm doing pinecone vector search because the use_pinecone is: {use_pinecone}")
27
+ try:
28
+ # Convert numpy array to list of native Python floats
29
+ query_vector_list = query_vector.tolist()
30
+ search_result = self.vectorizer.index.query(
31
+ vector=query_vector_list,
32
+ top_k=n,
33
+ include_metadata=True
34
+ )
35
+ logger.info(f"Search result: {search_result}")
36
+
37
+ # Retrieve and format the results
38
+ results = [(match['score'], match['metadata']['text']) for match in search_result['matches'] if
39
+ 'text' in match['metadata']]
40
+ except Exception as e:
41
+ logger.error(f"Pinecone query failed: {e}")
42
+ logger.info("Falling back to cosine similarity search.")
43
+
44
+ # Fallback to cosine similarity search
45
+ similarities = cosine_similarity(query_vector, self.corpus_vectors)
46
+ top_n_indices = np.argsort(similarities)[-n:][::-1]
47
+ results = [(float(similarities[i]), self.prompts[i]) for i in top_n_indices]
48
+ else:
49
+ logger.info(f"I'm cosine similarity search because the use_pinecone is: {use_pinecone}")
50
+ logger.info("Using cosine similarity for search")
51
+ similarities = cosine_similarity(query_vector, self.corpus_vectors)
52
+ top_n_indices = np.argsort(similarities)[-n:][::-1]
53
+ results = [(float(similarities[i]), self.prompts[i]) for i in top_n_indices]
54
+ return results
app/similarity.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def cosine_similarity(
5
+ query_vector: np.ndarray,
6
+ corpus_vectors: np.ndarray
7
+ ) -> np.ndarray:
8
+ """
9
+ Calculate cosine similarity between a query vector and a corpus of vectors.
10
+
11
+ Args:
12
+ query_vector: Vectorized prompt query of shape (D,).
13
+ corpus_vectors: Vectorized prompt corpus of shape (N, D).
14
+
15
+ Returns:
16
+ np.ndarray: The vector of shape (N,) with values in range [-1, 1] where 1
17
+ is max similarity i.e., two vectors are the same.
18
+ """
19
+ dot_product = np.dot(corpus_vectors, query_vector)
20
+ norm_query = np.linalg.norm(query_vector)
21
+ norm_corpus = np.linalg.norm(corpus_vectors, axis=1)
22
+ return dot_product / (norm_query * norm_corpus)
app/vectorizer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from datasets import load_dataset
7
+ from pinecone import Pinecone, ServerlessSpec
8
+
9
+ # Disable parallelism for tokenizers
10
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Vectorizer:
18
+ def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True):
19
+ logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}")
20
+ self.model = SentenceTransformer(model_name)
21
+ self.prompts = []
22
+ self.batch_size = batch_size
23
+ self.pinecone_index_name = "prompts-index"
24
+ self._init_pinecone = init_pinecone
25
+ self._setup_pinecone()
26
+ self._load_prompts()
27
+
28
+ def _setup_pinecone(self):
29
+ logger.info("Setting up Pinecone")
30
+ # Initialize Pinecone
31
+ pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090')
32
+ # Check if the Pinecone index exists, if not create it
33
+ existing_indexes = pinecone.list_indexes()
34
+
35
+ logger.info(f"self.init_pineconeself.init_pineconeself"
36
+ f".init_pineconeself.init_pineconeself.init_pinecone: {self._init_pinecone}")
37
+ if self.pinecone_index_name not in existing_indexes:
38
+ logger.info(f"Creating Pinecone index: {self.pinecone_index_name}")
39
+ if self._init_pinecone:
40
+ pinecone.create_index(
41
+ name=self.pinecone_index_name,
42
+ dimension=768,
43
+ metric='cosine',
44
+ spec=ServerlessSpec(
45
+ cloud="aws",
46
+ region="us-east-1"
47
+ )
48
+ )
49
+ else:
50
+ logger.info(f"Pinecone index {self.pinecone_index_name} already exists")
51
+
52
+ self.index = pinecone.Index(self.pinecone_index_name)
53
+
54
+ def _load_prompts(self):
55
+ logger.info("Loading prompts from Pinecone")
56
+ self.prompts = []
57
+ # Fetch vectors from the Pinecone index
58
+ index_stats = self.index.describe_index_stats()
59
+ logger.info(f"Index stats: {index_stats}")
60
+
61
+ namespaces = index_stats['namespaces']
62
+ for namespace, stats in namespaces.items():
63
+ vector_count = stats['vector_count']
64
+ ids = [str(i) for i in range(vector_count)]
65
+ for i in range(0, vector_count, self.batch_size):
66
+ batch_ids = ids[i:i + self.batch_size]
67
+ response = self.index.fetch(ids=batch_ids)
68
+ for vector in response.vectors.values():
69
+ metadata = vector.get('metadata')
70
+ if metadata and 'text' in metadata:
71
+ self.prompts.append(metadata['text'])
72
+ logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone")
73
+
74
+ def _store_prompts(self, dataset):
75
+ logger.info("Storing prompts in Pinecone")
76
+ for i in range(0, len(dataset), self.batch_size):
77
+ batch = dataset[i:i + self.batch_size]
78
+ vectors = self.model.encode(batch)
79
+ # Prepare data for Pinecone
80
+ pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector
81
+ in enumerate(vectors)]
82
+ self.index.upsert(vectors=pinecone_data)
83
+ logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone")
84
+
85
+ def transform(self, prompts):
86
+ return np.array(self.model.encode(prompts))
87
+
88
+ def store_from_dataset(self, store_data=False):
89
+ if store_data:
90
+ logger.info("Loading dataset")
91
+ dataset = load_dataset('fantasyfish/laion-art', split='train')
92
+ logger.info(f"Loaded {len(dataset)} items from dataset")
93
+ logger.info("Please wait for storing. This may take up to five minutes. ")
94
+ self._store_prompts([item['text'] for item in dataset])
95
+ logger.info("Items from dataset are stored.")
96
+ # Ensure prompts are loaded after storing
97
+ self._load_prompts()
98
+ logger.info("Items from dataset are loaded.")
img.png ADDED
load_data.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from app.vectorizer import Vectorizer
2
+
3
+ if __name__ == "__main__":
4
+ vectorizer = Vectorizer()
5
+ vectorizer.store_from_dataset(store_data=True) # Run this once to load the dataset into Pinecone
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ requests
3
+ streamlit
4
+ transformers
5
+ numpy
6
+ sentence-transformers
7
+ datasets
8
+ pinecone
9
+ unittest
run.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from flask import Flask, request, jsonify
4
+ from app.search_engine import PromptSearchEngine
5
+
6
+ app = Flask(__name__)
7
+
8
+ # Disable parallelism for tokenizers
9
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ search_engine = PromptSearchEngine()
16
+
17
+ @app.route('/search', methods=['POST'])
18
+ def search():
19
+ data = request.get_json()
20
+ query = data.get('query')
21
+ n = data.get('n', 5)
22
+ use_pinecone = data.get('use_pinecone', True)
23
+
24
+ logger.info(f"Received query: {query} with n: {n} and use_pinecone: {use_pinecone}")
25
+ results = search_engine.most_similar(query, n, use_pinecone)
26
+ formatted_results = [{'score': score, 'prompt': prompt} for score, prompt in results]
27
+ logger.info(f"Returning results: {formatted_results}")
28
+ return jsonify(formatted_results)
29
+
30
+
31
+ if __name__ == '__main__':
32
+ logger.info("Starting Flask server")
33
+ app.run(debug=True)
run_tests.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+
4
+ def run_all_tests():
5
+ test_loader = unittest.TestLoader()
6
+ test_suite = test_loader.discover('tests', pattern='test_*.py')
7
+
8
+ test_runner = unittest.TextTestRunner(verbosity=2)
9
+ test_runner.run(test_suite)
10
+
11
+
12
+ if __name__ == '__main__':
13
+ run_all_tests()
tests/test_search_engine.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import patch
3
+ import numpy as np
4
+ from app.search_engine import PromptSearchEngine
5
+
6
+
7
+ class TestPromptSearchEngine(unittest.TestCase):
8
+
9
+ @patch('app.vectorizer.Vectorizer')
10
+ def setUp(self, mock_vectorizer):
11
+ self.mock_vectorizer = mock_vectorizer.return_value
12
+ self.mock_vectorizer.transform.return_value = np.random.rand(10, 768)
13
+ self.mock_vectorizer.prompts = ['prompt'] * 10
14
+ self.search_engine = PromptSearchEngine()
15
+
16
+ def test_most_similar_with_cosine_similarity(self):
17
+ self.mock_vectorizer.index.query.side_effect = Exception('Pinecone error')
18
+ results = self.search_engine.most_similar('query', use_pinecone=False)
19
+ self.assertEqual(len(results), 5)
20
+ self.assertIsInstance(results[0][0], float)
21
+ self.assertIsInstance(results[0][1], str)
22
+
23
+ def test_most_similar_with_pinecone(self):
24
+ mock_search_result = {
25
+ 'matches': [
26
+ {'score': np.float32(0.9), 'metadata': {'text': 'prompt1'}},
27
+ {'score': np.float32(0.8), 'metadata': {'text': 'prompt2'}}
28
+ ]
29
+ }
30
+ self.mock_vectorizer.index.query.return_value = mock_search_result
31
+
32
+ results = self.search_engine.most_similar('query', use_pinecone=True)
33
+ self.assertEqual(len(results), 5)
34
+ self.assertIsInstance(results[0][0], float)
35
+ self.assertIsInstance(results[0][1], str)
36
+
37
+
38
+ if __name__ == '__main__':
39
+ unittest.main()
tests/test_similarity.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ from app.similarity import cosine_similarity
4
+
5
+
6
+ class TestSimilarity(unittest.TestCase):
7
+
8
+ def test_cosine_similarity(self):
9
+ query_vector = np.array([1, 2, 3])
10
+ corpus_vectors = np.array([
11
+ [1, 2, 3],
12
+ [4, 5, 6],
13
+ [7, 8, 9]
14
+ ])
15
+
16
+ expected_result = np.array([1.0, 0.9746318461970762, 0.9594119455666703])
17
+ result = cosine_similarity(query_vector, corpus_vectors)
18
+
19
+ np.testing.assert_almost_equal(result, expected_result, decimal=6)
20
+
21
+
22
+ if __name__ == '__main__':
23
+ unittest.main()
tests/test_vectorizer.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import patch, MagicMock
3
+ import numpy as np
4
+
5
+ from app.vectorizer import Vectorizer
6
+
7
+
8
+ class TestVectorizer(unittest.TestCase):
9
+
10
+ @patch('app.vectorizer.Pinecone')
11
+ @patch('app.vectorizer.SentenceTransformer')
12
+ def test_vectorizer_initialization(self, mock_sentence_transformer, mock_pinecone):
13
+ mock_sentence_transformer.return_value.encode.return_value = np.random.rand(1, 768)
14
+ vectorizer = Vectorizer(init_pinecone=False)
15
+ self.assertEqual(vectorizer.batch_size, 64)
16
+ self.assertEqual(vectorizer.pinecone_index_name, "prompts-index")
17
+
18
+ @patch('app.vectorizer.load_dataset')
19
+ @patch('app.vectorizer.Pinecone')
20
+ def test_store_from_dataset(self, mock_pinecone, mock_load_dataset):
21
+ mock_pinecone_instance = MagicMock()
22
+ mock_pinecone.return_value = mock_pinecone_instance
23
+ mock_load_dataset.return_value = [{'text': 'sample text'}]
24
+
25
+ vectorizer = Vectorizer(init_pinecone=False)
26
+ vectorizer.store_from_dataset(store_data=True)
27
+
28
+ mock_load_dataset.assert_called_once_with('fantasyfish/laion-art', split='train')
29
+ mock_pinecone_instance.Index.return_value.upsert.assert_called()
30
+
31
+ def test_transform(self):
32
+ with patch('app.vectorizer.SentenceTransformer') as mock_sentence_transformer:
33
+ mock_sentence_transformer.return_value.encode.return_value = np.random.rand(1, 768)
34
+ vectorizer = Vectorizer(init_pinecone=False)
35
+ vectors = vectorizer.transform(['sample prompt'])
36
+ self.assertEqual(vectors.shape, (1, 768))
ui/app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import streamlit as st
4
+ import requests
5
+
6
+ st.title("Prompt Search Engine")
7
+
8
+ query = st.text_input("Enter your query:")
9
+ use_pinecone = st.radio(
10
+ "Choose search method:",
11
+ ('Pinecone Vector Search', 'Cosine Similarity')
12
+ )
13
+ n = st.number_input("Number of results:", min_value=1, max_value=20, value=5)
14
+
15
+ if st.button("Search"):
16
+ search_method = use_pinecone == 'Pinecone Vector Search'
17
+ response = requests.post("http://localhost:5000/search", json={"query": query, "n": n, "use_pinecone": search_method})
18
+
19
+ # Log the response for debugging
20
+ st.write("Response Status Code:", response.status_code)
21
+
22
+ try:
23
+ results = response.json()
24
+ # for score, prompt in results:
25
+ # st.write(f"{score:.2f} - {prompt}")
26
+ for result in results:
27
+ score = float(result['score'])
28
+ prompt = result['prompt']
29
+ st.write(f"{score:.2f} - {prompt}")
30
+ except json.JSONDecodeError as e:
31
+ st.error(f"Failed to decode JSON response: {e}")
32
+ st.write(response.content)