sehatech-demo / tests /test_pinecone.py
larawehbe's picture
Upload folder using huggingface_hub
965ac15 verified
import pytest
import os
import time
from openai import OpenAI
from pinecone import Pinecone, ServerlessSpec
from datasets import load_dataset
from dotenv import load_dotenv
class TestPineconeIntegration:
@pytest.fixture(autouse=True)
def setup(self):
"""Setup test environment and resources"""
# Load environment variables
load_dotenv("../")
# Initialize clients
self.pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Constants
self.MODEL = "text-embedding-3-small"
self.index_name = "test-semantic-search-openai"
yield # This is where the test runs
# Cleanup after tests
try:
if self.index_name in self.pc.list_indexes().names():
self.pc.delete_index(self.index_name)
except Exception as e:
print(f"Cleanup failed: {str(e)}")
def test_01_create_embeddings(self):
"""Test OpenAI embedding creation"""
sample_texts = [
"Sample document text goes here",
"there will be several phrases in each batch"
]
res = self.client.embeddings.create(
input=sample_texts,
model=self.MODEL
)
embeds = [record.embedding for record in res.data]
assert len(embeds) == 2
assert len(embeds[0]) > 0 # Check if embeddings are non-empty
return embeds[0] # Return for use in other tests
def test_02_create_index(self):
"""Test Pinecone index creation"""
# Get sample embedding dimension from previous test
sample_embed = self.test_01_create_embeddings()
embedding_dimension = len(sample_embed)
spec = ServerlessSpec(cloud="aws", region="us-east-1")
# Create index if it doesn't exist
if self.index_name not in self.pc.list_indexes().names():
self.pc.create_index(
self.index_name,
dimension=embedding_dimension,
metric='dotproduct',
spec=spec
)
# Wait for index to be ready
max_retries = 60 # Maximum number of seconds to wait
retries = 0
while not self.pc.describe_index(self.index_name).status['ready']:
if retries >= max_retries:
raise TimeoutError("Index creation timed out")
time.sleep(1)
retries += 1
# Verify index exists and is ready
assert self.index_name in self.pc.list_indexes().names()
assert self.pc.describe_index(self.index_name).status['ready']
def test_03_upload_data(self):
"""Test data upload to Pinecone"""
# Ensure index exists first
self.test_02_create_index()
# Connect to index
index = self.pc.Index(self.index_name)
# Load test dataset - using 'trec' instead of 'train'
trec = load_dataset('trec', split='train[:10]') # Using smaller dataset for testing
batch_size = 5
total_processed = 0
for i in range(0, len(trec['text']), batch_size):
i_end = min(i + batch_size, len(trec['text']))
lines_batch = trec['text'][i:i_end]
ids_batch = [str(n) for n in range(i, i_end)]
# Create embeddings
res = self.client.embeddings.create(input=lines_batch, model=self.MODEL)
embeds = [record.embedding for record in res.data]
# Prepare metadata and upsert batch
meta = [{'text': line} for line in lines_batch]
to_upsert = zip(ids_batch, embeds, meta)
# Upsert to Pinecone
index.upsert(vectors=list(to_upsert))
total_processed += len(lines_batch)
# Wait for a moment to ensure data is indexed
time.sleep(5)
# Verify data was uploaded
stats = index.describe_index_stats()
print(f'stats: {stats}')
# assert stats.total_vector_count == total_processed
def test_04_query_index(self):
"""Test querying the Pinecone index"""
# Ensure data is uploaded first
self.test_03_upload_data()
index = self.pc.Index(self.index_name)
# Create query embedding
query = "What caused the Great Depression?"
xq = self.client.embeddings.create(input=query, model=self.MODEL).data[0].embedding
# Query index
res = index.query(vector=xq, top_k=5, include_metadata=True)
# Verify response format
assert 'matches' in res
assert len(res['matches']) <= 5 # Should return up to 5 results
# Verify match format
for match in res['matches']:
assert 'score' in match
assert 'metadata' in match
assert 'text' in match['metadata']
def test_05_delete_index(self):
"""Test index deletion"""
# Ensure index exists first
self.test_02_create_index()
# Delete index
self.pc.delete_index(self.index_name)
# Verify deletion
assert self.index_name not in self.pc.list_indexes().names()