jbilcke-hf's picture
jbilcke-hf HF staff
clarify the licence
c72e989
"use client"
import React, { useEffect, useRef, useState, useTransition } from 'react'
import {
Card,
CardContent,
CardDescription,
CardFooter,
CardHeader,
CardTitle
} from '@/components/ui/card'
import { Button } from '@/components/ui/button'
import { InputField } from '@/components/form/input-field'
import { SelectField } from '@/components/form/select-field'
import { SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'
import { Checkbox } from '@/components/ui/checkbox'
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
import { SliderField } from '@/components/form/slider-field'
import { Toaster } from '@/components/ui/sonner'
import { cn } from '@/lib/utils'
import { CharacterExpression, CharacterGender, CharacterPose, MIN_AGE, DEFAULT_AGE, MAX_AGE, DEFAULT_WEIGHT, GenerationMode, HairColor, HairStyle, PhotoshootMode, characterExpressions, characterGenders, characterPoses, hairColors, hairStyles, photoshootModes } from '../lib/config'
import { getStableDiffusionParams } from '@/lib/getStableDiffusionParams'
import { getStableCascadeParams } from '@/lib/getStableCascadeParams'
import { stableDiffusion } from './server/stableDiffusion'
import { stableCascade } from './server/stableCascade'
import { generateSeed } from '@/lib/generateSeed'
import { BackgroundRemovalParams, GenerationStatus, UpscalingParams } from '@/types'
import { removeBackground } from './server/background'
import { upscale } from './server/upscale'
/*
Maybe we should use this classification instead:
Underweight < 18.5
Normal weight 18.5–24.9
Overweight 25.0–29.9
Obese
*/
export function Main() {
const [_isPending, startTransition] = useTransition()
const [status, setStatus] = useState<GenerationStatus>("idle" as GenerationStatus)
const [generationMode, setGenerationMode] = useState<GenerationMode>('characters')
const [characterName, setCharacterName] = useState('Ragnarr, viking in armor')
const [characterAge, setCharacterAge] = useState(DEFAULT_AGE)
const [characterGender, setCharacterGender] = useState<CharacterGender>(
"male" // pick<CharacterGender>(characterGenders)
)
const [characterWeight, setCharacterWeight] = useState(DEFAULT_WEIGHT)
const [characterHairStyle, setCharacterHairStyle] = useState<HairStyle>(
"Braids" // pick<HairStyle>(hairStyles)
)
const [characterHairColor, setCharacterHairColor] = useState<HairColor>(
"chesnut" // pick<HairColor>(hairColors)
)
const [characterExpression, setCharacterExpression] = useState<CharacterExpression>("neutral")
const [characterPose, setCharacterPose] = useState<CharacterPose>("side-on pose")
const [photoshootMode, setPhotoshootMode] = useState<PhotoshootMode>(photoshootModes[0])
const [prompt, setPrompt] = useState('llamacookie')
const [negativePrompt, setNegativePrompt] = useState('')
const [inferenceSteps, setInferenceSteps] = useState(20)
const [nbPriorInferenceSteps, setNbPriorInferenceSteps] = useState(20)
const [nbDecoderInferenceSteps, setNbDecoderInferenceSteps] = useState(10)
const [guidanceScale, setGuidanceScale] = useState(4)
const [seed, setSeed] = useState('')
const [strength, setStrength] = useState(0.8)
const [runVaeOnEachStep, setRunVaeOnEachStep] = useState(false)
const [maxTokens, setMaxTokens] = useState(0)
const [totalTokens, setTotalTokens] = useState(0)
// the base image, uncropped
const [baseImageUrl, setBaseImageUrl] = useState("")
const [upscaledImageUrl, setUpscaledImageUrl] = useState("")
const [croppedImageUrl, setCroppedImageUrl] = useState("")
const showAdvancedSettings = true
const stableCascadeParams = getStableCascadeParams({
prompt,
generationMode,
negativePrompt,
characterAge,
characterExpression,
characterGender,
characterHairColor,
characterHairStyle,
characterName,
characterPose
})
const onDraw = async () => {
console.log("onRender")
let baseImageUrl = ""
let upscaledImageUrl = ""
let croppedImageUrl = ""
setStatus("generating")
setBaseImageUrl(baseImageUrl)
setUpscaledImageUrl(upscaledImageUrl)
setCroppedImageUrl(croppedImageUrl)
try {
// baseImageUrl = await stableDiffusion(stableDiffusionParams)
baseImageUrl = await stableCascade(stableCascadeParams)
} catch (err) {
console.error(`failed to generate:`, err)
}
if (!baseImageUrl) {
setStatus("error")
return
}
setBaseImageUrl(baseImageUrl)
/*
try {
croppedImageUrl = await removeBackground({
imageAsBase64: baseImageUrl,
})
} catch (err) {
console.error(`failed to crop the low-resolution image:`, err)
}
if (!croppedImageUrl) {
setStatus( "error")
return
}
setCroppedImageUrl(croppedImageUrl)
*/
setStatus("upscaling")
try {
upscaledImageUrl = await upscale({
imageAsBase64: baseImageUrl,
prompt: stableCascadeParams.prompt,
negativePrompt: stableCascadeParams.negativePrompt,
scaleFactor: 2,
seed: generateSeed(),
// // for a single image we can afford a higher rate, such as 25
nbSteps: 20,
})
} catch (err) {
console.error(`failed to upscale:`, err)
}
if (!upscaledImageUrl) {
setStatus( "error")
return
}
setUpscaledImageUrl(upscaledImageUrl)
try {
croppedImageUrl = await removeBackground({
imageAsBase64: upscaledImageUrl,
})
} catch (err) {
console.error(`failed to crop the upscaled image:`, err)
}
if (!croppedImageUrl) {
setStatus( "error")
return
}
setCroppedImageUrl(croppedImageUrl)
setStatus("finished")
}
const isBusy = status === "generating" || status === "cropping" || status === "upscaling"
console.log("debug:", {
status,
isBusy
})
return (
<div className={cn(
`fixed dark font-mono`,
// `bg-zinc-800`,
`bg-gradient-to-r from-stone-800 to-stone-500`,
`w-screen h-screen overflow-hidden`
)}>
<div className="flex flex-col w-full">
<div className="flex flex-row w-full p-10">
<div className={cn(
`flex flex-col w-[512px]`,
`transition-all duration-300 ease-in-out`,
)}>
<Card className="shadow-xl z-30 rounded-xl">
<CardHeader>
<div className="flex flex-col justify-start">
<div className="flex flex-row items-center font-mono p-3 bg-zinc-900 rounded-lg">
<div className="px-1 pt-1 p-0.5 rounded-sm mr-2 bg-yellow-400 text-xl text-zinc-900 font-bold">HFi</div>
<div className="text-yellow-400 text-lg">Illustrateur</div>
</div>
<p className="text-neutral-400 text-xs pt-4">Based on Stable Cascade, so currently for non-commercial usage only.</p>
</div>
</CardHeader>
<CardContent className="flex flex-col">
<Tabs
defaultValue={generationMode}
className="w-full"
value={generationMode}
onValueChange={(tab: any) => setGenerationMode(tab as GenerationMode)}
>
<TabsList className="grid w-full grid-cols-2 mb-4">
<TabsTrigger value="characters">Characters</TabsTrigger>
<TabsTrigger value="assets">Assets</TabsTrigger>
</TabsList>
<TabsContent value="characters" className="flex flex-col space-y-4 mt-0">
<SelectField
label="Photoshoot mode"
value={photoshootMode}
onValueChange={(value: string) => {
setPhotoshootMode(value as PhotoshootMode)
}}
>
<SelectTrigger className="text-xs">
<SelectValue placeholder="Photoshoot mode">{photoshootMode}</SelectValue>
</SelectTrigger>
<SelectContent>
{photoshootModes.map(mode =>
<SelectItem
key={`${mode || ""}`}
value={mode}
>{mode}</SelectItem>
)}
</SelectContent>
</SelectField>
<InputField
label="Character name"
// disabled={modelState != 'ready'}
onChange={(e) => setCharacterName(e.target.value)}
value={characterName}
/>
<SliderField
label={`${characterAge} years old`}
// disabled={modelState != 'ready'}
min={MIN_AGE}
max={MAX_AGE}
step={1}
onValueChange={(value: any) => {
let age = Number(value[0])
age = !isNaN(value[0]) && isFinite(value[0]) ? age : 0
age = Math.min(MAX_AGE, Math.max(MIN_AGE, age))
setCharacterAge(age)
}}
defaultValue={[characterAge]}
value={[characterAge]}
/>
<SelectField
label="Gender"
value={characterGender}
onValueChange={(value: any) => {
setCharacterGender(value as CharacterGender)
}}
>
<SelectTrigger className="text-xs">
<SelectValue placeholder="Gender">{characterGender}</SelectValue>
</SelectTrigger>
<SelectContent>
{characterGenders.map(gender =>
<SelectItem
key={`${gender || ""}`}
value={gender}
>{gender}</SelectItem>
)}
</SelectContent>
</SelectField>
<SelectField
label="Hair style"
value={characterHairStyle}
onValueChange={(value: string) => {
setCharacterHairStyle(value as HairStyle)
}}
>
<SelectTrigger className="text-xs">
<SelectValue placeholder="Hair style">{characterHairStyle}</SelectValue>
</SelectTrigger>
<SelectContent>
{hairStyles.map(hairStyle =>
<SelectItem
key={`${hairStyle || ""}`}
value={hairStyle}
>{hairStyle}</SelectItem>
)}
</SelectContent>
</SelectField>
<SelectField
label="Hair color"
value={characterHairColor}
onValueChange={(value: string) => {
setCharacterHairColor(value as HairStyle)
}}
>
<SelectTrigger className="text-xs">
<SelectValue placeholder="Hair color">{characterHairColor}</SelectValue>
</SelectTrigger>
<SelectContent>
{hairColors.map(hairColor =>
<SelectItem
key={`${hairColor || ""}`}
value={hairColor}
>{hairColor}</SelectItem>
)}
</SelectContent>
</SelectField>
<SelectField
label="Expression"
value={characterExpression}
onValueChange={(value: any) => {
setCharacterExpression(value as CharacterExpression)
}}
>
<SelectTrigger className="text-xs">
<SelectValue placeholder="Expression">{characterExpression}</SelectValue>
</SelectTrigger>
<SelectContent>
{characterExpressions.map(expression =>
<SelectItem
key={`${expression || ""}`}
value={expression}
>{expression}</SelectItem>
)}
</SelectContent>
</SelectField>
<SelectField
label="Pose"
value={characterPose}
onValueChange={(value: any) => {
setCharacterPose(value as CharacterPose)
}}
>
<SelectTrigger className="text-xs">
<SelectValue placeholder="Pose">{characterPose}</SelectValue>
</SelectTrigger>
<SelectContent>
{characterPoses.map(pose =>
<SelectItem
key={`${pose || ""}`}
value={pose}
>{pose}</SelectItem>
)}
</SelectContent>
</SelectField>
<InputField
label="Guidance Scale"
type='number'
min={1}
max={15}
step={0.1}
// disabled={modelState != 'ready'}
onChange={(e) => setGuidanceScale(parseFloat(e.target.value))}
value={guidanceScale}
className={cn({
hidden: true, // !showAdvancedSettings
})}
/>
<SliderField
label={`${inferenceSteps} steps`}
// disabled={modelState != 'ready'}
min={1}
max={10}
step={1}
onValueChange={(value: any) => {
let steps = Number(value[0])
steps = !isNaN(value[0]) && isFinite(value[0]) ? steps : 0
steps = Math.min(10, Math.max(1, steps))
setInferenceSteps(steps)
}}
defaultValue={[inferenceSteps]}
value={[inferenceSteps]}
className={cn({
hidden: true // !showAdvancedSettings
})}
/>
{/*
<InputField
label="Seed (optional)"
type="number"
// disabled={modelState != 'ready'}
onChange={(e) => setSeed(e.target.value)}
value={seed}
/>
*/}
</TabsContent>
<TabsContent
value="assets"
className="flex flex-col space-y-4 mt-0">
<InputField
label="Prompt"
// disabled={modelState != 'ready'}
onChange={(e) => setPrompt(e.target.value)}
value={prompt}
/>
{maxTokens > 0 && (<span style={{ color: totalTokens > maxTokens ? "#EC5578" : "rgba(255, 255, 255, 0.5)", fontSize: "0.8em" }}>{totalTokens}/{maxTokens} Tokens</span>)}
<InputField
label="Negative Prompt"
// disabled={modelState != 'ready'}
onChange={(e) => setNegativePrompt(e.target.value)}
value={negativePrompt}
className={cn({
hidden: !showAdvancedSettings
})}
/>
<SliderField
label={`${inferenceSteps} steps`}
// disabled={modelState != 'ready'}
min={1}
max={10}
step={1}
onValueChange={(value: any) => {
let steps = Number(value[0])
steps = !isNaN(value[0]) && isFinite(value[0]) ? steps : 0
steps = Math.min(1, Math.max(10, steps))
setInferenceSteps(steps)
}}
defaultValue={[inferenceSteps]}
value={[inferenceSteps]}
className={cn({
hidden: true, // !showAdvancedSettings
})}
/>
<InputField
label="Guidance Scale"
type='number'
min={1}
max={20}
step={0.5}
// disabled={modelState != 'ready'}
onChange={(e) => setGuidanceScale(parseFloat(e.target.value))}
value={guidanceScale}
className={cn({
hidden: true, // !showAdvancedSettings
})}
/>
{/*
<InputField
label="Seed (will be random by default)"
// disabled={modelState != 'ready'}
onChange={(e) => setSeed(e.target.value)}
value={seed}
/>
*/}
</TabsContent>
</Tabs>
<div className="flex-flex-row space-y-3 pt-4">
<div className="flex flex-row justify-end items-center">
<div className="flex flex-row justify-between items-center space-x-3">
{croppedImageUrl && <Button
className={cn(
`bg-zinc-950 text-zinc-200`,
`font-bold text-xs`,
`hover:bg-zinc-800 hover:text-zinc-100`,
`border border-zinc-200`
)}
disabled={isBusy}
onClick={() => {
// TODO: read the doc for Konva on how to download an image
const link = document.createElement('a');
link.download = 'export.png';
link.href = croppedImageUrl
link.click();
}}>💾 PNG</Button>}
<Button
onClick={onDraw}
disabled={isBusy}
// variant="ghost"
className={cn(
`bg-zinc-50/90 text-zinc-900/80`,
`font-bold`,
`hover:bg-zinc-50 hover:text-zinc-900/90`,
)}
>
{generationMode === "characters"
? '🦸 Cast'
: '🖍️ Draw'
}
</Button>
</div>
</div>
<div className="text-zinc-400 text-xs w-full text-right">Status: {status}</div>
</div>
</CardContent>
</Card>
</div>
<div className={cn(
`flex flex-col items-center justify-center`,
`absolute right-0 top-0`,
`w-[100vh] h-screen`,
// `transition-all duration-300 ease-in-out`
)}>
<img src={croppedImageUrl} />
</div>
</div>
</div>
<Toaster />
</div>
);
}