Spaces:
Runtime error
Runtime error
/* | |
* Copyright (c) Meta Platforms, Inc. and affiliates. | |
* | |
* This source code is licensed under the Chameleon License found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
import { useEffect, useState, useRef } from "react"; | |
import { LexicalComposer } from "@lexical/react/LexicalComposer"; | |
import { ContentEditable } from "@lexical/react/LexicalContentEditable"; | |
import { HistoryPlugin } from "@lexical/react/LexicalHistoryPlugin"; | |
import { RichTextPlugin } from "@lexical/react/LexicalRichTextPlugin"; | |
import { OnChangePlugin } from "@lexical/react/LexicalOnChangePlugin"; | |
import DragDropPaste from "../lexical/DragDropPastePlugin"; | |
import { ImagesPlugin } from "../lexical/ImagesPlugin"; | |
import { ImageNode } from "../lexical/ImageNode"; | |
import { ReplaceContentPlugin } from "../lexical/ReplaceContentPlugin"; | |
import LexicalErrorBoundary from "@lexical/react/LexicalErrorBoundary"; | |
import useWebSocket, { ReadyState } from "react-use-websocket"; | |
import { z } from "zod"; | |
import JsonView from "react18-json-view"; | |
import { InputRange } from "../inputs/InputRange"; | |
import { Config } from "../../Config"; | |
import axios from "axios"; | |
import { useHotkeys } from "react-hotkeys-hook"; | |
import { | |
COMPLETE, | |
FULL_OUTPUT, | |
FrontendMultimodalSequencePair, | |
GENERATE_MULTIMODAL, | |
IMAGE, | |
PARTIAL_OUTPUT, | |
QUEUE_STATUS, | |
TEXT, | |
WSContent, | |
WSMultimodalMessage, | |
WSOptions, | |
ZWSMultimodalMessage, | |
mergeTextContent, | |
readableWsState, | |
} from "../../DataTypes"; | |
import { StatusBadge, StatusCategory } from "../output/StatusBadge"; | |
import { | |
SettingsAdjust, | |
Close, | |
Idea, | |
} from "@carbon/icons-react"; | |
import { useAdvancedMode } from "../hooks/useAdvancedMode"; | |
import { InputShowHide } from "../inputs/InputShowHide"; | |
import { InputToggle } from "../inputs/InputToggle"; | |
import Markdown from "react-markdown"; | |
import remarkGfm from "remark-gfm"; | |
import { EOT_TOKEN } from "../../DataTypes"; | |
import { ImageResult } from "../output/ImageResult"; | |
enum GenerationSocketState { | |
Generating = "GENERATING", | |
UserWriting = "USER_WRITING", | |
NotReady = "NOT_READY", | |
} | |
function makeid(length) { | |
let result = ""; | |
const characters = | |
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; | |
const charactersLength = characters.length; | |
let counter = 0; | |
while (counter < length) { | |
result += characters.charAt(Math.floor(Math.random() * charactersLength)); | |
counter += 1; | |
} | |
return result; | |
} | |
// Prepend an arbitrary texdt prompt to an existing list of contents | |
export function prependTextPrompt( | |
toPrepend: string, | |
contents: WSContent[], | |
): WSContent[] { | |
if (toPrepend.length == 0) { | |
return contents; | |
} | |
const promptContent: WSContent = { | |
content: toPrepend, | |
content_type: TEXT, | |
}; | |
return [promptContent].concat(contents); | |
} | |
// Extract a flat list of text and image contents from the editor state | |
export function flattenContents(obj): WSContent[] { | |
let result: WSContent[] = []; | |
if (!obj || !obj.children || obj.children.length === 0) return result; | |
for (const child of obj.children) { | |
// Only take text and image contents | |
if (child.type === "text") { | |
result.push({ content: child.text, content_type: TEXT }); | |
} else if (child.type === "image") { | |
result.push({ | |
// TODO: Convert the src from URL to base64 image | |
content: child.src, | |
content_type: IMAGE, | |
}); | |
} | |
const grandChildren = flattenContents(child); | |
result = result.concat(grandChildren); | |
} | |
return result; | |
} | |
export function contentToHtml(content: WSContent, index?: number) { | |
if (content.content_type == TEXT) { | |
return ( | |
<Markdown remarkPlugins={[remarkGfm]} key={`t${index}`}> | |
{content.content} | |
</Markdown> | |
// <code style={{ whiteSpace: "pre-wrap" }} key={`code${index}`}> | |
// {content.content} | |
// </code> | |
); | |
} else if (content.content_type == IMAGE) { | |
return <ImageResult src={content.content} key={`img${index}`} />; | |
} else { | |
return <p key={`p${index}`}>Unknown content type</p>; | |
} | |
} | |
export function GenerateMixedModal() { | |
function Editor() { | |
const [clientId, setClientId] = useState<string>(makeid(8)); | |
const [generationState, setGenerationState] = | |
useState<GenerationSocketState>(GenerationSocketState.NotReady); | |
const [contents, setContents] = useState<WSContent[]>([]); | |
const [partialImage, setPartialImage] = useState<string>(""); | |
// Model hyperparams | |
const [temp, setTemp] = useState<number>(0.7); | |
const [topP, setTopP] = useState<number>(0.9); | |
const [cfgImageWeight, setCfgImageWeight] = useState<number>(1.2); | |
const [cfgTextWeight, setCfgTextWeight] = useState<number>(3.0); | |
const [yieldEveryN, setYieldEveryN] = useState<number>(32); | |
const [seed, setSeed] = useState<number | null>(Config.default_seed); | |
const [maxGenTokens, setMaxGenTokens] = useState<number>(4096); | |
const [repetitionPenalty, setRepetitionPenalty] = useState<number>(1.2); | |
const [showSeed, setShowSeed] = useState<boolean>(true); | |
const [numberInQueue, setNumberInQueue] = useState<number>(); | |
const socketUrl = `${Config.ws_address}/ws/chameleon/v2/${clientId}`; | |
// Array of text string or html string (i.e., an image) | |
const [modelOutput, setModelOutput] = useState<Array<WSContent>>([]); | |
const { readyState, sendJsonMessage, lastJsonMessage, getWebSocket } = | |
useWebSocket(socketUrl, { | |
onOpen: () => { | |
console.log("WS Opened"); | |
setGenerationState(GenerationSocketState.UserWriting); | |
}, | |
onClose: (e) => { | |
console.log("WS Closed", e); | |
setGenerationState(GenerationSocketState.NotReady); | |
}, | |
onError: (e) => { | |
console.log("WS Error", e); | |
setGenerationState(GenerationSocketState.NotReady); | |
}, | |
// TODO: Inspect error a bit | |
shouldReconnect: (closeEvent) => true, | |
heartbeat: false, | |
}); | |
function abortGeneration() { | |
getWebSocket()?.close(); | |
setModelOutput([]); | |
setGenerationState(GenerationSocketState.UserWriting); | |
setClientId(makeid(8)); | |
} | |
useEffect(() => { | |
if (lastJsonMessage != null) { | |
const maybeMessage = ZWSMultimodalMessage.safeParse(lastJsonMessage); | |
console.log("Message", lastJsonMessage, "Parsed", maybeMessage.success); | |
if (maybeMessage.success) { | |
if ( | |
maybeMessage.data.content.length != 1 && | |
maybeMessage.data.message_type != COMPLETE | |
) { | |
console.error("Too few or too many content"); | |
} | |
console.log("parsed message", maybeMessage); | |
if (maybeMessage.data.message_type == PARTIAL_OUTPUT) { | |
// Currently, the backend only sends one content piece at a time | |
const content = maybeMessage.data.content[0]; | |
if (content.content_type == IMAGE) { | |
setPartialImage(content.content); | |
} else if (content.content_type == TEXT) { | |
setModelOutput((prev) => { | |
return prev.concat(maybeMessage.data.content); | |
}); | |
} | |
setNumberInQueue(undefined); | |
} else if (maybeMessage.data.message_type == FULL_OUTPUT) { | |
// Only image gives full output, text is rendered as it | |
// comes. | |
const content = maybeMessage.data.content[0]; | |
if (content.content_type == IMAGE) { | |
setPartialImage(""); | |
setModelOutput((prev) => { | |
console.log("Set model image output"); | |
return prev.concat(maybeMessage.data.content); | |
}); | |
} | |
} else if (maybeMessage.data.message_type == COMPLETE) { | |
setGenerationState(GenerationSocketState.UserWriting); | |
} else if (maybeMessage.data.message_type == QUEUE_STATUS) { | |
console.log("Queue Status Message", maybeMessage); | |
// expects payload to be n_requests=<number> | |
setNumberInQueue( | |
Number(maybeMessage.data.content[0].content.match(/\d+/g)), | |
); | |
} | |
} | |
} else { | |
console.log("Null message"); | |
} | |
}, [lastJsonMessage, setModelOutput]); | |
const initialConfig = { | |
namespace: "MyEditor", | |
theme: { | |
heading: { | |
h1: "text-24 text-red-500", | |
}, | |
}, | |
onError, | |
nodes: [ImageNode], | |
}; | |
function onError(error) { | |
console.error(error); | |
} | |
function Placeholder() { | |
return ( | |
<> | |
<div className="absolute top-4 left-4 z-0 select-none pointer-events-none opacity-50 prose"> | |
You can edit text and drag/paste images in the input above.<br /> | |
It's just like writing a mini document. | |
</div> | |
</> | |
); | |
} | |
function onChange(editorState) { | |
// Call toJSON on the EditorState object, which produces a serialization safe string | |
const editorStateJSON = editorState.toJSON(); | |
setContents(flattenContents(editorStateJSON?.root)); | |
setExamplePrompt(null); | |
} | |
function onRunModelClick() { | |
if (runButtonDisabled) return; | |
async function prepareContent(content: WSContent): Promise<WSContent> { | |
if (content.content_type == TEXT) { | |
return content; | |
} else if (content.content_type == IMAGE) { | |
if (content.content.startsWith("http")) { | |
const response = await fetch(content.content); | |
const blob = await response.blob(); | |
const reader = new FileReader(); | |
return new Promise((resolve) => { | |
reader.onload = (event) => { | |
const result = event.target?.result; | |
if (typeof result === "string") { | |
resolve({ ...content, content: result }); | |
} else { | |
resolve(content); | |
} | |
}; | |
reader.readAsDataURL(blob); | |
}); | |
} else { | |
return content; | |
} | |
} else { | |
console.error("Unknown content type"); | |
return content; | |
} | |
} | |
async function prepareAndRun() { | |
if (contents.length != 0) { | |
setModelOutput([]); | |
setGenerationState(GenerationSocketState.Generating); | |
const currentContent = await Promise.all( | |
contents.map(prepareContent), | |
); | |
let processedContents = currentContent; | |
const suffix_tokens: Array<string> = [EOT_TOKEN]; | |
const options: WSOptions = { | |
message_type: GENERATE_MULTIMODAL, | |
temp: temp, | |
top_p: topP, | |
cfg_image_weight: cfgImageWeight, | |
cfg_text_weight: cfgTextWeight, | |
repetition_penalty: repetitionPenalty, | |
yield_every_n: yieldEveryN, | |
max_gen_tokens: maxGenTokens, | |
suffix_tokens: suffix_tokens, | |
seed: seed, | |
}; | |
const message: WSMultimodalMessage = { | |
message_type: GENERATE_MULTIMODAL, | |
content: processedContents, | |
options: options, | |
debug_info: {}, | |
}; | |
setContents(processedContents); | |
sendJsonMessage(message); | |
} | |
} | |
prepareAndRun().catch(console.error); | |
} | |
useHotkeys("ctrl+enter, cmd+enter", () => { | |
console.log("Run Model by hotkey"); | |
onRunModelClick(); | |
}); | |
const readableSocketState = readableWsState(readyState); | |
let socketStatus: StatusCategory = "neutral"; | |
if (readableSocketState == "Open") { | |
socketStatus = "success"; | |
} else if (readableSocketState == "Closed") { | |
socketStatus = "error"; | |
} else if (readableSocketState == "Connecting") { | |
socketStatus = "warning"; | |
} else { | |
socketStatus = "error"; | |
} | |
const runButtonDisabled = | |
readyState !== ReadyState.OPEN || | |
generationState != GenerationSocketState.UserWriting; | |
const runButtonText = runButtonDisabled ? ( | |
<div className="loading loading-infinity loading-lg text-neutral"></div> | |
) : ( | |
<div className="flex flex-row items-center"> | |
Run Model | |
{/* Use the following label when hot-key is implemented | |
<span className="flex flex-row items-center ml-2 text-[10px] text-gray-600"> | |
<MacCommand size={12} className="inline" /> | |
+ENTER | |
</span> */} | |
</div> | |
); | |
const runButtonColor = runButtonDisabled | |
? "btn-neutral opacity-60" | |
: "btn-success"; | |
let uiStatus: StatusCategory = "neutral"; | |
if (generationState == "USER_WRITING") { | |
uiStatus = "success"; | |
} else if (generationState == "GENERATING") { | |
uiStatus = "info"; | |
} else if (generationState == "NOT_READY") { | |
uiStatus = "error"; | |
} | |
const [advancedMode, setAdvancedMode] = useAdvancedMode(); | |
const [tutorialBanner, setTutorialBanner] = useState(true); | |
const [examplePrompt, setExamplePrompt] = useState<string | null>(null); | |
const chatRef = useRef<HTMLDivElement>(null); | |
useEffect(() => { | |
chatRef?.current?.scrollIntoView({ | |
behavior: "smooth", | |
block: "end", | |
inline: "end", | |
}); | |
}, [modelOutput]); | |
return ( | |
<> | |
<div className="flex-1 flex flex-col min-h-[calc(100vh-150px)] max-h-[calc(100vh-150px)]"> | |
<div | |
className={`flex-1 flex flex-col relative overflow-x-hidden mb-10`} | |
> | |
<div | |
className={`flex-1 flex flex-row items-stretch gap-4 max-h-[calc(100vh-200px)] ${ | |
advancedMode ? "ml-[500px]" : "ml-0" | |
} transition-all`} | |
> | |
<div className="flex-1 flex flex-col relative rounded-md px-6 py-4 bg-purple-50 gap-8"> | |
<div className="flex flex-row items-center justify-between"> | |
<div className="prose"> | |
<h4>Input</h4> | |
</div> | |
<SettingsAdjust | |
onClick={() => setAdvancedMode(!advancedMode)} | |
size={24} | |
className="hover:fill-primary cursor-pointer" | |
/> | |
</div> | |
<div className="flex flex-col flex-1 items-stretch overflow-y-scroll h-full"> | |
<LexicalComposer initialConfig={initialConfig}> | |
{/* Toolbar on top, if needed */} | |
{/* <ToolbarPlugin /> */} | |
<div className="relative flex-1"> | |
<RichTextPlugin | |
contentEditable={ | |
<ContentEditable | |
className={`relative bg-white ${ | |
tutorialBanner ? "rounded-t-md" : "rounded-md" | |
} block p-4 leading-5 text-md h-full`} | |
/> | |
} | |
placeholder={<Placeholder />} | |
ErrorBoundary={LexicalErrorBoundary} | |
/> | |
</div> | |
<DragDropPaste /> | |
<HistoryPlugin /> | |
<ImagesPlugin /> | |
<OnChangePlugin onChange={onChange} /> | |
<ReplaceContentPlugin payload={examplePrompt} /> | |
</LexicalComposer> | |
</div> | |
<div className="flex flex-row items-center justify-between my-4 gap-2"> | |
<div className="flex flex-row items-center gap-2"> | |
<button | |
onClick={onRunModelClick} | |
disabled={runButtonDisabled} | |
className={"btn" + " " + runButtonColor} | |
> | |
{runButtonText} | |
</button> | |
<button onClick={abortGeneration} className="btn btn-ghost"> | |
Abort | |
</button> | |
</div> | |
{!tutorialBanner && ( | |
<button | |
className="btn btn-circle bg-white border-none" | |
onClick={() => setTutorialBanner(true)} | |
> | |
<Idea size={24} /> | |
</button> | |
)} | |
</div> | |
</div> | |
{/* Results */} | |
<div className="flex-1 flex flex-col bg-gray-50 rounded-md overflow-x-hidden px-6 py-4 max-h-[calc(100vh-200px)] "> | |
<div className="prose"> | |
<h4>Output</h4> | |
</div> | |
<div className="mt-6 overflow-scroll flex-1 leading-relaxed markdown"> | |
{numberInQueue && numberInQueue > 0 && ( | |
<div | |
role="alert" | |
className="p-4 mb-4 text-med rounded-lg bg-purple-50" | |
> | |
There are {numberInQueue} other users in the queue for | |
generation. | |
</div> | |
)} | |
<div className="prose leading-snug"> | |
{mergeTextContent(modelOutput).map(contentToHtml)} | |
</div> | |
<ImageResult src={partialImage} completed={false} /> | |
</div> | |
</div> | |
</div> | |
{/* Side panel */} | |
<div | |
className={`absolute top-0 bottom-11 w-[490px] max-h-[calc(100vh-200px)] rounded-md px-6 py-4 overflow-y-scroll ${ | |
advancedMode ? "left-0" : "left-[-500px]" | |
} bg-gray-100 transition-all`} | |
> | |
<div className="prose flex flex-row items-center justify-between"> | |
<h3>Advanced settings</h3> | |
<Close | |
size={32} | |
className="cursor-pointer hover:fill-primary" | |
onClick={() => setAdvancedMode(false)} | |
/> | |
</div> | |
<InputRange | |
value={temp} | |
onValueChange={setTemp} | |
label="Temperature" | |
min={0.01} | |
step={0.01} | |
max={1} | |
/> | |
<InputRange | |
value={topP} | |
onValueChange={setTopP} | |
label="Top P" | |
min={0.01} | |
step={0.01} | |
max={1} | |
/> | |
<InputRange | |
value={maxGenTokens} | |
onValueChange={setMaxGenTokens} | |
label="Max Gen Tokens" | |
integerOnly | |
step={1} | |
min={1} | |
max={4096} | |
/> | |
<InputRange | |
value={repetitionPenalty} | |
onValueChange={setRepetitionPenalty} | |
label="Text Repetition Penalty" | |
min={0} | |
max={10} | |
/> | |
<InputRange | |
value={cfgImageWeight} | |
onValueChange={setCfgImageWeight} | |
label="CFG Image Weight" | |
min={0.01} | |
max={10} | |
/> | |
<InputRange | |
value={cfgTextWeight} | |
onValueChange={setCfgTextWeight} | |
label="CFG Text Weight" | |
min={0.01} | |
max={10} | |
/> | |
<InputToggle | |
label="Set seed" | |
value={showSeed} | |
onValueChange={(checked) => { | |
setShowSeed(checked); | |
}} | |
/> | |
{showSeed && seed != null && ( | |
<InputRange | |
value={seed} | |
step={1} | |
integerOnly={true} | |
onValueChange={setSeed} | |
label="Seed" | |
min={1} | |
max={1000} | |
/> | |
)} | |
{/* Input preview */} | |
<InputShowHide | |
labelShow="Show input data" | |
labelHide="Hide input data" | |
> | |
<div className="overflow-auto bg-white text-xs font-mono p-4 rounded-md mt-4"> | |
<JsonView | |
src={contents} | |
collapsed={({ node, indexOrName, depth, size }) => | |
indexOrName !== "data" && depth > 3 | |
} | |
/> | |
</div> | |
</InputShowHide> | |
</div> | |
</div> | |
<div className="absolute bottom-0 left-20 right-20 bg-white flex flex-row items-center gap-4 text-xs h-10"> | |
<StatusBadge | |
label="Connection" | |
category={socketStatus} | |
status={readableSocketState} | |
/> | |
<StatusBadge | |
label="UI" | |
category={uiStatus} | |
status={generationState} | |
/> | |
</div> | |
</div> | |
</> | |
); | |
} | |
return <Editor />; | |
} | |