|
"use client" |
|
|
|
import React, { useEffect, useRef, useState, useTransition } from 'react' |
|
|
|
import { |
|
Card, |
|
CardContent, |
|
CardDescription, |
|
CardFooter, |
|
CardHeader, |
|
CardTitle |
|
} from '@/components/ui/card' |
|
import { Button } from '@/components/ui/button' |
|
import { InputField } from '@/components/form/input-field' |
|
import { SelectField } from '@/components/form/select-field' |
|
import { SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select' |
|
import { Checkbox } from '@/components/ui/checkbox' |
|
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs' |
|
import { SliderField } from '@/components/form/slider-field' |
|
import { Toaster } from '@/components/ui/sonner' |
|
import { cn } from '@/lib/utils' |
|
import { CharacterExpression, CharacterGender, CharacterPose, MIN_AGE, DEFAULT_AGE, MAX_AGE, DEFAULT_WEIGHT, GenerationMode, HairColor, HairStyle, PhotoshootMode, characterExpressions, characterGenders, characterPoses, hairColors, hairStyles, photoshootModes } from '../lib/config' |
|
import { getStableDiffusionParams } from '@/lib/getStableDiffusionParams' |
|
import { getStableCascadeParams } from '@/lib/getStableCascadeParams' |
|
import { stableDiffusion } from './server/stableDiffusion' |
|
import { stableCascade } from './server/stableCascade' |
|
import { generateSeed } from '@/lib/generateSeed' |
|
import { BackgroundRemovalParams, GenerationStatus, UpscalingParams } from '@/types' |
|
import { removeBackground } from './server/background' |
|
import { upscale } from './server/upscale' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export function Main() { |
|
const [_isPending, startTransition] = useTransition() |
|
const [status, setStatus] = useState<GenerationStatus>("idle" as GenerationStatus) |
|
|
|
const [generationMode, setGenerationMode] = useState<GenerationMode>('characters') |
|
|
|
const [characterName, setCharacterName] = useState('Ragnarr, viking in armor') |
|
const [characterAge, setCharacterAge] = useState(DEFAULT_AGE) |
|
const [characterGender, setCharacterGender] = useState<CharacterGender>( |
|
"male" |
|
) |
|
const [characterWeight, setCharacterWeight] = useState(DEFAULT_WEIGHT) |
|
const [characterHairStyle, setCharacterHairStyle] = useState<HairStyle>( |
|
"Braids" |
|
) |
|
const [characterHairColor, setCharacterHairColor] = useState<HairColor>( |
|
"chesnut" |
|
) |
|
const [characterExpression, setCharacterExpression] = useState<CharacterExpression>("neutral") |
|
const [characterPose, setCharacterPose] = useState<CharacterPose>("side-on pose") |
|
const [photoshootMode, setPhotoshootMode] = useState<PhotoshootMode>(photoshootModes[0]) |
|
|
|
const [prompt, setPrompt] = useState('llamacookie') |
|
const [negativePrompt, setNegativePrompt] = useState('') |
|
const [inferenceSteps, setInferenceSteps] = useState(20) |
|
const [nbPriorInferenceSteps, setNbPriorInferenceSteps] = useState(20) |
|
const [nbDecoderInferenceSteps, setNbDecoderInferenceSteps] = useState(10) |
|
const [guidanceScale, setGuidanceScale] = useState(4) |
|
const [seed, setSeed] = useState('') |
|
const [strength, setStrength] = useState(0.8) |
|
const [runVaeOnEachStep, setRunVaeOnEachStep] = useState(false) |
|
const [maxTokens, setMaxTokens] = useState(0) |
|
const [totalTokens, setTotalTokens] = useState(0) |
|
|
|
|
|
const [baseImageUrl, setBaseImageUrl] = useState("") |
|
const [upscaledImageUrl, setUpscaledImageUrl] = useState("") |
|
const [croppedImageUrl, setCroppedImageUrl] = useState("") |
|
|
|
const showAdvancedSettings = true |
|
|
|
const stableCascadeParams = getStableCascadeParams({ |
|
prompt, |
|
generationMode, |
|
negativePrompt, |
|
characterAge, |
|
characterExpression, |
|
characterGender, |
|
characterHairColor, |
|
characterHairStyle, |
|
characterName, |
|
characterPose |
|
}) |
|
|
|
const onDraw = async () => { |
|
console.log("onRender") |
|
|
|
let baseImageUrl = "" |
|
let upscaledImageUrl = "" |
|
let croppedImageUrl = "" |
|
|
|
setStatus("generating") |
|
setBaseImageUrl(baseImageUrl) |
|
setUpscaledImageUrl(upscaledImageUrl) |
|
setCroppedImageUrl(croppedImageUrl) |
|
|
|
try { |
|
|
|
baseImageUrl = await stableCascade(stableCascadeParams) |
|
} catch (err) { |
|
console.error(`failed to generate:`, err) |
|
} |
|
|
|
if (!baseImageUrl) { |
|
setStatus("error") |
|
return |
|
} |
|
|
|
setBaseImageUrl(baseImageUrl) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setStatus("upscaling") |
|
|
|
try { |
|
upscaledImageUrl = await upscale({ |
|
imageAsBase64: baseImageUrl, |
|
prompt: stableCascadeParams.prompt, |
|
negativePrompt: stableCascadeParams.negativePrompt, |
|
scaleFactor: 2, |
|
seed: generateSeed(), |
|
|
|
|
|
nbSteps: 20, |
|
}) |
|
} catch (err) { |
|
console.error(`failed to upscale:`, err) |
|
} |
|
|
|
if (!upscaledImageUrl) { |
|
setStatus( "error") |
|
return |
|
} |
|
|
|
setUpscaledImageUrl(upscaledImageUrl) |
|
|
|
try { |
|
croppedImageUrl = await removeBackground({ |
|
imageAsBase64: upscaledImageUrl, |
|
}) |
|
} catch (err) { |
|
console.error(`failed to crop the upscaled image:`, err) |
|
} |
|
|
|
if (!croppedImageUrl) { |
|
setStatus( "error") |
|
return |
|
} |
|
|
|
setCroppedImageUrl(croppedImageUrl) |
|
setStatus("finished") |
|
} |
|
|
|
const isBusy = status === "generating" || status === "cropping" || status === "upscaling" |
|
|
|
console.log("debug:", { |
|
status, |
|
isBusy |
|
}) |
|
|
|
return ( |
|
<div className={cn( |
|
`fixed dark font-mono`, |
|
// `bg-zinc-800`, |
|
`bg-gradient-to-r from-stone-800 to-stone-500`, |
|
`w-screen h-screen overflow-hidden` |
|
)}> |
|
<div className="flex flex-col w-full"> |
|
<div className="flex flex-row w-full p-10"> |
|
<div className={cn( |
|
`flex flex-col w-[512px]`, |
|
`transition-all duration-300 ease-in-out`, |
|
)}> |
|
<Card className="shadow-xl z-30 rounded-xl"> |
|
<CardHeader> |
|
<div className="flex flex-col justify-start"> |
|
<div className="flex flex-row items-center font-mono p-3 bg-zinc-900 rounded-lg"> |
|
<div className="px-1 pt-1 p-0.5 rounded-sm mr-2 bg-yellow-400 text-xl text-zinc-900 font-bold">HFi</div> |
|
<div className="text-yellow-400 text-lg">Illustrateur</div> |
|
</div> |
|
<p className="text-neutral-400 text-xs pt-4">Based on Stable Cascade, so currently for non-commercial usage only.</p> |
|
</div> |
|
</CardHeader> |
|
<CardContent className="flex flex-col"> |
|
<Tabs |
|
defaultValue={generationMode} |
|
className="w-full" |
|
value={generationMode} |
|
onValueChange={(tab: any) => setGenerationMode(tab as GenerationMode)} |
|
> |
|
<TabsList className="grid w-full grid-cols-2 mb-4"> |
|
<TabsTrigger value="characters">Characters</TabsTrigger> |
|
<TabsTrigger value="assets">Assets</TabsTrigger> |
|
</TabsList> |
|
<TabsContent value="characters" className="flex flex-col space-y-4 mt-0"> |
|
<SelectField |
|
label="Photoshoot mode" |
|
value={photoshootMode} |
|
onValueChange={(value: string) => { |
|
setPhotoshootMode(value as PhotoshootMode) |
|
}} |
|
> |
|
<SelectTrigger className="text-xs"> |
|
<SelectValue placeholder="Photoshoot mode">{photoshootMode}</SelectValue> |
|
</SelectTrigger> |
|
<SelectContent> |
|
{photoshootModes.map(mode => |
|
<SelectItem |
|
key={`${mode || ""}`} |
|
value={mode} |
|
>{mode}</SelectItem> |
|
)} |
|
</SelectContent> |
|
</SelectField> |
|
|
|
<InputField |
|
label="Character name" |
|
// disabled={modelState != 'ready'} |
|
onChange={(e) => setCharacterName(e.target.value)} |
|
value={characterName} |
|
/> |
|
|
|
<SliderField |
|
label={`${characterAge} years old`} |
|
// disabled={modelState != 'ready'} |
|
min={MIN_AGE} |
|
max={MAX_AGE} |
|
step={1} |
|
onValueChange={(value: any) => { |
|
let age = Number(value[0]) |
|
age = !isNaN(value[0]) && isFinite(value[0]) ? age : 0 |
|
age = Math.min(MAX_AGE, Math.max(MIN_AGE, age)) |
|
setCharacterAge(age) |
|
}} |
|
defaultValue={[characterAge]} |
|
value={[characterAge]} |
|
/> |
|
|
|
<SelectField |
|
label="Gender" |
|
value={characterGender} |
|
onValueChange={(value: any) => { |
|
setCharacterGender(value as CharacterGender) |
|
}} |
|
> |
|
<SelectTrigger className="text-xs"> |
|
<SelectValue placeholder="Gender">{characterGender}</SelectValue> |
|
</SelectTrigger> |
|
<SelectContent> |
|
{characterGenders.map(gender => |
|
<SelectItem |
|
key={`${gender || ""}`} |
|
value={gender} |
|
>{gender}</SelectItem> |
|
)} |
|
</SelectContent> |
|
</SelectField> |
|
|
|
<SelectField |
|
label="Hair style" |
|
value={characterHairStyle} |
|
onValueChange={(value: string) => { |
|
setCharacterHairStyle(value as HairStyle) |
|
}} |
|
> |
|
<SelectTrigger className="text-xs"> |
|
<SelectValue placeholder="Hair style">{characterHairStyle}</SelectValue> |
|
</SelectTrigger> |
|
<SelectContent> |
|
{hairStyles.map(hairStyle => |
|
<SelectItem |
|
key={`${hairStyle || ""}`} |
|
value={hairStyle} |
|
>{hairStyle}</SelectItem> |
|
)} |
|
</SelectContent> |
|
</SelectField> |
|
|
|
<SelectField |
|
label="Hair color" |
|
value={characterHairColor} |
|
onValueChange={(value: string) => { |
|
setCharacterHairColor(value as HairStyle) |
|
}} |
|
> |
|
<SelectTrigger className="text-xs"> |
|
<SelectValue placeholder="Hair color">{characterHairColor}</SelectValue> |
|
</SelectTrigger> |
|
<SelectContent> |
|
{hairColors.map(hairColor => |
|
<SelectItem |
|
key={`${hairColor || ""}`} |
|
value={hairColor} |
|
>{hairColor}</SelectItem> |
|
)} |
|
</SelectContent> |
|
</SelectField> |
|
|
|
<SelectField |
|
label="Expression" |
|
value={characterExpression} |
|
onValueChange={(value: any) => { |
|
setCharacterExpression(value as CharacterExpression) |
|
}} |
|
> |
|
<SelectTrigger className="text-xs"> |
|
<SelectValue placeholder="Expression">{characterExpression}</SelectValue> |
|
</SelectTrigger> |
|
<SelectContent> |
|
{characterExpressions.map(expression => |
|
<SelectItem |
|
key={`${expression || ""}`} |
|
value={expression} |
|
>{expression}</SelectItem> |
|
)} |
|
</SelectContent> |
|
</SelectField> |
|
|
|
<SelectField |
|
label="Pose" |
|
value={characterPose} |
|
onValueChange={(value: any) => { |
|
setCharacterPose(value as CharacterPose) |
|
}} |
|
> |
|
<SelectTrigger className="text-xs"> |
|
<SelectValue placeholder="Pose">{characterPose}</SelectValue> |
|
</SelectTrigger> |
|
<SelectContent> |
|
{characterPoses.map(pose => |
|
<SelectItem |
|
key={`${pose || ""}`} |
|
value={pose} |
|
>{pose}</SelectItem> |
|
)} |
|
</SelectContent> |
|
</SelectField> |
|
<InputField |
|
label="Guidance Scale" |
|
type='number' |
|
min={1} |
|
max={15} |
|
step={0.1} |
|
// disabled={modelState != 'ready'} |
|
onChange={(e) => setGuidanceScale(parseFloat(e.target.value))} |
|
value={guidanceScale} |
|
className={cn({ |
|
hidden: true, // !showAdvancedSettings |
|
})} |
|
/> |
|
<SliderField |
|
label={`${inferenceSteps} steps`} |
|
// disabled={modelState != 'ready'} |
|
min={1} |
|
max={10} |
|
step={1} |
|
onValueChange={(value: any) => { |
|
let steps = Number(value[0]) |
|
steps = !isNaN(value[0]) && isFinite(value[0]) ? steps : 0 |
|
steps = Math.min(10, Math.max(1, steps)) |
|
setInferenceSteps(steps) |
|
}} |
|
defaultValue={[inferenceSteps]} |
|
value={[inferenceSteps]} |
|
className={cn({ |
|
hidden: true // !showAdvancedSettings |
|
})} |
|
/> |
|
|
|
{/* |
|
<InputField |
|
label="Seed (optional)" |
|
type="number" |
|
// disabled={modelState != 'ready'} |
|
onChange={(e) => setSeed(e.target.value)} |
|
value={seed} |
|
/> |
|
*/} |
|
</TabsContent> |
|
|
|
<TabsContent |
|
value="assets" |
|
className="flex flex-col space-y-4 mt-0"> |
|
<InputField |
|
label="Prompt" |
|
// disabled={modelState != 'ready'} |
|
onChange={(e) => setPrompt(e.target.value)} |
|
value={prompt} |
|
/> |
|
|
|
{maxTokens > 0 && (<span style={{ color: totalTokens > maxTokens ? "#EC5578" : "rgba(255, 255, 255, 0.5)", fontSize: "0.8em" }}>{totalTokens}/{maxTokens} Tokens</span>)} |
|
<InputField |
|
label="Negative Prompt" |
|
// disabled={modelState != 'ready'} |
|
onChange={(e) => setNegativePrompt(e.target.value)} |
|
value={negativePrompt} |
|
className={cn({ |
|
hidden: !showAdvancedSettings |
|
})} |
|
/> |
|
|
|
<SliderField |
|
label={`${inferenceSteps} steps`} |
|
// disabled={modelState != 'ready'} |
|
min={1} |
|
max={10} |
|
step={1} |
|
onValueChange={(value: any) => { |
|
let steps = Number(value[0]) |
|
steps = !isNaN(value[0]) && isFinite(value[0]) ? steps : 0 |
|
steps = Math.min(1, Math.max(10, steps)) |
|
setInferenceSteps(steps) |
|
}} |
|
defaultValue={[inferenceSteps]} |
|
value={[inferenceSteps]} |
|
className={cn({ |
|
hidden: true, // !showAdvancedSettings |
|
})} |
|
/> |
|
|
|
<InputField |
|
label="Guidance Scale" |
|
type='number' |
|
min={1} |
|
max={20} |
|
step={0.5} |
|
// disabled={modelState != 'ready'} |
|
onChange={(e) => setGuidanceScale(parseFloat(e.target.value))} |
|
value={guidanceScale} |
|
className={cn({ |
|
hidden: true, // !showAdvancedSettings |
|
})} |
|
/> |
|
{/* |
|
<InputField |
|
label="Seed (will be random by default)" |
|
// disabled={modelState != 'ready'} |
|
onChange={(e) => setSeed(e.target.value)} |
|
value={seed} |
|
/> |
|
*/} |
|
|
|
</TabsContent> |
|
</Tabs> |
|
<div className="flex-flex-row space-y-3 pt-4"> |
|
<div className="flex flex-row justify-end items-center"> |
|
<div className="flex flex-row justify-between items-center space-x-3"> |
|
{croppedImageUrl && <Button |
|
className={cn( |
|
`bg-zinc-950 text-zinc-200`, |
|
`font-bold text-xs`, |
|
`hover:bg-zinc-800 hover:text-zinc-100`, |
|
`border border-zinc-200` |
|
)} |
|
disabled={isBusy} |
|
onClick={() => { |
|
// TODO: read the doc for Konva on how to download an image |
|
const link = document.createElement('a'); |
|
link.download = 'export.png'; |
|
link.href = croppedImageUrl |
|
link.click(); |
|
}}>💾 PNG</Button>} |
|
|
|
<Button |
|
onClick={onDraw} |
|
disabled={isBusy} |
|
// variant="ghost" |
|
className={cn( |
|
`bg-zinc-50/90 text-zinc-900/80`, |
|
`font-bold`, |
|
`hover:bg-zinc-50 hover:text-zinc-900/90`, |
|
)} |
|
> |
|
{generationMode === "characters" |
|
? '🦸 Cast' |
|
: '🖍️ Draw' |
|
} |
|
</Button> |
|
</div> |
|
|
|
</div> |
|
<div className="text-zinc-400 text-xs w-full text-right">Status: {status}</div> |
|
</div> |
|
</CardContent> |
|
|
|
</Card> |
|
</div> |
|
<div className={cn( |
|
`flex flex-col items-center justify-center`, |
|
`absolute right-0 top-0`, |
|
`w-[100vh] h-screen`, |
|
// `transition-all duration-300 ease-in-out` |
|
)}> |
|
<img src={croppedImageUrl} /> |
|
</div> |
|
</div> |
|
</div> |
|
<Toaster /> |
|
</div> |
|
); |
|
} |
|
|