Spaces:
Running
Running
<html> | |
<head> | |
<title>Example</title> | |
</head> | |
<body> | |
<!-- <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@dev/dist/ort.webgpu.min.js"> </script> --> | |
<script src="https://wp-27.sh.intel.com/workspace/project/onnxruntime/js/web/dist/ort.webgpu.min.js"> </script> | |
<script type="module"> | |
import { AutoTokenizer, env } from '../../transformers/transformers.js'; | |
function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; } | |
const MODELS = { | |
"tinyllama": { name: "tinyllama", path: "schmuell/TinyLlama-1.1B-Chat-v1.0-int4" }, | |
"tinyllama_fp16": { name: "tinyllama-fp16", path: "schmuell/TinyLlama-1.1B-Chat-v1.0-fp16", externaldata: true }, | |
"phi2": { name: "phi2", path: "phi2-int4" }, | |
"phi2-mb": { name: "phi2-mb", path: "schmuell/phi2-mb", externaldata: true }, | |
"stablelm": { name: "stablelm", path: "schmuell/stablelm-2-zephyr-1_6b-int4" }, | |
} | |
function getConfig() { | |
const query = window.location.search.substring(1); | |
var config = { | |
model: "phi2", | |
provider: "webgpu", | |
profiler: 0, | |
verbose: 0, | |
threads: 1, | |
trace: 0, | |
csv: 0, | |
max_tokens: 256, | |
local: 1, | |
} | |
let vars = query.split("&"); | |
for (var i = 0; i < vars.length; i++) { | |
let pair = vars[i].split("="); | |
if (pair[0] in config) { | |
const key = pair[0]; | |
const value = decodeURIComponent(pair[1]); | |
if (typeof config[key] == "number") { | |
config[key] = parseInt(value); | |
} | |
else { | |
config[key] = value; | |
} | |
} else if (pair[0].length > 0) { | |
throw new Error("unknown argument: " + pair[0]); | |
} | |
} | |
if (MODELS[config.model] !== undefined) { | |
config.model = MODELS[config.model]; | |
} | |
return config; | |
} | |
class LLM { | |
sess = undefined; | |
profiler = false; | |
trace = false; | |
feed = {}; | |
output_tokens = []; | |
eos = 2; | |
need_position_ids = true; | |
stop = false; | |
kv_dims = []; | |
dtype = "float16"; | |
constructor() { | |
} | |
async load(model, options) { | |
const provider = options.provider || "webgpu"; | |
const verbose = options.verbose; | |
const local = options.local; | |
this.profiler = options.profiler; | |
this.trace = options.trace; | |
const model_path = (local) ? "models/" + model.path : "https://huggingface.co/" + model.path + "/resolve/main"; | |
log(`loading... ${model.name}, ${provider}`); | |
const json_bytes = await fetchAndCache(model_path + "/config.json"); | |
let textDecoder = new TextDecoder(); | |
const model_config = JSON.parse(textDecoder.decode(json_bytes)); | |
const model_bytes = await fetchAndCache(model_path + "/phi2-int4.onnx"); | |
const externaldata = (model.externaldata) ? await fetchAndCache(model_path + '/onnx/decoder_model_merged.onnx.data') : false; | |
let modelSize = model_bytes.byteLength; | |
if (externaldata) { | |
modelSize += externaldata.byteLength; | |
} | |
log(`model size ${Math.round(modelSize / 1024 / 1024)} MB`); | |
const opt = { | |
executionProviders: [provider], | |
preferredOutputLocation: {}, | |
}; | |
switch (provider) { | |
case "webgpu": | |
if (!("gpu" in navigator)) { | |
throw new Error("webgpu is NOT supported"); | |
} | |
for (let i = 0; i < model_config.num_hidden_layers; ++i) { | |
opt.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer'; | |
opt.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer'; | |
} | |
break; | |
case "webnn": | |
if (!("ml" in navigator)) { | |
throw new Error("webnn is NOT supported"); | |
} | |
break; | |
} | |
if (externaldata !== undefined) { | |
opt.externalData = [ | |
{ | |
data: externaldata, | |
path: 'decoder_model_merged.onnx.data' | |
}, | |
] | |
} | |
if (verbose) { | |
opt.logSeverityLevel = 0; | |
opt.logVerbosityLevel = 0; | |
ort.env.logLevel = "verbose"; | |
ort.env.debug = true; | |
} | |
ort.env.webgpu.profiling = {}; | |
if (this.profiler) { | |
opt.enableProfiling = true; | |
ort.env.webgpu.profilingMode = 'default'; | |
ort.env.webgpu.profiling.mode = 'default'; | |
} | |
this.sess = await ort.InferenceSession.create(model_bytes, opt); | |
if (this.trace) { | |
ort.env.trace = true; | |
ort.env.webgpu.profiling.ondata = (version, inputsMetadata, outputsMetadata, kernelId, kernelType, | |
kernelName, programName, startTime, endTime) => { }; | |
} | |
this.eos = model_config.eos_token_id; | |
this.kv_dims = [1, model_config.num_key_value_heads, 0, model_config.hidden_size / model_config.num_attention_heads]; | |
this.dtype = config.model.dtype || "float16"; | |
this.num_layers = model_config.num_hidden_layers; | |
this.initilize_feed(); | |
} | |
initilize_feed() { | |
this.feed = {}; | |
const empty = (this.dtype === "float16") ? new Uint16Array() : []; | |
for (let i = 0; i < this.num_layers; ++i) { | |
this.feed[`past_key_values.${i}.key`] = new ort.Tensor(this.dtype, empty, this.kv_dims) | |
this.feed[`past_key_values.${i}.value`] = new ort.Tensor(this.dtype, empty, this.kv_dims) | |
} | |
this.output_tokens = []; | |
} | |
argmax(t) { | |
const arr = t.data; | |
const start = t.dims[2] * (t.dims[1] - 1); | |
let max = arr[start]; | |
let maxidx = 0; | |
for (let i = 0; i < t.dims[2]; i++) { | |
const val = arr[i + start]; | |
if (!isFinite(val)) { | |
throw new Error("found infinitive in logits"); | |
} | |
if (val > max) { | |
max = arr[i + start]; | |
maxidx = i; | |
} | |
} | |
return maxidx; | |
} | |
update_kv_cache(feed, outputs) { | |
for (const name in outputs) { | |
if (name.startsWith('present')) { | |
let newName = name.replace('present', 'past_key_values'); | |
// free old gpu buffer | |
const t = feed[newName]; | |
if (t.location === 'gpu-buffer') { | |
t.dispose(); | |
} | |
feed[newName] = outputs[name]; | |
} | |
} | |
} | |
abort() { | |
this.stop = true; | |
} | |
async generate(tokens, callback, options) { | |
const keep_cache = options.keep_cache; | |
const max_tokens = options.max_tokens || 256; | |
const feed = this.feed; | |
const input_ids = new ort.Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, tokens.length]); | |
feed['input_ids'] = input_ids; | |
this.stop = false; | |
if (keep_cache) { | |
this.output_tokens.push(...input_ids) | |
} else { | |
this.initilize_feed(); | |
this.output_tokens = Array.from(feed['input_ids'].data); | |
} | |
let last_token = 0n; | |
let seqlen = this.output_tokens.length; | |
if (this.need_position_ids) { | |
if (keep_cache) { | |
feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, (_, i) => BigInt(i)), [1, input_ids.length]); | |
} else { | |
feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, (_, i) => BigInt(i)), [1, seqlen]); | |
} | |
} | |
while (last_token != this.eos && seqlen < max_tokens && !this.stop) { | |
seqlen = this.output_tokens.length; | |
feed['attention_mask'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, () => 1n), [1, seqlen]); | |
let outputs; | |
if (this.trace) { | |
console.timeStamp("RUN-BEGIN"); | |
outputs = await this.sess.run(feed); | |
console.timeStamp("RUN-END"); | |
} else { | |
outputs = await this.sess.run(feed); | |
} | |
last_token = BigInt(this.argmax(outputs.logits)); | |
this.output_tokens.push(last_token); | |
if (callback && !this.profiler) { | |
callback(this.output_tokens); | |
} | |
this.update_kv_cache(feed, outputs); | |
feed['input_ids'] = new ort.Tensor('int64', BigInt64Array.from([last_token]), [1, 1]); | |
if (this.need_position_ids) { | |
feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]); | |
} | |
} | |
if (this.profiler) { | |
this.sess.endProfiling(); | |
} | |
return this.output_tokens; | |
} | |
} | |
const config = getConfig(); | |
env.localModelPath = 'models'; | |
env.allowRemoteModels = config.local == 0; | |
env.allowLocalModels = config.local == 1; | |
ort.env.wasm.numThreads = config.threads; | |
ort.env.wasm.simd = true; | |
const cons_log = []; | |
if (config.profiler === 2) { | |
console.log = function (message) { | |
if (!message.includes('_fence_')) { | |
cons_log.push(message); | |
} | |
}; | |
} | |
const tokenizer = await AutoTokenizer.from_pretrained(config.model.path); | |
function create_download_link(cons_log) { | |
if (cons_log.length > 0) { | |
let link = document.getElementById('download').childNodes[0]; | |
if (link === undefined) { | |
link = document.createElement("a", "download-link"); | |
link.download = "profiler.log"; | |
link.innerText = "Download"; | |
document.getElementById('download').appendChild(link); | |
} | |
const base64 = btoa(cons_log.join('\n')); | |
link.href = `data:application/json;base64,${base64}`; | |
} | |
} | |
async function fetchAndCache(url) { | |
try { | |
const cache = await caches.open("onnx"); | |
let cachedResponse = await cache.match(url); | |
if (cachedResponse == undefined) { | |
await cache.add(url); | |
cachedResponse = await cache.match(url); | |
log(`${url} (network)`); | |
} else { | |
log(`${url} (cached)`); | |
} | |
const data = await cachedResponse.arrayBuffer(); | |
return data; | |
} catch (error) { | |
log(`${url} (network)`); | |
return await fetch(url).then(response => response.arrayBuffer()); | |
} | |
} | |
function token_to_text(tokenizer, tokens, startidx) { | |
const txt = tokenizer.decode(tokens.slice(startidx), { skip_special_tokens: true, }); | |
return txt; | |
} | |
const llm = new LLM(); | |
async function main() { | |
const model = config.model; | |
await llm.load(model, { | |
provider: config.provider, | |
verbose: config.verbose, | |
profiler: config.profiler, | |
trace: config.trace, | |
local: config.local, | |
}); | |
document.getElementById('status').innerText = ""; | |
const query = "Tell me about Constantinople."; | |
let prompt; | |
if (model.name.includes('phi2')) { | |
prompt = `User:${query}\nAssistant:`; | |
} else { | |
prompt = `"<|system|>\nYou are a friendly assistant.</s>\n<|user|>\n${query}</s>\n<|assistant|>\n`; | |
} | |
const { input_ids } = await tokenizer(prompt, { return_tensor: false, padding: true, truncation: true }); | |
const start_timer = performance.now(); | |
const output_tokens = await llm.generate(input_ids, (output_tokens) => { | |
document.getElementById('result').innerText = token_to_text(tokenizer, output_tokens, input_ids.length); | |
}, {}); | |
const took = (performance.now() - start_timer) / 1000; | |
const txt = token_to_text(tokenizer, output_tokens, input_ids.length); | |
const seqlen = output_tokens.length; | |
document.getElementById('result').innerText = txt; | |
const perf = `${seqlen} tokens in ${took.toFixed(1)}sec, ${(seqlen / took).toFixed(2)} tokens/sec`; | |
console.log(perf + " @@1"); | |
document.getElementById('perf').innerText = perf; | |
if (config.csv) { | |
log(`${model.name},${took.toFixed(2)},${(seqlen / took).toFixed(3)},${seqlen},@@2`); | |
} | |
} | |
try { | |
await main(); | |
} catch (error) { | |
console.error(error); | |
document.getElementById('result').innerText = error.message; | |
} finally { | |
create_download_link(cons_log); | |
} | |
</script> | |
<div id="status"></div> | |
<br /> | |
<div id="result"></div> | |
<br /> | |
<div id="perf"></div> | |
<br /> | |
<div id="download"></div> | |
<br /> | |
</body> | |
</html> | |