|
import React, { createContext, useContext, useEffect, useState } from 'react'; |
|
import { |
|
APIMessage, |
|
CanvasData, |
|
Conversation, |
|
Message, |
|
PendingMessage, |
|
ViewingChat, |
|
} from './types'; |
|
import StorageUtils from './storage'; |
|
import { |
|
filterThoughtFromMsgs, |
|
normalizeMsgsForAPI, |
|
getSSEStreamAsync, |
|
} from './misc'; |
|
import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config'; |
|
import { matchPath, useLocation, useNavigate } from 'react-router'; |
|
|
|
interface AppContextValue { |
|
|
|
viewingChat: ViewingChat | null; |
|
pendingMessages: Record<Conversation['id'], PendingMessage>; |
|
isGenerating: (convId: string) => boolean; |
|
sendMessage: ( |
|
convId: string | null, |
|
leafNodeId: Message['id'] | null, |
|
content: string, |
|
extra: Message['extra'], |
|
onChunk: CallbackGeneratedChunk |
|
) => Promise<boolean>; |
|
stopGenerating: (convId: string) => void; |
|
replaceMessageAndGenerate: ( |
|
convId: string, |
|
parentNodeId: Message['id'], |
|
content: string | null, |
|
extra: Message['extra'], |
|
onChunk: CallbackGeneratedChunk |
|
) => Promise<void>; |
|
|
|
|
|
canvasData: CanvasData | null; |
|
setCanvasData: (data: CanvasData | null) => void; |
|
|
|
|
|
config: typeof CONFIG_DEFAULT; |
|
saveConfig: (config: typeof CONFIG_DEFAULT) => void; |
|
showSettings: boolean; |
|
setShowSettings: (show: boolean) => void; |
|
} |
|
|
|
|
|
export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void; |
|
|
|
|
|
const AppContext = createContext<AppContextValue>({} as any); |
|
|
|
const getViewingChat = async (convId: string): Promise<ViewingChat | null> => { |
|
const conv = await StorageUtils.getOneConversation(convId); |
|
if (!conv) return null; |
|
return { |
|
conv: conv, |
|
|
|
messages: await StorageUtils.getMessages(convId), |
|
}; |
|
}; |
|
|
|
export const AppContextProvider = ({ |
|
children, |
|
}: { |
|
children: React.ReactElement; |
|
}) => { |
|
const { pathname } = useLocation(); |
|
const navigate = useNavigate(); |
|
const params = matchPath('/chat/:convId', pathname); |
|
const convId = params?.params?.convId; |
|
|
|
const [viewingChat, setViewingChat] = useState<ViewingChat | null>(null); |
|
const [pendingMessages, setPendingMessages] = useState< |
|
Record<Conversation['id'], PendingMessage> |
|
>({}); |
|
const [aborts, setAborts] = useState< |
|
Record<Conversation['id'], AbortController> |
|
>({}); |
|
const [config, setConfig] = useState(StorageUtils.getConfig()); |
|
const [canvasData, setCanvasData] = useState<CanvasData | null>(null); |
|
const [showSettings, setShowSettings] = useState(false); |
|
|
|
|
|
useEffect(() => { |
|
|
|
setCanvasData(null); |
|
const handleConversationChange = async (changedConvId: string) => { |
|
if (changedConvId !== convId) return; |
|
setViewingChat(await getViewingChat(changedConvId)); |
|
}; |
|
StorageUtils.onConversationChanged(handleConversationChange); |
|
getViewingChat(convId ?? '').then(setViewingChat); |
|
return () => { |
|
StorageUtils.offConversationChanged(handleConversationChange); |
|
}; |
|
}, [convId]); |
|
|
|
const setPending = (convId: string, pendingMsg: PendingMessage | null) => { |
|
|
|
if (!pendingMsg) { |
|
setPendingMessages((prev) => { |
|
const newState = { ...prev }; |
|
delete newState[convId]; |
|
return newState; |
|
}); |
|
} else { |
|
setPendingMessages((prev) => ({ ...prev, [convId]: pendingMsg })); |
|
} |
|
}; |
|
|
|
const setAbort = (convId: string, controller: AbortController | null) => { |
|
if (!controller) { |
|
setAborts((prev) => { |
|
const newState = { ...prev }; |
|
delete newState[convId]; |
|
return newState; |
|
}); |
|
} else { |
|
setAborts((prev) => ({ ...prev, [convId]: controller })); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
const isGenerating = (convId: string) => !!pendingMessages[convId]; |
|
|
|
const generateMessage = async ( |
|
convId: string, |
|
leafNodeId: Message['id'], |
|
onChunk: CallbackGeneratedChunk |
|
) => { |
|
if (isGenerating(convId)) return; |
|
|
|
const config = StorageUtils.getConfig(); |
|
const currConversation = await StorageUtils.getOneConversation(convId); |
|
if (!currConversation) { |
|
throw new Error('Current conversation is not found'); |
|
} |
|
|
|
const currMessages = StorageUtils.filterByLeafNodeId( |
|
await StorageUtils.getMessages(convId), |
|
leafNodeId, |
|
false |
|
); |
|
const abortController = new AbortController(); |
|
setAbort(convId, abortController); |
|
|
|
if (!currMessages) { |
|
throw new Error('Current messages are not found'); |
|
} |
|
|
|
const pendingId = Date.now() + 1; |
|
let pendingMsg: PendingMessage = { |
|
id: pendingId, |
|
convId, |
|
type: 'text', |
|
timestamp: pendingId, |
|
role: 'assistant', |
|
content: null, |
|
parent: leafNodeId, |
|
children: [], |
|
}; |
|
setPending(convId, pendingMsg); |
|
|
|
try { |
|
|
|
let messages: APIMessage[] = [ |
|
...(config.systemMessage.length === 0 |
|
? [] |
|
: [{ role: 'system', content: config.systemMessage } as APIMessage]), |
|
...normalizeMsgsForAPI(currMessages), |
|
]; |
|
if (config.excludeThoughtOnReq) { |
|
messages = filterThoughtFromMsgs(messages); |
|
} |
|
if (isDev) console.log({ messages }); |
|
|
|
|
|
const params = { |
|
messages, |
|
stream: true, |
|
cache_prompt: true, |
|
samplers: config.samplers, |
|
temperature: config.temperature, |
|
dynatemp_range: config.dynatemp_range, |
|
dynatemp_exponent: config.dynatemp_exponent, |
|
top_k: config.top_k, |
|
top_p: config.top_p, |
|
min_p: config.min_p, |
|
typical_p: config.typical_p, |
|
xtc_probability: config.xtc_probability, |
|
xtc_threshold: config.xtc_threshold, |
|
repeat_last_n: config.repeat_last_n, |
|
repeat_penalty: config.repeat_penalty, |
|
presence_penalty: config.presence_penalty, |
|
frequency_penalty: config.frequency_penalty, |
|
dry_multiplier: config.dry_multiplier, |
|
dry_base: config.dry_base, |
|
dry_allowed_length: config.dry_allowed_length, |
|
dry_penalty_last_n: config.dry_penalty_last_n, |
|
max_tokens: config.max_tokens, |
|
timings_per_token: !!config.showTokensPerSecond, |
|
...(config.custom.length ? JSON.parse(config.custom) : {}), |
|
}; |
|
|
|
|
|
const fetchResponse = await fetch(`${BASE_URL}/v1/chat/completions`, { |
|
method: 'POST', |
|
headers: { |
|
'Content-Type': 'application/json', |
|
...(config.apiKey |
|
? { Authorization: `Bearer ${config.apiKey}` } |
|
: {}), |
|
}, |
|
body: JSON.stringify(params), |
|
signal: abortController.signal, |
|
}); |
|
if (fetchResponse.status !== 200) { |
|
const body = await fetchResponse.json(); |
|
throw new Error(body?.error?.message || 'Unknown error'); |
|
} |
|
const chunks = getSSEStreamAsync(fetchResponse); |
|
for await (const chunk of chunks) { |
|
|
|
if (chunk.error) { |
|
throw new Error(chunk.error?.message || 'Unknown error'); |
|
} |
|
const addedContent = chunk.choices[0].delta.content; |
|
const lastContent = pendingMsg.content || ''; |
|
if (addedContent) { |
|
pendingMsg = { |
|
...pendingMsg, |
|
content: lastContent + addedContent, |
|
}; |
|
} |
|
const timings = chunk.timings; |
|
if (timings && config.showTokensPerSecond) { |
|
|
|
pendingMsg.timings = { |
|
prompt_n: timings.prompt_n, |
|
prompt_ms: timings.prompt_ms, |
|
predicted_n: timings.predicted_n, |
|
predicted_ms: timings.predicted_ms, |
|
}; |
|
} |
|
setPending(convId, pendingMsg); |
|
onChunk(); |
|
} |
|
} catch (err) { |
|
setPending(convId, null); |
|
if ((err as Error).name === 'AbortError') { |
|
|
|
|
|
} else { |
|
console.error(err); |
|
|
|
alert((err as any)?.message ?? 'Unknown error'); |
|
throw err; |
|
} |
|
} |
|
|
|
if (pendingMsg.content !== null) { |
|
await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId); |
|
} |
|
setPending(convId, null); |
|
onChunk(pendingId); |
|
}; |
|
|
|
const sendMessage = async ( |
|
convId: string | null, |
|
leafNodeId: Message['id'] | null, |
|
content: string, |
|
extra: Message['extra'], |
|
onChunk: CallbackGeneratedChunk |
|
): Promise<boolean> => { |
|
if (isGenerating(convId ?? '') || content.trim().length === 0) return false; |
|
|
|
if (convId === null || convId.length === 0 || leafNodeId === null) { |
|
const conv = await StorageUtils.createConversation( |
|
content.substring(0, 256) |
|
); |
|
convId = conv.id; |
|
leafNodeId = conv.currNode; |
|
|
|
navigate(`/chat/${convId}`); |
|
} |
|
|
|
const now = Date.now(); |
|
const currMsgId = now; |
|
StorageUtils.appendMsg( |
|
{ |
|
id: currMsgId, |
|
timestamp: now, |
|
type: 'text', |
|
convId, |
|
role: 'user', |
|
content, |
|
extra, |
|
parent: leafNodeId, |
|
children: [], |
|
}, |
|
leafNodeId |
|
); |
|
onChunk(currMsgId); |
|
|
|
try { |
|
await generateMessage(convId, currMsgId, onChunk); |
|
return true; |
|
} catch (_) { |
|
|
|
} |
|
return false; |
|
}; |
|
|
|
const stopGenerating = (convId: string) => { |
|
setPending(convId, null); |
|
aborts[convId]?.abort(); |
|
}; |
|
|
|
|
|
const replaceMessageAndGenerate = async ( |
|
convId: string, |
|
parentNodeId: Message['id'], |
|
content: string | null, |
|
extra: Message['extra'], |
|
onChunk: CallbackGeneratedChunk |
|
) => { |
|
if (isGenerating(convId)) return; |
|
|
|
if (content !== null) { |
|
const now = Date.now(); |
|
const currMsgId = now; |
|
StorageUtils.appendMsg( |
|
{ |
|
id: currMsgId, |
|
timestamp: now, |
|
type: 'text', |
|
convId, |
|
role: 'user', |
|
content, |
|
extra, |
|
parent: parentNodeId, |
|
children: [], |
|
}, |
|
parentNodeId |
|
); |
|
parentNodeId = currMsgId; |
|
} |
|
onChunk(parentNodeId); |
|
|
|
await generateMessage(convId, parentNodeId, onChunk); |
|
}; |
|
|
|
const saveConfig = (config: typeof CONFIG_DEFAULT) => { |
|
StorageUtils.setConfig(config); |
|
setConfig(config); |
|
}; |
|
|
|
return ( |
|
<AppContext.Provider |
|
value={{ |
|
isGenerating, |
|
viewingChat, |
|
pendingMessages, |
|
sendMessage, |
|
stopGenerating, |
|
replaceMessageAndGenerate, |
|
canvasData, |
|
setCanvasData, |
|
config, |
|
saveConfig, |
|
showSettings, |
|
setShowSettings, |
|
}} |
|
> |
|
{children} |
|
</AppContext.Provider> |
|
); |
|
}; |
|
|
|
export const useAppContext = () => useContext(AppContext); |
|
|