/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the Chameleon License found in the
* LICENSE file in the root directory of this source tree.
*/
import { useEffect, useState, useRef } from "react";
import { LexicalComposer } from "@lexical/react/LexicalComposer";
import { ContentEditable } from "@lexical/react/LexicalContentEditable";
import { HistoryPlugin } from "@lexical/react/LexicalHistoryPlugin";
import { RichTextPlugin } from "@lexical/react/LexicalRichTextPlugin";
import { OnChangePlugin } from "@lexical/react/LexicalOnChangePlugin";
import DragDropPaste from "../lexical/DragDropPastePlugin";
import { ImagesPlugin } from "../lexical/ImagesPlugin";
import { ImageNode } from "../lexical/ImageNode";
import { ReplaceContentPlugin } from "../lexical/ReplaceContentPlugin";
import LexicalErrorBoundary from "@lexical/react/LexicalErrorBoundary";
import useWebSocket, { ReadyState } from "react-use-websocket";
import { z } from "zod";
import JsonView from "react18-json-view";
import { InputRange } from "../inputs/InputRange";
import { Config } from "../../Config";
import axios from "axios";
import { useHotkeys } from "react-hotkeys-hook";
import {
COMPLETE,
FULL_OUTPUT,
FrontendMultimodalSequencePair,
GENERATE_MULTIMODAL,
IMAGE,
PARTIAL_OUTPUT,
QUEUE_STATUS,
TEXT,
WSContent,
WSMultimodalMessage,
WSOptions,
ZWSMultimodalMessage,
mergeTextContent,
readableWsState,
} from "../../DataTypes";
import { StatusBadge, StatusCategory } from "../output/StatusBadge";
import {
SettingsAdjust,
Close,
Idea,
} from "@carbon/icons-react";
import { useAdvancedMode } from "../hooks/useAdvancedMode";
import { InputShowHide } from "../inputs/InputShowHide";
import { InputToggle } from "../inputs/InputToggle";
import Markdown from "react-markdown";
import remarkGfm from "remark-gfm";
import { EOT_TOKEN } from "../../DataTypes";
import { ImageResult } from "../output/ImageResult";
enum GenerationSocketState {
Generating = "GENERATING",
UserWriting = "USER_WRITING",
NotReady = "NOT_READY",
}
function makeid(length) {
let result = "";
const characters =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
const charactersLength = characters.length;
let counter = 0;
while (counter < length) {
result += characters.charAt(Math.floor(Math.random() * charactersLength));
counter += 1;
}
return result;
}
// Prepend an arbitrary texdt prompt to an existing list of contents
export function prependTextPrompt(
toPrepend: string,
contents: WSContent[],
): WSContent[] {
if (toPrepend.length == 0) {
return contents;
}
const promptContent: WSContent = {
content: toPrepend,
content_type: TEXT,
};
return [promptContent].concat(contents);
}
// Extract a flat list of text and image contents from the editor state
export function flattenContents(obj): WSContent[] {
let result: WSContent[] = [];
if (!obj || !obj.children || obj.children.length === 0) return result;
for (const child of obj.children) {
// Only take text and image contents
if (child.type === "text") {
result.push({ content: child.text, content_type: TEXT });
} else if (child.type === "image") {
result.push({
// TODO: Convert the src from URL to base64 image
content: child.src,
content_type: IMAGE,
});
}
const grandChildren = flattenContents(child);
result = result.concat(grandChildren);
}
return result;
}
export function contentToHtml(content: WSContent, index?: number) {
if (content.content_type == TEXT) {
return (
{content.content}
//
// {content.content}
//
);
} else if (content.content_type == IMAGE) {
return ;
} else {
return
Unknown content type
;
}
}
export function GenerateMixedModal() {
function Editor() {
const [clientId, setClientId] = useState(makeid(8));
const [generationState, setGenerationState] =
useState(GenerationSocketState.NotReady);
const [contents, setContents] = useState([]);
const [partialImage, setPartialImage] = useState("");
// Model hyperparams
const [temp, setTemp] = useState(0.7);
const [topP, setTopP] = useState(0.9);
const [cfgImageWeight, setCfgImageWeight] = useState(1.2);
const [cfgTextWeight, setCfgTextWeight] = useState(3.0);
const [yieldEveryN, setYieldEveryN] = useState(32);
const [seed, setSeed] = useState(Config.default_seed);
const [maxGenTokens, setMaxGenTokens] = useState(4096);
const [repetitionPenalty, setRepetitionPenalty] = useState(1.2);
const [showSeed, setShowSeed] = useState(true);
const [numberInQueue, setNumberInQueue] = useState();
const socketUrl = `${Config.ws_address}/ws/chameleon/v2/${clientId}`;
// Array of text string or html string (i.e., an image)
const [modelOutput, setModelOutput] = useState>([]);
const { readyState, sendJsonMessage, lastJsonMessage, getWebSocket } =
useWebSocket(socketUrl, {
onOpen: () => {
console.log("WS Opened");
setGenerationState(GenerationSocketState.UserWriting);
},
onClose: (e) => {
console.log("WS Closed", e);
setGenerationState(GenerationSocketState.NotReady);
},
onError: (e) => {
console.log("WS Error", e);
setGenerationState(GenerationSocketState.NotReady);
},
// TODO: Inspect error a bit
shouldReconnect: (closeEvent) => true,
heartbeat: false,
});
function abortGeneration() {
getWebSocket()?.close();
setModelOutput([]);
setGenerationState(GenerationSocketState.UserWriting);
setClientId(makeid(8));
}
useEffect(() => {
if (lastJsonMessage != null) {
const maybeMessage = ZWSMultimodalMessage.safeParse(lastJsonMessage);
console.log("Message", lastJsonMessage, "Parsed", maybeMessage.success);
if (maybeMessage.success) {
if (
maybeMessage.data.content.length != 1 &&
maybeMessage.data.message_type != COMPLETE
) {
console.error("Too few or too many content");
}
console.log("parsed message", maybeMessage);
if (maybeMessage.data.message_type == PARTIAL_OUTPUT) {
// Currently, the backend only sends one content piece at a time
const content = maybeMessage.data.content[0];
if (content.content_type == IMAGE) {
setPartialImage(content.content);
} else if (content.content_type == TEXT) {
setModelOutput((prev) => {
return prev.concat(maybeMessage.data.content);
});
}
setNumberInQueue(undefined);
} else if (maybeMessage.data.message_type == FULL_OUTPUT) {
// Only image gives full output, text is rendered as it
// comes.
const content = maybeMessage.data.content[0];
if (content.content_type == IMAGE) {
setPartialImage("");
setModelOutput((prev) => {
console.log("Set model image output");
return prev.concat(maybeMessage.data.content);
});
}
} else if (maybeMessage.data.message_type == COMPLETE) {
setGenerationState(GenerationSocketState.UserWriting);
} else if (maybeMessage.data.message_type == QUEUE_STATUS) {
console.log("Queue Status Message", maybeMessage);
// expects payload to be n_requests=
setNumberInQueue(
Number(maybeMessage.data.content[0].content.match(/\d+/g)),
);
}
}
} else {
console.log("Null message");
}
}, [lastJsonMessage, setModelOutput]);
const initialConfig = {
namespace: "MyEditor",
theme: {
heading: {
h1: "text-24 text-red-500",
},
},
onError,
nodes: [ImageNode],
};
function onError(error) {
console.error(error);
}
function Placeholder() {
return (
<>
You can edit text and drag/paste images in the input above.
It's just like writing a mini document.
>
);
}
function onChange(editorState) {
// Call toJSON on the EditorState object, which produces a serialization safe string
const editorStateJSON = editorState.toJSON();
setContents(flattenContents(editorStateJSON?.root));
setExamplePrompt(null);
}
function onRunModelClick() {
if (runButtonDisabled) return;
async function prepareContent(content: WSContent): Promise {
if (content.content_type == TEXT) {
return content;
} else if (content.content_type == IMAGE) {
if (content.content.startsWith("http")) {
const response = await fetch(content.content);
const blob = await response.blob();
const reader = new FileReader();
return new Promise((resolve) => {
reader.onload = (event) => {
const result = event.target?.result;
if (typeof result === "string") {
resolve({ ...content, content: result });
} else {
resolve(content);
}
};
reader.readAsDataURL(blob);
});
} else {
return content;
}
} else {
console.error("Unknown content type");
return content;
}
}
async function prepareAndRun() {
if (contents.length != 0) {
setModelOutput([]);
setGenerationState(GenerationSocketState.Generating);
const currentContent = await Promise.all(
contents.map(prepareContent),
);
let processedContents = currentContent;
const suffix_tokens: Array = [EOT_TOKEN];
const options: WSOptions = {
message_type: GENERATE_MULTIMODAL,
temp: temp,
top_p: topP,
cfg_image_weight: cfgImageWeight,
cfg_text_weight: cfgTextWeight,
repetition_penalty: repetitionPenalty,
yield_every_n: yieldEveryN,
max_gen_tokens: maxGenTokens,
suffix_tokens: suffix_tokens,
seed: seed,
};
const message: WSMultimodalMessage = {
message_type: GENERATE_MULTIMODAL,
content: processedContents,
options: options,
debug_info: {},
};
setContents(processedContents);
sendJsonMessage(message);
}
}
prepareAndRun().catch(console.error);
}
useHotkeys("ctrl+enter, cmd+enter", () => {
console.log("Run Model by hotkey");
onRunModelClick();
});
const readableSocketState = readableWsState(readyState);
let socketStatus: StatusCategory = "neutral";
if (readableSocketState == "Open") {
socketStatus = "success";
} else if (readableSocketState == "Closed") {
socketStatus = "error";
} else if (readableSocketState == "Connecting") {
socketStatus = "warning";
} else {
socketStatus = "error";
}
const runButtonDisabled =
readyState !== ReadyState.OPEN ||
generationState != GenerationSocketState.UserWriting;
const runButtonText = runButtonDisabled ? (
) : (
Run Model
{/* Use the following label when hot-key is implemented
+ENTER
*/}
);
const runButtonColor = runButtonDisabled
? "btn-neutral opacity-60"
: "btn-success";
let uiStatus: StatusCategory = "neutral";
if (generationState == "USER_WRITING") {
uiStatus = "success";
} else if (generationState == "GENERATING") {
uiStatus = "info";
} else if (generationState == "NOT_READY") {
uiStatus = "error";
}
const [advancedMode, setAdvancedMode] = useAdvancedMode();
const [tutorialBanner, setTutorialBanner] = useState(true);
const [examplePrompt, setExamplePrompt] = useState(null);
const chatRef = useRef(null);
useEffect(() => {
chatRef?.current?.scrollIntoView({
behavior: "smooth",
block: "end",
inline: "end",
});
}, [modelOutput]);
return (
<>
Input
setAdvancedMode(!advancedMode)}
size={24}
className="hover:fill-primary cursor-pointer"
/>
{/* Toolbar on top, if needed */}
{/* */}
}
placeholder={
}
ErrorBoundary={LexicalErrorBoundary}
/>
{!tutorialBanner && (
)}
{/* Results */}
Output
{numberInQueue && numberInQueue > 0 && (
There are {numberInQueue} other users in the queue for
generation.
)}
{mergeTextContent(modelOutput).map(contentToHtml)}
{/* Side panel */}
Advanced settings
setAdvancedMode(false)}
/>
{
setShowSeed(checked);
}}
/>
{showSeed && seed != null && (
)}
{/* Input preview */}
indexOrName !== "data" && depth > 3
}
/>
>
);
}
return ;
}