import { isWebGPUAvailable } from "./webGpu"; import { updateSearchResults, updateResponse, getSearchResults, getQuery, updateSearchPromise, getSearchPromise, updateTextGenerationState, updateSearchState, updateModelLoadingProgress, getTextGenerationState, getSettings, listenToSettingsChanges, } from "./pubSub"; import { search } from "./search"; import { addLogEntry } from "./logEntries"; import { getSystemPrompt } from "./systemPrompt"; import prettyMilliseconds from "pretty-ms"; import { getOpenAiClient } from "./openai"; import { getSearchTokenHash } from "./searchTokenHash"; import { ChatCompletionCreateParamsStreaming, ChatCompletionMessageParam, } from "openai/resources/chat/completions.mjs"; import gptTokenizer from "gpt-tokenizer"; import { ChatOptions } from "@mlc-ai/web-llm"; import { defaultSettings } from "./settings"; export async function searchAndRespond() { if (getQuery() === "") return; document.title = getQuery(); updateResponse(""); updateSearchResults({ textResults: [], imageResults: [] }); updateSearchPromise(startSearch(getQuery())); if (!getSettings().enableAiResponse) return; const responseGenerationStartTime = new Date().getTime(); updateTextGenerationState("loadingModel"); try { const settings = getSettings(); if (settings.inferenceType === "openai") { await generateTextWithOpenAI(); } else if (settings.inferenceType === "internal") { await generateTextWithInternalApi(); } else { await canDownloadModels(); updateTextGenerationState("loadingModel"); try { if (!isWebGPUAvailable) throw Error("WebGPU is not available."); if (!settings.enableWebGpu) throw Error("WebGPU is disabled."); await generateTextWithWebLlm(); } catch (error) { addLogEntry(`Skipping text generation with WebLLM: ${error}`); addLogEntry(`Starting text generation with Wllama`); await generateTextWithWllama(); } } if (getTextGenerationState() !== "interrupted") { updateTextGenerationState("completed"); } } catch (error) { addLogEntry(`Error generating text: ${error}`); updateTextGenerationState("failed"); } addLogEntry( `Response generation took ${prettyMilliseconds( new Date().getTime() - responseGenerationStartTime, { verbose: true }, )}`, ); } async function generateTextWithOpenAI() { const settings = getSettings(); const openai = getOpenAiClient({ baseURL: settings.openAiApiBaseUrl, apiKey: settings.openAiApiKey, }); await canStartResponding(); updateTextGenerationState("preparingToGenerate"); const completion = await openai.chat.completions.create({ ...getDefaultChatCompletionCreateParamsStreaming(), model: settings.openAiApiModel, messages: [ { role: "user", content: getSystemPrompt(getFormattedSearchResults(true)), }, { role: "assistant", content: "Ok!" }, { role: "user", content: getQuery() }, ], }); let streamedMessage = ""; for await (const chunk of completion) { const deltaContent = chunk.choices[0].delta.content; if (deltaContent) streamedMessage += deltaContent; if (getTextGenerationState() === "interrupted") { completion.controller.abort(); } else if (getTextGenerationState() !== "generating") { updateTextGenerationState("generating"); } updateResponseRateLimited(streamedMessage); } updateResponse(streamedMessage); } async function generateTextWithInternalApi() { await canStartResponding(); updateTextGenerationState("preparingToGenerate"); const inferenceUrl = new URL("/inference", self.location.origin); const tokenPrefix = "Bearer "; const token = await getSearchTokenHash(); const response = await fetch(inferenceUrl.toString(), { method: "POST", headers: { "Content-Type": "application/json", Authorization: `${tokenPrefix}${token}`, }, body: JSON.stringify({ ...getDefaultChatCompletionCreateParamsStreaming(), messages: [ { role: "user", content: getSystemPrompt(getFormattedSearchResults(true)), }, { role: "assistant", content: "Ok!" }, { role: "user", content: getQuery() }, ], } as ChatCompletionCreateParamsStreaming), }); if (!response.ok || !response.body) { throw new Error(`HTTP error! status: ${response.status}`); } const reader = response.body.getReader(); const decoder = new TextDecoder("utf-8"); let streamedMessage = ""; while (true) { const { done, value } = await reader.read(); if (done) break; const chunk = decoder.decode(value); const lines = chunk.split("\n"); const parsedLines = lines .map((line) => line.replace(/^data: /, "").trim()) .filter((line) => line !== "" && line !== "[DONE]") .map((line) => JSON.parse(line)); for (const parsedLine of parsedLines) { const deltaContent = parsedLine.choices[0].delta.content; if (deltaContent) streamedMessage += deltaContent; if (getTextGenerationState() === "interrupted") { reader.cancel(); } else if (getTextGenerationState() !== "generating") { updateTextGenerationState("generating"); } updateResponseRateLimited(streamedMessage); } } updateResponse(streamedMessage); } async function generateTextWithWebLlm() { const { CreateWebWorkerMLCEngine, CreateMLCEngine, hasModelInCache } = await import("@mlc-ai/web-llm"); type InitProgressCallback = import("@mlc-ai/web-llm").InitProgressCallback; type MLCEngineConfig = import("@mlc-ai/web-llm").MLCEngineConfig; type ChatOptions = import("@mlc-ai/web-llm").ChatOptions; const selectedModelId = getSettings().webLlmModelId; addLogEntry(`Selected WebLLM model: ${selectedModelId}`); const isModelCached = await hasModelInCache(selectedModelId); let initProgressCallback: InitProgressCallback | undefined; if (isModelCached) { updateTextGenerationState("preparingToGenerate"); } else { initProgressCallback = (report) => { updateModelLoadingProgress(Math.round(report.progress * 100)); }; } const engineConfig: MLCEngineConfig = { initProgressCallback, logLevel: "SILENT", }; const chatOptions: ChatOptions = { repetition_penalty: defaultSettings.inferenceRepeatPenalty, }; const engine = Worker ? await CreateWebWorkerMLCEngine( new Worker(new URL("./webLlmWorker.ts", import.meta.url), { type: "module", }), selectedModelId, engineConfig, chatOptions, ) : await CreateMLCEngine(selectedModelId, engineConfig, chatOptions); if (getSettings().enableAiResponse) { await canStartResponding(); updateTextGenerationState("preparingToGenerate"); const completion = await engine.chat.completions.create({ ...getDefaultChatCompletionCreateParamsStreaming(), messages: [ { role: "user", content: getSystemPrompt(getFormattedSearchResults(true)), }, { role: "assistant", content: "Ok!" }, { role: "user", content: getQuery() }, ], }); let streamedMessage = ""; for await (const chunk of completion) { const deltaContent = chunk.choices[0].delta.content; if (deltaContent) streamedMessage += deltaContent; if (getTextGenerationState() === "interrupted") { await engine.interruptGenerate(); } else if (getTextGenerationState() !== "generating") { updateTextGenerationState("generating"); } updateResponseRateLimited(streamedMessage); } updateResponse(streamedMessage); } addLogEntry( `WebLLM finished generating the response. Stats: ${await engine.runtimeStatsText()}`, ); engine.unload(); } async function generateTextWithWllama() { const { initializeWllama, wllamaModels } = await import("./wllama"); let loadingPercentage = 0; const model = wllamaModels[getSettings().wllamaModelId]; const wllama = await initializeWllama(model.url, { wllama: { suppressNativeLog: true, }, model: { n_threads: getSettings().cpuThreads, n_ctx: model.contextSize, cache_type_k: model.cacheType, embeddings: false, allowOffline: true, progressCallback: ({ loaded, total }) => { const progressPercentage = Math.round((loaded / total) * 100); if (loadingPercentage !== progressPercentage) { loadingPercentage = progressPercentage; updateModelLoadingProgress(progressPercentage); } }, }, }); if (getSettings().enableAiResponse) { await canStartResponding(); updateTextGenerationState("preparingToGenerate"); const prompt = await model.buildPrompt( wllama, getQuery(), getFormattedSearchResults(model.shouldIncludeUrlsOnPrompt), ); let streamedMessage = ""; await wllama.createCompletion(prompt, { stopTokens: model.stopTokens, sampling: model.sampling, onNewToken: (_token, _piece, currentText, { abortSignal }) => { if (getTextGenerationState() === "interrupted") { abortSignal(); } else if (getTextGenerationState() !== "generating") { updateTextGenerationState("generating"); } if (model.stopStrings) { for (const stopString of model.stopStrings) { if ( currentText.slice(-(stopString.length * 2)).includes(stopString) ) { abortSignal(); currentText = currentText.slice(0, -stopString.length); break; } } } streamedMessage = currentText; updateResponseRateLimited(streamedMessage); }, }); updateResponse(streamedMessage); } await wllama.exit(); } function getFormattedSearchResults(shouldIncludeUrl: boolean) { const searchResults = getSearchResults().textResults.slice( 0, getSettings().searchResultsToConsider, ); if (searchResults.length === 0) return "None."; if (shouldIncludeUrl) { return searchResults .map( ([title, snippet, url], index) => `${index + 1}. [${title}](${url}) | ${snippet}`, ) .join("\n"); } return searchResults .map(([title, snippet]) => `- ${title} | ${snippet}`) .join("\n"); } async function getKeywords(text: string, limit?: number) { return (await import("keyword-extractor")).default .extract(text, { language: "english" }) .slice(0, limit); } async function startSearch(query: string) { updateSearchState("running"); let searchResults = await search( query.length > 2000 ? (await getKeywords(query, 20)).join(" ") : query, 30, ); if (searchResults.textResults.length === 0) { const queryKeywords = await getKeywords(query, 10); searchResults = await search(queryKeywords.join(" "), 30); } updateSearchState( searchResults.textResults.length === 0 ? "failed" : "completed", ); updateSearchResults(searchResults); return searchResults; } async function canStartResponding() { if (getSettings().searchResultsToConsider > 0) { updateTextGenerationState("awaitingSearchResults"); await getSearchPromise(); } } function updateResponseRateLimited(text: string) { const currentTime = Date.now(); if ( currentTime - updateResponseRateLimited.lastUpdateTime >= updateResponseRateLimited.updateInterval ) { updateResponse(text); updateResponseRateLimited.lastUpdateTime = currentTime; } } updateResponseRateLimited.lastUpdateTime = 0; updateResponseRateLimited.updateInterval = 1000 / 12; class ChatGenerationError extends Error { constructor(message: string) { super(message); this.name = "ChatGenerationError"; } } async function generateChatWithOpenAI( messages: ChatMessage[], onUpdate: (partialResponse: string) => void, ) { const settings = getSettings(); const openai = getOpenAiClient({ baseURL: settings.openAiApiBaseUrl, apiKey: settings.openAiApiKey, }); const completion = await openai.chat.completions.create({ ...getDefaultChatCompletionCreateParamsStreaming(), model: settings.openAiApiModel, messages: messages as ChatCompletionMessageParam[], }); let streamedMessage = ""; for await (const chunk of completion) { const deltaContent = chunk.choices[0].delta.content; if (deltaContent) { streamedMessage += deltaContent; onUpdate(streamedMessage); } if (getTextGenerationState() === "interrupted") { completion.controller.abort(); throw new ChatGenerationError("Chat generation interrupted"); } } return streamedMessage; } async function generateChatWithInternalApi( messages: ChatMessage[], onUpdate: (partialResponse: string) => void, ) { const inferenceUrl = new URL("/inference", self.location.origin); const tokenPrefix = "Bearer "; const token = await getSearchTokenHash(); const response = await fetch(inferenceUrl.toString(), { method: "POST", headers: { "Content-Type": "application/json", Authorization: `${tokenPrefix}${token}`, }, body: JSON.stringify({ ...getDefaultChatCompletionCreateParamsStreaming(), messages, } as ChatCompletionCreateParamsStreaming), }); if (!response.ok || !response.body) { throw new Error(`HTTP error! status: ${response.status}`); } const reader = response.body.getReader(); const decoder = new TextDecoder("utf-8"); let streamedMessage = ""; while (true) { const { done, value } = await reader.read(); if (done) break; const chunk = decoder.decode(value); const lines = chunk.split("\n"); const parsedLines = lines .map((line) => line.replace(/^data: /, "").trim()) .filter((line) => line !== "" && line !== "[DONE]") .map((line) => JSON.parse(line)); for (const parsedLine of parsedLines) { const deltaContent = parsedLine.choices[0].delta.content; if (deltaContent) { streamedMessage += deltaContent; onUpdate(streamedMessage); } if (getTextGenerationState() === "interrupted") { reader.cancel(); throw new ChatGenerationError("Chat generation interrupted"); } } } return streamedMessage; } async function generateChatWithWebLlm( messages: ChatMessage[], onUpdate: (partialResponse: string) => void, ) { const { CreateWebWorkerMLCEngine, CreateMLCEngine } = await import( "@mlc-ai/web-llm" ); type MLCEngineConfig = import("@mlc-ai/web-llm").MLCEngineConfig; type ChatCompletionMessageParam = import("@mlc-ai/web-llm").ChatCompletionMessageParam; const selectedModelId = getSettings().webLlmModelId; addLogEntry(`Selected WebLLM model for chat: ${selectedModelId}`); const engineConfig: MLCEngineConfig = { logLevel: "SILENT", }; const chatOptions: ChatOptions = { repetition_penalty: defaultSettings.inferenceRepeatPenalty, }; const engine = Worker ? await CreateWebWorkerMLCEngine( new Worker(new URL("./webLlmWorker.ts", import.meta.url), { type: "module", }), selectedModelId, engineConfig, chatOptions, ) : await CreateMLCEngine(selectedModelId, engineConfig, chatOptions); const completion = await engine.chat.completions.create({ ...getDefaultChatCompletionCreateParamsStreaming(), messages: messages as ChatCompletionMessageParam[], }); let streamedMessage = ""; for await (const chunk of completion) { const deltaContent = chunk.choices[0].delta.content; if (deltaContent) { streamedMessage += deltaContent; onUpdate(streamedMessage); } if (getTextGenerationState() === "interrupted") { await engine.interruptGenerate(); throw new ChatGenerationError("Chat generation interrupted"); } } addLogEntry( `WebLLM finished generating the chat response. Stats: ${await engine.runtimeStatsText()}`, ); engine.unload(); return streamedMessage; } async function generateChatWithWllama( messages: ChatMessage[], onUpdate: (partialResponse: string) => void, ) { const { initializeWllama, wllamaModels } = await import("./wllama"); const model = wllamaModels[getSettings().wllamaModelId]; const wllama = await initializeWllama(model.url, { wllama: { suppressNativeLog: true, }, model: { n_threads: getSettings().cpuThreads, n_ctx: model.contextSize, cache_type_k: model.cacheType, embeddings: false, allowOffline: true, }, }); const prompt = await model.buildPrompt( wllama, messages[messages.length - 1].content, getFormattedSearchResults(model.shouldIncludeUrlsOnPrompt), ); let streamedMessage = ""; await wllama.createCompletion(prompt, { stopTokens: model.stopTokens, sampling: model.sampling, onNewToken: (_token, _piece, currentText, { abortSignal }) => { if (getTextGenerationState() === "interrupted") { abortSignal(); throw new ChatGenerationError("Chat generation interrupted"); } if (model.stopStrings) { for (const stopString of model.stopStrings) { if ( currentText.slice(-(stopString.length * 2)).includes(stopString) ) { abortSignal(); currentText = currentText.slice(0, -stopString.length); break; } } } streamedMessage = currentText; onUpdate(streamedMessage); }, }); await wllama.exit(); return streamedMessage; } export async function generateChatResponse( newMessages: ChatMessage[], onUpdate: (partialResponse: string) => void, ) { const settings = getSettings(); let response = ""; try { const allMessages = [ { role: "user", content: getSystemPrompt(getFormattedSearchResults(true)), }, { role: "assistant", content: "Ok!" }, ...newMessages, ]; const lastMessagesReversed: ChatMessage[] = []; let totalTokens = 0; for (const message of allMessages.reverse()) { const newTotalTokens = totalTokens + gptTokenizer.encode(message.content).length; if (newTotalTokens > 1280) break; totalTokens = newTotalTokens; lastMessagesReversed.push(message); } const lastMessages = lastMessagesReversed.reverse(); if (settings.inferenceType === "openai") { response = await generateChatWithOpenAI(lastMessages, onUpdate); } else if (settings.inferenceType === "internal") { response = await generateChatWithInternalApi(lastMessages, onUpdate); } else { if (isWebGPUAvailable && settings.enableWebGpu) { response = await generateChatWithWebLlm(lastMessages, onUpdate); } else { response = await generateChatWithWllama(lastMessages, onUpdate); } } } catch (error) { if (error instanceof ChatGenerationError) { addLogEntry(`Chat generation interrupted: ${error.message}`); } else { addLogEntry(`Error generating chat response: ${error}`); } throw error; } return response; } export interface ChatMessage { role: "user" | "assistant" | string; content: string; } function canDownloadModels(): Promise { return new Promise((resolve) => { if (getSettings().allowAiModelDownload) { resolve(); } else { updateTextGenerationState("awaitingModelDownloadAllowance"); listenToSettingsChanges((settings) => { if (settings.allowAiModelDownload) { resolve(); } }); } }); } function getDefaultChatCompletionCreateParamsStreaming() { return { stream: true, max_tokens: 2048, temperature: defaultSettings.inferenceTemperature, top_p: defaultSettings.inferenceTopP, frequency_penalty: defaultSettings.inferenceFrequencyPenalty, presence_penalty: defaultSettings.inferencePresencePenalty, } as const; }