|
import { convertToCoreMessages, streamText as _streamText, type Message } from 'ai'; |
|
import { MAX_TOKENS, type FileMap } from './constants'; |
|
import fs from 'fs'; |
|
import { getSystemPrompt } from '~/lib/common/prompts/prompts'; |
|
import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODIFICATIONS_TAG_NAME, PROVIDER_LIST, WORK_DIR } from '~/utils/constants'; |
|
import type { IProviderSetting } from '~/types/model'; |
|
import { PromptLibrary } from '~/lib/common/prompt-library'; |
|
import { allowedHTMLElements } from '~/utils/markdown'; |
|
import { LLMManager } from '~/lib/modules/llm/manager'; |
|
import { createScopedLogger } from '~/utils/logger'; |
|
import { createFilesContext, extractPropertiesFromMessage } from './utils'; |
|
import { getFilePaths } from './select-context'; |
|
|
|
export type Messages = Message[]; |
|
|
|
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>; |
|
|
|
const logger = createScopedLogger('stream-text'); |
|
|
|
const CACHE_CONTROL_METADATA = { |
|
experimental_providerMetadata: { |
|
anthropic: { cacheControl: { type: 'ephemeral' } }, |
|
}, |
|
}; |
|
|
|
function persistMessages(messages: Message[]) { |
|
try { |
|
const messagesFilePath = 'messages.json'; |
|
fs.writeFileSync(messagesFilePath, JSON.stringify(messages, null, 2), 'utf8'); |
|
} catch (error) { |
|
console.error('Error writing messages to file:', error); |
|
} |
|
} |
|
|
|
export async function streamText(props: { |
|
messages: Omit<Message, 'id'>[]; |
|
env?: Env; |
|
options?: StreamingOptions; |
|
apiKeys?: Record<string, string>; |
|
files?: FileMap; |
|
providerSettings?: Record<string, IProviderSetting>; |
|
promptId?: string; |
|
contextOptimization?: boolean; |
|
isPromptCachingEnabled?: boolean; |
|
contextFiles?: FileMap; |
|
summary?: string; |
|
messageSliceId?: number; |
|
}) { |
|
const { |
|
messages, |
|
env: serverEnv, |
|
options, |
|
apiKeys, |
|
files, |
|
providerSettings, |
|
promptId, |
|
contextOptimization, |
|
isPromptCachingEnabled, |
|
contextFiles, |
|
summary, |
|
} = props; |
|
let currentModel = DEFAULT_MODEL; |
|
let currentProvider = DEFAULT_PROVIDER.name; |
|
let processedMessages = messages.map((message, idx) => { |
|
if (message.role === 'user') { |
|
const { model, provider, content } = extractPropertiesFromMessage(message); |
|
currentModel = model; |
|
currentProvider = provider; |
|
|
|
const putCacheControl = isPromptCachingEnabled && idx >= messages?.length - 4; |
|
|
|
return { |
|
...message, |
|
content, |
|
...(putCacheControl && CACHE_CONTROL_METADATA), |
|
}; |
|
} else if (message.role == 'assistant') { |
|
let content = message.content; |
|
content = content.replace(/<div class=\\"__boltThought__\\">.*?<\/div>/s, ''); |
|
content = content.replace(/<think>.*?<\/think>/s, ''); |
|
|
|
return { ...message, content }; |
|
} |
|
|
|
return message; |
|
}); |
|
|
|
const provider = PROVIDER_LIST.find((p) => p.name === currentProvider) || DEFAULT_PROVIDER; |
|
const staticModels = LLMManager.getInstance().getStaticModelListFromProvider(provider); |
|
let modelDetails = staticModels.find((m) => m.name === currentModel); |
|
|
|
if (!modelDetails) { |
|
const modelsList = [ |
|
...(provider.staticModels || []), |
|
...(await LLMManager.getInstance().getModelListFromProvider(provider, { |
|
apiKeys, |
|
providerSettings, |
|
serverEnv: serverEnv as any, |
|
})), |
|
]; |
|
|
|
if (!modelsList.length) { |
|
throw new Error(`No models found for provider ${provider.name}`); |
|
} |
|
|
|
modelDetails = modelsList.find((m) => m.name === currentModel); |
|
|
|
if (!modelDetails) { |
|
|
|
logger.warn( |
|
`MODEL [${currentModel}] not found in provider [${provider.name}]. Falling back to first model. ${modelsList[0].name}`, |
|
); |
|
modelDetails = modelsList[0]; |
|
} |
|
} |
|
|
|
const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS; |
|
|
|
let systemPrompt = |
|
PromptLibrary.getPropmtFromLibrary(promptId || 'default', { |
|
cwd: WORK_DIR, |
|
allowedHtmlElements: allowedHTMLElements, |
|
modificationTagName: MODIFICATIONS_TAG_NAME, |
|
}) ?? getSystemPrompt(); |
|
|
|
if (files && contextFiles && contextOptimization) { |
|
const codeContext = createFilesContext(contextFiles, true); |
|
const filePaths = getFilePaths(files); |
|
|
|
systemPrompt = `${systemPrompt} |
|
Below are all the files present in the project: |
|
--- |
|
${filePaths.join('\n')} |
|
--- |
|
|
|
Below is the artifact containing the context loaded into context buffer for you to have knowledge of and might need changes to fullfill current user request. |
|
CONTEXT BUFFER: |
|
--- |
|
${codeContext} |
|
--- |
|
`; |
|
|
|
if (summary) { |
|
systemPrompt = `${systemPrompt} |
|
below is the chat history till now |
|
CHAT SUMMARY: |
|
--- |
|
${props.summary} |
|
--- |
|
`; |
|
|
|
if (props.messageSliceId) { |
|
processedMessages = processedMessages.slice(props.messageSliceId); |
|
} else { |
|
const lastMessage = processedMessages.pop(); |
|
|
|
if (lastMessage) { |
|
processedMessages = [lastMessage]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
logger.info(`Sending llm call to ${provider.name} with model ${modelDetails.name}`); |
|
|
|
if (isPromptCachingEnabled) { |
|
const messages = [ |
|
{ |
|
role: 'system', |
|
content: systemPrompt, |
|
experimental_providerMetadata: { |
|
anthropic: { cacheControl: { type: 'ephemeral' } }, |
|
}, |
|
}, |
|
...processedMessages, |
|
] as Message[]; |
|
|
|
persistMessages(messages); |
|
|
|
return _streamText({ |
|
model: provider.getModelInstance({ |
|
model: modelDetails.name, |
|
serverEnv, |
|
apiKeys, |
|
providerSettings, |
|
}), |
|
maxTokens: dynamicMaxTokens, |
|
messages, |
|
...options, |
|
}); |
|
} |
|
|
|
return _streamText({ |
|
model: provider.getModelInstance({ |
|
model: modelDetails.name, |
|
serverEnv, |
|
apiKeys, |
|
providerSettings, |
|
}), |
|
system: systemPrompt, |
|
maxTokens: dynamicMaxTokens, |
|
messages: convertToCoreMessages(processedMessages as any), |
|
...options, |
|
}); |
|
} |
|
|