Spaces:
Build error
Build error
import pytest | |
import os | |
import time | |
from openai import OpenAI | |
from pinecone import Pinecone, ServerlessSpec | |
from datasets import load_dataset | |
from dotenv import load_dotenv | |
class TestPineconeIntegration: | |
def setup(self): | |
"""Setup test environment and resources""" | |
# Load environment variables | |
load_dotenv("../") | |
# Initialize clients | |
self.pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
# Constants | |
self.MODEL = "text-embedding-3-small" | |
self.index_name = "test-semantic-search-openai" | |
yield # This is where the test runs | |
# Cleanup after tests | |
try: | |
if self.index_name in self.pc.list_indexes().names(): | |
self.pc.delete_index(self.index_name) | |
except Exception as e: | |
print(f"Cleanup failed: {str(e)}") | |
def test_01_create_embeddings(self): | |
"""Test OpenAI embedding creation""" | |
sample_texts = [ | |
"Sample document text goes here", | |
"there will be several phrases in each batch" | |
] | |
res = self.client.embeddings.create( | |
input=sample_texts, | |
model=self.MODEL | |
) | |
embeds = [record.embedding for record in res.data] | |
assert len(embeds) == 2 | |
assert len(embeds[0]) > 0 # Check if embeddings are non-empty | |
return embeds[0] # Return for use in other tests | |
def test_02_create_index(self): | |
"""Test Pinecone index creation""" | |
# Get sample embedding dimension from previous test | |
sample_embed = self.test_01_create_embeddings() | |
embedding_dimension = len(sample_embed) | |
spec = ServerlessSpec(cloud="aws", region="us-east-1") | |
# Create index if it doesn't exist | |
if self.index_name not in self.pc.list_indexes().names(): | |
self.pc.create_index( | |
self.index_name, | |
dimension=embedding_dimension, | |
metric='dotproduct', | |
spec=spec | |
) | |
# Wait for index to be ready | |
max_retries = 60 # Maximum number of seconds to wait | |
retries = 0 | |
while not self.pc.describe_index(self.index_name).status['ready']: | |
if retries >= max_retries: | |
raise TimeoutError("Index creation timed out") | |
time.sleep(1) | |
retries += 1 | |
# Verify index exists and is ready | |
assert self.index_name in self.pc.list_indexes().names() | |
assert self.pc.describe_index(self.index_name).status['ready'] | |
def test_03_upload_data(self): | |
"""Test data upload to Pinecone""" | |
# Ensure index exists first | |
self.test_02_create_index() | |
# Connect to index | |
index = self.pc.Index(self.index_name) | |
# Load test dataset - using 'trec' instead of 'train' | |
trec = load_dataset('trec', split='train[:10]') # Using smaller dataset for testing | |
batch_size = 5 | |
total_processed = 0 | |
for i in range(0, len(trec['text']), batch_size): | |
i_end = min(i + batch_size, len(trec['text'])) | |
lines_batch = trec['text'][i:i_end] | |
ids_batch = [str(n) for n in range(i, i_end)] | |
# Create embeddings | |
res = self.client.embeddings.create(input=lines_batch, model=self.MODEL) | |
embeds = [record.embedding for record in res.data] | |
# Prepare metadata and upsert batch | |
meta = [{'text': line} for line in lines_batch] | |
to_upsert = zip(ids_batch, embeds, meta) | |
# Upsert to Pinecone | |
index.upsert(vectors=list(to_upsert)) | |
total_processed += len(lines_batch) | |
# Wait for a moment to ensure data is indexed | |
time.sleep(5) | |
# Verify data was uploaded | |
stats = index.describe_index_stats() | |
print(f'stats: {stats}') | |
# assert stats.total_vector_count == total_processed | |
def test_04_query_index(self): | |
"""Test querying the Pinecone index""" | |
# Ensure data is uploaded first | |
self.test_03_upload_data() | |
index = self.pc.Index(self.index_name) | |
# Create query embedding | |
query = "What caused the Great Depression?" | |
xq = self.client.embeddings.create(input=query, model=self.MODEL).data[0].embedding | |
# Query index | |
res = index.query(vector=xq, top_k=5, include_metadata=True) | |
# Verify response format | |
assert 'matches' in res | |
assert len(res['matches']) <= 5 # Should return up to 5 results | |
# Verify match format | |
for match in res['matches']: | |
assert 'score' in match | |
assert 'metadata' in match | |
assert 'text' in match['metadata'] | |
def test_05_delete_index(self): | |
"""Test index deletion""" | |
# Ensure index exists first | |
self.test_02_create_index() | |
# Delete index | |
self.pc.delete_index(self.index_name) | |
# Verify deletion | |
assert self.index_name not in self.pc.list_indexes().names() |