feat(mana-llm): add OpenAI-style tools + tool_calls passthrough

Extends the chat-completions surface so callers can ask any provider
to call named functions and get structured tool_calls back. Wired
through all three provider adapters so the planner and companion can
switch off the fragile JSON-parsing pathway.

- Request: tools[], tool_choice, assistant tool_calls, tool-role
  messages with tool_call_id.
- Response: MessageResponse.tool_calls, Choice.finish_reason adds
  "tool_calls", DeltaContent streams tool_calls.
- Google provider: Tool(function_declarations=...) build, result
  normalised (args dict → JSON string), function_response parts on
  a user turn for tool-role messages.
- OpenAI-compat: 1:1 passthrough of the OpenAI spec.
- Ollama: /api/chat passthrough; model-level capability check via a
  TOOL_CAPABLE_OLLAMA_PATTERNS whitelist (llama3.1+, qwen2.5+,
  mistral, command-r, …) — unsupported models rejected rather than
  silently falling back to prose.
- Router: model_supports_tools() check upfront for both streaming
  and non-streaming paths; ProviderCapabilityError bubbles as 400.

No silent downgrade. Missing tool support = explicit error.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Till JS 2026-04-20 15:22:48 +02:00
parent c612a22371
commit e757470cb0
8 changed files with 568 additions and 81 deletions

View file

