mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 21:01:08 +02:00
feat(local-llm): Phase 3 — move inference into a Web Worker
The browser tier of @mana/local-llm was running entirely in the main
JS thread. With Gemma 4 E2B that meant ~50-200 ms of synchronous
tensor work per forward pass × ~150 forward passes per generation =
the UI froze for 10-30 seconds during a single chat reply. Scrolling,
clicks, animations all stopped.
Move the actual inference into a Dedicated Web Worker. The main
thread keeps a thin LocalLLMEngine proxy with the same public API
(load / unload / generate / prompt / extractJson / classify /
onStatusChange / isSupported), so existing callers — the /llm-test
page, the playground module, @mana/shared-llm's BrowserBackend, the
Svelte 5 reactive bindings — need NO changes.
File layout after the split:
src/engine.ts — main-thread proxy (lazy worker init,
postMessage protocol, pending request map,
status broadcast handling, convenience
wrappers for prompt/extractJson/classify)
src/worker.ts — Web Worker entry point (typed message
protocol, single LocalLLMEngineImpl instance,
forwards status changes back to main thread)
src/engine-impl.ts — the actual transformers.js engine (renamed
from the previous engine.ts contents). NOT
exported from index.ts — only the worker
imports it. Same two-step tokenization,
aggregated progress reporting, streaming
token handling as before; just running in
a different thread now.
Worker construction uses Vite's documented `new Worker(new URL(
'./worker.ts', import.meta.url), { type: 'module' })` pattern, which
makes Vite split worker.ts (and its transformers.js dep) into its
own bundle chunk at build time. The proxy is lazy-init: the Worker
constructor is never touched at module-import time, so SSR stays
clean (Worker doesn't exist on Node).
Message protocol (typed end-to-end):
Main → Worker:
{ id, type: 'load', modelKey: ModelKey }
{ id, type: 'unload' }
{ id, type: 'generate', opts: SerializableGenerateOptions }
{ id, type: 'isReady' }
Worker → Main:
{ id, type: 'result', data?: unknown }
{ id, type: 'error', message: string }
{ id, type: 'token', token: string } — streaming chunk
{ type: 'status', status: LoadingStatus } — broadcast
The proxy assigns a unique id per request, stores the resolve/reject
+ optional onToken callback in a Map<id, PendingRequest>, and routes
incoming responses by id. Status messages have no id and fire every
registered status listener — same UX as before, just one extra hop.
Streaming: the worker re-attaches the streaming callback on its
side. Each emitted token gets posted back as `{ id, type: 'token',
token }` and the proxy invokes the original `onToken` callback. The
final `result` arrives as a normal response and resolves the
Promise. From the caller's perspective generate() still feels
identical — same async iterable feel via onToken, same return value.
Worker termination on unload: transformers.js doesn't expose a
dispose API, so we terminate the worker after unload and create a
fresh one on the next load. This is the only reliable way to
release VRAM between model swaps.
CSP: no header changes needed. The worker is loaded from a
same-origin URL (Vite emits it as
/_app/immutable/workers/worker.[hash].js), so 'self' in script-src
already covers it. The blob: + cdn.jsdelivr.net + wasm-unsafe-eval
allowlists we added during the original WebLLM/transformers.js
bring-up still apply because the worker still runs the same ONNX
runtime that needed them.
DistributiveOmit type helper: TS's plain `Omit<Union, K>` collapses
discriminated unions to an intersection in some configurations,
which broke the type narrowing at the postRequest call sites for
each request variant. Adding a tiny `DistributiveOmit<T, K>` helper
fixes the type-check without restructuring the protocol.
What this commit deliberately does NOT do:
- Change the public API surface. The whole point is that callers
remain untouched.
- Add multi-tab worker coordination via SharedWorker or
BroadcastChannel. Each tab still spawns its own dedicated worker
with its own copy of the model in VRAM. Multi-tab dedup is
Phase 2.5/Phase 4 work — see the design doc summary in the
previous Phase 1 commit message.
- Add a persistent task queue. Fire-and-forget background tasks
are Phase 4.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
919fcca4b7
commit
45f368f471
3 changed files with 509 additions and 296 deletions
270
packages/local-llm/src/engine-impl.ts
Normal file
270
packages/local-llm/src/engine-impl.ts
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
/**
|
||||
* LocalLLMEngineImpl — the actual transformers.js engine.
|
||||
*
|
||||
* This file is intentionally NOT exported from the package's index.ts.
|
||||
* It's loaded only inside the Web Worker (worker.ts), where it owns
|
||||
* the model + processor + WebGPU device + tensor work. The main thread
|
||||
* never instantiates it directly — instead it talks to a thin
|
||||
* `LocalLLMEngine` proxy in engine.ts that postMessages over to the
|
||||
* worker.
|
||||
*
|
||||
* Why the split: model.generate() with a 2B-parameter LLM does heavy
|
||||
* synchronous tensor work that blocks the JS thread for 50-200 ms per
|
||||
* forward pass. With ~150 forward passes per generation, the main
|
||||
* thread would freeze for ~10-30 seconds during a single chat reply.
|
||||
* Web Workers run on their own thread, so the main UI stays responsive
|
||||
* for scrolling, clicks, and animations while inference is happening.
|
||||
*
|
||||
* The implementation is otherwise identical to the previous in-thread
|
||||
* engine — same two-step tokenization, same aggregated progress, same
|
||||
* streaming-or-fallback token counting, same convention quirks for
|
||||
* transformers.js v4. See the comment headers in each method for the
|
||||
* detailed reasoning behind each non-obvious decision.
|
||||
*/
|
||||
|
||||
import type { GenerateOptions, GenerateResult, LoadingStatus, ModelConfig } from './types';
|
||||
import { MODELS, DEFAULT_MODEL, type ModelKey } from './models';
|
||||
|
||||
type TransformersModule = typeof import('@huggingface/transformers');
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
type AnyModel = any;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
type AnyProcessor = any;
|
||||
|
||||
export class LocalLLMEngineImpl {
|
||||
private model: AnyModel = null;
|
||||
private processor: AnyProcessor = null;
|
||||
private transformers: TransformersModule | null = null;
|
||||
private loadPromise: Promise<void> | null = null;
|
||||
private currentModel: ModelKey | null = null;
|
||||
private _status: LoadingStatus = { state: 'idle' };
|
||||
private statusListeners: Set<(status: LoadingStatus) => void> = new Set();
|
||||
|
||||
get status(): LoadingStatus {
|
||||
return this._status;
|
||||
}
|
||||
|
||||
get isReady(): boolean {
|
||||
return this._status.state === 'ready';
|
||||
}
|
||||
|
||||
get modelConfig(): ModelConfig | null {
|
||||
return this.currentModel ? MODELS[this.currentModel] : null;
|
||||
}
|
||||
|
||||
onStatusChange(listener: (status: LoadingStatus) => void): () => void {
|
||||
this.statusListeners.add(listener);
|
||||
return () => this.statusListeners.delete(listener);
|
||||
}
|
||||
|
||||
private setStatus(status: LoadingStatus) {
|
||||
this._status = status;
|
||||
for (const listener of this.statusListeners) {
|
||||
listener(status);
|
||||
}
|
||||
}
|
||||
|
||||
static isSupported(): boolean {
|
||||
return typeof navigator !== 'undefined' && 'gpu' in navigator;
|
||||
}
|
||||
|
||||
async load(model: ModelKey = DEFAULT_MODEL): Promise<void> {
|
||||
if (this.model && this.currentModel === model) return;
|
||||
if (this.loadPromise && this.currentModel === model) return this.loadPromise;
|
||||
if (this.model && this.currentModel !== model) {
|
||||
await this.unload();
|
||||
}
|
||||
this.currentModel = model;
|
||||
this.loadPromise = this._load(model);
|
||||
return this.loadPromise;
|
||||
}
|
||||
|
||||
private async _load(model: ModelKey): Promise<void> {
|
||||
if (!LocalLLMEngineImpl.isSupported()) {
|
||||
this.setStatus({ state: 'error', error: 'WebGPU not supported in this browser' });
|
||||
throw new Error('WebGPU not supported');
|
||||
}
|
||||
|
||||
this.setStatus({ state: 'checking' });
|
||||
|
||||
try {
|
||||
if (!this.transformers) {
|
||||
this.transformers = await import('@huggingface/transformers');
|
||||
}
|
||||
const config = MODELS[model];
|
||||
|
||||
// Aggregated per-file progress reporting — see the long comment
|
||||
// in the previous engine.ts for the rationale.
|
||||
const fileProgress = new Map<string, { loaded: number; total: number }>();
|
||||
|
||||
const formatBytes = (bytes: number): string => {
|
||||
if (bytes < 1024) return `${bytes} B`;
|
||||
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)} KB`;
|
||||
if (bytes < 1024 * 1024 * 1024) return `${(bytes / (1024 * 1024)).toFixed(0)} MB`;
|
||||
return `${(bytes / (1024 * 1024 * 1024)).toFixed(2)} GB`;
|
||||
};
|
||||
|
||||
const emitAggregate = () => {
|
||||
let totalLoaded = 0;
|
||||
let totalSize = 0;
|
||||
for (const { loaded, total } of fileProgress.values()) {
|
||||
totalLoaded += loaded;
|
||||
totalSize += total;
|
||||
}
|
||||
const pct = totalSize > 0 ? totalLoaded / totalSize : 0;
|
||||
this.setStatus({
|
||||
state: 'downloading',
|
||||
progress: pct,
|
||||
text:
|
||||
totalSize > 0
|
||||
? `Downloading model (${(pct * 100).toFixed(0)}%, ${formatBytes(totalLoaded)} / ${formatBytes(totalSize)}, ${fileProgress.size} files)`
|
||||
: `Downloading model (${fileProgress.size} files queued)`,
|
||||
});
|
||||
};
|
||||
|
||||
const progressCallback = (report: {
|
||||
status: string;
|
||||
file?: string;
|
||||
name?: string;
|
||||
progress?: number;
|
||||
loaded?: number;
|
||||
total?: number;
|
||||
}) => {
|
||||
const file = report.file ?? report.name ?? '_unknown';
|
||||
if (report.status === 'initiate') {
|
||||
if (!fileProgress.has(file)) fileProgress.set(file, { loaded: 0, total: 0 });
|
||||
emitAggregate();
|
||||
} else if (report.status === 'download' || report.status === 'progress') {
|
||||
fileProgress.set(file, {
|
||||
loaded: report.loaded ?? 0,
|
||||
total: report.total ?? fileProgress.get(file)?.total ?? 0,
|
||||
});
|
||||
emitAggregate();
|
||||
} else if (report.status === 'done') {
|
||||
const existing = fileProgress.get(file);
|
||||
if (existing && existing.total > 0) {
|
||||
fileProgress.set(file, { loaded: existing.total, total: existing.total });
|
||||
}
|
||||
emitAggregate();
|
||||
}
|
||||
};
|
||||
|
||||
const { AutoProcessor, Gemma4ForConditionalGeneration } = this.transformers as unknown as {
|
||||
AutoProcessor: { from_pretrained(id: string, opts?: unknown): Promise<AnyProcessor> };
|
||||
Gemma4ForConditionalGeneration: {
|
||||
from_pretrained(id: string, opts?: unknown): Promise<AnyModel>;
|
||||
};
|
||||
};
|
||||
|
||||
this.processor = await AutoProcessor.from_pretrained(config.modelId, {
|
||||
progress_callback: progressCallback,
|
||||
});
|
||||
this.model = await Gemma4ForConditionalGeneration.from_pretrained(config.modelId, {
|
||||
dtype: config.dtype,
|
||||
device: 'webgpu',
|
||||
progress_callback: progressCallback,
|
||||
});
|
||||
|
||||
this.setStatus({ state: 'ready' });
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
this.setStatus({ state: 'error', error: message });
|
||||
this.loadPromise = null;
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
async unload(): Promise<void> {
|
||||
this.model = null;
|
||||
this.processor = null;
|
||||
this.currentModel = null;
|
||||
this.loadPromise = null;
|
||||
this.setStatus({ state: 'idle' });
|
||||
}
|
||||
|
||||
async generate(options: GenerateOptions): Promise<GenerateResult> {
|
||||
if (!this.model || !this.processor) {
|
||||
await this.load();
|
||||
}
|
||||
|
||||
const start = performance.now();
|
||||
|
||||
// Two-step input prep — see previous engine.ts comment for why we
|
||||
// can't use apply_chat_template's all-in-one return_dict mode for
|
||||
// Gemma4ForConditionalGeneration.
|
||||
const promptText: string = this.processor.apply_chat_template(options.messages, {
|
||||
add_generation_prompt: true,
|
||||
tokenize: false,
|
||||
});
|
||||
|
||||
const inputs = this.processor.tokenizer(promptText, {
|
||||
return_tensors: 'pt',
|
||||
});
|
||||
|
||||
const promptTokenCount = this.tensorLength(inputs?.input_ids);
|
||||
|
||||
let collectedText = '';
|
||||
const transformers = this.transformers as TransformersModule;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const TextStreamer = (transformers as any).TextStreamer;
|
||||
const streamer = new TextStreamer(this.processor.tokenizer, {
|
||||
skip_prompt: true,
|
||||
skip_special_tokens: true,
|
||||
callback_function: (text: string) => {
|
||||
collectedText += text;
|
||||
options.onToken?.(text);
|
||||
},
|
||||
});
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let generated: any = null;
|
||||
try {
|
||||
generated = await this.model.generate({
|
||||
...inputs,
|
||||
max_new_tokens: options.maxTokens ?? 1024,
|
||||
temperature: options.temperature ?? 0.7,
|
||||
do_sample: (options.temperature ?? 0.7) > 0,
|
||||
streamer,
|
||||
});
|
||||
} catch (err) {
|
||||
if (!collectedText) throw err;
|
||||
}
|
||||
|
||||
let completionTokenCount = 0;
|
||||
try {
|
||||
if (generated && generated.dims) {
|
||||
const fullSequence = this.tensorRow(generated, 0);
|
||||
completionTokenCount = Math.max(0, fullSequence.length - promptTokenCount);
|
||||
}
|
||||
} catch {
|
||||
// fall through to estimate
|
||||
}
|
||||
if (completionTokenCount === 0 && collectedText) {
|
||||
completionTokenCount = Math.ceil(collectedText.length / 4);
|
||||
}
|
||||
|
||||
return {
|
||||
content: collectedText,
|
||||
usage: {
|
||||
prompt_tokens: promptTokenCount,
|
||||
completion_tokens: completionTokenCount,
|
||||
total_tokens: promptTokenCount + completionTokenCount,
|
||||
},
|
||||
latencyMs: Math.round(performance.now() - start),
|
||||
};
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
private tensorLength(tensor: any): number {
|
||||
if (!tensor || !tensor.dims) return 0;
|
||||
return tensor.dims[tensor.dims.length - 1];
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
private tensorRow(tensor: any, row: number): number[] {
|
||||
const seqLen = tensor.dims[tensor.dims.length - 1];
|
||||
const start = row * seqLen;
|
||||
return Array.from(tensor.data.slice(start, start + seqLen)) as number[];
|
||||
}
|
||||
}
|
||||
|
|
@ -1,35 +1,57 @@
|
|||
/**
|
||||
* LocalLLMEngine — transformers.js wrapper for client-side inference.
|
||||
* LocalLLMEngine — main-thread proxy for the worker-hosted engine.
|
||||
*
|
||||
* Lazy-loads a HuggingFace ONNX model on first use, caches weights in the
|
||||
* browser's Cache API, and runs inference on the WebGPU backend.
|
||||
* Public API is intentionally identical to the previous in-thread
|
||||
* implementation so existing callers (the /llm-test page, the
|
||||
* playground module, @mana/shared-llm's BrowserBackend, the Svelte 5
|
||||
* reactive bindings in svelte.svelte.ts) need no changes. Internally
|
||||
* every call now goes through a Web Worker — see worker.ts for the
|
||||
* other side of the protocol.
|
||||
*
|
||||
* The default model is Google's Gemma 4 E2B (`onnx-community/gemma-4-E2B-it-ONNX`,
|
||||
* q4f16). The external API of this class is intentionally identical to the
|
||||
* previous WebLLM implementation so callers (Svelte stores, /llm-test page,
|
||||
* playground module) need no changes when the underlying engine swaps.
|
||||
* Why a worker: a 2B-parameter LLM does heavy synchronous tensor work
|
||||
* for ~50-200 ms per forward pass. With ~150 forward passes per
|
||||
* generation, the main thread would freeze for ~10-30 seconds during
|
||||
* a single chat reply. Web Workers run on their own thread, so the
|
||||
* main UI stays responsive throughout.
|
||||
*
|
||||
* The proxy is constructed lazily — the Worker is only instantiated
|
||||
* on first method call. This matters for SSR: importing this module
|
||||
* during a server render must NOT touch the Worker constructor (which
|
||||
* doesn't exist in Node), and lazy construction is the cleanest way
|
||||
* to keep import-time side effects to zero.
|
||||
*/
|
||||
|
||||
import type { ChatMessage, GenerateOptions, GenerateResult, LoadingStatus } from './types';
|
||||
import type { ModelConfig } from './types';
|
||||
import { MODELS, DEFAULT_MODEL, type ModelKey } from './models';
|
||||
import type {
|
||||
ChatMessage,
|
||||
GenerateOptions,
|
||||
GenerateResult,
|
||||
LoadingStatus,
|
||||
ModelConfig,
|
||||
} from './types';
|
||||
import type { SerializableGenerateOptions, WorkerRequest, WorkerResponse } from './worker';
|
||||
|
||||
// transformers.js types are minimal here on purpose. The library does not
|
||||
// publish first-class TS types for every model class, and we never expose
|
||||
// these objects past this file — the public surface (LocalLLMEngine methods)
|
||||
// is fully typed via our own GenerateResult / LoadingStatus etc.
|
||||
type TransformersModule = typeof import('@huggingface/transformers');
|
||||
/** Tracking entry for an in-flight worker request. */
|
||||
interface PendingRequest {
|
||||
resolve: (data: unknown) => void;
|
||||
reject: (err: Error) => void;
|
||||
onToken?: (token: string) => void;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
type AnyModel = any;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
type AnyProcessor = any;
|
||||
/**
|
||||
* Distributive Omit — preserves the discriminated union when stripping
|
||||
* a key. Plain `Omit<Union, K>` collapses to an intersection in many
|
||||
* TS versions and loses the type narrowing on `req.type`. This helper
|
||||
* distributes the Omit across each member of the union so postRequest
|
||||
* still type-checks at the call sites.
|
||||
*/
|
||||
type DistributiveOmit<T, K extends keyof T> = T extends unknown ? Omit<T, K> : never;
|
||||
type WorkerRequestPayload = DistributiveOmit<WorkerRequest, 'id'>;
|
||||
|
||||
export class LocalLLMEngine {
|
||||
private model: AnyModel = null;
|
||||
private processor: AnyProcessor = null;
|
||||
private transformers: TransformersModule | null = null;
|
||||
private loadPromise: Promise<void> | null = null;
|
||||
private worker: Worker | null = null;
|
||||
private pending = new Map<string, PendingRequest>();
|
||||
private nextId = 0;
|
||||
private currentModel: ModelKey | null = null;
|
||||
private _status: LoadingStatus = { state: 'idle' };
|
||||
private statusListeners: Set<(status: LoadingStatus) => void> = new Set();
|
||||
|
|
@ -46,9 +68,6 @@ export class LocalLLMEngine {
|
|||
return this.currentModel ? MODELS[this.currentModel] : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to status changes (for non-Svelte usage).
|
||||
*/
|
||||
onStatusChange(listener: (status: LoadingStatus) => void): () => void {
|
||||
this.statusListeners.add(listener);
|
||||
return () => this.statusListeners.delete(listener);
|
||||
|
|
@ -61,297 +80,118 @@ export class LocalLLMEngine {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if WebGPU is available in this browser.
|
||||
*/
|
||||
/** Check if WebGPU is available. Synchronous and SSR-safe. */
|
||||
static isSupported(): boolean {
|
||||
return typeof navigator !== 'undefined' && 'gpu' in navigator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a model. Idempotent — returns immediately if already loaded.
|
||||
* Model weights are cached in the browser Cache API for instant reload.
|
||||
*/
|
||||
async load(model: ModelKey = DEFAULT_MODEL): Promise<void> {
|
||||
// Already loaded with this model
|
||||
if (this.model && this.currentModel === model) return;
|
||||
// ─── Worker management ──────────────────────────────────
|
||||
|
||||
// Already loading
|
||||
if (this.loadPromise && this.currentModel === model) return this.loadPromise;
|
||||
private getWorker(): Worker {
|
||||
if (this.worker) return this.worker;
|
||||
|
||||
// Unload previous model if switching
|
||||
if (this.model && this.currentModel !== model) {
|
||||
await this.unload();
|
||||
if (typeof Worker === 'undefined') {
|
||||
throw new Error('@mana/local-llm requires a browser environment (Worker is not defined)');
|
||||
}
|
||||
|
||||
this.currentModel = model;
|
||||
this.loadPromise = this._load(model);
|
||||
return this.loadPromise;
|
||||
}
|
||||
// `new URL('./worker.ts', import.meta.url)` is Vite's documented
|
||||
// pattern for declaring a Web Worker entry. Vite picks this up at
|
||||
// build time, splits worker.ts (and its transformers.js dep) into
|
||||
// its own chunk, and rewrites the URL to the chunk's hashed path.
|
||||
// Outside Vite (raw esbuild, plain Node, etc.) this would fail —
|
||||
// but the only consumer of this package is the SvelteKit web app
|
||||
// where Vite handles the bundling.
|
||||
this.worker = new Worker(new URL('./worker.ts', import.meta.url), {
|
||||
type: 'module',
|
||||
name: 'mana-local-llm',
|
||||
});
|
||||
|
||||
private async _load(model: ModelKey): Promise<void> {
|
||||
if (!LocalLLMEngine.isSupported()) {
|
||||
this.setStatus({ state: 'error', error: 'WebGPU not supported in this browser' });
|
||||
throw new Error('WebGPU not supported');
|
||||
}
|
||||
|
||||
this.setStatus({ state: 'checking' });
|
||||
|
||||
try {
|
||||
if (!this.transformers) {
|
||||
this.transformers = await import('@huggingface/transformers');
|
||||
this.worker.addEventListener('message', this.handleWorkerMessage);
|
||||
this.worker.addEventListener('error', (e) => {
|
||||
// Worker-level fatal error — reject all pending requests.
|
||||
const message = e.message || 'Worker crashed';
|
||||
for (const [id, p] of this.pending) {
|
||||
p.reject(new Error(`Worker error: ${message}`));
|
||||
this.pending.delete(id);
|
||||
}
|
||||
const config = MODELS[model];
|
||||
|
||||
// transformers.js progress callback shape:
|
||||
// { status: 'initiate'|'download'|'progress'|'done'|'ready',
|
||||
// name?: string, file?: string, progress?: number,
|
||||
// loaded?: number, total?: number }
|
||||
//
|
||||
// The callback fires per-file, and the library downloads many
|
||||
// shards in parallel (config.json, tokenizer.json, several
|
||||
// onnx weight files, …). If we naively report the latest event
|
||||
// the bar bounces wildly between files. Instead we keep a
|
||||
// per-file byte-accounting map and emit an aggregated total
|
||||
// every time anything moves. The denominator can grow as new
|
||||
// files are discovered (causing brief dips), but both
|
||||
// numerator and denominator are individually monotonic, so the
|
||||
// dips are small and brief — much smoother than per-file.
|
||||
const fileProgress = new Map<string, { loaded: number; total: number }>();
|
||||
|
||||
const formatBytes = (bytes: number): string => {
|
||||
if (bytes < 1024) return `${bytes} B`;
|
||||
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)} KB`;
|
||||
if (bytes < 1024 * 1024 * 1024) return `${(bytes / (1024 * 1024)).toFixed(0)} MB`;
|
||||
return `${(bytes / (1024 * 1024 * 1024)).toFixed(2)} GB`;
|
||||
};
|
||||
|
||||
const emitAggregate = () => {
|
||||
let totalLoaded = 0;
|
||||
let totalSize = 0;
|
||||
for (const { loaded, total } of fileProgress.values()) {
|
||||
totalLoaded += loaded;
|
||||
totalSize += total;
|
||||
}
|
||||
const pct = totalSize > 0 ? totalLoaded / totalSize : 0;
|
||||
this.setStatus({
|
||||
state: 'downloading',
|
||||
progress: pct,
|
||||
text:
|
||||
totalSize > 0
|
||||
? `Downloading model (${(pct * 100).toFixed(0)}%, ${formatBytes(totalLoaded)} / ${formatBytes(totalSize)}, ${fileProgress.size} files)`
|
||||
: `Downloading model (${fileProgress.size} files queued)`,
|
||||
});
|
||||
};
|
||||
|
||||
const progressCallback = (report: {
|
||||
status: string;
|
||||
file?: string;
|
||||
name?: string;
|
||||
progress?: number;
|
||||
loaded?: number;
|
||||
total?: number;
|
||||
}) => {
|
||||
const file = report.file ?? report.name ?? '_unknown';
|
||||
if (report.status === 'initiate') {
|
||||
if (!fileProgress.has(file)) fileProgress.set(file, { loaded: 0, total: 0 });
|
||||
emitAggregate();
|
||||
} else if (report.status === 'download' || report.status === 'progress') {
|
||||
fileProgress.set(file, {
|
||||
loaded: report.loaded ?? 0,
|
||||
total: report.total ?? fileProgress.get(file)?.total ?? 0,
|
||||
});
|
||||
emitAggregate();
|
||||
} else if (report.status === 'done') {
|
||||
// Pin the file to 100% so a final emit shows it complete
|
||||
const existing = fileProgress.get(file);
|
||||
if (existing && existing.total > 0) {
|
||||
fileProgress.set(file, { loaded: existing.total, total: existing.total });
|
||||
}
|
||||
emitAggregate();
|
||||
}
|
||||
// 'ready' is handled below after both processor + model finish
|
||||
};
|
||||
|
||||
// AutoProcessor wraps tokenizer + image/audio preprocessors. For
|
||||
// our text-only chat path we use the wrapped tokenizer's
|
||||
// apply_chat_template, but loading the full processor is the
|
||||
// path the model card documents and avoids architecture-specific
|
||||
// special-casing.
|
||||
const { AutoProcessor, Gemma4ForConditionalGeneration } = this.transformers as unknown as {
|
||||
AutoProcessor: { from_pretrained(id: string, opts?: unknown): Promise<AnyProcessor> };
|
||||
Gemma4ForConditionalGeneration: {
|
||||
from_pretrained(id: string, opts?: unknown): Promise<AnyModel>;
|
||||
};
|
||||
};
|
||||
|
||||
this.processor = await AutoProcessor.from_pretrained(config.modelId, {
|
||||
progress_callback: progressCallback,
|
||||
});
|
||||
this.model = await Gemma4ForConditionalGeneration.from_pretrained(config.modelId, {
|
||||
dtype: config.dtype,
|
||||
device: 'webgpu',
|
||||
progress_callback: progressCallback,
|
||||
});
|
||||
|
||||
this.setStatus({ state: 'ready' });
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
this.setStatus({ state: 'error', error: message });
|
||||
this.loadPromise = null;
|
||||
throw err;
|
||||
}
|
||||
});
|
||||
|
||||
return this.worker;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unload the model and free GPU memory.
|
||||
*/
|
||||
async unload(): Promise<void> {
|
||||
// transformers.js doesn't expose an explicit dispose() yet — dropping
|
||||
// the references and letting the runtime/GC clean up is the
|
||||
// recommended path. The WebGPU buffers are tied to the model object
|
||||
// and get released when it's no longer reachable.
|
||||
this.model = null;
|
||||
this.processor = null;
|
||||
this.currentModel = null;
|
||||
this.loadPromise = null;
|
||||
this.setStatus({ state: 'idle' });
|
||||
}
|
||||
private handleWorkerMessage = (event: MessageEvent<WorkerResponse>) => {
|
||||
const msg = event.data;
|
||||
|
||||
/**
|
||||
* Generate a response. Auto-loads the model if not yet loaded.
|
||||
*
|
||||
* Implementation notes for the transformers.js v4 backend:
|
||||
*
|
||||
* - We always attach a TextStreamer (regardless of whether the caller
|
||||
* passed an `onToken`), because the streamer is the *only* documented
|
||||
* stable way to read generated text out of model.generate(). The
|
||||
* tensor return value of generate() varies between transformers.js
|
||||
* versions and is sometimes null when a streamer is in play, which
|
||||
* used to crash this method with "Cannot read properties of null
|
||||
* (reading 'dims')" the moment a chat message was sent.
|
||||
*
|
||||
* - Token counts are computed from the tensor return value when
|
||||
* available, and fall back to a chars/4 estimate when it isn't —
|
||||
* so /llm-test still shows roughly meaningful prompt/completion
|
||||
* counts even on versions where generate() returns nothing.
|
||||
*/
|
||||
async generate(options: GenerateOptions): Promise<GenerateResult> {
|
||||
if (!this.model || !this.processor) {
|
||||
await this.load();
|
||||
// Status broadcasts have no id and target every listener.
|
||||
if (msg.type === 'status') {
|
||||
this.setStatus(msg.status);
|
||||
return;
|
||||
}
|
||||
|
||||
const start = performance.now();
|
||||
// Streaming token: route to the matching request's onToken
|
||||
// callback if one was registered.
|
||||
if (msg.type === 'token') {
|
||||
const pending = this.pending.get(msg.id);
|
||||
pending?.onToken?.(msg.token);
|
||||
return;
|
||||
}
|
||||
|
||||
// Two-step input prep, matching the Gemma 4 model-card example:
|
||||
// 1. Apply the chat template with tokenize:false to get the
|
||||
// formatted prompt as a plain string (no tokens, no tensor).
|
||||
// 2. Run the string through the processor's tokenizer with
|
||||
// return_tensors:'pt' to get a proper { input_ids, attention_mask }
|
||||
// pair backed by transformers.js Tensor objects.
|
||||
//
|
||||
// We previously asked apply_chat_template to do everything in one
|
||||
// shot via `return_dict: true`, but for Gemma4ForConditionalGeneration
|
||||
// that path returned a malformed shape (no .dims on input_ids), and
|
||||
// model.generate() then crashed deep inside the forward pass with
|
||||
// "Cannot read properties of null (reading 'dims')" — surfacing as
|
||||
// an opaque chat error. The two-step path is what every transformers.js
|
||||
// example for multimodal-capable processors uses.
|
||||
const promptText: string = this.processor.apply_chat_template(options.messages, {
|
||||
add_generation_prompt: true,
|
||||
tokenize: false,
|
||||
});
|
||||
// Result/error: resolve or reject the matching pending Promise.
|
||||
const pending = this.pending.get(msg.id);
|
||||
if (!pending) return;
|
||||
this.pending.delete(msg.id);
|
||||
|
||||
const inputs = this.processor.tokenizer(promptText, {
|
||||
return_tensors: 'pt',
|
||||
});
|
||||
if (msg.type === 'result') {
|
||||
pending.resolve(msg.data);
|
||||
} else {
|
||||
pending.reject(new Error(msg.message));
|
||||
}
|
||||
};
|
||||
|
||||
const promptTokenCount = this.tensorLength(inputs?.input_ids);
|
||||
private postRequest<T>(req: WorkerRequestPayload, onToken?: (token: string) => void): Promise<T> {
|
||||
const id = `${++this.nextId}`;
|
||||
const worker = this.getWorker();
|
||||
|
||||
// Always attach a streamer — it's our reliable text channel.
|
||||
let collectedText = '';
|
||||
const transformers = this.transformers as TransformersModule;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const TextStreamer = (transformers as any).TextStreamer;
|
||||
const streamer = new TextStreamer(this.processor.tokenizer, {
|
||||
skip_prompt: true,
|
||||
skip_special_tokens: true,
|
||||
callback_function: (text: string) => {
|
||||
collectedText += text;
|
||||
options.onToken?.(text);
|
||||
},
|
||||
});
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let generated: any = null;
|
||||
try {
|
||||
generated = await this.model.generate({
|
||||
...inputs,
|
||||
max_new_tokens: options.maxTokens ?? 1024,
|
||||
temperature: options.temperature ?? 0.7,
|
||||
do_sample: (options.temperature ?? 0.7) > 0,
|
||||
streamer,
|
||||
return new Promise<T>((resolve, reject) => {
|
||||
this.pending.set(id, {
|
||||
resolve: (data) => resolve(data as T),
|
||||
reject,
|
||||
onToken,
|
||||
});
|
||||
} catch (err) {
|
||||
// Some transformers.js versions throw at the end of streaming
|
||||
// even though the streamer successfully delivered all tokens.
|
||||
// Only re-throw if we genuinely have nothing to return.
|
||||
if (!collectedText) throw err;
|
||||
}
|
||||
|
||||
// Token counts: prefer the tensor return value, fall back to a
|
||||
// rough estimate from the collected text length so the UI still
|
||||
// shows non-zero numbers even on versions where generate() returns
|
||||
// null when a streamer is attached.
|
||||
let completionTokenCount = 0;
|
||||
try {
|
||||
if (generated && generated.dims) {
|
||||
const fullSequence = this.tensorRow(generated, 0);
|
||||
completionTokenCount = Math.max(0, fullSequence.length - promptTokenCount);
|
||||
}
|
||||
} catch {
|
||||
// fall through to estimate
|
||||
}
|
||||
if (completionTokenCount === 0 && collectedText) {
|
||||
// Gemma's BPE averages ~4 chars per token in English/German,
|
||||
// good enough for a UI hint, not for billing.
|
||||
completionTokenCount = Math.ceil(collectedText.length / 4);
|
||||
}
|
||||
|
||||
return {
|
||||
content: collectedText,
|
||||
usage: {
|
||||
prompt_tokens: promptTokenCount,
|
||||
completion_tokens: completionTokenCount,
|
||||
total_tokens: promptTokenCount + completionTokenCount,
|
||||
},
|
||||
latencyMs: Math.round(performance.now() - start),
|
||||
};
|
||||
worker.postMessage({ ...req, id } as WorkerRequest);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper: extract the seq-length of a transformers.js Tensor.
|
||||
* The tensors expose `.dims` ([batch, seq_len]) and `.data` (TypedArray).
|
||||
*/
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
private tensorLength(tensor: any): number {
|
||||
if (!tensor || !tensor.dims) return 0;
|
||||
return tensor.dims[tensor.dims.length - 1];
|
||||
// ─── Public API ──────────────────────────────────────────
|
||||
|
||||
async load(model: ModelKey = DEFAULT_MODEL): Promise<void> {
|
||||
if (this.currentModel === model && this.isReady) return;
|
||||
this.currentModel = model;
|
||||
await this.postRequest<void>({ type: 'load', modelKey: model });
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper: extract row N of a 2D tensor as a number array.
|
||||
*/
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
private tensorRow(tensor: any, row: number): number[] {
|
||||
const seqLen = tensor.dims[tensor.dims.length - 1];
|
||||
const start = row * seqLen;
|
||||
return Array.from(tensor.data.slice(start, start + seqLen)) as number[];
|
||||
async unload(): Promise<void> {
|
||||
if (!this.worker) return; // never loaded, nothing to do
|
||||
await this.postRequest<void>({ type: 'unload' });
|
||||
this.currentModel = null;
|
||||
// Tear down the worker so a future load() starts a fresh one
|
||||
// with cleared GPU buffers. transformers.js doesn't expose an
|
||||
// explicit dispose, so terminating the worker is the only way
|
||||
// to reliably reclaim VRAM.
|
||||
this.worker.terminate();
|
||||
this.worker = null;
|
||||
this.pending.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience: single prompt → response.
|
||||
*/
|
||||
async generate(options: GenerateOptions): Promise<GenerateResult> {
|
||||
const { onToken, ...rest } = options;
|
||||
const opts: SerializableGenerateOptions = rest;
|
||||
return this.postRequest<GenerateResult>({ type: 'generate', opts }, onToken);
|
||||
}
|
||||
|
||||
// ─── Convenience wrappers (main thread, build on top of generate) ──
|
||||
|
||||
async prompt(
|
||||
text: string,
|
||||
opts?: { systemPrompt?: string; temperature?: number; maxTokens?: number }
|
||||
|
|
@ -370,9 +210,6 @@ export class LocalLLMEngine {
|
|||
return result.content;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience: extract structured JSON from text.
|
||||
*/
|
||||
async extractJson<T = unknown>(
|
||||
text: string,
|
||||
instruction: string,
|
||||
|
|
@ -397,9 +234,6 @@ export class LocalLLMEngine {
|
|||
return JSON.parse(result.content) as T;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience: classify text into categories.
|
||||
*/
|
||||
async classify(text: string, categories: string[], opts?: { context?: string }): Promise<string> {
|
||||
const categoryList = categories.map((c) => `"${c}"`).join(', ');
|
||||
const result = await this.generate({
|
||||
|
|
@ -415,7 +249,6 @@ export class LocalLLMEngine {
|
|||
});
|
||||
|
||||
const normalized = result.content.trim().replace(/^["']|["']$/g, '');
|
||||
// Return the closest matching category
|
||||
const match = categories.find((c) => c.toLowerCase() === normalized.toLowerCase());
|
||||
return match ?? normalized;
|
||||
}
|
||||
|
|
|
|||
110
packages/local-llm/src/worker.ts
Normal file
110
packages/local-llm/src/worker.ts
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Web Worker entry point for @mana/local-llm.
|
||||
*
|
||||
* Runs in a Dedicated Worker context, owns a single LocalLLMEngineImpl
|
||||
* instance, and exchanges messages with the main thread proxy
|
||||
* (engine.ts) over postMessage. The protocol is small and typed:
|
||||
*
|
||||
* Main → Worker (WorkerRequest):
|
||||
* { id, type: 'load', modelKey: ModelKey }
|
||||
* { id, type: 'unload' }
|
||||
* { id, type: 'generate', opts: SerializableGenerateOptions }
|
||||
* { id, type: 'isReady' } — synchronous probe; resolves with bool
|
||||
*
|
||||
* Worker → Main (WorkerResponse):
|
||||
* { id, type: 'result', data?: unknown } — request fulfilled
|
||||
* { id, type: 'error', message: string } — request rejected
|
||||
* { id, type: 'token', token: string } — streaming token chunk
|
||||
* { type: 'status', status: LoadingStatus } — broadcast, no id
|
||||
*
|
||||
* Each request has a unique id chosen by the proxy. The worker echoes
|
||||
* the id on its result/error/token responses so the proxy can route
|
||||
* them back to the right pending Promise + onToken callback. Status
|
||||
* messages are broadcast (no id) and trigger every registered status
|
||||
* listener on the proxy.
|
||||
*
|
||||
* Note: this file does NOT import @mana/local-llm's index — it imports
|
||||
* engine-impl directly. The package's public surface is the proxy in
|
||||
* engine.ts; this file is the worker side of that proxy and lives in
|
||||
* its own bundle chunk.
|
||||
*/
|
||||
|
||||
import { LocalLLMEngineImpl } from './engine-impl';
|
||||
import type { GenerateOptions, LoadingStatus } from './types';
|
||||
import type { ModelKey } from './models';
|
||||
|
||||
// ─── Protocol types (mirrored in engine.ts) ────────────────────
|
||||
|
||||
export type SerializableGenerateOptions = Omit<GenerateOptions, 'onToken'>;
|
||||
|
||||
export type WorkerRequest =
|
||||
| { id: string; type: 'load'; modelKey: ModelKey }
|
||||
| { id: string; type: 'unload' }
|
||||
| { id: string; type: 'generate'; opts: SerializableGenerateOptions }
|
||||
| { id: string; type: 'isReady' };
|
||||
|
||||
export type WorkerResponse =
|
||||
| { id: string; type: 'result'; data?: unknown }
|
||||
| { id: string; type: 'error'; message: string }
|
||||
| { id: string; type: 'token'; token: string }
|
||||
| { type: 'status'; status: LoadingStatus };
|
||||
|
||||
// ─── Worker setup ──────────────────────────────────────────────
|
||||
|
||||
const engine = new LocalLLMEngineImpl();
|
||||
|
||||
// Forward all status changes to the main thread as broadcast messages.
|
||||
engine.onStatusChange((status) => {
|
||||
postMessage({ type: 'status', status } satisfies WorkerResponse);
|
||||
});
|
||||
|
||||
self.addEventListener('message', async (event: MessageEvent<WorkerRequest>) => {
|
||||
const req = event.data;
|
||||
|
||||
try {
|
||||
switch (req.type) {
|
||||
case 'load': {
|
||||
await engine.load(req.modelKey);
|
||||
postMessage({ id: req.id, type: 'result' } satisfies WorkerResponse);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'unload': {
|
||||
await engine.unload();
|
||||
postMessage({ id: req.id, type: 'result' } satisfies WorkerResponse);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'isReady': {
|
||||
postMessage({
|
||||
id: req.id,
|
||||
type: 'result',
|
||||
data: engine.isReady,
|
||||
} satisfies WorkerResponse);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'generate': {
|
||||
// Re-attach the streaming callback on the worker side. Each
|
||||
// emitted token gets posted back to the main thread tagged
|
||||
// with the originating request id, so the proxy can route
|
||||
// it to the right onToken callback.
|
||||
const result = await engine.generate({
|
||||
...req.opts,
|
||||
onToken: (token) => {
|
||||
postMessage({ id: req.id, type: 'token', token } satisfies WorkerResponse);
|
||||
},
|
||||
});
|
||||
postMessage({
|
||||
id: req.id,
|
||||
type: 'result',
|
||||
data: result,
|
||||
} satisfies WorkerResponse);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
postMessage({ id: req.id, type: 'error', message } satisfies WorkerResponse);
|
||||
}
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue