Spaces:
Running
Running
/** | |
* Copyright 2024 Google LLC | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
import { Content, GenerativeContentBlob, Part } from "@google/generative-ai"; | |
import { EventEmitter } from "eventemitter3"; | |
import { difference } from "lodash"; | |
import { | |
ClientContentMessage, | |
isInterrupted, | |
isModelTurn, | |
isServerContentMessage, | |
isSetupCompleteMessage, | |
isToolCallCancellationMessage, | |
isToolCallMessage, | |
isTurnComplete, | |
LiveIncomingMessage, | |
ModelTurn, | |
RealtimeInputMessage, | |
ServerContent, | |
SetupMessage, | |
StreamingLog, | |
ToolCall, | |
ToolCallCancellation, | |
ToolResponseMessage, | |
type LiveConfig, | |
} from "../multimodal-live-types"; | |
import { blobToJSON, base64ToArrayBuffer } from "./utils"; | |
/** | |
* the events that this client will emit | |
*/ | |
interface MultimodalLiveClientEventTypes { | |
open: () => void; | |
log: (log: StreamingLog) => void; | |
close: (event: CloseEvent) => void; | |
audio: (data: ArrayBuffer) => void; | |
content: (data: ServerContent) => void; | |
interrupted: () => void; | |
setupcomplete: () => void; | |
turncomplete: () => void; | |
toolcall: (toolCall: ToolCall) => void; | |
toolcallcancellation: (toolcallCancellation: ToolCallCancellation) => void; | |
} | |
export type MultimodalLiveAPIClientConnection = { | |
url?: string; | |
apiKey?: string; | |
}; | |
/** | |
* A event-emitting class that manages the connection to the websocket and emits | |
* events to the rest of the application. | |
* If you dont want to use react you can still use this. | |
*/ | |
export class MultimodalLiveClient extends EventEmitter<MultimodalLiveClientEventTypes> { | |
public ws: WebSocket | null = null; | |
protected config: LiveConfig | null = null; | |
public url: string; | |
constructor({ url, apiKey }: MultimodalLiveAPIClientConnection = {}) { | |
super(); | |
this.url = url || `${window.location.protocol === 'https:' ? 'wss:' : 'ws:'}//${window.location.host}/ws`; | |
this.send = this.send.bind(this); | |
} | |
log(type: string, message: StreamingLog["message"]) { | |
const log: StreamingLog = { | |
date: new Date(), | |
type, | |
message, | |
}; | |
this.emit("log", log); | |
} | |
connect(config: LiveConfig): Promise<boolean> { | |
this.config = config; | |
const ws = new WebSocket(this.url); | |
ws.addEventListener("message", async (evt: MessageEvent) => { | |
if (evt.data instanceof Blob) { | |
this.receive(evt.data); | |
} else { | |
console.log("non blob message", evt); | |
} | |
}); | |
return new Promise((resolve, reject) => { | |
const onError = (ev: Event) => { | |
this.disconnect(ws); | |
const message = `Could not connect to "${this.url}"`; | |
this.log(`server.${ev.type}`, message); | |
reject(new Error(message)); | |
}; | |
ws.addEventListener("error", onError); | |
ws.addEventListener("open", (ev: Event) => { | |
if (!this.config) { | |
reject("Invalid config sent to `connect(config)`"); | |
return; | |
} | |
this.log(`client.${ev.type}`, `connected to socket`); | |
this.emit("open"); | |
this.ws = ws; | |
const setupMessage: SetupMessage = { | |
setup: this.config, | |
}; | |
this._sendDirect(setupMessage); | |
this.log("client.send", "setup"); | |
ws.removeEventListener("error", onError); | |
ws.addEventListener("close", (ev: CloseEvent) => { | |
console.log(ev); | |
this.disconnect(ws); | |
let reason = ev.reason || ""; | |
if (reason.toLowerCase().includes("error")) { | |
const prelude = "ERROR]"; | |
const preludeIndex = reason.indexOf(prelude); | |
if (preludeIndex > 0) { | |
reason = reason.slice( | |
preludeIndex + prelude.length + 1, | |
Infinity, | |
); | |
} | |
} | |
this.log( | |
`server.${ev.type}`, | |
`disconnected ${reason ? `with reason: ${reason}` : ``}`, | |
); | |
this.emit("close", ev); | |
}); | |
resolve(true); | |
}); | |
}); | |
} | |
disconnect(ws?: WebSocket) { | |
// could be that this is an old websocket and theres already a new instance | |
// only close it if its still the correct reference | |
if ((!ws || this.ws === ws) && this.ws) { | |
this.ws.close(); | |
this.ws = null; | |
this.log("client.close", `Disconnected`); | |
return true; | |
} | |
return false; | |
} | |
protected async receive(blob: Blob) { | |
const response: LiveIncomingMessage = (await blobToJSON( | |
blob, | |
)) as LiveIncomingMessage; | |
if (isToolCallMessage(response)) { | |
this.log("server.toolCall", response); | |
this.emit("toolcall", response.toolCall); | |
return; | |
} | |
if (isToolCallCancellationMessage(response)) { | |
this.log("receive.toolCallCancellation", response); | |
this.emit("toolcallcancellation", response.toolCallCancellation); | |
return; | |
} | |
if (isSetupCompleteMessage(response)) { | |
this.log("server.send", "setupComplete"); | |
this.emit("setupcomplete"); | |
return; | |
} | |
// this json also might be `contentUpdate { interrupted: true }` | |
// or contentUpdate { end_of_turn: true } | |
if (isServerContentMessage(response)) { | |
const { serverContent } = response; | |
if (isInterrupted(serverContent)) { | |
this.log("receive.serverContent", "interrupted"); | |
this.emit("interrupted"); | |
return; | |
} | |
if (isTurnComplete(serverContent)) { | |
this.log("server.send", "turnComplete"); | |
this.emit("turncomplete"); | |
//plausible theres more to the message, continue | |
} | |
if (isModelTurn(serverContent)) { | |
let parts: Part[] = serverContent.modelTurn.parts; | |
// when its audio that is returned for modelTurn | |
const audioParts = parts.filter( | |
(p) => p.inlineData && p.inlineData.mimeType.startsWith("audio/pcm"), | |
); | |
const base64s = audioParts.map((p) => p.inlineData?.data); | |
// strip the audio parts out of the modelTurn | |
const otherParts = difference(parts, audioParts); | |
// console.log("otherParts", otherParts); | |
base64s.forEach((b64) => { | |
if (b64) { | |
const data = base64ToArrayBuffer(b64); | |
this.emit("audio", data); | |
this.log(`server.audio`, `buffer (${data.byteLength})`); | |
} | |
}); | |
if (!otherParts.length) { | |
return; | |
} | |
parts = otherParts; | |
const content: ModelTurn = { modelTurn: { parts } }; | |
this.emit("content", content); | |
this.log(`server.content`, response); | |
} | |
} else { | |
console.log("received unmatched message", response); | |
} | |
} | |
/** | |
* send realtimeInput, this is base64 chunks of "audio/pcm" and/or "image/jpg" | |
*/ | |
sendRealtimeInput(chunks: GenerativeContentBlob[]) { | |
let hasAudio = false; | |
let hasVideo = false; | |
for (let i = 0; i < chunks.length; i++) { | |
const ch = chunks[i]; | |
if (ch.mimeType.includes("audio")) { | |
hasAudio = true; | |
} | |
if (ch.mimeType.includes("image")) { | |
hasVideo = true; | |
} | |
if (hasAudio && hasVideo) { | |
break; | |
} | |
} | |
const message = | |
hasAudio && hasVideo | |
? "audio + video" | |
: hasAudio | |
? "audio" | |
: hasVideo | |
? "video" | |
: "unknown"; | |
const data: RealtimeInputMessage = { | |
realtimeInput: { | |
mediaChunks: chunks, | |
}, | |
}; | |
this._sendDirect(data); | |
this.log(`client.realtimeInput`, message); | |
} | |
/** | |
* send a response to a function call and provide the id of the functions you are responding to | |
*/ | |
sendToolResponse(toolResponse: ToolResponseMessage["toolResponse"]) { | |
const message: ToolResponseMessage = { | |
toolResponse, | |
}; | |
this._sendDirect(message); | |
this.log(`client.toolResponse`, message); | |
} | |
/** | |
* send normal content parts such as { text } | |
*/ | |
send(parts: Part | Part[], turnComplete: boolean = true) { | |
parts = Array.isArray(parts) ? parts : [parts]; | |
const content: Content = { | |
role: "user", | |
parts, | |
}; | |
const clientContentRequest: ClientContentMessage = { | |
clientContent: { | |
turns: [content], | |
turnComplete, | |
}, | |
}; | |
this._sendDirect(clientContentRequest); | |
this.log(`client.send`, clientContentRequest); | |
} | |
/** | |
* used internally to send all messages | |
* don't use directly unless trying to send an unsupported message type | |
*/ | |
_sendDirect(request: object) { | |
if (!this.ws) { | |
throw new Error("WebSocket is not connected"); | |
} | |
const str = JSON.stringify(request); | |
this.ws.send(str); | |
} | |
} | |