File size: 1,147 Bytes
902bb3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import Accelerate
import CoreML
import Embeddings
import Foundation
import SQLiteVec

public func initializeDatabase(_ filePath: String) throws -> Database {
    try SQLiteVec.initialize()
    return try Database(.uri(filePath))
}

@discardableResult
public func queryEmbeddings(db: Database, query: String, tokenIds: [Int], vectorSize: Int = 1024) async throws -> [Float] {
    let result = try await db.query("SELECT embedding FROM embeddings WHERE rowid IN \(query)", params: tokenIds)
    var acc = [Float](repeating: 0, count: vectorSize)
    for item in result {
        guard let embedding = item["embedding"] as? Data else {
            continue
        }
        let row: [Float] = embedding.toArray()
        acc = vDSP.add(row, acc)
    }
    return vDSP.divide(acc, Float(result.count))
}

@discardableResult
public func queryEmbeddings(embeddings: MLTensor, tokenIds: [Int32]) async -> [Float] {
    let indices = MLTensor(shape: [tokenIds.count], scalars: tokenIds)
    let data = embeddings
        .gathering(atIndices: indices, alongAxis: 0)
        .mean(alongAxes: 0)
    return await data.shapedArray(of: Float.self).scalars
}