demos / demo /sd-turbo /index.html
Yang Gu
Fix urls
e499d98
raw
history blame
23.2 kB
<html>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous" />
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"
integrity="sha384-HwwvtgBNo3bZJJLYd8oVXjrBZt8cqVSpeBNS5n7C8IVInixGAoxmnlMuBnhbgrkm" crossorigin="anonymous">
</script>
<script type="module">
import { env, AutoTokenizer } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers/dist/transformers.js'
window.AutoTokenizer = AutoTokenizer;
env.localModelPath = 'models';
</script>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@dev/dist/ort.webgpu.min.js">
</script>
<head>
<title>Stable Diffusion Turbo</title>
</head>
<style>
.onerow {
display: flex;
}
</style>
<body data-bs-theme="dark">
<div class="container">
<div class="row pt-3">
<div class="col-md-9 col-12">
<h2>Stable Diffusion Turbo</h2>
</div>
</div>
<div class="container p-2 card" id="input-area">
<div class="input-group">
<textarea class="form-control" id="user-input" placeholder="Type your question here..."></textarea>
<button id="send-button" class="btn btn-primary">Send</button>
</div>
</div>
<!--<div id="image_area">
<div class="onerow">
<div id="img_div_0" style="margin-right: 4px;">
<canvas id="img_canvas_0"></canvas>
</div>
<div id="img_div_1" style="margin-right: 4px;">
<canvas id="img_canvas_1"></canvas>
</div>
<div id="img_div_2" style="margin-right: 4px;">
<canvas id="img_canvas_2"></canvas>
</div>
<div id="img_div_3" style="margin-right: 4px;">
<canvas id="img_canvas_3"></canvas>
</div>
</div>
</div>-->
<div class="right">
<canvas class='canvas' id='canvas' width='512' height='512'>
Canvas is not supported. You'll need to try a newer browser version or another browser.
</canvas>
</div>
<p class="text-lg-start">
<div id="status" style="font: 1em consolas;"></div>
</p>
</div>
</body>
<script>
var deviceWebgpu = null;
var queueWebgpu = null;
var textEncoderOutputsBuffer = null;
var textEncoderOutputsTensor = null;
var textEncoderOutputs = {};
var latentData = null;
var latentBuffer = null;
var unetSampleInputsBuffer = null;
var unetSampleInputsTensor = null;
var unetOutSampleBuffer = null;
var unetOutSampleTensor = null;
var prescaleLatentSpacePipeline = null;
var prescaleLatentSpaceBindGroup = null;
var stepLatentSpacePipeline = null;
var stepLatentSpaceBindGroup = null;
var decodedOutputsBuffer = null;
var decodedOutputsTensor = null;
const pixelHeight = 512;
const pixelWidth = 512;
var renderContext = null;
var renderPipeline = null;
var renderBindGroup = null;
const PRESCALE_LATENT_SPACE_SHADER = `
@binding(0) @group(0) var<storage, read_write> result: array<vec4<f32>>;
@binding(1) @group(0) var<storage, read> latentData: array<vec4<f32>>;
@compute @workgroup_size(128, 1, 1)
fn _start(@builtin(global_invocation_id) GlobalId : vec3<u32>) {
let index = GlobalId.x;
let value = latentData[index] / 14.64877241136608;
result[index] = value;
}
`;
const STEP_LATENT_SPACE_SHADER = `
@binding(0) @group(0) var<storage, read_write> result: array<vec4<f32>>;
@binding(1) @group(0) var<storage, read> latentData: array<vec4<f32>>;
@compute @workgroup_size(128, 1, 1)
fn _start(@builtin(global_invocation_id) GlobalId : vec3<u32>) {
let index = GlobalId.x;
let sigma_hat = 14.6146;
let latentVal = latentData[index];
let outputSampleVal = result[index];
let pred_original_sample = latentVal - 14.6146 * outputSampleVal;
let derivative = (latentVal - pred_original_sample) / 14.6146;
let dt = -14.6146;
result[index] = (latentVal + derivative * dt) / 0.18215;
}
`;
const VERTEX_SHADER = `
struct VertexOutput {
@builtin(position) Position : vec4<f32>,
@location(0) fragUV : vec2<f32>,
}
@vertex
fn main(@builtin(vertex_index) VertexIndex : u32) -> VertexOutput {
var pos = array<vec2<f32>, 6>(
vec2<f32>( 1.0, 1.0),
vec2<f32>( 1.0, -1.0),
vec2<f32>(-1.0, -1.0),
vec2<f32>( 1.0, 1.0),
vec2<f32>(-1.0, -1.0),
vec2<f32>(-1.0, 1.0)
);
var uv = array<vec2<f32>, 6>(
vec2<f32>(1.0, 0.0),
vec2<f32>(1.0, 1.0),
vec2<f32>(0.0, 1.0),
vec2<f32>(1.0, 0.0),
vec2<f32>(0.0, 1.0),
vec2<f32>(0.0, 0.0)
);
var output : VertexOutput;
output.Position = vec4<f32>(pos[VertexIndex], 0.0, 1.0);
output.fragUV = uv[VertexIndex];
return output;
}
`;
const PIXEL_SHADER = `
@group(0) @binding(1) var<storage, read> buf : array<f32>;
@fragment
fn main(@location(0) fragUV : vec2<f32>) -> @location(0) vec4<f32> {
// The user-facing camera is mirrored, flip horizontally.
var coord = vec2(0.0, 0.0);
if (fragUV.x < 0.5) {
coord = vec2(fragUV.x + 0.5, fragUV.y);
} else {
coord = vec2(fragUV.x - 0.5, fragUV.y);
}
let redInputOffset = 0;
let greenInputOffset = 262144;
let blueInputOffset = 524288;
let index = i32(coord.x * f32(512)) + i32(coord.y * f32(512) * f32(512)); // pixelWidth = pixelHeight= 512
let r = clamp(buf[index] / 2 + 0.5, 0.0, 1.0);
let g = clamp(buf[262144 + index] / 2 + 0.5, 0.0, 1.0);
let b = clamp(buf[524288 + index] / 2 + 0.5, 0.0, 1.0);
let a = 1.0;
var out_color = vec4<f32>(r, g, b, a);
return out_color;
}
`
function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; }
function getConfig() {
const query = window.location.search.substring(1);
var config = {
// model: "models/onnx-sd-turbo-fp16",
//model: "https://huggingface.co/schmuell/sd-turbo-ort-web/resolve/main",
model: "models",
provider: "webgpu",
device: "gpu",
threads: "1",
images: "1",
};
let vars = query.split("&");
for (var i = 0; i < vars.length; i++) {
let pair = vars[i].split("=");
if (pair[0] in config) {
config[pair[0]] = decodeURIComponent(pair[1]);
} else if (pair[0].length > 0) {
throw new Error("unknown argument: " + pair[0]);
}
}
config.threads = parseInt(config.threads);
config.images = parseInt(config.images);
return config;
}
function randn_latents(shape, noise_sigma) {
function randn() {
// Use the Box-Muller transform
let u = Math.random();
let v = Math.random();
let z = Math.sqrt(-2 * Math.log(u)) * Math.cos(2 * Math.PI * v);
return z;
}
let size = 1;
shape.forEach(element => {
size *= element;
});
let data = new Float32Array(size);
// Loop over the shape dimensions
for (let i = 0; i < size; i++) {
data[i] = randn() * noise_sigma;
}
return data;
}
async function fetchAndCache(base_url, model_path) {
const url = `${base_url}/${model_path}`;
try {
const cache = await caches.open("onnx");
let cachedResponse = await cache.match(url);
if (cachedResponse == undefined) {
await cache.add(url);
cachedResponse = await cache.match(url);
log(`${model_path} (network)`);
} else {
log(`${model_path} (cached)`);
}
const data = await cachedResponse.arrayBuffer();
return data;
} catch (error) {
log(`${model_path} (network)`);
return await fetch(url).then(response => response.arrayBuffer());
}
}
function uploadToGPU(buffer, values, type) {
const stagingBuffer = deviceWebgpu.createBuffer({
usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC,
size: values.buffer.byteLength,
mappedAtCreation: true
});
const arrayBuffer = stagingBuffer.getMappedRange();
if (type === 'float32') {
new Float32Array(arrayBuffer).set(values);
} else if (type === 'int32') {
new Int32Array(arrayBuffer).set(values);
}
stagingBuffer.unmap();
const encoder = deviceWebgpu.createCommandEncoder();
encoder.copyBufferToBuffer(stagingBuffer, 0, buffer, 0, values.byteLength);
deviceWebgpu.queue.submit([encoder.finish()]);
stagingBuffer.destroy();
}
function submitComputeTask(pipeline, bindGroup) {
let commandEncoderWebgpu = deviceWebgpu.createCommandEncoder();
let computePassEncoder = commandEncoderWebgpu.beginComputePass();
computePassEncoder.setPipeline(pipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(32, 1, 1);
computePassEncoder.end();
computePassEncoder = null;
queueWebgpu.submit([commandEncoderWebgpu.finish()]);
}
async function load_models(models) {
const cache = await caches.open("onnx");
let missing = 0;
for (const [name, model] of Object.entries(models)) {
const url = `${config.model}/${model.url}`;
let cachedResponse = await cache.match(url);
if (cachedResponse === undefined) {
missing += model.size;
}
}
if (missing > 0) {
log(`downloading ${missing} MB from network ... it might take a while`);
} else {
log("loading...");
}
let loadedCount = 0;
for (const [name, model] of Object.entries(models)) {
try {
if (loadedCount === 1) {
webgpuResourceInitialize();
}
const start = performance.now();
const model_bytes = await fetchAndCache(config.model, model.url);
const sess_opt = { ...opt, ...model.opt };
// profiling
//ort.env.webgpu.profiling = { mode: "default" };
models[name].sess = await ort.InferenceSession.create(model_bytes, sess_opt);
const stop = performance.now();
loadedCount++;
log(`${model.url} in ${(stop - start).toFixed(1)}ms`);
} catch (e) {
log(`${model.url} failed, ${e}`);
}
}
const latent_shape = [1, 4, 64, 64];
latentData = randn_latents(latent_shape, sigma);
uploadToGPU(latentBuffer, latentData, "float32");
submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup);
log("ready.");
}
const config = getConfig();
const models = {
"unet": {
url: "unet/model.onnx", size: 640,
// should have 'steps: 1' but will fail to create the session
opt: { freeDimensionOverrides: { batch_size: 1, num_channels: 4, height: 64, width: 64, sequence_length: 77, } }
},
"text_encoder": {
url: "text_encoder/model.onnx", size: 1700,
// should have 'sequence_length: 77' but produces a bad image
opt: { freeDimensionOverrides: { batch_size: 1, } },
},
"vae_decoder": {
url: "vae_decoder/model.onnx", size: 95,
opt: { freeDimensionOverrides: { batch_size: 1, num_channels_latent: 4, height_latent: 64, width_latent: 64 } }
}
}
const text = document.getElementById("user-input");
let tokenizer;
let loading;
const sigma = 14.6146;
const gamma = 0;
const vae_scaling_factor = 0.18215;
// text.value = "A cinematic shot of a baby racoon wearing an intricate italian priest robe.";
text.value = "Paris with the river in the background";
if (config.provider == "webgpu") {
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = true;
} else {
ort.env.wasm.numThreads = config.threads;
ort.env.wasm.simd = true;
}
const opt = {
executionProviders: [config.provider],
enableMemPattern: false,
enableCpuMemArena: false,
extra: {
session: {
disable_prepacking: "1",
use_device_allocator_for_initializers: "1",
use_ort_model_bytes_directly: "1",
use_ort_model_bytes_for_initializers: "1"
}
},
};
switch (config.provider) {
case "webgpu":
if (!("gpu" in navigator)) {
throw new Error("webgpu is NOT supported");
}
opt.preferredOutputLocation = { last_hidden_state: "gpu-buffer" };
break;
case "webnn":
if (!("ml" in navigator)) {
throw new Error("webnn is NOT supported");
}
opt.executionProviders = [{
name: "webnn",
deviceType: config.device,
powerPreference: 'default'
}];
break;
}
// Event listener for Ctrl + Enter or CMD + Enter
document.getElementById('user-input').addEventListener('keydown', function (e) {
if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') {
run();
const latent_shape = [1, 4, 64, 64];
latentData = randn_latents(latent_shape, sigma);
uploadToGPU(latentBuffer, latentData, "float32");
submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup);
}
});
document.getElementById('send-button').addEventListener('click', function (e) {
run();
const latent_shape = [1, 4, 64, 64];
latentData = randn_latents(latent_shape, sigma);
uploadToGPU(latentBuffer, latentData, "float32");
submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup);
});
function init_latents(t) {
const d = t.data;
for (let i = 0; i < d.length; i++) {
d[i] = d[i] * sigma;
}
return t;
}
function scale_model_inputs(t) {
const d_i = t.data;
const d_o = new Float32Array(d_i.length);
const divi = (sigma ** 2 + 1) ** 0.5;
for (let i = 0; i < d_i.length; i++) {
d_o[i] = d_i[i] / divi;
}
return new ort.Tensor(d_o, t.dims);
}
function step(model_output, sample) {
// poor mens EulerA.
// Since this is just a example for sd-turbo, only implement the absolute minimum
// needed to create an image
const d_o = new Float32Array(model_output.data.length);
const prev_sample = new ort.Tensor(d_o, model_output.dims);
const sigma_hat = sigma * (gamma + 1);
for (let i = 0; i < model_output.data.length; i++) {
pred_original_sample = sample.data[i] - sigma_hat * model_output.data[i];
derivative = (sample.data[i] - pred_original_sample) / sigma_hat;
dt = 0 - sigma_hat;
d_o[i] = (sample.data[i] + derivative * dt) / vae_scaling_factor;
}
return prev_sample;
}
function draw_image(t, image_nr) {
let pix = t.data;
for (var i = 0; i < pix.length; i++) {
let x = pix[i];
x = x / 2 + 0.5
if (x < 0.) x = 0.;
if (x > 1.) x = 1.;
pix[i] = x;
}
const imageData = t.toImageData({ tensorLayout: 'NCWH', format: 'RGB' });
const canvas = document.getElementById(`img_canvas_${image_nr}`);
canvas.width = imageData.width;
canvas.height = imageData.height;
canvas.getContext('2d').putImageData(imageData, 0, 0);
const div = document.getElementById(`img_div_${image_nr}`);
div.style.opacity = 1.
}
async function downloadToCPU(buffer) {
const stagingBuffer = deviceWebgpu.createBuffer({
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
size: buffer.size
});
const encoder = deviceWebgpu.createCommandEncoder();
encoder.copyBufferToBuffer(buffer, 0, stagingBuffer, 0, buffer.size);
deviceWebgpu.queue.submit([encoder.finish()]);
await stagingBuffer.mapAsync(GPUMapMode.READ);
const arrayBuffer = stagingBuffer.getMappedRange().slice(0, buffer.size / 4);
stagingBuffer.destroy();
return new Float32Array(arrayBuffer);
};
async function run() {
try {
document.getElementById('status').innerText = "generating ...";
if (tokenizer === undefined) {
tokenizer = await AutoTokenizer.from_pretrained('tokenizer', quantized = false, local_files_only = false);
tokenizer.pad_token_id = 0;
}
let canvases = [];
await loading;
const { input_ids } = await tokenizer(text.value, { padding: true, max_length: 77, truncation: true, return_tensor: false });
// text-encoder
let start = performance.now();
const executionStart = performance.now();
textEncoderOutputs['last_hidden_state'] = textEncoderOutputsTensor;
await models.text_encoder.sess.run({ "input_ids": new ort.Tensor("int32", input_ids, [1, input_ids.length]) }, textEncoderOutputs);
let perf_info = [`text_encoder: ${(performance.now() - start).toFixed(1)}ms`];
for (let j = 0; j < config.images; j++) {
start = performance.now();
let feed = {
"sample": unetSampleInputsTensor,
"timestep": new ort.Tensor("int64", [999n], [1]),
"encoder_hidden_states": textEncoderOutputsTensor,
};
var unetOutSampleOutputs = {};
unetOutSampleOutputs['out_sample'] = unetOutSampleTensor;
let { out_sample } = await models.unet.sess.run(feed, unetOutSampleOutputs);
perf_info.push(`unet: ${(performance.now() - start).toFixed(1)}ms`);
submitComputeTask(stepLatentSpacePipeline, stepLatentSpaceBindGroup);
start = performance.now();
var vaeDecodeInputs = {};
vaeDecodeInputs['latent_sample'] = unetOutSampleTensor;
const decodedOutputs = {};
decodedOutputs['sample'] = decodedOutputsTensor;
await models.vae_decoder.sess.run(vaeDecodeInputs, decodedOutputs);
// profiling
// ort.env.webgpu.profiling = { mode: "" };
const commandEncoder = deviceWebgpu.createCommandEncoder();
const textureView = renderContext.getCurrentTexture().createView();
const renderPassDescriptor = {
colorAttachments: [
{
view: textureView,
clearValue: { r: 1.0, g: 0.0, b: 0.0, a: 1.0 },
loadOp: 'clear',
storeOp: 'store',
},
],
};
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
passEncoder.setPipeline(renderPipeline);
passEncoder.setBindGroup(0, renderBindGroup);
passEncoder.draw(6, 1, 0, 0);
passEncoder.end();
deviceWebgpu.queue.submit([commandEncoder.finish()]);
await deviceWebgpu.queue.onSubmittedWorkDone();
const executionEnd = performance.now();
perf_info.push(`vae_decoder: ${(executionEnd - start).toFixed(1)}ms`);
perf_info.push(`execution time: ${(executionEnd - executionStart).toFixed(1)}ms`);
log(perf_info.join(", "))
perf_info = [];
}
//last_hidden_state.dispose();
log("done");
} catch (e) {
log(e);
}
}
function webgpuResourceInitialize() {
deviceWebgpu = ort.env.webgpu.device;
queueWebgpu = deviceWebgpu.queue;
textEncoderOutputsBuffer = deviceWebgpu.createBuffer({
size: Math.ceil((1 * 77 * 1024 * 4) / 16) * 16,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
});
textEncoderOutputsTensor = ort.Tensor.fromGpuBuffer(textEncoderOutputsBuffer, {
dataType: 'float32', dims: [1, 77, 1024],
dispose: () => textEncoderOutputsBuffer.destroy()
});
unetOutSampleBuffer = deviceWebgpu.createBuffer({
size: Math.ceil((1 * 4 * 64 * 64 * 4) / 16) * 16,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
});
unetOutSampleTensor = ort.Tensor.fromGpuBuffer(unetOutSampleBuffer, {
dataType: 'float32', dims: [1, 4, 64, 64],
dispose: () => unetOutSampleBuffer.destroy()
});
latentBuffer = deviceWebgpu.createBuffer({
size: Math.ceil((1 * 4 * 64 * 64 * 4) / 16) * 16,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
});
unetSampleInputsBuffer = deviceWebgpu.createBuffer({
size: Math.ceil((1 * 4 * 64 * 64 * 4) / 16) * 16,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
});
unetSampleInputsTensor = ort.Tensor.fromGpuBuffer(unetSampleInputsBuffer, {
dataType: 'float32', dims: [1, 4, 64, 64],
dispose: () => unetSampleInputsBuffer.destroy()
});
decodedOutputsBuffer = deviceWebgpu.createBuffer({
size: Math.ceil((1 * 3 * pixelHeight * pixelWidth * 4) / 16) * 16,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
});
decodedOutputsTensor = ort.Tensor.fromGpuBuffer(decodedOutputsBuffer, {
dataType: 'float32', dims: [1, 3, pixelHeight, pixelWidth],
dispose: () => decodedOutputsBuffer.destroy()
});
prescaleLatentSpacePipeline = deviceWebgpu.createComputePipeline({
layout: 'auto',
compute: {
module: deviceWebgpu.createShaderModule({
code: PRESCALE_LATENT_SPACE_SHADER,
}),
entryPoint: '_start',
},
});
prescaleLatentSpaceBindGroup = deviceWebgpu.createBindGroup({
layout: prescaleLatentSpacePipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer: unetSampleInputsBuffer,
},
},
{
binding: 1,
resource: {
buffer: latentBuffer,
},
}
],
});
stepLatentSpacePipeline = deviceWebgpu.createComputePipeline({
layout: 'auto',
compute: {
module: deviceWebgpu.createShaderModule({
code: STEP_LATENT_SPACE_SHADER,
}),
entryPoint: '_start',
},
});
stepLatentSpaceBindGroup = deviceWebgpu.createBindGroup({
layout: stepLatentSpacePipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer: unetOutSampleBuffer,
},
},
{
binding: 1,
resource: {
buffer: latentBuffer,
},
}
],
});
const canvas = document.getElementById(`canvas`);
canvas.width = pixelWidth;
canvas.height = pixelHeight;
renderContext = canvas.getContext('webgpu');
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
const presentationSize = [pixelWidth, pixelHeight];
renderContext.configure({
device: deviceWebgpu,
size: presentationSize,
format: presentationFormat,
alphaMode: 'opaque',
});
renderPipeline = deviceWebgpu.createRenderPipeline({
layout: 'auto',
vertex: {
module: deviceWebgpu.createShaderModule({
code: VERTEX_SHADER,
}),
entryPoint: 'main',
},
fragment: {
module: deviceWebgpu.createShaderModule({
code: PIXEL_SHADER,
}),
entryPoint: 'main',
targets: [
{
format: presentationFormat,
},
],
},
primitive: {
topology: 'triangle-list',
},
});
renderBindGroup = deviceWebgpu.createBindGroup({
layout: renderPipeline.getBindGroupLayout(0),
entries: [
{
binding: 1,
resource: {
buffer: decodedOutputsBuffer,
},
}
],
});
}
async function main() {
loading = load_models(models);
}
window.onload = () => {
main();
}
</script>
</html>