diff --git a/packages/local-llm/src/engine-impl.ts b/packages/local-llm/src/engine-impl.ts new file mode 100644 index 000000000..e5066f60a --- /dev/null +++ b/packages/local-llm/src/engine-impl.ts @@ -0,0 +1,270 @@ +/** + * LocalLLMEngineImpl — the actual transformers.js engine. + * + * This file is intentionally NOT exported from the package's index.ts. + * It's loaded only inside the Web Worker (worker.ts), where it owns + * the model + processor + WebGPU device + tensor work. The main thread + * never instantiates it directly — instead it talks to a thin + * `LocalLLMEngine` proxy in engine.ts that postMessages over to the + * worker. + * + * Why the split: model.generate() with a 2B-parameter LLM does heavy + * synchronous tensor work that blocks the JS thread for 50-200 ms per + * forward pass. With ~150 forward passes per generation, the main + * thread would freeze for ~10-30 seconds during a single chat reply. + * Web Workers run on their own thread, so the main UI stays responsive + * for scrolling, clicks, and animations while inference is happening. + * + * The implementation is otherwise identical to the previous in-thread + * engine — same two-step tokenization, same aggregated progress, same + * streaming-or-fallback token counting, same convention quirks for + * transformers.js v4. See the comment headers in each method for the + * detailed reasoning behind each non-obvious decision. + */ + +import type { GenerateOptions, GenerateResult, LoadingStatus, ModelConfig } from './types'; +import { MODELS, DEFAULT_MODEL, type ModelKey } from './models'; + +type TransformersModule = typeof import('@huggingface/transformers'); + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type AnyModel = any; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type AnyProcessor = any; + +export class LocalLLMEngineImpl { + private model: AnyModel = null; + private processor: AnyProcessor = null; + private transformers: TransformersModule | null = null; + private loadPromise: Promise | null = null; + private currentModel: ModelKey | null = null; + private _status: LoadingStatus = { state: 'idle' }; + private statusListeners: Set<(status: LoadingStatus) => void> = new Set(); + + get status(): LoadingStatus { + return this._status; + } + + get isReady(): boolean { + return this._status.state === 'ready'; + } + + get modelConfig(): ModelConfig | null { + return this.currentModel ? MODELS[this.currentModel] : null; + } + + onStatusChange(listener: (status: LoadingStatus) => void): () => void { + this.statusListeners.add(listener); + return () => this.statusListeners.delete(listener); + } + + private setStatus(status: LoadingStatus) { + this._status = status; + for (const listener of this.statusListeners) { + listener(status); + } + } + + static isSupported(): boolean { + return typeof navigator !== 'undefined' && 'gpu' in navigator; + } + + async load(model: ModelKey = DEFAULT_MODEL): Promise { + if (this.model && this.currentModel === model) return; + if (this.loadPromise && this.currentModel === model) return this.loadPromise; + if (this.model && this.currentModel !== model) { + await this.unload(); + } + this.currentModel = model; + this.loadPromise = this._load(model); + return this.loadPromise; + } + + private async _load(model: ModelKey): Promise { + if (!LocalLLMEngineImpl.isSupported()) { + this.setStatus({ state: 'error', error: 'WebGPU not supported in this browser' }); + throw new Error('WebGPU not supported'); + } + + this.setStatus({ state: 'checking' }); + + try { + if (!this.transformers) { + this.transformers = await import('@huggingface/transformers'); + } + const config = MODELS[model]; + + // Aggregated per-file progress reporting — see the long comment + // in the previous engine.ts for the rationale. + const fileProgress = new Map(); + + const formatBytes = (bytes: number): string => { + if (bytes < 1024) return `${bytes} B`; + if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)} KB`; + if (bytes < 1024 * 1024 * 1024) return `${(bytes / (1024 * 1024)).toFixed(0)} MB`; + return `${(bytes / (1024 * 1024 * 1024)).toFixed(2)} GB`; + }; + + const emitAggregate = () => { + let totalLoaded = 0; + let totalSize = 0; + for (const { loaded, total } of fileProgress.values()) { + totalLoaded += loaded; + totalSize += total; + } + const pct = totalSize > 0 ? totalLoaded / totalSize : 0; + this.setStatus({ + state: 'downloading', + progress: pct, + text: + totalSize > 0 + ? `Downloading model (${(pct * 100).toFixed(0)}%, ${formatBytes(totalLoaded)} / ${formatBytes(totalSize)}, ${fileProgress.size} files)` + : `Downloading model (${fileProgress.size} files queued)`, + }); + }; + + const progressCallback = (report: { + status: string; + file?: string; + name?: string; + progress?: number; + loaded?: number; + total?: number; + }) => { + const file = report.file ?? report.name ?? '_unknown'; + if (report.status === 'initiate') { + if (!fileProgress.has(file)) fileProgress.set(file, { loaded: 0, total: 0 }); + emitAggregate(); + } else if (report.status === 'download' || report.status === 'progress') { + fileProgress.set(file, { + loaded: report.loaded ?? 0, + total: report.total ?? fileProgress.get(file)?.total ?? 0, + }); + emitAggregate(); + } else if (report.status === 'done') { + const existing = fileProgress.get(file); + if (existing && existing.total > 0) { + fileProgress.set(file, { loaded: existing.total, total: existing.total }); + } + emitAggregate(); + } + }; + + const { AutoProcessor, Gemma4ForConditionalGeneration } = this.transformers as unknown as { + AutoProcessor: { from_pretrained(id: string, opts?: unknown): Promise }; + Gemma4ForConditionalGeneration: { + from_pretrained(id: string, opts?: unknown): Promise; + }; + }; + + this.processor = await AutoProcessor.from_pretrained(config.modelId, { + progress_callback: progressCallback, + }); + this.model = await Gemma4ForConditionalGeneration.from_pretrained(config.modelId, { + dtype: config.dtype, + device: 'webgpu', + progress_callback: progressCallback, + }); + + this.setStatus({ state: 'ready' }); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + this.setStatus({ state: 'error', error: message }); + this.loadPromise = null; + throw err; + } + } + + async unload(): Promise { + this.model = null; + this.processor = null; + this.currentModel = null; + this.loadPromise = null; + this.setStatus({ state: 'idle' }); + } + + async generate(options: GenerateOptions): Promise { + if (!this.model || !this.processor) { + await this.load(); + } + + const start = performance.now(); + + // Two-step input prep — see previous engine.ts comment for why we + // can't use apply_chat_template's all-in-one return_dict mode for + // Gemma4ForConditionalGeneration. + const promptText: string = this.processor.apply_chat_template(options.messages, { + add_generation_prompt: true, + tokenize: false, + }); + + const inputs = this.processor.tokenizer(promptText, { + return_tensors: 'pt', + }); + + const promptTokenCount = this.tensorLength(inputs?.input_ids); + + let collectedText = ''; + const transformers = this.transformers as TransformersModule; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const TextStreamer = (transformers as any).TextStreamer; + const streamer = new TextStreamer(this.processor.tokenizer, { + skip_prompt: true, + skip_special_tokens: true, + callback_function: (text: string) => { + collectedText += text; + options.onToken?.(text); + }, + }); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let generated: any = null; + try { + generated = await this.model.generate({ + ...inputs, + max_new_tokens: options.maxTokens ?? 1024, + temperature: options.temperature ?? 0.7, + do_sample: (options.temperature ?? 0.7) > 0, + streamer, + }); + } catch (err) { + if (!collectedText) throw err; + } + + let completionTokenCount = 0; + try { + if (generated && generated.dims) { + const fullSequence = this.tensorRow(generated, 0); + completionTokenCount = Math.max(0, fullSequence.length - promptTokenCount); + } + } catch { + // fall through to estimate + } + if (completionTokenCount === 0 && collectedText) { + completionTokenCount = Math.ceil(collectedText.length / 4); + } + + return { + content: collectedText, + usage: { + prompt_tokens: promptTokenCount, + completion_tokens: completionTokenCount, + total_tokens: promptTokenCount + completionTokenCount, + }, + latencyMs: Math.round(performance.now() - start), + }; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + private tensorLength(tensor: any): number { + if (!tensor || !tensor.dims) return 0; + return tensor.dims[tensor.dims.length - 1]; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + private tensorRow(tensor: any, row: number): number[] { + const seqLen = tensor.dims[tensor.dims.length - 1]; + const start = row * seqLen; + return Array.from(tensor.data.slice(start, start + seqLen)) as number[]; + } +} diff --git a/packages/local-llm/src/engine.ts b/packages/local-llm/src/engine.ts index a11d90aea..0785df377 100644 --- a/packages/local-llm/src/engine.ts +++ b/packages/local-llm/src/engine.ts @@ -1,35 +1,57 @@ /** - * LocalLLMEngine — transformers.js wrapper for client-side inference. + * LocalLLMEngine — main-thread proxy for the worker-hosted engine. * - * Lazy-loads a HuggingFace ONNX model on first use, caches weights in the - * browser's Cache API, and runs inference on the WebGPU backend. + * Public API is intentionally identical to the previous in-thread + * implementation so existing callers (the /llm-test page, the + * playground module, @mana/shared-llm's BrowserBackend, the Svelte 5 + * reactive bindings in svelte.svelte.ts) need no changes. Internally + * every call now goes through a Web Worker — see worker.ts for the + * other side of the protocol. * - * The default model is Google's Gemma 4 E2B (`onnx-community/gemma-4-E2B-it-ONNX`, - * q4f16). The external API of this class is intentionally identical to the - * previous WebLLM implementation so callers (Svelte stores, /llm-test page, - * playground module) need no changes when the underlying engine swaps. + * Why a worker: a 2B-parameter LLM does heavy synchronous tensor work + * for ~50-200 ms per forward pass. With ~150 forward passes per + * generation, the main thread would freeze for ~10-30 seconds during + * a single chat reply. Web Workers run on their own thread, so the + * main UI stays responsive throughout. + * + * The proxy is constructed lazily — the Worker is only instantiated + * on first method call. This matters for SSR: importing this module + * during a server render must NOT touch the Worker constructor (which + * doesn't exist in Node), and lazy construction is the cleanest way + * to keep import-time side effects to zero. */ -import type { ChatMessage, GenerateOptions, GenerateResult, LoadingStatus } from './types'; -import type { ModelConfig } from './types'; import { MODELS, DEFAULT_MODEL, type ModelKey } from './models'; +import type { + ChatMessage, + GenerateOptions, + GenerateResult, + LoadingStatus, + ModelConfig, +} from './types'; +import type { SerializableGenerateOptions, WorkerRequest, WorkerResponse } from './worker'; -// transformers.js types are minimal here on purpose. The library does not -// publish first-class TS types for every model class, and we never expose -// these objects past this file — the public surface (LocalLLMEngine methods) -// is fully typed via our own GenerateResult / LoadingStatus etc. -type TransformersModule = typeof import('@huggingface/transformers'); +/** Tracking entry for an in-flight worker request. */ +interface PendingRequest { + resolve: (data: unknown) => void; + reject: (err: Error) => void; + onToken?: (token: string) => void; +} -// eslint-disable-next-line @typescript-eslint/no-explicit-any -type AnyModel = any; -// eslint-disable-next-line @typescript-eslint/no-explicit-any -type AnyProcessor = any; +/** + * Distributive Omit — preserves the discriminated union when stripping + * a key. Plain `Omit` collapses to an intersection in many + * TS versions and loses the type narrowing on `req.type`. This helper + * distributes the Omit across each member of the union so postRequest + * still type-checks at the call sites. + */ +type DistributiveOmit = T extends unknown ? Omit : never; +type WorkerRequestPayload = DistributiveOmit; export class LocalLLMEngine { - private model: AnyModel = null; - private processor: AnyProcessor = null; - private transformers: TransformersModule | null = null; - private loadPromise: Promise | null = null; + private worker: Worker | null = null; + private pending = new Map(); + private nextId = 0; private currentModel: ModelKey | null = null; private _status: LoadingStatus = { state: 'idle' }; private statusListeners: Set<(status: LoadingStatus) => void> = new Set(); @@ -46,9 +68,6 @@ export class LocalLLMEngine { return this.currentModel ? MODELS[this.currentModel] : null; } - /** - * Subscribe to status changes (for non-Svelte usage). - */ onStatusChange(listener: (status: LoadingStatus) => void): () => void { this.statusListeners.add(listener); return () => this.statusListeners.delete(listener); @@ -61,297 +80,118 @@ export class LocalLLMEngine { } } - /** - * Check if WebGPU is available in this browser. - */ + /** Check if WebGPU is available. Synchronous and SSR-safe. */ static isSupported(): boolean { return typeof navigator !== 'undefined' && 'gpu' in navigator; } - /** - * Load a model. Idempotent — returns immediately if already loaded. - * Model weights are cached in the browser Cache API for instant reload. - */ - async load(model: ModelKey = DEFAULT_MODEL): Promise { - // Already loaded with this model - if (this.model && this.currentModel === model) return; + // ─── Worker management ────────────────────────────────── - // Already loading - if (this.loadPromise && this.currentModel === model) return this.loadPromise; + private getWorker(): Worker { + if (this.worker) return this.worker; - // Unload previous model if switching - if (this.model && this.currentModel !== model) { - await this.unload(); + if (typeof Worker === 'undefined') { + throw new Error('@mana/local-llm requires a browser environment (Worker is not defined)'); } - this.currentModel = model; - this.loadPromise = this._load(model); - return this.loadPromise; - } + // `new URL('./worker.ts', import.meta.url)` is Vite's documented + // pattern for declaring a Web Worker entry. Vite picks this up at + // build time, splits worker.ts (and its transformers.js dep) into + // its own chunk, and rewrites the URL to the chunk's hashed path. + // Outside Vite (raw esbuild, plain Node, etc.) this would fail — + // but the only consumer of this package is the SvelteKit web app + // where Vite handles the bundling. + this.worker = new Worker(new URL('./worker.ts', import.meta.url), { + type: 'module', + name: 'mana-local-llm', + }); - private async _load(model: ModelKey): Promise { - if (!LocalLLMEngine.isSupported()) { - this.setStatus({ state: 'error', error: 'WebGPU not supported in this browser' }); - throw new Error('WebGPU not supported'); - } - - this.setStatus({ state: 'checking' }); - - try { - if (!this.transformers) { - this.transformers = await import('@huggingface/transformers'); + this.worker.addEventListener('message', this.handleWorkerMessage); + this.worker.addEventListener('error', (e) => { + // Worker-level fatal error — reject all pending requests. + const message = e.message || 'Worker crashed'; + for (const [id, p] of this.pending) { + p.reject(new Error(`Worker error: ${message}`)); + this.pending.delete(id); } - const config = MODELS[model]; - - // transformers.js progress callback shape: - // { status: 'initiate'|'download'|'progress'|'done'|'ready', - // name?: string, file?: string, progress?: number, - // loaded?: number, total?: number } - // - // The callback fires per-file, and the library downloads many - // shards in parallel (config.json, tokenizer.json, several - // onnx weight files, …). If we naively report the latest event - // the bar bounces wildly between files. Instead we keep a - // per-file byte-accounting map and emit an aggregated total - // every time anything moves. The denominator can grow as new - // files are discovered (causing brief dips), but both - // numerator and denominator are individually monotonic, so the - // dips are small and brief — much smoother than per-file. - const fileProgress = new Map(); - - const formatBytes = (bytes: number): string => { - if (bytes < 1024) return `${bytes} B`; - if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)} KB`; - if (bytes < 1024 * 1024 * 1024) return `${(bytes / (1024 * 1024)).toFixed(0)} MB`; - return `${(bytes / (1024 * 1024 * 1024)).toFixed(2)} GB`; - }; - - const emitAggregate = () => { - let totalLoaded = 0; - let totalSize = 0; - for (const { loaded, total } of fileProgress.values()) { - totalLoaded += loaded; - totalSize += total; - } - const pct = totalSize > 0 ? totalLoaded / totalSize : 0; - this.setStatus({ - state: 'downloading', - progress: pct, - text: - totalSize > 0 - ? `Downloading model (${(pct * 100).toFixed(0)}%, ${formatBytes(totalLoaded)} / ${formatBytes(totalSize)}, ${fileProgress.size} files)` - : `Downloading model (${fileProgress.size} files queued)`, - }); - }; - - const progressCallback = (report: { - status: string; - file?: string; - name?: string; - progress?: number; - loaded?: number; - total?: number; - }) => { - const file = report.file ?? report.name ?? '_unknown'; - if (report.status === 'initiate') { - if (!fileProgress.has(file)) fileProgress.set(file, { loaded: 0, total: 0 }); - emitAggregate(); - } else if (report.status === 'download' || report.status === 'progress') { - fileProgress.set(file, { - loaded: report.loaded ?? 0, - total: report.total ?? fileProgress.get(file)?.total ?? 0, - }); - emitAggregate(); - } else if (report.status === 'done') { - // Pin the file to 100% so a final emit shows it complete - const existing = fileProgress.get(file); - if (existing && existing.total > 0) { - fileProgress.set(file, { loaded: existing.total, total: existing.total }); - } - emitAggregate(); - } - // 'ready' is handled below after both processor + model finish - }; - - // AutoProcessor wraps tokenizer + image/audio preprocessors. For - // our text-only chat path we use the wrapped tokenizer's - // apply_chat_template, but loading the full processor is the - // path the model card documents and avoids architecture-specific - // special-casing. - const { AutoProcessor, Gemma4ForConditionalGeneration } = this.transformers as unknown as { - AutoProcessor: { from_pretrained(id: string, opts?: unknown): Promise }; - Gemma4ForConditionalGeneration: { - from_pretrained(id: string, opts?: unknown): Promise; - }; - }; - - this.processor = await AutoProcessor.from_pretrained(config.modelId, { - progress_callback: progressCallback, - }); - this.model = await Gemma4ForConditionalGeneration.from_pretrained(config.modelId, { - dtype: config.dtype, - device: 'webgpu', - progress_callback: progressCallback, - }); - - this.setStatus({ state: 'ready' }); - } catch (err) { - const message = err instanceof Error ? err.message : String(err); this.setStatus({ state: 'error', error: message }); - this.loadPromise = null; - throw err; - } + }); + + return this.worker; } - /** - * Unload the model and free GPU memory. - */ - async unload(): Promise { - // transformers.js doesn't expose an explicit dispose() yet — dropping - // the references and letting the runtime/GC clean up is the - // recommended path. The WebGPU buffers are tied to the model object - // and get released when it's no longer reachable. - this.model = null; - this.processor = null; - this.currentModel = null; - this.loadPromise = null; - this.setStatus({ state: 'idle' }); - } + private handleWorkerMessage = (event: MessageEvent) => { + const msg = event.data; - /** - * Generate a response. Auto-loads the model if not yet loaded. - * - * Implementation notes for the transformers.js v4 backend: - * - * - We always attach a TextStreamer (regardless of whether the caller - * passed an `onToken`), because the streamer is the *only* documented - * stable way to read generated text out of model.generate(). The - * tensor return value of generate() varies between transformers.js - * versions and is sometimes null when a streamer is in play, which - * used to crash this method with "Cannot read properties of null - * (reading 'dims')" the moment a chat message was sent. - * - * - Token counts are computed from the tensor return value when - * available, and fall back to a chars/4 estimate when it isn't — - * so /llm-test still shows roughly meaningful prompt/completion - * counts even on versions where generate() returns nothing. - */ - async generate(options: GenerateOptions): Promise { - if (!this.model || !this.processor) { - await this.load(); + // Status broadcasts have no id and target every listener. + if (msg.type === 'status') { + this.setStatus(msg.status); + return; } - const start = performance.now(); + // Streaming token: route to the matching request's onToken + // callback if one was registered. + if (msg.type === 'token') { + const pending = this.pending.get(msg.id); + pending?.onToken?.(msg.token); + return; + } - // Two-step input prep, matching the Gemma 4 model-card example: - // 1. Apply the chat template with tokenize:false to get the - // formatted prompt as a plain string (no tokens, no tensor). - // 2. Run the string through the processor's tokenizer with - // return_tensors:'pt' to get a proper { input_ids, attention_mask } - // pair backed by transformers.js Tensor objects. - // - // We previously asked apply_chat_template to do everything in one - // shot via `return_dict: true`, but for Gemma4ForConditionalGeneration - // that path returned a malformed shape (no .dims on input_ids), and - // model.generate() then crashed deep inside the forward pass with - // "Cannot read properties of null (reading 'dims')" — surfacing as - // an opaque chat error. The two-step path is what every transformers.js - // example for multimodal-capable processors uses. - const promptText: string = this.processor.apply_chat_template(options.messages, { - add_generation_prompt: true, - tokenize: false, - }); + // Result/error: resolve or reject the matching pending Promise. + const pending = this.pending.get(msg.id); + if (!pending) return; + this.pending.delete(msg.id); - const inputs = this.processor.tokenizer(promptText, { - return_tensors: 'pt', - }); + if (msg.type === 'result') { + pending.resolve(msg.data); + } else { + pending.reject(new Error(msg.message)); + } + }; - const promptTokenCount = this.tensorLength(inputs?.input_ids); + private postRequest(req: WorkerRequestPayload, onToken?: (token: string) => void): Promise { + const id = `${++this.nextId}`; + const worker = this.getWorker(); - // Always attach a streamer — it's our reliable text channel. - let collectedText = ''; - const transformers = this.transformers as TransformersModule; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const TextStreamer = (transformers as any).TextStreamer; - const streamer = new TextStreamer(this.processor.tokenizer, { - skip_prompt: true, - skip_special_tokens: true, - callback_function: (text: string) => { - collectedText += text; - options.onToken?.(text); - }, - }); - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let generated: any = null; - try { - generated = await this.model.generate({ - ...inputs, - max_new_tokens: options.maxTokens ?? 1024, - temperature: options.temperature ?? 0.7, - do_sample: (options.temperature ?? 0.7) > 0, - streamer, + return new Promise((resolve, reject) => { + this.pending.set(id, { + resolve: (data) => resolve(data as T), + reject, + onToken, }); - } catch (err) { - // Some transformers.js versions throw at the end of streaming - // even though the streamer successfully delivered all tokens. - // Only re-throw if we genuinely have nothing to return. - if (!collectedText) throw err; - } - - // Token counts: prefer the tensor return value, fall back to a - // rough estimate from the collected text length so the UI still - // shows non-zero numbers even on versions where generate() returns - // null when a streamer is attached. - let completionTokenCount = 0; - try { - if (generated && generated.dims) { - const fullSequence = this.tensorRow(generated, 0); - completionTokenCount = Math.max(0, fullSequence.length - promptTokenCount); - } - } catch { - // fall through to estimate - } - if (completionTokenCount === 0 && collectedText) { - // Gemma's BPE averages ~4 chars per token in English/German, - // good enough for a UI hint, not for billing. - completionTokenCount = Math.ceil(collectedText.length / 4); - } - - return { - content: collectedText, - usage: { - prompt_tokens: promptTokenCount, - completion_tokens: completionTokenCount, - total_tokens: promptTokenCount + completionTokenCount, - }, - latencyMs: Math.round(performance.now() - start), - }; + worker.postMessage({ ...req, id } as WorkerRequest); + }); } - /** - * Helper: extract the seq-length of a transformers.js Tensor. - * The tensors expose `.dims` ([batch, seq_len]) and `.data` (TypedArray). - */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - private tensorLength(tensor: any): number { - if (!tensor || !tensor.dims) return 0; - return tensor.dims[tensor.dims.length - 1]; + // ─── Public API ────────────────────────────────────────── + + async load(model: ModelKey = DEFAULT_MODEL): Promise { + if (this.currentModel === model && this.isReady) return; + this.currentModel = model; + await this.postRequest({ type: 'load', modelKey: model }); } - /** - * Helper: extract row N of a 2D tensor as a number array. - */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - private tensorRow(tensor: any, row: number): number[] { - const seqLen = tensor.dims[tensor.dims.length - 1]; - const start = row * seqLen; - return Array.from(tensor.data.slice(start, start + seqLen)) as number[]; + async unload(): Promise { + if (!this.worker) return; // never loaded, nothing to do + await this.postRequest({ type: 'unload' }); + this.currentModel = null; + // Tear down the worker so a future load() starts a fresh one + // with cleared GPU buffers. transformers.js doesn't expose an + // explicit dispose, so terminating the worker is the only way + // to reliably reclaim VRAM. + this.worker.terminate(); + this.worker = null; + this.pending.clear(); } - /** - * Convenience: single prompt → response. - */ + async generate(options: GenerateOptions): Promise { + const { onToken, ...rest } = options; + const opts: SerializableGenerateOptions = rest; + return this.postRequest({ type: 'generate', opts }, onToken); + } + + // ─── Convenience wrappers (main thread, build on top of generate) ── + async prompt( text: string, opts?: { systemPrompt?: string; temperature?: number; maxTokens?: number } @@ -370,9 +210,6 @@ export class LocalLLMEngine { return result.content; } - /** - * Convenience: extract structured JSON from text. - */ async extractJson( text: string, instruction: string, @@ -397,9 +234,6 @@ export class LocalLLMEngine { return JSON.parse(result.content) as T; } - /** - * Convenience: classify text into categories. - */ async classify(text: string, categories: string[], opts?: { context?: string }): Promise { const categoryList = categories.map((c) => `"${c}"`).join(', '); const result = await this.generate({ @@ -415,7 +249,6 @@ export class LocalLLMEngine { }); const normalized = result.content.trim().replace(/^["']|["']$/g, ''); - // Return the closest matching category const match = categories.find((c) => c.toLowerCase() === normalized.toLowerCase()); return match ?? normalized; } diff --git a/packages/local-llm/src/worker.ts b/packages/local-llm/src/worker.ts new file mode 100644 index 000000000..ef3aa0db3 --- /dev/null +++ b/packages/local-llm/src/worker.ts @@ -0,0 +1,110 @@ +/** + * Web Worker entry point for @mana/local-llm. + * + * Runs in a Dedicated Worker context, owns a single LocalLLMEngineImpl + * instance, and exchanges messages with the main thread proxy + * (engine.ts) over postMessage. The protocol is small and typed: + * + * Main → Worker (WorkerRequest): + * { id, type: 'load', modelKey: ModelKey } + * { id, type: 'unload' } + * { id, type: 'generate', opts: SerializableGenerateOptions } + * { id, type: 'isReady' } — synchronous probe; resolves with bool + * + * Worker → Main (WorkerResponse): + * { id, type: 'result', data?: unknown } — request fulfilled + * { id, type: 'error', message: string } — request rejected + * { id, type: 'token', token: string } — streaming token chunk + * { type: 'status', status: LoadingStatus } — broadcast, no id + * + * Each request has a unique id chosen by the proxy. The worker echoes + * the id on its result/error/token responses so the proxy can route + * them back to the right pending Promise + onToken callback. Status + * messages are broadcast (no id) and trigger every registered status + * listener on the proxy. + * + * Note: this file does NOT import @mana/local-llm's index — it imports + * engine-impl directly. The package's public surface is the proxy in + * engine.ts; this file is the worker side of that proxy and lives in + * its own bundle chunk. + */ + +import { LocalLLMEngineImpl } from './engine-impl'; +import type { GenerateOptions, LoadingStatus } from './types'; +import type { ModelKey } from './models'; + +// ─── Protocol types (mirrored in engine.ts) ──────────────────── + +export type SerializableGenerateOptions = Omit; + +export type WorkerRequest = + | { id: string; type: 'load'; modelKey: ModelKey } + | { id: string; type: 'unload' } + | { id: string; type: 'generate'; opts: SerializableGenerateOptions } + | { id: string; type: 'isReady' }; + +export type WorkerResponse = + | { id: string; type: 'result'; data?: unknown } + | { id: string; type: 'error'; message: string } + | { id: string; type: 'token'; token: string } + | { type: 'status'; status: LoadingStatus }; + +// ─── Worker setup ────────────────────────────────────────────── + +const engine = new LocalLLMEngineImpl(); + +// Forward all status changes to the main thread as broadcast messages. +engine.onStatusChange((status) => { + postMessage({ type: 'status', status } satisfies WorkerResponse); +}); + +self.addEventListener('message', async (event: MessageEvent) => { + const req = event.data; + + try { + switch (req.type) { + case 'load': { + await engine.load(req.modelKey); + postMessage({ id: req.id, type: 'result' } satisfies WorkerResponse); + break; + } + + case 'unload': { + await engine.unload(); + postMessage({ id: req.id, type: 'result' } satisfies WorkerResponse); + break; + } + + case 'isReady': { + postMessage({ + id: req.id, + type: 'result', + data: engine.isReady, + } satisfies WorkerResponse); + break; + } + + case 'generate': { + // Re-attach the streaming callback on the worker side. Each + // emitted token gets posted back to the main thread tagged + // with the originating request id, so the proxy can route + // it to the right onToken callback. + const result = await engine.generate({ + ...req.opts, + onToken: (token) => { + postMessage({ id: req.id, type: 'token', token } satisfies WorkerResponse); + }, + }); + postMessage({ + id: req.id, + type: 'result', + data: result, + } satisfies WorkerResponse); + break; + } + } + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + postMessage({ id: req.id, type: 'error', message } satisfies WorkerResponse); + } +});