Spaces:
Running
Running
// Copyright 2024 The MediaPipe Authors. | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
// ---------------------------------------------------------------------------------------- // | |
import {FilesetResolver, LlmInference} from 'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai'; | |
const input = document.getElementById('input'); | |
const output = document.getElementById('output'); | |
const submit = document.getElementById('submit'); | |
const status = document.getElementById('status'); | |
const modelFileName = 'gemma-2b-it-gpu-int4.bin'; /* Update the file name */ | |
let startTime; | |
/** | |
* Display newly generated partial results to the output text box. | |
*/ | |
function displayPartialResults(partialResults, complete) { | |
output.textContent += partialResults; | |
if (complete) { | |
if (!output.textContent) { | |
output.textContent = 'Result is empty'; | |
} | |
submit.disabled = false; | |
const wordCount = output.textContent.split(' ').length; | |
const seconds = Math.round((performance.now() - startTime) / 1000, 2); | |
const wordCountPerSecond = Math.round(wordCount / seconds, 2); | |
status.innerHTML = `${wordCount} words in ${seconds} seconds, ${wordCountPerSecond} words per second`; | |
} | |
} | |
// Get model via Origin Private File System | |
async function getModelOPFS(name, url, updateModel) { | |
const root = await navigator.storage.getDirectory(); | |
let fileHandle; | |
async function updateFile() { | |
const response = await fetch(url); | |
const buffer = await readResponse(response); | |
fileHandle = await root.getFileHandle(name, {create: true}); | |
const writable = await fileHandle.createWritable(); | |
await writable.write(buffer); | |
await writable.close(); | |
return buffer; | |
} | |
if (updateModel) { | |
return await updateFile(); | |
} | |
try { | |
fileHandle = await root.getFileHandle(name); | |
const blob = await fileHandle.getFile(); | |
return await blob.arrayBuffer(); | |
} catch (e) { | |
return await updateFile(); | |
} | |
} | |
async function readResponse(response) { | |
const contentLength = response.headers.get('Content-Length'); | |
let total = parseInt(contentLength ?? '0'); | |
let buffer = new Uint8Array(total); | |
let loaded = 0; | |
const reader = response.body.getReader(); | |
async function read() { | |
const {done, value} = await reader.read(); | |
if (done) return; | |
let newLoaded = loaded + value.length; | |
if (newLoaded > total) { | |
total = newLoaded; | |
let newBuffer = new Uint8Array(total); | |
newBuffer.set(buffer); | |
buffer = newBuffer; | |
} | |
buffer.set(value, loaded); | |
loaded = newLoaded; | |
return read(); | |
} | |
await read(); | |
return buffer; | |
} | |
/** | |
* Main function to run LLM Inference. | |
*/ | |
async function runDemo() { | |
const genaiFileset = await FilesetResolver.forGenAiTasks( | |
'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/wasm'); | |
let llmInference; | |
const modelBuffer = new Int8Array(await getModelOPFS(modelFileName, modelFileName, false)); | |
submit.onclick = () => { | |
startTime = performance.now(); | |
output.textContent = ''; | |
status.innerHTML = ''; | |
submit.disabled = true; | |
llmInference.generateResponse(input.value, displayPartialResults); | |
}; | |
submit.value = 'Loading the model...' | |
LlmInference | |
.createFromModelBuffer(genaiFileset, modelBuffer) | |
.then(llm => { | |
llmInference = llm; | |
submit.disabled = false; | |
submit.value = 'Get Response' | |
}).catch(() =>{ | |
alert('Failed to initialize the task.'); | |
}); | |
} | |
runDemo(); | |