import React from "react"; import { io, Socket } from "socket.io-client"; import { useQueryClient } from "@tanstack/react-query"; import EventLogger from "#/utils/event-logger"; import { handleAssistantMessage } from "#/services/actions"; import { showChatError, trackError } from "#/utils/error-handler"; import { useRate } from "#/hooks/use-rate"; import { OpenHandsParsedEvent } from "#/types/core"; import { AssistantMessageAction, CommandAction, FileEditAction, FileWriteAction, OpenHandsAction, UserMessageAction, } from "#/types/core/actions"; import { Conversation } from "#/api/open-hands.types"; import { useUserProviders } from "#/hooks/use-user-providers"; import { useActiveConversation } from "#/hooks/query/use-active-conversation"; import { OpenHandsObservation } from "#/types/core/observations"; import { isAgentStateChangeObservation, isErrorObservation, isOpenHandsAction, isOpenHandsObservation, isStatusUpdate, isUserMessage, } from "#/types/core/guards"; import { useOptimisticUserMessage } from "#/hooks/use-optimistic-user-message"; import { useWSErrorMessage } from "#/hooks/use-ws-error-message"; const hasValidMessageProperty = (obj: unknown): obj is { message: string } => typeof obj === "object" && obj !== null && "message" in obj && typeof obj.message === "string"; const isOpenHandsEvent = (event: unknown): event is OpenHandsParsedEvent => typeof event === "object" && event !== null && "id" in event && "source" in event && "message" in event && "timestamp" in event; const isFileWriteAction = ( event: OpenHandsParsedEvent, ): event is FileWriteAction => "action" in event && event.action === "write"; const isFileEditAction = ( event: OpenHandsParsedEvent, ): event is FileEditAction => "action" in event && event.action === "edit"; const isCommandAction = (event: OpenHandsParsedEvent): event is CommandAction => "action" in event && event.action === "run"; const isAssistantMessage = ( event: OpenHandsParsedEvent, ): event is AssistantMessageAction => "source" in event && "type" in event && event.source === "agent" && event.type === "message"; const isMessageAction = ( event: OpenHandsParsedEvent, ): event is UserMessageAction | AssistantMessageAction => isUserMessage(event) || isAssistantMessage(event); export enum WsClientProviderStatus { CONNECTED, DISCONNECTED, CONNECTING, } interface UseWsClient { status: WsClientProviderStatus; isLoadingMessages: boolean; events: Record[]; parsedEvents: (OpenHandsAction | OpenHandsObservation)[]; send: (event: Record) => void; } const WsClientContext = React.createContext({ status: WsClientProviderStatus.DISCONNECTED, isLoadingMessages: true, events: [], parsedEvents: [], send: () => { throw new Error("not connected"); }, }); interface WsClientProviderProps { conversationId: string; } interface ErrorArg { message?: string; data?: ErrorArgData | unknown; } interface ErrorArgData { msg_id: string; } export function updateStatusWhenErrorMessagePresent(data: ErrorArg | unknown) { const isObject = (val: unknown): val is object => !!val && typeof val === "object"; const isString = (val: unknown): val is string => typeof val === "string"; if (isObject(data) && "message" in data && isString(data.message)) { if (data.message === "websocket error" || data.message === "timeout") { return; } let msgId: string | undefined; let metadata: Record = {}; if ("data" in data && isObject(data.data)) { if ("msg_id" in data.data && isString(data.data.msg_id)) { msgId = data.data.msg_id; } metadata = data.data as Record; } showChatError({ message: data.message, source: "websocket", metadata, msgId, }); } } export function WsClientProvider({ conversationId, children, }: React.PropsWithChildren) { const { removeOptimisticUserMessage } = useOptimisticUserMessage(); const { setErrorMessage, removeErrorMessage } = useWSErrorMessage(); const queryClient = useQueryClient(); const sioRef = React.useRef(null); const [status, setStatus] = React.useState( WsClientProviderStatus.DISCONNECTED, ); const [events, setEvents] = React.useState[]>([]); const [parsedEvents, setParsedEvents] = React.useState< (OpenHandsAction | OpenHandsObservation)[] >([]); const lastEventRef = React.useRef | null>(null); const { providers } = useUserProviders(); const messageRateHandler = useRate({ threshold: 250 }); const { data: conversation, refetch: refetchConversation } = useActiveConversation(); function send(event: Record) { if (!sioRef.current) { EventLogger.error("WebSocket is not connected."); return; } sioRef.current.emit("oh_user_action", event); } function handleConnect() { setStatus(WsClientProviderStatus.CONNECTED); removeErrorMessage(); } function handleMessage(event: Record) { handleAssistantMessage(event); if (isOpenHandsEvent(event)) { const isStatusUpdateError = isStatusUpdate(event) && event.type === "error"; const isAgentStateChangeError = isAgentStateChangeObservation(event) && event.extras.agent_state === "error"; if (isStatusUpdateError || isAgentStateChangeError) { const errorMessage = isStatusUpdate(event) ? event.message : event.extras.reason || "Unknown error"; trackError({ message: errorMessage, source: "chat", metadata: { msgId: event.id }, }); setErrorMessage(errorMessage); return; } if (isOpenHandsAction(event) || isOpenHandsObservation(event)) { setParsedEvents((prevEvents) => [...prevEvents, event]); } if (isErrorObservation(event)) { trackError({ message: event.message, source: "chat", metadata: { msgId: event.id }, }); } else { removeErrorMessage(); } if (isUserMessage(event)) { removeOptimisticUserMessage(); } if (isMessageAction(event)) { messageRateHandler.record(new Date().getTime()); } // Invalidate diffs cache when a file is edited or written if ( isFileEditAction(event) || isFileWriteAction(event) || isCommandAction(event) ) { queryClient.invalidateQueries( { queryKey: ["file_changes", conversationId], }, // Do not refetch if we are still receiving messages at a high rate (e.g., loading an existing conversation) // This prevents unnecessary refetches when the user is still receiving messages { cancelRefetch: false }, ); // Invalidate file diff cache when a file is edited or written if (!isCommandAction(event)) { const cachedConversaton = queryClient.getQueryData([ "user", "conversation", conversationId, ]); const clonedRepositoryDirectory = cachedConversaton?.selected_repository?.split("/").pop(); let fileToInvalidate = event.args.path.replace("/workspace/", ""); if (clonedRepositoryDirectory) { fileToInvalidate = fileToInvalidate.replace( `${clonedRepositoryDirectory}/`, "", ); } queryClient.invalidateQueries({ queryKey: ["file_diff", conversationId, fileToInvalidate], }); } } } setEvents((prevEvents) => [...prevEvents, event]); if (!Number.isNaN(parseInt(event.id as string, 10))) { lastEventRef.current = event; } } function handleDisconnect(data: unknown) { setStatus(WsClientProviderStatus.DISCONNECTED); const sio = sioRef.current; if (!sio) { return; } sio.io.opts.query = sio.io.opts.query || {}; sio.io.opts.query.latest_event_id = lastEventRef.current?.id; updateStatusWhenErrorMessagePresent(data); setErrorMessage(hasValidMessageProperty(data) ? data.message : ""); } function handleError(data: unknown) { // set status setStatus(WsClientProviderStatus.DISCONNECTED); updateStatusWhenErrorMessagePresent(data); setErrorMessage( hasValidMessageProperty(data) ? data.message : "An unknown error occurred on the WebSocket connection.", ); // check if something went wrong with the conversation. refetchConversation(); } React.useEffect(() => { lastEventRef.current = null; // reset events when conversationId changes setEvents([]); setParsedEvents([]); setStatus(WsClientProviderStatus.DISCONNECTED); }, [conversationId]); React.useEffect(() => { if (!conversationId) { throw new Error("No conversation ID provided"); } if ( !conversation || ["STOPPED", "STARTING"].includes(conversation.status) ) { return () => undefined; // conversation not yet loaded } let sio = sioRef.current; if (sio?.connected) { sio.disconnect(); } const lastEvent = lastEventRef.current; const query = { latest_event_id: lastEvent?.id ?? -1, conversation_id: conversationId, providers_set: providers, session_api_key: conversation.session_api_key, // Have to set here because socketio doesn't support custom headers. :( }; let baseUrl = null; if (conversation.url && !conversation.url.startsWith("/")) { baseUrl = new URL(conversation.url).host; } else { baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host; } sio = io(baseUrl, { transports: ["websocket"], query, }); sio.on("connect", handleConnect); sio.on("oh_event", handleMessage); sio.on("connect_error", handleError); sio.on("connect_failed", handleError); sio.on("disconnect", handleDisconnect); sioRef.current = sio; return () => { sio.off("connect", handleConnect); sio.off("oh_event", handleMessage); sio.off("connect_error", handleError); sio.off("connect_failed", handleError); sio.off("disconnect", handleDisconnect); }; }, [conversationId, conversation?.url, conversation?.status]); React.useEffect( () => () => { const sio = sioRef.current; if (sio) { sio.off("disconnect", handleDisconnect); sio.disconnect(); } }, [], ); const value = React.useMemo( () => ({ status, isLoadingMessages: messageRateHandler.isUnderThreshold, events, parsedEvents, send, }), [status, messageRateHandler.isUnderThreshold, events, parsedEvents], ); return {children}; } export function useWsClient() { const context = React.useContext(WsClientContext); return context; }