vision-agent / lib /hooks /useVisionAgent.ts
MingruiZhang's picture
feat: UI Improvements (#72)
0fd8446 unverified
raw
history blame
2.38 kB
import { useChat } from 'ai/react';
import { toast } from 'react-hot-toast';
import { useEffect, useRef, useState } from 'react';
import { ChatWithMessages, MessageUI, MessageUserInput } from '../types';
import {
dbPostCreateMessage,
dbPostUpdateMessageResponse,
} from '../db/functions';
import {
convertAssistantUIMessageToDBMessageResponse,
convertDBMessageToUIMessage,
} from '../utils/message';
const useVisionAgent = (chat: ChatWithMessages) => {
const { messages: dbMessages, id, mediaUrl } = chat;
const latestDbMessage = dbMessages[dbMessages.length - 1];
// Temporary solution for now while single we have to pass mediaUrl separately outside of the messages
const currMediaUrl = useRef<string>(mediaUrl);
const currMessageId = useRef<string>(latestDbMessage?.id);
const { messages, append, isLoading, reload } = useChat({
api: '/api/vision-agent',
streamMode: 'text',
onResponse(response) {
if (response.status !== 200) {
toast.error(response.statusText);
}
},
onFinish: async message => {
await dbPostUpdateMessageResponse(
currMessageId.current,
convertAssistantUIMessageToDBMessageResponse(message),
);
},
sendExtraMessageFields: true,
initialMessages: convertDBMessageToUIMessage(dbMessages),
body: {
mediaUrl: currMediaUrl.current,
id,
},
onError: err => {
err && toast.error(err.message);
},
});
/**
* If case this is first time user navigated with init message, we need to reload the chat for the first response
*/
const once = useRef(true);
useEffect(() => {
if (
!isLoading &&
messages.length === 1 &&
messages[0].role === 'user' &&
once.current
) {
once.current = false;
reload();
}
}, [isLoading, messages, reload]);
return {
messages: messages as MessageUI[],
append: async (messageInput: MessageUserInput) => {
currMediaUrl.current = messageInput.mediaUrl;
append({
id,
role: 'user',
content: messageInput.prompt,
// @ts-ignore valid when setting sendExtraMessageFields
mediaUrl: messageInput.mediaUrl,
});
const resp = await dbPostCreateMessage(id, messageInput);
currMessageId.current = resp.id;
},
reload,
isLoading,
};
};
export default useVisionAgent;