import gptTokenizer from "gpt-tokenizer"; import type { ChatMessage } from "gpt-tokenizer/GptEncoding"; import prettyMilliseconds from "pretty-ms"; import { addLogEntry } from "./logEntries"; import { getQuery, getSettings, getTextGenerationState, listenToSettingsChanges, updateResponse, updateSearchPromise, updateSearchResults, updateSearchState, updateTextGenerationState, } from "./pubSub"; import { search } from "./search"; import { getSystemPrompt } from "./systemPrompt"; import { ChatGenerationError, getFormattedSearchResults, } from "./textGenerationUtilities"; import { isWebGPUAvailable } from "./webGpu"; export async function searchAndRespond() { if (getQuery() === "") return; document.title = getQuery(); updateResponse(""); updateSearchResults({ textResults: [], imageResults: [] }); updateSearchPromise(startSearch(getQuery())); if (!getSettings().enableAiResponse) return; const responseGenerationStartTime = new Date().getTime(); updateTextGenerationState("loadingModel"); try { const settings = getSettings(); if (settings.inferenceType === "openai") { const { generateTextWithOpenAi } = await import( "./textGenerationWithOpenAi" ); await generateTextWithOpenAi(); } else if (settings.inferenceType === "internal") { const { generateTextWithInternalApi } = await import( "./textGenerationWithInternalApi" ); await generateTextWithInternalApi(); } else { await canDownloadModels(); updateTextGenerationState("loadingModel"); if (isWebGPUAvailable && settings.enableWebGpu) { const { generateTextWithWebLlm } = await import( "./textGenerationWithWebLlm" ); await generateTextWithWebLlm(); } else { const { generateTextWithWllama } = await import( "./textGenerationWithWllama" ); await generateTextWithWllama(); } } updateTextGenerationState("completed"); } catch (error) { if (getTextGenerationState() !== "interrupted") { addLogEntry(`Error generating text: ${error}`); updateTextGenerationState("failed"); } } addLogEntry( `Response generation took ${prettyMilliseconds( new Date().getTime() - responseGenerationStartTime, { verbose: true }, )}`, ); } export async function generateChatResponse( newMessages: ChatMessage[], onUpdate: (partialResponse: string) => void, ) { const settings = getSettings(); let response = ""; try { const allMessages: ChatMessage[] = [ { role: "user", content: getSystemPrompt(getFormattedSearchResults(true)), }, { role: "assistant", content: "Ok!" }, ...newMessages, ]; const lastMessagesReversed: ChatMessage[] = []; let totalTokens = 0; for (const message of allMessages.reverse()) { const newTotalTokens = totalTokens + gptTokenizer.encode(message.content).length; if (newTotalTokens > 1280) break; totalTokens = newTotalTokens; lastMessagesReversed.push(message); } const lastMessages = lastMessagesReversed.reverse(); if (settings.inferenceType === "openai") { const { generateChatWithOpenAi } = await import( "./textGenerationWithOpenAi" ); response = await generateChatWithOpenAi(lastMessages, onUpdate); } else if (settings.inferenceType === "internal") { const { generateChatWithInternalApi } = await import( "./textGenerationWithInternalApi" ); response = await generateChatWithInternalApi(lastMessages, onUpdate); } else { if (isWebGPUAvailable && settings.enableWebGpu) { const { generateChatWithWebLlm } = await import( "./textGenerationWithWebLlm" ); response = await generateChatWithWebLlm(lastMessages, onUpdate); } else { const { generateChatWithWllama } = await import( "./textGenerationWithWllama" ); response = await generateChatWithWllama(lastMessages, onUpdate); } } } catch (error) { if (error instanceof ChatGenerationError) { addLogEntry(`Chat generation interrupted: ${error.message}`); } else { addLogEntry(`Error generating chat response: ${error}`); } throw error; } return response; } async function getKeywords(text: string, limit?: number) { return (await import("keyword-extractor")).default .extract(text, { language: "english" }) .slice(0, limit); } async function startSearch(query: string) { updateSearchState("running"); const searchResultsLimit = 15; let searchResults = await search( query.length > 2000 ? (await getKeywords(query, 20)).join(" ") : query, searchResultsLimit, ); if (searchResults.textResults.length === 0) { const queryKeywords = await getKeywords(query, 10); searchResults = await search(queryKeywords.join(" "), searchResultsLimit); } updateSearchState( searchResults.textResults.length === 0 ? "failed" : "completed", ); updateSearchResults(searchResults); return searchResults; } function canDownloadModels(): Promise { return new Promise((resolve) => { if (getSettings().allowAiModelDownload) { resolve(); } else { updateTextGenerationState("awaitingModelDownloadAllowance"); listenToSettingsChanges((settings) => { if (settings.allowAiModelDownload) { resolve(); } }); } }); }