/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the Chameleon License found in the * LICENSE file in the root directory of this source tree. */ import { useEffect, useState, useRef } from "react"; import { LexicalComposer } from "@lexical/react/LexicalComposer"; import { ContentEditable } from "@lexical/react/LexicalContentEditable"; import { HistoryPlugin } from "@lexical/react/LexicalHistoryPlugin"; import { RichTextPlugin } from "@lexical/react/LexicalRichTextPlugin"; import { OnChangePlugin } from "@lexical/react/LexicalOnChangePlugin"; import DragDropPaste from "../lexical/DragDropPastePlugin"; import { ImagesPlugin } from "../lexical/ImagesPlugin"; import { ImageNode } from "../lexical/ImageNode"; import { ReplaceContentPlugin } from "../lexical/ReplaceContentPlugin"; import LexicalErrorBoundary from "@lexical/react/LexicalErrorBoundary"; import useWebSocket, { ReadyState } from "react-use-websocket"; import { z } from "zod"; import JsonView from "react18-json-view"; import { InputRange } from "../inputs/InputRange"; import { Config } from "../../Config"; import axios from "axios"; import { useHotkeys } from "react-hotkeys-hook"; import { COMPLETE, FULL_OUTPUT, FrontendMultimodalSequencePair, GENERATE_MULTIMODAL, IMAGE, PARTIAL_OUTPUT, QUEUE_STATUS, TEXT, WSContent, WSMultimodalMessage, WSOptions, ZWSMultimodalMessage, mergeTextContent, readableWsState, } from "../../DataTypes"; import { StatusBadge, StatusCategory } from "../output/StatusBadge"; import { SettingsAdjust, Close, Idea, } from "@carbon/icons-react"; import { useAdvancedMode } from "../hooks/useAdvancedMode"; import { InputShowHide } from "../inputs/InputShowHide"; import { InputToggle } from "../inputs/InputToggle"; import Markdown from "react-markdown"; import remarkGfm from "remark-gfm"; import { EOT_TOKEN } from "../../DataTypes"; import { ImageResult } from "../output/ImageResult"; enum GenerationSocketState { Generating = "GENERATING", UserWriting = "USER_WRITING", NotReady = "NOT_READY", } function makeid(length) { let result = ""; const characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; const charactersLength = characters.length; let counter = 0; while (counter < length) { result += characters.charAt(Math.floor(Math.random() * charactersLength)); counter += 1; } return result; } // Prepend an arbitrary texdt prompt to an existing list of contents export function prependTextPrompt( toPrepend: string, contents: WSContent[], ): WSContent[] { if (toPrepend.length == 0) { return contents; } const promptContent: WSContent = { content: toPrepend, content_type: TEXT, }; return [promptContent].concat(contents); } // Extract a flat list of text and image contents from the editor state export function flattenContents(obj): WSContent[] { let result: WSContent[] = []; if (!obj || !obj.children || obj.children.length === 0) return result; for (const child of obj.children) { // Only take text and image contents if (child.type === "text") { result.push({ content: child.text, content_type: TEXT }); } else if (child.type === "image") { result.push({ // TODO: Convert the src from URL to base64 image content: child.src, content_type: IMAGE, }); } const grandChildren = flattenContents(child); result = result.concat(grandChildren); } return result; } export function contentToHtml(content: WSContent, index?: number) { if (content.content_type == TEXT) { return ( {content.content} // // {content.content} // ); } else if (content.content_type == IMAGE) { return ; } else { return

Unknown content type

