sam2ai's picture
Update src/app/api/chat/route.ts
fae682e verified
import {
createParser,
ParsedEvent,
ReconnectInterval,
} from "eventsource-parser";
import { NextRequest, NextResponse } from "next/server";
import {
ChatCompletionAssistantMessageParam,
ChatCompletionCreateParamsStreaming,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
} from "openai/resources/index.mjs";
import { encodeChat, tokenLimit } from "@/lib/token-counter";
const addSystemMessage = (
messages: ChatCompletionMessageParam[],
systemPrompt?: string
) => {
// early exit if system prompt is empty
if (!systemPrompt || systemPrompt === "") {
return messages;
}
// add system prompt to the chat (if it's not already there)
// check first message in the chat
if (!messages) {
// if there are no messages, add the system prompt as the first message
messages = [
{
content: systemPrompt,
role: "system",
},
];
} else if (messages.length === 0) {
// if there are no messages, add the system prompt as the first message
messages.push({
content: systemPrompt,
role: "system",
});
} else {
// if there are messages, check if the first message is a system prompt
if (messages[0].role === "system") {
// if the first message is a system prompt, update it
messages[0].content = systemPrompt;
} else {
// if the first message is not a system prompt, add the system prompt as the first message
messages.unshift({
content: systemPrompt,
role: "system",
});
}
}
return messages;
};
const formatMessages = (
messages: ChatCompletionMessageParam[]
): ChatCompletionMessageParam[] => {
let mappedMessages: ChatCompletionMessageParam[] = [];
let messagesTokenCounts: number[] = [];
const responseTokens = 512;
const tokenLimitRemaining = tokenLimit - responseTokens;
let tokenCount = 0;
messages.forEach((m) => {
if (m.role === "system") {
mappedMessages.push({
role: "system",
content: m.content,
} as ChatCompletionSystemMessageParam);
} else if (m.role === "user") {
mappedMessages.push({
role: "user",
content: m.content,
} as ChatCompletionUserMessageParam);
} else if (m.role === "assistant") {
mappedMessages.push({
role: "assistant",
content: m.content,
} as ChatCompletionAssistantMessageParam);
} else {
return;
}
// ignore typing
// tslint:disable-next-line
const messageTokens = encodeChat([m]);
messagesTokenCounts.push(messageTokens);
tokenCount += messageTokens;
});
if (tokenCount <= tokenLimitRemaining) {
return mappedMessages;
}
// remove the middle messages until the token count is below the limit
while (tokenCount > tokenLimitRemaining) {
const middleMessageIndex = Math.floor(messages.length / 2);
const middleMessageTokens = messagesTokenCounts[middleMessageIndex];
mappedMessages.splice(middleMessageIndex, 1);
messagesTokenCounts.splice(middleMessageIndex, 1);
tokenCount -= middleMessageTokens;
}
return mappedMessages;
};
export async function POST(req: NextRequest): Promise<NextResponse> {
try {
const { messages, chatOptions } = await req.json();
if (!chatOptions.selectedModel || chatOptions.selectedModel === "") {
throw new Error("Selected model is required");
}
const baseUrl = process.env.VLLM_URL;
if (!baseUrl) {
throw new Error("VLLM_URL is not set");
}
const apiKey = process.env.VLLM_API_KEY;
const formattedMessages = formatMessages(
addSystemMessage(messages, chatOptions.systemPrompt)
);
const stream = await getOpenAIStream(
baseUrl,
chatOptions.selectedModel,
formattedMessages,
chatOptions.temperature,
apiKey,
);
return new NextResponse(stream, {
headers: { "Content-Type": "text/event-stream" },
});
} catch (error) {
console.error(error);
return NextResponse.json(
{
success: false,
error: error instanceof Error ? error.message : "Unknown error",
},
{ status: 500 }
);
}
}
const getOpenAIStream = async (
apiUrl: string,
model: string,
messages: ChatCompletionMessageParam[],
temperature?: number,
apiKey?: string
): Promise<ReadableStream<Uint8Array>> => {
const encoder = new TextEncoder();
const decoder = new TextDecoder();
const headers = new Headers();
headers.set("Content-Type", "application/json");
if (apiKey !== undefined) {
headers.set("Authorization", `Bearer ${apiKey}`);
headers.set("api-key", apiKey);
}
const chatOptions: ChatCompletionCreateParamsStreaming = {
model: model,
// frequency_penalty: 0,
// max_tokens: 2000,
messages: messages,
// presence_penalty: 0,
stream: true,
temperature: temperature ?? 0.5,
// response_format: {
// type: "json_object",
// }
// top_p: 0.95,
};
const res = await fetch(apiUrl + "/v1/chat/completions", {
headers: headers,
method: "POST",
body: JSON.stringify(chatOptions),
});
if (res.status !== 200) {
const statusText = res.statusText;
const responseBody = await res.text();
console.error(`vLLM API response error: ${responseBody}`);
throw new Error(
`The vLLM API has encountered an error with a status code of ${res.status} ${statusText}: ${responseBody}`
);
}
return new ReadableStream({
async start(controller) {
const onParse = (event: ParsedEvent | ReconnectInterval) => {
if (event.type === "event") {
const data = event.data;
if (data === "[DONE]") {
controller.close();
return;
}
try {
const json = JSON.parse(data);
const text = json.choices[0].delta.content;
const queue = encoder.encode(text);
controller.enqueue(queue);
} catch (e) {
controller.error(e);
}
}
};
const parser = createParser(onParse);
for await (const chunk of res.body as any) {
// An extra newline is required to make AzureOpenAI work.
const str = decoder.decode(chunk).replace("[DONE]\n", "[DONE]\n\n");
parser.feed(str);
}
},
});
};