Yang Gu commited on
Commit
78e34bc
Β·
1 Parent(s): 397aeb0

Refine SD turbo

Browse files
demo/sd-turbo/index.html CHANGED
@@ -9,8 +9,8 @@
9
  window.AutoTokenizer = AutoTokenizer;
10
  env.localModelPath = 'models';
11
  </script>
12
- <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@dev/dist/ort.webgpu.min.js">
13
- </script>
14
 
15
  <head>
16
  <title>Stable Diffusion Turbo</title>
@@ -57,6 +57,7 @@
57
  </canvas>
58
  </div>
59
  <p class="text-lg-start">
 
60
  <div id="status" style="font: 1em consolas;"></div>
61
  </p>
62
  </div>
@@ -226,26 +227,6 @@ fn main(@location(0) fragUV : vec2<f32>) -> @location(0) vec4<f32> {
226
  return data;
227
  }
228
 
229
- async function fetchAndCache(base_url, model_path) {
230
- const url = `${base_url}/${model_path}`;
231
- try {
232
- const cache = await caches.open("onnx");
233
- let cachedResponse = await cache.match(url);
234
- if (cachedResponse == undefined) {
235
- await cache.add(url);
236
- cachedResponse = await cache.match(url);
237
- log(`${model_path} (network)`);
238
- } else {
239
- log(`${model_path} (cached)`);
240
- }
241
- const data = await cachedResponse.arrayBuffer();
242
- return data;
243
- } catch (error) {
244
- log(`${model_path} (network)`);
245
- return await fetch(url).then(response => response.arrayBuffer());
246
- }
247
- }
248
-
249
  function uploadToGPU(buffer, values, type) {
250
 
251
  const stagingBuffer = deviceWebgpu.createBuffer({
@@ -278,62 +259,37 @@ fn main(@location(0) fragUV : vec2<f32>) -> @location(0) vec4<f32> {
278
  }
279
 
280
  async function load_models(models) {
281
- const cache = await caches.open("onnx");
282
- let missing = 0;
283
- for (const [name, model] of Object.entries(models)) {
284
- const url = `${config.model}/${model.url}`;
285
- let cachedResponse = await cache.match(url);
286
- if (cachedResponse === undefined) {
287
- missing += model.size;
288
- }
289
- }
290
- if (missing > 0) {
291
- log(`downloading ${missing} MB from network ... it might take a while`);
292
- } else {
293
- log("loading...");
294
- }
295
  let loadedCount = 0;
296
  for (const [name, model] of Object.entries(models)) {
297
- try {
298
- if (loadedCount === 1) {
299
- webgpuResourceInitialize();
300
- }
301
- const start = performance.now();
302
- const model_bytes = await fetchAndCache(config.model, model.url);
303
- const sess_opt = { ...opt, ...model.opt };
304
- // profiling
305
- //ort.env.webgpu.profiling = { mode: "default" };
306
- models[name].sess = await ort.InferenceSession.create(model_bytes, sess_opt);
307
- const stop = performance.now();
308
- loadedCount++;
309
- log(`${model.url} in ${(stop - start).toFixed(1)}ms`);
310
- } catch (e) {
311
- log(`${model.url} failed, ${e}`);
312
- }
313
  }
 
 
314
  const latent_shape = [1, 4, 64, 64];
315
  latentData = randn_latents(latent_shape, sigma);
316
  uploadToGPU(latentBuffer, latentData, "float32");
317
  submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup);
318
-
319
- log("ready.");
320
  }
321
 
322
  const config = getConfig();
323
 
324
  const models = {
325
  "unet": {
326
- url: "unet/model.onnx", size: 640,
327
  // should have 'steps: 1' but will fail to create the session
328
  opt: { freeDimensionOverrides: { batch_size: 1, num_channels: 4, height: 64, width: 64, sequence_length: 77, } }
329
  },
330
  "text_encoder": {
331
- url: "text_encoder/model.onnx", size: 1700,
332
  // should have 'sequence_length: 77' but produces a bad image
333
  opt: { freeDimensionOverrides: { batch_size: 1, } },
334
  },
335
  "vae_decoder": {
336
- url: "vae_decoder/model.onnx", size: 95,
337
  opt: { freeDimensionOverrides: { batch_size: 1, num_channels_latent: 4, height_latent: 64, width_latent: 64 } }
338
  }
339
  }
@@ -714,4 +670,4 @@ fn main(@location(0) fragUV : vec2<f32>) -> @location(0) vec4<f32> {
714
  }
715
  </script>
716
 
717
- </html>
 
9
  window.AutoTokenizer = AutoTokenizer;
10
  env.localModelPath = 'models';
11
  </script>
12
+ <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@dev/dist/ort.webgpu.min.js"></script>
13
+ <script src="../../util.js"></script>
14
 
15
  <head>
16
  <title>Stable Diffusion Turbo</title>
 
57
  </canvas>
58
  </div>
59
  <p class="text-lg-start">
60
+ <text id="model-progress">Downloading model</text><br />
61
  <div id="status" style="font: 1em consolas;"></div>
62
  </p>
63
  </div>
 
227
  return data;
228
  }
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  function uploadToGPU(buffer, values, type) {
231
 
232
  const stagingBuffer = deviceWebgpu.createBuffer({
 
259
  }
260
 
261
  async function load_models(models) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  let loadedCount = 0;
263
  for (const [name, model] of Object.entries(models)) {
264
+ const model_bytes = await getModelOPFS(model.url, `models/${model.url}`, false);
265
+ const sess_opt = { ...opt, ...model.opt };
266
+ // profiling
267
+ //ort.env.webgpu.profiling = { mode: "default" };
268
+ models[name].sess = await ort.InferenceSession.create(model_bytes, sess_opt);
 
 
 
 
 
 
 
 
 
 
 
269
  }
270
+ document.getElementById('model-progress').innerHTML = 'Model download finished. You may now click "Send" button to generate a new image.';
271
+ webgpuResourceInitialize();
272
  const latent_shape = [1, 4, 64, 64];
273
  latentData = randn_latents(latent_shape, sigma);
274
  uploadToGPU(latentBuffer, latentData, "float32");
275
  submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup);
 
 
276
  }
277
 
278
  const config = getConfig();
279
 
280
  const models = {
281
  "unet": {
282
+ url: "unet.onnx", size: 640,
283
  // should have 'steps: 1' but will fail to create the session
284
  opt: { freeDimensionOverrides: { batch_size: 1, num_channels: 4, height: 64, width: 64, sequence_length: 77, } }
285
  },
286
  "text_encoder": {
287
+ url: "text_encoder.onnx", size: 1700,
288
  // should have 'sequence_length: 77' but produces a bad image
289
  opt: { freeDimensionOverrides: { batch_size: 1, } },
290
  },
291
  "vae_decoder": {
292
+ url: "vae_decoder.onnx", size: 95,
293
  opt: { freeDimensionOverrides: { batch_size: 1, num_channels_latent: 4, height_latent: 64, width_latent: 64 } }
294
  }
295
  }
 
670
  }
