diff --git a/packages/local-llm/src/engine.ts b/packages/local-llm/src/engine.ts index c2c9021e3..86d9aa384 100644 --- a/packages/local-llm/src/engine.ts +++ b/packages/local-llm/src/engine.ts @@ -220,6 +220,21 @@ export class LocalLLMEngine { /** * Generate a response. Auto-loads the model if not yet loaded. + * + * Implementation notes for the transformers.js v4 backend: + * + * - We always attach a TextStreamer (regardless of whether the caller + * passed an `onToken`), because the streamer is the *only* documented + * stable way to read generated text out of model.generate(). The + * tensor return value of generate() varies between transformers.js + * versions and is sometimes null when a streamer is in play, which + * used to crash this method with "Cannot read properties of null + * (reading 'dims')" the moment a chat message was sent. + * + * - Token counts are computed from the tensor return value when + * available, and fall back to a chars/4 estimate when it isn't — + * so /llm-test still shows roughly meaningful prompt/completion + * counts even on versions where generate() returns nothing. */ async generate(options: GenerateOptions): Promise { if (!this.model || !this.processor) { @@ -229,51 +244,73 @@ export class LocalLLMEngine { const start = performance.now(); // Apply Gemma's chat template via the processor's tokenizer wrapper. - // `add_generation_prompt: true` appends the tokens that tell the model - // "now generate an assistant turn". + // `add_generation_prompt: true` appends the tokens that tell the + // model "now generate an assistant turn". `return_dict: true` makes + // it return { input_ids, attention_mask } so we can spread it into + // model.generate(). NOTE: do NOT pass `return_tensor: 'pt'` — that + // is the Python `transformers` convention; transformers.js's + // equivalent option is just `return_tensor: true`, which is the + // default anyway. Passing the string broke nothing in older + // versions but made input shape detection unreliable. const inputs = await this.processor.apply_chat_template(options.messages, { add_generation_prompt: true, return_dict: true, - return_tensor: 'pt', }); - const promptTokenCount = this.tensorLength(inputs.input_ids); + const promptTokenCount = this.tensorLength(inputs?.input_ids); - // Streaming via TextStreamer if requested - let streamer: unknown = undefined; - if (options.onToken) { - const transformers = this.transformers as TransformersModule; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const TextStreamer = (transformers as any).TextStreamer; - streamer = new TextStreamer(this.processor.tokenizer, { - skip_prompt: true, - skip_special_tokens: true, - callback_function: (text: string) => { - options.onToken!(text); - }, + // Always attach a streamer — it's our reliable text channel. + let collectedText = ''; + const transformers = this.transformers as TransformersModule; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const TextStreamer = (transformers as any).TextStreamer; + const streamer = new TextStreamer(this.processor.tokenizer, { + skip_prompt: true, + skip_special_tokens: true, + callback_function: (text: string) => { + collectedText += text; + options.onToken?.(text); + }, + }); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let generated: any = null; + try { + generated = await this.model.generate({ + ...inputs, + max_new_tokens: options.maxTokens ?? 1024, + temperature: options.temperature ?? 0.7, + do_sample: (options.temperature ?? 0.7) > 0, + streamer, }); + } catch (err) { + // Some transformers.js versions throw at the end of streaming + // even though the streamer successfully delivered all tokens. + // Only re-throw if we genuinely have nothing to return. + if (!collectedText) throw err; } - const generated = await this.model.generate({ - ...inputs, - max_new_tokens: options.maxTokens ?? 1024, - temperature: options.temperature ?? 0.7, - do_sample: (options.temperature ?? 0.7) > 0, - streamer, - }); - - // `generated` is a tensor with shape [batch, seq_len_with_prompt]. - // We slice off the prompt portion to get just the new tokens. - const fullSequence = this.tensorRow(generated, 0); - const newTokens = fullSequence.slice(promptTokenCount); - const completionTokenCount = newTokens.length; - - const content: string = this.processor.tokenizer.decode(newTokens, { - skip_special_tokens: true, - }); + // Token counts: prefer the tensor return value, fall back to a + // rough estimate from the collected text length so the UI still + // shows non-zero numbers even on versions where generate() returns + // null when a streamer is attached. + let completionTokenCount = 0; + try { + if (generated && generated.dims) { + const fullSequence = this.tensorRow(generated, 0); + completionTokenCount = Math.max(0, fullSequence.length - promptTokenCount); + } + } catch { + // fall through to estimate + } + if (completionTokenCount === 0 && collectedText) { + // Gemma's BPE averages ~4 chars per token in English/German, + // good enough for a UI hint, not for billing. + completionTokenCount = Math.ceil(collectedText.length / 4); + } return { - content, + content: collectedText, usage: { prompt_tokens: promptTokenCount, completion_tokens: completionTokenCount,