kiln_ai.adapters.model_adapters.litellm_adapter

  1import asyncio
  2import copy
  3import json
  4import logging
  5from dataclasses import dataclass
  6from typing import Any, Dict, List, Tuple
  7
  8import litellm
  9from litellm.types.utils import (
 10    ChatCompletionMessageToolCall,
 11    ChoiceLogprobs,
 12    Choices,
 13    ModelResponse,
 14)
 15from litellm.types.utils import Message as LiteLLMMessage
 16from litellm.types.utils import Usage as LiteLlmUsage
 17from openai.types.chat.chat_completion_message_tool_call_param import (
 18    ChatCompletionMessageToolCallParam,
 19)
 20
 21import kiln_ai.datamodel as datamodel
 22from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM
 23from kiln_ai.adapters.chat.chat_formatter import ToolResponseMessage
 24from kiln_ai.adapters.ml_model_list import (
 25    KilnModelProvider,
 26    ModelProviderName,
 27    StructuredOutputMode,
 28)
 29from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream
 30from kiln_ai.adapters.model_adapters.base_adapter import (
 31    AdapterConfig,
 32    BaseAdapter,
 33    RunOutput,
 34    Usage,
 35)
 36from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
 37from kiln_ai.datamodel.datamodel_enums import InputType
 38from kiln_ai.datamodel.json_schema import (
 39    close_object_schemas,
 40    validate_schema_with_value_error,
 41)
 42from kiln_ai.datamodel.run_config import (
 43    KilnAgentRunConfigProperties,
 44    as_kiln_agent_run_config,
 45)
 46from kiln_ai.tools.base_tool import (
 47    KilnToolInterface,
 48    ToolCallContext,
 49    ToolCallDefinition,
 50)
 51from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
 52from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
 53from kiln_ai.utils.litellm import get_litellm_provider_info
 54from kiln_ai.utils.open_ai_types import (
 55    ChatCompletionAssistantMessageParamWrapper,
 56    ChatCompletionMessageParam,
 57    ChatCompletionToolMessageParamWrapper,
 58)
 59
 60MAX_CALLS_PER_TURN = 10
 61MAX_TOOL_CALLS_PER_TURN = 30
 62
 63logger = logging.getLogger(__name__)
 64
 65
 66def _validate_unmanaged_tools(tools: list[KilnToolInterface]) -> None:
 67    for i, tool in enumerate(tools):
 68        if not isinstance(tool, KilnToolInterface):
 69            raise TypeError(
 70                f"unmanaged_tools[{i}] must be a KilnToolInterface instance, got {type(tool).__name__}"
 71            )
 72
 73
 74@dataclass
 75class ModelTurnResult:
 76    assistant_message: str
 77    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
 78    model_response: ModelResponse | None
 79    model_choice: Choices | None
 80    usage: Usage
 81    interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None
 82
 83
 84class LiteLlmAdapter(BaseAdapter):
 85    def __init__(
 86        self,
 87        config: LiteLlmConfig,
 88        kiln_task: datamodel.Task,
 89        base_adapter_config: AdapterConfig | None = None,
 90    ):
 91        if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties):
 92            raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties")
 93        self.config = config
 94        self._additional_body_options = config.additional_body_options
 95        self._api_base = config.base_url
 96        self._headers = config.default_headers
 97        self._litellm_model_id: str | None = None
 98        self._cached_available_tools: list[KilnToolInterface] | None = None
 99
