Spaces:
Runtime error
Runtime error
const BaseClient = require('./BaseClient'); | |
const ChatGPTClient = require('./ChatGPTClient'); | |
const { | |
encoding_for_model: encodingForModel, | |
get_encoding: getEncoding, | |
} = require('@dqbd/tiktoken'); | |
const { maxTokensMap, genAzureChatCompletion } = require('../../utils'); | |
// Cache to store Tiktoken instances | |
const tokenizersCache = {}; | |
// Counter for keeping track of the number of tokenizer calls | |
let tokenizerCallsCount = 0; | |
class OpenAIClient extends BaseClient { | |
constructor(apiKey, options = {}) { | |
super(apiKey, options); | |
this.ChatGPTClient = new ChatGPTClient(); | |
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this); | |
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this); | |
this.sender = options.sender ?? 'ChatGPT'; | |
this.contextStrategy = options.contextStrategy | |
? options.contextStrategy.toLowerCase() | |
: 'discard'; | |
this.shouldRefineContext = this.contextStrategy === 'refine'; | |
this.azure = options.azure || false; | |
if (this.azure) { | |
this.azureEndpoint = genAzureChatCompletion(this.azure); | |
} | |
this.setOptions(options); | |
} | |
setOptions(options) { | |
if (this.options && !this.options.replaceOptions) { | |
this.options.modelOptions = { | |
...this.options.modelOptions, | |
...options.modelOptions, | |
}; | |
delete options.modelOptions; | |
this.options = { | |
...this.options, | |
...options, | |
}; | |
} else { | |
this.options = options; | |
} | |
if (this.options.openaiApiKey) { | |
this.apiKey = this.options.openaiApiKey; | |
} | |
const modelOptions = this.options.modelOptions || {}; | |
if (!this.modelOptions) { | |
this.modelOptions = { | |
...modelOptions, | |
model: modelOptions.model || 'gpt-3.5-turbo', | |
temperature: | |
typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature, | |
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p, | |
presence_penalty: | |
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty, | |
stop: modelOptions.stop, | |
}; | |
} | |
this.isChatCompletion = | |
this.options.reverseProxyUrl || | |
this.options.localAI || | |
this.modelOptions.model.startsWith('gpt-'); | |
this.isChatGptModel = this.isChatCompletion; | |
if (this.modelOptions.model === 'text-davinci-003') { | |
this.isChatCompletion = false; | |
this.isChatGptModel = false; | |
} | |
const { isChatGptModel } = this; | |
this.isUnofficialChatGptModel = | |
this.modelOptions.model.startsWith('text-chat') || | |
this.modelOptions.model.startsWith('text-davinci-002-render'); | |
this.maxContextTokens = maxTokensMap[this.modelOptions.model] ?? 4095; // 1 less than maximum | |
this.maxResponseTokens = this.modelOptions.max_tokens || 1024; | |
this.maxPromptTokens = | |
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; | |
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { | |
throw new Error( | |
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ | |
this.maxPromptTokens + this.maxResponseTokens | |
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`, | |
); | |
} | |
this.userLabel = this.options.userLabel || 'User'; | |
this.chatGptLabel = this.options.chatGptLabel || 'Assistant'; | |
this.setupTokens(); | |
if (!this.modelOptions.stop) { | |
const stopTokens = [this.startToken]; | |
if (this.endToken && this.endToken !== this.startToken) { | |
stopTokens.push(this.endToken); | |
} | |
stopTokens.push(`\n${this.userLabel}:`); | |
stopTokens.push('<|diff_marker|>'); | |
this.modelOptions.stop = stopTokens; | |
} | |
if (this.options.reverseProxyUrl) { | |
this.completionsUrl = this.options.reverseProxyUrl; | |
} else if (isChatGptModel) { | |
this.completionsUrl = 'https://api.openai.com/v1/chat/completions'; | |
} else { | |
this.completionsUrl = 'https://api.openai.com/v1/completions'; | |
} | |
if (this.azureEndpoint) { | |
this.completionsUrl = this.azureEndpoint; | |
} | |
if (this.azureEndpoint && this.options.debug) { | |
console.debug(`Using Azure endpoint: ${this.azureEndpoint}`, this.azure); | |
} | |
return this; | |
} | |
setupTokens() { | |
if (this.isChatCompletion) { | |
this.startToken = '||>'; | |
this.endToken = ''; | |
} else if (this.isUnofficialChatGptModel) { | |
this.startToken = '<|im_start|>'; | |
this.endToken = '<|im_end|>'; | |
} else { | |
this.startToken = '||>'; | |
this.endToken = ''; | |
} | |
} | |
// Selects an appropriate tokenizer based on the current configuration of the client instance. | |
// It takes into account factors such as whether it's a chat completion, an unofficial chat GPT model, etc. | |
selectTokenizer() { | |
let tokenizer; | |
this.encoding = 'text-davinci-003'; | |
if (this.isChatCompletion) { | |
this.encoding = 'cl100k_base'; | |
tokenizer = this.constructor.getTokenizer(this.encoding); | |
} else if (this.isUnofficialChatGptModel) { | |
const extendSpecialTokens = { | |
'<|im_start|>': 100264, | |
'<|im_end|>': 100265, | |
}; | |
tokenizer = this.constructor.getTokenizer(this.encoding, true, extendSpecialTokens); | |
} else { | |
try { | |
this.encoding = this.modelOptions.model; | |
tokenizer = this.constructor.getTokenizer(this.modelOptions.model, true); | |
} catch { | |
tokenizer = this.constructor.getTokenizer(this.encoding, true); | |
} | |
} | |
return tokenizer; | |
} | |
// Retrieves a tokenizer either from the cache or creates a new one if one doesn't exist in the cache. | |
// If a tokenizer is being created, it's also added to the cache. | |
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { | |
let tokenizer; | |
if (tokenizersCache[encoding]) { | |
tokenizer = tokenizersCache[encoding]; | |
} else { | |
if (isModelName) { | |
tokenizer = encodingForModel(encoding, extendSpecialTokens); | |
} else { | |
tokenizer = getEncoding(encoding, extendSpecialTokens); | |
} | |
tokenizersCache[encoding] = tokenizer; | |
} | |
return tokenizer; | |
} | |
// Frees all encoders in the cache and resets the count. | |
static freeAndResetAllEncoders() { | |
try { | |
Object.keys(tokenizersCache).forEach((key) => { | |
if (tokenizersCache[key]) { | |
tokenizersCache[key].free(); | |
delete tokenizersCache[key]; | |
} | |
}); | |
// Reset count | |
tokenizerCallsCount = 1; | |
} catch (error) { | |
console.log('Free and reset encoders error'); | |
console.error(error); | |
} | |
} | |
// Checks if the cache of tokenizers has reached a certain size. If it has, it frees and resets all tokenizers. | |
resetTokenizersIfNecessary() { | |
if (tokenizerCallsCount >= 25) { | |
if (this.options.debug) { | |
console.debug('freeAndResetAllEncoders: reached 25 encodings, resetting...'); | |
} | |
this.constructor.freeAndResetAllEncoders(); | |
} | |
tokenizerCallsCount++; | |
} | |
// Returns the token count of a given text. It also checks and resets the tokenizers if necessary. | |
getTokenCount(text) { | |
this.resetTokenizersIfNecessary(); | |
try { | |
const tokenizer = this.selectTokenizer(); | |
return tokenizer.encode(text, 'all').length; | |
} catch (error) { | |
this.constructor.freeAndResetAllEncoders(); | |
const tokenizer = this.selectTokenizer(); | |
return tokenizer.encode(text, 'all').length; | |
} | |
} | |
getSaveOptions() { | |
return { | |
chatGptLabel: this.options.chatGptLabel, | |
promptPrefix: this.options.promptPrefix, | |
...this.modelOptions, | |
}; | |
} | |
getBuildMessagesOptions(opts) { | |
return { | |
isChatCompletion: this.isChatCompletion, | |
promptPrefix: opts.promptPrefix, | |
abortController: opts.abortController, | |
}; | |
} | |
async buildMessages( | |
messages, | |
parentMessageId, | |
{ isChatCompletion = false, promptPrefix = null }, | |
) { | |
if (!isChatCompletion) { | |
return await this.buildPrompt(messages, parentMessageId, { | |
isChatGptModel: isChatCompletion, | |
promptPrefix, | |
}); | |
} | |
let payload; | |
let instructions; | |
let tokenCountMap; | |
let promptTokens; | |
let orderedMessages = this.constructor.getMessagesForConversation(messages, parentMessageId); | |
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim(); | |
if (promptPrefix) { | |
promptPrefix = `Instructions:\n${promptPrefix}`; | |
instructions = { | |
role: 'system', | |
name: 'instructions', | |
content: promptPrefix, | |
}; | |
if (this.contextStrategy) { | |
instructions.tokenCount = this.getTokenCountForMessage(instructions); | |
} | |
} | |
const formattedMessages = orderedMessages.map((message) => { | |
let { role: _role, sender, text } = message; | |
const role = _role ?? sender; | |
const content = text ?? ''; | |
const formattedMessage = { | |
role: role?.toLowerCase() === 'user' ? 'user' : 'assistant', | |
content, | |
}; | |
if (this.options?.name && formattedMessage.role === 'user') { | |
formattedMessage.name = this.options.name; | |
} | |
if (this.contextStrategy) { | |
formattedMessage.tokenCount = | |
message.tokenCount ?? this.getTokenCountForMessage(formattedMessage); | |
} | |
return formattedMessage; | |
}); | |
// TODO: need to handle interleaving instructions better | |
if (this.contextStrategy) { | |
({ payload, tokenCountMap, promptTokens, messages } = await this.handleContextStrategy({ | |
instructions, | |
orderedMessages, | |
formattedMessages, | |
})); | |
} | |
const result = { | |
prompt: payload, | |
promptTokens, | |
messages, | |
}; | |
if (tokenCountMap) { | |
tokenCountMap.instructions = instructions?.tokenCount; | |
result.tokenCountMap = tokenCountMap; | |
} | |
return result; | |
} | |
async sendCompletion(payload, opts = {}) { | |
let reply = ''; | |
let result = null; | |
if (typeof opts.onProgress === 'function') { | |
await this.getCompletion( | |
payload, | |
(progressMessage) => { | |
if (progressMessage === '[DONE]') { | |
return; | |
} | |
const token = this.isChatCompletion | |
? progressMessage.choices?.[0]?.delta?.content | |
: progressMessage.choices?.[0]?.text; | |
// first event's delta content is always undefined | |
if (!token) { | |
return; | |
} | |
if (this.options.debug) { | |
// console.debug(token); | |
} | |
if (token === this.endToken) { | |
return; | |
} | |
opts.onProgress(token); | |
reply += token; | |
}, | |
opts.abortController || new AbortController(), | |
); | |
} else { | |
result = await this.getCompletion( | |
payload, | |
null, | |
opts.abortController || new AbortController(), | |
); | |
if (this.options.debug) { | |
console.debug(JSON.stringify(result)); | |
} | |
if (this.isChatCompletion) { | |
reply = result.choices[0].message.content; | |
} else { | |
reply = result.choices[0].text.replace(this.endToken, ''); | |
} | |
} | |
return reply.trim(); | |
} | |
getTokenCountForResponse(response) { | |
return this.getTokenCountForMessage({ | |
role: 'assistant', | |
content: response.text, | |
}); | |
} | |
} | |
module.exports = OpenAIClient; | |