MiniSearch / client /modules /textGeneration.ts
github-actions[bot]
Sync from https://github.com/felladrin/MiniSearch
e538a38
import gptTokenizer from "gpt-tokenizer";
import type { ChatMessage } from "gpt-tokenizer/GptEncoding";
import prettyMilliseconds from "pretty-ms";
import { addLogEntry } from "./logEntries";
import {
getQuery,
getSettings,
getTextGenerationState,
listenToSettingsChanges,
updateImageSearchResults,
updateImageSearchState,
updateResponse,
updateSearchPromise,
updateTextGenerationState,
updateTextSearchResults,
updateTextSearchState,
} from "./pubSub";
import { searchImages, searchText } from "./search";
import { getSystemPrompt } from "./systemPrompt";
import {
ChatGenerationError,
defaultContextSize,
getFormattedSearchResults,
} from "./textGenerationUtilities";
import type { ImageSearchResults, TextSearchResults } from "./types";
import { isWebGPUAvailable } from "./webGpu";
export async function searchAndRespond() {
if (getQuery() === "") return;
document.title = getQuery();
updateResponse("");
updateTextSearchResults([]);
updateImageSearchResults([]);
updateSearchPromise(startTextSearch(getQuery()));
if (!getSettings().enableAiResponse) return;
const responseGenerationStartTime = new Date().getTime();
try {
const settings = getSettings();
if (settings.inferenceType === "openai") {
const { generateTextWithOpenAi } = await import(
"./textGenerationWithOpenAi"
);
await generateTextWithOpenAi();
} else if (settings.inferenceType === "internal") {
const { generateTextWithInternalApi } = await import(
"./textGenerationWithInternalApi"
);
await generateTextWithInternalApi();
} else if (settings.inferenceType === "horde") {
const { generateTextWithHorde } = await import(
"./textGenerationWithHorde"
);
await generateTextWithHorde();
} else {
await canDownloadModels();
updateTextGenerationState("loadingModel");
if (isWebGPUAvailable && settings.enableWebGpu) {
const { generateTextWithWebLlm } = await import(
"./textGenerationWithWebLlm"
);
await generateTextWithWebLlm();
} else {
const { generateTextWithWllama } = await import(
"./textGenerationWithWllama"
);
await generateTextWithWllama();
}
}
updateTextGenerationState("completed");
} catch (error) {
if (getTextGenerationState() !== "interrupted") {
addLogEntry(`Error generating text: ${error}`);
updateTextGenerationState("failed");
}
}
addLogEntry(
`Response generation took ${prettyMilliseconds(
new Date().getTime() - responseGenerationStartTime,
{ verbose: true },
)}`,
);
}
export async function generateChatResponse(
newMessages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
const settings = getSettings();
let response = "";
try {
const allMessages: ChatMessage[] = [
{
role: "user",
content: getSystemPrompt(getFormattedSearchResults(true)),
},
{ role: "assistant", content: "Ok!" },
...newMessages,
];
const lastMessagesReversed: ChatMessage[] = [];
let totalTokens = 0;
for (const message of allMessages.reverse()) {
const newTotalTokens =
totalTokens + gptTokenizer.encode(message.content).length;
if (newTotalTokens > defaultContextSize * 0.6) break;
totalTokens = newTotalTokens;
lastMessagesReversed.push(message);
}
const lastMessages = lastMessagesReversed.reverse();
if (settings.inferenceType === "openai") {
const { generateChatWithOpenAi } = await import(
"./textGenerationWithOpenAi"
);
response = await generateChatWithOpenAi(lastMessages, onUpdate);
} else if (settings.inferenceType === "internal") {
const { generateChatWithInternalApi } = await import(
"./textGenerationWithInternalApi"
);
response = await generateChatWithInternalApi(lastMessages, onUpdate);
} else if (settings.inferenceType === "horde") {
const { generateChatWithHorde } = await import(
"./textGenerationWithHorde"
);
response = await generateChatWithHorde(lastMessages, onUpdate);
} else {
if (isWebGPUAvailable && settings.enableWebGpu) {
const { generateChatWithWebLlm } = await import(
"./textGenerationWithWebLlm"
);
response = await generateChatWithWebLlm(lastMessages, onUpdate);
} else {
const { generateChatWithWllama } = await import(
"./textGenerationWithWllama"
);
response = await generateChatWithWllama(lastMessages, onUpdate);
}
}
} catch (error) {
if (error instanceof ChatGenerationError) {
addLogEntry(`Chat generation interrupted: ${error.message}`);
} else {
addLogEntry(`Error generating chat response: ${error}`);
}
throw error;
}
return response;
}
async function getKeywords(text: string, limit?: number) {
return (await import("keyword-extractor")).default
.extract(text, { language: "english" })
.slice(0, limit);
}
async function startTextSearch(query: string) {
const results = {
textResults: [] as TextSearchResults,
imageResults: [] as ImageSearchResults,
};
const searchQuery =
query.length > 2000 ? (await getKeywords(query, 20)).join(" ") : query;
if (getSettings().enableImageSearch) {
updateImageSearchState("running");
}
if (getSettings().enableTextSearch) {
updateTextSearchState("running");
let textResults = await searchText(
searchQuery,
getSettings().searchResultsLimit,
);
if (textResults.length === 0) {
const queryKeywords = await getKeywords(query, 10);
const keywordResults = await searchText(
queryKeywords.join(" "),
getSettings().searchResultsLimit,
);
textResults = keywordResults;
}
results.textResults = textResults;
updateTextSearchState(
results.textResults.length === 0 ? "failed" : "completed",
);
updateTextSearchResults(textResults);
}
if (getSettings().enableImageSearch) {
startImageSearch(searchQuery, results);
}
return results;
}
async function startImageSearch(
searchQuery: string,
results: { textResults: TextSearchResults; imageResults: ImageSearchResults },
) {
const imageResults = await searchImages(
searchQuery,
getSettings().searchResultsLimit,
);
results.imageResults = imageResults;
updateImageSearchState(
results.imageResults.length === 0 ? "failed" : "completed",
);
updateImageSearchResults(imageResults);
}
function canDownloadModels(): Promise<void> {
return new Promise((resolve) => {
if (getSettings().allowAiModelDownload) {
resolve();
} else {
updateTextGenerationState("awaitingModelDownloadAllowance");
listenToSettingsChanges((settings) => {
if (settings.allowAiModelDownload) {
resolve();
}
});
}
});
}