|
|
|
|
|
|
|
class TVMPNDMScheduler { |
|
constructor(schedulerConsts, latentShape, tvm, device, vm) { |
|
this.timestep = []; |
|
this.sampleCoeff = []; |
|
this.alphaDiff = []; |
|
this.modelOutputDenomCoeff = []; |
|
this.ets = []; |
|
this.schedulerFunc = []; |
|
this.currSample = undefined; |
|
this.tvm = tvm; |
|
|
|
|
|
|
|
|
|
function loadConsts(output, dtype, input) { |
|
for (let t = 0; t < input.length; ++t) { |
|
output.push( |
|
tvm.detachFromCurrentScope( |
|
tvm.empty([], dtype, device).copyFrom([input[t]]) |
|
) |
|
); |
|
} |
|
} |
|
loadConsts(this.timestep, "int32", schedulerConsts["timesteps"]); |
|
loadConsts(this.sampleCoeff, "float32", schedulerConsts["sample_coeff"]); |
|
loadConsts(this.alphaDiff, "float32", schedulerConsts["alpha_diff"]); |
|
loadConsts( |
|
this.modelOutputDenomCoeff, "float32", |
|
schedulerConsts["model_output_denom_coeff"]); |
|
|
|
for (let i = 0; i < 4; ++i) { |
|
this.ets.push( |
|
this.tvm.detachFromCurrentScope( |
|
this.tvm.empty(latentShape, "float32", device) |
|
) |
|
); |
|
} |
|
|
|
for (let i = 0; i < 5; ++i) { |
|
this.schedulerFunc.push( |
|
tvm.detachFromCurrentScope( |
|
vm.getFunction("pndm_scheduler_step_" + i.toString()) |
|
) |
|
); |
|
} |
|
} |
|
|
|
dispose() { |
|
for (let t = 0; t < this.timestep.length; ++t) { |
|
this.timestep[t].dispose(); |
|
this.sampleCoeff[t].dispose(); |
|
this.alphaDiff[t].dispose(); |
|
this.modelOutputDenomCoeff[t].dispose(); |
|
} |
|
|
|
for (let i = 0; i < this.schedulerFunc.length; ++i) { |
|
this.schedulerFunc[i].dispose(); |
|
} |
|
|
|
if (this.currSample) { |
|
this.currSample.dispose(); |
|
} |
|
for (let i = 0; i < this.ets.length; ++i) { |
|
this.ets[i].dispose(); |
|
} |
|
} |
|
|
|
step(modelOutput, sample, counter) { |
|
|
|
if (counter != 1) { |
|
this.ets.shift(); |
|
this.ets.push(this.tvm.detachFromCurrentScope( |
|
modelOutput |
|
)); |
|
} |
|
if (counter == 0) { |
|
this.currSample = this.tvm.detachFromCurrentScope( |
|
sample |
|
); |
|
} else if (counter == 1) { |
|
sample = this.tvm.attachToCurrentScope(this.currSample); |
|
this.currSample = undefined; |
|
} |
|
|
|
const findex = counter < 4 ? counter : 4; |
|
const prevLatents = this.schedulerFunc[findex]( |
|
sample, |
|
modelOutput, |
|
this.sampleCoeff[counter], |
|
this.alphaDiff[counter], |
|
this.modelOutputDenomCoeff[counter], |
|
this.ets[0], |
|
this.ets[1], |
|
this.ets[2], |
|
this.ets[3] |
|
); |
|
return prevLatents; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
class TVMDPMSolverMultistepScheduler { |
|
constructor(schedulerConsts, latentShape, tvm, device, vm) { |
|
this.timestep = []; |
|
this.alpha = []; |
|
this.sigma = []; |
|
this.c0 = []; |
|
this.c1 = []; |
|
this.c2 = []; |
|
this.lastModelOutput = undefined; |
|
this.convertModelOutputFunc = undefined; |
|
this.stepFunc = undefined; |
|
this.tvm = tvm; |
|
|
|
|
|
|
|
|
|
function loadConsts(output, dtype, input) { |
|
for (let t = 0; t < input.length; ++t) { |
|
output.push( |
|
tvm.detachFromCurrentScope( |
|
tvm.empty([], dtype, device).copyFrom([input[t]]) |
|
) |
|
); |
|
} |
|
} |
|
loadConsts(this.timestep, "int32", schedulerConsts["timesteps"]); |
|
loadConsts(this.alpha, "float32", schedulerConsts["alpha"]); |
|
loadConsts(this.sigma, "float32", schedulerConsts["sigma"]); |
|
loadConsts(this.c0, "float32", schedulerConsts["c0"]); |
|
loadConsts(this.c1, "float32", schedulerConsts["c1"]); |
|
loadConsts(this.c2, "float32", schedulerConsts["c2"]); |
|
|
|
this.lastModelOutput = this.tvm.detachFromCurrentScope( |
|
this.tvm.empty(latentShape, "float32", device) |
|
) |
|
this.convertModelOutputFunc = tvm.detachFromCurrentScope( |
|
vm.getFunction("dpm_solver_multistep_scheduler_convert_model_output") |
|
) |
|
this.stepFunc = tvm.detachFromCurrentScope( |
|
vm.getFunction("dpm_solver_multistep_scheduler_step") |
|
) |
|
} |
|
|
|
dispose() { |
|
for (let t = 0; t < this.timestep.length; ++t) { |
|
this.timestep[t].dispose(); |
|
this.alpha[t].dispose(); |
|
this.sigma[t].dispose(); |
|
this.c0[t].dispose(); |
|
this.c1[t].dispose(); |
|
this.c2[t].dispose(); |
|
} |
|
|
|
this.lastModelOutput.dispose(); |
|
this.convertModelOutputFunc.dispose(); |
|
this.stepFunc.dispose(); |
|
} |
|
|
|
step(modelOutput, sample, counter) { |
|
modelOutput = this.convertModelOutputFunc(sample, modelOutput, this.alpha[counter], this.sigma[counter]) |
|
const prevLatents = this.stepFunc( |
|
sample, |
|
modelOutput, |
|
this.lastModelOutput, |
|
this.c0[counter], |
|
this.c1[counter], |
|
this.c2[counter], |
|
); |
|
this.lastModelOutput = this.tvm.detachFromCurrentScope( |
|
modelOutput |
|
); |
|
|
|
return prevLatents; |
|
} |
|
} |
|
|
|
class StableDiffusionPipeline { |
|
constructor(tvm, tokenizer, schedulerConsts, cacheMetadata) { |
|
if (cacheMetadata == undefined) { |
|
throw Error("Expect cacheMetadata"); |
|
} |
|
this.tvm = tvm; |
|
this.tokenizer = tokenizer; |
|
this.maxTokenLength = 77; |
|
|
|
this.device = this.tvm.webgpu(); |
|
this.tvm.bindCanvas(document.getElementById("canvas")); |
|
|
|
this.vm = this.tvm.detachFromCurrentScope( |
|
this.tvm.createVirtualMachine(this.device) |
|
); |
|
|
|
this.schedulerConsts = schedulerConsts; |
|
this.clipToTextEmbeddings = this.tvm.detachFromCurrentScope( |
|
this.vm.getFunction("clip") |
|
); |
|
this.clipParams = this.tvm.detachFromCurrentScope( |
|
this.tvm.getParamsFromCache("clip", cacheMetadata.clipParamSize) |
|
); |
|
this.unetLatentsToNoisePred = this.tvm.detachFromCurrentScope( |
|
this.vm.getFunction("unet") |
|
); |
|
this.unetParams = this.tvm.detachFromCurrentScope( |
|
this.tvm.getParamsFromCache("unet", cacheMetadata.unetParamSize) |
|
); |
|
this.vaeToImage = this.tvm.detachFromCurrentScope( |
|
this.vm.getFunction("vae") |
|
); |
|
this.vaeParams = this.tvm.detachFromCurrentScope( |
|
this.tvm.getParamsFromCache("vae", cacheMetadata.vaeParamSize) |
|
); |
|
this.imageToRGBA = this.tvm.detachFromCurrentScope( |
|
this.vm.getFunction("image_to_rgba") |
|
); |
|
this.concatEmbeddings = this.tvm.detachFromCurrentScope( |
|
this.vm.getFunction("concat_embeddings") |
|
); |
|
} |
|
|
|
dispose() { |
|
|
|
this.concatEmbeddings.dispose(); |
|
this.imageToRGBA.dispose() |
|
this.vaeParams.dispose(); |
|
this.vaeToImage.dispose(); |
|
this.unetParams.dispose(); |
|
this.unetLatentsToNoisePred.dispose(); |
|
this.clipParams.dispose(); |
|
this.clipToTextEmbeddings.dispose(); |
|
this.vm.dispose(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenize(prompt) { |
|
const encoded = this.tokenizer.encode(prompt, true).input_ids; |
|
const inputIDs = new Int32Array(this.maxTokenLength); |
|
|
|
if (encoded.length < this.maxTokenLength) { |
|
inputIDs.set(encoded); |
|
const lastTok = encoded[encoded.length - 1]; |
|
inputIDs.fill(lastTok, encoded.length, inputIDs.length); |
|
} else { |
|
inputIDs.set(encoded.slice(0, this.maxTokenLength)); |
|
} |
|
return this.tvm.empty([1, this.maxTokenLength], "int32", this.device).copyFrom(inputIDs); |
|
} |
|
|
|
|
|
|
|
|
|
async asyncLoadWebGPUPiplines() { |
|
await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule()); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async generate( |
|
prompt, |
|
negPrompt = "", |
|
progressCallback = undefined, |
|
schedulerId = 0, |
|
vaeCycle = -1, |
|
beginRenderVae = 10 |
|
) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
this.tvm.beginScope(); |
|
|
|
const latentShape = [1, 4, 64, 64]; |
|
|
|
var unetNumSteps; |
|
if (schedulerId == 0) { |
|
scheduler = new TVMDPMSolverMultistepScheduler( |
|
this.schedulerConsts[0], latentShape, this.tvm, this.device, this.vm); |
|
unetNumSteps = this.schedulerConsts[0]["num_steps"]; |
|
} else { |
|
scheduler = new TVMPNDMScheduler( |
|
this.schedulerConsts[1], latentShape, this.tvm, this.device, this.vm); |
|
unetNumSteps = this.schedulerConsts[1]["num_steps"]; |
|
} |
|
const totalNumSteps = unetNumSteps + 2; |
|
|
|
if (progressCallback !== undefined) { |
|
progressCallback("clip", 0, 1, totalNumSteps); |
|
} |
|
|
|
const embeddings = this.tvm.withNewScope(() => { |
|
let posInputIDs = this.tokenize(prompt); |
|
let negInputIDs = this.tokenize(negPrompt); |
|
const posEmbeddings = this.clipToTextEmbeddings( |
|
posInputIDs, this.clipParams); |
|
const negEmbeddings = this.clipToTextEmbeddings( |
|
negInputIDs, this.clipParams); |
|
|
|
return this.tvm.detachFromCurrentScope( |
|
this.concatEmbeddings(negEmbeddings, posEmbeddings) |
|
); |
|
}); |
|
|
|
const scale = Math.sqrt(12) / 2; |
|
let latents = this.tvm.detachFromCurrentScope( |
|
this.tvm.uniform(latentShape, -scale, scale, this.tvm.webgpu()) |
|
); |
|
this.tvm.endScope(); |
|
|
|
|
|
|
|
if (vaeCycle != -1) { |
|
|
|
this.tvm.withNewScope(() => { |
|
const image = this.vaeToImage(latents, this.vaeParams); |
|
this.tvm.showImage(this.imageToRGBA(image)); |
|
}); |
|
await this.device.sync(); |
|
} |
|
vaeCycle = vaeCycle == -1 ? unetNumSteps : vaeCycle; |
|
let lastSync = undefined; |
|
|
|
for (let counter = 0; counter < unetNumSteps; ++counter) { |
|
if (progressCallback !== undefined) { |
|
progressCallback("unet", counter, unetNumSteps, totalNumSteps); |
|
} |
|
const timestep = scheduler.timestep[counter]; |
|
|
|
const newLatents = this.tvm.withNewScope(() => { |
|
this.tvm.attachToCurrentScope(latents); |
|
const noisePred = this.unetLatentsToNoisePred( |
|
latents, timestep, embeddings, this.unetParams); |
|
|
|
return this.tvm.detachFromCurrentScope( |
|
scheduler.step(noisePred, latents, counter) |
|
); |
|
}); |
|
latents = newLatents; |
|
|
|
if (lastSync !== undefined) { |
|
await lastSync; |
|
} |
|
|
|
lastSync = this.device.sync(); |
|
|
|
|
|
if ((counter + 1) % vaeCycle == 0 && |
|
(counter + 1) != unetNumSteps && |
|
counter >= beginRenderVae) { |
|
this.tvm.withNewScope(() => { |
|
const image = this.vaeToImage(latents, this.vaeParams); |
|
this.tvm.showImage(this.imageToRGBA(image)); |
|
}); |
|
await this.device.sync(); |
|
} |
|
} |
|
scheduler.dispose(); |
|
embeddings.dispose(); |
|
|
|
|
|
|
|
if (progressCallback !== undefined) { |
|
progressCallback("vae", 0, 1, totalNumSteps); |
|
} |
|
this.tvm.withNewScope(() => { |
|
const image = this.vaeToImage(latents, this.vaeParams); |
|
this.tvm.showImage(this.imageToRGBA(image)); |
|
}); |
|
latents.dispose(); |
|
await this.device.sync(); |
|
if (progressCallback !== undefined) { |
|
progressCallback("vae", 1, 1, totalNumSteps); |
|
} |
|
} |
|
|
|
clearCanvas() { |
|
this.tvm.clearCanvas(); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
class StableDiffusionInstance { |
|
constructor() { |
|
this.tvm = undefined; |
|
this.pipeline = undefined; |
|
this.config = undefined; |
|
this.generateInProgress = false; |
|
this.logger = console.log; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
async #asyncInitTVM(wasmUrl, cacheUrl) { |
|
if (this.tvm !== undefined) { |
|
return; |
|
} |
|
|
|
if (document.getElementById("log") !== undefined) { |
|
this.logger = function (message) { |
|
console.log(message); |
|
const d = document.createElement("div"); |
|
d.innerHTML = message; |
|
document.getElementById("log").appendChild(d); |
|
}; |
|
} |
|
|
|
const wasmSource = await ( |
|
await fetch(wasmUrl) |
|
).arrayBuffer(); |
|
const tvm = await tvmjs.instantiate( |
|
new Uint8Array(wasmSource), |
|
new EmccWASI(), |
|
this.logger |
|
); |
|
|
|
try { |
|
const output = await tvmjs.detectGPUDevice(); |
|
if (output !== undefined) { |
|
var label = "WebGPU"; |
|
if (output.adapterInfo.description.length != 0) { |
|
label += " - " + output.adapterInfo.description; |
|
} else { |
|
label += " - " + output.adapterInfo.vendor; |
|
} |
|
document.getElementById( |
|
"gpu-tracker-label").innerHTML = ("Initialize GPU device: " + label); |
|
tvm.initWebGPU(output.device); |
|
} else { |
|
document.getElementById( |
|
"gpu-tracker-label").innerHTML = "This browser env do not support WebGPU"; |
|
this.reset(); |
|
throw Error("This browser env do not support WebGPU"); |
|
} |
|
} catch (err) { |
|
document.getElementById("gpu-tracker-label").innerHTML = ( |
|
"Find an error initializing the WebGPU device " + err.toString() |
|
); |
|
console.log(err.stack); |
|
this.reset(); |
|
throw Error("Find an error initializing WebGPU: " + err.toString()); |
|
} |
|
|
|
this.tvm = tvm; |
|
function initProgressCallback(report) { |
|
document.getElementById("progress-tracker-label").innerHTML = report.text; |
|
document.getElementById("progress-tracker-progress").value = report.progress * 100; |
|
} |
|
tvm.registerInitProgressCallback(initProgressCallback); |
|
if (!cacheUrl.startsWith("http")) { |
|
cacheUrl = new URL(cacheUrl, document.URL).href; |
|
} |
|
await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu()); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async #asyncInitPipeline(schedulerConstUrl, tokenizerName) { |
|
if (this.tvm == undefined) { |
|
throw Error("asyncInitTVM is not called"); |
|
} |
|
if (this.pipeline !== undefined) return; |
|
var schedulerConst = [] |
|
for (let i = 0; i < schedulerConstUrl.length; ++i) { |
|
schedulerConst.push(await (await fetch(schedulerConstUrl[i])).json()) |
|
} |
|
const tokenizer = await tvmjsGlobalEnv.getTokenizer(tokenizerName); |
|
this.pipeline = this.tvm.withNewScope(() => { |
|
return new StableDiffusionPipeline(this.tvm, tokenizer, schedulerConst, this.tvm.cacheMetadata); |
|
}); |
|
await this.pipeline.asyncLoadWebGPUPiplines(); |
|
} |
|
|
|
|
|
|
|
|
|
async #asyncInitConfig() { |
|
if (this.config !== undefined) return; |
|
this.config = await (await fetch("stable-diffusion-config.json")).json(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
#getProgressCallback() { |
|
const tstart = performance.now(); |
|
function progressCallback(stage, counter, numSteps, totalNumSteps) { |
|
const timeElapsed = (performance.now() - tstart) / 1000; |
|
let text = "Generating ... at stage " + stage; |
|
if (stage == "unet") { |
|
counter += 1; |
|
text += " step [" + counter + "/" + numSteps + "]" |
|
} |
|
if (stage == "vae") { |
|
counter = totalNumSteps; |
|
} |
|
text += ", " + Math.ceil(timeElapsed) + " secs elapsed."; |
|
document.getElementById("progress-tracker-label").innerHTML = text; |
|
document.getElementById("progress-tracker-progress").value = (counter / totalNumSteps) * 100; |
|
} |
|
return progressCallback; |
|
} |
|
|
|
|
|
|
|
|
|
async asyncInit() { |
|
if (this.pipeline !== undefined) return; |
|
await this.#asyncInitConfig(); |
|
await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl); |
|
await this.#asyncInitPipeline(this.config.schedulerConstUrl, this.config.tokenizer); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
async asyncInitOnRPCServerLoad(tvmInstance) { |
|
if (this.tvm !== undefined) { |
|
throw Error("Cannot reuse a loaded instance for rpc"); |
|
} |
|
this.tvm = tvmInstance; |
|
|
|
this.tvm.beginScope(); |
|
this.tvm.registerAsyncServerFunc("generate", async (prompt, schedulerId, vaeCycle) => { |
|
document.getElementById("inputPrompt").value = prompt; |
|
const negPrompt = ""; |
|
document.getElementById("negativePrompt").value = ""; |
|
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle); |
|
}); |
|
this.tvm.registerAsyncServerFunc("clearCanvas", async () => { |
|
this.tvm.clearCanvas(); |
|
}); |
|
this.tvm.registerAsyncServerFunc("showImage", async (data) => { |
|
this.tvm.showImage(data); |
|
}); |
|
this.tvm.endScope(); |
|
} |
|
|
|
|
|
|
|
|
|
async generate() { |
|
if (this.requestInProgress) { |
|
this.logger("Request in progress, generate request ignored"); |
|
return; |
|
} |
|
this.requestInProgress = true; |
|
try { |
|
await this.asyncInit(); |
|
const prompt = document.getElementById("inputPrompt").value; |
|
const negPrompt = document.getElementById("negativePrompt").value; |
|
const schedulerId = document.getElementById("schedulerId").value; |
|
const vaeCycle = document.getElementById("vaeCycle").value; |
|
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle); |
|
} catch (err) { |
|
this.logger("Generate error, " + err.toString()); |
|
console.log(err.stack); |
|
this.reset(); |
|
} |
|
this.requestInProgress = false; |
|
} |
|
|
|
|
|
|
|
|
|
reset() { |
|
this.tvm = undefined; |
|
if (this.pipeline !== undefined) { |
|
this.pipeline.dispose(); |
|
} |
|
this.pipeline = undefined; |
|
} |
|
} |
|
|
|
localStableDiffusionInst = new StableDiffusionInstance(); |
|
|
|
tvmjsGlobalEnv.asyncOnGenerate = async function () { |
|
await localStableDiffusionInst.generate(); |
|
}; |
|
|
|
tvmjsGlobalEnv.asyncOnRPCServerLoad = async function (tvm) { |
|
const inst = new StableDiffusionInstance(); |
|
await inst.asyncInitOnRPCServerLoad(tvm); |
|
}; |
|
|