miniSearch / client /modules /textGeneration.ts
github-actions[bot]
Sync to HuggingFace Spaces
f152ae2
import gptTokenizer from "gpt-tokenizer";
import type { ChatMessage } from "gpt-tokenizer/GptEncoding";
import prettyMilliseconds from "pretty-ms";
import { addLogEntry } from "./logEntries";
import {
getQuery,
getSettings,
getTextGenerationState,
listenToSettingsChanges,
updateResponse,
updateSearchPromise,
updateSearchResults,
updateSearchState,
updateTextGenerationState,
} from "./pubSub";
import { search } from "./search";
import { getSystemPrompt } from "./systemPrompt";
import {
ChatGenerationError,
getFormattedSearchResults,
} from "./textGenerationUtilities";
import { isWebGPUAvailable } from "./webGpu";
export async function searchAndRespond() {
if (getQuery() === "") return;
document.title = getQuery();
updateResponse("");
updateSearchResults({ textResults: [], imageResults: [] });
updateSearchPromise(startSearch(getQuery()));
if (!getSettings().enableAiResponse) return;
const responseGenerationStartTime = new Date().getTime();
updateTextGenerationState("loadingModel");
try {
const settings = getSettings();
if (settings.inferenceType === "openai") {
const { generateTextWithOpenAi } = await import(
"./textGenerationWithOpenAi"
);
await generateTextWithOpenAi();
} else if (settings.inferenceType === "internal") {
const { generateTextWithInternalApi } = await import(
"./textGenerationWithInternalApi"
);
await generateTextWithInternalApi();
} else {
await canDownloadModels();
updateTextGenerationState("loadingModel");
if (isWebGPUAvailable && settings.enableWebGpu) {
const { generateTextWithWebLlm } = await import(
"./textGenerationWithWebLlm"
);
await generateTextWithWebLlm();
} else {
const { generateTextWithWllama } = await import(
"./textGenerationWithWllama"
);
await generateTextWithWllama();
}
}
updateTextGenerationState("completed");
} catch (error) {
if (getTextGenerationState() !== "interrupted") {
addLogEntry(`Error generating text: ${error}`);
updateTextGenerationState("failed");
}
}
addLogEntry(
`Response generation took ${prettyMilliseconds(
new Date().getTime() - responseGenerationStartTime,
{ verbose: true },
)}`,
);
}
export async function generateChatResponse(
newMessages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
const settings = getSettings();
let response = "";
try {
const allMessages: ChatMessage[] = [
{
role: "user",
content: getSystemPrompt(getFormattedSearchResults(true)),
},
{ role: "assistant", content: "Ok!" },
...newMessages,
];
const lastMessagesReversed: ChatMessage[] = [];
let totalTokens = 0;
for (const message of allMessages.reverse()) {
const newTotalTokens =
totalTokens + gptTokenizer.encode(message.content).length;
if (newTotalTokens > 1280) break;
totalTokens = newTotalTokens;
lastMessagesReversed.push(message);
}
const lastMessages = lastMessagesReversed.reverse();
if (settings.inferenceType === "openai") {
const { generateChatWithOpenAi } = await import(
"./textGenerationWithOpenAi"
);
response = await generateChatWithOpenAi(lastMessages, onUpdate);
} else if (settings.inferenceType === "internal") {
const { generateChatWithInternalApi } = await import(
"./textGenerationWithInternalApi"
);
response = await generateChatWithInternalApi(lastMessages, onUpdate);
} else {
if (isWebGPUAvailable && settings.enableWebGpu) {
const { generateChatWithWebLlm } = await import(
"./textGenerationWithWebLlm"
);
response = await generateChatWithWebLlm(lastMessages, onUpdate);
} else {
const { generateChatWithWllama } = await import(
"./textGenerationWithWllama"
);
response = await generateChatWithWllama(lastMessages, onUpdate);
}
}
} catch (error) {
if (error instanceof ChatGenerationError) {
addLogEntry(`Chat generation interrupted: ${error.message}`);
} else {
addLogEntry(`Error generating chat response: ${error}`);
}
throw error;
}
return response;
}
async function getKeywords(text: string, limit?: number) {
return (await import("keyword-extractor")).default
.extract(text, { language: "english" })
.slice(0, limit);
}
async function startSearch(query: string) {
updateSearchState("running");
const searchResultsLimit = 15;
let searchResults = await search(
query.length > 2000 ? (await getKeywords(query, 20)).join(" ") : query,
searchResultsLimit,
);
if (searchResults.textResults.length === 0) {
const queryKeywords = await getKeywords(query, 10);
searchResults = await search(queryKeywords.join(" "), searchResultsLimit);
}
updateSearchState(
searchResults.textResults.length === 0 ? "failed" : "completed",
);
updateSearchResults(searchResults);
return searchResults;
}
function canDownloadModels(): Promise<void> {
return new Promise((resolve) => {
if (getSettings().allowAiModelDownload) {
resolve();
} else {
updateTextGenerationState("awaitingModelDownloadAllowance");
listenToSettingsChanges((settings) => {
if (settings.allowAiModelDownload) {
resolve();
}
});
}
});
}