Spaces:
Runtime error
Runtime error
const BaseClient = require('./BaseClient'); | |
const { google } = require('googleapis'); | |
const { Agent, ProxyAgent } = require('undici'); | |
const { | |
encoding_for_model: encodingForModel, | |
get_encoding: getEncoding, | |
} = require('@dqbd/tiktoken'); | |
const tokenizersCache = {}; | |
class GoogleClient extends BaseClient { | |
constructor(credentials, options = {}) { | |
super('apiKey', options); | |
this.client_email = credentials.client_email; | |
this.project_id = credentials.project_id; | |
this.private_key = credentials.private_key; | |
this.sender = 'PaLM2'; | |
this.setOptions(options); | |
} | |
/* Google/PaLM2 specific methods */ | |
constructUrl() { | |
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`; | |
} | |
async getClient() { | |
const scopes = ['https://www.googleapis.com/auth/cloud-platform']; | |
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes); | |
jwtClient.authorize((err) => { | |
if (err) { | |
console.log(err); | |
throw err; | |
} | |
}); | |
return jwtClient; | |
} | |
/* Required Client methods */ | |
setOptions(options) { | |
if (this.options && !this.options.replaceOptions) { | |
// nested options aren't spread properly, so we need to do this manually | |
this.options.modelOptions = { | |
...this.options.modelOptions, | |
...options.modelOptions, | |
}; | |
delete options.modelOptions; | |
// now we can merge options | |
this.options = { | |
...this.options, | |
...options, | |
}; | |
} else { | |
this.options = options; | |
} | |
this.options.examples = this.options.examples.filter( | |
(obj) => obj.input.content !== '' && obj.output.content !== '', | |
); | |
const modelOptions = this.options.modelOptions || {}; | |
this.modelOptions = { | |
...modelOptions, | |
// set some good defaults (check for undefined in some cases because they may be 0) | |
model: modelOptions.model || 'chat-bison', | |
temperature: typeof modelOptions.temperature === 'undefined' ? 0.2 : modelOptions.temperature, // 0 - 1, 0.2 is recommended | |
topP: typeof modelOptions.topP === 'undefined' ? 0.95 : modelOptions.topP, // 0 - 1, default: 0.95 | |
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40 | |
// stop: modelOptions.stop // no stop method for now | |
}; | |
this.isChatModel = this.modelOptions.model.startsWith('chat-'); | |
const { isChatModel } = this; | |
this.isTextModel = this.modelOptions.model.startsWith('text-'); | |
const { isTextModel } = this; | |
this.maxContextTokens = this.options.maxContextTokens || (isTextModel ? 8000 : 4096); | |
// The max prompt tokens is determined by the max context tokens minus the max response tokens. | |
// Earlier messages will be dropped until the prompt is within the limit. | |
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1024; | |
this.maxPromptTokens = | |
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; | |
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { | |
throw new Error( | |
`maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ | |
this.maxPromptTokens + this.maxResponseTokens | |
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`, | |
); | |
} | |
this.userLabel = this.options.userLabel || 'User'; | |
this.modelLabel = this.options.modelLabel || 'Assistant'; | |
if (isChatModel) { | |
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves. | |
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason, | |
// without tripping the stop sequences, so I'm using "||>" instead. | |
this.startToken = '||>'; | |
this.endToken = ''; | |
this.gptEncoder = this.constructor.getTokenizer('cl100k_base'); | |
} else if (isTextModel) { | |
this.startToken = '<|im_start|>'; | |
this.endToken = '<|im_end|>'; | |
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, { | |
'<|im_start|>': 100264, | |
'<|im_end|>': 100265, | |
}); | |
} else { | |
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting | |
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated | |
// as a single token. So we're using this instead. | |
this.startToken = '||>'; | |
this.endToken = ''; | |
try { | |
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true); | |
} catch { | |
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true); | |
} | |
} | |
if (!this.modelOptions.stop) { | |
const stopTokens = [this.startToken]; | |
if (this.endToken && this.endToken !== this.startToken) { | |
stopTokens.push(this.endToken); | |
} | |
stopTokens.push(`\n${this.userLabel}:`); | |
stopTokens.push('<|diff_marker|>'); | |
// I chose not to do one for `modelLabel` because I've never seen it happen | |
this.modelOptions.stop = stopTokens; | |
} | |
if (this.options.reverseProxyUrl) { | |
this.completionsUrl = this.options.reverseProxyUrl; | |
} else { | |
this.completionsUrl = this.constructUrl(); | |
} | |
return this; | |
} | |
getMessageMapMethod() { | |
return ((message) => ({ | |
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel), | |
content: message?.content ?? message.text, | |
})).bind(this); | |
} | |
buildMessages(messages = []) { | |
const formattedMessages = messages.map(this.getMessageMapMethod()); | |
let payload = { | |
instances: [ | |
{ | |
messages: formattedMessages, | |
}, | |
], | |
parameters: this.options.modelOptions, | |
}; | |
if (this.options.promptPrefix) { | |
payload.instances[0].context = this.options.promptPrefix; | |
} | |
if (this.options.examples.length > 0) { | |
payload.instances[0].examples = this.options.examples; | |
} | |
/* TO-DO: text model needs more context since it can't process an array of messages */ | |
if (this.isTextModel) { | |
payload.instances = [ | |
{ | |
prompt: messages[messages.length - 1].content, | |
}, | |
]; | |
} | |
if (this.options.debug) { | |
console.debug('GoogleClient buildMessages'); | |
console.dir(payload, { depth: null }); | |
} | |
return { prompt: payload }; | |
} | |
async getCompletion(payload, abortController = null) { | |
if (!abortController) { | |
abortController = new AbortController(); | |
} | |
const { debug } = this.options; | |
const url = this.completionsUrl; | |
if (debug) { | |
console.debug(); | |
console.debug(url); | |
console.debug(this.modelOptions); | |
console.debug(); | |
} | |
const opts = { | |
method: 'POST', | |
agent: new Agent({ | |
bodyTimeout: 0, | |
headersTimeout: 0, | |
}), | |
signal: abortController.signal, | |
}; | |
if (this.options.proxy) { | |
opts.agent = new ProxyAgent(this.options.proxy); | |
} | |
const client = await this.getClient(); | |
const res = await client.request({ url, method: 'POST', data: payload }); | |
console.dir(res.data, { depth: null }); | |
return res.data; | |
} | |
getSaveOptions() { | |
return { | |
promptPrefix: this.options.promptPrefix, | |
modelLabel: this.options.modelLabel, | |
...this.modelOptions, | |
}; | |
} | |
getBuildMessagesOptions() { | |
// console.log('GoogleClient doesn\'t use getBuildMessagesOptions'); | |
} | |
async sendCompletion(payload, opts = {}) { | |
console.log('GoogleClient: sendcompletion', payload, opts); | |
let reply = ''; | |
let blocked = false; | |
try { | |
const result = await this.getCompletion(payload, opts.abortController); | |
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked; | |
reply = | |
result?.predictions?.[0]?.candidates?.[0]?.content || | |
result?.predictions?.[0]?.content || | |
''; | |
if (blocked === true) { | |
reply = `Google blocked a proper response to your message:\n${JSON.stringify( | |
result.predictions[0].safetyAttributes, | |
)}${reply.length > 0 ? `\nAI Response:\n${reply}` : ''}`; | |
} | |
if (this.options.debug) { | |
console.debug('result'); | |
console.debug(result); | |
} | |
} catch (err) { | |
console.error(err); | |
} | |
if (!blocked) { | |
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 }); | |
} | |
return reply.trim(); | |
} | |
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */ | |
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { | |
if (tokenizersCache[encoding]) { | |
return tokenizersCache[encoding]; | |
} | |
let tokenizer; | |
if (isModelName) { | |
tokenizer = encodingForModel(encoding, extendSpecialTokens); | |
} else { | |
tokenizer = getEncoding(encoding, extendSpecialTokens); | |
} | |
tokenizersCache[encoding] = tokenizer; | |
return tokenizer; | |
} | |
getTokenCount(text) { | |
return this.gptEncoder.encode(text, 'all').length; | |
} | |
} | |
module.exports = GoogleClient; | |