/** * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { Content, GenerativeContentBlob, Part } from "@google/generative-ai"; import { EventEmitter } from "eventemitter3"; import { difference } from "lodash"; import { ClientContentMessage, isInterrupted, isModelTurn, isServerContentMessage, isSetupCompleteMessage, isToolCallCancellationMessage, isToolCallMessage, isTurnComplete, LiveIncomingMessage, ModelTurn, RealtimeInputMessage, ServerContent, SetupMessage, StreamingLog, ToolCall, ToolCallCancellation, ToolResponseMessage, type LiveConfig, } from "../multimodal-live-types"; import { blobToJSON, base64ToArrayBuffer } from "./utils"; /** * the events that this client will emit */ interface MultimodalLiveClientEventTypes { open: () => void; log: (log: StreamingLog) => void; close: (event: CloseEvent) => void; audio: (data: ArrayBuffer) => void; content: (data: ServerContent) => void; interrupted: () => void; setupcomplete: () => void; turncomplete: () => void; toolcall: (toolCall: ToolCall) => void; toolcallcancellation: (toolcallCancellation: ToolCallCancellation) => void; } export type MultimodalLiveAPIClientConnection = { url?: string; apiKey?: string; }; /** * A event-emitting class that manages the connection to the websocket and emits * events to the rest of the application. * If you dont want to use react you can still use this. */ export class MultimodalLiveClient extends EventEmitter<MultimodalLiveClientEventTypes> { public ws: WebSocket | null = null; protected config: LiveConfig | null = null; public url: string; constructor({ url, apiKey }: MultimodalLiveAPIClientConnection = {}) { super(); this.url = url || `${window.location.protocol === 'https:' ? 'wss:' : 'ws:'}//${window.location.host}/ws`; this.send = this.send.bind(this); } log(type: string, message: StreamingLog["message"]) { const log: StreamingLog = { date: new Date(), type, message, }; this.emit("log", log); } connect(config: LiveConfig): Promise<boolean> { this.config = config; const ws = new WebSocket(this.url); ws.addEventListener("message", async (evt: MessageEvent) => { if (evt.data instanceof Blob) { this.receive(evt.data); } else { console.log("non blob message", evt); } }); return new Promise((resolve, reject) => { const onError = (ev: Event) => { this.disconnect(ws); const message = `Could not connect to "${this.url}"`; this.log(`server.${ev.type}`, message); reject(new Error(message)); }; ws.addEventListener("error", onError); ws.addEventListener("open", (ev: Event) => { if (!this.config) { reject("Invalid config sent to `connect(config)`"); return; } this.log(`client.${ev.type}`, `connected to socket`); this.emit("open"); this.ws = ws; const setupMessage: SetupMessage = { setup: this.config, }; this._sendDirect(setupMessage); this.log("client.send", "setup"); ws.removeEventListener("error", onError); ws.addEventListener("close", (ev: CloseEvent) => { console.log(ev); this.disconnect(ws); let reason = ev.reason || ""; if (reason.toLowerCase().includes("error")) { const prelude = "ERROR]"; const preludeIndex = reason.indexOf(prelude); if (preludeIndex > 0) { reason = reason.slice( preludeIndex + prelude.length + 1, Infinity, ); } } this.log( `server.${ev.type}`, `disconnected ${reason ? `with reason: ${reason}` : ``}`, ); this.emit("close", ev); }); resolve(true); }); }); } disconnect(ws?: WebSocket) { // could be that this is an old websocket and theres already a new instance // only close it if its still the correct reference if ((!ws || this.ws === ws) && this.ws) { this.ws.close(); this.ws = null; this.log("client.close", `Disconnected`); return true; } return false; } protected async receive(blob: Blob) { const response: LiveIncomingMessage = (await blobToJSON( blob, )) as LiveIncomingMessage; if (isToolCallMessage(response)) { this.log("server.toolCall", response); this.emit("toolcall", response.toolCall); return; } if (isToolCallCancellationMessage(response)) { this.log("receive.toolCallCancellation", response); this.emit("toolcallcancellation", response.toolCallCancellation); return; } if (isSetupCompleteMessage(response)) { this.log("server.send", "setupComplete"); this.emit("setupcomplete"); return; } // this json also might be `contentUpdate { interrupted: true }` // or contentUpdate { end_of_turn: true } if (isServerContentMessage(response)) { const { serverContent } = response; if (isInterrupted(serverContent)) { this.log("receive.serverContent", "interrupted"); this.emit("interrupted"); return; } if (isTurnComplete(serverContent)) { this.log("server.send", "turnComplete"); this.emit("turncomplete"); //plausible theres more to the message, continue } if (isModelTurn(serverContent)) { let parts: Part[] = serverContent.modelTurn.parts; // when its audio that is returned for modelTurn const audioParts = parts.filter( (p) => p.inlineData && p.inlineData.mimeType.startsWith("audio/pcm"), ); const base64s = audioParts.map((p) => p.inlineData?.data); // strip the audio parts out of the modelTurn const otherParts = difference(parts, audioParts); // console.log("otherParts", otherParts); base64s.forEach((b64) => { if (b64) { const data = base64ToArrayBuffer(b64); this.emit("audio", data); this.log(`server.audio`, `buffer (${data.byteLength})`); } }); if (!otherParts.length) { return; } parts = otherParts; const content: ModelTurn = { modelTurn: { parts } }; this.emit("content", content); this.log(`server.content`, response); } } else { console.log("received unmatched message", response); } } /** * send realtimeInput, this is base64 chunks of "audio/pcm" and/or "image/jpg" */ sendRealtimeInput(chunks: GenerativeContentBlob[]) { let hasAudio = false; let hasVideo = false; for (let i = 0; i < chunks.length; i++) { const ch = chunks[i]; if (ch.mimeType.includes("audio")) { hasAudio = true; } if (ch.mimeType.includes("image")) { hasVideo = true; } if (hasAudio && hasVideo) { break; } } const message = hasAudio && hasVideo ? "audio + video" : hasAudio ? "audio" : hasVideo ? "video" : "unknown"; const data: RealtimeInputMessage = { realtimeInput: { mediaChunks: chunks, }, }; this._sendDirect(data); this.log(`client.realtimeInput`, message); } /** * send a response to a function call and provide the id of the functions you are responding to */ sendToolResponse(toolResponse: ToolResponseMessage["toolResponse"]) { const message: ToolResponseMessage = { toolResponse, }; this._sendDirect(message); this.log(`client.toolResponse`, message); } /** * send normal content parts such as { text } */ send(parts: Part | Part[], turnComplete: boolean = true) { parts = Array.isArray(parts) ? parts : [parts]; const content: Content = { role: "user", parts, }; const clientContentRequest: ClientContentMessage = { clientContent: { turns: [content], turnComplete, }, }; this._sendDirect(clientContentRequest); this.log(`client.send`, clientContentRequest); } /** * used internally to send all messages * don't use directly unless trying to send an unsupported message type */ _sendDirect(request: object) { if (!this.ws) { throw new Error("WebSocket is not connected"); } const str = JSON.stringify(request); this.ws.send(str); } }