@ -1,6 +1,13 @@
"""Pydantic models for OpenAI-compatible API."""
from .requests import ChatCompletionRequest, EmbeddingRequest
from .requests import (
ChatCompletionRequest,
EmbeddingRequest,
FunctionSpec,
Message,
ToolChoice,
ToolSpec,
)
from .responses import (
ChatCompletionResponse,
ChatCompletionStreamResponse,
@ -12,6 +19,8 @@ from .responses import (
ModelInfo,
ModelsResponse,
StreamChoice,
ToolCall,
ToolCallFunction,
Usage,
)
@ -24,9 +33,15 @@ __all__ = [
"EmbeddingData",
"EmbeddingRequest",
"EmbeddingResponse",
"FunctionSpec",
"Message",
"MessageResponse",
"ModelInfo",
"ModelsResponse",
"StreamChoice",
"ToolCall",
"ToolCallFunction",
"ToolChoice",
"ToolSpec",
"Usage",
]

View file

@ -28,11 +28,64 @@ class ImageContent(BaseModel):
MessageContent = str | list[TextContent | ImageContent]
class Message(BaseModel):
"""A single message in the conversation."""
class ToolCallFunction(BaseModel):
"""The function portion of a tool_call on an assistant message."""
role: Literal["system", "user", "assistant"]
content: MessageContent
name: str
# Arguments are passed as a JSON string (OpenAI spec). Providers may
# emit structured args natively; the adapter serialises them here.
arguments: str
class ToolCall(BaseModel):
"""A tool invocation the assistant decided to make."""
id: str
type: Literal["function"] = "function"
function: ToolCallFunction
class Message(BaseModel):
"""A single message in the conversation.
`tool` messages carry the result of a previously-requested tool call
back into the context; they must reference the originating call via
``tool_call_id``. Assistant messages may contain either plain
``content`` or ``tool_calls`` (or both, though providers typically
only emit one at a time).
"""
role: Literal["system", "user", "assistant", "tool"]
content: MessageContent | None = None
tool_call_id: str | None = None
tool_calls: list[ToolCall] | None = None
class FunctionSpec(BaseModel):
"""OpenAI-style function declaration."""
name: str
description: str
# JSON Schema for the function's parameters. Kept as a dict here to
# stay loose; providers translate it to their native shape.
parameters: dict[str, Any] = Field(default_factory=dict)
class ToolSpec(BaseModel):
"""A tool the model may call. Only ``function`` tools are supported."""
type: Literal["function"] = "function"
function: FunctionSpec
class ToolChoiceFunction(BaseModel):
"""Force-pick a specific function for ``tool_choice``."""
type: Literal["function"] = "function"
function: dict[str, str] # {"name": "tool_name"}
ToolChoice = Literal["auto", "none", "required"] | ToolChoiceFunction
class ResponseFormat(BaseModel):
@ -63,6 +116,8 @@ class ChatCompletionRequest(BaseModel):
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
stop: str | list[str] | None = None
response_format: ResponseFormat | None = None
tools: list[ToolSpec] | None = None
tool_choice: ToolChoice | None = None
class EmbeddingRequest(BaseModel):

View file

@ -15,11 +15,27 @@ class Usage(BaseModel):
total_tokens: int = 0
class ToolCallFunction(BaseModel):
"""Function portion of a tool_call the assistant produced."""
name: str
arguments: str # JSON-encoded (OpenAI spec)
class ToolCall(BaseModel):
"""A tool invocation the assistant decided to make."""
id: str
type: Literal["function"] = "function"
function: ToolCallFunction
class MessageResponse(BaseModel):
"""Response message from the model."""
role: Literal["assistant"] = "assistant"
content: str
content: str | None = None
tool_calls: list[ToolCall] | None = None
class Choice(BaseModel):
@ -27,7 +43,9 @@ class Choice(BaseModel):
index: int = 0
message: MessageResponse
finish_reason: Literal["stop", "length", "content_filter"] | None = "stop"
finish_reason: (
Literal["stop", "length", "content_filter", "tool_calls"] | None
) = "stop"
class ChatCompletionResponse(BaseModel):
@ -46,6 +64,7 @@ class DeltaContent(BaseModel):
role: Literal["assistant"] | None = None
content: str | None = None
tool_calls: list[ToolCall] | None = None
class StreamChoice(BaseModel):
@ -53,7 +72,9 @@ class StreamChoice(BaseModel):
index: int = 0
delta: DeltaContent
finish_reason: Literal["stop", "length", "content_filter"] | None = None
finish_reason: (
Literal["stop", "length", "content_filter", "tool_calls"] | None
) = None
class ChatCompletionStreamResponse(BaseModel):

View file

@ -19,6 +19,21 @@ class LLMProvider(ABC):
name: str = "base"
# Set to True if the provider supports OpenAI-style `tools` + `tool_calls`
# for chat completions. The router rejects tool-bearing requests routed
# to providers without support rather than silently dropping the tools.
# Provider adapters may further narrow this per-model if needed.
supports_tools: bool = False
def model_supports_tools(self, model: str) -> bool:
"""Check if a specific model within this provider supports tools.
Default: falls back to the provider-wide flag. Providers with a
mixed capability surface (e.g. Ollama depends on the local
model) override this.
"""
return self.supports_tools
@abstractmethod
async def chat_completion(
self,

View file

@ -1,6 +1,8 @@
"""Google Gemini provider for mana-llm (fallback when Ollama is unavailable)."""
import json
import logging
import uuid
from collections.abc import AsyncIterator
from typing import Any
@ -20,6 +22,9 @@ from src.models import (
MessageResponse,
ModelInfo,
StreamChoice,
ToolCall,
ToolCallFunction,
ToolSpec,
Usage,
)
@ -35,14 +40,66 @@ from .errors import (
logger = logging.getLogger(__name__)
def _unwrap_gemini_response(response: Any, gemini_model: str) -> str:
"""Validate a non-streaming Gemini response and return its text.
def _build_gemini_tools(tools: list[ToolSpec] | None) -> list[types.Tool] | None:
"""Translate our ToolSpec list into Gemini ``types.Tool`` declarations."""
if not tools:
return None
declarations: list[types.FunctionDeclaration] = []
for t in tools:
declarations.append(
types.FunctionDeclaration(
name=t.function.name,
description=t.function.description,
parameters=t.function.parameters or None,
)
)
return [types.Tool(function_declarations=declarations)]
Raises a structured ProviderError if the response was blocked,
truncated, or otherwise produced no usable text. The SDK's
``response.text`` accessor silently returns an empty string in all
of those cases, which downstream consumers (e.g. the planner
parser) cannot distinguish from a well-formed empty completion.
def _extract_tool_calls(candidate: Any) -> list[ToolCall] | None:
"""Pull any ``function_call`` parts off a candidate into ToolCalls.
Gemini emits tool calls as ``content.parts[i].function_call`` with a
``FunctionCall(name, args)`` where ``args`` is a dict (not a JSON
string). We normalise to OpenAI shape: arguments are JSON-encoded
strings so downstream handlers can treat all providers the same.
"""
if candidate is None:
return None
content = getattr(candidate, "content", None)
parts = getattr(content, "parts", None) or []
calls: list[ToolCall] = []
for part in parts:
fc = getattr(part, "function_call", None)
if fc is None:
continue
name = getattr(fc, "name", None)
if not name:
continue
args = getattr(fc, "args", None) or {}
# Gemini's args are already dict-shaped; serialise to JSON string.
try:
arguments_json = json.dumps(dict(args), ensure_ascii=False)
except (TypeError, ValueError):
arguments_json = json.dumps({}, ensure_ascii=False)
calls.append(
ToolCall(
id=f"call_{uuid.uuid4().hex[:12]}",
function=ToolCallFunction(name=name, arguments=arguments_json),
)
)
return calls or None
def _unwrap_gemini_response(
response: Any, gemini_model: str
) -> tuple[str, list[ToolCall] | None, str]:
"""Validate a non-streaming Gemini response.
Returns ``(text, tool_calls, finish_reason)``. Raises a structured
``ProviderError`` if the response was blocked, truncated, or
otherwise produced no usable payload. ``response.text`` silently
returns ``""`` on blocked responses we refuse to propagate that.
"""
candidates = getattr(response, "candidates", None) or []
candidate = candidates[0] if candidates else None
@ -56,6 +113,7 @@ def _unwrap_gemini_response(response: Any, gemini_model: str) -> str:
finish_name = finish_name.rsplit(".", 1)[-1]
text = response.text or ""
tool_calls = _extract_tool_calls(candidate)
if finish_name in {"SAFETY", "RECITATION", "PROHIBITED_CONTENT", "SPII", "BLOCKLIST"}:
# Pull the first safety rating that actually blocked if present.
@ -80,13 +138,37 @@ def _unwrap_gemini_response(response: Any, gemini_model: str) -> str:
)
raise ProviderTruncatedError(partial_text=text or None)
if not text and finish_name not in (None, "STOP"):
# Unknown finish reason, empty text — surface instead of silent "".
if not text and not tool_calls and finish_name not in (None, "STOP"):
# Unknown finish reason, nothing to return — surface instead of "".
raise ProviderError(
f"Gemini returned no content (finish_reason={finish_name})"
)
return text
# Normalise the finish_reason to our OpenAI-compatible vocabulary.
openai_finish = "stop"
if tool_calls:
openai_finish = "tool_calls"
elif finish_name == "MAX_TOKENS":
openai_finish = "length"
return text, tool_calls, openai_finish
def _lookup_tool_name(messages: list[Any], tool_call_id: str | None) -> str | None:
"""Find the tool name for a given ``tool_call_id``.
OpenAI's tool-message carries only the id; the matching name lives
on the preceding assistant message's ``tool_calls[]``. We scan from
the end (most recent assistant turn) backwards.
"""
if not tool_call_id:
return None
for m in reversed(messages):
if m.role != "assistant" or not m.tool_calls:
continue
for call in m.tool_calls:
if call.id == tool_call_id:
return call.function.name
return None
def _wrap_gemini_call_error(err: Exception, gemini_model: str) -> ProviderError:
@ -123,6 +205,8 @@ class GoogleProvider(LLMProvider):
"""Google Gemini API provider."""
name = "google"
# Gemini 2.x supports OpenAI-style function calling across all listed models.
supports_tools = True
def __init__(self, api_key: str, default_model: str = "gemini-2.5-flash"):
self.api_key = api_key
@ -138,43 +222,98 @@ class GoogleProvider(LLMProvider):
) -> tuple[str | None, list[types.Content]]:
"""Convert OpenAI-format messages to Google Gemini format.
Returns (system_instruction, contents).
Returns (system_instruction, contents). Handles text, multimodal
image content, assistant messages carrying ``tool_calls``, and
``tool`` result messages (mapped to Gemini ``function_response``
parts Gemini has no ``tool`` role, function responses ride on
a ``user`` turn).
"""
system_instruction: str | None = None
contents: list[types.Content] = []
for msg in request.messages:
if msg.role == "system":
# Gemini uses system_instruction separately
if isinstance(msg.content, str):
system_instruction = msg.content
continue
role = "user" if msg.role == "user" else "model"
if isinstance(msg.content, str):
contents.append(types.Content(role=role, parts=[types.Part.from_text(text=msg.content)]))
else:
# Multimodal content
parts: list[types.Part] = []
for part in msg.content:
if part.type == "text":
parts.append(types.Part.from_text(text=part.text))
elif part.type == "image_url" and part.image_url:
url = part.image_url.url
if url.startswith("data:"):
# Parse data URI: data:image/jpeg;base64,<data>
header, b64_data = url.split(",", 1)
mime_type = header.split(":")[1].split(";")[0]
import base64
image_bytes = base64.b64decode(b64_data)
parts.append(
types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
# Tool result message → function_response Part on a user turn.
if msg.role == "tool":
# The content is the stringified tool result. We also need
# a tool name — the OpenAI spec carries it on the matching
# assistant tool_call, keyed by tool_call_id. We don't
# track that back-reference here, so we pull the name
# from the preceding assistant message's tool_calls.
name = _lookup_tool_name(request.messages, msg.tool_call_id)
if not name:
continue # orphan tool message — skip silently
payload: Any
if isinstance(msg.content, str):
try:
payload = json.loads(msg.content)
if not isinstance(payload, dict):
payload = {"result": payload}
except (TypeError, ValueError):
payload = {"result": msg.content}
else:
payload = {"result": ""}
contents.append(
types.Content(
role="user",
parts=[
types.Part.from_function_response(
name=name, response=payload
)
else:
# URL-based image - use as URI
parts.append(types.Part.from_uri(file_uri=url, mime_type="image/jpeg"))
],
)
)
continue
role = "user" if msg.role == "user" else "model"
parts: list[types.Part] = []
if msg.role == "assistant" and msg.tool_calls:
for call in msg.tool_calls:
try:
args_obj = json.loads(call.function.arguments or "{}")
except (TypeError, ValueError):
args_obj = {}
parts.append(
types.Part(
function_call=types.FunctionCall(
name=call.function.name, args=args_obj
)
)
)
if msg.content is not None:
if isinstance(msg.content, str):
parts.append(types.Part.from_text(text=msg.content))
else:
for part in msg.content:
if part.type == "text":
parts.append(types.Part.from_text(text=part.text))
elif part.type == "image_url" and part.image_url:
url = part.image_url.url
if url.startswith("data:"):
header, b64_data = url.split(",", 1)
mime_type = header.split(":")[1].split(";")[0]
import base64
image_bytes = base64.b64decode(b64_data)
parts.append(
types.Part.from_bytes(
data=image_bytes, mime_type=mime_type
)
)
else:
parts.append(
types.Part.from_uri(
file_uri=url, mime_type="image/jpeg"
)
)
if parts:
contents.append(types.Content(role=role, parts=parts))
return system_instruction, contents
@ -199,6 +338,10 @@ class GoogleProvider(LLMProvider):
stop_seqs = request.stop if isinstance(request.stop, list) else [request.stop]
config["stop_sequences"] = stop_seqs
gemini_tools = _build_gemini_tools(request.tools)
if gemini_tools:
config["tools"] = gemini_tools
gen_config = types.GenerateContentConfig(
system_instruction=system_instruction,
**config,
@ -217,7 +360,9 @@ class GoogleProvider(LLMProvider):
except Exception as err:
raise _wrap_gemini_call_error(err, gemini_model) from err
content = _unwrap_gemini_response(response, gemini_model)
content, tool_calls, finish_reason = _unwrap_gemini_response(
response, gemini_model
)
usage_meta = response.usage_metadata
return ChatCompletionResponse(
@ -225,8 +370,11 @@ class GoogleProvider(LLMProvider):
choices=[
Choice(
index=0,
message=MessageResponse(content=content),
finish_reason="stop",
message=MessageResponse(
content=content or None,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
],
usage=Usage(
@ -253,6 +401,10 @@ class GoogleProvider(LLMProvider):
if request.top_p is not None:
config["top_p"] = request.top_p
gemini_tools = _build_gemini_tools(request.tools)
if gemini_tools:
config["tools"] = gemini_tools
gen_config = types.GenerateContentConfig(
system_instruction=system_instruction,
**config,
@ -297,7 +449,10 @@ class GoogleProvider(LLMProvider):
# Post-stream check: if the stream ended without emitting any text,
# surface the structured reason instead of quietly closing with an
# empty "stop". Matches _unwrap_gemini_response semantics.
# empty "stop". Matches _unwrap_gemini_response semantics. We
# discard the return value here — the streaming path produces its
# content chunk-by-chunk above, so we only need this call for its
# side-effect of raising on SAFETY / RECITATION / MAX_TOKENS.
if not emitted_any_text and last_chunk is not None:
_unwrap_gemini_response(last_chunk, gemini_model)

View file

@ -20,14 +20,76 @@ from src.models import (
MessageResponse,
ModelInfo,
StreamChoice,
ToolCall,
ToolCallFunction,
Usage,
)
from .base import LLMProvider
# Ollama emits tool_calls under /api/chat as:
# {"message":{"content":"", "tool_calls":[{"function":{"name":"x","arguments":{...}}}]}}
# Arguments arrive as a dict — we normalise to a JSON string to match
# the OpenAI-spec shape our downstream code expects.
# Ollama models known to support tool-calling reliably (as of 0.3+).
# Everything else: we still pass `tools` to the API (it ignores them on
# incompatible models), but the assistant response will be plain text.
# A shared-ai-level capability check rejects these cases before the call.
TOOL_CAPABLE_OLLAMA_PATTERNS: tuple[str, ...] = (
"llama3.1",
"llama3.2",
"llama3.3",
"qwen2.5",
"qwen3",
"mistral",
"mixtral",
"command-r",
"firefunction",
)
logger = logging.getLogger(__name__)
def _safe_parse_args(raw: str | None) -> dict[str, Any]:
"""Best-effort parse of a JSON-encoded arguments string."""
if not raw:
return {}
try:
parsed = json.loads(raw)
except (TypeError, ValueError):
return {}
return parsed if isinstance(parsed, dict) else {}
def _tool_calls_from_ollama(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None:
"""Normalise Ollama's tool_calls into our ToolCall model."""
if not raw:
return None
import uuid
calls: list[ToolCall] = []
for idx, c in enumerate(raw):
fn = c.get("function") or {}
name = fn.get("name")
if not name:
continue
args = fn.get("arguments")
if isinstance(args, dict):
arguments_json = json.dumps(args, ensure_ascii=False)
elif isinstance(args, str):
arguments_json = args
else:
arguments_json = "{}"
calls.append(
ToolCall(
id=c.get("id") or f"call_{uuid.uuid4().hex[:12]}",
function=ToolCallFunction(name=name, arguments=arguments_json),
)
)
return calls or None
def _strip_json_fences(content: str) -> str:
"""Strip ```json ... ``` markdown fences from a string if present.
@ -53,6 +115,17 @@ class OllamaProvider(LLMProvider):
"""Ollama LLM provider."""
name = "ollama"
supports_tools = True
def model_supports_tools(self, model: str) -> bool:
"""Narrow tool capability to models trained for function calling.
Ollama accepts the ``tools`` payload even for incompatible models
but silently drops it the assistant just replies in prose. We
reject those requests upfront so callers get a clear error.
"""
lower = model.lower()
return any(pattern in lower for pattern in TOOL_CAPABLE_OLLAMA_PATTERNS)
def __init__(self, base_url: str | None = None, timeout: int | None = None):
self.base_url = (base_url or settings.ollama_url).rstrip("/")
@ -70,37 +143,56 @@ class OllamaProvider(LLMProvider):
return self._client
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
"""Convert OpenAI message format to Ollama format."""
messages = []
"""Convert OpenAI message format to Ollama format.
Ollama's ``/api/chat`` uses the OpenAI shape for assistant
``tool_calls`` and ``tool`` result messages (with args as objects
rather than JSON strings on output input still accepts JSON
strings), so we largely pass them through.
"""
messages: list[dict[str, Any]] = []
for msg in request.messages:
if isinstance(msg.content, str):
messages.append({"role": msg.role, "content": msg.content})
out: dict[str, Any] = {"role": msg.role}
if msg.tool_calls:
out["tool_calls"] = [
{
"function": {
"name": c.function.name,
# Ollama accepts either stringified JSON or
# already-parsed objects; send parsed for
# better compatibility with older models.
"arguments": _safe_parse_args(c.function.arguments),
}
}
for c in msg.tool_calls
]
if msg.content is None:
out["content"] = ""
elif isinstance(msg.content, str):
out["content"] = msg.content
else:
# Handle multimodal content (vision)
text_parts = []
images = []
text_parts: list[str] = []
images: list[str] = []
for part in msg.content:
if part.type == "text":
text_parts.append(part.text)
elif part.type == "image_url":
url = part.image_url.url
# Extract base64 data from data URL
if url.startswith("data:"):
# Format: data:image/png;base64,<base64_data>
base64_data = url.split(",", 1)[1] if "," in url else url
images.append(base64_data)
else:
# HTTP URL - Ollama expects base64, so we'd need to fetch
# For now, log warning and skip
logger.warning(f"HTTP image URLs not supported, skipping: {url[:50]}...")
message_data: dict[str, Any] = {
"role": msg.role,
"content": " ".join(text_parts),
}
logger.warning(
"HTTP image URLs not supported, skipping: %s...",
url[:50],
)
out["content"] = " ".join(text_parts)
if images:
message_data["images"] = images
messages.append(message_data)
out["images"] = images
messages.append(out)
return messages
async def chat_completion(
@ -152,23 +244,39 @@ class OllamaProvider(LLMProvider):
if options:
payload["options"] = options
if request.tools:
payload["tools"] = [
{"type": "function", "function": t.function.model_dump()}
for t in request.tools
]
logger.debug(f"Ollama request: {model}, messages: {len(request.messages)}")
response = await self.client.post("/api/chat", json=payload)
response.raise_for_status()
data = response.json()
message = data.get("message", {})
tool_calls = _tool_calls_from_ollama(message.get("tool_calls"))
# Defensive fence-stripping: even with `format` set, some older
# Ollama versions still emit ```json ... ``` wrappers for vision
# models. Strip them so strict downstream parsers see clean JSON.
content = _strip_json_fences(data["message"]["content"])
raw_content = message.get("content", "") or ""
content = _strip_json_fences(raw_content) if raw_content else ""
finish_reason = "tool_calls" if tool_calls else (
"stop" if data.get("done") else None
)
return ChatCompletionResponse(
model=f"ollama/{model}",
choices=[
Choice(
message=MessageResponse(content=content),
finish_reason="stop" if data.get("done") else None,
message=MessageResponse(
content=content or None,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
],
usage=Usage(
@ -204,6 +312,12 @@ class OllamaProvider(LLMProvider):
if options:
payload["options"] = options
if request.tools:
payload["tools"] = [
{"type": "function", "function": t.function.model_dump()}
for t in request.tools
]
logger.debug(f"Ollama streaming request: {model}")
response_id = f"chatcmpl-{model[:8]}"

View file

@ -19,6 +19,8 @@ from src.models import (
MessageResponse,
ModelInfo,
StreamChoice,
ToolCall,
ToolCallFunction,
Usage,
)
@ -27,9 +29,37 @@ from .base import LLMProvider
logger = logging.getLogger(__name__)
def _tool_calls_from_openai(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None:
"""Normalise an OpenAI-spec ``tool_calls`` array into our model shape."""
if not raw:
return None
calls: list[ToolCall] = []
for c in raw:
fn = c.get("function") or {}
name = fn.get("name")
if not name:
continue
calls.append(
ToolCall(
id=c.get("id") or "",
function=ToolCallFunction(
name=name,
arguments=fn.get("arguments") or "{}",
),
)
)
return calls or None
class OpenAICompatProvider(LLMProvider):
"""OpenAI-compatible API provider (OpenRouter, Groq, Together, etc.)."""
# OpenRouter/Groq/Together all expose tool_calls per the OpenAI spec;
# individual models within those services may or may not support it,
# but the request shape is uniform. The upstream returns a proper
# error for unsupported models — no silent downgrade here.
supports_tools = True
def __init__(
self,
name: str,
@ -60,23 +90,53 @@ class OpenAICompatProvider(LLMProvider):
return self._client
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
"""Convert internal message format to OpenAI format."""
messages = []
"""Convert internal message format to OpenAI format.
The OpenAI chat-completions endpoint is the source of truth for
this shape, so most fields pass through verbatim including
``role='tool'`` messages with ``tool_call_id`` + ``content`` and
assistant messages carrying ``tool_calls[]``.
"""
messages: list[dict[str, Any]] = []
for msg in request.messages:
if isinstance(msg.content, str):
messages.append({"role": msg.role, "content": msg.content})
out: dict[str, Any] = {"role": msg.role}
if msg.tool_call_id:
out["tool_call_id"] = msg.tool_call_id
if msg.tool_calls:
out["tool_calls"] = [
{
"id": c.id,
"type": c.type,
"function": {
"name": c.function.name,
"arguments": c.function.arguments,
},
}
for c in msg.tool_calls
]
if msg.content is None:
# Assistant tool-call messages have null content per spec.
out["content"] = None
elif isinstance(msg.content, str):
out["content"] = msg.content
else:
# Handle multimodal content
content_parts = []
for part in msg.content:
if part.type == "text":
content_parts.append({"type": "text", "text": part.text})
elif part.type == "image_url":
content_parts.append({
"type": "image_url",
"image_url": {"url": part.image_url.url},
})
messages.append({"role": msg.role, "content": content_parts})
content_parts.append(
{
"type": "image_url",
"image_url": {"url": part.image_url.url},
}
)
out["content"] = content_parts
messages.append(out)
return messages
async def chat_completion(
@ -104,6 +164,17 @@ class OpenAICompatProvider(LLMProvider):
payload["presence_penalty"] = request.presence_penalty
if request.stop:
payload["stop"] = request.stop
if request.tools:
payload["tools"] = [
{"type": "function", "function": t.function.model_dump()}
for t in request.tools
]
if request.tool_choice is not None:
payload["tool_choice"] = (
request.tool_choice
if isinstance(request.tool_choice, str)
else request.tool_choice.model_dump()
)
logger.debug(f"{self.name} request: {model}, messages: {len(request.messages)}")
@ -117,7 +188,12 @@ class OpenAICompatProvider(LLMProvider):
choices=[
Choice(
index=choice.get("index", 0),
message=MessageResponse(content=choice["message"]["content"]),
message=MessageResponse(
content=choice["message"].get("content"),
tool_calls=_tool_calls_from_openai(
choice["message"].get("tool_calls")
),
),
finish_reason=choice.get("finish_reason", "stop"),
)
for choice in data.get("choices", [])
@ -154,6 +230,17 @@ class OpenAICompatProvider(LLMProvider):
payload["presence_penalty"] = request.presence_penalty
if request.stop:
payload["stop"] = request.stop
if request.tools:
payload["tools"] = [
{"type": "function", "function": t.function.model_dump()}
for t in request.tools
]
if request.tool_choice is not None:
payload["tool_choice"] = (
request.tool_choice
if isinstance(request.tool_choice, str)
else request.tool_choice.model_dump()
)
logger.debug(f"{self.name} streaming request: {model}")
@ -190,6 +277,9 @@ class OpenAICompatProvider(LLMProvider):
delta=DeltaContent(
role=delta.get("role"),
content=delta.get("content"),
tool_calls=_tool_calls_from_openai(
delta.get("tool_calls")
),
),
finish_reason=choice.get("finish_reason"),
)

View file

@ -17,6 +17,7 @@ from src.models import (
)
from .base import LLMProvider
from .errors import ProviderCapabilityError
from .ollama import OllamaProvider
from .openai_compat import OpenAICompatProvider
@ -154,6 +155,23 @@ class ProviderRouter:
)
return await google.chat_completion(request, gemini_model)
def _check_tool_capability(
self, provider: LLMProvider, model_name: str, request: ChatCompletionRequest
) -> None:
"""Refuse tool-bearing requests for providers/models without tool support.
Silent downgrade (dropping the `tools` payload) is more dangerous
than an explicit error the caller would get plain text back and
have no way to tell the tools never reached the model.
"""
if not request.tools:
return
if not provider.model_supports_tools(model_name):
raise ProviderCapabilityError(
f"{provider.name}/{model_name} does not support tool calling. "
"Choose a tool-capable model (e.g. gemini-2.5-flash, llama3.1:*)"
)
async def chat_completion(
self,
request: ChatCompletionRequest,
@ -164,6 +182,7 @@ class ProviderRouter:
# Non-Ollama providers: direct routing, no fallback
if provider_name != "ollama":
provider = self._get_provider(provider_name)
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing chat completion to {provider_name}/{model_name}")
return await provider.chat_completion(request, model_name)
@ -176,6 +195,7 @@ class ProviderRouter:
# Try Ollama first
provider = self._get_provider("ollama")
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing chat completion to ollama/{model_name}")
self._ollama_concurrent += 1
@ -199,6 +219,7 @@ class ProviderRouter:
# Non-Ollama: direct
if provider_name != "ollama":
provider = self._get_provider(provider_name)
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing streaming to {provider_name}/{model_name}")
async for chunk in provider.chat_completion_stream(request, model_name):
yield chunk
@ -219,6 +240,7 @@ class ProviderRouter:
return
provider = self._get_provider("ollama")
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing streaming to ollama/{model_name}")
self._ollama_concurrent += 1