Spaces:
Running
Running
Yang Gu
commited on
Commit
Β·
78e34bc
1
Parent(s):
397aeb0
Refine SD turbo
Browse files
demo/sd-turbo/index.html
CHANGED
@@ -9,8 +9,8 @@
|
|
9 |
window.AutoTokenizer = AutoTokenizer;
|
10 |
env.localModelPath = 'models';
|
11 |
</script>
|
12 |
-
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@dev/dist/ort.webgpu.min.js">
|
13 |
-
|
14 |
|
15 |
<head>
|
16 |
<title>Stable Diffusion Turbo</title>
|
@@ -57,6 +57,7 @@
|
|
57 |
</canvas>
|
58 |
</div>
|
59 |
<p class="text-lg-start">
|
|
|
60 |
<div id="status" style="font: 1em consolas;"></div>
|
61 |
</p>
|
62 |
</div>
|
@@ -226,26 +227,6 @@ fn main(@location(0) fragUV : vec2<f32>) -> @location(0) vec4<f32> {
|
|
226 |
return data;
|
227 |
}
|
228 |
|
229 |
-
async function fetchAndCache(base_url, model_path) {
|
230 |
-
const url = `${base_url}/${model_path}`;
|
231 |
-
try {
|
232 |
-
const cache = await caches.open("onnx");
|
233 |
-
let cachedResponse = await cache.match(url);
|
234 |
-
if (cachedResponse == undefined) {
|
235 |
-
await cache.add(url);
|
236 |
-
cachedResponse = await cache.match(url);
|
237 |
-
log(`${model_path} (network)`);
|
238 |
-
} else {
|
239 |
-
log(`${model_path} (cached)`);
|
240 |
-
}
|
241 |
-
const data = await cachedResponse.arrayBuffer();
|
242 |
-
return data;
|
243 |
-
} catch (error) {
|
244 |
-
log(`${model_path} (network)`);
|
245 |
-
return await fetch(url).then(response => response.arrayBuffer());
|
246 |
-
}
|
247 |
-
}
|
248 |
-
|
249 |
function uploadToGPU(buffer, values, type) {
|
250 |
|
251 |
const stagingBuffer = deviceWebgpu.createBuffer({
|
@@ -278,62 +259,37 @@ fn main(@location(0) fragUV : vec2<f32>) -> @location(0) vec4<f32> {
|
|
278 |
}
|
279 |
|
280 |
async function load_models(models) {
|
281 |
-
const cache = await caches.open("onnx");
|
282 |
-
let missing = 0;
|
283 |
-
for (const [name, model] of Object.entries(models)) {
|
284 |
-
const url = `${config.model}/${model.url}`;
|
285 |
-
let cachedResponse = await cache.match(url);
|
286 |
-
if (cachedResponse === undefined) {
|
287 |
-
missing += model.size;
|
288 |
-
}
|
289 |
-
}
|
290 |
-
if (missing > 0) {
|
291 |
-
log(`downloading ${missing} MB from network ... it might take a while`);
|
292 |
-
} else {
|
293 |
-
log("loading...");
|
294 |
-
}
|
295 |
let loadedCount = 0;
|
296 |
for (const [name, model] of Object.entries(models)) {
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
const model_bytes = await fetchAndCache(config.model, model.url);
|
303 |
-
const sess_opt = { ...opt, ...model.opt };
|
304 |
-
// profiling
|
305 |
-
//ort.env.webgpu.profiling = { mode: "default" };
|
306 |
-
models[name].sess = await ort.InferenceSession.create(model_bytes, sess_opt);
|
307 |
-
const stop = performance.now();
|
308 |
-
loadedCount++;
|
309 |
-
log(`${model.url} in ${(stop - start).toFixed(1)}ms`);
|
310 |
-
} catch (e) {
|
311 |
-
log(`${model.url} failed, ${e}`);
|
312 |
-
}
|
313 |
}
|
|
|
|
|
314 |
const latent_shape = [1, 4, 64, 64];
|
315 |
latentData = randn_latents(latent_shape, sigma);
|
316 |
uploadToGPU(latentBuffer, latentData, "float32");
|
317 |
submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup);
|
318 |
-
|
319 |
-
log("ready.");
|
320 |
}
|
321 |
|
322 |
const config = getConfig();
|
323 |
|
324 |
const models = {
|
325 |
"unet": {
|
326 |
-
url: "unet
|
327 |
// should have 'steps: 1' but will fail to create the session
|
328 |
opt: { freeDimensionOverrides: { batch_size: 1, num_channels: 4, height: 64, width: 64, sequence_length: 77, } }
|
329 |
},
|
330 |
"text_encoder": {
|
331 |
-
url: "text_encoder
|
332 |
// should have 'sequence_length: 77' but produces a bad image
|
333 |
opt: { freeDimensionOverrides: { batch_size: 1, } },
|
334 |
},
|
335 |
"vae_decoder": {
|
336 |
-
url: "vae_decoder
|
337 |
opt: { freeDimensionOverrides: { batch_size: 1, num_channels_latent: 4, height_latent: 64, width_latent: 64 } }
|
338 |
}
|
339 |
}
|
@@ -714,4 +670,4 @@ fn main(@location(0) fragUV : vec2<f32>) -> @location(0) vec4<f32> {
|
|
714 |
}
|
715 |
</script>
|
716 |
|
717 |
-
</html>
|
|
|
9 |
window.AutoTokenizer = AutoTokenizer;
|
10 |
env.localModelPath = 'models';
|
11 |
</script>
|
12 |
+
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@dev/dist/ort.webgpu.min.js"></script>
|
13 |
+
<script src="../../util.js"></script>
|
14 |
|
15 |
<head>
|
16 |
<title>Stable Diffusion Turbo</title>
|
|
|
57 |
</canvas>
|
58 |
</div>
|
59 |
<p class="text-lg-start">
|
60 |
+
<text id="model-progress">Downloading model</text><br />
|
61 |
<div id="status" style="font: 1em consolas;"></div>
|
62 |
</p>
|
63 |
</div>
|
|
|
227 |
return data;
|
228 |
}
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
function uploadToGPU(buffer, values, type) {
|
231 |
|
232 |
const stagingBuffer = deviceWebgpu.createBuffer({
|
|
|
259 |
}
|
260 |
|
261 |
async function load_models(models) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
let loadedCount = 0;
|
263 |
for (const [name, model] of Object.entries(models)) {
|
264 |
+
const model_bytes = await getModelOPFS(model.url, `models/${model.url}`, false);
|
265 |
+
const sess_opt = { ...opt, ...model.opt };
|
266 |
+
// profiling
|
267 |
+
//ort.env.webgpu.profiling = { mode: "default" };
|
268 |
+
models[name].sess = await ort.InferenceSession.create(model_bytes, sess_opt);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
}
|
270 |
+
document.getElementById('model-progress').innerHTML = 'Model download finished. You may now click "Send" button to generate a new image.';
|
271 |
+
webgpuResourceInitialize();
|
272 |
const latent_shape = [1, 4, 64, 64];
|
273 |
latentData = randn_latents(latent_shape, sigma);
|
274 |
uploadToGPU(latentBuffer, latentData, "float32");
|
275 |
submitComputeTask(prescaleLatentSpacePipeline, prescaleLatentSpaceBindGroup);
|
|
|
|
|
276 |
}
|
277 |
|
278 |
const config = getConfig();
|
279 |
|
280 |
const models = {
|
281 |
"unet": {
|
282 |
+
url: "unet.onnx", size: 640,
|
283 |
// should have 'steps: 1' but will fail to create the session
|
284 |
opt: { freeDimensionOverrides: { batch_size: 1, num_channels: 4, height: 64, width: 64, sequence_length: 77, } }
|
285 |
},
|
286 |
"text_encoder": {
|
287 |
+
url: "text_encoder.onnx", size: 1700,
|
288 |
// should have 'sequence_length: 77' but produces a bad image
|
289 |
opt: { freeDimensionOverrides: { batch_size: 1, } },
|
290 |
},
|
291 |
"vae_decoder": {
|
292 |
+
url: "vae_decoder.onnx", size: 95,
|
293 |
opt: { freeDimensionOverrides: { batch_size: 1, num_channels_latent: 4, height_latent: 64, width_latent: 64 } }
|
294 |
}
|
295 |
}
|
|
|
670 |
}
|
671 |
</script>
|
672 |
|
673 |
+
</html>
|
demo/sd-turbo/models/{text_encoder/model.onnx β text_encoder.onnx}
RENAMED
File without changes
|
demo/sd-turbo/models/{unet/model.onnx β unet.onnx}
RENAMED
File without changes
|
demo/sd-turbo/models/{vae_decoder/model.onnx β vae_decoder.onnx}
RENAMED
File without changes
|
main.js
CHANGED
@@ -49,21 +49,23 @@ function createElem(tag, attrs = {}, children = []) {
|
|
49 |
return elem;
|
50 |
}
|
51 |
|
|
|
|
|
52 |
const pageCategories = [
|
53 |
{
|
54 |
title: `Computer Vision`,
|
55 |
description: `Computer Vision`,
|
56 |
demos: {
|
57 |
-
gemma: {
|
58 |
-
name: 'Stable Diffusion Turbo',
|
59 |
-
description: `Stable Diffusion Turbo from https://github.com/guschmue/ort-webgpu/tree/master/sd-turbo`,
|
60 |
-
filename: "sd-turbo",
|
61 |
-
},
|
62 |
sam: {
|
63 |
name: 'Segment Anything',
|
64 |
description: `Segment Anything from https://github.com/guschmue/ort-webgpu/tree/master/segment-anything`,
|
65 |
filename: "sam",
|
66 |
},
|
|
|
|
|
|
|
|
|
|
|
67 |
yolo: {
|
68 |
name: 'Yolo',
|
69 |
description: `Yolo V9 from https://github.com/guschmue/ort-webgpu/tree/master/yolov9`,
|
|
|
49 |
return elem;
|
50 |
}
|
51 |
|
52 |
+
// todo: Musicgen
|
53 |
+
|
54 |
const pageCategories = [
|
55 |
{
|
56 |
title: `Computer Vision`,
|
57 |
description: `Computer Vision`,
|
58 |
demos: {
|
|
|
|
|
|
|
|
|
|
|
59 |
sam: {
|
60 |
name: 'Segment Anything',
|
61 |
description: `Segment Anything from https://github.com/guschmue/ort-webgpu/tree/master/segment-anything`,
|
62 |
filename: "sam",
|
63 |
},
|
64 |
+
sdturbo: {
|
65 |
+
name: 'Stable Diffusion Turbo',
|
66 |
+
description: `Stable Diffusion Turbo from https://github.com/guschmue/ort-webgpu/tree/master/sd-turbo`,
|
67 |
+
filename: "sd-turbo",
|
68 |
+
},
|
69 |
yolo: {
|
70 |
name: 'Yolo',
|
71 |
description: `Yolo V9 from https://github.com/guschmue/ort-webgpu/tree/master/yolov9`,
|
util.js
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
function updateGetModelProgress(loaded, total) {
|
2 |
const progressElement = document.getElementById('model-progress');
|
3 |
if (total === 0) {
|
4 |
-
progressElement.innerHTML =
|
5 |
return;
|
6 |
}
|
7 |
const percent = ((loaded / total) * 100).toFixed(2);
|
@@ -11,7 +11,7 @@ function updateGetModelProgress(loaded, total) {
|
|
11 |
} else {
|
12 |
progress = "Downloading";
|
13 |
}
|
14 |
-
progressElement.innerHTML = `${progress} model: Total ${total} bytes, loaded ${loaded} bytes, ${percent}%`;
|
15 |
}
|
16 |
|
17 |
// Get model via Origin Private File System
|
@@ -21,7 +21,7 @@ async function getModelOPFS(name, url, updateModel) {
|
|
21 |
|
22 |
async function updateFile() {
|
23 |
const response = await fetch(url);
|
24 |
-
const buffer = await readResponse(response, updateGetModelProgress);
|
25 |
fileHandle = await root.getFileHandle(name, {create: true});
|
26 |
const writable = await fileHandle.createWritable();
|
27 |
await writable.write(buffer);
|
@@ -36,14 +36,14 @@ async function getModelOPFS(name, url, updateModel) {
|
|
36 |
try {
|
37 |
fileHandle = await root.getFileHandle(name);
|
38 |
const blob = await fileHandle.getFile();
|
39 |
-
updateGetModelProgress(0, 0);
|
40 |
return await blob.arrayBuffer();
|
41 |
} catch (e) {
|
42 |
return await updateFile();
|
43 |
}
|
44 |
}
|
45 |
|
46 |
-
async function readResponse(response, progressCallback) {
|
47 |
const contentLength = response.headers.get('Content-Length');
|
48 |
let total = parseInt(contentLength ?? '0');
|
49 |
let buffer = new Uint8Array(total);
|
@@ -65,7 +65,7 @@ async function readResponse(response, progressCallback) {
|
|
65 |
loaded = newLoaded;
|
66 |
|
67 |
if (progressCallback) {
|
68 |
-
progressCallback(loaded, total);
|
69 |
}
|
70 |
|
71 |
return read();
|
|
|
1 |
+
function updateGetModelProgress(name, loaded, total) {
|
2 |
const progressElement = document.getElementById('model-progress');
|
3 |
if (total === 0) {
|
4 |
+
progressElement.innerHTML = `Model ${name} is already in local cache`;
|
5 |
return;
|
6 |
}
|
7 |
const percent = ((loaded / total) * 100).toFixed(2);
|
|
|
11 |
} else {
|
12 |
progress = "Downloading";
|
13 |
}
|
14 |
+
progressElement.innerHTML = `${progress} model ${name}: Total ${total} bytes, loaded ${loaded} bytes, ${percent}%`;
|
15 |
}
|
16 |
|
17 |
// Get model via Origin Private File System
|
|
|
21 |
|
22 |
async function updateFile() {
|
23 |
const response = await fetch(url);
|
24 |
+
const buffer = await readResponse(name, response, updateGetModelProgress);
|
25 |
fileHandle = await root.getFileHandle(name, {create: true});
|
26 |
const writable = await fileHandle.createWritable();
|
27 |
await writable.write(buffer);
|
|
|
36 |
try {
|
37 |
fileHandle = await root.getFileHandle(name);
|
38 |
const blob = await fileHandle.getFile();
|
39 |
+
updateGetModelProgress(name, 0, 0);
|
40 |
return await blob.arrayBuffer();
|
41 |
} catch (e) {
|
42 |
return await updateFile();
|
43 |
}
|
44 |
}
|
45 |
|
46 |
+
async function readResponse(name, response, progressCallback) {
|
47 |
const contentLength = response.headers.get('Content-Length');
|
48 |
let total = parseInt(contentLength ?? '0');
|
49 |
let buffer = new Uint8Array(total);
|
|
|
65 |
loaded = newLoaded;
|
66 |
|
67 |
if (progressCallback) {
|
68 |
+
progressCallback(name, loaded, total);
|
69 |
}
|
70 |
|
71 |
return read();
|