|
import { z } from "zod"; |
|
import { env } from "$env/dynamic/private"; |
|
import type { Endpoint } from "../endpoints"; |
|
import type { TextGenerationStreamOutput } from "@huggingface/inference"; |
|
import type { Cohere, CohereClient } from "cohere-ai"; |
|
import { buildPrompt } from "$lib/buildPrompt"; |
|
|
|
export const endpointCohereParametersSchema = z.object({ |
|
weight: z.number().int().positive().default(1), |
|
model: z.any(), |
|
type: z.literal("cohere"), |
|
apiKey: z.string().default(env.COHERE_API_TOKEN), |
|
raw: z.boolean().default(false), |
|
}); |
|
|
|
export async function endpointCohere( |
|
input: z.input<typeof endpointCohereParametersSchema> |
|
): Promise<Endpoint> { |
|
const { apiKey, model, raw } = endpointCohereParametersSchema.parse(input); |
|
|
|
let cohere: CohereClient; |
|
|
|
try { |
|
cohere = new (await import("cohere-ai")).CohereClient({ |
|
token: apiKey, |
|
}); |
|
} catch (e) { |
|
throw new Error("Failed to import cohere-ai", { cause: e }); |
|
} |
|
|
|
return async ({ messages, preprompt, generateSettings, continueMessage }) => { |
|
let system = preprompt; |
|
if (messages?.[0]?.from === "system") { |
|
system = messages[0].content; |
|
} |
|
|
|
const parameters = { ...model.parameters, ...generateSettings }; |
|
|
|
return (async function* () { |
|
let stream; |
|
let tokenId = 0; |
|
|
|
if (raw) { |
|
const prompt = await buildPrompt({ |
|
messages: messages.filter((message) => message.from !== "system"), |
|
model, |
|
preprompt: system, |
|
continueMessage, |
|
}); |
|
|
|
stream = await cohere.chatStream({ |
|
message: prompt, |
|
rawPrompting: true, |
|
model: model.id ?? model.name, |
|
p: parameters?.top_p, |
|
k: parameters?.top_k, |
|
maxTokens: parameters?.max_new_tokens, |
|
temperature: parameters?.temperature, |
|
stopSequences: parameters?.stop, |
|
frequencyPenalty: parameters?.frequency_penalty, |
|
}); |
|
} else { |
|
const formattedMessages = messages |
|
.filter((message) => message.from !== "system") |
|
.map((message) => ({ |
|
role: message.from === "user" ? "USER" : "CHATBOT", |
|
message: message.content, |
|
})) satisfies Cohere.ChatMessage[]; |
|
|
|
stream = await cohere.chatStream({ |
|
model: model.id ?? model.name, |
|
chatHistory: formattedMessages.slice(0, -1), |
|
message: formattedMessages[formattedMessages.length - 1].message, |
|
preamble: system, |
|
p: parameters?.top_p, |
|
k: parameters?.top_k, |
|
maxTokens: parameters?.max_new_tokens, |
|
temperature: parameters?.temperature, |
|
stopSequences: parameters?.stop, |
|
frequencyPenalty: parameters?.frequency_penalty, |
|
}); |
|
} |
|
|
|
for await (const output of stream) { |
|
if (output.eventType === "text-generation") { |
|
yield { |
|
token: { |
|
id: tokenId++, |
|
text: output.text, |
|
logprob: 0, |
|
special: false, |
|
}, |
|
generated_text: null, |
|
details: null, |
|
} satisfies TextGenerationStreamOutput; |
|
} else if (output.eventType === "stream-end") { |
|
if (["ERROR", "ERROR_TOXIC", "ERROR_LIMIT"].includes(output.finishReason)) { |
|
throw new Error(output.finishReason); |
|
} |
|
yield { |
|
token: { |
|
id: tokenId++, |
|
text: "", |
|
logprob: 0, |
|
special: true, |
|
}, |
|
generated_text: output.response.text, |
|
details: null, |
|
}; |
|
} |
|
} |
|
})(); |
|
}; |
|
} |
|
|