; } } export function GenerateMixedModal() { function Editor() { const [clientId, setClientId] = useState(makeid(8)); const [generationState, setGenerationState] = useState(GenerationSocketState.NotReady); const [contents, setContents] = useState([]); const [partialImage, setPartialImage] = useState(""); // Model hyperparams const [temp, setTemp] = useState(0.7); const [topP, setTopP] = useState(0.9); const [cfgImageWeight, setCfgImageWeight] = useState(1.2); const [cfgTextWeight, setCfgTextWeight] = useState(3.0); const [yieldEveryN, setYieldEveryN] = useState(32); const [seed, setSeed] = useState(Config.default_seed); const [maxGenTokens, setMaxGenTokens] = useState(4096); const [repetitionPenalty, setRepetitionPenalty] = useState(1.2); const [showSeed, setShowSeed] = useState(true); const [numberInQueue, setNumberInQueue] = useState(); const socketUrl = `${Config.ws_address}/ws/chameleon/v2/${clientId}`; // Array of text string or html string (i.e., an image) const [modelOutput, setModelOutput] = useState>([]); const { readyState, sendJsonMessage, lastJsonMessage, getWebSocket } = useWebSocket(socketUrl, { onOpen: () => { console.log("WS Opened"); setGenerationState(GenerationSocketState.UserWriting); }, onClose: (e) => { console.log("WS Closed", e); setGenerationState(GenerationSocketState.NotReady); }, onError: (e) => { console.log("WS Error", e); setGenerationState(GenerationSocketState.NotReady); }, // TODO: Inspect error a bit shouldReconnect: (closeEvent) => true, heartbeat: false, }); function abortGeneration() { getWebSocket()?.close(); setModelOutput([]); setGenerationState(GenerationSocketState.UserWriting); setClientId(makeid(8)); } useEffect(() => { if (lastJsonMessage != null) { const maybeMessage = ZWSMultimodalMessage.safeParse(lastJsonMessage); console.log("Message", lastJsonMessage, "Parsed", maybeMessage.success); if (maybeMessage.success) { if ( maybeMessage.data.content.length != 1 && maybeMessage.data.message_type != COMPLETE ) { console.error("Too few or too many content"); } console.log("parsed message", maybeMessage); if (maybeMessage.data.message_type == PARTIAL_OUTPUT) { // Currently, the backend only sends one content piece at a time const content = maybeMessage.data.content[0]; if (content.content_type == IMAGE) { setPartialImage(content.content); } else if (content.content_type == TEXT) { setModelOutput((prev) => { return prev.concat(maybeMessage.data.content); }); } setNumberInQueue(undefined); } else if (maybeMessage.data.message_type == FULL_OUTPUT) { // Only image gives full output, text is rendered as it // comes. const content = maybeMessage.data.content[0]; if (content.content_type == IMAGE) { setPartialImage(""); setModelOutput((prev) => { console.log("Set model image output"); return prev.concat(maybeMessage.data.content); }); } } else if (maybeMessage.data.message_type == COMPLETE) { setGenerationState(GenerationSocketState.UserWriting); } else if (maybeMessage.data.message_type == QUEUE_STATUS) { console.log("Queue Status Message", maybeMessage); // expects payload to be n_requests= setNumberInQueue( Number(maybeMessage.data.content[0].content.match(/\d+/g)), ); } } } else { console.log("Null message"); } }, [lastJsonMessage, setModelOutput]); const initialConfig = { namespace: "MyEditor", theme: { heading: { h1: "text-24 text-red-500", }, }, onError, nodes: [ImageNode], }; function onError(error) { console.error(error); } function Placeholder() { return ( <>
You can edit text and drag/paste images in the input above.
It's just like writing a mini document.
); } function onChange(editorState) { // Call toJSON on the EditorState object, which produces a serialization safe string const editorStateJSON = editorState.toJSON(); setContents(flattenContents(editorStateJSON?.root)); setExamplePrompt(null); } function onRunModelClick() { if (runButtonDisabled) return; async function prepareContent(content: WSContent): Promise { if (content.content_type == TEXT) { return content; } else if (content.content_type == IMAGE) { if (content.content.startsWith("http")) { const response = await fetch(content.content); const blob = await response.blob(); const reader = new FileReader(); return new Promise((resolve) => { reader.onload = (event) => { const result = event.target?.result; if (typeof result === "string") { resolve({ ...content, content: result }); } else { resolve(content); } }; reader.readAsDataURL(blob); }); } else { return content; } } else { console.error("Unknown content type"); return content; } } async function prepareAndRun() { if (contents.length != 0) { setModelOutput([]); setGenerationState(GenerationSocketState.Generating); const currentContent = await Promise.all( contents.map(prepareContent), ); let processedContents = currentContent; const suffix_tokens: Array = [EOT_TOKEN]; const options: WSOptions = { message_type: GENERATE_MULTIMODAL, temp: temp, top_p: topP, cfg_image_weight: cfgImageWeight, cfg_text_weight: cfgTextWeight, repetition_penalty: repetitionPenalty, yield_every_n: yieldEveryN, max_gen_tokens: maxGenTokens, suffix_tokens: suffix_tokens, seed: seed, }; const message: WSMultimodalMessage = { message_type: GENERATE_MULTIMODAL, content: processedContents, options: options, debug_info: {}, }; setContents(processedContents); sendJsonMessage(message); } } prepareAndRun().catch(console.error); } useHotkeys("ctrl+enter, cmd+enter", () => { console.log("Run Model by hotkey"); onRunModelClick(); }); const readableSocketState = readableWsState(readyState); let socketStatus: StatusCategory = "neutral"; if (readableSocketState == "Open") { socketStatus = "success"; } else if (readableSocketState == "Closed") { socketStatus = "error"; } else if (readableSocketState == "Connecting") { socketStatus = "warning"; } else { socketStatus = "error"; } const runButtonDisabled = readyState !== ReadyState.OPEN || generationState != GenerationSocketState.UserWriting; const runButtonText = runButtonDisabled ? (
) : (
Run Model {/* Use the following label when hot-key is implemented +ENTER */}
); const runButtonColor = runButtonDisabled ? "btn-neutral opacity-60" : "btn-success"; let uiStatus: StatusCategory = "neutral"; if (generationState == "USER_WRITING") { uiStatus = "success"; } else if (generationState == "GENERATING") { uiStatus = "info"; } else if (generationState == "NOT_READY") { uiStatus = "error"; } const [advancedMode, setAdvancedMode] = useAdvancedMode(); const [tutorialBanner, setTutorialBanner] = useState(true); const [examplePrompt, setExamplePrompt] = useState(null); const chatRef = useRef(null); useEffect(() => { chatRef?.current?.scrollIntoView({ behavior: "smooth", block: "end", inline: "end", }); }, [modelOutput]); return ( <>

Input

setAdvancedMode(!advancedMode)} size={24} className="hover:fill-primary cursor-pointer" />
{/* Toolbar on top, if needed */} {/* */}
} placeholder={} ErrorBoundary={LexicalErrorBoundary} />
{!tutorialBanner && ( )}
{/* Results */}

Output

{numberInQueue && numberInQueue > 0 && (
There are {numberInQueue} other users in the queue for generation.
)}
{mergeTextContent(modelOutput).map(contentToHtml)}
{/* Side panel */}

Advanced settings

setAdvancedMode(false)} />
{ setShowSeed(checked); }} /> {showSeed && seed != null && ( )} {/* Input preview */}
indexOrName !== "data" && depth > 3 } />
); } return ; }