mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-19 20:41:25 +02:00
feat(mana-llm): add OpenAI-style tools + tool_calls passthrough
Extends the chat-completions surface so callers can ask any provider to call named functions and get structured tool_calls back. Wired through all three provider adapters so the planner and companion can switch off the fragile JSON-parsing pathway. - Request: tools[], tool_choice, assistant tool_calls, tool-role messages with tool_call_id. - Response: MessageResponse.tool_calls, Choice.finish_reason adds "tool_calls", DeltaContent streams tool_calls. - Google provider: Tool(function_declarations=...) build, result normalised (args dict → JSON string), function_response parts on a user turn for tool-role messages. - OpenAI-compat: 1:1 passthrough of the OpenAI spec. - Ollama: /api/chat passthrough; model-level capability check via a TOOL_CAPABLE_OLLAMA_PATTERNS whitelist (llama3.1+, qwen2.5+, mistral, command-r, …) — unsupported models rejected rather than silently falling back to prose. - Router: model_supports_tools() check upfront for both streaming and non-streaming paths; ProviderCapabilityError bubbles as 400. No silent downgrade. Missing tool support = explicit error. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c612a22371
commit
e757470cb0
8 changed files with 568 additions and 81 deletions
|
|
@ -1,6 +1,13 @@
|
|||
"""Pydantic models for OpenAI-compatible API."""
|
||||
|
||||
from .requests import ChatCompletionRequest, EmbeddingRequest
|
||||
from .requests import (
|
||||
ChatCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
FunctionSpec,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolSpec,
|
||||
)
|
||||
from .responses import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
|
|
@ -12,6 +19,8 @@ from .responses import (
|
|||
ModelInfo,
|
||||
ModelsResponse,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
|
@ -24,9 +33,15 @@ __all__ = [
|
|||
"EmbeddingData",
|
||||
"EmbeddingRequest",
|
||||
"EmbeddingResponse",
|
||||
"FunctionSpec",
|
||||
"Message",
|
||||
"MessageResponse",
|
||||
"ModelInfo",
|
||||
"ModelsResponse",
|
||||
"StreamChoice",
|
||||
"ToolCall",
|
||||
"ToolCallFunction",
|
||||
"ToolChoice",
|
||||
"ToolSpec",
|
||||
"Usage",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -28,11 +28,64 @@ class ImageContent(BaseModel):
|
|||
MessageContent = str | list[TextContent | ImageContent]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A single message in the conversation."""
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""The function portion of a tool_call on an assistant message."""
|
||||
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: MessageContent
|
||||
name: str
|
||||
# Arguments are passed as a JSON string (OpenAI spec). Providers may
|
||||
# emit structured args natively; the adapter serialises them here.
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""A tool invocation the assistant decided to make."""
|
||||
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallFunction
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A single message in the conversation.
|
||||
|
||||
`tool` messages carry the result of a previously-requested tool call
|
||||
back into the context; they must reference the originating call via
|
||||
``tool_call_id``. Assistant messages may contain either plain
|
||||
``content`` or ``tool_calls`` (or both, though providers typically
|
||||
only emit one at a time).
|
||||
"""
|
||||
|
||||
role: Literal["system", "user", "assistant", "tool"]
|
||||
content: MessageContent | None = None
|
||||
tool_call_id: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class FunctionSpec(BaseModel):
|
||||
"""OpenAI-style function declaration."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
# JSON Schema for the function's parameters. Kept as a dict here to
|
||||
# stay loose; providers translate it to their native shape.
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolSpec(BaseModel):
|
||||
"""A tool the model may call. Only ``function`` tools are supported."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionSpec
|
||||
|
||||
|
||||
class ToolChoiceFunction(BaseModel):
|
||||
"""Force-pick a specific function for ``tool_choice``."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: dict[str, str] # {"name": "tool_name"}
|
||||
|
||||
|
||||
ToolChoice = Literal["auto", "none", "required"] | ToolChoiceFunction
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
|
|
@ -63,6 +116,8 @@ class ChatCompletionRequest(BaseModel):
|
|||
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
|
||||
stop: str | list[str] | None = None
|
||||
response_format: ResponseFormat | None = None
|
||||
tools: list[ToolSpec] | None = None
|
||||
tool_choice: ToolChoice | None = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
|
|
|
|||
|
|
@ -15,11 +15,27 @@ class Usage(BaseModel):
|
|||
total_tokens: int = 0
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""Function portion of a tool_call the assistant produced."""
|
||||
|
||||
name: str
|
||||
arguments: str # JSON-encoded (OpenAI spec)
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""A tool invocation the assistant decided to make."""
|
||||
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallFunction
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Response message from the model."""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: str
|
||||
content: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
|
|
@ -27,7 +43,9 @@ class Choice(BaseModel):
|
|||
|
||||
index: int = 0
|
||||
message: MessageResponse
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = "stop"
|
||||
finish_reason: (
|
||||
Literal["stop", "length", "content_filter", "tool_calls"] | None
|
||||
) = "stop"
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
|
|
@ -46,6 +64,7 @@ class DeltaContent(BaseModel):
|
|||
|
||||
role: Literal["assistant"] | None = None
|
||||
content: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class StreamChoice(BaseModel):
|
||||
|
|
@ -53,7 +72,9 @@ class StreamChoice(BaseModel):
|
|||
|
||||
index: int = 0
|
||||
delta: DeltaContent
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
finish_reason: (
|
||||
Literal["stop", "length", "content_filter", "tool_calls"] | None
|
||||
) = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,21 @@ class LLMProvider(ABC):
|
|||
|
||||
name: str = "base"
|
||||
|
||||
# Set to True if the provider supports OpenAI-style `tools` + `tool_calls`
|
||||
# for chat completions. The router rejects tool-bearing requests routed
|
||||
# to providers without support rather than silently dropping the tools.
|
||||
# Provider adapters may further narrow this per-model if needed.
|
||||
supports_tools: bool = False
|
||||
|
||||
def model_supports_tools(self, model: str) -> bool:
|
||||
"""Check if a specific model within this provider supports tools.
|
||||
|
||||
Default: falls back to the provider-wide flag. Providers with a
|
||||
mixed capability surface (e.g. Ollama — depends on the local
|
||||
model) override this.
|
||||
"""
|
||||
return self.supports_tools
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""Google Gemini provider for mana-llm (fallback when Ollama is unavailable)."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -20,6 +22,9 @@ from src.models import (
|
|||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
ToolSpec,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
|
@ -35,14 +40,66 @@ from .errors import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _unwrap_gemini_response(response: Any, gemini_model: str) -> str:
|
||||
"""Validate a non-streaming Gemini response and return its text.
|
||||
def _build_gemini_tools(tools: list[ToolSpec] | None) -> list[types.Tool] | None:
|
||||
"""Translate our ToolSpec list into Gemini ``types.Tool`` declarations."""
|
||||
if not tools:
|
||||
return None
|
||||
declarations: list[types.FunctionDeclaration] = []
|
||||
for t in tools:
|
||||
declarations.append(
|
||||
types.FunctionDeclaration(
|
||||
name=t.function.name,
|
||||
description=t.function.description,
|
||||
parameters=t.function.parameters or None,
|
||||
)
|
||||
)
|
||||
return [types.Tool(function_declarations=declarations)]
|
||||
|
||||
Raises a structured ProviderError if the response was blocked,
|
||||
truncated, or otherwise produced no usable text. The SDK's
|
||||
``response.text`` accessor silently returns an empty string in all
|
||||
of those cases, which downstream consumers (e.g. the planner
|
||||
parser) cannot distinguish from a well-formed empty completion.
|
||||
|
||||
def _extract_tool_calls(candidate: Any) -> list[ToolCall] | None:
|
||||
"""Pull any ``function_call`` parts off a candidate into ToolCalls.
|
||||
|
||||
Gemini emits tool calls as ``content.parts[i].function_call`` with a
|
||||
``FunctionCall(name, args)`` where ``args`` is a dict (not a JSON
|
||||
string). We normalise to OpenAI shape: arguments are JSON-encoded
|
||||
strings so downstream handlers can treat all providers the same.
|
||||
"""
|
||||
if candidate is None:
|
||||
return None
|
||||
content = getattr(candidate, "content", None)
|
||||
parts = getattr(content, "parts", None) or []
|
||||
calls: list[ToolCall] = []
|
||||
for part in parts:
|
||||
fc = getattr(part, "function_call", None)
|
||||
if fc is None:
|
||||
continue
|
||||
name = getattr(fc, "name", None)
|
||||
if not name:
|
||||
continue
|
||||
args = getattr(fc, "args", None) or {}
|
||||
# Gemini's args are already dict-shaped; serialise to JSON string.
|
||||
try:
|
||||
arguments_json = json.dumps(dict(args), ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
arguments_json = json.dumps({}, ensure_ascii=False)
|
||||
calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:12]}",
|
||||
function=ToolCallFunction(name=name, arguments=arguments_json),
|
||||
)
|
||||
)
|
||||
return calls or None
|
||||
|
||||
|
||||
def _unwrap_gemini_response(
|
||||
response: Any, gemini_model: str
|
||||
) -> tuple[str, list[ToolCall] | None, str]:
|
||||
"""Validate a non-streaming Gemini response.
|
||||
|
||||
Returns ``(text, tool_calls, finish_reason)``. Raises a structured
|
||||
``ProviderError`` if the response was blocked, truncated, or
|
||||
otherwise produced no usable payload. ``response.text`` silently
|
||||
returns ``""`` on blocked responses — we refuse to propagate that.
|
||||
"""
|
||||
candidates = getattr(response, "candidates", None) or []
|
||||
candidate = candidates[0] if candidates else None
|
||||
|
|
@ -56,6 +113,7 @@ def _unwrap_gemini_response(response: Any, gemini_model: str) -> str:
|
|||
finish_name = finish_name.rsplit(".", 1)[-1]
|
||||
|
||||
text = response.text or ""
|
||||
tool_calls = _extract_tool_calls(candidate)
|
||||
|
||||
if finish_name in {"SAFETY", "RECITATION", "PROHIBITED_CONTENT", "SPII", "BLOCKLIST"}:
|
||||
# Pull the first safety rating that actually blocked if present.
|
||||
|
|
@ -80,13 +138,37 @@ def _unwrap_gemini_response(response: Any, gemini_model: str) -> str:
|
|||
)
|
||||
raise ProviderTruncatedError(partial_text=text or None)
|
||||
|
||||
if not text and finish_name not in (None, "STOP"):
|
||||
# Unknown finish reason, empty text — surface instead of silent "".
|
||||
if not text and not tool_calls and finish_name not in (None, "STOP"):
|
||||
# Unknown finish reason, nothing to return — surface instead of "".
|
||||
raise ProviderError(
|
||||
f"Gemini returned no content (finish_reason={finish_name})"
|
||||
)
|
||||
|
||||
return text
|
||||
# Normalise the finish_reason to our OpenAI-compatible vocabulary.
|
||||
openai_finish = "stop"
|
||||
if tool_calls:
|
||||
openai_finish = "tool_calls"
|
||||
elif finish_name == "MAX_TOKENS":
|
||||
openai_finish = "length"
|
||||
return text, tool_calls, openai_finish
|
||||
|
||||
|
||||
def _lookup_tool_name(messages: list[Any], tool_call_id: str | None) -> str | None:
|
||||
"""Find the tool name for a given ``tool_call_id``.
|
||||
|
||||
OpenAI's tool-message carries only the id; the matching name lives
|
||||
on the preceding assistant message's ``tool_calls[]``. We scan from
|
||||
the end (most recent assistant turn) backwards.
|
||||
"""
|
||||
if not tool_call_id:
|
||||
return None
|
||||
for m in reversed(messages):
|
||||
if m.role != "assistant" or not m.tool_calls:
|
||||
continue
|
||||
for call in m.tool_calls:
|
||||
if call.id == tool_call_id:
|
||||
return call.function.name
|
||||
return None
|
||||
|
||||
|
||||
def _wrap_gemini_call_error(err: Exception, gemini_model: str) -> ProviderError:
|
||||
|
|
@ -123,6 +205,8 @@ class GoogleProvider(LLMProvider):
|
|||
"""Google Gemini API provider."""
|
||||
|
||||
name = "google"
|
||||
# Gemini 2.x supports OpenAI-style function calling across all listed models.
|
||||
supports_tools = True
|
||||
|
||||
def __init__(self, api_key: str, default_model: str = "gemini-2.5-flash"):
|
||||
self.api_key = api_key
|
||||
|
|
@ -138,43 +222,98 @@ class GoogleProvider(LLMProvider):
|
|||
) -> tuple[str | None, list[types.Content]]:
|
||||
"""Convert OpenAI-format messages to Google Gemini format.
|
||||
|
||||
Returns (system_instruction, contents).
|
||||
Returns (system_instruction, contents). Handles text, multimodal
|
||||
image content, assistant messages carrying ``tool_calls``, and
|
||||
``tool`` result messages (mapped to Gemini ``function_response``
|
||||
parts — Gemini has no ``tool`` role, function responses ride on
|
||||
a ``user`` turn).
|
||||
"""
|
||||
system_instruction: str | None = None
|
||||
contents: list[types.Content] = []
|
||||
|
||||
for msg in request.messages:
|
||||
if msg.role == "system":
|
||||
# Gemini uses system_instruction separately
|
||||
if isinstance(msg.content, str):
|
||||
system_instruction = msg.content
|
||||
continue
|
||||
|
||||
role = "user" if msg.role == "user" else "model"
|
||||
|
||||
if isinstance(msg.content, str):
|
||||
contents.append(types.Content(role=role, parts=[types.Part.from_text(text=msg.content)]))
|
||||
else:
|
||||
# Multimodal content
|
||||
parts: list[types.Part] = []
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
parts.append(types.Part.from_text(text=part.text))
|
||||
elif part.type == "image_url" and part.image_url:
|
||||
url = part.image_url.url
|
||||
if url.startswith("data:"):
|
||||
# Parse data URI: data:image/jpeg;base64,<data>
|
||||
header, b64_data = url.split(",", 1)
|
||||
mime_type = header.split(":")[1].split(";")[0]
|
||||
import base64
|
||||
|
||||
image_bytes = base64.b64decode(b64_data)
|
||||
parts.append(
|
||||
types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
|
||||
# Tool result message → function_response Part on a user turn.
|
||||
if msg.role == "tool":
|
||||
# The content is the stringified tool result. We also need
|
||||
# a tool name — the OpenAI spec carries it on the matching
|
||||
# assistant tool_call, keyed by tool_call_id. We don't
|
||||
# track that back-reference here, so we pull the name
|
||||
# from the preceding assistant message's tool_calls.
|
||||
name = _lookup_tool_name(request.messages, msg.tool_call_id)
|
||||
if not name:
|
||||
continue # orphan tool message — skip silently
|
||||
payload: Any
|
||||
if isinstance(msg.content, str):
|
||||
try:
|
||||
payload = json.loads(msg.content)
|
||||
if not isinstance(payload, dict):
|
||||
payload = {"result": payload}
|
||||
except (TypeError, ValueError):
|
||||
payload = {"result": msg.content}
|
||||
else:
|
||||
payload = {"result": ""}
|
||||
contents.append(
|
||||
types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_function_response(
|
||||
name=name, response=payload
|
||||
)
|
||||
else:
|
||||
# URL-based image - use as URI
|
||||
parts.append(types.Part.from_uri(file_uri=url, mime_type="image/jpeg"))
|
||||
],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
role = "user" if msg.role == "user" else "model"
|
||||
parts: list[types.Part] = []
|
||||
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
for call in msg.tool_calls:
|
||||
try:
|
||||
args_obj = json.loads(call.function.arguments or "{}")
|
||||
except (TypeError, ValueError):
|
||||
args_obj = {}
|
||||
parts.append(
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
name=call.function.name, args=args_obj
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if msg.content is not None:
|
||||
if isinstance(msg.content, str):
|
||||
parts.append(types.Part.from_text(text=msg.content))
|
||||
else:
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
parts.append(types.Part.from_text(text=part.text))
|
||||
elif part.type == "image_url" and part.image_url:
|
||||
url = part.image_url.url
|
||||
if url.startswith("data:"):
|
||||
header, b64_data = url.split(",", 1)
|
||||
mime_type = header.split(":")[1].split(";")[0]
|
||||
import base64
|
||||
|
||||
image_bytes = base64.b64decode(b64_data)
|
||||
parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=image_bytes, mime_type=mime_type
|
||||
)
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=url, mime_type="image/jpeg"
|
||||
)
|
||||
)
|
||||
|
||||
if parts:
|
||||
contents.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return system_instruction, contents
|
||||
|
|
@ -199,6 +338,10 @@ class GoogleProvider(LLMProvider):
|
|||
stop_seqs = request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
config["stop_sequences"] = stop_seqs
|
||||
|
||||
gemini_tools = _build_gemini_tools(request.tools)
|
||||
if gemini_tools:
|
||||
config["tools"] = gemini_tools
|
||||
|
||||
gen_config = types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
**config,
|
||||
|
|
@ -217,7 +360,9 @@ class GoogleProvider(LLMProvider):
|
|||
except Exception as err:
|
||||
raise _wrap_gemini_call_error(err, gemini_model) from err
|
||||
|
||||
content = _unwrap_gemini_response(response, gemini_model)
|
||||
content, tool_calls, finish_reason = _unwrap_gemini_response(
|
||||
response, gemini_model
|
||||
)
|
||||
usage_meta = response.usage_metadata
|
||||
|
||||
return ChatCompletionResponse(
|
||||
|
|
@ -225,8 +370,11 @@ class GoogleProvider(LLMProvider):
|
|||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=MessageResponse(content=content),
|
||||
finish_reason="stop",
|
||||
message=MessageResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
|
|
@ -253,6 +401,10 @@ class GoogleProvider(LLMProvider):
|
|||
if request.top_p is not None:
|
||||
config["top_p"] = request.top_p
|
||||
|
||||
gemini_tools = _build_gemini_tools(request.tools)
|
||||
if gemini_tools:
|
||||
config["tools"] = gemini_tools
|
||||
|
||||
gen_config = types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
**config,
|
||||
|
|
@ -297,7 +449,10 @@ class GoogleProvider(LLMProvider):
|
|||
|
||||
# Post-stream check: if the stream ended without emitting any text,
|
||||
# surface the structured reason instead of quietly closing with an
|
||||
# empty "stop". Matches _unwrap_gemini_response semantics.
|
||||
# empty "stop". Matches _unwrap_gemini_response semantics. We
|
||||
# discard the return value here — the streaming path produces its
|
||||
# content chunk-by-chunk above, so we only need this call for its
|
||||
# side-effect of raising on SAFETY / RECITATION / MAX_TOKENS.
|
||||
if not emitted_any_text and last_chunk is not None:
|
||||
_unwrap_gemini_response(last_chunk, gemini_model)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,14 +20,76 @@ from src.models import (
|
|||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
|
||||
# Ollama emits tool_calls under /api/chat as:
|
||||
# {"message":{"content":"", "tool_calls":[{"function":{"name":"x","arguments":{...}}}]}}
|
||||
# Arguments arrive as a dict — we normalise to a JSON string to match
|
||||
# the OpenAI-spec shape our downstream code expects.
|
||||
|
||||
# Ollama models known to support tool-calling reliably (as of 0.3+).
|
||||
# Everything else: we still pass `tools` to the API (it ignores them on
|
||||
# incompatible models), but the assistant response will be plain text.
|
||||
# A shared-ai-level capability check rejects these cases before the call.
|
||||
TOOL_CAPABLE_OLLAMA_PATTERNS: tuple[str, ...] = (
|
||||
"llama3.1",
|
||||
"llama3.2",
|
||||
"llama3.3",
|
||||
"qwen2.5",
|
||||
"qwen3",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"command-r",
|
||||
"firefunction",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_parse_args(raw: str | None) -> dict[str, Any]:
|
||||
"""Best-effort parse of a JSON-encoded arguments string."""
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except (TypeError, ValueError):
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
def _tool_calls_from_ollama(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None:
|
||||
"""Normalise Ollama's tool_calls into our ToolCall model."""
|
||||
if not raw:
|
||||
return None
|
||||
import uuid
|
||||
|
||||
calls: list[ToolCall] = []
|
||||
for idx, c in enumerate(raw):
|
||||
fn = c.get("function") or {}
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, dict):
|
||||
arguments_json = json.dumps(args, ensure_ascii=False)
|
||||
elif isinstance(args, str):
|
||||
arguments_json = args
|
||||
else:
|
||||
arguments_json = "{}"
|
||||
calls.append(
|
||||
ToolCall(
|
||||
id=c.get("id") or f"call_{uuid.uuid4().hex[:12]}",
|
||||
function=ToolCallFunction(name=name, arguments=arguments_json),
|
||||
)
|
||||
)
|
||||
return calls or None
|
||||
|
||||
|
||||
def _strip_json_fences(content: str) -> str:
|
||||
"""Strip ```json ... ``` markdown fences from a string if present.
|
||||
|
||||
|
|
@ -53,6 +115,17 @@ class OllamaProvider(LLMProvider):
|
|||
"""Ollama LLM provider."""
|
||||
|
||||
name = "ollama"
|
||||
supports_tools = True
|
||||
|
||||
def model_supports_tools(self, model: str) -> bool:
|
||||
"""Narrow tool capability to models trained for function calling.
|
||||
|
||||
Ollama accepts the ``tools`` payload even for incompatible models
|
||||
but silently drops it — the assistant just replies in prose. We
|
||||
reject those requests upfront so callers get a clear error.
|
||||
"""
|
||||
lower = model.lower()
|
||||
return any(pattern in lower for pattern in TOOL_CAPABLE_OLLAMA_PATTERNS)
|
||||
|
||||
def __init__(self, base_url: str | None = None, timeout: int | None = None):
|
||||
self.base_url = (base_url or settings.ollama_url).rstrip("/")
|
||||
|
|
@ -70,37 +143,56 @@ class OllamaProvider(LLMProvider):
|
|||
return self._client
|
||||
|
||||
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI message format to Ollama format."""
|
||||
messages = []
|
||||
"""Convert OpenAI message format to Ollama format.
|
||||
|
||||
Ollama's ``/api/chat`` uses the OpenAI shape for assistant
|
||||
``tool_calls`` and ``tool`` result messages (with args as objects
|
||||
rather than JSON strings on output — input still accepts JSON
|
||||
strings), so we largely pass them through.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
for msg in request.messages:
|
||||
if isinstance(msg.content, str):
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
out: dict[str, Any] = {"role": msg.role}
|
||||
|
||||
if msg.tool_calls:
|
||||
out["tool_calls"] = [
|
||||
{
|
||||
"function": {
|
||||
"name": c.function.name,
|
||||
# Ollama accepts either stringified JSON or
|
||||
# already-parsed objects; send parsed for
|
||||
# better compatibility with older models.
|
||||
"arguments": _safe_parse_args(c.function.arguments),
|
||||
}
|
||||
}
|
||||
for c in msg.tool_calls
|
||||
]
|
||||
|
||||
if msg.content is None:
|
||||
out["content"] = ""
|
||||
elif isinstance(msg.content, str):
|
||||
out["content"] = msg.content
|
||||
else:
|
||||
# Handle multimodal content (vision)
|
||||
text_parts = []
|
||||
images = []
|
||||
text_parts: list[str] = []
|
||||
images: list[str] = []
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
text_parts.append(part.text)
|
||||
elif part.type == "image_url":
|
||||
url = part.image_url.url
|
||||
# Extract base64 data from data URL
|
||||
if url.startswith("data:"):
|
||||
# Format: data:image/png;base64,<base64_data>
|
||||
base64_data = url.split(",", 1)[1] if "," in url else url
|
||||
images.append(base64_data)
|
||||
else:
|
||||
# HTTP URL - Ollama expects base64, so we'd need to fetch
|
||||
# For now, log warning and skip
|
||||
logger.warning(f"HTTP image URLs not supported, skipping: {url[:50]}...")
|
||||
|
||||
message_data: dict[str, Any] = {
|
||||
"role": msg.role,
|
||||
"content": " ".join(text_parts),
|
||||
}
|
||||
logger.warning(
|
||||
"HTTP image URLs not supported, skipping: %s...",
|
||||
url[:50],
|
||||
)
|
||||
out["content"] = " ".join(text_parts)
|
||||
if images:
|
||||
message_data["images"] = images
|
||||
messages.append(message_data)
|
||||
out["images"] = images
|
||||
|
||||
messages.append(out)
|
||||
return messages
|
||||
|
||||
async def chat_completion(
|
||||
|
|
@ -152,23 +244,39 @@ class OllamaProvider(LLMProvider):
|
|||
if options:
|
||||
payload["options"] = options
|
||||
|
||||
if request.tools:
|
||||
payload["tools"] = [
|
||||
{"type": "function", "function": t.function.model_dump()}
|
||||
for t in request.tools
|
||||
]
|
||||
|
||||
logger.debug(f"Ollama request: {model}, messages: {len(request.messages)}")
|
||||
|
||||
response = await self.client.post("/api/chat", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
message = data.get("message", {})
|
||||
tool_calls = _tool_calls_from_ollama(message.get("tool_calls"))
|
||||
# Defensive fence-stripping: even with `format` set, some older
|
||||
# Ollama versions still emit ```json ... ``` wrappers for vision
|
||||
# models. Strip them so strict downstream parsers see clean JSON.
|
||||
content = _strip_json_fences(data["message"]["content"])
|
||||
raw_content = message.get("content", "") or ""
|
||||
content = _strip_json_fences(raw_content) if raw_content else ""
|
||||
|
||||
finish_reason = "tool_calls" if tool_calls else (
|
||||
"stop" if data.get("done") else None
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
model=f"ollama/{model}",
|
||||
choices=[
|
||||
Choice(
|
||||
message=MessageResponse(content=content),
|
||||
finish_reason="stop" if data.get("done") else None,
|
||||
message=MessageResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
|
|
@ -204,6 +312,12 @@ class OllamaProvider(LLMProvider):
|
|||
if options:
|
||||
payload["options"] = options
|
||||
|
||||
if request.tools:
|
||||
payload["tools"] = [
|
||||
{"type": "function", "function": t.function.model_dump()}
|
||||
for t in request.tools
|
||||
]
|
||||
|
||||
logger.debug(f"Ollama streaming request: {model}")
|
||||
|
||||
response_id = f"chatcmpl-{model[:8]}"
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ from src.models import (
|
|||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
|
@ -27,9 +29,37 @@ from .base import LLMProvider
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _tool_calls_from_openai(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None:
|
||||
"""Normalise an OpenAI-spec ``tool_calls`` array into our model shape."""
|
||||
if not raw:
|
||||
return None
|
||||
calls: list[ToolCall] = []
|
||||
for c in raw:
|
||||
fn = c.get("function") or {}
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
calls.append(
|
||||
ToolCall(
|
||||
id=c.get("id") or "",
|
||||
function=ToolCallFunction(
|
||||
name=name,
|
||||
arguments=fn.get("arguments") or "{}",
|
||||
),
|
||||
)
|
||||
)
|
||||
return calls or None
|
||||
|
||||
|
||||
class OpenAICompatProvider(LLMProvider):
|
||||
"""OpenAI-compatible API provider (OpenRouter, Groq, Together, etc.)."""
|
||||
|
||||
# OpenRouter/Groq/Together all expose tool_calls per the OpenAI spec;
|
||||
# individual models within those services may or may not support it,
|
||||
# but the request shape is uniform. The upstream returns a proper
|
||||
# error for unsupported models — no silent downgrade here.
|
||||
supports_tools = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
|
|
@ -60,23 +90,53 @@ class OpenAICompatProvider(LLMProvider):
|
|||
return self._client
|
||||
|
||||
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
|
||||
"""Convert internal message format to OpenAI format."""
|
||||
messages = []
|
||||
"""Convert internal message format to OpenAI format.
|
||||
|
||||
The OpenAI chat-completions endpoint is the source of truth for
|
||||
this shape, so most fields pass through verbatim — including
|
||||
``role='tool'`` messages with ``tool_call_id`` + ``content`` and
|
||||
assistant messages carrying ``tool_calls[]``.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
for msg in request.messages:
|
||||
if isinstance(msg.content, str):
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
out: dict[str, Any] = {"role": msg.role}
|
||||
|
||||
if msg.tool_call_id:
|
||||
out["tool_call_id"] = msg.tool_call_id
|
||||
|
||||
if msg.tool_calls:
|
||||
out["tool_calls"] = [
|
||||
{
|
||||
"id": c.id,
|
||||
"type": c.type,
|
||||
"function": {
|
||||
"name": c.function.name,
|
||||
"arguments": c.function.arguments,
|
||||
},
|
||||
}
|
||||
for c in msg.tool_calls
|
||||
]
|
||||
|
||||
if msg.content is None:
|
||||
# Assistant tool-call messages have null content per spec.
|
||||
out["content"] = None
|
||||
elif isinstance(msg.content, str):
|
||||
out["content"] = msg.content
|
||||
else:
|
||||
# Handle multimodal content
|
||||
content_parts = []
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
content_parts.append({"type": "text", "text": part.text})
|
||||
elif part.type == "image_url":
|
||||
content_parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": part.image_url.url},
|
||||
})
|
||||
messages.append({"role": msg.role, "content": content_parts})
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": part.image_url.url},
|
||||
}
|
||||
)
|
||||
out["content"] = content_parts
|
||||
|
||||
messages.append(out)
|
||||
return messages
|
||||
|
||||
async def chat_completion(
|
||||
|
|
@ -104,6 +164,17 @@ class OpenAICompatProvider(LLMProvider):
|
|||
payload["presence_penalty"] = request.presence_penalty
|
||||
if request.stop:
|
||||
payload["stop"] = request.stop
|
||||
if request.tools:
|
||||
payload["tools"] = [
|
||||
{"type": "function", "function": t.function.model_dump()}
|
||||
for t in request.tools
|
||||
]
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = (
|
||||
request.tool_choice
|
||||
if isinstance(request.tool_choice, str)
|
||||
else request.tool_choice.model_dump()
|
||||
)
|
||||
|
||||
logger.debug(f"{self.name} request: {model}, messages: {len(request.messages)}")
|
||||
|
||||
|
|
@ -117,7 +188,12 @@ class OpenAICompatProvider(LLMProvider):
|
|||
choices=[
|
||||
Choice(
|
||||
index=choice.get("index", 0),
|
||||
message=MessageResponse(content=choice["message"]["content"]),
|
||||
message=MessageResponse(
|
||||
content=choice["message"].get("content"),
|
||||
tool_calls=_tool_calls_from_openai(
|
||||
choice["message"].get("tool_calls")
|
||||
),
|
||||
),
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
)
|
||||
for choice in data.get("choices", [])
|
||||
|
|
@ -154,6 +230,17 @@ class OpenAICompatProvider(LLMProvider):
|
|||
payload["presence_penalty"] = request.presence_penalty
|
||||
if request.stop:
|
||||
payload["stop"] = request.stop
|
||||
if request.tools:
|
||||
payload["tools"] = [
|
||||
{"type": "function", "function": t.function.model_dump()}
|
||||
for t in request.tools
|
||||
]
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = (
|
||||
request.tool_choice
|
||||
if isinstance(request.tool_choice, str)
|
||||
else request.tool_choice.model_dump()
|
||||
)
|
||||
|
||||
logger.debug(f"{self.name} streaming request: {model}")
|
||||
|
||||
|
|
@ -190,6 +277,9 @@ class OpenAICompatProvider(LLMProvider):
|
|||
delta=DeltaContent(
|
||||
role=delta.get("role"),
|
||||
content=delta.get("content"),
|
||||
tool_calls=_tool_calls_from_openai(
|
||||
delta.get("tool_calls")
|
||||
),
|
||||
),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from src.models import (
|
|||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
from .errors import ProviderCapabilityError
|
||||
from .ollama import OllamaProvider
|
||||
from .openai_compat import OpenAICompatProvider
|
||||
|
||||
|
|
@ -154,6 +155,23 @@ class ProviderRouter:
|
|||
)
|
||||
return await google.chat_completion(request, gemini_model)
|
||||
|
||||
def _check_tool_capability(
|
||||
self, provider: LLMProvider, model_name: str, request: ChatCompletionRequest
|
||||
) -> None:
|
||||
"""Refuse tool-bearing requests for providers/models without tool support.
|
||||
|
||||
Silent downgrade (dropping the `tools` payload) is more dangerous
|
||||
than an explicit error — the caller would get plain text back and
|
||||
have no way to tell the tools never reached the model.
|
||||
"""
|
||||
if not request.tools:
|
||||
return
|
||||
if not provider.model_supports_tools(model_name):
|
||||
raise ProviderCapabilityError(
|
||||
f"{provider.name}/{model_name} does not support tool calling. "
|
||||
"Choose a tool-capable model (e.g. gemini-2.5-flash, llama3.1:*)"
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
|
|
@ -164,6 +182,7 @@ class ProviderRouter:
|
|||
# Non-Ollama providers: direct routing, no fallback
|
||||
if provider_name != "ollama":
|
||||
provider = self._get_provider(provider_name)
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
logger.info(f"Routing chat completion to {provider_name}/{model_name}")
|
||||
return await provider.chat_completion(request, model_name)
|
||||
|
||||
|
|
@ -176,6 +195,7 @@ class ProviderRouter:
|
|||
|
||||
# Try Ollama first
|
||||
provider = self._get_provider("ollama")
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
logger.info(f"Routing chat completion to ollama/{model_name}")
|
||||
self._ollama_concurrent += 1
|
||||
|
||||
|
|
@ -199,6 +219,7 @@ class ProviderRouter:
|
|||
# Non-Ollama: direct
|
||||
if provider_name != "ollama":
|
||||
provider = self._get_provider(provider_name)
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
logger.info(f"Routing streaming to {provider_name}/{model_name}")
|
||||
async for chunk in provider.chat_completion_stream(request, model_name):
|
||||
yield chunk
|
||||
|
|
@ -219,6 +240,7 @@ class ProviderRouter:
|
|||
return
|
||||
|
||||
provider = self._get_provider("ollama")
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
logger.info(f"Routing streaming to ollama/{model_name}")
|
||||
self._ollama_concurrent += 1
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue