File size: 2,449 Bytes
f152ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import path from "node:path";
import { fileURLToPath } from "node:url";
import { getLlama } from "node-llama-cpp";
import { downloadFileFromHuggingFaceRepository } from "./downloadFileFromHuggingFaceRepository";

const loadModelPromise = loadModel();

export async function rankSearchResults(
  query: string,
  searchResults: [title: string, content: string, url: string][],
) {
  const model = await loadModelPromise;

  const embeddingContext = await model.createEmbeddingContext();

  const queryEmbedding = (
    await embeddingContext.getEmbeddingFor(query.toLocaleLowerCase())
  ).vector;

  const documentsEmbeddings: (readonly number[])[] = [];

  const documents = searchResults.map(([title, snippet, url]) =>
    `[${title}](${url} "${snippet.replaceAll('"', "'")}")`.toLocaleLowerCase(),
  );

  for (const document of documents) {
    const embedding = await embeddingContext.getEmbeddingFor(document);
    documentsEmbeddings.push(embedding.vector);
  }

  const scores = documentsEmbeddings.map((documentEmbedding) =>
    calculateDotProduct(queryEmbedding, documentEmbedding),
  );

  const highestScore = Math.max(...scores);

  const scoreThreshold = highestScore / 2;

  const [firstResult, ...nextResults] = searchResults
    .map((result, index) => ({ result, score: scores[index] }))
    .filter(({ score }) => score > scoreThreshold);

  const nextTopResultsCount = 5;

  const nextTopResults = nextResults
    .slice(0, nextTopResultsCount)
    .sort((a, b) => b.score - a.score);

  const remainingResults = nextResults
    .slice(nextTopResultsCount)
    .sort((a, b) => b.score - a.score);

  return [firstResult, ...nextTopResults, ...remainingResults].map(
    ({ result }) => result,
  );
}

function calculateDotProduct(
  firstArray: readonly number[],
  secondArray: readonly number[],
) {
  let result = 0;

  for (let index = 0; index < firstArray.length; index++) {
    result += firstArray[index] * secondArray[index];
  }

  return result;
}

async function loadModel() {
  const hfRepo = "Felladrin/gguf-Q8_0-all-MiniLM-L6-v2";

  const hfRepoFile = "all-minilm-l6-v2-q8_0.gguf";

  const localFilePath = path.resolve(
    path.dirname(fileURLToPath(import.meta.url)),
    "models",
    hfRepo,
    hfRepoFile,
  );

  const llama = await getLlama();

  await downloadFileFromHuggingFaceRepository(
    hfRepo,
    hfRepoFile,
    localFilePath,
  );

  return llama.loadModel({ modelPath: localFilePath });
}