import { useChat } from 'ai/react'; import { toast } from 'react-hot-toast'; import { useEffect, useRef, useState } from 'react'; import { ChatWithMessages, MessageUI, MessageUserInput } from '../types'; import { dbPostCreateMessage, dbPostUpdateMessageResponse, } from '../db/functions'; import { convertAssistantUIMessageToDBMessageResponse, convertDBMessageToUIMessage, } from '../utils/message'; import { useSetAtom } from 'jotai'; import { selectedMessageId } from '@/state/chat'; import { Message } from '@prisma/client'; const useVisionAgent = (chat: ChatWithMessages) => { const { messages: dbMessages, id, mediaUrl } = chat; const latestDbMessage = dbMessages[dbMessages.length - 1]; const setMessageId = useSetAtom(selectedMessageId); // Temporary solution for now while single we have to pass mediaUrl separately outside of the messages const currMediaUrl = useRef(mediaUrl); const currMessageId = useRef(latestDbMessage?.id); const { messages, append, isLoading, reload } = useChat({ api: '/api/vision-agent', streamMode: 'text', onResponse(response) { if (response.status !== 200) { toast.error(response.statusText); } }, onFinish: async message => { await dbPostUpdateMessageResponse( currMessageId.current, convertAssistantUIMessageToDBMessageResponse(message), ); setMessageId(currMessageId.current); }, initialMessages: convertDBMessageToUIMessage(dbMessages), body: { mediaUrl: currMediaUrl.current, id, }, onError: err => { err && toast.error(err.message); }, }); /** * If case this is first time user navigated with init message, we need to reload the chat for the first response */ const once = useRef(true); useEffect(() => { if ( !isLoading && messages.length === 1 && messages[0].role === 'user' && once.current ) { once.current = false; reload(); } }, [isLoading, messages, reload]); return { messages: messages as MessageUI[], append: (message: Message) => { currMediaUrl.current = message.mediaUrl; currMessageId.current = message.id; append({ id, role: 'user', content: message.prompt, }); }, reload, isLoading, }; }; export default useVisionAgent;