demos / llm-inference /index.js
Yang Gu
Add llm based on MediaPipe and TFLite
afed82d
raw
history blame
3.98 kB
// Copyright 2024 The MediaPipe Authors.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ---------------------------------------------------------------------------------------- //
import {FilesetResolver, LlmInference} from 'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai';
const input = document.getElementById('input');
const output = document.getElementById('output');
const submit = document.getElementById('submit');
const status = document.getElementById('status');
const modelFileName = 'gemma-2b-it-gpu-int4.bin'; /* Update the file name */
let startTime;
/**
* Display newly generated partial results to the output text box.
*/
function displayPartialResults(partialResults, complete) {
output.textContent += partialResults;
if (complete) {
if (!output.textContent) {
output.textContent = 'Result is empty';
}
submit.disabled = false;
const wordCount = output.textContent.split(' ').length;
const seconds = Math.round((performance.now() - startTime) / 1000, 2);
const wordCountPerSecond = Math.round(wordCount / seconds, 2);
status.innerHTML = `${wordCount} words in ${seconds} seconds, ${wordCountPerSecond} words per second`;
}
}
// Get model via Origin Private File System
async function getModelOPFS(name, url, updateModel) {
const root = await navigator.storage.getDirectory();
let fileHandle;
async function updateFile() {
const response = await fetch(url);
const buffer = await readResponse(response);
fileHandle = await root.getFileHandle(name, {create: true});
const writable = await fileHandle.createWritable();
await writable.write(buffer);
await writable.close();
return buffer;
}
if (updateModel) {
return await updateFile();
}
try {
fileHandle = await root.getFileHandle(name);
const blob = await fileHandle.getFile();
return await blob.arrayBuffer();
} catch (e) {
return await updateFile();
}
}
async function readResponse(response) {
const contentLength = response.headers.get('Content-Length');
let total = parseInt(contentLength ?? '0');
let buffer = new Uint8Array(total);
let loaded = 0;
const reader = response.body.getReader();
async function read() {
const {done, value} = await reader.read();
if (done) return;
let newLoaded = loaded + value.length;
if (newLoaded > total) {
total = newLoaded;
let newBuffer = new Uint8Array(total);
newBuffer.set(buffer);
buffer = newBuffer;
}
buffer.set(value, loaded);
loaded = newLoaded;
return read();
}
await read();
return buffer;
}
/**
* Main function to run LLM Inference.
*/
async function runDemo() {
const genaiFileset = await FilesetResolver.forGenAiTasks(
'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/wasm');
let llmInference;
const modelBuffer = new Int8Array(await getModelOPFS(modelFileName, modelFileName, false));
submit.onclick = () => {
startTime = performance.now();
output.textContent = '';
status.innerHTML = '';
submit.disabled = true;
llmInference.generateResponse(input.value, displayPartialResults);
};
submit.value = 'Loading the model...'
LlmInference
.createFromModelBuffer(genaiFileset, modelBuffer)
.then(llm => {
llmInference = llm;
submit.disabled = false;
submit.value = 'Get Response'
}).catch(() =>{
alert('Failed to initialize the task.');
});
}
runDemo();