File size: 5,367 Bytes
965ac15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import pytest
import os
import time
from openai import OpenAI
from pinecone import Pinecone, ServerlessSpec
from datasets import load_dataset
from dotenv import load_dotenv

class TestPineconeIntegration:
    @pytest.fixture(autouse=True)
    def setup(self):
        """Setup test environment and resources"""
        # Load environment variables
        load_dotenv("../")
        
        # Initialize clients
        self.pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        
        # Constants
        self.MODEL = "text-embedding-3-small"
        self.index_name = "test-semantic-search-openai"
        
        yield  # This is where the test runs
        
        # Cleanup after tests
        try:
            if self.index_name in self.pc.list_indexes().names():
                self.pc.delete_index(self.index_name)
        except Exception as e:
            print(f"Cleanup failed: {str(e)}")

    def test_01_create_embeddings(self):
        """Test OpenAI embedding creation"""
        sample_texts = [
            "Sample document text goes here",
            "there will be several phrases in each batch"
        ]
        
        res = self.client.embeddings.create(
            input=sample_texts,
            model=self.MODEL
        )
        
        embeds = [record.embedding for record in res.data]
        assert len(embeds) == 2
        assert len(embeds[0]) > 0  # Check if embeddings are non-empty
        
        return embeds[0]  # Return for use in other tests

    def test_02_create_index(self):
        """Test Pinecone index creation"""
        # Get sample embedding dimension from previous test
        sample_embed = self.test_01_create_embeddings()
        embedding_dimension = len(sample_embed)
        
        spec = ServerlessSpec(cloud="aws", region="us-east-1")
        
        # Create index if it doesn't exist
        if self.index_name not in self.pc.list_indexes().names():
            self.pc.create_index(
                self.index_name,
                dimension=embedding_dimension,
                metric='dotproduct',
                spec=spec
            )
        
        # Wait for index to be ready
        max_retries = 60  # Maximum number of seconds to wait
        retries = 0
        while not self.pc.describe_index(self.index_name).status['ready']:
            if retries >= max_retries:
                raise TimeoutError("Index creation timed out")
            time.sleep(1)
            retries += 1
        
        # Verify index exists and is ready
        assert self.index_name in self.pc.list_indexes().names()
        assert self.pc.describe_index(self.index_name).status['ready']

    def test_03_upload_data(self):
        """Test data upload to Pinecone"""
        # Ensure index exists first
        self.test_02_create_index()
        
        # Connect to index
        index = self.pc.Index(self.index_name)
        # Load test dataset - using 'trec' instead of 'train'
        trec = load_dataset('trec', split='train[:10]')  # Using smaller dataset for testing
        
        batch_size = 5
        total_processed = 0
        
        for i in range(0, len(trec['text']), batch_size):
            i_end = min(i + batch_size, len(trec['text']))
            lines_batch = trec['text'][i:i_end]
            ids_batch = [str(n) for n in range(i, i_end)]
            
            # Create embeddings
            res = self.client.embeddings.create(input=lines_batch, model=self.MODEL)
            embeds = [record.embedding for record in res.data]
            
            # Prepare metadata and upsert batch
            meta = [{'text': line} for line in lines_batch]
            to_upsert = zip(ids_batch, embeds, meta)
            
            # Upsert to Pinecone
            index.upsert(vectors=list(to_upsert))
            total_processed += len(lines_batch)
        
        # Wait for a moment to ensure data is indexed
        time.sleep(5)
        
        # Verify data was uploaded
        stats = index.describe_index_stats()
        print(f'stats: {stats}')
        # assert stats.total_vector_count == total_processed

    def test_04_query_index(self):
        """Test querying the Pinecone index"""
        # Ensure data is uploaded first
        self.test_03_upload_data()
        
        index = self.pc.Index(self.index_name)
        
        # Create query embedding
        query = "What caused the Great Depression?"
        xq = self.client.embeddings.create(input=query, model=self.MODEL).data[0].embedding
        
        # Query index
        res = index.query(vector=xq, top_k=5, include_metadata=True)
        
        # Verify response format
        assert 'matches' in res
        assert len(res['matches']) <= 5  # Should return up to 5 results
        
        # Verify match format
        for match in res['matches']:
            assert 'score' in match
            assert 'metadata' in match
            assert 'text' in match['metadata']

    def test_05_delete_index(self):
        """Test index deletion"""
        # Ensure index exists first
        self.test_02_create_index()
        
        # Delete index
        self.pc.delete_index(self.index_name)
        
        # Verify deletion
        assert self.index_name not in self.pc.list_indexes().names()