100        super().__init__(
101            task=kiln_task,
102            run_config=config.run_config_properties,
103            config=base_adapter_config,
104        )
105
106        unmanaged_tools = self.base_adapter_config.unmanaged_tools
107        if unmanaged_tools:
108            _validate_unmanaged_tools(unmanaged_tools)
109
110    async def _run_model_turn(
111        self,
112        provider: KilnModelProvider,
113        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
114        top_logprobs: int | None,
115        skip_response_format: bool,
116    ) -> ModelTurnResult:
117        """
118        Call the model for a single top level turn: from user message to agent message.
119
120        It may make handle iterations of tool calls between the user/agent message if needed.
121        """
122
123        usage = Usage()
124        messages = list(prior_messages)
125        tool_calls_count = 0
126
127        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
128            # Build completion kwargs for tool calls
129            completion_kwargs = await self.build_completion_kwargs(
130                provider,
131                # Pass a copy, as acompletion mutates objects and breaks types.
132                copy.deepcopy(messages),
133                top_logprobs,
134                skip_response_format,
135            )
136
137            # Make the completion call
138            model_response, response_choice = await self.acompletion_checking_response(
139                **completion_kwargs
140            )
141
142            # count the usage
143            usage += self.usage_from_response(model_response)
144
145            # Extract content and tool calls
146            if not hasattr(response_choice, "message"):
147                raise ValueError("Response choice has no message")
148            content = response_choice.message.content
149            tool_calls = response_choice.message.tool_calls
150            if not content and not tool_calls:
151                raise ValueError(
152                    "Model returned an assistant message, but no content or tool calls. This is not supported."
153                )
154
155            # Add message to messages, so it can be used in the next turn
156            messages.append(response_choice.message)
157
158            # Process tool calls if any
159            if tool_calls and len(tool_calls) > 0:
160                # check if we should return control to caller
161                if self.base_adapter_config.return_on_tool_call:
162                    # filter out task_response tool (task_response tools are internal)
163                    standard_tool_calls = [
164                        tc for tc in tool_calls if tc.function.name != "task_response"
165                    ]
166                    has_task_response = any(
167                        tc.function.name == "task_response" for tc in tool_calls
168                    )
169                    if standard_tool_calls and not has_task_response:
170                        return ModelTurnResult(
171                            # we don't have any content, we are waiting for toolcall output to come back from client
172                            assistant_message="",
173                            all_messages=messages,
174                            model_response=model_response,
175                            model_choice=response_choice,
176                            usage=usage,
177                            interrupted_by_tool_calls=standard_tool_calls,
178                        )
179
180                # otherwise: process tool calls internally until final output
181                (
182                    assistant_message_from_toolcall,
183                    tool_call_messages,
184                ) = await self.process_tool_calls(tool_calls)
185
186                # Add tool call results to messages
187                messages.extend(tool_call_messages)
188
189                # If task_response tool was called, we're done
190                if assistant_message_from_toolcall is not None:
191                    return ModelTurnResult(
192                        assistant_message=assistant_message_from_toolcall,
193                        all_messages=messages,
194                        model_response=model_response,
195                        model_choice=response_choice,
196                        usage=usage,
197                    )
198
199                # If there were tool calls, increment counter and continue
200                if tool_call_messages:
201                    tool_calls_count += 1
202                    continue
203
204            # If no tool calls, return the content as final output
205            if content:
206                return ModelTurnResult(
207                    assistant_message=content,
208                    all_messages=messages,
209                    model_response=model_response,
210                    model_choice=response_choice,
211                    usage=usage,
212                )
213
214            # If we get here with no content and no tool calls, break
215            raise RuntimeError(
216                "Model returned neither content nor tool calls. It must return at least one of these."
217            )
218
219        raise RuntimeError(
220            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
221        )
222
223    async def _run(
224        self,
225        input: InputType,
226        prior_trace: list[ChatCompletionMessageParam] | None = None,
227    ) -> tuple[RunOutput, Usage | None]:
228        usage = Usage()
229
230        provider = self.model_provider()
231        if not provider.model_id:
232            raise ValueError("Model ID is required for OpenAI compatible models")
233
234        # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter
235        chat_formatter = self.build_chat_formatter(input, prior_trace)
236        messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy(
237            chat_formatter.initial_messages()
238        )
239
240        prior_output: str | None = None
241        final_choice: Choices | None = None
242        turns = 0
243
244        # Same loop for both fresh runs and prior_trace continuation.
245        # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues).
246        while True:
247            turns += 1
248            if turns > MAX_CALLS_PER_TURN:
249                raise RuntimeError(
250                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
251                )
252
253            turn = chat_formatter.next_turn(prior_output)
254            if turn is None:
255                # No next turn, we're done
256                break
257
258            # Add messages from the turn to chat history
259            for message in turn.messages:
260                if message.content is None:
261                    raise ValueError("Empty message content isn't allowed")
262                # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
263                msg_dict: dict = {"role": message.role, "content": message.content}
264                if isinstance(message, ToolResponseMessage):
265                    msg_dict["tool_call_id"] = message.tool_call_id
266                messages.append(msg_dict)  # type: ignore
267
268            skip_response_format = not turn.final_call
269            turn_result = await self._run_model_turn(
270                provider,
271                messages,
272                self.base_adapter_config.top_logprobs if turn.final_call else None,
273                skip_response_format,
274            )
275
276            usage += turn_result.usage
277
278            prior_output = turn_result.assistant_message
279            messages = turn_result.all_messages
280            final_choice = turn_result.model_choice
281
282            # Check if we were interrupted by tool calls
283            if turn_result.interrupted_by_tool_calls:
284                trace = self.all_messages_to_trace(messages)
285                intermediate_outputs = chat_formatter.intermediate_outputs()
286                output = RunOutput(
287                    output=prior_output or "",
288                    intermediate_outputs=intermediate_outputs,
289                    output_logprobs=None,
290                    trace=trace,
291                )
292                return output, usage
293
294            if not prior_output:
295                raise RuntimeError("No assistant message/output returned from model")
296
297        logprobs = self._extract_and_validate_logprobs(final_choice)
298
299        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
300        intermediate_outputs = chat_formatter.intermediate_outputs()
301        self._extract_reasoning_to_intermediate_outputs(
302            final_choice, intermediate_outputs
303        )
304
305        if not isinstance(prior_output, str):
306            raise RuntimeError(f"assistant message is not a string: {prior_output}")
307
308        trace = self.all_messages_to_trace(messages)
309        output = RunOutput(
310            output=prior_output,
311            intermediate_outputs=intermediate_outputs,
312            output_logprobs=logprobs,
313            trace=trace,
314        )
315
316        return output, usage
317
318    def _create_run_stream(
319        self,
320        input: InputType,
321        prior_trace: list[ChatCompletionMessageParam] | None = None,
322    ) -> AdapterStream:
323        provider = self.model_provider()
324        if not provider.model_id:
325            raise ValueError("Model ID is required for OpenAI compatible models")
326
327        chat_formatter = self.build_chat_formatter(input, prior_trace)
328        initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy(
329            chat_formatter.initial_messages()
330        )
331
332        return AdapterStream(
333            adapter=self,
334            provider=provider,
335            chat_formatter=chat_formatter,
336            initial_messages=initial_messages,
337            top_logprobs=self.base_adapter_config.top_logprobs,
338        )
339
340    def _extract_and_validate_logprobs(
341        self, final_choice: Choices | None
342    ) -> ChoiceLogprobs | None:
343        """
344        Extract logprobs from the final choice and validate they exist if required.
345        """
346        logprobs = None
347        if (
348            final_choice is not None
349            and hasattr(final_choice, "logprobs")
350            and isinstance(final_choice.logprobs, ChoiceLogprobs)
351        ):
352            logprobs = final_choice.logprobs
353
354        # Check logprobs worked, if required
355        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
356            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
357
358        return logprobs
359
360    def _extract_reasoning_to_intermediate_outputs(
361        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
362    ) -> None:
363        """Extract reasoning content from model choice and add to intermediate outputs if present."""
364        if (
365            final_choice is not None
366            and hasattr(final_choice, "message")
367            and hasattr(final_choice.message, "reasoning_content")
368        ):
369            reasoning_content = final_choice.message.reasoning_content
370            if reasoning_content is not None:
371                stripped_reasoning_content = reasoning_content.strip()
372                if len(stripped_reasoning_content) > 0:
373                    intermediate_outputs["reasoning"] = stripped_reasoning_content
374
375    async def acompletion_checking_response(
376        self, **kwargs: Any
377    ) -> Tuple[ModelResponse, Choices]:
378        response = await litellm.acompletion(**kwargs)
379
380        if (
381            not isinstance(response, ModelResponse)
382            or not response.choices
383            or len(response.choices) == 0
384            or not isinstance(response.choices[0], Choices)
385        ):
386            raise RuntimeError(
387                f"Expected ModelResponse with Choices, got {type(response)}."
388            )
389        return response, response.choices[0]
390
391    def adapter_name(self) -> str:
392        return "kiln_openai_compatible_adapter"
393
394    async def response_format_options(self) -> dict[str, Any]:
395        # Unstructured if task isn't structured
396        if not self.has_structured_output():
397            return {}
398
399        run_config = as_kiln_agent_run_config(self.run_config)
400        structured_output_mode: StructuredOutputMode = run_config.structured_output_mode
401
402        match structured_output_mode:
403            case StructuredOutputMode.json_mode:
404                return {"response_format": {"type": "json_object"}}
405            case StructuredOutputMode.json_schema:
406                return self.json_schema_response_format()
407            case StructuredOutputMode.function_calling_weak:
408                return self.tool_call_params(strict=False)
409            case StructuredOutputMode.function_calling:
410                return self.tool_call_params(strict=True)
411            case StructuredOutputMode.json_instructions:
412                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
413                return {}
414            case StructuredOutputMode.json_custom_instructions:
415                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
416                return {}
417            case StructuredOutputMode.json_instruction_and_object:
418                # We set response_format to json_object and also set json instructions in the prompt
419                return {"response_format": {"type": "json_object"}}
420            case StructuredOutputMode.default:
421                provider_name = run_config.model_provider_name
422                if provider_name == ModelProviderName.ollama:
423                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
424                    return self.json_schema_response_format()
425                elif provider_name == ModelProviderName.docker_model_runner:
426                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
427                    return self.json_schema_response_format()
428                else:
429                    # Default to function calling -- it's older than the other modes. Higher compatibility.
430                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
431                    strict = provider_name == ModelProviderName.openai
432                    return self.tool_call_params(strict=strict)
433            case StructuredOutputMode.unknown:
434                # See above, but this case should never happen.
435                raise ValueError("Structured output mode is unknown.")
436            case _:
437                raise_exhaustive_enum_error(structured_output_mode)  # type: ignore[arg-type]
438
439    def json_schema_response_format(self) -> dict[str, Any]:
440        output_schema = self.task.output_schema()
441        if output_schema is None:
442            raise ValueError(
443                "Invalid output schema for this task. Cannot use JSON schema response format."
444            )
445        output_schema = close_object_schemas(output_schema, strict=True)
446        return {
447            "response_format": {
448                "type": "json_schema",
449                "json_schema": {
450                    "name": "task_response",
451                    "schema": output_schema,
452                },
453            }
454        }
455
456    def tool_call_params(self, strict: bool) -> dict[str, Any]:
457        # Add additional_properties: false to the schema (OpenAI requires this for some models)
458        output_schema = self.task.output_schema()
459        if not isinstance(output_schema, dict):
460            raise ValueError(
461                "Invalid output schema for this task. Can not use tool calls."
462            )
463        output_schema = close_object_schemas(output_schema, strict=strict)
464
465        function_params = {
466            "name": "task_response",
467            "parameters": output_schema,
468        }
469        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
470        if strict:
471            function_params["strict"] = True
472
473        return {
474            "tools": [
475                {
476                    "type": "function",
477                    "function": function_params,
478                }
479            ],
480            "tool_choice": {
481                "type": "function",
482                "function": {"name": "task_response"},
483            },
484        }
485
486    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
487        # Don't love having this logic here. But it's worth the usability improvement
488        # so better to keep it than exclude it. Should figure out how I want to isolate
489        # this sort of logic so it's config driven and can be overridden
490        extra_body: dict[str, Any] = {}
491        provider_options = {}
492
493        run_config = as_kiln_agent_run_config(self.run_config)
494        # For legacy config 'thinking_level' is not set, default to provider's default
495        if "thinking_level" in run_config.model_fields_set:
496            thinking_level = run_config.thinking_level
497        else:
498            thinking_level = provider.default_thinking_level
499
500        # Set the reasoning_effort
501        if thinking_level is not None:
502            # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens
503            if (
504                provider.name == ModelProviderName.openrouter
505                and provider.openrouter_reasoning_object
506            ):
507                extra_body["reasoning"] = {"effort": thinking_level}
508            else:
509                extra_body["reasoning_effort"] = thinking_level
510
511        if provider.require_openrouter_reasoning:
512            # https://openrouter.ai/docs/use-cases/reasoning-tokens
513            extra_body["reasoning"] = {
514                "exclude": False,
515            }
516
517        if provider.gemini_reasoning_enabled:
518            extra_body["reasoning"] = {
519                "enabled": True,
520            }
521
522        if provider.name == ModelProviderName.openrouter:
523            # Ask OpenRouter to include usage in the response (cost)
524            extra_body["usage"] = {"include": True}
525
526            # Set a default provider order for more deterministic routing.
527            # OpenRouter will ignore providers that don't support the model.
528            # Special cases below (like R1) can override this order.
529            # allow_fallbacks is true by default, but we can override it here.
530            provider_options["order"] = [
531                "fireworks",
532                "parasail",
533                "together",
534                "deepinfra",
535                "novita",
536                "groq",
537                "amazon-bedrock",
538                "azure",
539                "nebius",
540            ]
541
542        if provider.anthropic_extended_thinking:
543            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
544
545        if provider.r1_openrouter_options:
546            # Require providers that support the reasoning parameter
547            provider_options["require_parameters"] = True
548            # Prefer R1 providers with reasonable perf/quants
549            provider_options["order"] = ["fireworks", "together"]
550            # R1 providers with unreasonable quants
551            provider_options["ignore"] = ["deepinfra"]
552
553        # Only set of this request is to get logprobs.
554        if (
555            provider.logprobs_openrouter_options
556            and self.base_adapter_config.top_logprobs is not None
557        ):
558            # Don't let OpenRouter choose a provider that doesn't support logprobs.
559            provider_options["require_parameters"] = True
560            # DeepInfra silently fails to return logprobs consistently.
561            provider_options["ignore"] = ["deepinfra"]
562
563        if provider.openrouter_skip_required_parameters:
564            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
565            provider_options["require_parameters"] = False
566
567        # Siliconflow uses a bool flag for thinking, for some models
568        if provider.siliconflow_enable_thinking is not None:
569            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
570
571        if len(provider_options) > 0:
572            extra_body["provider"] = provider_options
573
574        return extra_body
575
576    def litellm_model_id(self) -> str:
577        # The model ID is an interesting combination of format and url endpoint.
578        # It specifics the provider URL/host, but this is overridden if you manually set an api url
579        if self._litellm_model_id:
580            return self._litellm_model_id
581
582        litellm_provider_info = get_litellm_provider_info(self.model_provider())
583        if litellm_provider_info.is_custom and self._api_base is None:
584            raise ValueError(
585                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
586            )
587
588        self._litellm_model_id = litellm_provider_info.litellm_model_id
589        return self._litellm_model_id
590
591    def _allowed_openai_params_for_completion_kwargs(
592        self, completion_kwargs: dict[str, Any]
593    ) -> list[str]:
594        """
595        LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong
596        and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter.
597        """
598        # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that
599        explicit_allowed_params: Any | list = completion_kwargs.get(
600            "allowed_openai_params", []
601        )
602        if not isinstance(explicit_allowed_params, list):
603            raise ValueError(
604                f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}"
605            )
606        explicit_allowed_params_validated = [
607            param for param in explicit_allowed_params if isinstance(param, str)
608        ]
609        invalid_count = len(explicit_allowed_params) - len(
610            explicit_allowed_params_validated
611        )
612        if invalid_count > 0:
613            raise ValueError(
614                f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings"
615            )
616
617        # these are our own logic
618        automatic_allowed_params: list[str] = []
619        if "tools" in completion_kwargs:
620            automatic_allowed_params.append("tools")
621        if "tool_choice" in completion_kwargs:
622            automatic_allowed_params.append("tool_choice")
623
624        return list(set(explicit_allowed_params_validated + automatic_allowed_params))
625
626    async def build_completion_kwargs(
627        self,
628        provider: KilnModelProvider,
629        messages: list[ChatCompletionMessageIncludingLiteLLM],
630        top_logprobs: int | None,
631        skip_response_format: bool = False,
632    ) -> dict[str, Any]:
633        run_config = as_kiln_agent_run_config(self.run_config)
634        extra_body = self.build_extra_body(provider)
635
636        # Merge all parameters into a single kwargs dict for litellm
637        completion_kwargs = {
638            "model": self.litellm_model_id(),
639            "messages": messages,
640            "api_base": self._api_base,
641            "headers": self._headers,
642            "temperature": run_config.temperature,
643            "top_p": run_config.top_p,
644            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
645            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
646            # Better to ignore them than to fail the model call.
647            # https://docs.litellm.ai/docs/completion/input
648            "drop_params": True,
649            **extra_body,
650            **self._additional_body_options,
651        }
652
653        if self.base_adapter_config.automatic_prompt_caching:
654            # Mark the last message for cache control. Litellm's AnthropicCacheControlHook
655            # handles provider-specific injection. Providers auto-cache matching prefixes,
656            # so marking the last message is sufficient for multi-turn conversations.
657            completion_kwargs["cache_control_injection_points"] = [
658                {"location": "message", "index": -1}
659            ]
660
661        tool_calls = await self.litellm_tools()
662        has_tools = len(tool_calls) > 0
663        if has_tools:
664            completion_kwargs["tools"] = tool_calls
665            completion_kwargs["tool_choice"] = "auto"
666
667        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
668        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
669        if provider.temp_top_p_exclusive:
670            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
671                del completion_kwargs["top_p"]
672            if (
673                "temperature" in completion_kwargs
674                and completion_kwargs["temperature"] == 1.0
675            ):
676                del completion_kwargs["temperature"]
677            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
678                raise ValueError(
679                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
680                )
681
682        if not skip_response_format:
683            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
684            response_format_options = await self.response_format_options()
685
686            # Check for a conflict between tools and response format using tools
687            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
688            if has_tools and "tools" in response_format_options:
689                raise ValueError(
690                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
691                )
692
693            completion_kwargs.update(response_format_options)
694
695        if top_logprobs is not None:
696            completion_kwargs["logprobs"] = True
697            completion_kwargs["top_logprobs"] = top_logprobs
698
699        # any params listed in this list will be passed to the model regardless of LiteLLM's own validation
700        allowed_openai_params = self._allowed_openai_params_for_completion_kwargs(
701            completion_kwargs
702        )
703        if len(allowed_openai_params) > 0:
704            completion_kwargs["allowed_openai_params"] = allowed_openai_params
705
706        return completion_kwargs
707
708    def usage_from_response(self, response: ModelResponse) -> Usage:
709        litellm_usage = response.get("usage", None)
710
711        # LiteLLM isn't consistent in how it returns the cost.
712        cost = response._hidden_params.get("response_cost", None)
713        if cost is None and litellm_usage:
714            cost = litellm_usage.get("cost", None)
715
716        usage = Usage()
717
718        if not litellm_usage and not cost:
719            return usage
720
721        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
722            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
723            usage.output_tokens = litellm_usage.get("completion_tokens", None)
724            usage.total_tokens = litellm_usage.get("total_tokens", None)
725            prompt_details = litellm_usage.get("prompt_tokens_details", None)
726            if prompt_details and hasattr(prompt_details, "cached_tokens"):
727                usage.cached_tokens = prompt_details.cached_tokens
728            elif prompt_details:
729                logger.warning(
730                    f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted"
731                )
732        else:
733            logger.warning(
734                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
735            )
736
737        if isinstance(cost, float):
738            usage.cost = cost
739        elif cost is not None:
740            # None is allowed, but no other types are expected
741            logger.warning(
742                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
743            )
744
745        return usage
746
747    async def cached_available_tools(self) -> list[KilnToolInterface]:
748        if self._cached_available_tools is None:
749            self._cached_available_tools = await self.available_tools()
750        return self._cached_available_tools
751
752    async def _tools_for_execution(self) -> list[KilnToolInterface]:
753        """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``)."""
754        registry = await self.cached_available_tools()
755        unmanaged = self.base_adapter_config.unmanaged_tools or []
756        return registry + unmanaged
757
758    async def litellm_tools(self) -> list[ToolCallDefinition]:
759        available_tools = await self.cached_available_tools()
760
761        registry_defs = [await tool.toolcall_definition() for tool in available_tools]
762        unmanaged = self.base_adapter_config.unmanaged_tools
763        unmanaged_defs = (
764            [await t.toolcall_definition() for t in unmanaged] if unmanaged else []
765        )
766
767        merged = registry_defs + unmanaged_defs
768        seen_names: set[str] = set()
769        for d in merged:
770            name = d["function"]["name"]
771            if name in seen_names:
772                raise ValueError(
773                    f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names."
774                )
775            seen_names.add(name)
776
777        return merged
778
779    async def process_tool_calls(
780        self, tool_calls: list[ChatCompletionMessageToolCall] | None
781    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
782        if tool_calls is None:
783            return None, []
784
785        assistant_output_from_toolcall: str | None = None
786        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
787        tool_run_coroutines = []
788
789        for tool_call in tool_calls:
790            # Kiln "task_response" tool is used for returning structured output via tool calls.
791            # Load the output from the tool call. Also
792            if tool_call.function.name == "task_response":
793                assistant_output_from_toolcall = tool_call.function.arguments
794                continue
795
796            # Process normal tool calls (not the "task_response" tool)
797            tool_name = tool_call.function.name
798            tool = None
799            for tool_option in await self._tools_for_execution():
800                if await tool_option.name() == tool_name:
801                    tool = tool_option
802                    break
803            if not tool:
804                raise RuntimeError(
805                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
806                )
807
808            # Parse the arguments and validate them against the tool's schema
809            try:
810                parsed_args = json.loads(tool_call.function.arguments)
811            except json.JSONDecodeError:
812                raise RuntimeError(
813                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
814                )
815            try:
816                tool_call_definition = await tool.toolcall_definition()
817                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
818                validate_schema_with_value_error(parsed_args, json_schema)
819            except Exception as e:
820                raise RuntimeError(
821                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
822                ) from e
823
824            # Create context with the calling task's allow_saving setting
825            context = ToolCallContext(
826                allow_saving=self.base_adapter_config.allow_saving
827            )
828
829            async def run_tool_and_format(
830                t=tool, c=context, args=parsed_args, tc_id=tool_call.id
831            ):
832                result = await t.run(c, **args)
833                return ChatCompletionToolMessageParamWrapper(
834                    role="tool",
835                    tool_call_id=tc_id,
836                    content=result.output,
837                    kiln_task_tool_data=result.kiln_task_tool_data
838                    if isinstance(result, KilnTaskToolResult)
839                    else None,
840                    is_error=result.is_error if result.is_error else None,
841                    error_message=result.error_message
842                    if result.error_message
843                    else None,
844                )
845
846            tool_run_coroutines.append(run_tool_and_format())
847
848        if tool_run_coroutines:
849            tool_call_response_messages = await asyncio.gather(*tool_run_coroutines)
850
851        if (
852            assistant_output_from_toolcall is not None
853            and len(tool_call_response_messages) > 0
854        ):
855            raise RuntimeError(
856                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
857            )
858
859        return assistant_output_from_toolcall, tool_call_response_messages
860
861    def litellm_message_to_trace_message(
862        self, raw_message: LiteLLMMessage
863    ) -> ChatCompletionAssistantMessageParamWrapper:
864        """
865        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
866        """
867        message: ChatCompletionAssistantMessageParamWrapper = {
868            "role": "assistant",
869        }
870        if raw_message.role != "assistant":
871            raise ValueError(
872                "Model returned a message with a role other than assistant. This is not supported."
873            )
874
875        if hasattr(raw_message, "content"):
876            message["content"] = raw_message.content
877        if hasattr(raw_message, "reasoning_content"):
878            message["reasoning_content"] = raw_message.reasoning_content
879        if hasattr(raw_message, "tool_calls"):
880            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
881            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
882            for litellm_tool_call in raw_message.tool_calls or []:
883                # Optional in the SDK for streaming responses, but should never be None at this point.
884                if litellm_tool_call.function.name is None:
885                    raise ValueError(
886                        "The model requested a tool call, without providing a function name (required)."
887                    )
888                open_ai_tool_calls.append(
889                    ChatCompletionMessageToolCallParam(
890                        id=litellm_tool_call.id,
891                        type="function",
892                        function={
893                            "name": litellm_tool_call.function.name,
894                            "arguments": litellm_tool_call.function.arguments,
895                        },
896                    )
897                )
898            if len(open_ai_tool_calls) > 0:
899                message["tool_calls"] = open_ai_tool_calls
900
901        if not message.get("content") and not message.get("tool_calls"):
902            raise ValueError(
903                "Model returned an assistant message, but no content or tool calls. This is not supported."
904            )
905
906        return message
907
908    def all_messages_to_trace(
909        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
910    ) -> list[ChatCompletionMessageParam]:
911        """
912        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
913        """
914        trace: list[ChatCompletionMessageParam] = []
915        for message in messages:
916            if isinstance(message, LiteLLMMessage):
917                trace.append(self.litellm_message_to_trace_message(message))
918            else:
919                trace.append(message)
920        return trace
MAX_CALLS_PER_TURN = 10
MAX_TOOL_CALLS_PER_TURN = 30
@dataclass
class ModelTurnResult:
75@dataclass
76class ModelTurnResult:
77    assistant_message: str
78    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
79    model_response: ModelResponse | None
80    model_choice: Choices | None
81    usage: Usage
82    interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None
ModelTurnResult( assistant_message: str, all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], model_response: litellm.types.utils.ModelResponse | None, model_choice: litellm.types.utils.Choices | None, usage: kiln_ai.datamodel.Usage, interrupted_by_tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None = None)
assistant_message: str
all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]
model_response: litellm.types.utils.ModelResponse | None
model_choice: litellm.types.utils.Choices | None
interrupted_by_tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None = None
 85class LiteLlmAdapter(BaseAdapter):
 86    def __init__(
 87        self,
 88        config: LiteLlmConfig,
 89        kiln_task: datamodel.Task,
 90        base_adapter_config: AdapterConfig | None = None,
 91    ):
 92        if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties):
 93            raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties")
 94        self.config = config
 95        self._additional_body_options = config.additional_body_options
 96        self._api_base = config.base_url
 97        self._headers = config.default_headers
 98        self._litellm_model_id: str | None = None
 99        self._cached_available_tools: list[KilnToolInterface] | None = None
