import { Template } from "@huggingface/jinja"; import { type DownloadModelConfig, type SamplingConfig, Wllama, type WllamaConfig, } from "@wllama/wllama/esm"; import type CacheManager from "@wllama/wllama/esm/cache-manager"; import type { CacheEntry, CacheEntryMetadata, } from "@wllama/wllama/esm/cache-manager"; import multiThreadWllamaJsUrl from "@wllama/wllama/esm/multi-thread/wllama.js?url"; import multiThreadWllamaWasmUrl from "@wllama/wllama/esm/multi-thread/wllama.wasm?url"; import multiThreadWllamaWorkerMjsUrl from "@wllama/wllama/esm/multi-thread/wllama.worker.mjs?url"; import singleThreadWllamaJsUrl from "@wllama/wllama/esm/single-thread/wllama.js?url"; import singleThreadWllamaWasmUrl from "@wllama/wllama/esm/single-thread/wllama.wasm?url"; import { addLogEntry } from "./logEntries"; import { defaultSettings } from "./settings"; import { getSystemPrompt } from "./systemPrompt"; export async function initializeWllama( modelUrl: string | string[], config?: { wllama?: WllamaConfig; model?: DownloadModelConfig; }, ) { addLogEntry("Initializing Wllama"); const wllama = new Wllama( { "single-thread/wllama.js": singleThreadWllamaJsUrl, "single-thread/wllama.wasm": singleThreadWllamaWasmUrl, "multi-thread/wllama.js": multiThreadWllamaJsUrl, "multi-thread/wllama.wasm": multiThreadWllamaWasmUrl, "multi-thread/wllama.worker.mjs": multiThreadWllamaWorkerMjsUrl, }, config?.wllama, ); wllama.cacheManager = new CustomCacheManager("wllama-cache"); await wllama.loadModelFromUrl(modelUrl, config?.model); addLogEntry("Wllama initialized successfully"); return wllama; } export interface WllamaModel { label: string; url: string | string[]; cacheType: "f16" | "q8_0" | "q4_0"; contextSize: number; fileSizeInMegabytes: number; sampling: SamplingConfig; shouldIncludeUrlsOnPrompt: boolean; buildPrompt: ( wllama: Wllama, query: string, searchResults: string, ) => Promise; stopStrings?: string[]; stopTokens?: number[]; } const defaultModelConfig: Omit< WllamaModel, "label" | "url" | "fileSizeInMegabytes" > = { buildPrompt: async (wllama, query, searchResults) => { return formatChat(wllama, [ { id: 1, role: "user", content: getSystemPrompt(searchResults) }, { id: 2, role: "assistant", content: "Ok!" }, { id: 3, role: "user", content: query }, ]); }, cacheType: "f16", contextSize: 2048, shouldIncludeUrlsOnPrompt: true, sampling: { top_p: defaultSettings.inferenceTopP, temp: defaultSettings.inferenceTemperature, penalty_freq: defaultSettings.inferenceFrequencyPenalty, penalty_present: defaultSettings.inferencePresencePenalty, penalty_repeat: defaultSettings.inferenceRepeatPenalty, // @ts-expect-error Wllama still doesn't have the following properties defined, although they are supported by the llama.cpp. xtc_probability: 0.5, dry_multiplier: 0.8, sampling_seq: "ptxd", }, }; export const wllamaModels: Record = { "smollm2-135m": { ...defaultModelConfig, label: "SmolLM 2 135M", url: "", fileSizeInMegabytes: 145, }, "smollm2-360m": { ...defaultModelConfig, label: "SmolLM 2 360M", url: "", fileSizeInMegabytes: 386, }, "qwen-2.5-0.5b": { ...defaultModelConfig, label: "Qwen 2.5 0.5B", url: "", fileSizeInMegabytes: 531, }, "danube-3-500m": { ...defaultModelConfig, label: "Danube 3 500M", url: "", fileSizeInMegabytes: 547, }, "amd-olmo-1b": { ...defaultModelConfig, label: "AMD OLMo 1B", url: "", fileSizeInMegabytes: 872, }, "granite-3.0-1b": { ...defaultModelConfig, label: "Granite 3.0 1B [400M]", url: "", fileSizeInMegabytes: 969, buildPrompt: async (_, query, searchResults) => buildGranitePrompt(query, searchResults), }, "llama-3.2-1b": { ...defaultModelConfig, label: "Llama 3.2 1B", url: "", fileSizeInMegabytes: 975, }, "pythia-1.4b": { ...defaultModelConfig, label: "Pythia 1.4B", url: "", fileSizeInMegabytes: 1060, }, "pints-1.5b": { ...defaultModelConfig, label: "Pints 1.5B", url: "", fileSizeInMegabytes: 1150, }, "smollm2-1.7b": { ...defaultModelConfig, label: "SmolLM 2 1.7B", url: "", fileSizeInMegabytes: 1230, }, "arcee-lite": { ...defaultModelConfig, label: "Arcee Lite 1.5B", url: "", fileSizeInMegabytes: 1430, }, "danube2-1.8b": { ...defaultModelConfig, label: "Danube 2 1.8B", url: "", fileSizeInMegabytes: 1300, }, "granite-3.0-2b": { ...defaultModelConfig, label: "Granite 3.0 2B", url: "", fileSizeInMegabytes: 1870, buildPrompt: async (_, query, searchResults) => buildGranitePrompt(query, searchResults), }, "gemma-2-2b": { ...defaultModelConfig, label: "Gemma 2 2B", url: "", fileSizeInMegabytes: 1920, }, "llama-3.2-3b": { ...defaultModelConfig, label: "Llama 3.2 3B", url: "", fileSizeInMegabytes: 2420, }, "granite-3.0-3b": { ...defaultModelConfig, label: "Granite 3.0 3B [800M]", url: "", fileSizeInMegabytes: 2450, buildPrompt: async (_, query, searchResults) => buildGranitePrompt(query, searchResults), }, "minicpm3-4b": { ...defaultModelConfig, label: "MiniCPM 3 4B", url: "", fileSizeInMegabytes: 2470, contextSize: 1920, }, "phi-3.5-mini-3.8b": { ...defaultModelConfig, label: "Phi 3.5 Mini 3.8B", url: "", fileSizeInMegabytes: 2820, }, "magpielm-4b": { ...defaultModelConfig, label: "MagpieLM 4B", url: "", fileSizeInMegabytes: 3230, }, "nemotron-mini-4b": { ...defaultModelConfig, label: "Nemotron Mini 4B", url: "", fileSizeInMegabytes: 3550, }, "olmoe-1b-7b": { ...defaultModelConfig, label: "OLMoE 7B [1B]", url: "", fileSizeInMegabytes: 3700, }, }; function buildGranitePrompt(query: string, searchResults: string) { return `<|start_of_role|>system<|end_of_role|>${getSystemPrompt(searchResults)}<|end_of_text|> <|start_of_role|>user<|end_of_role|>${query}<|end_of_text|> <|start_of_role|>assistant<|end_of_role|>`; } export interface Message { id: number; content: string; role: "system" | "user" | "assistant"; } export const formatChat = async (wllama: Wllama, messages: Message[]) => { const defaultChatTemplate = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"; const template = new Template( wllama.getChatTemplate() ?? defaultChatTemplate, ); const textDecoder = new TextDecoder(); return template.render({ messages, bos_token: textDecoder.decode(await wllama.detokenize([wllama.getBOS()])), eos_token: textDecoder.decode(await wllama.detokenize([wllama.getEOS()])), add_generation_prompt: true, }); }; class CustomCacheManager implements CacheManager { private readonly cacheName: string; constructor(cacheName: string) { this.cacheName = cacheName; } async getNameFromURL(url: string): Promise { const encoder = new TextEncoder(); const data = encoder.encode(url); const hashBuffer = await crypto.subtle.digest("SHA-1", data); const hashArray = Array.from(new Uint8Array(hashBuffer)); const hashHex = hashArray .map((b) => b.toString(16).padStart(2, "0")) .join(""); const fileName = url.split("/").pop() || "default"; return `${hashHex}_${fileName}`; } async write( name: string, stream: ReadableStream, metadata: CacheEntryMetadata, ): Promise { const cache = await; const response = new Response(stream, { headers: { "X-Metadata": JSON.stringify(metadata) }, }); await cache.put(name, response); } async open(name: string): Promise { const cache = await; const response = await cache.match(name); return response?.body ?? null; } async getSize(name: string): Promise { const cache = await; const response = await cache.match(name); if (!response) return -1; return ( Number(response.headers.get("Content-Length")) || (await response.blob()).size ); } async getMetadata(name: string): Promise { const cache = await; const response = await cache.match(name); if (!response) return null; const metadata = response.headers.get("X-Metadata"); return metadata ? JSON.parse(metadata) : null; } async list(): Promise { const cache = await; const keys = await cache.keys(); return Promise.all( (request) => { const response = await cache.match(request); if (!response) throw new Error(`No response for ${request.url}`); const metadata = await this.getMetadata(request.url); const size = await this.getSize(request.url); return { name: request.url, size, metadata: metadata ?? { etag: "", originalSize: 0, originalURL: "" }, }; }), ); } async clear(): Promise { await caches.delete(this.cacheName); } async delete(nameOrURL: string): Promise { const cache = await; const success = await cache.delete(nameOrURL); if (!success) { throw new Error(`Failed to delete cache entry for ${nameOrURL}`); } } async deleteMany(predicate: (e: CacheEntry) => boolean): Promise { const entries = await this.list(); const cache = await; const deletionPromises = entries .filter(predicate) .map((entry) => cache.delete(; await Promise.all(deletionPromises); } async writeMetadata( name: string, metadata: CacheEntryMetadata, ): Promise { const cache = await; const response = await cache.match(name); if (!response) { throw new Error(`Cache entry for ${name} not found`); } const newResponse = new Response(response.body, { headers: { ...response.headers, "X-Metadata": JSON.stringify(metadata), }, }); await cache.put(name, newResponse); } }