SumAI-Search / client /modules /textGeneration.ts
github-actions[bot]
Sync to HuggingFace Spaces
6b3405c
import { isWebGPUAvailable } from "./webGpu";
import {
updateSearchResults,
updateResponse,
getSearchResults,
getQuery,
updateSearchPromise,
getSearchPromise,
updateTextGenerationState,
updateSearchState,
updateModelLoadingProgress,
getTextGenerationState,
getSettings,
listenToSettingsChanges,
} from "./pubSub";
import { search } from "./search";
import { addLogEntry } from "./logEntries";
import { getSystemPrompt } from "./systemPrompt";
import prettyMilliseconds from "pretty-ms";
import { getOpenAiClient } from "./openai";
import { getSearchTokenHash } from "./searchTokenHash";
import {
ChatCompletionCreateParamsStreaming,
ChatCompletionMessageParam,
} from "openai/resources/chat/completions.mjs";
import gptTokenizer from "gpt-tokenizer";
import { ChatOptions } from "@mlc-ai/web-llm";
import { defaultSettings } from "./settings";
export async function searchAndRespond() {
if (getQuery() === "") return;
document.title = getQuery();
updateResponse("");
updateSearchResults({ textResults: [], imageResults: [] });
updateSearchPromise(startSearch(getQuery()));
if (!getSettings().enableAiResponse) return;
const responseGenerationStartTime = new Date().getTime();
updateTextGenerationState("loadingModel");
try {
const settings = getSettings();
if (settings.inferenceType === "openai") {
await generateTextWithOpenAI();
} else if (settings.inferenceType === "internal") {
await generateTextWithInternalApi();
} else {
await canDownloadModels();
updateTextGenerationState("loadingModel");
try {
if (!isWebGPUAvailable) throw Error("WebGPU is not available.");
if (!settings.enableWebGpu) throw Error("WebGPU is disabled.");
await generateTextWithWebLlm();
} catch (error) {
addLogEntry(`Skipping text generation with WebLLM: ${error}`);
addLogEntry(`Starting text generation with Wllama`);
await generateTextWithWllama();
}
}
if (getTextGenerationState() !== "interrupted") {
updateTextGenerationState("completed");
}
} catch (error) {
addLogEntry(`Error generating text: ${error}`);
updateTextGenerationState("failed");
}
addLogEntry(
`Response generation took ${prettyMilliseconds(
new Date().getTime() - responseGenerationStartTime,
{ verbose: true },
)}`,
);
}
async function generateTextWithOpenAI() {
const settings = getSettings();
const openai = getOpenAiClient({
baseURL: settings.openAiApiBaseUrl,
apiKey: settings.openAiApiKey,
});
await canStartResponding();
updateTextGenerationState("preparingToGenerate");
const completion = await openai.chat.completions.create({
...getDefaultChatCompletionCreateParamsStreaming(),
model: settings.openAiApiModel,
messages: [
{
role: "user",
content: getSystemPrompt(getFormattedSearchResults(true)),
},
{ role: "assistant", content: "Ok!" },
{ role: "user", content: getQuery() },
],
});
let streamedMessage = "";
for await (const chunk of completion) {
const deltaContent = chunk.choices[0].delta.content;
if (deltaContent) streamedMessage += deltaContent;
if (getTextGenerationState() === "interrupted") {
completion.controller.abort();
} else if (getTextGenerationState() !== "generating") {
updateTextGenerationState("generating");
}
updateResponseRateLimited(streamedMessage);
}
updateResponse(streamedMessage);
}
async function generateTextWithInternalApi() {
await canStartResponding();
updateTextGenerationState("preparingToGenerate");
const inferenceUrl = new URL("/inference", self.location.origin);
const tokenPrefix = "Bearer ";
const token = await getSearchTokenHash();
const response = await fetch(inferenceUrl.toString(), {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `${tokenPrefix}${token}`,
},
body: JSON.stringify({
...getDefaultChatCompletionCreateParamsStreaming(),
messages: [
{
role: "user",
content: getSystemPrompt(getFormattedSearchResults(true)),
},
{ role: "assistant", content: "Ok!" },
{ role: "user", content: getQuery() },
],
} as ChatCompletionCreateParamsStreaming),
});
if (!response.ok || !response.body) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder("utf-8");
let streamedMessage = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split("\n");
const parsedLines = lines
.map((line) => line.replace(/^data: /, "").trim())
.filter((line) => line !== "" && line !== "[DONE]")
.map((line) => JSON.parse(line));
for (const parsedLine of parsedLines) {
const deltaContent = parsedLine.choices[0].delta.content;
if (deltaContent) streamedMessage += deltaContent;
if (getTextGenerationState() === "interrupted") {
reader.cancel();
} else if (getTextGenerationState() !== "generating") {
updateTextGenerationState("generating");
}
updateResponseRateLimited(streamedMessage);
}
}
updateResponse(streamedMessage);
}
async function generateTextWithWebLlm() {
const { CreateWebWorkerMLCEngine, CreateMLCEngine, hasModelInCache } =
await import("@mlc-ai/web-llm");
type InitProgressCallback = import("@mlc-ai/web-llm").InitProgressCallback;
type MLCEngineConfig = import("@mlc-ai/web-llm").MLCEngineConfig;
type ChatOptions = import("@mlc-ai/web-llm").ChatOptions;
const selectedModelId = getSettings().webLlmModelId;
addLogEntry(`Selected WebLLM model: ${selectedModelId}`);
const isModelCached = await hasModelInCache(selectedModelId);
let initProgressCallback: InitProgressCallback | undefined;
if (isModelCached) {
updateTextGenerationState("preparingToGenerate");
} else {
initProgressCallback = (report) => {
updateModelLoadingProgress(Math.round(report.progress * 100));
};
}
const engineConfig: MLCEngineConfig = {
initProgressCallback,
logLevel: "SILENT",
};
const chatOptions: ChatOptions = {
repetition_penalty: defaultSettings.inferenceRepeatPenalty,
};
const engine = Worker
? await CreateWebWorkerMLCEngine(
new Worker(new URL("./webLlmWorker.ts", import.meta.url), {
type: "module",
}),
selectedModelId,
engineConfig,
chatOptions,
)
: await CreateMLCEngine(selectedModelId, engineConfig, chatOptions);
if (getSettings().enableAiResponse) {
await canStartResponding();
updateTextGenerationState("preparingToGenerate");
const completion = await engine.chat.completions.create({
...getDefaultChatCompletionCreateParamsStreaming(),
messages: [
{
role: "user",
content: getSystemPrompt(getFormattedSearchResults(true)),
},
{ role: "assistant", content: "Ok!" },
{ role: "user", content: getQuery() },
],
});
let streamedMessage = "";
for await (const chunk of completion) {
const deltaContent = chunk.choices[0].delta.content;
if (deltaContent) streamedMessage += deltaContent;
if (getTextGenerationState() === "interrupted") {
await engine.interruptGenerate();
} else if (getTextGenerationState() !== "generating") {
updateTextGenerationState("generating");
}
updateResponseRateLimited(streamedMessage);
}
updateResponse(streamedMessage);
}
addLogEntry(
`WebLLM finished generating the response. Stats: ${await engine.runtimeStatsText()}`,
);
engine.unload();
}
async function generateTextWithWllama() {
const { initializeWllama, wllamaModels } = await import("./wllama");
let loadingPercentage = 0;
const model = wllamaModels[getSettings().wllamaModelId];
const wllama = await initializeWllama(model.url, {
wllama: {
suppressNativeLog: true,
},
model: {
n_threads: getSettings().cpuThreads,
n_ctx: model.contextSize,
cache_type_k: model.cacheType,
embeddings: false,
allowOffline: true,
progressCallback: ({ loaded, total }) => {
const progressPercentage = Math.round((loaded / total) * 100);
if (loadingPercentage !== progressPercentage) {
loadingPercentage = progressPercentage;
updateModelLoadingProgress(progressPercentage);
}
},
},
});
if (getSettings().enableAiResponse) {
await canStartResponding();
updateTextGenerationState("preparingToGenerate");
const prompt = await model.buildPrompt(
wllama,
getQuery(),
getFormattedSearchResults(model.shouldIncludeUrlsOnPrompt),
);
let streamedMessage = "";
await wllama.createCompletion(prompt, {
stopTokens: model.stopTokens,
sampling: model.sampling,
onNewToken: (_token, _piece, currentText, { abortSignal }) => {
if (getTextGenerationState() === "interrupted") {
abortSignal();
} else if (getTextGenerationState() !== "generating") {
updateTextGenerationState("generating");
}
if (model.stopStrings) {
for (const stopString of model.stopStrings) {
if (
currentText.slice(-(stopString.length * 2)).includes(stopString)
) {
abortSignal();
currentText = currentText.slice(0, -stopString.length);
break;
}
}
}
streamedMessage = currentText;
updateResponseRateLimited(streamedMessage);
},
});
updateResponse(streamedMessage);
}
await wllama.exit();
}
function getFormattedSearchResults(shouldIncludeUrl: boolean) {
const searchResults = getSearchResults().textResults.slice(
0,
getSettings().searchResultsToConsider,
);
if (searchResults.length === 0) return "None.";
if (shouldIncludeUrl) {
return searchResults
.map(
([title, snippet, url], index) =>
`${index + 1}. [${title}](${url}) | ${snippet}`,
)
.join("\n");
}
return searchResults
.map(([title, snippet]) => `- ${title} | ${snippet}`)
.join("\n");
}
async function getKeywords(text: string, limit?: number) {
return (await import("keyword-extractor")).default
.extract(text, { language: "english" })
.slice(0, limit);
}
async function startSearch(query: string) {
updateSearchState("running");
let searchResults = await search(
query.length > 2000 ? (await getKeywords(query, 20)).join(" ") : query,
30,
);
if (searchResults.textResults.length === 0) {
const queryKeywords = await getKeywords(query, 10);
searchResults = await search(queryKeywords.join(" "), 30);
}
updateSearchState(
searchResults.textResults.length === 0 ? "failed" : "completed",
);
updateSearchResults(searchResults);
return searchResults;
}
async function canStartResponding() {
if (getSettings().searchResultsToConsider > 0) {
updateTextGenerationState("awaitingSearchResults");
await getSearchPromise();
}
}
function updateResponseRateLimited(text: string) {
const currentTime = Date.now();
if (
currentTime - updateResponseRateLimited.lastUpdateTime >=
updateResponseRateLimited.updateInterval
) {
updateResponse(text);
updateResponseRateLimited.lastUpdateTime = currentTime;
}
}
updateResponseRateLimited.lastUpdateTime = 0;
updateResponseRateLimited.updateInterval = 1000 / 12;
class ChatGenerationError extends Error {
constructor(message: string) {
super(message);
this.name = "ChatGenerationError";
}
}
async function generateChatWithOpenAI(
messages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
const settings = getSettings();
const openai = getOpenAiClient({
baseURL: settings.openAiApiBaseUrl,
apiKey: settings.openAiApiKey,
});
const completion = await openai.chat.completions.create({
...getDefaultChatCompletionCreateParamsStreaming(),
model: settings.openAiApiModel,
messages: messages as ChatCompletionMessageParam[],
});
let streamedMessage = "";
for await (const chunk of completion) {
const deltaContent = chunk.choices[0].delta.content;
if (deltaContent) {
streamedMessage += deltaContent;
onUpdate(streamedMessage);
}
if (getTextGenerationState() === "interrupted") {
completion.controller.abort();
throw new ChatGenerationError("Chat generation interrupted");
}
}
return streamedMessage;
}
async function generateChatWithInternalApi(
messages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
const inferenceUrl = new URL("/inference", self.location.origin);
const tokenPrefix = "Bearer ";
const token = await getSearchTokenHash();
const response = await fetch(inferenceUrl.toString(), {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `${tokenPrefix}${token}`,
},
body: JSON.stringify({
...getDefaultChatCompletionCreateParamsStreaming(),
messages,
} as ChatCompletionCreateParamsStreaming),
});
if (!response.ok || !response.body) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder("utf-8");
let streamedMessage = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split("\n");
const parsedLines = lines
.map((line) => line.replace(/^data: /, "").trim())
.filter((line) => line !== "" && line !== "[DONE]")
.map((line) => JSON.parse(line));
for (const parsedLine of parsedLines) {
const deltaContent = parsedLine.choices[0].delta.content;
if (deltaContent) {
streamedMessage += deltaContent;
onUpdate(streamedMessage);
}
if (getTextGenerationState() === "interrupted") {
reader.cancel();
throw new ChatGenerationError("Chat generation interrupted");
}
}
}
return streamedMessage;
}
async function generateChatWithWebLlm(
messages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
const { CreateWebWorkerMLCEngine, CreateMLCEngine } = await import(
"@mlc-ai/web-llm"
);
type MLCEngineConfig = import("@mlc-ai/web-llm").MLCEngineConfig;
type ChatCompletionMessageParam =
import("@mlc-ai/web-llm").ChatCompletionMessageParam;
const selectedModelId = getSettings().webLlmModelId;
addLogEntry(`Selected WebLLM model for chat: ${selectedModelId}`);
const engineConfig: MLCEngineConfig = {
logLevel: "SILENT",
};
const chatOptions: ChatOptions = {
repetition_penalty: defaultSettings.inferenceRepeatPenalty,
};
const engine = Worker
? await CreateWebWorkerMLCEngine(
new Worker(new URL("./webLlmWorker.ts", import.meta.url), {
type: "module",
}),
selectedModelId,
engineConfig,
chatOptions,
)
: await CreateMLCEngine(selectedModelId, engineConfig, chatOptions);
const completion = await engine.chat.completions.create({
...getDefaultChatCompletionCreateParamsStreaming(),
messages: messages as ChatCompletionMessageParam[],
});
let streamedMessage = "";
for await (const chunk of completion) {
const deltaContent = chunk.choices[0].delta.content;
if (deltaContent) {
streamedMessage += deltaContent;
onUpdate(streamedMessage);
}
if (getTextGenerationState() === "interrupted") {
await engine.interruptGenerate();
throw new ChatGenerationError("Chat generation interrupted");
}
}
addLogEntry(
`WebLLM finished generating the chat response. Stats: ${await engine.runtimeStatsText()}`,
);
engine.unload();
return streamedMessage;
}
async function generateChatWithWllama(
messages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
const { initializeWllama, wllamaModels } = await import("./wllama");
const model = wllamaModels[getSettings().wllamaModelId];
const wllama = await initializeWllama(model.url, {
wllama: {
suppressNativeLog: true,
},
model: {
n_threads: getSettings().cpuThreads,
n_ctx: model.contextSize,
cache_type_k: model.cacheType,
embeddings: false,
allowOffline: true,
},
});
const prompt = await model.buildPrompt(
wllama,
messages[messages.length - 1].content,
getFormattedSearchResults(model.shouldIncludeUrlsOnPrompt),
);
let streamedMessage = "";
await wllama.createCompletion(prompt, {
stopTokens: model.stopTokens,
sampling: model.sampling,
onNewToken: (_token, _piece, currentText, { abortSignal }) => {
if (getTextGenerationState() === "interrupted") {
abortSignal();
throw new ChatGenerationError("Chat generation interrupted");
}
if (model.stopStrings) {
for (const stopString of model.stopStrings) {
if (
currentText.slice(-(stopString.length * 2)).includes(stopString)
) {
abortSignal();
currentText = currentText.slice(0, -stopString.length);
break;
}
}
}
streamedMessage = currentText;
onUpdate(streamedMessage);
},
});
await wllama.exit();
return streamedMessage;
}
export async function generateChatResponse(
newMessages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
const settings = getSettings();
let response = "";
try {
const allMessages = [
{
role: "user",
content: getSystemPrompt(getFormattedSearchResults(true)),
},
{ role: "assistant", content: "Ok!" },
...newMessages,
];
const lastMessagesReversed: ChatMessage[] = [];
let totalTokens = 0;
for (const message of allMessages.reverse()) {
const newTotalTokens =
totalTokens + gptTokenizer.encode(message.content).length;
if (newTotalTokens > 1280) break;
totalTokens = newTotalTokens;
lastMessagesReversed.push(message);
}
const lastMessages = lastMessagesReversed.reverse();
if (settings.inferenceType === "openai") {
response = await generateChatWithOpenAI(lastMessages, onUpdate);
} else if (settings.inferenceType === "internal") {
response = await generateChatWithInternalApi(lastMessages, onUpdate);
} else {
if (isWebGPUAvailable && settings.enableWebGpu) {
response = await generateChatWithWebLlm(lastMessages, onUpdate);
} else {
response = await generateChatWithWllama(lastMessages, onUpdate);
}
}
} catch (error) {
if (error instanceof ChatGenerationError) {
addLogEntry(`Chat generation interrupted: ${error.message}`);
} else {
addLogEntry(`Error generating chat response: ${error}`);
}
throw error;
}
return response;
}
export interface ChatMessage {
role: "user" | "assistant" | string;
content: string;
}
function canDownloadModels(): Promise<void> {
return new Promise((resolve) => {
if (getSettings().allowAiModelDownload) {
resolve();
} else {
updateTextGenerationState("awaitingModelDownloadAllowance");
listenToSettingsChanges((settings) => {
if (settings.allowAiModelDownload) {
resolve();
}
});
}
});
}
function getDefaultChatCompletionCreateParamsStreaming() {
return {
stream: true,
max_tokens: 2048,
temperature: defaultSettings.inferenceTemperature,
top_p: defaultSettings.inferenceTopP,
frequency_penalty: defaultSettings.inferenceFrequencyPenalty,
presence_penalty: defaultSettings.inferencePresencePenalty,
} as const;
}