Spaces:
Build error
Build error
import React from "react"; | |
import { io, Socket } from "socket.io-client"; | |
import { useQueryClient } from "@tanstack/react-query"; | |
import EventLogger from "#/utils/event-logger"; | |
import { handleAssistantMessage } from "#/services/actions"; | |
import { showChatError, trackError } from "#/utils/error-handler"; | |
import { useRate } from "#/hooks/use-rate"; | |
import { OpenHandsParsedEvent } from "#/types/core"; | |
import { | |
AssistantMessageAction, | |
CommandAction, | |
FileEditAction, | |
FileWriteAction, | |
OpenHandsAction, | |
UserMessageAction, | |
} from "#/types/core/actions"; | |
import { Conversation } from "#/api/open-hands.types"; | |
import { useUserProviders } from "#/hooks/use-user-providers"; | |
import { useActiveConversation } from "#/hooks/query/use-active-conversation"; | |
import { OpenHandsObservation } from "#/types/core/observations"; | |
import { | |
isAgentStateChangeObservation, | |
isErrorObservation, | |
isOpenHandsAction, | |
isOpenHandsObservation, | |
isStatusUpdate, | |
isUserMessage, | |
} from "#/types/core/guards"; | |
import { useOptimisticUserMessage } from "#/hooks/use-optimistic-user-message"; | |
import { useWSErrorMessage } from "#/hooks/use-ws-error-message"; | |
const hasValidMessageProperty = (obj: unknown): obj is { message: string } => | |
typeof obj === "object" && | |
obj !== null && | |
"message" in obj && | |
typeof obj.message === "string"; | |
const isOpenHandsEvent = (event: unknown): event is OpenHandsParsedEvent => | |
typeof event === "object" && | |
event !== null && | |
"id" in event && | |
"source" in event && | |
"message" in event && | |
"timestamp" in event; | |
const isFileWriteAction = ( | |
event: OpenHandsParsedEvent, | |
): event is FileWriteAction => "action" in event && event.action === "write"; | |
const isFileEditAction = ( | |
event: OpenHandsParsedEvent, | |
): event is FileEditAction => "action" in event && event.action === "edit"; | |
const isCommandAction = (event: OpenHandsParsedEvent): event is CommandAction => | |
"action" in event && event.action === "run"; | |
const isAssistantMessage = ( | |
event: OpenHandsParsedEvent, | |
): event is AssistantMessageAction => | |
"source" in event && | |
"type" in event && | |
event.source === "agent" && | |
event.type === "message"; | |
const isMessageAction = ( | |
event: OpenHandsParsedEvent, | |
): event is UserMessageAction | AssistantMessageAction => | |
isUserMessage(event) || isAssistantMessage(event); | |
export enum WsClientProviderStatus { | |
CONNECTED, | |
DISCONNECTED, | |
CONNECTING, | |
} | |
interface UseWsClient { | |
status: WsClientProviderStatus; | |
isLoadingMessages: boolean; | |
events: Record<string, unknown>[]; | |
parsedEvents: (OpenHandsAction | OpenHandsObservation)[]; | |
send: (event: Record<string, unknown>) => void; | |
} | |
const WsClientContext = React.createContext<UseWsClient>({ | |
status: WsClientProviderStatus.DISCONNECTED, | |
isLoadingMessages: true, | |
events: [], | |
parsedEvents: [], | |
send: () => { | |
throw new Error("not connected"); | |
}, | |
}); | |
interface WsClientProviderProps { | |
conversationId: string; | |
} | |
interface ErrorArg { | |
message?: string; | |
data?: ErrorArgData | unknown; | |
} | |
interface ErrorArgData { | |
msg_id: string; | |
} | |
export function updateStatusWhenErrorMessagePresent(data: ErrorArg | unknown) { | |
const isObject = (val: unknown): val is object => | |
!!val && typeof val === "object"; | |
const isString = (val: unknown): val is string => typeof val === "string"; | |
if (isObject(data) && "message" in data && isString(data.message)) { | |
if (data.message === "websocket error" || data.message === "timeout") { | |
return; | |
} | |
let msgId: string | undefined; | |
let metadata: Record<string, unknown> = {}; | |
if ("data" in data && isObject(data.data)) { | |
if ("msg_id" in data.data && isString(data.data.msg_id)) { | |
msgId = data.data.msg_id; | |
} | |
metadata = data.data as Record<string, unknown>; | |
} | |
showChatError({ | |
message: data.message, | |
source: "websocket", | |
metadata, | |
msgId, | |
}); | |
} | |
} | |
export function WsClientProvider({ | |
conversationId, | |
children, | |
}: React.PropsWithChildren<WsClientProviderProps>) { | |
const { removeOptimisticUserMessage } = useOptimisticUserMessage(); | |
const { setErrorMessage, removeErrorMessage } = useWSErrorMessage(); | |
const queryClient = useQueryClient(); | |
const sioRef = React.useRef<Socket | null>(null); | |
const [status, setStatus] = React.useState( | |
WsClientProviderStatus.DISCONNECTED, | |
); | |
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]); | |
const [parsedEvents, setParsedEvents] = React.useState< | |
(OpenHandsAction | OpenHandsObservation)[] | |
>([]); | |
const lastEventRef = React.useRef<Record<string, unknown> | null>(null); | |
const { providers } = useUserProviders(); | |
const messageRateHandler = useRate({ threshold: 250 }); | |
const { data: conversation, refetch: refetchConversation } = | |
useActiveConversation(); | |
function send(event: Record<string, unknown>) { | |
if (!sioRef.current) { | |
EventLogger.error("WebSocket is not connected."); | |
return; | |
} | |
sioRef.current.emit("oh_user_action", event); | |
} | |
function handleConnect() { | |
setStatus(WsClientProviderStatus.CONNECTED); | |
removeErrorMessage(); | |
} | |
function handleMessage(event: Record<string, unknown>) { | |
handleAssistantMessage(event); | |
if (isOpenHandsEvent(event)) { | |
const isStatusUpdateError = | |
isStatusUpdate(event) && event.type === "error"; | |
const isAgentStateChangeError = | |
isAgentStateChangeObservation(event) && | |
event.extras.agent_state === "error"; | |
if (isStatusUpdateError || isAgentStateChangeError) { | |
const errorMessage = isStatusUpdate(event) | |
? event.message | |
: event.extras.reason || "Unknown error"; | |
trackError({ | |
message: errorMessage, | |
source: "chat", | |
metadata: { msgId: event.id }, | |
}); | |
setErrorMessage(errorMessage); | |
return; | |
} | |
if (isOpenHandsAction(event) || isOpenHandsObservation(event)) { | |
setParsedEvents((prevEvents) => [...prevEvents, event]); | |
} | |
if (isErrorObservation(event)) { | |
trackError({ | |
message: event.message, | |
source: "chat", | |
metadata: { msgId: event.id }, | |
}); | |
} else { | |
removeErrorMessage(); | |
} | |
if (isUserMessage(event)) { | |
removeOptimisticUserMessage(); | |
} | |
if (isMessageAction(event)) { | |
messageRateHandler.record(new Date().getTime()); | |
} | |
// Invalidate diffs cache when a file is edited or written | |
if ( | |
isFileEditAction(event) || | |
isFileWriteAction(event) || | |
isCommandAction(event) | |
) { | |
queryClient.invalidateQueries( | |
{ | |
queryKey: ["file_changes", conversationId], | |
}, | |
// Do not refetch if we are still receiving messages at a high rate (e.g., loading an existing conversation) | |
// This prevents unnecessary refetches when the user is still receiving messages | |
{ cancelRefetch: false }, | |
); | |
// Invalidate file diff cache when a file is edited or written | |
if (!isCommandAction(event)) { | |
const cachedConversaton = queryClient.getQueryData<Conversation>([ | |
"user", | |
"conversation", | |
conversationId, | |
]); | |
const clonedRepositoryDirectory = | |
cachedConversaton?.selected_repository?.split("/").pop(); | |
let fileToInvalidate = event.args.path.replace("/workspace/", ""); | |
if (clonedRepositoryDirectory) { | |
fileToInvalidate = fileToInvalidate.replace( | |
`${clonedRepositoryDirectory}/`, | |
"", | |
); | |
} | |
queryClient.invalidateQueries({ | |
queryKey: ["file_diff", conversationId, fileToInvalidate], | |
}); | |
} | |
} | |
} | |
setEvents((prevEvents) => [...prevEvents, event]); | |
if (!Number.isNaN(parseInt(event.id as string, 10))) { | |
lastEventRef.current = event; | |
} | |
} | |
function handleDisconnect(data: unknown) { | |
setStatus(WsClientProviderStatus.DISCONNECTED); | |
const sio = sioRef.current; | |
if (!sio) { | |
return; | |
} | |
sio.io.opts.query = sio.io.opts.query || {}; | |
sio.io.opts.query.latest_event_id = lastEventRef.current?.id; | |
updateStatusWhenErrorMessagePresent(data); | |
setErrorMessage(hasValidMessageProperty(data) ? data.message : ""); | |
} | |
function handleError(data: unknown) { | |
// set status | |
setStatus(WsClientProviderStatus.DISCONNECTED); | |
updateStatusWhenErrorMessagePresent(data); | |
setErrorMessage( | |
hasValidMessageProperty(data) | |
? data.message | |
: "An unknown error occurred on the WebSocket connection.", | |
); | |
// check if something went wrong with the conversation. | |
refetchConversation(); | |
} | |
React.useEffect(() => { | |
lastEventRef.current = null; | |
// reset events when conversationId changes | |
setEvents([]); | |
setParsedEvents([]); | |
setStatus(WsClientProviderStatus.DISCONNECTED); | |
}, [conversationId]); | |
React.useEffect(() => { | |
if (!conversationId) { | |
throw new Error("No conversation ID provided"); | |
} | |
if ( | |
!conversation || | |
["STOPPED", "STARTING"].includes(conversation.status) | |
) { | |
return () => undefined; // conversation not yet loaded | |
} | |
let sio = sioRef.current; | |
if (sio?.connected) { | |
sio.disconnect(); | |
} | |
const lastEvent = lastEventRef.current; | |
const query = { | |
latest_event_id: lastEvent?.id ?? -1, | |
conversation_id: conversationId, | |
providers_set: providers, | |
session_api_key: conversation.session_api_key, // Have to set here because socketio doesn't support custom headers. :( | |
}; | |
let baseUrl = null; | |
if (conversation.url && !conversation.url.startsWith("/")) { | |
baseUrl = new URL(conversation.url).host; | |
} else { | |
baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host; | |
} | |
sio = io(baseUrl, { | |
transports: ["websocket"], | |
query, | |
}); | |
sio.on("connect", handleConnect); | |
sio.on("oh_event", handleMessage); | |
sio.on("connect_error", handleError); | |
sio.on("connect_failed", handleError); | |
sio.on("disconnect", handleDisconnect); | |
sioRef.current = sio; | |
return () => { | |
sio.off("connect", handleConnect); | |
sio.off("oh_event", handleMessage); | |
sio.off("connect_error", handleError); | |
sio.off("connect_failed", handleError); | |
sio.off("disconnect", handleDisconnect); | |
}; | |
}, [conversationId, conversation?.url, conversation?.status]); | |
React.useEffect( | |
() => () => { | |
const sio = sioRef.current; | |
if (sio) { | |
sio.off("disconnect", handleDisconnect); | |
sio.disconnect(); | |
} | |
}, | |
[], | |
); | |
const value = React.useMemo<UseWsClient>( | |
() => ({ | |
status, | |
isLoadingMessages: messageRateHandler.isUnderThreshold, | |
events, | |
parsedEvents, | |
send, | |
}), | |
[status, messageRateHandler.isUnderThreshold, events, parsedEvents], | |
); | |
return <WsClientContext value={value}>{children}</WsClientContext>; | |
} | |
export function useWsClient() { | |
const context = React.useContext(WsClientContext); | |
return context; | |
} | |