671
  </script>
672
 
673
+ </html>
demo/sd-turbo/models/{text_encoder/model.onnx β†’ text_encoder.onnx} RENAMED
File without changes
demo/sd-turbo/models/{unet/model.onnx β†’ unet.onnx} RENAMED
File without changes
demo/sd-turbo/models/{vae_decoder/model.onnx β†’ vae_decoder.onnx} RENAMED
File without changes
main.js CHANGED
@@ -49,21 +49,23 @@ function createElem(tag, attrs = {}, children = []) {
49
  return elem;
50
  }
51
 
 
 
52
  const pageCategories = [
53
  {
54
  title: `Computer Vision`,
55
  description: `Computer Vision`,
56
  demos: {
57
- gemma: {
58
- name: 'Stable Diffusion Turbo',
59
- description: `Stable Diffusion Turbo from https://github.com/guschmue/ort-webgpu/tree/master/sd-turbo`,
60
- filename: "sd-turbo",
61
- },
62
  sam: {
63
  name: 'Segment Anything',
64
  description: `Segment Anything from https://github.com/guschmue/ort-webgpu/tree/master/segment-anything`,
65
  filename: "sam",
66
  },
 
 
 
 
 
67
  yolo: {
68
  name: 'Yolo',
69
  description: `Yolo V9 from https://github.com/guschmue/ort-webgpu/tree/master/yolov9`,
 
49
  return elem;
50
  }
51
 
52
+ // todo: Musicgen
53
+
54
  const pageCategories = [
55
  {
56
  title: `Computer Vision`,
57
  description: `Computer Vision`,
58
  demos: {
 
 
 
 
 
59
  sam: {
60
  name: 'Segment Anything',
61
  description: `Segment Anything from https://github.com/guschmue/ort-webgpu/tree/master/segment-anything`,
62
  filename: "sam",
63
  },
64
+ sdturbo: {
65
+ name: 'Stable Diffusion Turbo',
66
+ description: `Stable Diffusion Turbo from https://github.com/guschmue/ort-webgpu/tree/master/sd-turbo`,
67
+ filename: "sd-turbo",
68
+ },
69
  yolo: {
70
  name: 'Yolo',
71
  description: `Yolo V9 from https://github.com/guschmue/ort-webgpu/tree/master/yolov9`,
util.js CHANGED
@@ -1,7 +1,7 @@
1
- function updateGetModelProgress(loaded, total) {
2
  const progressElement = document.getElementById('model-progress');
3
  if (total === 0) {
4
- progressElement.innerHTML = "Model is already in local cache";
5
  return;
6
  }
7
  const percent = ((loaded / total) * 100).toFixed(2);
@@ -11,7 +11,7 @@ function updateGetModelProgress(loaded, total) {
11
  } else {
12
  progress = "Downloading";
13
  }
14
- progressElement.innerHTML = `${progress} model: Total ${total} bytes, loaded ${loaded} bytes, ${percent}%`;
15
  }
16
 
17
  // Get model via Origin Private File System
@@ -21,7 +21,7 @@ async function getModelOPFS(name, url, updateModel) {
21
 
22
  async function updateFile() {
23
  const response = await fetch(url);
24
- const buffer = await readResponse(response, updateGetModelProgress);
25
  fileHandle = await root.getFileHandle(name, {create: true});
26
  const writable = await fileHandle.createWritable();
27
  await writable.write(buffer);
@@ -36,14 +36,14 @@ async function getModelOPFS(name, url, updateModel) {
36
  try {
37
  fileHandle = await root.getFileHandle(name);
38
  const blob = await fileHandle.getFile();
39
- updateGetModelProgress(0, 0);
40
  return await blob.arrayBuffer();
41
  } catch (e) {
42
  return await updateFile();
43
  }
44
  }
45
 
46
- async function readResponse(response, progressCallback) {
47
  const contentLength = response.headers.get('Content-Length');
48
  let total = parseInt(contentLength ?? '0');
49
  let buffer = new Uint8Array(total);
@@ -65,7 +65,7 @@ async function readResponse(response, progressCallback) {
65
  loaded = newLoaded;
66
 
67
  if (progressCallback) {
68
- progressCallback(loaded, total);
69
  }
70
 
71
  return read();
 
1
+ function updateGetModelProgress(name, loaded, total) {
2
  const progressElement = document.getElementById('model-progress');
3
  if (total === 0) {
4
+ progressElement.innerHTML = `Model ${name} is already in local cache`;
5
  return;
6
  }
7
  const percent = ((loaded / total) * 100).toFixed(2);
 
11
  } else {
12
  progress = "Downloading";
13
  }
14
+ progressElement.innerHTML = `${progress} model ${name}: Total ${total} bytes, loaded ${loaded} bytes, ${percent}%`;
15
  }
16
 
17
  // Get model via Origin Private File System
 
21
 
22
  async function updateFile() {
23
  const response = await fetch(url);
24
+ const buffer = await readResponse(name, response, updateGetModelProgress);
25
  fileHandle = await root.getFileHandle(name, {create: true});
26
  const writable = await fileHandle.createWritable();
27
  await writable.write(buffer);
 
36
  try {
37
  fileHandle = await root.getFileHandle(name);
38
  const blob = await fileHandle.getFile();
39
+ updateGetModelProgress(name, 0, 0);
40
  return await blob.arrayBuffer();
41
  } catch (e) {
42
  return await updateFile();
43
  }
44
  }
45
 
46
+ async function readResponse(name, response, progressCallback) {
47
  const contentLength = response.headers.get('Content-Length');
48
  let total = parseInt(contentLength ?? '0');
49
  let buffer = new Uint8Array(total);
 
65
  loaded = newLoaded;
66
 
67
  if (progressCallback) {
68
+ progressCallback(name, loaded, total);
69
  }
70
 
71
  return read();