From e757470cb0c5c7a60c1943964213b1c1cef47888 Mon Sep 17 00:00:00 2001 From: Till JS Date: Mon, 20 Apr 2026 15:22:48 +0200 Subject: [PATCH] feat(mana-llm): add OpenAI-style tools + tool_calls passthrough MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the chat-completions surface so callers can ask any provider to call named functions and get structured tool_calls back. Wired through all three provider adapters so the planner and companion can switch off the fragile JSON-parsing pathway. - Request: tools[], tool_choice, assistant tool_calls, tool-role messages with tool_call_id. - Response: MessageResponse.tool_calls, Choice.finish_reason adds "tool_calls", DeltaContent streams tool_calls. - Google provider: Tool(function_declarations=...) build, result normalised (args dict → JSON string), function_response parts on a user turn for tool-role messages. - OpenAI-compat: 1:1 passthrough of the OpenAI spec. - Ollama: /api/chat passthrough; model-level capability check via a TOOL_CAPABLE_OLLAMA_PATTERNS whitelist (llama3.1+, qwen2.5+, mistral, command-r, …) — unsupported models rejected rather than silently falling back to prose. - Router: model_supports_tools() check upfront for both streaming and non-streaming paths; ProviderCapabilityError bubbles as 400. No silent downgrade. Missing tool support = explicit error. Co-Authored-By: Claude Opus 4.7 (1M context) --- services/mana-llm/src/models/__init__.py | 17 +- services/mana-llm/src/models/requests.py | 63 ++++- services/mana-llm/src/models/responses.py | 27 +- services/mana-llm/src/providers/base.py | 15 ++ services/mana-llm/src/providers/google.py | 235 +++++++++++++++--- services/mana-llm/src/providers/ollama.py | 158 ++++++++++-- .../mana-llm/src/providers/openai_compat.py | 112 ++++++++- services/mana-llm/src/providers/router.py | 22 ++ 8 files changed, 568 insertions(+), 81 deletions(-) diff --git a/services/mana-llm/src/models/__init__.py b/services/mana-llm/src/models/__init__.py index 1c1785e30..ac99636cd 100644 --- a/services/mana-llm/src/models/__init__.py +++ b/services/mana-llm/src/models/__init__.py @@ -1,6 +1,13 @@ """Pydantic models for OpenAI-compatible API.""" -from .requests import ChatCompletionRequest, EmbeddingRequest +from .requests import ( + ChatCompletionRequest, + EmbeddingRequest, + FunctionSpec, + Message, + ToolChoice, + ToolSpec, +) from .responses import ( ChatCompletionResponse, ChatCompletionStreamResponse, @@ -12,6 +19,8 @@ from .responses import ( ModelInfo, ModelsResponse, StreamChoice, + ToolCall, + ToolCallFunction, Usage, ) @@ -24,9 +33,15 @@ __all__ = [ "EmbeddingData", "EmbeddingRequest", "EmbeddingResponse", + "FunctionSpec", + "Message", "MessageResponse", "ModelInfo", "ModelsResponse", "StreamChoice", + "ToolCall", + "ToolCallFunction", + "ToolChoice", + "ToolSpec", "Usage", ] diff --git a/services/mana-llm/src/models/requests.py b/services/mana-llm/src/models/requests.py index c470696cd..4c444124a 100644 --- a/services/mana-llm/src/models/requests.py +++ b/services/mana-llm/src/models/requests.py @@ -28,11 +28,64 @@ class ImageContent(BaseModel): MessageContent = str | list[TextContent | ImageContent] -class Message(BaseModel): - """A single message in the conversation.""" +class ToolCallFunction(BaseModel): + """The function portion of a tool_call on an assistant message.""" - role: Literal["system", "user", "assistant"] - content: MessageContent + name: str + # Arguments are passed as a JSON string (OpenAI spec). Providers may + # emit structured args natively; the adapter serialises them here. + arguments: str + + +class ToolCall(BaseModel): + """A tool invocation the assistant decided to make.""" + + id: str + type: Literal["function"] = "function" + function: ToolCallFunction + + +class Message(BaseModel): + """A single message in the conversation. + + `tool` messages carry the result of a previously-requested tool call + back into the context; they must reference the originating call via + ``tool_call_id``. Assistant messages may contain either plain + ``content`` or ``tool_calls`` (or both, though providers typically + only emit one at a time). + """ + + role: Literal["system", "user", "assistant", "tool"] + content: MessageContent | None = None + tool_call_id: str | None = None + tool_calls: list[ToolCall] | None = None + + +class FunctionSpec(BaseModel): + """OpenAI-style function declaration.""" + + name: str + description: str + # JSON Schema for the function's parameters. Kept as a dict here to + # stay loose; providers translate it to their native shape. + parameters: dict[str, Any] = Field(default_factory=dict) + + +class ToolSpec(BaseModel): + """A tool the model may call. Only ``function`` tools are supported.""" + + type: Literal["function"] = "function" + function: FunctionSpec + + +class ToolChoiceFunction(BaseModel): + """Force-pick a specific function for ``tool_choice``.""" + + type: Literal["function"] = "function" + function: dict[str, str] # {"name": "tool_name"} + + +ToolChoice = Literal["auto", "none", "required"] | ToolChoiceFunction class ResponseFormat(BaseModel): @@ -63,6 +116,8 @@ class ChatCompletionRequest(BaseModel): presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0) stop: str | list[str] | None = None response_format: ResponseFormat | None = None + tools: list[ToolSpec] | None = None + tool_choice: ToolChoice | None = None class EmbeddingRequest(BaseModel): diff --git a/services/mana-llm/src/models/responses.py b/services/mana-llm/src/models/responses.py index 44a545d24..37ac4af2e 100644 --- a/services/mana-llm/src/models/responses.py +++ b/services/mana-llm/src/models/responses.py @@ -15,11 +15,27 @@ class Usage(BaseModel): total_tokens: int = 0 +class ToolCallFunction(BaseModel): + """Function portion of a tool_call the assistant produced.""" + + name: str + arguments: str # JSON-encoded (OpenAI spec) + + +class ToolCall(BaseModel): + """A tool invocation the assistant decided to make.""" + + id: str + type: Literal["function"] = "function" + function: ToolCallFunction + + class MessageResponse(BaseModel): """Response message from the model.""" role: Literal["assistant"] = "assistant" - content: str + content: str | None = None + tool_calls: list[ToolCall] | None = None class Choice(BaseModel): @@ -27,7 +43,9 @@ class Choice(BaseModel): index: int = 0 message: MessageResponse - finish_reason: Literal["stop", "length", "content_filter"] | None = "stop" + finish_reason: ( + Literal["stop", "length", "content_filter", "tool_calls"] | None + ) = "stop" class ChatCompletionResponse(BaseModel): @@ -46,6 +64,7 @@ class DeltaContent(BaseModel): role: Literal["assistant"] | None = None content: str | None = None + tool_calls: list[ToolCall] | None = None class StreamChoice(BaseModel): @@ -53,7 +72,9 @@ class StreamChoice(BaseModel): index: int = 0 delta: DeltaContent - finish_reason: Literal["stop", "length", "content_filter"] | None = None + finish_reason: ( + Literal["stop", "length", "content_filter", "tool_calls"] | None + ) = None class ChatCompletionStreamResponse(BaseModel): diff --git a/services/mana-llm/src/providers/base.py b/services/mana-llm/src/providers/base.py index 6f163bbac..a3909e5ae 100644 --- a/services/mana-llm/src/providers/base.py +++ b/services/mana-llm/src/providers/base.py @@ -19,6 +19,21 @@ class LLMProvider(ABC): name: str = "base" + # Set to True if the provider supports OpenAI-style `tools` + `tool_calls` + # for chat completions. The router rejects tool-bearing requests routed + # to providers without support rather than silently dropping the tools. + # Provider adapters may further narrow this per-model if needed. + supports_tools: bool = False + + def model_supports_tools(self, model: str) -> bool: + """Check if a specific model within this provider supports tools. + + Default: falls back to the provider-wide flag. Providers with a + mixed capability surface (e.g. Ollama — depends on the local + model) override this. + """ + return self.supports_tools + @abstractmethod async def chat_completion( self, diff --git a/services/mana-llm/src/providers/google.py b/services/mana-llm/src/providers/google.py index 1b204b2ee..bdf768a74 100644 --- a/services/mana-llm/src/providers/google.py +++ b/services/mana-llm/src/providers/google.py @@ -1,6 +1,8 @@ """Google Gemini provider for mana-llm (fallback when Ollama is unavailable).""" +import json import logging +import uuid from collections.abc import AsyncIterator from typing import Any @@ -20,6 +22,9 @@ from src.models import ( MessageResponse, ModelInfo, StreamChoice, + ToolCall, + ToolCallFunction, + ToolSpec, Usage, ) @@ -35,14 +40,66 @@ from .errors import ( logger = logging.getLogger(__name__) -def _unwrap_gemini_response(response: Any, gemini_model: str) -> str: - """Validate a non-streaming Gemini response and return its text. +def _build_gemini_tools(tools: list[ToolSpec] | None) -> list[types.Tool] | None: + """Translate our ToolSpec list into Gemini ``types.Tool`` declarations.""" + if not tools: + return None + declarations: list[types.FunctionDeclaration] = [] + for t in tools: + declarations.append( + types.FunctionDeclaration( + name=t.function.name, + description=t.function.description, + parameters=t.function.parameters or None, + ) + ) + return [types.Tool(function_declarations=declarations)] - Raises a structured ProviderError if the response was blocked, - truncated, or otherwise produced no usable text. The SDK's - ``response.text`` accessor silently returns an empty string in all - of those cases, which downstream consumers (e.g. the planner - parser) cannot distinguish from a well-formed empty completion. + +def _extract_tool_calls(candidate: Any) -> list[ToolCall] | None: + """Pull any ``function_call`` parts off a candidate into ToolCalls. + + Gemini emits tool calls as ``content.parts[i].function_call`` with a + ``FunctionCall(name, args)`` where ``args`` is a dict (not a JSON + string). We normalise to OpenAI shape: arguments are JSON-encoded + strings so downstream handlers can treat all providers the same. + """ + if candidate is None: + return None + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + calls: list[ToolCall] = [] + for part in parts: + fc = getattr(part, "function_call", None) + if fc is None: + continue + name = getattr(fc, "name", None) + if not name: + continue + args = getattr(fc, "args", None) or {} + # Gemini's args are already dict-shaped; serialise to JSON string. + try: + arguments_json = json.dumps(dict(args), ensure_ascii=False) + except (TypeError, ValueError): + arguments_json = json.dumps({}, ensure_ascii=False) + calls.append( + ToolCall( + id=f"call_{uuid.uuid4().hex[:12]}", + function=ToolCallFunction(name=name, arguments=arguments_json), + ) + ) + return calls or None + + +def _unwrap_gemini_response( + response: Any, gemini_model: str +) -> tuple[str, list[ToolCall] | None, str]: + """Validate a non-streaming Gemini response. + + Returns ``(text, tool_calls, finish_reason)``. Raises a structured + ``ProviderError`` if the response was blocked, truncated, or + otherwise produced no usable payload. ``response.text`` silently + returns ``""`` on blocked responses — we refuse to propagate that. """ candidates = getattr(response, "candidates", None) or [] candidate = candidates[0] if candidates else None @@ -56,6 +113,7 @@ def _unwrap_gemini_response(response: Any, gemini_model: str) -> str: finish_name = finish_name.rsplit(".", 1)[-1] text = response.text or "" + tool_calls = _extract_tool_calls(candidate) if finish_name in {"SAFETY", "RECITATION", "PROHIBITED_CONTENT", "SPII", "BLOCKLIST"}: # Pull the first safety rating that actually blocked if present. @@ -80,13 +138,37 @@ def _unwrap_gemini_response(response: Any, gemini_model: str) -> str: ) raise ProviderTruncatedError(partial_text=text or None) - if not text and finish_name not in (None, "STOP"): - # Unknown finish reason, empty text — surface instead of silent "". + if not text and not tool_calls and finish_name not in (None, "STOP"): + # Unknown finish reason, nothing to return — surface instead of "". raise ProviderError( f"Gemini returned no content (finish_reason={finish_name})" ) - return text + # Normalise the finish_reason to our OpenAI-compatible vocabulary. + openai_finish = "stop" + if tool_calls: + openai_finish = "tool_calls" + elif finish_name == "MAX_TOKENS": + openai_finish = "length" + return text, tool_calls, openai_finish + + +def _lookup_tool_name(messages: list[Any], tool_call_id: str | None) -> str | None: + """Find the tool name for a given ``tool_call_id``. + + OpenAI's tool-message carries only the id; the matching name lives + on the preceding assistant message's ``tool_calls[]``. We scan from + the end (most recent assistant turn) backwards. + """ + if not tool_call_id: + return None + for m in reversed(messages): + if m.role != "assistant" or not m.tool_calls: + continue + for call in m.tool_calls: + if call.id == tool_call_id: + return call.function.name + return None def _wrap_gemini_call_error(err: Exception, gemini_model: str) -> ProviderError: @@ -123,6 +205,8 @@ class GoogleProvider(LLMProvider): """Google Gemini API provider.""" name = "google" + # Gemini 2.x supports OpenAI-style function calling across all listed models. + supports_tools = True def __init__(self, api_key: str, default_model: str = "gemini-2.5-flash"): self.api_key = api_key @@ -138,43 +222,98 @@ class GoogleProvider(LLMProvider): ) -> tuple[str | None, list[types.Content]]: """Convert OpenAI-format messages to Google Gemini format. - Returns (system_instruction, contents). + Returns (system_instruction, contents). Handles text, multimodal + image content, assistant messages carrying ``tool_calls``, and + ``tool`` result messages (mapped to Gemini ``function_response`` + parts — Gemini has no ``tool`` role, function responses ride on + a ``user`` turn). """ system_instruction: str | None = None contents: list[types.Content] = [] for msg in request.messages: if msg.role == "system": - # Gemini uses system_instruction separately if isinstance(msg.content, str): system_instruction = msg.content continue - role = "user" if msg.role == "user" else "model" - - if isinstance(msg.content, str): - contents.append(types.Content(role=role, parts=[types.Part.from_text(text=msg.content)])) - else: - # Multimodal content - parts: list[types.Part] = [] - for part in msg.content: - if part.type == "text": - parts.append(types.Part.from_text(text=part.text)) - elif part.type == "image_url" and part.image_url: - url = part.image_url.url - if url.startswith("data:"): - # Parse data URI: data:image/jpeg;base64, - header, b64_data = url.split(",", 1) - mime_type = header.split(":")[1].split(";")[0] - import base64 - - image_bytes = base64.b64decode(b64_data) - parts.append( - types.Part.from_bytes(data=image_bytes, mime_type=mime_type) + # Tool result message → function_response Part on a user turn. + if msg.role == "tool": + # The content is the stringified tool result. We also need + # a tool name — the OpenAI spec carries it on the matching + # assistant tool_call, keyed by tool_call_id. We don't + # track that back-reference here, so we pull the name + # from the preceding assistant message's tool_calls. + name = _lookup_tool_name(request.messages, msg.tool_call_id) + if not name: + continue # orphan tool message — skip silently + payload: Any + if isinstance(msg.content, str): + try: + payload = json.loads(msg.content) + if not isinstance(payload, dict): + payload = {"result": payload} + except (TypeError, ValueError): + payload = {"result": msg.content} + else: + payload = {"result": ""} + contents.append( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=name, response=payload ) - else: - # URL-based image - use as URI - parts.append(types.Part.from_uri(file_uri=url, mime_type="image/jpeg")) + ], + ) + ) + continue + + role = "user" if msg.role == "user" else "model" + parts: list[types.Part] = [] + + if msg.role == "assistant" and msg.tool_calls: + for call in msg.tool_calls: + try: + args_obj = json.loads(call.function.arguments or "{}") + except (TypeError, ValueError): + args_obj = {} + parts.append( + types.Part( + function_call=types.FunctionCall( + name=call.function.name, args=args_obj + ) + ) + ) + + if msg.content is not None: + if isinstance(msg.content, str): + parts.append(types.Part.from_text(text=msg.content)) + else: + for part in msg.content: + if part.type == "text": + parts.append(types.Part.from_text(text=part.text)) + elif part.type == "image_url" and part.image_url: + url = part.image_url.url + if url.startswith("data:"): + header, b64_data = url.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] + import base64 + + image_bytes = base64.b64decode(b64_data) + parts.append( + types.Part.from_bytes( + data=image_bytes, mime_type=mime_type + ) + ) + else: + parts.append( + types.Part.from_uri( + file_uri=url, mime_type="image/jpeg" + ) + ) + + if parts: contents.append(types.Content(role=role, parts=parts)) return system_instruction, contents @@ -199,6 +338,10 @@ class GoogleProvider(LLMProvider): stop_seqs = request.stop if isinstance(request.stop, list) else [request.stop] config["stop_sequences"] = stop_seqs + gemini_tools = _build_gemini_tools(request.tools) + if gemini_tools: + config["tools"] = gemini_tools + gen_config = types.GenerateContentConfig( system_instruction=system_instruction, **config, @@ -217,7 +360,9 @@ class GoogleProvider(LLMProvider): except Exception as err: raise _wrap_gemini_call_error(err, gemini_model) from err - content = _unwrap_gemini_response(response, gemini_model) + content, tool_calls, finish_reason = _unwrap_gemini_response( + response, gemini_model + ) usage_meta = response.usage_metadata return ChatCompletionResponse( @@ -225,8 +370,11 @@ class GoogleProvider(LLMProvider): choices=[ Choice( index=0, - message=MessageResponse(content=content), - finish_reason="stop", + message=MessageResponse( + content=content or None, + tool_calls=tool_calls, + ), + finish_reason=finish_reason, ) ], usage=Usage( @@ -253,6 +401,10 @@ class GoogleProvider(LLMProvider): if request.top_p is not None: config["top_p"] = request.top_p + gemini_tools = _build_gemini_tools(request.tools) + if gemini_tools: + config["tools"] = gemini_tools + gen_config = types.GenerateContentConfig( system_instruction=system_instruction, **config, @@ -297,7 +449,10 @@ class GoogleProvider(LLMProvider): # Post-stream check: if the stream ended without emitting any text, # surface the structured reason instead of quietly closing with an - # empty "stop". Matches _unwrap_gemini_response semantics. + # empty "stop". Matches _unwrap_gemini_response semantics. We + # discard the return value here — the streaming path produces its + # content chunk-by-chunk above, so we only need this call for its + # side-effect of raising on SAFETY / RECITATION / MAX_TOKENS. if not emitted_any_text and last_chunk is not None: _unwrap_gemini_response(last_chunk, gemini_model) diff --git a/services/mana-llm/src/providers/ollama.py b/services/mana-llm/src/providers/ollama.py index 5990d1237..0cc22fe5f 100644 --- a/services/mana-llm/src/providers/ollama.py +++ b/services/mana-llm/src/providers/ollama.py @@ -20,14 +20,76 @@ from src.models import ( MessageResponse, ModelInfo, StreamChoice, + ToolCall, + ToolCallFunction, Usage, ) from .base import LLMProvider +# Ollama emits tool_calls under /api/chat as: +# {"message":{"content":"", "tool_calls":[{"function":{"name":"x","arguments":{...}}}]}} +# Arguments arrive as a dict — we normalise to a JSON string to match +# the OpenAI-spec shape our downstream code expects. + +# Ollama models known to support tool-calling reliably (as of 0.3+). +# Everything else: we still pass `tools` to the API (it ignores them on +# incompatible models), but the assistant response will be plain text. +# A shared-ai-level capability check rejects these cases before the call. +TOOL_CAPABLE_OLLAMA_PATTERNS: tuple[str, ...] = ( + "llama3.1", + "llama3.2", + "llama3.3", + "qwen2.5", + "qwen3", + "mistral", + "mixtral", + "command-r", + "firefunction", +) + logger = logging.getLogger(__name__) +def _safe_parse_args(raw: str | None) -> dict[str, Any]: + """Best-effort parse of a JSON-encoded arguments string.""" + if not raw: + return {} + try: + parsed = json.loads(raw) + except (TypeError, ValueError): + return {} + return parsed if isinstance(parsed, dict) else {} + + +def _tool_calls_from_ollama(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None: + """Normalise Ollama's tool_calls into our ToolCall model.""" + if not raw: + return None + import uuid + + calls: list[ToolCall] = [] + for idx, c in enumerate(raw): + fn = c.get("function") or {} + name = fn.get("name") + if not name: + continue + args = fn.get("arguments") + if isinstance(args, dict): + arguments_json = json.dumps(args, ensure_ascii=False) + elif isinstance(args, str): + arguments_json = args + else: + arguments_json = "{}" + calls.append( + ToolCall( + id=c.get("id") or f"call_{uuid.uuid4().hex[:12]}", + function=ToolCallFunction(name=name, arguments=arguments_json), + ) + ) + return calls or None + + def _strip_json_fences(content: str) -> str: """Strip ```json ... ``` markdown fences from a string if present. @@ -53,6 +115,17 @@ class OllamaProvider(LLMProvider): """Ollama LLM provider.""" name = "ollama" + supports_tools = True + + def model_supports_tools(self, model: str) -> bool: + """Narrow tool capability to models trained for function calling. + + Ollama accepts the ``tools`` payload even for incompatible models + but silently drops it — the assistant just replies in prose. We + reject those requests upfront so callers get a clear error. + """ + lower = model.lower() + return any(pattern in lower for pattern in TOOL_CAPABLE_OLLAMA_PATTERNS) def __init__(self, base_url: str | None = None, timeout: int | None = None): self.base_url = (base_url or settings.ollama_url).rstrip("/") @@ -70,37 +143,56 @@ class OllamaProvider(LLMProvider): return self._client def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]: - """Convert OpenAI message format to Ollama format.""" - messages = [] + """Convert OpenAI message format to Ollama format. + + Ollama's ``/api/chat`` uses the OpenAI shape for assistant + ``tool_calls`` and ``tool`` result messages (with args as objects + rather than JSON strings on output — input still accepts JSON + strings), so we largely pass them through. + """ + messages: list[dict[str, Any]] = [] for msg in request.messages: - if isinstance(msg.content, str): - messages.append({"role": msg.role, "content": msg.content}) + out: dict[str, Any] = {"role": msg.role} + + if msg.tool_calls: + out["tool_calls"] = [ + { + "function": { + "name": c.function.name, + # Ollama accepts either stringified JSON or + # already-parsed objects; send parsed for + # better compatibility with older models. + "arguments": _safe_parse_args(c.function.arguments), + } + } + for c in msg.tool_calls + ] + + if msg.content is None: + out["content"] = "" + elif isinstance(msg.content, str): + out["content"] = msg.content else: - # Handle multimodal content (vision) - text_parts = [] - images = [] + text_parts: list[str] = [] + images: list[str] = [] for part in msg.content: if part.type == "text": text_parts.append(part.text) elif part.type == "image_url": url = part.image_url.url - # Extract base64 data from data URL if url.startswith("data:"): - # Format: data:image/png;base64, base64_data = url.split(",", 1)[1] if "," in url else url images.append(base64_data) else: - # HTTP URL - Ollama expects base64, so we'd need to fetch - # For now, log warning and skip - logger.warning(f"HTTP image URLs not supported, skipping: {url[:50]}...") - - message_data: dict[str, Any] = { - "role": msg.role, - "content": " ".join(text_parts), - } + logger.warning( + "HTTP image URLs not supported, skipping: %s...", + url[:50], + ) + out["content"] = " ".join(text_parts) if images: - message_data["images"] = images - messages.append(message_data) + out["images"] = images + + messages.append(out) return messages async def chat_completion( @@ -152,23 +244,39 @@ class OllamaProvider(LLMProvider): if options: payload["options"] = options + if request.tools: + payload["tools"] = [ + {"type": "function", "function": t.function.model_dump()} + for t in request.tools + ] + logger.debug(f"Ollama request: {model}, messages: {len(request.messages)}") response = await self.client.post("/api/chat", json=payload) response.raise_for_status() data = response.json() + message = data.get("message", {}) + tool_calls = _tool_calls_from_ollama(message.get("tool_calls")) # Defensive fence-stripping: even with `format` set, some older # Ollama versions still emit ```json ... ``` wrappers for vision # models. Strip them so strict downstream parsers see clean JSON. - content = _strip_json_fences(data["message"]["content"]) + raw_content = message.get("content", "") or "" + content = _strip_json_fences(raw_content) if raw_content else "" + + finish_reason = "tool_calls" if tool_calls else ( + "stop" if data.get("done") else None + ) return ChatCompletionResponse( model=f"ollama/{model}", choices=[ Choice( - message=MessageResponse(content=content), - finish_reason="stop" if data.get("done") else None, + message=MessageResponse( + content=content or None, + tool_calls=tool_calls, + ), + finish_reason=finish_reason, ) ], usage=Usage( @@ -204,6 +312,12 @@ class OllamaProvider(LLMProvider): if options: payload["options"] = options + if request.tools: + payload["tools"] = [ + {"type": "function", "function": t.function.model_dump()} + for t in request.tools + ] + logger.debug(f"Ollama streaming request: {model}") response_id = f"chatcmpl-{model[:8]}" diff --git a/services/mana-llm/src/providers/openai_compat.py b/services/mana-llm/src/providers/openai_compat.py index e5565757e..a720fe7f2 100644 --- a/services/mana-llm/src/providers/openai_compat.py +++ b/services/mana-llm/src/providers/openai_compat.py @@ -19,6 +19,8 @@ from src.models import ( MessageResponse, ModelInfo, StreamChoice, + ToolCall, + ToolCallFunction, Usage, ) @@ -27,9 +29,37 @@ from .base import LLMProvider logger = logging.getLogger(__name__) +def _tool_calls_from_openai(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None: + """Normalise an OpenAI-spec ``tool_calls`` array into our model shape.""" + if not raw: + return None + calls: list[ToolCall] = [] + for c in raw: + fn = c.get("function") or {} + name = fn.get("name") + if not name: + continue + calls.append( + ToolCall( + id=c.get("id") or "", + function=ToolCallFunction( + name=name, + arguments=fn.get("arguments") or "{}", + ), + ) + ) + return calls or None + + class OpenAICompatProvider(LLMProvider): """OpenAI-compatible API provider (OpenRouter, Groq, Together, etc.).""" + # OpenRouter/Groq/Together all expose tool_calls per the OpenAI spec; + # individual models within those services may or may not support it, + # but the request shape is uniform. The upstream returns a proper + # error for unsupported models — no silent downgrade here. + supports_tools = True + def __init__( self, name: str, @@ -60,23 +90,53 @@ class OpenAICompatProvider(LLMProvider): return self._client def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]: - """Convert internal message format to OpenAI format.""" - messages = [] + """Convert internal message format to OpenAI format. + + The OpenAI chat-completions endpoint is the source of truth for + this shape, so most fields pass through verbatim — including + ``role='tool'`` messages with ``tool_call_id`` + ``content`` and + assistant messages carrying ``tool_calls[]``. + """ + messages: list[dict[str, Any]] = [] for msg in request.messages: - if isinstance(msg.content, str): - messages.append({"role": msg.role, "content": msg.content}) + out: dict[str, Any] = {"role": msg.role} + + if msg.tool_call_id: + out["tool_call_id"] = msg.tool_call_id + + if msg.tool_calls: + out["tool_calls"] = [ + { + "id": c.id, + "type": c.type, + "function": { + "name": c.function.name, + "arguments": c.function.arguments, + }, + } + for c in msg.tool_calls + ] + + if msg.content is None: + # Assistant tool-call messages have null content per spec. + out["content"] = None + elif isinstance(msg.content, str): + out["content"] = msg.content else: - # Handle multimodal content content_parts = [] for part in msg.content: if part.type == "text": content_parts.append({"type": "text", "text": part.text}) elif part.type == "image_url": - content_parts.append({ - "type": "image_url", - "image_url": {"url": part.image_url.url}, - }) - messages.append({"role": msg.role, "content": content_parts}) + content_parts.append( + { + "type": "image_url", + "image_url": {"url": part.image_url.url}, + } + ) + out["content"] = content_parts + + messages.append(out) return messages async def chat_completion( @@ -104,6 +164,17 @@ class OpenAICompatProvider(LLMProvider): payload["presence_penalty"] = request.presence_penalty if request.stop: payload["stop"] = request.stop + if request.tools: + payload["tools"] = [ + {"type": "function", "function": t.function.model_dump()} + for t in request.tools + ] + if request.tool_choice is not None: + payload["tool_choice"] = ( + request.tool_choice + if isinstance(request.tool_choice, str) + else request.tool_choice.model_dump() + ) logger.debug(f"{self.name} request: {model}, messages: {len(request.messages)}") @@ -117,7 +188,12 @@ class OpenAICompatProvider(LLMProvider): choices=[ Choice( index=choice.get("index", 0), - message=MessageResponse(content=choice["message"]["content"]), + message=MessageResponse( + content=choice["message"].get("content"), + tool_calls=_tool_calls_from_openai( + choice["message"].get("tool_calls") + ), + ), finish_reason=choice.get("finish_reason", "stop"), ) for choice in data.get("choices", []) @@ -154,6 +230,17 @@ class OpenAICompatProvider(LLMProvider): payload["presence_penalty"] = request.presence_penalty if request.stop: payload["stop"] = request.stop + if request.tools: + payload["tools"] = [ + {"type": "function", "function": t.function.model_dump()} + for t in request.tools + ] + if request.tool_choice is not None: + payload["tool_choice"] = ( + request.tool_choice + if isinstance(request.tool_choice, str) + else request.tool_choice.model_dump() + ) logger.debug(f"{self.name} streaming request: {model}") @@ -190,6 +277,9 @@ class OpenAICompatProvider(LLMProvider): delta=DeltaContent( role=delta.get("role"), content=delta.get("content"), + tool_calls=_tool_calls_from_openai( + delta.get("tool_calls") + ), ), finish_reason=choice.get("finish_reason"), ) diff --git a/services/mana-llm/src/providers/router.py b/services/mana-llm/src/providers/router.py index 3111b7a93..dfec1ac07 100644 --- a/services/mana-llm/src/providers/router.py +++ b/services/mana-llm/src/providers/router.py @@ -17,6 +17,7 @@ from src.models import ( ) from .base import LLMProvider +from .errors import ProviderCapabilityError from .ollama import OllamaProvider from .openai_compat import OpenAICompatProvider @@ -154,6 +155,23 @@ class ProviderRouter: ) return await google.chat_completion(request, gemini_model) + def _check_tool_capability( + self, provider: LLMProvider, model_name: str, request: ChatCompletionRequest + ) -> None: + """Refuse tool-bearing requests for providers/models without tool support. + + Silent downgrade (dropping the `tools` payload) is more dangerous + than an explicit error — the caller would get plain text back and + have no way to tell the tools never reached the model. + """ + if not request.tools: + return + if not provider.model_supports_tools(model_name): + raise ProviderCapabilityError( + f"{provider.name}/{model_name} does not support tool calling. " + "Choose a tool-capable model (e.g. gemini-2.5-flash, llama3.1:*)" + ) + async def chat_completion( self, request: ChatCompletionRequest, @@ -164,6 +182,7 @@ class ProviderRouter: # Non-Ollama providers: direct routing, no fallback if provider_name != "ollama": provider = self._get_provider(provider_name) + self._check_tool_capability(provider, model_name, request) logger.info(f"Routing chat completion to {provider_name}/{model_name}") return await provider.chat_completion(request, model_name) @@ -176,6 +195,7 @@ class ProviderRouter: # Try Ollama first provider = self._get_provider("ollama") + self._check_tool_capability(provider, model_name, request) logger.info(f"Routing chat completion to ollama/{model_name}") self._ollama_concurrent += 1 @@ -199,6 +219,7 @@ class ProviderRouter: # Non-Ollama: direct if provider_name != "ollama": provider = self._get_provider(provider_name) + self._check_tool_capability(provider, model_name, request) logger.info(f"Routing streaming to {provider_name}/{model_name}") async for chunk in provider.chat_completion_stream(request, model_name): yield chunk @@ -219,6 +240,7 @@ class ProviderRouter: return provider = self._get_provider("ollama") + self._check_tool_capability(provider, model_name, request) logger.info(f"Routing streaming to ollama/{model_name}") self._ollama_concurrent += 1