100
101        super().__init__(
102            task=kiln_task,
103            run_config=config.run_config_properties,
104            config=base_adapter_config,
105        )
106
107        unmanaged_tools = self.base_adapter_config.unmanaged_tools
108        if unmanaged_tools:
109            _validate_unmanaged_tools(unmanaged_tools)
110
111    async def _run_model_turn(
112        self,
113        provider: KilnModelProvider,
114        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
115        top_logprobs: int | None,
116        skip_response_format: bool,
117    ) -> ModelTurnResult:
118        """
119        Call the model for a single top level turn: from user message to agent message.
120
121        It may make handle iterations of tool calls between the user/agent message if needed.
122        """
123
124        usage = Usage()
125        messages = list(prior_messages)
126        tool_calls_count = 0
127
128        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
129            # Build completion kwargs for tool calls
130            completion_kwargs = await self.build_completion_kwargs(
131                provider,
132                # Pass a copy, as acompletion mutates objects and breaks types.
133                copy.deepcopy(messages),
134                top_logprobs,
135                skip_response_format,
136            )
137
138            # Make the completion call
139            model_response, response_choice = await self.acompletion_checking_response(
140                **completion_kwargs
141            )
142
143            # count the usage
144            usage += self.usage_from_response(model_response)
145
146            # Extract content and tool calls
147            if not hasattr(response_choice, "message"):
148                raise ValueError("Response choice has no message")
149            content = response_choice.message.content
150            tool_calls = response_choice.message.tool_calls
151            if not content and not tool_calls:
152                raise ValueError(
153                    "Model returned an assistant message, but no content or tool calls. This is not supported."
154                )
155
156            # Add message to messages, so it can be used in the next turn
157            messages.append(response_choice.message)
158
159            # Process tool calls if any
160            if tool_calls and len(tool_calls) > 0:
161                # check if we should return control to caller
162                if self.base_adapter_config.return_on_tool_call:
163                    # filter out task_response tool (task_response tools are internal)
164                    standard_tool_calls = [
165                        tc for tc in tool_calls if tc.function.name != "task_response"
166                    ]
167                    has_task_response = any(
168                        tc.function.name == "task_response" for tc in tool_calls
169                    )
170                    if standard_tool_calls and not has_task_response:
171                        return ModelTurnResult(
172                            # we don't have any content, we are waiting for toolcall output to come back from client
173                            assistant_message="",
174                            all_messages=messages,
175                            model_response=model_response,
176                            model_choice=response_choice,
177                            usage=usage,
178                            interrupted_by_tool_calls=standard_tool_calls,
179                        )
180
181                # otherwise: process tool calls internally until final output
182                (
183                    assistant_message_from_toolcall,
184                    tool_call_messages,
185                ) = await self.process_tool_calls(tool_calls)
186
187                # Add tool call results to messages
188                messages.extend(tool_call_messages)
189
190                # If task_response tool was called, we're done
191                if assistant_message_from_toolcall is not None:
192                    return ModelTurnResult(
193                        assistant_message=assistant_message_from_toolcall,
194                        all_messages=messages,
195                        model_response=model_response,
196                        model_choice=response_choice,
197                        usage=usage,
198                    )
199
200                # If there were tool calls, increment counter and continue
201                if tool_call_messages:
202                    tool_calls_count += 1
203                    continue
204
205            # If no tool calls, return the content as final output
206            if content:
207                return ModelTurnResult(
208                    assistant_message=content,
209                    all_messages=messages,
210                    model_response=model_response,
211                    model_choice=response_choice,
212                    usage=usage,
213                )
214
215            # If we get here with no content and no tool calls, break
216            raise RuntimeError(
217                "Model returned neither content nor tool calls. It must return at least one of these."
218            )
219
220        raise RuntimeError(
221            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
222        )
223
224    async def _run(
225        self,
226        input: InputType,
227        prior_trace: list[ChatCompletionMessageParam] | None = None,
228    ) -> tuple[RunOutput, Usage | None]:
229        usage = Usage()
230
231        provider = self.model_provider()
232        if not provider.model_id:
233            raise ValueError("Model ID is required for OpenAI compatible models")
234
235        # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter
236        chat_formatter = self.build_chat_formatter(input, prior_trace)
237        messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy(
238            chat_formatter.initial_messages()
239        )
240
241        prior_output: str | None = None
242        final_choice: Choices | None = None
243        turns = 0
244
245        # Same loop for both fresh runs and prior_trace continuation.
246        # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues).
247        while True:
248            turns += 1
249            if turns > MAX_CALLS_PER_TURN:
250                raise RuntimeError(
251                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
252                )
253
254            turn = chat_formatter.next_turn(prior_output)
255            if turn is None:
256                # No next turn, we're done
257                break
258
259            # Add messages from the turn to chat history
260            for message in turn.messages:
261                if message.content is None:
262                    raise ValueError("Empty message content isn't allowed")
263                # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
264                msg_dict: dict = {"role": message.role, "content": message.content}
265                if isinstance(message, ToolResponseMessage):
266                    msg_dict["tool_call_id"] = message.tool_call_id
267                messages.append(msg_dict)  # type: ignore
268
269            skip_response_format = not turn.final_call
270            turn_result = await self._run_model_turn(
271                provider,
272                messages,
273                self.base_adapter_config.top_logprobs if turn.final_call else None,
274                skip_response_format,
275            )
276
277            usage += turn_result.usage
278
279            prior_output = turn_result.assistant_message
280            messages = turn_result.all_messages
281            final_choice = turn_result.model_choice
282
283            # Check if we were interrupted by tool calls
284            if turn_result.interrupted_by_tool_calls:
285                trace = self.all_messages_to_trace(messages)
286                intermediate_outputs = chat_formatter.intermediate_outputs()
287                output = RunOutput(
288                    output=prior_output or "",
289                    intermediate_outputs=intermediate_outputs,
290                    output_logprobs=None,
291                    trace=trace,
292                )
293                return output, usage
294
295            if not prior_output:
296                raise RuntimeError("No assistant message/output returned from model")
297
298        logprobs = self._extract_and_validate_logprobs(final_choice)
299
300        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
301        intermediate_outputs = chat_formatter.intermediate_outputs()
302        self._extract_reasoning_to_intermediate_outputs(
303            final_choice, intermediate_outputs
304        )
305
306        if not isinstance(prior_output, str):
307            raise RuntimeError(f"assistant message is not a string: {prior_output}")
308
309        trace = self.all_messages_to_trace(messages)
310        output = RunOutput(
311            output=prior_output,
312            intermediate_outputs=intermediate_outputs,
313            output_logprobs=logprobs,
314            trace=trace,
315        )
316
317        return output, usage
318
319    def _create_run_stream(
320        self,
321        input: InputType,
322        prior_trace: list[ChatCompletionMessageParam] | None = None,
323    ) -> AdapterStream:
324        provider = self.model_provider()
325        if not provider.model_id:
326            raise ValueError("Model ID is required for OpenAI compatible models")
327
328        chat_formatter = self.build_chat_formatter(input, prior_trace)
329        initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy(
330            chat_formatter.initial_messages()
331        )
332
333        return AdapterStream(
334            adapter=self,
335            provider=provider,
336            chat_formatter=chat_formatter,
337            initial_messages=initial_messages,
338            top_logprobs=self.base_adapter_config.top_logprobs,
339        )
340
341    def _extract_and_validate_logprobs(
342        self, final_choice: Choices | None
343    ) -> ChoiceLogprobs | None:
344        """
345        Extract logprobs from the final choice and validate they exist if required.
346        """
347        logprobs = None
348        if (
349            final_choice is not None
350            and hasattr(final_choice, "logprobs")
351            and isinstance(final_choice.logprobs, ChoiceLogprobs)
352        ):
353            logprobs = final_choice.logprobs
354
355        # Check logprobs worked, if required
356        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
357            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
358
359        return logprobs
360
361    def _extract_reasoning_to_intermediate_outputs(
362        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
363    ) -> None:
364        """Extract reasoning content from model choice and add to intermediate outputs if present."""
365        if (
366            final_choice is not None
367            and hasattr(final_choice, "message")
368            and hasattr(final_choice.message, "reasoning_content")
369        ):
370            reasoning_content = final_choice.message.reasoning_content
371            if reasoning_content is not None:
372                stripped_reasoning_content = reasoning_content.strip()
373                if len(stripped_reasoning_content) > 0:
374                    intermediate_outputs["reasoning"] = stripped_reasoning_content
375
376    async def acompletion_checking_response(
377        self, **kwargs: Any
378    ) -> Tuple[ModelResponse, Choices]:
379        response = await litellm.acompletion(**kwargs)
380
381        if (
382            not isinstance(response, ModelResponse)
383            or not response.choices
384            or len(response.choices) == 0
385            or not isinstance(response.choices[0], Choices)
386        ):
387            raise RuntimeError(
388                f"Expected ModelResponse with Choices, got {type(response)}."
389            )
390        return response, response.choices[0]
391
392    def adapter_name(self) -> str:
393        return "kiln_openai_compatible_adapter"
394
395    async def response_format_options(self) -> dict[str, Any]:
396        # Unstructured if task isn't structured
397        if not self.has_structured_output():
398            return {}
399
400        run_config = as_kiln_agent_run_config(self.run_config)
401        structured_output_mode: StructuredOutputMode = run_config.structured_output_mode
402
403        match structured_output_mode:
404            case StructuredOutputMode.json_mode:
405                return {"response_format": {"type": "json_object"}}
406            case StructuredOutputMode.json_schema:
407                return self.json_schema_response_format()
408            case StructuredOutputMode.function_calling_weak:
409                return self.tool_call_params(strict=False)
410            case StructuredOutputMode.function_calling:
411                return self.tool_call_params(strict=True)
412            case StructuredOutputMode.json_instructions:
413                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
414                return {}
415            case StructuredOutputMode.json_custom_instructions:
416                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
417                return {}
418            case StructuredOutputMode.json_instruction_and_object:
419                # We set response_format to json_object and also set json instructions in the prompt
420                return {"response_format": {"type": "json_object"}}
421            case StructuredOutputMode.default:
422                provider_name = run_config.model_provider_name
423                if provider_name == ModelProviderName.ollama:
424                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
425                    return self.json_schema_response_format()
426                elif provider_name == ModelProviderName.docker_model_runner:
427                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
428                    return self.json_schema_response_format()
429                else:
430                    # Default to function calling -- it's older than the other modes. Higher compatibility.
431                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
432                    strict = provider_name == ModelProviderName.openai
433                    return self.tool_call_params(strict=strict)
434            case StructuredOutputMode.unknown:
435                # See above, but this case should never happen.
436                raise ValueError("Structured output mode is unknown.")
437            case _:
438                raise_exhaustive_enum_error(structured_output_mode)  # type: ignore[arg-type]
439
440    def json_schema_response_format(self) -> dict[str, Any]:
441        output_schema = self.task.output_schema()
442        if output_schema is None:
443            raise ValueError(
444                "Invalid output schema for this task. Cannot use JSON schema response format."
445            )
446        output_schema = close_object_schemas(output_schema, strict=True)
447        return {
448            "response_format": {
449                "type": "json_schema",
450                "json_schema": {
451                    "name": "task_response",
452                    "schema": output_schema,
453                },
454            }
455        }
456
457    def tool_call_params(self, strict: bool) -> dict[str, Any]:
458        # Add additional_properties: false to the schema (OpenAI requires this for some models)
459        output_schema = self.task.output_schema()
460        if not isinstance(output_schema, dict):
461            raise ValueError(
462                "Invalid output schema for this task. Can not use tool calls."
463            )
464        output_schema = close_object_schemas(output_schema, strict=strict)
465
466        function_params = {
467            "name": "task_response",
468            "parameters": output_schema,
469        }
470        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
471        if strict:
472            function_params["strict"] = True
473
474        return {
475            "tools": [
476                {
477                    "type": "function",
478                    "function": function_params,
479                }
480            ],
481            "tool_choice": {
482                "type": "function",
483                "function": {"name": "task_response"},
484            },
485        }
486
487    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
488        # Don't love having this logic here. But it's worth the usability improvement
489        # so better to keep it than exclude it. Should figure out how I want to isolate
490        # this sort of logic so it's config driven and can be overridden
491        extra_body: dict[str, Any] = {}
492        provider_options = {}
493
494        run_config = as_kiln_agent_run_config(self.run_config)
495        # For legacy config 'thinking_level' is not set, default to provider's default
496        if "thinking_level" in run_config.model_fields_set:
497            thinking_level = run_config.thinking_level
498        else:
499            thinking_level = provider.default_thinking_level
500
501        # Set the reasoning_effort
502        if thinking_level is not None:
503            # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens
504            if (
505                provider.name == ModelProviderName.openrouter
506                and provider.openrouter_reasoning_object
507            ):
508                extra_body["reasoning"] = {"effort": thinking_level}
509            else:
510                extra_body["reasoning_effort"] = thinking_level
511
512        if provider.require_openrouter_reasoning:
513            # https://openrouter.ai/docs/use-cases/reasoning-tokens
514            extra_body["reasoning"] = {
515                "exclude": False,
516            }
517
518        if provider.gemini_reasoning_enabled:
519            extra_body["reasoning"] = {
520                "enabled": True,
521            }
522
523        if provider.name == ModelProviderName.openrouter:
524            # Ask OpenRouter to include usage in the response (cost)
525            extra_body["usage"] = {"include": True}
526
527            # Set a default provider order for more deterministic routing.
528            # OpenRouter will ignore providers that don't support the model.
529            # Special cases below (like R1) can override this order.
530            # allow_fallbacks is true by default, but we can override it here.
531            provider_options["order"] = [
532                "fireworks",
533                "parasail",
534                "together",
535                "deepinfra",
536                "novita",
537                "groq",
538                "amazon-bedrock",
539                "azure",
540                "nebius",
541            ]
542
543        if provider.anthropic_extended_thinking:
544            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
545
546        if provider.r1_openrouter_options:
547            # Require providers that support the reasoning parameter
548            provider_options["require_parameters"] = True
549            # Prefer R1 providers with reasonable perf/quants
550            provider_options["order"] = ["fireworks", "together"]
551            # R1 providers with unreasonable quants
552            provider_options["ignore"] = ["deepinfra"]
553
554        # Only set of this request is to get logprobs.
555        if (
556            provider.logprobs_openrouter_options
557            and self.base_adapter_config.top_logprobs is not None
558        ):
559            # Don't let OpenRouter choose a provider that doesn't support logprobs.
560            provider_options["require_parameters"] = True
561            # DeepInfra silently fails to return logprobs consistently.
562            provider_options["ignore"] = ["deepinfra"]
563
564        if provider.openrouter_skip_required_parameters:
565            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
566            provider_options["require_parameters"] = False
567
568        # Siliconflow uses a bool flag for thinking, for some models
569        if provider.siliconflow_enable_thinking is not None:
570            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
571
572        if len(provider_options) > 0:
573            extra_body["provider"] = provider_options
574
575        return extra_body
576
577    def litellm_model_id(self) -> str:
578        # The model ID is an interesting combination of format and url endpoint.
579        # It specifics the provider URL/host, but this is overridden if you manually set an api url
580        if self._litellm_model_id:
581            return self._litellm_model_id
582
583        litellm_provider_info = get_litellm_provider_info(self.model_provider())
584        if litellm_provider_info.is_custom and self._api_base is None:
585            raise ValueError(
586                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
587            )
588
589        self._litellm_model_id = litellm_provider_info.litellm_model_id
590        return self._litellm_model_id
591
592    def _allowed_openai_params_for_completion_kwargs(
593        self, completion_kwargs: dict[str, Any]
594    ) -> list[str]:
595        """
596        LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong
597        and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter.
598        """
599        # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that
600        explicit_allowed_params: Any | list = completion_kwargs.get(
601            "allowed_openai_params", []
602        )
603        if not isinstance(explicit_allowed_params, list):
604            raise ValueError(
605                f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}"
606            )
607        explicit_allowed_params_validated = [
608            param for param in explicit_allowed_params if isinstance(param, str)
609        ]
610        invalid_count = len(explicit_allowed_params) - len(
611            explicit_allowed_params_validated
612        )
613        if invalid_count > 0:
614            raise ValueError(
615                f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings"
616            )
617
618        # these are our own logic
619        automatic_allowed_params: list[str] = []
620        if "tools" in completion_kwargs:
621            automatic_allowed_params.append("tools")
622        if "tool_choice" in completion_kwargs:
623            automatic_allowed_params.append("tool_choice")
624
625        return list(set(explicit_allowed_params_validated + automatic_allowed_params))
626
627    async def build_completion_kwargs(
628        self,
629        provider: KilnModelProvider,
630        messages: list[ChatCompletionMessageIncludingLiteLLM],
631        top_logprobs: int | None,
632        skip_response_format: bool = False,
633    ) -> dict[str, Any]:
634        run_config = as_kiln_agent_run_config(self.run_config)
635        extra_body = self.build_extra_body(provider)
636
637        # Merge all parameters into a single kwargs dict for litellm
638        completion_kwargs = {
639            "model": self.litellm_model_id(),
640            "messages": messages,
641            "api_base": self._api_base,
642            "headers": self._headers,
643            "temperature": run_config.temperature,
644            "top_p": run_config.top_p,
645            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
646            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
647            # Better to ignore them than to fail the model call.
648            # https://docs.litellm.ai/docs/completion/input
649            "drop_params": True,
650            **extra_body,
651            **self._additional_body_options,
652        }
653
654        if self.base_adapter_config.automatic_prompt_caching:
655            # Mark the last message for cache control. Litellm's AnthropicCacheControlHook
656            # handles provider-specific injection. Providers auto-cache matching prefixes,
657            # so marking the last message is sufficient for multi-turn conversations.
658            completion_kwargs["cache_control_injection_points"] = [
659                {"location": "message", "index": -1}
660            ]
661
662        tool_calls = await self.litellm_tools()
663        has_tools = len(tool_calls) > 0
664        if has_tools:
665            completion_kwargs["tools"] = tool_calls
666            completion_kwargs["tool_choice"] = "auto"
667
668        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
669        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
670        if provider.temp_top_p_exclusive:
671            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
672                del completion_kwargs["top_p"]
673            if (
674                "temperature" in completion_kwargs
675                and completion_kwargs["temperature"] == 1.0
676            ):
677                del completion_kwargs["temperature"]
678            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
679                raise ValueError(
680                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
681                )
682
683        if not skip_response_format:
684            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
685            response_format_options = await self.response_format_options()
686
687            # Check for a conflict between tools and response format using tools
688            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
689            if has_tools and "tools" in response_format_options:
690                raise ValueError(
691                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
692                )
693
694            completion_kwargs.update(response_format_options)
695
696        if top_logprobs is not None:
697            completion_kwargs["logprobs"] = True
698            completion_kwargs["top_logprobs"] = top_logprobs
699
700        # any params listed in this list will be passed to the model regardless of LiteLLM's own validation
701        allowed_openai_params = self._allowed_openai_params_for_completion_kwargs(
702            completion_kwargs
703        )
704        if len(allowed_openai_params) > 0:
705            completion_kwargs["allowed_openai_params"] = allowed_openai_params
706
707        return completion_kwargs
708
709    def usage_from_response(self, response: ModelResponse) -> Usage:
710        litellm_usage = response.get("usage", None)
711
712        # LiteLLM isn't consistent in how it returns the cost.
713        cost = response._hidden_params.get("response_cost", None)
714        if cost is None and litellm_usage:
715            cost = litellm_usage.get("cost", None)
716
717        usage = Usage()
718
719        if not litellm_usage and not cost:
720            return usage
721
722        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
723            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
724            usage.output_tokens = litellm_usage.get("completion_tokens", None)
725            usage.total_tokens = litellm_usage.get("total_tokens", None)
726            prompt_details = litellm_usage.get("prompt_tokens_details", None)
727            if prompt_details and hasattr(prompt_details, "cached_tokens"):
728                usage.cached_tokens = prompt_details.cached_tokens
729            elif prompt_details:
730                logger.warning(
731                    f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted"
732                )
733        else:
734            logger.warning(
735                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
736            )
737
738        if isinstance(cost, float):
739            usage.cost = cost
740        elif cost is not None:
741            # None is allowed, but no other types are expected
742            logger.warning(
743                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
744            )
745
746        return usage
747
748    async def cached_available_tools(self) -> list[KilnToolInterface]:
749        if self._cached_available_tools is None:
750            self._cached_available_tools = await self.available_tools()
751        return self._cached_available_tools
752
753    async def _tools_for_execution(self) -> list[KilnToolInterface]:
754        """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``)."""
755        registry = await self.cached_available_tools()
756        unmanaged = self.base_adapter_config.unmanaged_tools or []
757        return registry + unmanaged
758
759    async def litellm_tools(self) -> list[ToolCallDefinition]:
760        available_tools = await self.cached_available_tools()
761
762        registry_defs = [await tool.toolcall_definition() for tool in available_tools]
763        unmanaged = self.base_adapter_config.unmanaged_tools
764        unmanaged_defs = (
765            [await t.toolcall_definition() for t in unmanaged] if unmanaged else []
766        )
767
768        merged = registry_defs + unmanaged_defs
769        seen_names: set[str] = set()
770        for d in merged:
771            name = d["function"]["name"]
772            if name in seen_names:
773                raise ValueError(
774                    f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names."
775                )
776            seen_names.add(name)
777
778        return merged
779
780    async def process_tool_calls(
781        self, tool_calls: list[ChatCompletionMessageToolCall] | None
782    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
783        if tool_calls is None:
784            return None, []
785
786        assistant_output_from_toolcall: str | None = None
787        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
788        tool_run_coroutines = []
789
790        for tool_call in tool_calls:
791            # Kiln "task_response" tool is used for returning structured output via tool calls.
792            # Load the output from the tool call. Also
793            if tool_call.function.name == "task_response":
794                assistant_output_from_toolcall = tool_call.function.arguments
795                continue
796
797            # Process normal tool calls (not the "task_response" tool)
798            tool_name = tool_call.function.name
799            tool = None
800            for tool_option in await self._tools_for_execution():
801                if await tool_option.name() == tool_name:
802                    tool = tool_option
803                    break
804            if not tool:
805                raise RuntimeError(
806                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
807                )
808
809            # Parse the arguments and validate them against the tool's schema
810            try:
811                parsed_args = json.loads(tool_call.function.arguments)
812            except json.JSONDecodeError:
813                raise RuntimeError(
814                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
815                )
816            try:
817                tool_call_definition = await tool.toolcall_definition()
818                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
819                validate_schema_with_value_error(parsed_args, json_schema)
820            except Exception as e:
821                raise RuntimeError(
822                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
823                ) from e
824
825            # Create context with the calling task's allow_saving setting
826            context = ToolCallContext(
827                allow_saving=self.base_adapter_config.allow_saving
828            )
829
830            async def run_tool_and_format(
831                t=tool, c=context, args=parsed_args, tc_id=tool_call.id
832            ):
833                result = await t.run(c, **args)
834                return ChatCompletionToolMessageParamWrapper(
835                    role="tool",
836                    tool_call_id=tc_id,
837                    content=result.output,
838                    kiln_task_tool_data=result.kiln_task_tool_data
839                    if isinstance(result, KilnTaskToolResult)
840                    else None,
841                    is_error=result.is_error if result.is_error else None,
842                    error_message=result.error_message
843                    if result.error_message
844                    else None,
845                )
846
847            tool_run_coroutines.append(run_tool_and_format())
848
849        if tool_run_coroutines:
850            tool_call_response_messages = await asyncio.gather(*tool_run_coroutines)
851
852        if (
853            assistant_output_from_toolcall is not None
854            and len(tool_call_response_messages) > 0
855        ):
856            raise RuntimeError(
857                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
858            )
859
860        return assistant_output_from_toolcall, tool_call_response_messages
861
862    def litellm_message_to_trace_message(
863        self, raw_message: LiteLLMMessage
864    ) -> ChatCompletionAssistantMessageParamWrapper:
865        """
866        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
867        """
868        message: ChatCompletionAssistantMessageParamWrapper = {
869            "role": "assistant",
870        }
871        if raw_message.role != "assistant":
872            raise ValueError(
873                "Model returned a message with a role other than assistant. This is not supported."
874            )
875
876        if hasattr(raw_message, "content"):
877            message["content"] = raw_message.content
878        if hasattr(raw_message, "reasoning_content"):
879            message["reasoning_content"] = raw_message.reasoning_content
880        if hasattr(raw_message, "tool_calls"):
881            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
882            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
883            for litellm_tool_call in raw_message.tool_calls or []:
884                # Optional in the SDK for streaming responses, but should never be None at this point.
885                if litellm_tool_call.function.name is None:
886                    raise ValueError(
887                        "The model requested a tool call, without providing a function name (required)."
888                    )
889                open_ai_tool_calls.append(
890                    ChatCompletionMessageToolCallParam(
891                        id=litellm_tool_call.id,
892                        type="function",
893                        function={
894                            "name": litellm_tool_call.function.name,
895                            "arguments": litellm_tool_call.function.arguments,
896                        },
897                    )
898                )
899            if len(open_ai_tool_calls) > 0:
900                message["tool_calls"] = open_ai_tool_calls
901
902        if not message.get("content") and not message.get("tool_calls"):
903            raise ValueError(
904                "Model returned an assistant message, but no content or tool calls. This is not supported."
905            )
906
907        return message
908
909    def all_messages_to_trace(
910        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
911    ) -> list[ChatCompletionMessageParam]:
912        """
913        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
914        """
915        trace: list[ChatCompletionMessageParam] = []
916        for message in messages:
917            if isinstance(message, LiteLLMMessage):
918                trace.append(self.litellm_message_to_trace_message(message))
919            else:
920                trace.append(message)
921        return trace

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Prompt building is handled internally by the adapter, which uses a prompt builder based on the run config. To override the prompt building behavior, pass a custom prompt builder to the adapter config.

LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
 86    def __init__(
 87        self,
 88        config: LiteLlmConfig,
 89        kiln_task: datamodel.Task,
 90        base_adapter_config: AdapterConfig | None = None,
 91    ):
 92        if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties):
 93            raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties")
 94        self.config = config
 95        self._additional_body_options = config.additional_body_options
 96        self._api_base = config.base_url
 97        self._headers = config.default_headers
 98        self._litellm_model_id: str | None = None
 99        self._cached_available_tools: list[KilnToolInterface] | None = None
