Spaces:
Runtime error
Runtime error
import { | |
createParser, | |
ParsedEvent, | |
ReconnectInterval, | |
} from "eventsource-parser"; | |
import { NextRequest, NextResponse } from "next/server"; | |
import { | |
ChatCompletionAssistantMessageParam, | |
ChatCompletionCreateParamsStreaming, | |
ChatCompletionMessageParam, | |
ChatCompletionSystemMessageParam, | |
ChatCompletionUserMessageParam, | |
} from "openai/resources/index.mjs"; | |
import { encodeChat, tokenLimit } from "@/lib/token-counter"; | |
const addSystemMessage = ( | |
messages: ChatCompletionMessageParam[], | |
systemPrompt?: string | |
) => { | |
// early exit if system prompt is empty | |
if (!systemPrompt || systemPrompt === "") { | |
return messages; | |
} | |
// add system prompt to the chat (if it's not already there) | |
// check first message in the chat | |
if (!messages) { | |
// if there are no messages, add the system prompt as the first message | |
messages = [ | |
{ | |
content: systemPrompt, | |
role: "system", | |
}, | |
]; | |
} else if (messages.length === 0) { | |
// if there are no messages, add the system prompt as the first message | |
messages.push({ | |
content: systemPrompt, | |
role: "system", | |
}); | |
} else { | |
// if there are messages, check if the first message is a system prompt | |
if (messages[0].role === "system") { | |
// if the first message is a system prompt, update it | |
messages[0].content = systemPrompt; | |
} else { | |
// if the first message is not a system prompt, add the system prompt as the first message | |
messages.unshift({ | |
content: systemPrompt, | |
role: "system", | |
}); | |
} | |
} | |
return messages; | |
}; | |
const formatMessages = ( | |
messages: ChatCompletionMessageParam[] | |
): ChatCompletionMessageParam[] => { | |
let mappedMessages: ChatCompletionMessageParam[] = []; | |
let messagesTokenCounts: number[] = []; | |
const responseTokens = 512; | |
const tokenLimitRemaining = tokenLimit - responseTokens; | |
let tokenCount = 0; | |
messages.forEach((m) => { | |
if (m.role === "system") { | |
mappedMessages.push({ | |
role: "system", | |
content: m.content, | |
} as ChatCompletionSystemMessageParam); | |
} else if (m.role === "user") { | |
mappedMessages.push({ | |
role: "user", | |
content: m.content, | |
} as ChatCompletionUserMessageParam); | |
} else if (m.role === "assistant") { | |
mappedMessages.push({ | |
role: "assistant", | |
content: m.content, | |
} as ChatCompletionAssistantMessageParam); | |
} else { | |
return; | |
} | |
// ignore typing | |
// tslint:disable-next-line | |
const messageTokens = encodeChat([m]); | |
messagesTokenCounts.push(messageTokens); | |
tokenCount += messageTokens; | |
}); | |
if (tokenCount <= tokenLimitRemaining) { | |
return mappedMessages; | |
} | |
// remove the middle messages until the token count is below the limit | |
while (tokenCount > tokenLimitRemaining) { | |
const middleMessageIndex = Math.floor(messages.length / 2); | |
const middleMessageTokens = messagesTokenCounts[middleMessageIndex]; | |
mappedMessages.splice(middleMessageIndex, 1); | |
messagesTokenCounts.splice(middleMessageIndex, 1); | |
tokenCount -= middleMessageTokens; | |
} | |
return mappedMessages; | |
}; | |
export async function POST(req: NextRequest): Promise<NextResponse> { | |
try { | |
const { messages, chatOptions } = await req.json(); | |
if (!chatOptions.selectedModel || chatOptions.selectedModel === "") { | |
throw new Error("Selected model is required"); | |
} | |
const baseUrl = process.env.VLLM_URL; | |
if (!baseUrl) { | |
throw new Error("VLLM_URL is not set"); | |
} | |
const apiKey = process.env.VLLM_API_KEY; | |
const formattedMessages = formatMessages( | |
addSystemMessage(messages, chatOptions.systemPrompt) | |
); | |
const stream = await getOpenAIStream( | |
baseUrl, | |
chatOptions.selectedModel, | |
formattedMessages, | |
chatOptions.temperature, | |
apiKey, | |
); | |
return new NextResponse(stream, { | |
headers: { "Content-Type": "text/event-stream" }, | |
}); | |
} catch (error) { | |
console.error(error); | |
return NextResponse.json( | |
{ | |
success: false, | |
error: error instanceof Error ? error.message : "Unknown error", | |
}, | |
{ status: 500 } | |
); | |
} | |
} | |
const getOpenAIStream = async ( | |
apiUrl: string, | |
model: string, | |
messages: ChatCompletionMessageParam[], | |
temperature?: number, | |
apiKey?: string | |
): Promise<ReadableStream<Uint8Array>> => { | |
const encoder = new TextEncoder(); | |
const decoder = new TextDecoder(); | |
const headers = new Headers(); | |
headers.set("Content-Type", "application/json"); | |
if (apiKey !== undefined) { | |
headers.set("Authorization", `Bearer ${apiKey}`); | |
headers.set("api-key", apiKey); | |
} | |
const chatOptions: ChatCompletionCreateParamsStreaming = { | |
model: model, | |
// frequency_penalty: 0, | |
// max_tokens: 2000, | |
messages: messages, | |
// presence_penalty: 0, | |
stream: true, | |
temperature: temperature ?? 0.5, | |
// response_format: { | |
// type: "json_object", | |
// } | |
// top_p: 0.95, | |
}; | |
const res = await fetch(apiUrl + "/v1/chat/completions", { | |
headers: headers, | |
method: "POST", | |
body: JSON.stringify(chatOptions), | |
}); | |
if (res.status !== 200) { | |
const statusText = res.statusText; | |
const responseBody = await res.text(); | |
console.error(`vLLM API response error: ${responseBody}`); | |
throw new Error( | |
`The vLLM API has encountered an error with a status code of ${res.status} ${statusText}: ${responseBody}` | |
); | |
} | |
return new ReadableStream({ | |
async start(controller) { | |
const onParse = (event: ParsedEvent | ReconnectInterval) => { | |
if (event.type === "event") { | |
const data = event.data; | |
if (data === "[DONE]") { | |
controller.close(); | |
return; | |
} | |
try { | |
const json = JSON.parse(data); | |
const text = json.choices[0].delta.content; | |
const queue = encoder.encode(text); | |
controller.enqueue(queue); | |
} catch (e) { | |
controller.error(e); | |
} | |
} | |
}; | |
const parser = createParser(onParse); | |
for await (const chunk of res.body as any) { | |
// An extra newline is required to make AzureOpenAI work. | |
const str = decoder.decode(chunk).replace("[DONE]\n", "[DONE]\n\n"); | |
parser.feed(str); | |
} | |
}, | |
}); | |
}; | |