Yang Gu
Format the model folder names
1322fe0
raw
history blame
10.4 kB
import { env, pipeline } from './transformers/transformers.js';
const clipboardIcon = `<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-clipboard" viewBox="0 0 16 16">
<path d="M4 1.5H3a2 2 0 0 0-2 2V14a2 2 0 0 0 2 2h10a2 2 0 0 0 2-2V3.5a2 2 0 0 0-2-2h-1v1h1a1 1 0 0 1 1 1V14a1 1 0 0 1-1 1H3a1 1 0 0 1-1-1V3.5a1 1 0 0 1 1-1h1v-1z"/>
<path d="M9.5 1a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-3a.5.5 0 0 1-.5-.5v-1a.5.5 0 0 1 .5-.5h3zm-3-1A1.5 1.5 0 0 0 5 1.5v1A1.5 1.5 0 0 0 6.5 4h3A1.5 1.5 0 0 0 11 2.5v-1A1.5 1.5 0 0 0 9.5 0h-3z"/>
</svg>`
marked.use({
mangle: false,
headerIds: false
});
function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; }
const sendButton = document.getElementById('send-button');
// adjusts the padding at the bottom of scrollWrapper to be the height of the input box
function adjustPadding() {
const inputBoxHeight = document.getElementById('input-area').offsetHeight;
const scrollWrapper = document.getElementById('scroll-wrapper');
scrollWrapper.style.paddingBottom = `${inputBoxHeight + 15}px`;
}
// sets up padding resize whenever input box has its height changed
const autoResizePadding = new ResizeObserver(() => {
adjustPadding();
});
autoResizePadding.observe(document.getElementById('input-area'));
// variables to handle auto-scroll
// we only need one ResizeObserver and isAutoScrollOn variable globally
// no need to make a new one for every time submitRequest is called
const scrollWrapper = document.getElementById('scroll-wrapper');
let isAutoScrollOn = true;
// autoscroll when new line is added
const autoScroller = new ResizeObserver(() => {
if (isAutoScrollOn) {
scrollWrapper.scrollIntoView({ behavior: "smooth", block: "end" });
}
});
// event listener for scrolling
let lastKnownScrollPosition = 0;
let ticking = false;
document.addEventListener("scroll", (event) => {
// if user has scrolled up and autoScroll is on we turn it off
if (!ticking && isAutoScrollOn && window.scrollY < lastKnownScrollPosition) {
window.requestAnimationFrame(() => {
isAutoScrollOn = false;
ticking = false;
});
ticking = true;
}
// if user has scrolled nearly all the way down and autoScroll is disabled, re-enable
else if (!ticking && !isAutoScrollOn &&
window.scrollY > lastKnownScrollPosition && // make sure scroll direction is down
window.scrollY >= document.documentElement.scrollHeight - window.innerHeight - 30 // add 30px of space--no need to scroll all the way down, just most of the way
) {
window.requestAnimationFrame(() => {
isAutoScrollOn = true;
ticking = false;
});
ticking = true;
}
lastKnownScrollPosition = window.scrollY;
});
function copyTextToClipboard(responseDiv, with_button) {
let elem = responseDiv;
if (with_button) {
let copyButton = document.createElement('button');
copyButton.className = 'btn btn-secondary copy-button';
copyButton.innerHTML = clipboardIcon;
elem = copyButton;
}
elem.onclick = () => {
let text = responseDiv.hidden_text;
if (!text) {
text = responseDiv.innerText;
}
navigator.clipboard.writeText(text).then(() => {
console.log('Text copied to clipboard');
}).catch(err => {
console.error('Failed to copy text:', err);
});
};
if (with_button) {
responseDiv.appendChild(elem);
}
}
let interrupt;
// Function to handle the user input and call the API functions
async function submitRequest() {
if (sendButton.innerHTML == "Stop" && interrupt) {
console.log("Stop");
interrupt.abort('Stop button pressed');
return;
}
document.getElementById('chat-container').style.display = 'block';
const input = document.getElementById('user-input').value;
if (input.length == 0) {
document.getElementById('chat-history').context = "";
let chatHistory = document.getElementById('chat-history');
while (chatHistory.firstChild) {
chatHistory.firstChild.remove();
}
return;
}
let context = document.getElementById('chat-history').context;
if (context === undefined) {
context = "";
}
// Create user message element and append to chat history
let chatHistory = document.getElementById('chat-history');
let userMessageDiv = document.createElement('div');
userMessageDiv.className = 'mb-2 user-message';
userMessageDiv.innerText = input;
chatHistory.appendChild(userMessageDiv);
copyTextToClipboard(userMessageDiv);
// Create response container
let responseDiv = document.createElement('div');
responseDiv.className = 'response-message mb-2 text-start';
responseDiv.style.minHeight = '3em'; // make sure div does not shrink if we cancel the request when no text has been generated yet
let spinner = document.createElement('div');
spinner.className = 'spinner-border text-light';
spinner.setAttribute('role', 'status');
responseDiv.appendChild(spinner);
chatHistory.appendChild(responseDiv);
// create button to stop text generation
sendButton.innerHTML = "Stop";
interrupt = new AbortController();
// change autoScroller to keep track of our new responseDiv
autoScroller.observe(responseDiv);
Query(context + " " + input, (word) => {
// add word to response
responseDiv.innerHTML = DOMPurify.sanitize(marked.parse(word)); // Append word to response container
}).then(() => {
chatHistory.context = responseDiv.innerHTML;
copyTextToClipboard(responseDiv, true);
sendButton.innerHTML = "Send";
spinner.remove();
interrupt = undefined;
}).catch(error => {
if (error !== 'Stop button pressed') {
console.error(error);
}
sendButton.innerHTML = "Send";
spinner.remove();
interrupt = undefined;
});
// Clear user input
document.getElementById('user-input').value = '';
}
const preCannedQueries = {
"1": "Tell me about the lighthouse of Alexandria",
"2": "Did the lighthouse of Alexandria existed at the same time the library of Alexandria existed?",
"3": "How did the Pharos lighthouse impact ancient maritime trade?",
"4": "Tell me about Constantinople?",
};
// Event listener for Ctrl + Enter or CMD + Enter
document.getElementById('user-input').addEventListener('keydown', function (e) {
if (e.ctrlKey) {
if (e.key === 'Enter') {
submitRequest();
} else {
const query = preCannedQueries[e.key];
if (query) {
document.getElementById('user-input').value = query;
submitRequest();
}
}
}
});
function cleanup_text(text) {
const assistantText = text.slice(text.indexOf('assistant|>') + 11);
return assistantText;
}
function getConfig() {
const query = window.location.search.substring(1);
var config = {
model: "schmuell/TinyLlama-1.1B-Chat-v1.0-int4",
provider: "webgpu",
isPhi2: false,
needsExternalData: false,
layers: 22,
}
let vars = query.split("&");
for (var i = 0; i < vars.length; i++) {
let pair = vars[i].split("=");
if (pair[0] in config) {
config[pair[0]] = decodeURIComponent(pair[1]);
} else if (pair[0].length > 0) {
throw new Error("unknown argument: " + pair[0]);
}
}
if (config.model.includes("phi2")) {
config.isPhi2 = true;
config.layers = 32;
}
config.needsExternalData = config.model.includes("-fp16");
return config;
}
const config = getConfig();
let pipe;
async function Query(query, cb) {
// Define the list of messages
const messages = [
{ "role": "system", "content": "You are a friendly assistant." },
// { "role": "user", "content": "Tell me about the lighthouse of Alexandria" },
{ "role": "user", "content": query },
]
// Construct the prompt
let prompt;
if (config.isPhi2) {
prompt = query;
} else {
prompt = pipe.tokenizer.apply_chat_template(messages, {
tokenize: false, add_generation_prompt: true,
});
}
// Generate a response
const start = performance.now();
const result = await pipe(prompt, {
max_new_tokens: 256,
temperature: 0.7,
do_sample: true,
top_k: 15,
callback_function: function (beams) {
const decodedText = pipe.tokenizer.decode(beams[0].output_token_ids, { skip_special_tokens: true, });
cb(cleanup_text(decodedText));
}
});
const stop = performance.now();
console.log(`took ${((stop - start) / 1000).toFixed(1)}sec`);
}
async function LoadModel() {
env.backends.onnx.wasm.numThreads = 1;
env.allowRemoteModels = true;
env.backends.onnx.wasm.wasmPaths = 'transformers/';
const model = config.model;
let options;
if (config.isPhi2) {
// slighly different setup for phi2
options = {
quantized: false,
session_options: {
executionProviders: [config.provider],
preferredOutputLocation: {},
}
}
} else {
options = {
quantized: config.provider == "wasm" ? true : false,
session_options: {
executionProviders: [config.provider],
preferredOutputLocation: {},
}
}
if (config.provider == "webgpu") {
for (let i = 0; i < config.layers; ++i) {
options.session_options.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer';
options.session_options.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer';
}
}
}
if (config.needsExternalData) {
options.session_options.externalData = [
{
data: 'onnx/decoder_model_merged.onnx.data',
path: 'decoder_model_merged.onnx.data'
},
];
}
const start = performance.now();
log("Loading model ... ");
pipe = await pipeline('text-generation', model, options);
log(`done, ${((performance.now() - start) / 1000).toFixed(1)}sec`);
}
async function hasFp16() {
try {
const adapter = await navigator.gpu.requestAdapter()
return adapter.features.has('shader-f16')
} catch (e) {
return false
}
}
window.onload = () => {
hasFp16().then((fp16) => {
if (fp16) {
LoadModel().then(() => {
adjustPadding();
sendButton.addEventListener('click', submitRequest);
const userInput = document.getElementById('user-input');
document.getElementById("status").style.display = "none";
userInput.focus();
});
} else {
log("Your GPU or Browser doesn't support webgpu/f16");
}
});
}