Spaces:
Running
Running
<html> | |
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet" | |
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous" /> | |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js" | |
integrity="sha384-HwwvtgBNo3bZJJLYd8oVXjrBZt8cqVSpeBNS5n7C8IVInixGAoxmnlMuBnhbgrkm" crossorigin="anonymous"> | |
</script> | |
<script type="module"> | |
import { env, AutoTokenizer } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers/dist/transformers.js' | |
window.AutoTokenizer = AutoTokenizer; | |
env.localModelPath = 'models'; | |
</script> | |
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@dev/dist/ort.webgpu.min.js"> | |
</script> | |
<head> | |
<title>Stable Diffusion Turbo</title> | |
</head> | |
<style> | |
.onerow { | |
display: flex; | |
} | |
</style> | |
<body data-bs-theme="dark"> | |
<div class="container"> | |
<div class="row pt-3"> | |
<div class="col-md-9 col-12"> | |
<h2>Stable Diffusion Turbo</h2> | |
</div> | |
</div> | |
<div class="container p-2 card" id="input-area"> | |
<div class="input-group"> | |
<textarea class="form-control" id="user-input" placeholder="Type your question here..."></textarea> | |
<button id="send-button" class="btn btn-primary">Send</button> | |
</div> | |
</div> | |
<!--<div id="image_area"> | |
<div class="onerow"> | |
<div id="img_div_0" style="margin-right: 4px;"> | |
<canvas id="img_canvas_0"></canvas> | |
</div> | |
<div id="img_div_1" style="margin-right: 4px;"> | |
<canvas id="img_canvas_1"></canvas> | |
</div> | |
<div id="img_div_2" style="margin-right: 4px;"> | |
<canvas id="img_canvas_2"></canvas> | |
</div> | |
<div id="img_div_3" style="margin-right: 4px;"> | |
<canvas id="img_canvas_3"></canvas> | |
</div> | |
</div> | |
</div>--> | |
<div class="right"> | |
<canvas class='canvas' id='canvas' width='512' height='512'> | |
Canvas is not supported. You'll need to try a newer browser version or another browser. | |
</canvas> | |
</div> | |
<p class="text-lg-start"> | |
<div id="status" style="font: 1em consolas;"></div> | |
</p> | |
</div> | |
</body> | |
<script> | |
var deviceWebgpu = null; | |
var queueWebgpu = null; | |
var textEncoderOutputsBuffer = null; | |
var textEncoderOutputsTensor = null; | |
var textEncoderOutputs = {}; | |
var latentData = null; | |
var latentBuffer = null; | |
var unetSampleInputsBuffer = null; | |
var unetSampleInputsTensor = null; | |
var unetOutSampleBuffer = null; | |
var unetOutSampleTensor = null; | |
var prescaleLatentSpacePipeline = null; | |
var prescaleLatentSpaceBindGroup = null; | |
var stepLatentSpacePipeline = null; | |
var stepLatentSpaceBindGroup = null; | |
var decodedOutputsBuffer = null; | |
var decodedOutputsTensor = null; | |
const pixelHeight = 512; | |
const pixelWidth = 512; | |
var renderContext = null; | |
var renderPipeline = null; | |
var renderBindGroup = null; | |
const PRESCALE_LATENT_SPACE_SHADER = ` | |
@binding(0) @group(0) var<storage, read_write> result: array<vec4<f32>>; | |
@binding(1) @group(0) var<storage, read> latentData: array<vec4<f32>>; | |
@compute @workgroup_size(128, 1, 1) | |
fn _start(@builtin(global_invocation_id) GlobalId : vec3<u32>) { | |
let index = GlobalId.x; | |
let value = latentData[index] / 14.64877241136608; | |
result[index] = value; | |
} | |
`; | |
const STEP_LATENT_SPACE_SHADER = ` | |
@binding(0) @group(0) var<storage, read_write> result: array<vec4<f32>>; | |
@binding(1) @group(0) var<storage, read> latentData: array<vec4<f32>>; | |
@compute @workgroup_size(128, 1, 1) | |
fn _start(@builtin(global_invocation_id) GlobalId : vec3<u32>) { | |
let index = GlobalId.x; | |
let sigma_hat = 14.6146; | |
let latentVal = latentData[index]; | |
let outputSampleVal = result[index]; | |
let pred_original_sample = latentVal - 14.6146 * outputSampleVal; | |
let derivative = (latentVal - pred_original_sample) / 14.6146; | |
let dt = -14.6146; | |
result[index] = (latentVal + derivative * dt) / 0.18215; | |
} | |
`; | |
const VERTEX_SHADER = ` | |
struct VertexOutput { | |
@builtin(position) Position : vec4<f32>, | |
@location(0) fragUV : vec2<f32>, | |
} | |
@vertex | |
fn main(@builtin(vertex_index) VertexIndex : u32) -> VertexOutput { | |
var pos = array<vec2<f32>, 6>( | |
vec2<f32>( 1.0, 1.0), | |
vec2<f32>( 1.0, -1.0), | |
vec2<f32>(-1.0, -1.0), | |
vec2<f32>( 1.0, 1.0), | |
vec2<f32>(-1.0, -1.0), | |
vec2<f32>(-1.0, 1.0) | |
); | |
var uv = array<vec2<f32>, 6>( | |
vec2<f32>(1.0, 0.0), | |
vec2<f32>(1.0, 1.0), | |
vec2<f32>(0.0, 1.0), | |
vec2<f32>(1.0, 0.0), | |
vec2<f32>(0.0, 1.0), | |
vec2<f32>(0.0, 0.0) | |
); | |
var output : VertexOutput; | |
output.Position = vec4<f32>(pos[VertexIndex], 0.0, 1.0); | |
output.fragUV = uv[VertexIndex]; | |
return output; | |
} | |
`; | |
const PIXEL_SHADER = ` | |
@group(0) @binding(1) var<storage, read> buf : array<f32>; | |
@fragment | |
fn main(@location(0) fragUV : vec2<f32>) -> @location(0) vec4<f32> { | |
// The user-facing camera is mirrored, flip horizontally. | |
var coord = vec2(0.0, 0.0); | |
if (fragUV.x < 0.5) { | |
coord = vec2(fragUV.x + 0.5, fragUV.y); | |
} else { | |
coord = vec2(fragUV.x - 0.5, fragUV.y); | |
} | |
let redInputOffset = 0; | |
let greenInputOffset = 262144; | |
let blueInputOffset = 524288; | |
let index = i32(coord.x * f32(512)) + i32(coord.y * f32(512) * f32(512)); // pixelWidth = pixelHeight= 512 | |
let r = clamp(buf[index] / 2 + 0.5, 0.0, 1.0); | |
let g = clamp(buf[262144 + index] / 2 + 0.5, 0.0, 1.0); | |
let b = clamp(buf[524288 + index] / 2 + 0.5, 0.0, 1.0); | |
let a = 1.0; | |
var out_color = vec4<f32>(r, g, b, a); | |
return out_color; | |
} | |
` | |
function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; } | |
function getConfig() { | |
const query = window.location.search.substring(1); | |
var config = { | |
// model: "models/onnx-sd-turbo-fp16", | |
//model: "https://huggingface.co/schmuell/sd-turbo-ort-web/resolve/main", | |
model: "models", | |
provider: "webgpu", | |
device: "gpu", | |
threads: "1", | |
images: "1", | |
}; | |
let vars = query.split("&"); | |
for (var i = 0; i < vars.length; i++) { | |
let pair = vars[i].split("="); | |
if (pair[0] in config) { | |
config[pair[0]] = decodeURIComponent(pair[1]); | |
} else if (pair[0].length > 0) { | |
throw new Error("unknown argument: " + pair[0]); | |
} | |
} | |
config.threads = parseInt(config.threads); | |
config.images = parseInt(config.images); | |
return config; | |
} | |
function randn_latents(shape, noise_sigma) { | |
function randn() { | |
// Use the Box-Muller transform | |
let u = Math.random(); | |
let v = Math.random(); | |
let z = Math.sqrt(-2 * Math.log(u)) * Math.cos(2 * Math.PI * v); | |
return z; | |
} | |
let size = 1; | |
shape.forEach(element => { | |
size *= element; | |
}); | |
let data = new Float32Array(size); | |
// Loop over the shape dimensions | |
for (let i = 0; i < size; i++) { | |
data[i] = randn() * noise_sigma; | |
} | |
return data; | |
} | |
async function fetchAndCache(base_url, model_path) { | |
const url = `${base_url}/${model_path}`; | |
try { | |
const cache = await caches.open("onnx"); | |
let cachedResponse = await cache.match(url); | |
if (cachedResponse == undefined) { | |
await cache.add(url); | |
cachedResponse = await cache.match(url); | |
log(`${model_path} (network)`); | |
} else { | |
log(`${model_path} (cached)`); | |
} | |
const data = await cachedResponse.arrayBuffer(); | |
return data; | |
} catch (error) { | |
log(`${model_path} (network)`); | |
return await fetch(url).then(response => response.arrayBuffer()); | |
} | |
} | |
function uploadToGPU(buffer, values, type) { | |
const stagingBuffer = deviceWebgpu.createBuffer({ | |
usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC, | |
size: values.buffer.byteLength, | |
mappedAtCreation: true | |
}); | |
const arrayBuffer = stagingBuffer.getMappedRange(); | |
if (type === 'float32') { | |
new Float32Array(arrayBuffer).set(values); | |
} else if (type === 'int32') { | |
new Int32Array(arrayBuffer).set(values); | |
} | |
stagingBuffer.unmap(); | |
const encoder = deviceWebgpu.createCommandEncoder(); | |
encoder.copyBufferToBuffer(stagingBuffer, 0, buffer, 0, values.byteLength); | |
deviceWebgpu.queue.submit([encoder.finish()]); | |
stagingBuffer.destroy(); | |
} | |
function submitComputeTask(pipeline, bindGroup) { | |
let commandEncoderWebgpu = deviceWebgpu.createCommandEncoder(); | |
let computePassEncoder = commandEncoderWebgpu.beginComputePass(); | |
computePassEncoder.setPipeline(pipeline); | |
computePassEncoder.setBindGroup(0, bindGroup); | |
computePassEncoder.dispatchWorkgroups(32, 1, 1); | |
computePassEncoder.end(); | |
computePassEncoder = null; | |
queueWebgpu.submit([commandEncoderWebgpu.finish()]); | |
} | |
async function load_models(models) { | |
const cache = await caches.open("onnx"); | |
let missing = 0; | |
for (const [name, model] of Object.entries(models)) { | |
const url = `${config.model}/${model.url}`; | |
let cachedResponse = await cache.match(url); | |
if (cachedResponse === undefined) { | |
missing += model.size; | |
} | |
} | |
if (missing > 0) { | |
log(`downloading ${missing} MB from network ... it might take a while`); | |
} else { | |
log("loading..."); | |
} | |
let loadedCount = 0; | |
for (const [name, model] of Object.entries(models)) { | |
try { | |
if (loadedCount === 1) { | |
webgpuResourceInitialize(); | |
} | |
const start = performance.now(); | |
const model_bytes = await fetchAndCache(config.model, model.url); | |
const sess_opt = { ...opt, ...model.opt }; | |
// profiling | |
//ort.env.webgpu.profiling = { mode: "default" }; | |
models[name].sess = await ort.InferenceSession.create(model_bytes, sess_opt); | |
const stop = performance.now(); | |
loadedCount++; | |
log(`${model.url} in ${(stop - start).toFixed(1)}ms`); | |
} catch (e) { | |
log(`${model.url} failed, ${e}`); | |
} | |
} | |
const latent_shape = [1, 4, 64, 64]; | |
latentData = randn_latents(latent_shape, sigma); | |
uploadToGPU(latentBuffer, latentData, "float32"); | |
submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup); | |
log("ready."); | |
} | |
const config = getConfig(); | |
const models = { | |
"unet": { | |
url: "unet/model.onnx", size: 640, | |
// should have 'steps: 1' but will fail to create the session | |
opt: { freeDimensionOverrides: { batch_size: 1, num_channels: 4, height: 64, width: 64, sequence_length: 77, } } | |
}, | |
"text_encoder": { | |
url: "text_encoder/model.onnx", size: 1700, | |
// should have 'sequence_length: 77' but produces a bad image | |
opt: { freeDimensionOverrides: { batch_size: 1, } }, | |
}, | |
"vae_decoder": { | |
url: "vae_decoder/model.onnx", size: 95, | |
opt: { freeDimensionOverrides: { batch_size: 1, num_channels_latent: 4, height_latent: 64, width_latent: 64 } } | |
} | |
} | |
const text = document.getElementById("user-input"); | |
let tokenizer; | |
let loading; | |
const sigma = 14.6146; | |
const gamma = 0; | |
const vae_scaling_factor = 0.18215; | |
// text.value = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."; | |
text.value = "Paris with the river in the background"; | |
if (config.provider == "webgpu") { | |
ort.env.wasm.numThreads = 1; | |
ort.env.wasm.simd = true; | |
} else { | |
ort.env.wasm.numThreads = config.threads; | |
ort.env.wasm.simd = true; | |
} | |
const opt = { | |
executionProviders: [config.provider], | |
enableMemPattern: false, | |
enableCpuMemArena: false, | |
extra: { | |
session: { | |
disable_prepacking: "1", | |
use_device_allocator_for_initializers: "1", | |
use_ort_model_bytes_directly: "1", | |
use_ort_model_bytes_for_initializers: "1" | |
} | |
}, | |
}; | |
switch (config.provider) { | |
case "webgpu": | |
if (!("gpu" in navigator)) { | |
throw new Error("webgpu is NOT supported"); | |
} | |
opt.preferredOutputLocation = { last_hidden_state: "gpu-buffer" }; | |
break; | |
case "webnn": | |
if (!("ml" in navigator)) { | |
throw new Error("webnn is NOT supported"); | |
} | |
opt.executionProviders = [{ | |
name: "webnn", | |
deviceType: config.device, | |
powerPreference: 'default' | |
}]; | |
break; | |
} | |
// Event listener for Ctrl + Enter or CMD + Enter | |
document.getElementById('user-input').addEventListener('keydown', function (e) { | |
if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') { | |
run(); | |
const latent_shape = [1, 4, 64, 64]; | |
latentData = randn_latents(latent_shape, sigma); | |
uploadToGPU(latentBuffer, latentData, "float32"); | |
submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup); | |
} | |
}); | |
document.getElementById('send-button').addEventListener('click', function (e) { | |
run(); | |
const latent_shape = [1, 4, 64, 64]; | |
latentData = randn_latents(latent_shape, sigma); | |
uploadToGPU(latentBuffer, latentData, "float32"); | |
submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup); | |
}); | |
function init_latents(t) { | |
const d = t.data; | |
for (let i = 0; i < d.length; i++) { | |
d[i] = d[i] * sigma; | |
} | |
return t; | |
} | |
function scale_model_inputs(t) { | |
const d_i = t.data; | |
const d_o = new Float32Array(d_i.length); | |
const divi = (sigma ** 2 + 1) ** 0.5; | |
for (let i = 0; i < d_i.length; i++) { | |
d_o[i] = d_i[i] / divi; | |
} | |
return new ort.Tensor(d_o, t.dims); | |
} | |
function step(model_output, sample) { | |
// poor mens EulerA. | |
// Since this is just a example for sd-turbo, only implement the absolute minimum | |
// needed to create an image | |
const d_o = new Float32Array(model_output.data.length); | |
const prev_sample = new ort.Tensor(d_o, model_output.dims); | |
const sigma_hat = sigma * (gamma + 1); | |
for (let i = 0; i < model_output.data.length; i++) { | |
pred_original_sample = sample.data[i] - sigma_hat * model_output.data[i]; | |
derivative = (sample.data[i] - pred_original_sample) / sigma_hat; | |
dt = 0 - sigma_hat; | |
d_o[i] = (sample.data[i] + derivative * dt) / vae_scaling_factor; | |
} | |
return prev_sample; | |
} | |
function draw_image(t, image_nr) { | |
let pix = t.data; | |
for (var i = 0; i < pix.length; i++) { | |
let x = pix[i]; | |
x = x / 2 + 0.5 | |
if (x < 0.) x = 0.; | |
if (x > 1.) x = 1.; | |
pix[i] = x; | |
} | |
const imageData = t.toImageData({ tensorLayout: 'NCWH', format: 'RGB' }); | |
const canvas = document.getElementById(`img_canvas_${image_nr}`); | |
canvas.width = imageData.width; | |
canvas.height = imageData.height; | |
canvas.getContext('2d').putImageData(imageData, 0, 0); | |
const div = document.getElementById(`img_div_${image_nr}`); | |
div.style.opacity = 1. | |
} | |
async function downloadToCPU(buffer) { | |
const stagingBuffer = deviceWebgpu.createBuffer({ | |
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, | |
size: buffer.size | |
}); | |
const encoder = deviceWebgpu.createCommandEncoder(); | |
encoder.copyBufferToBuffer(buffer, 0, stagingBuffer, 0, buffer.size); | |
deviceWebgpu.queue.submit([encoder.finish()]); | |
await stagingBuffer.mapAsync(GPUMapMode.READ); | |
const arrayBuffer = stagingBuffer.getMappedRange().slice(0, buffer.size / 4); | |
stagingBuffer.destroy(); | |
return new Float32Array(arrayBuffer); | |
}; | |
async function run() { | |
try { | |
document.getElementById('status').innerText = "generating ..."; | |
if (tokenizer === undefined) { | |
tokenizer = await AutoTokenizer.from_pretrained('tokenizer', quantized = false, local_files_only = false); | |
tokenizer.pad_token_id = 0; | |
} | |
let canvases = []; | |
await loading; | |
const { input_ids } = await tokenizer(text.value, { padding: true, max_length: 77, truncation: true, return_tensor: false }); | |
// text-encoder | |
let start = performance.now(); | |
const executionStart = performance.now(); | |
textEncoderOutputs['last_hidden_state'] = textEncoderOutputsTensor; | |
await models.text_encoder.sess.run({ "input_ids": new ort.Tensor("int32", input_ids, [1, input_ids.length]) }, textEncoderOutputs); | |
let perf_info = [`text_encoder: ${(performance.now() - start).toFixed(1)}ms`]; | |
for (let j = 0; j < config.images; j++) { | |
start = performance.now(); | |
let feed = { | |
"sample": unetSampleInputsTensor, | |
"timestep": new ort.Tensor("int64", [999n], [1]), | |
"encoder_hidden_states": textEncoderOutputsTensor, | |
}; | |
var unetOutSampleOutputs = {}; | |
unetOutSampleOutputs['out_sample'] = unetOutSampleTensor; | |
let { out_sample } = await models.unet.sess.run(feed, unetOutSampleOutputs); | |
perf_info.push(`unet: ${(performance.now() - start).toFixed(1)}ms`); | |
submitComputeTask(stepLatentSpacePipeline, stepLatentSpaceBindGroup); | |
start = performance.now(); | |
var vaeDecodeInputs = {}; | |
vaeDecodeInputs['latent_sample'] = unetOutSampleTensor; | |
const decodedOutputs = {}; | |
decodedOutputs['sample'] = decodedOutputsTensor; | |
await models.vae_decoder.sess.run(vaeDecodeInputs, decodedOutputs); | |
// profiling | |
// ort.env.webgpu.profiling = { mode: "" }; | |
const commandEncoder = deviceWebgpu.createCommandEncoder(); | |
const textureView = renderContext.getCurrentTexture().createView(); | |
const renderPassDescriptor = { | |
colorAttachments: [ | |
{ | |
view: textureView, | |
clearValue: { r: 1.0, g: 0.0, b: 0.0, a: 1.0 }, | |
loadOp: 'clear', | |
storeOp: 'store', | |
}, | |
], | |
}; | |
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor); | |
passEncoder.setPipeline(renderPipeline); | |
passEncoder.setBindGroup(0, renderBindGroup); | |
passEncoder.draw(6, 1, 0, 0); | |
passEncoder.end(); | |
deviceWebgpu.queue.submit([commandEncoder.finish()]); | |
await deviceWebgpu.queue.onSubmittedWorkDone(); | |
const executionEnd = performance.now(); | |
perf_info.push(`vae_decoder: ${(executionEnd - start).toFixed(1)}ms`); | |
perf_info.push(`execution time: ${(executionEnd - executionStart).toFixed(1)}ms`); | |
log(perf_info.join(", ")) | |
perf_info = []; | |
} | |
//last_hidden_state.dispose(); | |
log("done"); | |
} catch (e) { | |
log(e); | |
} | |
} | |
function webgpuResourceInitialize() { | |
deviceWebgpu = ort.env.webgpu.device; | |
queueWebgpu = deviceWebgpu.queue; | |
textEncoderOutputsBuffer = deviceWebgpu.createBuffer({ | |
size: Math.ceil((1 * 77 * 1024 * 4) / 16) * 16, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | |
}); | |
textEncoderOutputsTensor = ort.Tensor.fromGpuBuffer(textEncoderOutputsBuffer, { | |
dataType: 'float32', dims: [1, 77, 1024], | |
dispose: () => textEncoderOutputsBuffer.destroy() | |
}); | |
unetOutSampleBuffer = deviceWebgpu.createBuffer({ | |
size: Math.ceil((1 * 4 * 64 * 64 * 4) / 16) * 16, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | |
}); | |
unetOutSampleTensor = ort.Tensor.fromGpuBuffer(unetOutSampleBuffer, { | |
dataType: 'float32', dims: [1, 4, 64, 64], | |
dispose: () => unetOutSampleBuffer.destroy() | |
}); | |
latentBuffer = deviceWebgpu.createBuffer({ | |
size: Math.ceil((1 * 4 * 64 * 64 * 4) / 16) * 16, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | |
}); | |
unetSampleInputsBuffer = deviceWebgpu.createBuffer({ | |
size: Math.ceil((1 * 4 * 64 * 64 * 4) / 16) * 16, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | |
}); | |
unetSampleInputsTensor = ort.Tensor.fromGpuBuffer(unetSampleInputsBuffer, { | |
dataType: 'float32', dims: [1, 4, 64, 64], | |
dispose: () => unetSampleInputsBuffer.destroy() | |
}); | |
decodedOutputsBuffer = deviceWebgpu.createBuffer({ | |
size: Math.ceil((1 * 3 * pixelHeight * pixelWidth * 4) / 16) * 16, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | |
}); | |
decodedOutputsTensor = ort.Tensor.fromGpuBuffer(decodedOutputsBuffer, { | |
dataType: 'float32', dims: [1, 3, pixelHeight, pixelWidth], | |
dispose: () => decodedOutputsBuffer.destroy() | |
}); | |
prescaleLatentSpacePipeline = deviceWebgpu.createComputePipeline({ | |
layout: 'auto', | |
compute: { | |
module: deviceWebgpu.createShaderModule({ | |
code: PRESCALE_LATENT_SPACE_SHADER, | |
}), | |
entryPoint: '_start', | |
}, | |
}); | |
prescaleLatentSpaceBindGroup = deviceWebgpu.createBindGroup({ | |
layout: prescaleLatentSpacePipeline.getBindGroupLayout(0), | |
entries: [ | |
{ | |
binding: 0, | |
resource: { | |
buffer: unetSampleInputsBuffer, | |
}, | |
}, | |
{ | |
binding: 1, | |
resource: { | |
buffer: latentBuffer, | |
}, | |
} | |
], | |
}); | |
stepLatentSpacePipeline = deviceWebgpu.createComputePipeline({ | |
layout: 'auto', | |
compute: { | |
module: deviceWebgpu.createShaderModule({ | |
code: STEP_LATENT_SPACE_SHADER, | |
}), | |
entryPoint: '_start', | |
}, | |
}); | |
stepLatentSpaceBindGroup = deviceWebgpu.createBindGroup({ | |
layout: stepLatentSpacePipeline.getBindGroupLayout(0), | |
entries: [ | |
{ | |
binding: 0, | |
resource: { | |
buffer: unetOutSampleBuffer, | |
}, | |
}, | |
{ | |
binding: 1, | |
resource: { | |
buffer: latentBuffer, | |
}, | |
} | |
], | |
}); | |
const canvas = document.getElementById(`canvas`); | |
canvas.width = pixelWidth; | |
canvas.height = pixelHeight; | |
renderContext = canvas.getContext('webgpu'); | |
const presentationFormat = navigator.gpu.getPreferredCanvasFormat(); | |
const presentationSize = [pixelWidth, pixelHeight]; | |
renderContext.configure({ | |
device: deviceWebgpu, | |
size: presentationSize, | |
format: presentationFormat, | |
alphaMode: 'opaque', | |
}); | |
renderPipeline = deviceWebgpu.createRenderPipeline({ | |
layout: 'auto', | |
vertex: { | |
module: deviceWebgpu.createShaderModule({ | |
code: VERTEX_SHADER, | |
}), | |
entryPoint: 'main', | |
}, | |
fragment: { | |
module: deviceWebgpu.createShaderModule({ | |
code: PIXEL_SHADER, | |
}), | |
entryPoint: 'main', | |
targets: [ | |
{ | |
format: presentationFormat, | |
}, | |
], | |
}, | |
primitive: { | |
topology: 'triangle-list', | |
}, | |
}); | |
renderBindGroup = deviceWebgpu.createBindGroup({ | |
layout: renderPipeline.getBindGroupLayout(0), | |
entries: [ | |
{ | |
binding: 1, | |
resource: { | |
buffer: decodedOutputsBuffer, | |
}, | |
} | |
], | |
}); | |
} | |
async function main() { | |
loading = load_models(models); | |
} | |
window.onload = () => { | |
main(); | |
} | |
</script> | |
</html> |