100
101        super().__init__(
102            task=kiln_task,
103            run_config=config.run_config_properties,
104            config=base_adapter_config,
105        )
106
107        unmanaged_tools = self.base_adapter_config.unmanaged_tools
108        if unmanaged_tools:
109            _validate_unmanaged_tools(unmanaged_tools)
config
async def acompletion_checking_response( self, **kwargs: Any) -> Tuple[litellm.types.utils.ModelResponse, litellm.types.utils.Choices]:
376    async def acompletion_checking_response(
377        self, **kwargs: Any
378    ) -> Tuple[ModelResponse, Choices]:
379        response = await litellm.acompletion(**kwargs)
380
381        if (
382            not isinstance(response, ModelResponse)
383            or not response.choices
384            or len(response.choices) == 0
385            or not isinstance(response.choices[0], Choices)
386        ):
387            raise RuntimeError(
388                f"Expected ModelResponse with Choices, got {type(response)}."
389            )
390        return response, response.choices[0]
def adapter_name(self) -> str:
392    def adapter_name(self) -> str:
393        return "kiln_openai_compatible_adapter"
async def response_format_options(self) -> dict[str, typing.Any]:
395    async def response_format_options(self) -> dict[str, Any]:
396        # Unstructured if task isn't structured
397        if not self.has_structured_output():
398            return {}
399
400        run_config = as_kiln_agent_run_config(self.run_config)
401        structured_output_mode: StructuredOutputMode = run_config.structured_output_mode
402
403        match structured_output_mode:
404            case StructuredOutputMode.json_mode:
405                return {"response_format": {"type": "json_object"}}
406            case StructuredOutputMode.json_schema:
407                return self.json_schema_response_format()
408            case StructuredOutputMode.function_calling_weak:
409                return self.tool_call_params(strict=False)
410            case StructuredOutputMode.function_calling:
411                return self.tool_call_params(strict=True)
412            case StructuredOutputMode.json_instructions:
413                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
414                return {}
415            case StructuredOutputMode.json_custom_instructions:
416                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
417                return {}
418            case StructuredOutputMode.json_instruction_and_object:
419                # We set response_format to json_object and also set json instructions in the prompt
420                return {"response_format": {"type": "json_object"}}
421            case StructuredOutputMode.default:
422                provider_name = run_config.model_provider_name
423                if provider_name == ModelProviderName.ollama:
424                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
425                    return self.json_schema_response_format()
426                elif provider_name == ModelProviderName.docker_model_runner:
427                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
428                    return self.json_schema_response_format()
429                else:
430                    # Default to function calling -- it's older than the other modes. Higher compatibility.
431                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
432                    strict = provider_name == ModelProviderName.openai
433                    return self.tool_call_params(strict=strict)
434            case StructuredOutputMode.unknown:
435                # See above, but this case should never happen.
436                raise ValueError("Structured output mode is unknown.")
437            case _:
438                raise_exhaustive_enum_error(structured_output_mode)  # type: ignore[arg-type]
def json_schema_response_format(self) -> dict[str, typing.Any]:
440    def json_schema_response_format(self) -> dict[str, Any]:
441        output_schema = self.task.output_schema()
442        if output_schema is None:
443            raise ValueError(
444                "Invalid output schema for this task. Cannot use JSON schema response format."
445            )
446        output_schema = close_object_schemas(output_schema, strict=True)
447        return {
448            "response_format": {
449                "type": "json_schema",
450                "json_schema": {
451                    "name": "task_response",
452                    "schema": output_schema,
453                },
454            }
455        }
def tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
457    def tool_call_params(self, strict: bool) -> dict[str, Any]:
458        # Add additional_properties: false to the schema (OpenAI requires this for some models)
459        output_schema = self.task.output_schema()
460        if not isinstance(output_schema, dict):
461            raise ValueError(
462                "Invalid output schema for this task. Can not use tool calls."
463            )
464        output_schema = close_object_schemas(output_schema, strict=strict)
465
466        function_params = {
467            "name": "task_response",
468            "parameters": output_schema,
469        }
470        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
471        if strict:
472            function_params["strict"] = True
473
474        return {
475            "tools": [
476                {
477                    "type": "function",
478                    "function": function_params,
479                }
480            ],
481            "tool_choice": {
482                "type": "function",
483                "function": {"name": "task_response"},
484            },
485        }
def build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
487    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
488        # Don't love having this logic here. But it's worth the usability improvement
489        # so better to keep it than exclude it. Should figure out how I want to isolate
490        # this sort of logic so it's config driven and can be overridden
491        extra_body: dict[str, Any] = {}
492        provider_options = {}
493
494        run_config = as_kiln_agent_run_config(self.run_config)
495        # For legacy config 'thinking_level' is not set, default to provider's default
496        if "thinking_level" in run_config.model_fields_set:
497            thinking_level = run_config.thinking_level
498        else:
499            thinking_level = provider.default_thinking_level
500
501        # Set the reasoning_effort
502        if thinking_level is not None:
503            # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens
504            if (
505                provider.name == ModelProviderName.openrouter
506                and provider.openrouter_reasoning_object
507            ):
508                extra_body["reasoning"] = {"effort": thinking_level}
509            else:
510                extra_body["reasoning_effort"] = thinking_level
511
512        if provider.require_openrouter_reasoning:
513            # https://openrouter.ai/docs/use-cases/reasoning-tokens
514            extra_body["reasoning"] = {
515                "exclude": False,
516            }
517
518        if provider.gemini_reasoning_enabled:
519            extra_body["reasoning"] = {
520                "enabled": True,
521            }
522
523        if provider.name == ModelProviderName.openrouter:
524            # Ask OpenRouter to include usage in the response (cost)
525            extra_body["usage"] = {"include": True}
526
527            # Set a default provider order for more deterministic routing.
528            # OpenRouter will ignore providers that don't support the model.
529            # Special cases below (like R1) can override this order.
530            # allow_fallbacks is true by default, but we can override it here.
531            provider_options["order"] = [
532                "fireworks",
533                "parasail",
534                "together",
535                "deepinfra",
536                "novita",
537                "groq",
538                "amazon-bedrock",
539                "azure",
540                "nebius",
541            ]
542
543        if provider.anthropic_extended_thinking:
544            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
545
546        if provider.r1_openrouter_options:
547            # Require providers that support the reasoning parameter
548            provider_options["require_parameters"] = True
549            # Prefer R1 providers with reasonable perf/quants
550            provider_options["order"] = ["fireworks", "together"]
551            # R1 providers with unreasonable quants
552            provider_options["ignore"] = ["deepinfra"]
553
554        # Only set of this request is to get logprobs.
555        if (
556            provider.logprobs_openrouter_options
557            and self.base_adapter_config.top_logprobs is not None
558        ):
559            # Don't let OpenRouter choose a provider that doesn't support logprobs.
560            provider_options["require_parameters"] = True
561            # DeepInfra silently fails to return logprobs consistently.
562            provider_options["ignore"] = ["deepinfra"]
563
564        if provider.openrouter_skip_required_parameters:
565            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
566            provider_options["require_parameters"] = False
567
568        # Siliconflow uses a bool flag for thinking, for some models
569        if provider.siliconflow_enable_thinking is not None:
570            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
571
572        if len(provider_options) > 0:
573            extra_body["provider"] = provider_options
574
575        return extra_body
def litellm_model_id(self) -> str:
577    def litellm_model_id(self) -> str:
578        # The model ID is an interesting combination of format and url endpoint.
579        # It specifics the provider URL/host, but this is overridden if you manually set an api url
580        if self._litellm_model_id:
581            return self._litellm_model_id
582
583        litellm_provider_info = get_litellm_provider_info(self.model_provider())
584        if litellm_provider_info.is_custom and self._api_base is None:
585            raise ValueError(
586                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
587            )
588
589        self._litellm_model_id = litellm_provider_info.litellm_model_id
590        return self._litellm_model_id
async def build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
627    async def build_completion_kwargs(
628        self,
629        provider: KilnModelProvider,
630        messages: list[ChatCompletionMessageIncludingLiteLLM],
631        top_logprobs: int | None,
632        skip_response_format: bool = False,
633    ) -> dict[str, Any]:
634        run_config = as_kiln_agent_run_config(self.run_config)
635        extra_body = self.build_extra_body(provider)
636
637        # Merge all parameters into a single kwargs dict for litellm
638        completion_kwargs = {
639            "model": self.litellm_model_id(),
640            "messages": messages,
641            "api_base": self._api_base,
642            "headers": self._headers,
643            "temperature": run_config.temperature,
644            "top_p": run_config.top_p,
645            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
646            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
647            # Better to ignore them than to fail the model call.
648            # https://docs.litellm.ai/docs/completion/input
649            "drop_params": True,
650            **extra_body,
651            **self._additional_body_options,
652        }
653
654        if self.base_adapter_config.automatic_prompt_caching:
655            # Mark the last message for cache control. Litellm's AnthropicCacheControlHook
656            # handles provider-specific injection. Providers auto-cache matching prefixes,
657            # so marking the last message is sufficient for multi-turn conversations.
658            completion_kwargs["cache_control_injection_points"] = [
659                {"location": "message", "index": -1}
660            ]
661
662        tool_calls = await self.litellm_tools()
663        has_tools = len(tool_calls) > 0
664        if has_tools:
665            completion_kwargs["tools"] = tool_calls
666            completion_kwargs["tool_choice"] = "auto"
667
668        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
669        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
670        if provider.temp_top_p_exclusive:
671            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
672                del completion_kwargs["top_p"]
673            if (
674                "temperature" in completion_kwargs
675                and completion_kwargs["temperature"] == 1.0
676            ):
677                del completion_kwargs["temperature"]
678            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
679                raise ValueError(
680                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
681                )
682
683        if not skip_response_format:
684            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
685            response_format_options = await self.response_format_options()
686
687            # Check for a conflict between tools and response format using tools
688            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
689            if has_tools and "tools" in response_format_options:
690                raise ValueError(
691                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
692                )
693
694            completion_kwargs.update(response_format_options)
695
696        if top_logprobs is not None:
697            completion_kwargs["logprobs"] = True
698            completion_kwargs["top_logprobs"] = top_logprobs
699
700        # any params listed in this list will be passed to the model regardless of LiteLLM's own validation
701        allowed_openai_params = self._allowed_openai_params_for_completion_kwargs(
702            completion_kwargs
703        )
704        if len(allowed_openai_params) > 0:
705            completion_kwargs["allowed_openai_params"] = allowed_openai_params
706
707        return completion_kwargs
def usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage:
709    def usage_from_response(self, response: ModelResponse) -> Usage:
710        litellm_usage = response.get("usage", None)
711
712        # LiteLLM isn't consistent in how it returns the cost.
713        cost = response._hidden_params.get("response_cost", None)
714        if cost is None and litellm_usage:
715            cost = litellm_usage.get("cost", None)
716
717        usage = Usage()
718
719        if not litellm_usage and not cost:
720            return usage
721
722        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
723            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
724            usage.output_tokens = litellm_usage.get("completion_tokens", None)
725            usage.total_tokens = litellm_usage.get("total_tokens", None)
726            prompt_details = litellm_usage.get("prompt_tokens_details", None)
727            if prompt_details and hasattr(prompt_details, "cached_tokens"):
728                usage.cached_tokens = prompt_details.cached_tokens
729            elif prompt_details:
730                logger.warning(
731                    f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted"
732                )
733        else:
734            logger.warning(
735                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
736            )
737
738        if isinstance(cost, float):
739            usage.cost = cost
740        elif cost is not None:
741            # None is allowed, but no other types are expected
742            logger.warning(
743                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
744            )
745
746        return usage
async def cached_available_tools(self) -> list[kiln_ai.tools.KilnToolInterface]:
748    async def cached_available_tools(self) -> list[KilnToolInterface]:
749        if self._cached_available_tools is None:
750            self._cached_available_tools = await self.available_tools()
751        return self._cached_available_tools
async def litellm_tools(self) -> list[kiln_ai.tools.base_tool.ToolCallDefinition]:
759    async def litellm_tools(self) -> list[ToolCallDefinition]:
760        available_tools = await self.cached_available_tools()
761
762        registry_defs = [await tool.toolcall_definition() for tool in available_tools]
763        unmanaged = self.base_adapter_config.unmanaged_tools
764        unmanaged_defs = (
765            [await t.toolcall_definition() for t in unmanaged] if unmanaged else []
766        )
767
768        merged = registry_defs + unmanaged_defs
769        seen_names: set[str] = set()
770        for d in merged:
771            name = d["function"]["name"]
772            if name in seen_names:
773                raise ValueError(
774                    f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names."
775                )
776            seen_names.add(name)
777
778        return merged
async def process_tool_calls( self, tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None) -> tuple[str | None, list[kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper]]:
780    async def process_tool_calls(
781        self, tool_calls: list[ChatCompletionMessageToolCall] | None
782    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
783        if tool_calls is None:
784            return None, []
785
786        assistant_output_from_toolcall: str | None = None
787        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
788        tool_run_coroutines = []
789
790        for tool_call in tool_calls:
791            # Kiln "task_response" tool is used for returning structured output via tool calls.
792            # Load the output from the tool call. Also
793            if tool_call.function.name == "task_response":
794                assistant_output_from_toolcall = tool_call.function.arguments
795                continue
796
797            # Process normal tool calls (not the "task_response" tool)
798            tool_name = tool_call.function.name
799            tool = None
800            for tool_option in await self._tools_for_execution():
801                if await tool_option.name() == tool_name:
802                    tool = tool_option
803                    break
804            if not tool:
805                raise RuntimeError(
806                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
807                )
808
809            # Parse the arguments and validate them against the tool's schema
810            try:
811                parsed_args = json.loads(tool_call.function.arguments)
812            except json.JSONDecodeError:
813                raise RuntimeError(
814                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
815                )
816            try:
817                tool_call_definition = await tool.toolcall_definition()
818                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
819                validate_schema_with_value_error(parsed_args, json_schema)
820            except Exception as e:
821                raise RuntimeError(
822                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
823                ) from e
824
825            # Create context with the calling task's allow_saving setting
826            context = ToolCallContext(
827                allow_saving=self.base_adapter_config.allow_saving
828            )
829
830            async def run_tool_and_format(
831                t=tool, c=context, args=parsed_args, tc_id=tool_call.id
832            ):
833                result = await t.run(c, **args)
834                return ChatCompletionToolMessageParamWrapper(
835                    role="tool",
836                    tool_call_id=tc_id,
837                    content=result.output,
838                    kiln_task_tool_data=result.kiln_task_tool_data
839                    if isinstance(result, KilnTaskToolResult)
840                    else None,
841                    is_error=result.is_error if result.is_error else None,
842                    error_message=result.error_message
843                    if result.error_message
844                    else None,
845                )
846
847            tool_run_coroutines.append(run_tool_and_format())
848
849        if tool_run_coroutines:
850            tool_call_response_messages = await asyncio.gather(*tool_run_coroutines)
851
852        if (
853            assistant_output_from_toolcall is not None
854            and len(tool_call_response_messages) > 0
855        ):
856            raise RuntimeError(
857                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
858            )
859
860        return assistant_output_from_toolcall, tool_call_response_messages
def litellm_message_to_trace_message( self, raw_message: litellm.types.utils.Message) -> kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper:
862    def litellm_message_to_trace_message(
863        self, raw_message: LiteLLMMessage
864    ) -> ChatCompletionAssistantMessageParamWrapper:
865        """
866        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
867        """
868        message: ChatCompletionAssistantMessageParamWrapper = {
869            "role": "assistant",
870        }
871        if raw_message.role != "assistant":
872            raise ValueError(
873                "Model returned a message with a role other than assistant. This is not supported."
874            )
875
876        if hasattr(raw_message, "content"):
877            message["content"] = raw_message.content
878        if hasattr(raw_message, "reasoning_content"):
879            message["reasoning_content"] = raw_message.reasoning_content
880        if hasattr(raw_message, "tool_calls"):
881            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
882            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
883            for litellm_tool_call in raw_message.tool_calls or []:
884                # Optional in the SDK for streaming responses, but should never be None at this point.
885                if litellm_tool_call.function.name is None:
886                    raise ValueError(
887                        "The model requested a tool call, without providing a function name (required)."
888                    )
889                open_ai_tool_calls.append(
890                    ChatCompletionMessageToolCallParam(
891                        id=litellm_tool_call.id,
892                        type="function",
893                        function={
894                            "name": litellm_tool_call.function.name,
895                            "arguments": litellm_tool_call.function.arguments,
896                        },
897                    )
898                )
899            if len(open_ai_tool_calls) > 0:
900                message["tool_calls"] = open_ai_tool_calls
901
902        if not message.get("content") and not message.get("tool_calls"):
903            raise ValueError(
904                "Model returned an assistant message, but no content or tool calls. This is not supported."
905            )
906
907        return message

Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper

def all_messages_to_trace( self, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]) -> list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]]:
909    def all_messages_to_trace(
910        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
911    ) -> list[ChatCompletionMessageParam]:
912        """
913        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
914        """
915        trace: list[ChatCompletionMessageParam] = []
916        for message in messages:
917            if isinstance(message, LiteLLMMessage):
918                trace.append(self.litellm_message_to_trace_message(message))
919            else:
920                trace.append(message)
921        return trace

Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.