Spaces:
Running
Running
File size: 2,682 Bytes
a8e1cb0 cc43e3c 9882676 cc43e3c a8e1cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import { useChat, type Message } from 'ai/react';
import { useAtom } from 'jotai';
import { toast } from 'react-hot-toast';
import { datasetAtom } from '../../state';
import { useEffect, useState } from 'react';
import { MessageWithSelectedDataset } from '../types';
const useChatWithDataset = () => {
const [dataset] = useAtom(datasetAtom);
const {
messages,
append,
reload,
stop,
isLoading,
input,
setInput,
setMessages,
} = useChat({
sendExtraMessageFields: true,
onResponse(response) {
if (response.status !== 200) {
toast.error(response.statusText);
}
},
});
const [loadingDots, setLoadingDots] = useState('');
useEffect(() => {
let loadingInterval: NodeJS.Timeout;
if (isLoading) {
loadingInterval = setInterval(() => {
setLoadingDots(prevMessage => {
switch (prevMessage) {
case '':
return '.';
case '.':
return '..';
case '..':
return '...';
case '...':
return '';
default:
return '';
}
});
}, 500);
}
return () => {
clearInterval(loadingInterval);
};
}, [isLoading]);
const assistantLoadingMessage = {
id: 'loading',
content: loadingDots,
role: 'assistant',
};
const messageWithLoading =
isLoading &&
messages.length &&
messages[messages.length - 1].role !== 'assistant'
? [...messages, assistantLoadingMessage]
: messages;
const selectedDataset = dataset.find(entity => entity.selected)
? dataset.filter(entity => entity.selected)
: // If there is no selected dataset, use the entire dataset
dataset;
const appendWithDataset: typeof append = message => {
// const newSystemMessage: Message = {
// id: 'fake-id',
// content:
// 'For the next prompt, here are names of images provided by user, please use these name if you need reference: ' +
// selectedDataset.map(entity => entity.name).join(', '),
// role: 'system',
// };
// const newSystemMessage: Message = {
// id: 'fake-id',
// content: `For the next prompt, please use tags provided by the user to assign to corresponding images.
// For example:
// Input:
// red, blue, round
// Answer (each in a new line):
// Image 1: red\n
// Image 2: blue,round\n`,
// role: 'system',
// };
// setMessages([...messages, newSystemMessage]);
return append({
...message,
// @ts-ignore this is extra fields
dataset: selectedDataset,
} satisfies MessageWithSelectedDataset);
};
return {
messages: messageWithLoading as MessageWithSelectedDataset[],
append: appendWithDataset,
reload,
stop,
isLoading,
input,
setInput,
};
};
export default useChatWithDataset;
|