File size: 3,527 Bytes
f152ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import {
  getQuery,
  getSettings,
  getTextGenerationState,
  updateModelLoadingProgress,
  updateModelSizeInMegabytes,
  updateResponse,
  updateTextGenerationState,
} from "./pubSub";
import {
  ChatGenerationError,
  canStartResponding,
  getFormattedSearchResults,
} from "./textGenerationUtilities";

export async function generateTextWithWllama() {
  if (!getSettings().enableAiResponse) return;

  const response = await generateWithWllama(getQuery(), updateResponse, true);

  updateResponse(response);
}

export async function generateChatWithWllama(
  messages: import("gpt-tokenizer/GptEncoding").ChatMessage[],
  onUpdate: (partialResponse: string) => void,
) {
  return generateWithWllama(
    messages[messages.length - 1].content,
    onUpdate,
    false,
  );
}

async function initializeWllamaInstance(
  progressCallback?: ({
    loaded,
    total,
  }: {
    loaded: number;
    total: number;
  }) => void,
) {
  const { initializeWllama, wllamaModels } = await import("./wllama");
  const model = wllamaModels[getSettings().wllamaModelId];

  updateModelSizeInMegabytes(model.fileSizeInMegabytes);

  const wllama = await initializeWllama(model.url, {
    wllama: {
      suppressNativeLog: true,
    },
    model: {
      n_threads: getSettings().cpuThreads,
      n_ctx: model.contextSize,
      cache_type_k: model.cacheType,
      embeddings: false,
      allowOffline: true,
      progressCallback,
    },
  });

  return { wllama, model };
}

async function generateWithWllama(
  input: string,
  onUpdate: (partialResponse: string) => void,
  shouldCheckCanRespond = false,
) {
  let loadingPercentage = 0;

  const { wllama, model } = await initializeWllamaInstance(
    shouldCheckCanRespond
      ? ({ loaded, total }) => {
          const progressPercentage = Math.round((loaded / total) * 100);
          if (loadingPercentage !== progressPercentage) {
            loadingPercentage = progressPercentage;
            updateModelLoadingProgress(progressPercentage);
          }
        }
      : undefined,
  );

  if (shouldCheckCanRespond) {
    await canStartResponding();
    updateTextGenerationState("preparingToGenerate");
  }

  const prompt = await model.buildPrompt(
    wllama,
    input,
    getFormattedSearchResults(model.shouldIncludeUrlsOnPrompt),
  );

  let streamedMessage = "";

  await wllama.createCompletion(prompt, {
    nPredict: 2048,
    stopTokens: model.stopTokens,
    sampling: model.sampling,
    onNewToken: (_token, _piece, currentText, { abortSignal }) => {
      if (shouldCheckCanRespond && getTextGenerationState() === "interrupted") {
        abortSignal();
        throw new ChatGenerationError("Chat generation interrupted");
      }

      if (shouldCheckCanRespond && getTextGenerationState() !== "generating") {
        updateTextGenerationState("generating");
      }

      streamedMessage = handleWllamaCompletion(
        model,
        currentText,
        abortSignal,
        onUpdate,
      );
    },
  });

  await wllama.exit();
  return streamedMessage;
}

function handleWllamaCompletion(
  model: import("./wllama").WllamaModel,
  currentText: string,
  abortSignal: () => void,
  onUpdate: (text: string) => void,
) {
  let text = currentText;

  if (model.stopStrings) {
    for (const stopString of model.stopStrings) {
      if (text.slice(-(stopString.length * 2)).includes(stopString)) {
        abortSignal();
        text = text.slice(0, -stopString.length);
        break;
      }
    }
  }

  onUpdate(text);
  return text;
}