static-retrieval-mrl-en-v1-sqlite-vec
/
EmbeddingsBenchmark
/Sources
/EmbeddingsBenchmarkLib
/EmbeddingsBenchmarkLib.swift
import Accelerate | |
import CoreML | |
import Embeddings | |
import Foundation | |
import SQLiteVec | |
public func initializeDatabase(_ filePath: String) throws -> Database { | |
try SQLiteVec.initialize() | |
return try Database(.uri(filePath)) | |
} | |
@discardableResult | |
public func queryEmbeddings(db: Database, query: String, tokenIds: [Int], vectorSize: Int = 1024) async throws -> [Float] { | |
let result = try await db.query("SELECT embedding FROM embeddings WHERE rowid IN \(query)", params: tokenIds) | |
var acc = [Float](repeating: 0, count: vectorSize) | |
for item in result { | |
guard let embedding = item["embedding"] as? Data else { | |
continue | |
} | |
let row: [Float] = embedding.toArray() | |
acc = vDSP.add(row, acc) | |
} | |
return vDSP.divide(acc, Float(result.count)) | |
} | |
@discardableResult | |
public func queryEmbeddings(embeddings: MLTensor, tokenIds: [Int32]) async -> [Float] { | |
let indices = MLTensor(shape: [tokenIds.count], scalars: tokenIds) | |
let data = embeddings | |
.gathering(atIndices: indices, alongAxis: 0) | |
.mean(alongAxes: 0) | |
return await data.shapedArray(of: Float.self).scalars | |
} | |