Spaces:
Build error
Build error
File size: 5,367 Bytes
965ac15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import pytest
import os
import time
from openai import OpenAI
from pinecone import Pinecone, ServerlessSpec
from datasets import load_dataset
from dotenv import load_dotenv
class TestPineconeIntegration:
@pytest.fixture(autouse=True)
def setup(self):
"""Setup test environment and resources"""
# Load environment variables
load_dotenv("../")
# Initialize clients
self.pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Constants
self.MODEL = "text-embedding-3-small"
self.index_name = "test-semantic-search-openai"
yield # This is where the test runs
# Cleanup after tests
try:
if self.index_name in self.pc.list_indexes().names():
self.pc.delete_index(self.index_name)
except Exception as e:
print(f"Cleanup failed: {str(e)}")
def test_01_create_embeddings(self):
"""Test OpenAI embedding creation"""
sample_texts = [
"Sample document text goes here",
"there will be several phrases in each batch"
]
res = self.client.embeddings.create(
input=sample_texts,
model=self.MODEL
)
embeds = [record.embedding for record in res.data]
assert len(embeds) == 2
assert len(embeds[0]) > 0 # Check if embeddings are non-empty
return embeds[0] # Return for use in other tests
def test_02_create_index(self):
"""Test Pinecone index creation"""
# Get sample embedding dimension from previous test
sample_embed = self.test_01_create_embeddings()
embedding_dimension = len(sample_embed)
spec = ServerlessSpec(cloud="aws", region="us-east-1")
# Create index if it doesn't exist
if self.index_name not in self.pc.list_indexes().names():
self.pc.create_index(
self.index_name,
dimension=embedding_dimension,
metric='dotproduct',
spec=spec
)
# Wait for index to be ready
max_retries = 60 # Maximum number of seconds to wait
retries = 0
while not self.pc.describe_index(self.index_name).status['ready']:
if retries >= max_retries:
raise TimeoutError("Index creation timed out")
time.sleep(1)
retries += 1
# Verify index exists and is ready
assert self.index_name in self.pc.list_indexes().names()
assert self.pc.describe_index(self.index_name).status['ready']
def test_03_upload_data(self):
"""Test data upload to Pinecone"""
# Ensure index exists first
self.test_02_create_index()
# Connect to index
index = self.pc.Index(self.index_name)
# Load test dataset - using 'trec' instead of 'train'
trec = load_dataset('trec', split='train[:10]') # Using smaller dataset for testing
batch_size = 5
total_processed = 0
for i in range(0, len(trec['text']), batch_size):
i_end = min(i + batch_size, len(trec['text']))
lines_batch = trec['text'][i:i_end]
ids_batch = [str(n) for n in range(i, i_end)]
# Create embeddings
res = self.client.embeddings.create(input=lines_batch, model=self.MODEL)
embeds = [record.embedding for record in res.data]
# Prepare metadata and upsert batch
meta = [{'text': line} for line in lines_batch]
to_upsert = zip(ids_batch, embeds, meta)
# Upsert to Pinecone
index.upsert(vectors=list(to_upsert))
total_processed += len(lines_batch)
# Wait for a moment to ensure data is indexed
time.sleep(5)
# Verify data was uploaded
stats = index.describe_index_stats()
print(f'stats: {stats}')
# assert stats.total_vector_count == total_processed
def test_04_query_index(self):
"""Test querying the Pinecone index"""
# Ensure data is uploaded first
self.test_03_upload_data()
index = self.pc.Index(self.index_name)
# Create query embedding
query = "What caused the Great Depression?"
xq = self.client.embeddings.create(input=query, model=self.MODEL).data[0].embedding
# Query index
res = index.query(vector=xq, top_k=5, include_metadata=True)
# Verify response format
assert 'matches' in res
assert len(res['matches']) <= 5 # Should return up to 5 results
# Verify match format
for match in res['matches']:
assert 'score' in match
assert 'metadata' in match
assert 'text' in match['metadata']
def test_05_delete_index(self):
"""Test index deletion"""
# Ensure index exists first
self.test_02_create_index()
# Delete index
self.pc.delete_index(self.index_name)
# Verify deletion
assert self.index_name not in self.pc.list_indexes().names() |