diff --git a/packages/local-llm/src/engine.ts b/packages/local-llm/src/engine.ts index 86d9aa384..a11d90aea 100644 --- a/packages/local-llm/src/engine.ts +++ b/packages/local-llm/src/engine.ts @@ -243,18 +243,27 @@ export class LocalLLMEngine { const start = performance.now(); - // Apply Gemma's chat template via the processor's tokenizer wrapper. - // `add_generation_prompt: true` appends the tokens that tell the - // model "now generate an assistant turn". `return_dict: true` makes - // it return { input_ids, attention_mask } so we can spread it into - // model.generate(). NOTE: do NOT pass `return_tensor: 'pt'` — that - // is the Python `transformers` convention; transformers.js's - // equivalent option is just `return_tensor: true`, which is the - // default anyway. Passing the string broke nothing in older - // versions but made input shape detection unreliable. - const inputs = await this.processor.apply_chat_template(options.messages, { + // Two-step input prep, matching the Gemma 4 model-card example: + // 1. Apply the chat template with tokenize:false to get the + // formatted prompt as a plain string (no tokens, no tensor). + // 2. Run the string through the processor's tokenizer with + // return_tensors:'pt' to get a proper { input_ids, attention_mask } + // pair backed by transformers.js Tensor objects. + // + // We previously asked apply_chat_template to do everything in one + // shot via `return_dict: true`, but for Gemma4ForConditionalGeneration + // that path returned a malformed shape (no .dims on input_ids), and + // model.generate() then crashed deep inside the forward pass with + // "Cannot read properties of null (reading 'dims')" — surfacing as + // an opaque chat error. The two-step path is what every transformers.js + // example for multimodal-capable processors uses. + const promptText: string = this.processor.apply_chat_template(options.messages, { add_generation_prompt: true, - return_dict: true, + tokenize: false, + }); + + const inputs = this.processor.tokenizer(promptText, { + return_tensors: 'pt', }); const promptTokenCount = this.tensorLength(inputs?.input_ids);