static-retrieval-mrl-en-v1-sqlite-vec
/
EmbeddingsBenchmark
/Tests
/EmbeddingsBenchmarkTests
/EmbeddingsBenchmarkTests.swift
import CoreML | |
import SQLiteVec | |
import Testing | |
@testable import EmbeddingsBenchmarkLib | |
func createDatabase(_ data: [[Float]]) async throws -> Database { | |
try SQLiteVec.initialize() | |
let db = try Database(.inMemory) | |
try await db.execute("CREATE VIRTUAL TABLE embeddings USING vec0(embedding float[3])") | |
for (index, row) in data.enumerated() { | |
try await db.execute( | |
""" | |
INSERT INTO embeddings(rowid, embedding) | |
VALUES (?, ?) | |
""", | |
params: [index, row] | |
) | |
} | |
return db | |
} | |
func testEmbeddingMethods() async throws { | |
let data: [[Float]] = [ | |
[1.0, 2.0, 3.0], | |
[4.0, 5.0, 6.0], | |
[7.0, 8.0, 9.0] | |
] | |
let embeddings = MLTensor(shape: [3, 3], scalars: data.flatMap { $0 }) | |
let coreMLResult = await queryEmbeddings(embeddings: embeddings, tokenIds: [0, 2]) | |
let db = try await createDatabase(data) | |
let sqliteResult = try await queryEmbeddings( | |
db: db, | |
query: "(?, ?)", | |
tokenIds: [0, 2], | |
vectorSize: 3) | |
#expect(coreMLResult == sqliteResult) | |
} | |