Spaces:
Runtime error
Runtime error
const { initializeFakeClient } = require('./FakeClient'); | |
jest.mock('../../../lib/db/connectDb'); | |
jest.mock('../../../models', () => { | |
return function () { | |
return { | |
save: jest.fn(), | |
deleteConvos: jest.fn(), | |
getConvo: jest.fn(), | |
getMessages: jest.fn(), | |
saveMessage: jest.fn(), | |
updateMessage: jest.fn(), | |
saveConvo: jest.fn(), | |
}; | |
}; | |
}); | |
jest.mock('langchain/text_splitter', () => { | |
return { | |
RecursiveCharacterTextSplitter: jest.fn().mockImplementation(() => { | |
return { createDocuments: jest.fn().mockResolvedValue([]) }; | |
}), | |
}; | |
}); | |
jest.mock('langchain/chat_models/openai', () => { | |
return { | |
ChatOpenAI: jest.fn().mockImplementation(() => { | |
return {}; | |
}), | |
}; | |
}); | |
jest.mock('langchain/chains', () => { | |
return { | |
loadSummarizationChain: jest.fn().mockReturnValue({ | |
call: jest.fn().mockResolvedValue({ output_text: 'Refined answer' }), | |
}), | |
}; | |
}); | |
let parentMessageId; | |
let conversationId; | |
const fakeMessages = []; | |
const userMessage = 'Hello, ChatGPT!'; | |
const apiKey = 'fake-api-key'; | |
describe('BaseClient', () => { | |
let TestClient; | |
const options = { | |
// debug: true, | |
modelOptions: { | |
model: 'gpt-3.5-turbo', | |
temperature: 0, | |
}, | |
}; | |
beforeEach(() => { | |
TestClient = initializeFakeClient(apiKey, options, fakeMessages); | |
}); | |
test('returns the input messages without instructions when addInstructions() is called with empty instructions', () => { | |
const messages = [{ content: 'Hello' }, { content: 'How are you?' }, { content: 'Goodbye' }]; | |
const instructions = ''; | |
const result = TestClient.addInstructions(messages, instructions); | |
expect(result).toEqual(messages); | |
}); | |
test('returns the input messages with instructions properly added when addInstructions() is called with non-empty instructions', () => { | |
const messages = [{ content: 'Hello' }, { content: 'How are you?' }, { content: 'Goodbye' }]; | |
const instructions = { content: 'Please respond to the question.' }; | |
const result = TestClient.addInstructions(messages, instructions); | |
const expected = [ | |
{ content: 'Hello' }, | |
{ content: 'How are you?' }, | |
{ content: 'Please respond to the question.' }, | |
{ content: 'Goodbye' }, | |
]; | |
expect(result).toEqual(expected); | |
}); | |
test('concats messages correctly in concatenateMessages()', () => { | |
const messages = [ | |
{ name: 'User', content: 'Hello' }, | |
{ name: 'Assistant', content: 'How can I help you?' }, | |
{ name: 'User', content: 'I have a question.' }, | |
]; | |
const result = TestClient.concatenateMessages(messages); | |
const expected = | |
'User:\nHello\n\nAssistant:\nHow can I help you?\n\nUser:\nI have a question.\n\n'; | |
expect(result).toBe(expected); | |
}); | |
test('refines messages correctly in refineMessages()', async () => { | |
const messagesToRefine = [ | |
{ role: 'user', content: 'Hello', tokenCount: 10 }, | |
{ role: 'assistant', content: 'How can I help you?', tokenCount: 20 }, | |
]; | |
const remainingContextTokens = 100; | |
const expectedRefinedMessage = { | |
role: 'assistant', | |
content: 'Refined answer', | |
tokenCount: 14, // 'Refined answer'.length | |
}; | |
const result = await TestClient.refineMessages(messagesToRefine, remainingContextTokens); | |
expect(result).toEqual(expectedRefinedMessage); | |
}); | |
test('gets messages within token limit (under limit) correctly in getMessagesWithinTokenLimit()', async () => { | |
TestClient.maxContextTokens = 100; | |
TestClient.shouldRefineContext = true; | |
TestClient.refineMessages = jest.fn().mockResolvedValue({ | |
role: 'assistant', | |
content: 'Refined answer', | |
tokenCount: 30, | |
}); | |
const messages = [ | |
{ role: 'user', content: 'Hello', tokenCount: 5 }, | |
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 }, | |
{ role: 'user', content: 'I have a question.', tokenCount: 18 }, | |
]; | |
const expectedContext = [ | |
{ role: 'user', content: 'Hello', tokenCount: 5 }, // 'Hello'.length | |
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 }, | |
{ role: 'user', content: 'I have a question.', tokenCount: 18 }, | |
]; | |
const expectedRemainingContextTokens = 58; // 100 - 5 - 19 - 18 | |
const expectedMessagesToRefine = []; | |
const result = await TestClient.getMessagesWithinTokenLimit(messages); | |
expect(result.context).toEqual(expectedContext); | |
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens); | |
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); | |
}); | |
test('gets messages within token limit (over limit) correctly in getMessagesWithinTokenLimit()', async () => { | |
TestClient.maxContextTokens = 50; // Set a lower limit | |
TestClient.shouldRefineContext = true; | |
TestClient.refineMessages = jest.fn().mockResolvedValue({ | |
role: 'assistant', | |
content: 'Refined answer', | |
tokenCount: 4, | |
}); | |
const messages = [ | |
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 }, | |
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 }, | |
{ role: 'user', content: 'Hello', tokenCount: 5 }, | |
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 }, | |
{ role: 'user', content: 'I have a question.', tokenCount: 18 }, | |
]; | |
const expectedContext = [ | |
{ role: 'user', content: 'Hello', tokenCount: 5 }, | |
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 }, | |
{ role: 'user', content: 'I have a question.', tokenCount: 18 }, | |
]; | |
const expectedRemainingContextTokens = 8; // 50 - 18 - 19 - 5 | |
const expectedMessagesToRefine = [ | |
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 }, | |
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 }, | |
]; | |
const result = await TestClient.getMessagesWithinTokenLimit(messages); | |
expect(result.context).toEqual(expectedContext); | |
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens); | |
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); | |
}); | |
test('handles context strategy correctly in handleContextStrategy()', async () => { | |
TestClient.addInstructions = jest | |
.fn() | |
.mockReturnValue([ | |
{ content: 'Hello' }, | |
{ content: 'How can I help you?' }, | |
{ content: 'Please provide more details.' }, | |
{ content: 'I can assist you with that.' }, | |
]); | |
TestClient.getMessagesWithinTokenLimit = jest.fn().mockReturnValue({ | |
context: [ | |
{ content: 'How can I help you?' }, | |
{ content: 'Please provide more details.' }, | |
{ content: 'I can assist you with that.' }, | |
], | |
remainingContextTokens: 80, | |
messagesToRefine: [{ content: 'Hello' }], | |
refineIndex: 3, | |
}); | |
TestClient.refineMessages = jest.fn().mockResolvedValue({ | |
role: 'assistant', | |
content: 'Refined answer', | |
tokenCount: 30, | |
}); | |
TestClient.getTokenCountForResponse = jest.fn().mockReturnValue(40); | |
const instructions = { content: 'Please provide more details.' }; | |
const orderedMessages = [ | |
{ content: 'Hello' }, | |
{ content: 'How can I help you?' }, | |
{ content: 'Please provide more details.' }, | |
{ content: 'I can assist you with that.' }, | |
]; | |
const formattedMessages = [ | |
{ content: 'Hello' }, | |
{ content: 'How can I help you?' }, | |
{ content: 'Please provide more details.' }, | |
{ content: 'I can assist you with that.' }, | |
]; | |
const expectedResult = { | |
payload: [ | |
{ | |
content: 'Refined answer', | |
role: 'assistant', | |
tokenCount: 30, | |
}, | |
{ content: 'How can I help you?' }, | |
{ content: 'Please provide more details.' }, | |
{ content: 'I can assist you with that.' }, | |
], | |
promptTokens: expect.any(Number), | |
tokenCountMap: {}, | |
messages: expect.any(Array), | |
}; | |
const result = await TestClient.handleContextStrategy({ | |
instructions, | |
orderedMessages, | |
formattedMessages, | |
}); | |
expect(result).toEqual(expectedResult); | |
}); | |
describe('sendMessage', () => { | |
test('sendMessage should return a response message', async () => { | |
const expectedResult = expect.objectContaining({ | |
sender: TestClient.sender, | |
text: expect.any(String), | |
isCreatedByUser: false, | |
messageId: expect.any(String), | |
parentMessageId: expect.any(String), | |
conversationId: expect.any(String), | |
}); | |
const response = await TestClient.sendMessage(userMessage); | |
parentMessageId = response.messageId; | |
conversationId = response.conversationId; | |
expect(response).toEqual(expectedResult); | |
}); | |
test('sendMessage should work with provided conversationId and parentMessageId', async () => { | |
const userMessage = 'Second message in the conversation'; | |
const opts = { | |
conversationId, | |
parentMessageId, | |
getIds: jest.fn(), | |
onStart: jest.fn(), | |
}; | |
const expectedResult = expect.objectContaining({ | |
sender: TestClient.sender, | |
text: expect.any(String), | |
isCreatedByUser: false, | |
messageId: expect.any(String), | |
parentMessageId: expect.any(String), | |
conversationId: opts.conversationId, | |
}); | |
const response = await TestClient.sendMessage(userMessage, opts); | |
parentMessageId = response.messageId; | |
expect(response.conversationId).toEqual(conversationId); | |
expect(response).toEqual(expectedResult); | |
expect(opts.getIds).toHaveBeenCalled(); | |
expect(opts.onStart).toHaveBeenCalled(); | |
expect(TestClient.getBuildMessagesOptions).toHaveBeenCalled(); | |
expect(TestClient.getSaveOptions).toHaveBeenCalled(); | |
}); | |
test('should return chat history', async () => { | |
const chatMessages = await TestClient.loadHistory(conversationId, parentMessageId); | |
expect(TestClient.currentMessages).toHaveLength(4); | |
expect(chatMessages[0].text).toEqual(userMessage); | |
}); | |
test('setOptions is called with the correct arguments', async () => { | |
TestClient.setOptions = jest.fn(); | |
const opts = { conversationId: '123', parentMessageId: '456' }; | |
await TestClient.sendMessage('Hello, world!', opts); | |
expect(TestClient.setOptions).toHaveBeenCalledWith(opts); | |
TestClient.setOptions.mockClear(); | |
}); | |
test('loadHistory is called with the correct arguments', async () => { | |
const opts = { conversationId: '123', parentMessageId: '456' }; | |
await TestClient.sendMessage('Hello, world!', opts); | |
expect(TestClient.loadHistory).toHaveBeenCalledWith( | |
opts.conversationId, | |
opts.parentMessageId, | |
); | |
}); | |
test('getIds is called with the correct arguments', async () => { | |
const getIds = jest.fn(); | |
const opts = { getIds }; | |
const response = await TestClient.sendMessage('Hello, world!', opts); | |
expect(getIds).toHaveBeenCalledWith({ | |
userMessage: expect.objectContaining({ text: 'Hello, world!' }), | |
conversationId: response.conversationId, | |
responseMessageId: response.messageId, | |
}); | |
}); | |
test('onStart is called with the correct arguments', async () => { | |
const onStart = jest.fn(); | |
const opts = { onStart }; | |
await TestClient.sendMessage('Hello, world!', opts); | |
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' })); | |
}); | |
test('saveMessageToDatabase is called with the correct arguments', async () => { | |
const saveOptions = TestClient.getSaveOptions(); | |
const user = {}; // Mock user | |
const opts = { user }; | |
await TestClient.sendMessage('Hello, world!', opts); | |
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledWith( | |
expect.objectContaining({ | |
sender: expect.any(String), | |
text: expect.any(String), | |
isCreatedByUser: expect.any(Boolean), | |
messageId: expect.any(String), | |
parentMessageId: expect.any(String), | |
conversationId: expect.any(String), | |
}), | |
saveOptions, | |
user, | |
); | |
}); | |
test('sendCompletion is called with the correct arguments', async () => { | |
const payload = {}; // Mock payload | |
TestClient.buildMessages.mockReturnValue({ prompt: payload, tokenCountMap: null }); | |
const opts = {}; | |
await TestClient.sendMessage('Hello, world!', opts); | |
expect(TestClient.sendCompletion).toHaveBeenCalledWith(payload, opts); | |
}); | |
test('getTokenCountForResponse is called with the correct arguments', async () => { | |
const tokenCountMap = {}; // Mock tokenCountMap | |
TestClient.buildMessages.mockReturnValue({ prompt: [], tokenCountMap }); | |
TestClient.getTokenCountForResponse = jest.fn(); | |
const response = await TestClient.sendMessage('Hello, world!', {}); | |
expect(TestClient.getTokenCountForResponse).toHaveBeenCalledWith(response); | |
}); | |
test('returns an object with the correct shape', async () => { | |
const response = await TestClient.sendMessage('Hello, world!', {}); | |
expect(response).toEqual( | |
expect.objectContaining({ | |
sender: expect.any(String), | |
text: expect.any(String), | |
isCreatedByUser: expect.any(Boolean), | |
messageId: expect.any(String), | |
parentMessageId: expect.any(String), | |
conversationId: expect.any(String), | |
}), | |
); | |
}); | |
}); | |
}); | |