Spaces:
Build error
Build error
import { trimTopic } from "../utils"; | |
import Locale, { getLang } from "../locales"; | |
import { showToast } from "../components/ui-lib"; | |
import { ModelConfig, ModelType, useAppConfig } from "./config"; | |
import { createEmptyMask, Mask } from "./mask"; | |
import { | |
DEFAULT_INPUT_TEMPLATE, | |
DEFAULT_SYSTEM_TEMPLATE, | |
KnowledgeCutOffDate, | |
StoreKey, | |
SUMMARIZE_MODEL, | |
} from "../constant"; | |
import { api, RequestMessage } from "../client/api"; | |
import { ChatControllerPool } from "../client/controller"; | |
import { prettyObject } from "../utils/format"; | |
import { estimateTokenLength } from "../utils/token"; | |
import { nanoid } from "nanoid"; | |
import { Plugin, usePluginStore } from "../store/plugin"; | |
export interface ChatToolMessage { | |
toolName: string; | |
toolInput?: string; | |
} | |
import { createPersistStore } from "../utils/store"; | |
export type ChatMessage = RequestMessage & { | |
date: string; | |
toolMessages?: ChatToolMessage[]; | |
streaming?: boolean; | |
isError?: boolean; | |
id: string; | |
model?: ModelType; | |
}; | |
export function createMessage(override: Partial<ChatMessage>): ChatMessage { | |
return { | |
id: nanoid(), | |
date: new Date().toLocaleString(), | |
toolMessages: new Array<ChatToolMessage>(), | |
role: "user", | |
content: "", | |
...override, | |
}; | |
} | |
export interface ChatStat { | |
tokenCount: number; | |
wordCount: number; | |
charCount: number; | |
} | |
export interface ChatSession { | |
id: string; | |
topic: string; | |
memoryPrompt: string; | |
messages: ChatMessage[]; | |
stat: ChatStat; | |
lastUpdate: number; | |
lastSummarizeIndex: number; | |
clearContextIndex?: number; | |
mask: Mask; | |
} | |
export const DEFAULT_TOPIC = Locale.Store.DefaultTopic; | |
export const BOT_HELLO: ChatMessage = createMessage({ | |
role: "assistant", | |
content: Locale.Store.BotHello, | |
}); | |
function createEmptySession(): ChatSession { | |
return { | |
id: nanoid(), | |
topic: DEFAULT_TOPIC, | |
memoryPrompt: "", | |
messages: [], | |
stat: { | |
tokenCount: 0, | |
wordCount: 0, | |
charCount: 0, | |
}, | |
lastUpdate: Date.now(), | |
lastSummarizeIndex: 0, | |
mask: createEmptyMask(), | |
}; | |
} | |
function getSummarizeModel(currentModel: string) { | |
// if it is using gpt-* models, force to use 3.5 to summarize | |
return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel; | |
} | |
function countMessages(msgs: ChatMessage[]) { | |
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0); | |
} | |
function fillTemplateWith(input: string, modelConfig: ModelConfig) { | |
let cutoff = | |
KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default; | |
const vars = { | |
cutoff, | |
model: modelConfig.model, | |
time: new Date().toLocaleString(), | |
lang: getLang(), | |
input: input, | |
}; | |
let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE; | |
// must contains {{input}} | |
const inputVar = "{{input}}"; | |
if (!output.includes(inputVar)) { | |
output += "\n" + inputVar; | |
} | |
Object.entries(vars).forEach(([name, value]) => { | |
output = output.replaceAll(`{{${name}}}`, value); | |
}); | |
return output; | |
} | |
const DEFAULT_CHAT_STATE = { | |
sessions: [createEmptySession()], | |
currentSessionIndex: 0, | |
}; | |
export const useChatStore = createPersistStore( | |
DEFAULT_CHAT_STATE, | |
(set, _get) => { | |
function get() { | |
return { | |
..._get(), | |
...methods, | |
}; | |
} | |
const methods = { | |
clearSessions() { | |
set(() => ({ | |
sessions: [createEmptySession()], | |
currentSessionIndex: 0, | |
})); | |
}, | |
selectSession(index: number) { | |
set({ | |
currentSessionIndex: index, | |
}); | |
}, | |
moveSession(from: number, to: number) { | |
set((state) => { | |
const { sessions, currentSessionIndex: oldIndex } = state; | |
// move the session | |
const newSessions = [...sessions]; | |
const session = newSessions[from]; | |
newSessions.splice(from, 1); | |
newSessions.splice(to, 0, session); | |
// modify current session id | |
let newIndex = oldIndex === from ? to : oldIndex; | |
if (oldIndex > from && oldIndex <= to) { | |
newIndex -= 1; | |
} else if (oldIndex < from && oldIndex >= to) { | |
newIndex += 1; | |
} | |
return { | |
currentSessionIndex: newIndex, | |
sessions: newSessions, | |
}; | |
}); | |
}, | |
newSession(mask?: Mask) { | |
const session = createEmptySession(); | |
if (mask) { | |
const config = useAppConfig.getState(); | |
const globalModelConfig = config.modelConfig; | |
session.mask = { | |
...mask, | |
modelConfig: { | |
...globalModelConfig, | |
...mask.modelConfig, | |
}, | |
}; | |
session.topic = mask.name; | |
} | |
set((state) => ({ | |
currentSessionIndex: 0, | |
sessions: [session].concat(state.sessions), | |
})); | |
}, | |
nextSession(delta: number) { | |
const n = get().sessions.length; | |
const limit = (x: number) => (x + n) % n; | |
const i = get().currentSessionIndex; | |
get().selectSession(limit(i + delta)); | |
}, | |
deleteSession(index: number) { | |
const deletingLastSession = get().sessions.length === 1; | |
const deletedSession = get().sessions.at(index); | |
if (!deletedSession) return; | |
const sessions = get().sessions.slice(); | |
sessions.splice(index, 1); | |
const currentIndex = get().currentSessionIndex; | |
let nextIndex = Math.min( | |
currentIndex - Number(index < currentIndex), | |
sessions.length - 1, | |
); | |
if (deletingLastSession) { | |
nextIndex = 0; | |
sessions.push(createEmptySession()); | |
} | |
// for undo delete action | |
const restoreState = { | |
currentSessionIndex: get().currentSessionIndex, | |
sessions: get().sessions.slice(), | |
}; | |
set(() => ({ | |
currentSessionIndex: nextIndex, | |
sessions, | |
})); | |
showToast( | |
Locale.Home.DeleteToast, | |
{ | |
text: Locale.Home.Revert, | |
onClick() { | |
set(() => restoreState); | |
}, | |
}, | |
5000, | |
); | |
}, | |
currentSession() { | |
let index = get().currentSessionIndex; | |
const sessions = get().sessions; | |
if (index < 0 || index >= sessions.length) { | |
index = Math.min(sessions.length - 1, Math.max(0, index)); | |
set(() => ({ currentSessionIndex: index })); | |
} | |
const session = sessions[index]; | |
return session; | |
}, | |
onNewMessage(message: ChatMessage) { | |
get().updateCurrentSession((session) => { | |
session.messages = session.messages.concat(); | |
session.lastUpdate = Date.now(); | |
}); | |
get().updateStat(message); | |
get().summarizeSession(); | |
}, | |
async onUserInput(content: string) { | |
const session = get().currentSession(); | |
const modelConfig = session.mask.modelConfig; | |
const userContent = fillTemplateWith(content, modelConfig); | |
console.log("[User Input] after template: ", userContent); | |
const userMessage: ChatMessage = createMessage({ | |
role: "user", | |
content: userContent, | |
}); | |
const botMessage: ChatMessage = createMessage({ | |
role: "assistant", | |
streaming: true, | |
model: modelConfig.model, | |
toolMessages: [], | |
}); | |
// get recent messages | |
const recentMessages = get().getMessagesWithMemory(); | |
const sendMessages = recentMessages.concat(userMessage); | |
const messageIndex = get().currentSession().messages.length + 1; | |
const config = useAppConfig.getState(); | |
const pluginConfig = useAppConfig.getState().pluginConfig; | |
const pluginStore = usePluginStore.getState(); | |
const allPlugins = pluginStore | |
.getAll() | |
.filter( | |
(m) => | |
(!getLang() || | |
m.lang === (getLang() == "cn" ? getLang() : "en")) && | |
m.enable, | |
); | |
// save user's and bot's message | |
get().updateCurrentSession((session) => { | |
const savedUserMessage = { | |
...userMessage, | |
content, | |
}; | |
session.messages.push(savedUserMessage); | |
session.messages.push(botMessage); | |
}); | |
if ( | |
config.pluginConfig.enable && | |
session.mask.usePlugins && | |
allPlugins.length > 0 | |
) { | |
console.log("[ToolAgent] start"); | |
const pluginToolNames = allPlugins.map((m) => m.toolName); | |
api.llm.toolAgentChat({ | |
messages: sendMessages, | |
config: { ...modelConfig, stream: true }, | |
agentConfig: { ...pluginConfig, useTools: pluginToolNames }, | |
onUpdate(message) { | |
botMessage.streaming = true; | |
if (message) { | |
botMessage.content = message; | |
} | |
get().updateCurrentSession((session) => { | |
session.messages = session.messages.concat(); | |
}); | |
}, | |
onToolUpdate(toolName, toolInput) { | |
botMessage.streaming = true; | |
if (toolName && toolInput) { | |
botMessage.toolMessages!.push({ | |
toolName, | |
toolInput, | |
}); | |
} | |
get().updateCurrentSession((session) => { | |
session.messages = session.messages.concat(); | |
}); | |
}, | |
onFinish(message) { | |
botMessage.streaming = false; | |
if (message) { | |
botMessage.content = message; | |
get().onNewMessage(botMessage); | |
} | |
ChatControllerPool.remove(session.id, botMessage.id); | |
}, | |
onError(error) { | |
const isAborted = error.message.includes("aborted"); | |
botMessage.content += | |
"\n\n" + | |
prettyObject({ | |
error: true, | |
message: error.message, | |
}); | |
botMessage.streaming = false; | |
userMessage.isError = !isAborted; | |
botMessage.isError = !isAborted; | |
get().updateCurrentSession((session) => { | |
session.messages = session.messages.concat(); | |
}); | |
ChatControllerPool.remove( | |
session.id, | |
botMessage.id ?? messageIndex, | |
); | |
console.error("[Chat] failed ", error); | |
}, | |
onController(controller) { | |
// collect controller for stop/retry | |
ChatControllerPool.addController( | |
session.id, | |
botMessage.id ?? messageIndex, | |
controller, | |
); | |
}, | |
}); | |
} else { | |
// make request | |
api.llm.chat({ | |
messages: sendMessages, | |
config: { ...modelConfig, stream: true }, | |
onUpdate(message) { | |
botMessage.streaming = true; | |
if (message) { | |
botMessage.content = message; | |
} | |
get().updateCurrentSession((session) => { | |
session.messages = session.messages.concat(); | |
}); | |
}, | |
onFinish(message) { | |
botMessage.streaming = false; | |
if (message) { | |
botMessage.content = message; | |
get().onNewMessage(botMessage); | |
} | |
ChatControllerPool.remove(session.id, botMessage.id); | |
}, | |
onError(error) { | |
const isAborted = error.message.includes("aborted"); | |
botMessage.content += | |
"\n\n" + | |
prettyObject({ | |
error: true, | |
message: error.message, | |
}); | |
botMessage.streaming = false; | |
userMessage.isError = !isAborted; | |
botMessage.isError = !isAborted; | |
get().updateCurrentSession((session) => { | |
session.messages = session.messages.concat(); | |
}); | |
ChatControllerPool.remove( | |
session.id, | |
botMessage.id ?? messageIndex, | |
); | |
console.error("[Chat] failed ", error); | |
}, | |
onController(controller) { | |
// collect controller for stop/retry | |
ChatControllerPool.addController( | |
session.id, | |
botMessage.id ?? messageIndex, | |
controller, | |
); | |
}, | |
}); | |
} | |
}, | |
getMemoryPrompt() { | |
const session = get().currentSession(); | |
return { | |
role: "system", | |
content: | |
session.memoryPrompt.length > 0 | |
? Locale.Store.Prompt.History(session.memoryPrompt) | |
: "", | |
date: "", | |
} as ChatMessage; | |
}, | |
getMessagesWithMemory() { | |
const session = get().currentSession(); | |
const modelConfig = session.mask.modelConfig; | |
const clearContextIndex = session.clearContextIndex ?? 0; | |
const messages = session.messages.slice(); | |
const totalMessageCount = session.messages.length; | |
// in-context prompts | |
const contextPrompts = session.mask.context.slice(); | |
// system prompts, to get close to OpenAI Web ChatGPT | |
const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts; | |
const systemPrompts = shouldInjectSystemPrompts | |
? [ | |
createMessage({ | |
role: "system", | |
content: fillTemplateWith("", { | |
...modelConfig, | |
template: DEFAULT_SYSTEM_TEMPLATE, | |
}), | |
}), | |
] | |
: []; | |
if (shouldInjectSystemPrompts) { | |
console.log( | |
"[Global System Prompt] ", | |
systemPrompts.at(0)?.content ?? "empty", | |
); | |
} | |
// long term memory | |
const shouldSendLongTermMemory = | |
modelConfig.sendMemory && | |
session.memoryPrompt && | |
session.memoryPrompt.length > 0 && | |
session.lastSummarizeIndex > clearContextIndex; | |
const longTermMemoryPrompts = shouldSendLongTermMemory | |
? [get().getMemoryPrompt()] | |
: []; | |
const longTermMemoryStartIndex = session.lastSummarizeIndex; | |
// short term memory | |
const shortTermMemoryStartIndex = Math.max( | |
0, | |
totalMessageCount - modelConfig.historyMessageCount, | |
); | |
// lets concat send messages, including 4 parts: | |
// 0. system prompt: to get close to OpenAI Web ChatGPT | |
// 1. long term memory: summarized memory messages | |
// 2. pre-defined in-context prompts | |
// 3. short term memory: latest n messages | |
// 4. newest input message | |
const memoryStartIndex = shouldSendLongTermMemory | |
? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex) | |
: shortTermMemoryStartIndex; | |
// and if user has cleared history messages, we should exclude the memory too. | |
const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex); | |
const maxTokenThreshold = modelConfig.max_tokens; | |
// get recent messages as much as possible | |
const reversedRecentMessages = []; | |
for ( | |
let i = totalMessageCount - 1, tokenCount = 0; | |
i >= contextStartIndex && tokenCount < maxTokenThreshold; | |
i -= 1 | |
) { | |
const msg = messages[i]; | |
if (!msg || msg.isError) continue; | |
tokenCount += estimateTokenLength(msg.content); | |
reversedRecentMessages.push(msg); | |
} | |
// concat all messages | |
const recentMessages = [ | |
...systemPrompts, | |
...longTermMemoryPrompts, | |
...contextPrompts, | |
...reversedRecentMessages.reverse(), | |
]; | |
return recentMessages; | |
}, | |
updateMessage( | |
sessionIndex: number, | |
messageIndex: number, | |
updater: (message?: ChatMessage) => void, | |
) { | |
const sessions = get().sessions; | |
const session = sessions.at(sessionIndex); | |
const messages = session?.messages; | |
updater(messages?.at(messageIndex)); | |
set(() => ({ sessions })); | |
}, | |
resetSession() { | |
get().updateCurrentSession((session) => { | |
session.messages = []; | |
session.memoryPrompt = ""; | |
}); | |
}, | |
summarizeSession() { | |
const config = useAppConfig.getState(); | |
const session = get().currentSession(); | |
// remove error messages if any | |
const messages = session.messages; | |
// should summarize topic after chating more than 50 words | |
const SUMMARIZE_MIN_LEN = 50; | |
if ( | |
config.enableAutoGenerateTitle && | |
session.topic === DEFAULT_TOPIC && | |
countMessages(messages) >= SUMMARIZE_MIN_LEN | |
) { | |
const topicMessages = messages.concat( | |
createMessage({ | |
role: "user", | |
content: Locale.Store.Prompt.Topic, | |
}), | |
); | |
api.llm.chat({ | |
messages: topicMessages, | |
config: { | |
model: getSummarizeModel(session.mask.modelConfig.model), | |
}, | |
onFinish(message) { | |
get().updateCurrentSession( | |
(session) => | |
(session.topic = | |
message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC), | |
); | |
}, | |
}); | |
} | |
const modelConfig = session.mask.modelConfig; | |
const summarizeIndex = Math.max( | |
session.lastSummarizeIndex, | |
session.clearContextIndex ?? 0, | |
); | |
let toBeSummarizedMsgs = messages | |
.filter((msg) => !msg.isError) | |
.slice(summarizeIndex); | |
const historyMsgLength = countMessages(toBeSummarizedMsgs); | |
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) { | |
const n = toBeSummarizedMsgs.length; | |
toBeSummarizedMsgs = toBeSummarizedMsgs.slice( | |
Math.max(0, n - modelConfig.historyMessageCount), | |
); | |
} | |
// add memory prompt | |
toBeSummarizedMsgs.unshift(get().getMemoryPrompt()); | |
const lastSummarizeIndex = session.messages.length; | |
console.log( | |
"[Chat History] ", | |
toBeSummarizedMsgs, | |
historyMsgLength, | |
modelConfig.compressMessageLengthThreshold, | |
); | |
if ( | |
historyMsgLength > modelConfig.compressMessageLengthThreshold && | |
modelConfig.sendMemory | |
) { | |
api.llm.chat({ | |
messages: toBeSummarizedMsgs.concat( | |
createMessage({ | |
role: "system", | |
content: Locale.Store.Prompt.Summarize, | |
date: "", | |
}), | |
), | |
config: { | |
...modelConfig, | |
stream: true, | |
model: getSummarizeModel(session.mask.modelConfig.model), | |
}, | |
onUpdate(message) { | |
session.memoryPrompt = message; | |
}, | |
onFinish(message) { | |
console.log("[Memory] ", message); | |
session.lastSummarizeIndex = lastSummarizeIndex; | |
}, | |
onError(err) { | |
console.error("[Summarize] ", err); | |
}, | |
}); | |
} | |
}, | |
updateStat(message: ChatMessage) { | |
get().updateCurrentSession((session) => { | |
session.stat.charCount += message.content.length; | |
// TODO: should update chat count and word count | |
}); | |
}, | |
updateCurrentSession(updater: (session: ChatSession) => void) { | |
const sessions = get().sessions; | |
const index = get().currentSessionIndex; | |
updater(sessions[index]); | |
set(() => ({ sessions })); | |
}, | |
clearAllData() { | |
localStorage.clear(); | |
location.reload(); | |
}, | |
}; | |
return methods; | |
}, | |
{ | |
name: StoreKey.Chat, | |
version: 3.1, | |
migrate(persistedState, version) { | |
const state = persistedState as any; | |
const newState = JSON.parse( | |
JSON.stringify(state), | |
) as typeof DEFAULT_CHAT_STATE; | |
if (version < 2) { | |
newState.sessions = []; | |
const oldSessions = state.sessions; | |
for (const oldSession of oldSessions) { | |
const newSession = createEmptySession(); | |
newSession.topic = oldSession.topic; | |
newSession.messages = [...oldSession.messages]; | |
newSession.mask.modelConfig.sendMemory = true; | |
newSession.mask.modelConfig.historyMessageCount = 4; | |
newSession.mask.modelConfig.compressMessageLengthThreshold = 1000; | |
newState.sessions.push(newSession); | |
} | |
} | |
if (version < 3) { | |
// migrate id to nanoid | |
newState.sessions.forEach((s) => { | |
s.id = nanoid(); | |
s.messages.forEach((m) => (m.id = nanoid())); | |
}); | |
} | |
// Enable `enableInjectSystemPrompts` attribute for old sessions. | |
// Resolve issue of old sessions not automatically enabling. | |
if (version < 3.1) { | |
newState.sessions.forEach((s) => { | |
if ( | |
// Exclude those already set by user | |
!s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts") | |
) { | |
// Because users may have changed this configuration, | |
// the user's current configuration is used instead of the default | |
const config = useAppConfig.getState(); | |
s.mask.modelConfig.enableInjectSystemPrompts = | |
config.modelConfig.enableInjectSystemPrompts; | |
} | |
}); | |
} | |
return newState as any; | |
}, | |
}, | |
); | |