kiln_ai.adapters.model_adapters.litellm_adapter

  1import asyncio
  2import copy
  3import json
  4import logging
  5import time
  6from dataclasses import dataclass
  7from typing import Any, Dict, List, Tuple
  8
  9import litellm
 10from litellm.types.utils import (
 11    ChatCompletionMessageToolCall,
 12    ChoiceLogprobs,
 13    Choices,
 14    ModelResponse,
 15)
 16from litellm.types.utils import Message as LiteLLMMessage
 17from litellm.types.utils import Usage as LiteLlmUsage
 18from openai.types.chat.chat_completion_message_tool_call_param import (
 19    ChatCompletionMessageToolCallParam,
 20)
 21
 22import kiln_ai.datamodel as datamodel
 23from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM
 24from kiln_ai.adapters.chat.chat_formatter import chat_message_to_dict
 25from kiln_ai.adapters.ml_model_list import (
 26    KilnModelProvider,
 27    ModelProviderName,
 28    StructuredOutputMode,
 29)
 30from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream
 31from kiln_ai.adapters.model_adapters.base_adapter import (
 32    AdapterConfig,
 33    BaseAdapter,
 34    MessageUsage,
 35    RunOutput,
 36    Usage,
 37)
 38from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
 39from kiln_ai.datamodel.datamodel_enums import InputType
 40from kiln_ai.datamodel.json_schema import (
 41    close_object_schemas,
 42    validate_schema_with_value_error,
 43)
 44from kiln_ai.datamodel.run_config import (
 45    KilnAgentRunConfigProperties,
 46    as_kiln_agent_run_config,
 47)
 48from kiln_ai.tools.base_tool import (
 49    KilnToolInterface,
 50    ToolCallContext,
 51    ToolCallDefinition,
 52)
 53from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
 54from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
 55from kiln_ai.utils.litellm import get_litellm_provider_info
 56from kiln_ai.utils.open_ai_types import (
 57    ChatCompletionAssistantMessageParamWrapper,
 58    ChatCompletionMessageParam,
 59    ChatCompletionToolMessageParamWrapper,
 60    sanitize_messages_for_provider,
 61)
 62
 63MAX_CALLS_PER_TURN = 10
 64MAX_TOOL_CALLS_PER_TURN = 30
 65
 66logger = logging.getLogger(__name__)
 67
 68
 69def _validate_unmanaged_tools(tools: list[KilnToolInterface]) -> None:
 70    for i, tool in enumerate(tools):
 71        if not isinstance(tool, KilnToolInterface):
 72            raise TypeError(
 73                f"unmanaged_tools[{i}] must be a KilnToolInterface instance, got {type(tool).__name__}"
 74            )
 75
 76
 77@dataclass
 78class ModelTurnResult:
 79    assistant_message: str
 80    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
 81    model_response: ModelResponse | None
 82    model_choice: Choices | None
 83    usage: Usage
 84    interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None
 85    message_latency: dict[int, int] | None = None
 86    message_usage: dict[int, MessageUsage] | None = None
 87
 88
 89class LiteLlmAdapter(BaseAdapter):
 90    def __init__(
 91        self,
 92        config: LiteLlmConfig,
 93        kiln_task: datamodel.Task,
 94        base_adapter_config: AdapterConfig | None = None,
 95    ):
 96        if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties):
 97            raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties")
 98        self.config = config
 99        self._additional_body_options = config.additional_body_options
100        self._api_base = config.base_url
101        self._headers = config.default_headers
102        self._litellm_model_id: str | None = None
103        self._cached_available_tools: list[KilnToolInterface] | None = None
104
105        super().__init__(
106            task=kiln_task,
107            run_config=config.run_config_properties,
108            config=base_adapter_config,
109        )
110
111        unmanaged_tools = self.base_adapter_config.unmanaged_tools
112        if unmanaged_tools:
113            _validate_unmanaged_tools(unmanaged_tools)
114
115    async def _run_model_turn(
116        self,
117        provider: KilnModelProvider,
118        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
119        top_logprobs: int | None,
120        skip_response_format: bool,
121    ) -> ModelTurnResult:
122        """
123        Call the model for a single top level turn: from user message to agent message.
124
125        It may make handle iterations of tool calls between the user/agent message if needed.
126        """
127
128        usage = Usage()
129        messages = list(prior_messages)
130        tool_calls_count = 0
131        # Per-LLM-call latency / usage, keyed by index in the messages list.
132        # Kept separate because we don't own the LiteLLM message objects.
133        message_latency: dict[int, int] = {}
134        message_usage: dict[int, MessageUsage] = {}
135
136        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
137            # Build completion kwargs for tool calls
138            completion_kwargs = await self.build_completion_kwargs(
139                provider,
140                # Pass a copy, as acompletion mutates objects and breaks types.
141                copy.deepcopy(messages),
142                top_logprobs,
143                skip_response_format,
144            )
145
146            # Make the completion call (timed)
147            start = time.monotonic()
148            model_response, response_choice = await self.acompletion_checking_response(
149                **completion_kwargs
150            )
151            call_latency_ms = int((time.monotonic() - start) * 1000)
152
153            # count the usage
154            call_usage = self.usage_from_response(model_response)
155            usage += call_usage
156            usage.total_llm_latency_ms = (
157                usage.total_llm_latency_ms or 0
158            ) + call_latency_ms
159
160            # Extract content and tool calls
161            if not hasattr(response_choice, "message"):
162                raise ValueError("Response choice has no message")
163            content = response_choice.message.content
164            tool_calls = response_choice.message.tool_calls
165            if not content and not tool_calls:
166                raise ValueError(
167                    "Model returned an assistant message, but no content or tool calls. This is not supported."
168                )
169
170            # Add message to messages, so it can be used in the next turn
171            messages.append(response_choice.message)
172            message_latency[len(messages) - 1] = call_latency_ms
173            message_usage[len(messages) - 1] = call_usage
174
175            # Process tool calls if any
176            if tool_calls and len(tool_calls) > 0:
177                # check if we should return control to caller
178                if self.base_adapter_config.return_on_tool_call:
179                    # filter out task_response tool (task_response tools are internal)
180                    standard_tool_calls = [
181                        tc for tc in tool_calls if tc.function.name != "task_response"
182                    ]
183                    has_task_response = any(
184                        tc.function.name == "task_response" for tc in tool_calls
185                    )
186                    if standard_tool_calls and not has_task_response:
187                        return ModelTurnResult(
188                            # we don't have any content, we are waiting for toolcall output to come back from client
189                            assistant_message="",
190                            all_messages=messages,
191                            model_response=model_response,
192                            model_choice=response_choice,
193                            usage=usage,
194                            interrupted_by_tool_calls=standard_tool_calls,
195                            message_latency=message_latency,
196                            message_usage=message_usage,
197                        )
198
199                # otherwise: process tool calls internally until final output
200                (
201                    assistant_message_from_toolcall,
202                    tool_call_messages,
203                ) = await self.process_tool_calls(tool_calls)
204
205                # Add tool call results to messages
206                messages.extend(tool_call_messages)
207
208                # If task_response tool was called, we're done
209                if assistant_message_from_toolcall is not None:
210                    return ModelTurnResult(
211                        assistant_message=assistant_message_from_toolcall,
212                        all_messages=messages,
213                        model_response=model_response,
214                        model_choice=response_choice,
215                        usage=usage,
216                        message_latency=message_latency,
217                        message_usage=message_usage,
218                    )
219
220                # If there were tool calls, increment counter and continue
221                if tool_call_messages:
222                    tool_calls_count += 1
223                    continue
224
225            # If no tool calls, return the content as final output
226            if content:
227                return ModelTurnResult(
228                    assistant_message=content,
229                    all_messages=messages,
230                    model_response=model_response,
231                    model_choice=response_choice,
232                    usage=usage,
233                    message_latency=message_latency,
234                    message_usage=message_usage,
235                )
236
237            # If we get here with no content and no tool calls, break
238            raise RuntimeError(
239                "Model returned neither content nor tool calls. It must return at least one of these."
240            )
241
242        raise RuntimeError(
243            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
244        )
245
246    async def _run(
247        self,
248        input: InputType,
249        prior_trace: list[ChatCompletionMessageParam] | None = None,
250    ) -> tuple[RunOutput, Usage | None]:
251        usage = Usage()
252
253        provider = self.model_provider()
254        if not provider.model_id:
255            raise ValueError("Model ID is required for OpenAI compatible models")
256
257        # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter
258        chat_formatter = self.build_chat_formatter(input, prior_trace)
259        messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy(
260            chat_formatter.initial_messages()
261        )
262
263        prior_output: str | None = None
264        final_choice: Choices | None = None
265        turns = 0
266        message_latency: dict[int, int] = {}
267        message_usage: dict[int, MessageUsage] = {}
268
269        # Same loop for both fresh runs and prior_trace continuation.
270        # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues).
271        while True:
272            turns += 1
273            if turns > MAX_CALLS_PER_TURN:
274                raise RuntimeError(
275                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
276                )
277
278            turn = chat_formatter.next_turn(prior_output)
279            if turn is None:
280                # No next turn, we're done
281                break
282
283            # Add messages from the turn to chat history
284            for message in turn.messages:
285                if message.content is None:
286                    raise ValueError("Empty message content isn't allowed")
287                messages.append(chat_message_to_dict(message))  # type: ignore
288
289            skip_response_format = not turn.final_call
290            turn_result = await self._run_model_turn(
291                provider,
292                messages,
293                self.base_adapter_config.top_logprobs if turn.final_call else None,
294                skip_response_format,
295            )
296
297            usage += turn_result.usage
298            if turn_result.message_latency:
299                message_latency.update(turn_result.message_latency)
300            if turn_result.message_usage:
301                message_usage.update(turn_result.message_usage)
302
303            prior_output = turn_result.assistant_message
304            messages = turn_result.all_messages
305            final_choice = turn_result.model_choice
306
307            # Check if we were interrupted by tool calls
308            if turn_result.interrupted_by_tool_calls:
309                trace = self.all_messages_to_trace(
310                    messages, message_latency, message_usage
311                )
312                intermediate_outputs = chat_formatter.intermediate_outputs()
313                output = RunOutput(
314                    output=prior_output or "",
315                    intermediate_outputs=intermediate_outputs,
316                    output_logprobs=None,
317                    trace=trace,
318                )
319                return output, usage
320
321            if not prior_output:
322                raise RuntimeError("No assistant message/output returned from model")
323
324        logprobs = self._extract_and_validate_logprobs(final_choice)
325
326        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
327        intermediate_outputs = chat_formatter.intermediate_outputs()
328        self._extract_reasoning_to_intermediate_outputs(
329            final_choice, intermediate_outputs
330        )
331
332        if not isinstance(prior_output, str):
333            raise RuntimeError(f"assistant message is not a string: {prior_output}")
334
335        trace = self.all_messages_to_trace(messages, message_latency, message_usage)
336        output = RunOutput(
337            output=prior_output,
338            intermediate_outputs=intermediate_outputs,
339            output_logprobs=logprobs,
340            trace=trace,
341        )
342
343        return output, usage
344
345    def _create_run_stream(
346        self,
347        input: InputType,
348        prior_trace: list[ChatCompletionMessageParam] | None = None,
349    ) -> AdapterStream:
350        provider = self.model_provider()
351        if not provider.model_id:
352            raise ValueError("Model ID is required for OpenAI compatible models")
353
354        chat_formatter = self.build_chat_formatter(input, prior_trace)
355        initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy(
356            chat_formatter.initial_messages()
357        )
358
359        return AdapterStream(
360            adapter=self,
361            provider=provider,
362            chat_formatter=chat_formatter,
363            initial_messages=initial_messages,
364            top_logprobs=self.base_adapter_config.top_logprobs,
365        )
366
367    def _extract_and_validate_logprobs(
368        self, final_choice: Choices | None
369    ) -> ChoiceLogprobs | None:
370        """
371        Extract logprobs from the final choice and validate they exist if required.
372        """
373        logprobs = None
374        if (
375            final_choice is not None
376            and hasattr(final_choice, "logprobs")
377            and isinstance(final_choice.logprobs, ChoiceLogprobs)
378        ):
379            logprobs = final_choice.logprobs
380
381        # Check logprobs worked, if required
382        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
383            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
384
385        return logprobs
386
387    def _extract_reasoning_to_intermediate_outputs(
388        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
389    ) -> None:
390        """Extract reasoning content from model choice and add to intermediate outputs if present."""
391        if (
392            final_choice is not None
393            and hasattr(final_choice, "message")
394            and hasattr(final_choice.message, "reasoning_content")
395        ):
396            reasoning_content = final_choice.message.reasoning_content
397            if reasoning_content is not None:
398                stripped_reasoning_content = reasoning_content.strip()
399                if len(stripped_reasoning_content) > 0:
400                    intermediate_outputs["reasoning"] = stripped_reasoning_content
401
402    async def acompletion_checking_response(
403        self, **kwargs: Any
404    ) -> Tuple[ModelResponse, Choices]:
405        response = await litellm.acompletion(**kwargs)
406
407        if (
408            not isinstance(response, ModelResponse)
409            or not response.choices
410            or len(response.choices) == 0
411            or not isinstance(response.choices[0], Choices)
412        ):
413            raise RuntimeError(
414                f"Expected ModelResponse with Choices, got {type(response)}."
415            )
416        return response, response.choices[0]
417
418    def adapter_name(self) -> str:
419        return "kiln_openai_compatible_adapter"
420
421    async def response_format_options(self) -> dict[str, Any]:
422        # Unstructured if task isn't structured
423        if not self.has_structured_output():
424            return {}
425
426        run_config = as_kiln_agent_run_config(self.run_config)
427        structured_output_mode: StructuredOutputMode = run_config.structured_output_mode
428
429        match structured_output_mode:
430            case StructuredOutputMode.json_mode:
431                return {"response_format": {"type": "json_object"}}
432            case StructuredOutputMode.json_schema:
433                return self.json_schema_response_format()
434            case StructuredOutputMode.function_calling_weak:
435                return self.tool_call_params(strict=False)
436            case StructuredOutputMode.function_calling:
437                return self.tool_call_params(strict=True)
438            case StructuredOutputMode.json_instructions:
439                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
440                return {}
441            case StructuredOutputMode.json_custom_instructions:
442                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
443                return {}
444            case StructuredOutputMode.json_instruction_and_object:
445                # We set response_format to json_object and also set json instructions in the prompt
446                return {"response_format": {"type": "json_object"}}
447            case StructuredOutputMode.default:
448                provider_name = run_config.model_provider_name
449                if provider_name == ModelProviderName.ollama:
450                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
451                    return self.json_schema_response_format()
452                elif provider_name == ModelProviderName.docker_model_runner:
453                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
454                    return self.json_schema_response_format()
455                else:
456                    # Default to function calling -- it's older than the other modes. Higher compatibility.
457                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
458                    strict = provider_name == ModelProviderName.openai
459                    return self.tool_call_params(strict=strict)
460            case StructuredOutputMode.unknown:
461                # See above, but this case should never happen.
462                raise ValueError("Structured output mode is unknown.")
463            case _:
464                raise_exhaustive_enum_error(structured_output_mode)  # type: ignore[arg-type]
465
466    def json_schema_response_format(self) -> dict[str, Any]:
467        output_schema = self.task.output_schema()
468        if output_schema is None:
469            raise ValueError(
470                "Invalid output schema for this task. Cannot use JSON schema response format."
471            )
472        output_schema = close_object_schemas(output_schema, strict=True)
473        return {
474            "response_format": {
475                "type": "json_schema",
476                "json_schema": {
477                    "name": "task_response",
478                    "schema": output_schema,
479                },
480            }
481        }
482
483    def tool_call_params(self, strict: bool) -> dict[str, Any]:
484        # Add additional_properties: false to the schema (OpenAI requires this for some models)
485        output_schema = self.task.output_schema()
486        if not isinstance(output_schema, dict):
487            raise ValueError(
488                "Invalid output schema for this task. Can not use tool calls."
489            )
490        output_schema = close_object_schemas(output_schema, strict=strict)
491
492        function_params = {
493            "name": "task_response",
494            "parameters": output_schema,
495        }
496        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
497        if strict:
498            function_params["strict"] = True
499
500        return {
501            "tools": [
502                {
503                    "type": "function",
504                    "function": function_params,
505                }
506            ],
507            "tool_choice": {
508                "type": "function",
509                "function": {"name": "task_response"},
510            },
511        }
512
513    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
514        # Don't love having this logic here. But it's worth the usability improvement
515        # so better to keep it than exclude it. Should figure out how I want to isolate
516        # this sort of logic so it's config driven and can be overridden
517        extra_body: dict[str, Any] = {}
518        provider_options = {}
519
520        run_config = as_kiln_agent_run_config(self.run_config)
521        # For legacy config 'thinking_level' is not set, default to provider's default
522        if "thinking_level" in run_config.model_fields_set:
523            thinking_level = run_config.thinking_level
524        else:
525            thinking_level = provider.default_thinking_level
526
527        # Set the reasoning_effort
528        if thinking_level is not None:
529            # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens
530            if (
531                provider.name == ModelProviderName.openrouter
532                and provider.openrouter_reasoning_object
533            ):
534                extra_body["reasoning"] = {"effort": thinking_level}
535            else:
536                extra_body["reasoning_effort"] = thinking_level
537
538        if provider.require_openrouter_reasoning:
539            # https://openrouter.ai/docs/use-cases/reasoning-tokens
540            extra_body["reasoning"] = {
541                "exclude": False,
542            }
543
544        if provider.gemini_reasoning_enabled:
545            extra_body["reasoning"] = {
546                "enabled": True,
547            }
548
549        if provider.name == ModelProviderName.openrouter:
550            # Ask OpenRouter to include usage in the response (cost)
551            extra_body["usage"] = {"include": True}
552
553            # Set a default provider order for more deterministic routing.
554            # OpenRouter will ignore providers that don't support the model.
555            # Special cases below (like R1) can override this order.
556            # allow_fallbacks is true by default, but we can override it here.
557            provider_options["order"] = [
558                "fireworks",
559                "parasail",
560                "together",
561                "deepinfra",
562                "novita",
563                "groq",
564                "amazon-bedrock",
565                "azure",
566                "nebius",
567            ]
568
569        if provider.anthropic_extended_thinking:
570            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
571
572        if provider.r1_openrouter_options:
573            # Require providers that support the reasoning parameter
574            provider_options["require_parameters"] = True
575            # Prefer R1 providers with reasonable perf/quants
576            provider_options["order"] = ["fireworks", "together"]
577            # R1 providers with unreasonable quants
578            provider_options["ignore"] = ["deepinfra"]
579
580        # Only set of this request is to get logprobs.
581        if (
582            provider.logprobs_openrouter_options
583            and self.base_adapter_config.top_logprobs is not None
584        ):
585            # Don't let OpenRouter choose a provider that doesn't support logprobs.
586            provider_options["require_parameters"] = True
587            # DeepInfra silently fails to return logprobs consistently.
588            provider_options["ignore"] = ["deepinfra"]
589
590        if provider.openrouter_skip_required_parameters:
591            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
592            provider_options["require_parameters"] = False
593
594        # Siliconflow uses a bool flag for thinking, for some models
595        if provider.siliconflow_enable_thinking is not None:
596            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
597
598        if len(provider_options) > 0:
599            extra_body["provider"] = provider_options
600
601        return extra_body
602
603    def litellm_model_id(self) -> str:
604        # The model ID is an interesting combination of format and url endpoint.
605        # It specifics the provider URL/host, but this is overridden if you manually set an api url
606        if self._litellm_model_id:
607            return self._litellm_model_id
608
609        litellm_provider_info = get_litellm_provider_info(self.model_provider())
610        if litellm_provider_info.is_custom and self._api_base is None:
611            raise ValueError(
612                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
613            )
614
615        self._litellm_model_id = litellm_provider_info.litellm_model_id
616        return self._litellm_model_id
617
618    def _allowed_openai_params_for_completion_kwargs(
619        self, completion_kwargs: dict[str, Any]
620    ) -> list[str]:
621        """
622        LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong
623        and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter.
624        """
625        # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that
626        explicit_allowed_params: Any | list = completion_kwargs.get(
627            "allowed_openai_params", []
628        )
629        if not isinstance(explicit_allowed_params, list):
630            raise ValueError(
631                f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}"
632            )
633        explicit_allowed_params_validated = [
634            param for param in explicit_allowed_params if isinstance(param, str)
635        ]
636        invalid_count = len(explicit_allowed_params) - len(
637            explicit_allowed_params_validated
638        )
639        if invalid_count > 0:
640            raise ValueError(
641                f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings"
642            )
643
644        # these are our own logic
645        automatic_allowed_params: list[str] = []
646        if "tools" in completion_kwargs:
647            automatic_allowed_params.append("tools")
648        if "tool_choice" in completion_kwargs:
649            automatic_allowed_params.append("tool_choice")
650
651        return list(set(explicit_allowed_params_validated + automatic_allowed_params))
652
653    async def build_completion_kwargs(
654        self,
655        provider: KilnModelProvider,
656        messages: list[ChatCompletionMessageIncludingLiteLLM],
657        top_logprobs: int | None,
658        skip_response_format: bool = False,
659    ) -> dict[str, Any]:
660        run_config = as_kiln_agent_run_config(self.run_config)
661        extra_body = self.build_extra_body(provider)
662
663        # Merge all parameters into a single kwargs dict for litellm
664        completion_kwargs = {
665            "model": self.litellm_model_id(),
666            "messages": messages,
667            "api_base": self._api_base,
668            "headers": self._headers,
669            "temperature": run_config.temperature,
670            "top_p": run_config.top_p,
671            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
672            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
673            # Better to ignore them than to fail the model call.
674            # https://docs.litellm.ai/docs/completion/input
675            "drop_params": True,
676            **extra_body,
677            **self._additional_body_options,
678        }
679
680        if self.base_adapter_config.automatic_prompt_caching:
681            # Mark the last message for cache control. Litellm's AnthropicCacheControlHook
682            # handles provider-specific injection. Providers auto-cache matching prefixes,
683            # so marking the last message is sufficient for multi-turn conversations.
684            completion_kwargs["cache_control_injection_points"] = [
685                {"location": "message", "index": -1}
686            ]
687
688        tool_calls = await self.litellm_tools()
689        has_tools = len(tool_calls) > 0
690        if has_tools:
691            completion_kwargs["tools"] = tool_calls
692            completion_kwargs["tool_choice"] = "auto"
693
694        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
695        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
696        if provider.temp_top_p_exclusive:
697            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
698                del completion_kwargs["top_p"]
699            if (
700                "temperature" in completion_kwargs
701                and completion_kwargs["temperature"] == 1.0
702            ):
703                del completion_kwargs["temperature"]
704            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
705                raise ValueError(
706                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
707                )
708
709        if not skip_response_format:
710            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
711            response_format_options = await self.response_format_options()
712
713            # Check for a conflict between tools and response format using tools
714            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
715            if has_tools and "tools" in response_format_options:
716                raise ValueError(
717                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
718                )
719
720            completion_kwargs.update(response_format_options)
721
722        if top_logprobs is not None:
723            completion_kwargs["logprobs"] = True
724            completion_kwargs["top_logprobs"] = top_logprobs
725
726        # any params listed in this list will be passed to the model regardless of LiteLLM's own validation
727        allowed_openai_params = self._allowed_openai_params_for_completion_kwargs(
728            completion_kwargs
729        )
730        if len(allowed_openai_params) > 0:
731            completion_kwargs["allowed_openai_params"] = allowed_openai_params
732
733        completion_kwargs["messages"] = sanitize_messages_for_provider(messages)
734
735        return completion_kwargs
736
737    def usage_from_response(self, response: ModelResponse) -> MessageUsage:
738        litellm_usage = response.get("usage", None)
739
740        # LiteLLM isn't consistent in how it returns the cost.
741        cost = response._hidden_params.get("response_cost", None)
742        if cost is None and litellm_usage:
743            cost = litellm_usage.get("cost", None)
744
745        usage = MessageUsage()
746
747        if not litellm_usage and not cost:
748            return usage
749
750        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
751            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
752            usage.output_tokens = litellm_usage.get("completion_tokens", None)
753            usage.total_tokens = litellm_usage.get("total_tokens", None)
754            prompt_details = litellm_usage.get("prompt_tokens_details", None)
755            if prompt_details and hasattr(prompt_details, "cached_tokens"):
756                usage.cached_tokens = prompt_details.cached_tokens
757            elif prompt_details:
758                logger.warning(
759                    f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted"
760                )
761        else:
762            logger.warning(
763                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
764            )
765
766        if isinstance(cost, float):
767            usage.cost = cost
768        elif cost is not None:
769            # None is allowed, but no other types are expected
770            logger.warning(
771                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
772            )
773
774        return usage
775
776    async def cached_available_tools(self) -> list[KilnToolInterface]:
777        if self._cached_available_tools is None:
778            self._cached_available_tools = await self.available_tools()
779        return self._cached_available_tools
780
781    async def _tools_for_execution(self) -> list[KilnToolInterface]:
782        """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``)."""
783        registry = await self.cached_available_tools()
784        unmanaged = self.base_adapter_config.unmanaged_tools or []
785        return registry + unmanaged
786
787    async def litellm_tools(self) -> list[ToolCallDefinition]:
788        available_tools = await self.cached_available_tools()
789
790        registry_defs = [await tool.toolcall_definition() for tool in available_tools]
791        unmanaged = self.base_adapter_config.unmanaged_tools
792        unmanaged_defs = (
793            [await t.toolcall_definition() for t in unmanaged] if unmanaged else []
794        )
795
796        merged = registry_defs + unmanaged_defs
797        seen_names: set[str] = set()
798        for d in merged:
799            name = d["function"]["name"]
800            if name in seen_names:
801                raise ValueError(
802                    f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names."
803                )
804            seen_names.add(name)
805
806        return merged
807
808    async def process_tool_calls(
809        self, tool_calls: list[ChatCompletionMessageToolCall] | None
810    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
811        if tool_calls is None:
812            return None, []
813
814        assistant_output_from_toolcall: str | None = None
815        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
816        tool_run_coroutines = []
817
818        for tool_call in tool_calls:
819            # Kiln "task_response" tool is used for returning structured output via tool calls.
820            # Load the output from the tool call. Also
821            if tool_call.function.name == "task_response":
822                assistant_output_from_toolcall = tool_call.function.arguments
823                continue
824
825            # Process normal tool calls (not the "task_response" tool)
826            tool_name = tool_call.function.name
827            tool = None
828            for tool_option in await self._tools_for_execution():
829                if await tool_option.name() == tool_name:
830                    tool = tool_option
831                    break
832            if not tool:
833                raise RuntimeError(
834                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
835                )
836
837            # Parse the arguments and validate them against the tool's schema
838            try:
839                parsed_args = json.loads(tool_call.function.arguments)
840            except json.JSONDecodeError:
841                raise RuntimeError(
842                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
843                )
844            try:
845                tool_call_definition = await tool.toolcall_definition()
846                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
847                validate_schema_with_value_error(parsed_args, json_schema)
848            except Exception as e:
849                raise RuntimeError(
850                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
851                ) from e
852
853            # Create context with the calling task's allow_saving setting
854            context = ToolCallContext(
855                allow_saving=self.base_adapter_config.allow_saving
856            )
857
858            async def run_tool_and_format(
859                t=tool, c=context, args=parsed_args, tc_id=tool_call.id
860            ):
861                result = await t.run(c, **args)
862                return ChatCompletionToolMessageParamWrapper(
863                    role="tool",
864                    tool_call_id=tc_id,
865                    content=result.output,
866                    kiln_task_tool_data=result.kiln_task_tool_data
867                    if isinstance(result, KilnTaskToolResult)
868                    else None,
869                    is_error=result.is_error if result.is_error else None,
870                    error_message=result.error_message
871                    if result.error_message
872                    else None,
873                )
874
875            tool_run_coroutines.append(run_tool_and_format())
876
877        if tool_run_coroutines:
878            tool_call_response_messages = await asyncio.gather(*tool_run_coroutines)
879
880        if (
881            assistant_output_from_toolcall is not None
882            and len(tool_call_response_messages) > 0
883        ):
884            raise RuntimeError(
885                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
886            )
887
888        return assistant_output_from_toolcall, tool_call_response_messages
889
890    def litellm_message_to_trace_message(
891        self,
892        raw_message: LiteLLMMessage,
893        latency_ms: int | None = None,
894        usage: MessageUsage | None = None,
895    ) -> ChatCompletionAssistantMessageParamWrapper:
896        """
897        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
898        """
899        message: ChatCompletionAssistantMessageParamWrapper = {
900            "role": "assistant",
901        }
902        if raw_message.role != "assistant":
903            raise ValueError(
904                "Model returned a message with a role other than assistant. This is not supported."
905            )
906
907        if hasattr(raw_message, "content"):
908            message["content"] = raw_message.content
909        if hasattr(raw_message, "reasoning_content"):
910            message["reasoning_content"] = raw_message.reasoning_content
911        if hasattr(raw_message, "tool_calls"):
912            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
913            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
914            for litellm_tool_call in raw_message.tool_calls or []:
915                # Optional in the SDK for streaming responses, but should never be None at this point.
916                if litellm_tool_call.function.name is None:
917                    raise ValueError(
918                        "The model requested a tool call, without providing a function name (required)."
919                    )
920                open_ai_tool_calls.append(
921                    ChatCompletionMessageToolCallParam(
922                        id=litellm_tool_call.id,
923                        type="function",
924                        function={
925                            "name": litellm_tool_call.function.name,
926                            "arguments": litellm_tool_call.function.arguments,
927                        },
928                    )
929                )
930            if len(open_ai_tool_calls) > 0:
931                message["tool_calls"] = open_ai_tool_calls
932
933        if latency_ms is not None:
934            message["latency_ms"] = latency_ms
935
936        if usage is not None:
937            message["usage"] = usage
938
939        if not message.get("content") and not message.get("tool_calls"):
940            raise ValueError(
941                "Model returned an assistant message, but no content or tool calls. This is not supported."
942            )
943
944        return message
945
946    def all_messages_to_trace(
947        self,
948        messages: list[ChatCompletionMessageIncludingLiteLLM],
949        message_latency: dict[int, int] | None = None,
950        message_usage: dict[int, MessageUsage] | None = None,
951    ) -> list[ChatCompletionMessageParam]:
952        """
953        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
954
955        Non-LiteLLM dict messages pass through unchanged. Any per-message
956        ``usage``/``latency_ms`` already attached to those dicts (e.g. from a
957        seeded prior trace) is preserved.
958        """
959        trace: list[ChatCompletionMessageParam] = []
960        for i, message in enumerate(messages):
961            if isinstance(message, LiteLLMMessage):
962                latency_ms = message_latency.get(i) if message_latency else None
963                usage = message_usage.get(i) if message_usage else None
964                trace.append(
965                    self.litellm_message_to_trace_message(message, latency_ms, usage)
966                )
967            else:
968                trace.append(message)
969        return trace
MAX_CALLS_PER_TURN = 10
MAX_TOOL_CALLS_PER_TURN = 30
@dataclass
class ModelTurnResult:
78@dataclass
79class ModelTurnResult:
80    assistant_message: str
81    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
82    model_response: ModelResponse | None
83    model_choice: Choices | None
84    usage: Usage
85    interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None
86    message_latency: dict[int, int] | None = None
87    message_usage: dict[int, MessageUsage] | None = None
ModelTurnResult( assistant_message: str, all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], model_response: litellm.types.utils.ModelResponse | None, model_choice: litellm.types.utils.Choices | None, usage: kiln_ai.datamodel.Usage, interrupted_by_tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None = None, message_latency: dict[int, int] | None = None, message_usage: dict[int, kiln_ai.datamodel.MessageUsage] | None = None)
assistant_message: str
all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]
model_response: litellm.types.utils.ModelResponse | None
model_choice: litellm.types.utils.Choices | None
interrupted_by_tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None = None
message_latency: dict[int, int] | None = None
message_usage: dict[int, kiln_ai.datamodel.MessageUsage] | None = None
 90class LiteLlmAdapter(BaseAdapter):
 91    def __init__(
 92        self,
 93        config: LiteLlmConfig,
 94        kiln_task: datamodel.Task,
 95        base_adapter_config: AdapterConfig | None = None,
 96    ):
 97        if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties):
 98            raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties")
 99        self.config = config
100        self._additional_body_options = config.additional_body_options
101        self._api_base = config.base_url
102        self._headers = config.default_headers
103        self._litellm_model_id: str | None = None
104        self._cached_available_tools: list[KilnToolInterface] | None = None
105
106        super().__init__(
107            task=kiln_task,
108            run_config=config.run_config_properties,
109            config=base_adapter_config,
110        )
111
112        unmanaged_tools = self.base_adapter_config.unmanaged_tools
113        if unmanaged_tools:
114            _validate_unmanaged_tools(unmanaged_tools)
115
116    async def _run_model_turn(
117        self,
118        provider: KilnModelProvider,
119        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
120        top_logprobs: int | None,
121        skip_response_format: bool,
122    ) -> ModelTurnResult:
123        """
124        Call the model for a single top level turn: from user message to agent message.
125
126        It may make handle iterations of tool calls between the user/agent message if needed.
127        """
128
129        usage = Usage()
130        messages = list(prior_messages)
131        tool_calls_count = 0
132        # Per-LLM-call latency / usage, keyed by index in the messages list.
133        # Kept separate because we don't own the LiteLLM message objects.
134        message_latency: dict[int, int] = {}
135        message_usage: dict[int, MessageUsage] = {}
136
137        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
138            # Build completion kwargs for tool calls
139            completion_kwargs = await self.build_completion_kwargs(
140                provider,
141                # Pass a copy, as acompletion mutates objects and breaks types.
142                copy.deepcopy(messages),
143                top_logprobs,
144                skip_response_format,
145            )
146
147            # Make the completion call (timed)
148            start = time.monotonic()
149            model_response, response_choice = await self.acompletion_checking_response(
150                **completion_kwargs
151            )
152            call_latency_ms = int((time.monotonic() - start) * 1000)
153
154            # count the usage
155            call_usage = self.usage_from_response(model_response)
156            usage += call_usage
157            usage.total_llm_latency_ms = (
158                usage.total_llm_latency_ms or 0
159            ) + call_latency_ms
160
161            # Extract content and tool calls
162            if not hasattr(response_choice, "message"):
163                raise ValueError("Response choice has no message")
164            content = response_choice.message.content
165            tool_calls = response_choice.message.tool_calls
166            if not content and not tool_calls:
167                raise ValueError(
168                    "Model returned an assistant message, but no content or tool calls. This is not supported."
169                )
170
171            # Add message to messages, so it can be used in the next turn
172            messages.append(response_choice.message)
173            message_latency[len(messages) - 1] = call_latency_ms
174            message_usage[len(messages) - 1] = call_usage
175
176            # Process tool calls if any
177            if tool_calls and len(tool_calls) > 0:
178                # check if we should return control to caller
179                if self.base_adapter_config.return_on_tool_call:
180                    # filter out task_response tool (task_response tools are internal)
181                    standard_tool_calls = [
182                        tc for tc in tool_calls if tc.function.name != "task_response"
183                    ]
184                    has_task_response = any(
185                        tc.function.name == "task_response" for tc in tool_calls
186                    )
187                    if standard_tool_calls and not has_task_response:
188                        return ModelTurnResult(
189                            # we don't have any content, we are waiting for toolcall output to come back from client
190                            assistant_message="",
191                            all_messages=messages,
192                            model_response=model_response,
193                            model_choice=response_choice,
194                            usage=usage,
195                            interrupted_by_tool_calls=standard_tool_calls,
196                            message_latency=message_latency,
197                            message_usage=message_usage,
198                        )
199
200                # otherwise: process tool calls internally until final output
201                (
202                    assistant_message_from_toolcall,
203                    tool_call_messages,
204                ) = await self.process_tool_calls(tool_calls)
205
206                # Add tool call results to messages
207                messages.extend(tool_call_messages)
208
209                # If task_response tool was called, we're done
210                if assistant_message_from_toolcall is not None:
211                    return ModelTurnResult(
212                        assistant_message=assistant_message_from_toolcall,
213                        all_messages=messages,
214                        model_response=model_response,
215                        model_choice=response_choice,
216                        usage=usage,
217                        message_latency=message_latency,
218                        message_usage=message_usage,
219                    )
220
221                # If there were tool calls, increment counter and continue
222                if tool_call_messages:
223                    tool_calls_count += 1
224                    continue
225
226            # If no tool calls, return the content as final output
227            if content:
228                return ModelTurnResult(
229                    assistant_message=content,
230                    all_messages=messages,
231                    model_response=model_response,
232                    model_choice=response_choice,
233                    usage=usage,
234                    message_latency=message_latency,
235                    message_usage=message_usage,
236                )
237
238            # If we get here with no content and no tool calls, break
239            raise RuntimeError(
240                "Model returned neither content nor tool calls. It must return at least one of these."
241            )
242
243        raise RuntimeError(
244            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
245        )
246
247    async def _run(
248        self,
249        input: InputType,
250        prior_trace: list[ChatCompletionMessageParam] | None = None,
251    ) -> tuple[RunOutput, Usage | None]:
252        usage = Usage()
253
254        provider = self.model_provider()
255        if not provider.model_id:
256            raise ValueError("Model ID is required for OpenAI compatible models")
257
258        # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter
259        chat_formatter = self.build_chat_formatter(input, prior_trace)
260        messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy(
261            chat_formatter.initial_messages()
262        )
263
264        prior_output: str | None = None
265        final_choice: Choices | None = None
266        turns = 0
267        message_latency: dict[int, int] = {}
268        message_usage: dict[int, MessageUsage] = {}
269
270        # Same loop for both fresh runs and prior_trace continuation.
271        # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues).
272        while True:
273            turns += 1
274            if turns > MAX_CALLS_PER_TURN:
275                raise RuntimeError(
276                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
277                )
278
279            turn = chat_formatter.next_turn(prior_output)
280            if turn is None:
281                # No next turn, we're done
282                break
283
284            # Add messages from the turn to chat history
285            for message in turn.messages:
286                if message.content is None:
287                    raise ValueError("Empty message content isn't allowed")
288                messages.append(chat_message_to_dict(message))  # type: ignore
289
290            skip_response_format = not turn.final_call
291            turn_result = await self._run_model_turn(
292                provider,
293                messages,
294                self.base_adapter_config.top_logprobs if turn.final_call else None,
295                skip_response_format,
296            )
297
298            usage += turn_result.usage
299            if turn_result.message_latency:
300                message_latency.update(turn_result.message_latency)
301            if turn_result.message_usage:
302                message_usage.update(turn_result.message_usage)
303
304            prior_output = turn_result.assistant_message
305            messages = turn_result.all_messages
306            final_choice = turn_result.model_choice
307
308            # Check if we were interrupted by tool calls
309            if turn_result.interrupted_by_tool_calls:
310                trace = self.all_messages_to_trace(
311                    messages, message_latency, message_usage
312                )
313                intermediate_outputs = chat_formatter.intermediate_outputs()
314                output = RunOutput(
315                    output=prior_output or "",
316                    intermediate_outputs=intermediate_outputs,
317                    output_logprobs=None,
318                    trace=trace,
319                )
320                return output, usage
321
322            if not prior_output:
323                raise RuntimeError("No assistant message/output returned from model")
324
325        logprobs = self._extract_and_validate_logprobs(final_choice)
326
327        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
328        intermediate_outputs = chat_formatter.intermediate_outputs()
329        self._extract_reasoning_to_intermediate_outputs(
330            final_choice, intermediate_outputs
331        )
332
333        if not isinstance(prior_output, str):
334            raise RuntimeError(f"assistant message is not a string: {prior_output}")
335
336        trace = self.all_messages_to_trace(messages, message_latency, message_usage)
337        output = RunOutput(
338            output=prior_output,
339            intermediate_outputs=intermediate_outputs,
340            output_logprobs=logprobs,
341            trace=trace,
342        )
343
344        return output, usage
345
346    def _create_run_stream(
347        self,
348        input: InputType,
349        prior_trace: list[ChatCompletionMessageParam] | None = None,
350    ) -> AdapterStream:
351        provider = self.model_provider()
352        if not provider.model_id:
353            raise ValueError("Model ID is required for OpenAI compatible models")
354
355        chat_formatter = self.build_chat_formatter(input, prior_trace)
356        initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy(
357            chat_formatter.initial_messages()
358        )
359
360        return AdapterStream(
361            adapter=self,
362            provider=provider,
363            chat_formatter=chat_formatter,
364            initial_messages=initial_messages,
365            top_logprobs=self.base_adapter_config.top_logprobs,
366        )
367
368    def _extract_and_validate_logprobs(
369        self, final_choice: Choices | None
370    ) -> ChoiceLogprobs | None:
371        """
372        Extract logprobs from the final choice and validate they exist if required.
373        """
374        logprobs = None
375        if (
376            final_choice is not None
377            and hasattr(final_choice, "logprobs")
378            and isinstance(final_choice.logprobs, ChoiceLogprobs)
379        ):
380            logprobs = final_choice.logprobs
381
382        # Check logprobs worked, if required
383        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
384            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
385
386        return logprobs
387
388    def _extract_reasoning_to_intermediate_outputs(
389        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
390    ) -> None:
391        """Extract reasoning content from model choice and add to intermediate outputs if present."""
392        if (
393            final_choice is not None
394            and hasattr(final_choice, "message")
395            and hasattr(final_choice.message, "reasoning_content")
396        ):
397            reasoning_content = final_choice.message.reasoning_content
398            if reasoning_content is not None:
399                stripped_reasoning_content = reasoning_content.strip()
400                if len(stripped_reasoning_content) > 0:
401                    intermediate_outputs["reasoning"] = stripped_reasoning_content
402
403    async def acompletion_checking_response(
404        self, **kwargs: Any
405    ) -> Tuple[ModelResponse, Choices]:
406        response = await litellm.acompletion(**kwargs)
407
408        if (
409            not isinstance(response, ModelResponse)
410            or not response.choices
411            or len(response.choices) == 0
412            or not isinstance(response.choices[0], Choices)
413        ):
414            raise RuntimeError(
415                f"Expected ModelResponse with Choices, got {type(response)}."
416            )
417        return response, response.choices[0]
418
419    def adapter_name(self) -> str:
420        return "kiln_openai_compatible_adapter"
421
422    async def response_format_options(self) -> dict[str, Any]:
423        # Unstructured if task isn't structured
424        if not self.has_structured_output():
425            return {}
426
427        run_config = as_kiln_agent_run_config(self.run_config)
428        structured_output_mode: StructuredOutputMode = run_config.structured_output_mode
429
430        match structured_output_mode:
431            case StructuredOutputMode.json_mode:
432                return {"response_format": {"type": "json_object"}}
433            case StructuredOutputMode.json_schema:
434                return self.json_schema_response_format()
435            case StructuredOutputMode.function_calling_weak:
436                return self.tool_call_params(strict=False)
437            case StructuredOutputMode.function_calling:
438                return self.tool_call_params(strict=True)
439            case StructuredOutputMode.json_instructions:
440                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
441                return {}
442            case StructuredOutputMode.json_custom_instructions:
443                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
444                return {}
445            case StructuredOutputMode.json_instruction_and_object:
446                # We set response_format to json_object and also set json instructions in the prompt
447                return {"response_format": {"type": "json_object"}}
448            case StructuredOutputMode.default:
449                provider_name = run_config.model_provider_name
450                if provider_name == ModelProviderName.ollama:
451                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
452                    return self.json_schema_response_format()
453                elif provider_name == ModelProviderName.docker_model_runner:
454                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
455                    return self.json_schema_response_format()
456                else:
457                    # Default to function calling -- it's older than the other modes. Higher compatibility.
458                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
459                    strict = provider_name == ModelProviderName.openai
460                    return self.tool_call_params(strict=strict)
461            case StructuredOutputMode.unknown:
462                # See above, but this case should never happen.
463                raise ValueError("Structured output mode is unknown.")
464            case _:
465                raise_exhaustive_enum_error(structured_output_mode)  # type: ignore[arg-type]
466
467    def json_schema_response_format(self) -> dict[str, Any]:
468        output_schema = self.task.output_schema()
469        if output_schema is None:
470            raise ValueError(
471                "Invalid output schema for this task. Cannot use JSON schema response format."
472            )
473        output_schema = close_object_schemas(output_schema, strict=True)
474        return {
475            "response_format": {
476                "type": "json_schema",
477                "json_schema": {
478                    "name": "task_response",
479                    "schema": output_schema,
480                },
481            }
482        }
483
484    def tool_call_params(self, strict: bool) -> dict[str, Any]:
485        # Add additional_properties: false to the schema (OpenAI requires this for some models)
486        output_schema = self.task.output_schema()
487        if not isinstance(output_schema, dict):
488            raise ValueError(
489                "Invalid output schema for this task. Can not use tool calls."
490            )
491        output_schema = close_object_schemas(output_schema, strict=strict)
492
493        function_params = {
494            "name": "task_response",
495            "parameters": output_schema,
496        }
497        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
498        if strict:
499            function_params["strict"] = True
500
501        return {
502            "tools": [
503                {
504                    "type": "function",
505                    "function": function_params,
506                }
507            ],
508            "tool_choice": {
509                "type": "function",
510                "function": {"name": "task_response"},
511            },
512        }
513
514    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
515        # Don't love having this logic here. But it's worth the usability improvement
516        # so better to keep it than exclude it. Should figure out how I want to isolate
517        # this sort of logic so it's config driven and can be overridden
518        extra_body: dict[str, Any] = {}
519        provider_options = {}
520
521        run_config = as_kiln_agent_run_config(self.run_config)
522        # For legacy config 'thinking_level' is not set, default to provider's default
523        if "thinking_level" in run_config.model_fields_set:
524            thinking_level = run_config.thinking_level
525        else:
526            thinking_level = provider.default_thinking_level
527
528        # Set the reasoning_effort
529        if thinking_level is not None:
530            # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens
531            if (
532                provider.name == ModelProviderName.openrouter
533                and provider.openrouter_reasoning_object
534            ):
535                extra_body["reasoning"] = {"effort": thinking_level}
536            else:
537                extra_body["reasoning_effort"] = thinking_level
538
539        if provider.require_openrouter_reasoning:
540            # https://openrouter.ai/docs/use-cases/reasoning-tokens
541            extra_body["reasoning"] = {
542                "exclude": False,
543            }
544
545        if provider.gemini_reasoning_enabled:
546            extra_body["reasoning"] = {
547                "enabled": True,
548            }
549
550        if provider.name == ModelProviderName.openrouter:
551            # Ask OpenRouter to include usage in the response (cost)
552            extra_body["usage"] = {"include": True}
553
554            # Set a default provider order for more deterministic routing.
555            # OpenRouter will ignore providers that don't support the model.
556            # Special cases below (like R1) can override this order.
557            # allow_fallbacks is true by default, but we can override it here.
558            provider_options["order"] = [
559                "fireworks",
560                "parasail",
561                "together",
562                "deepinfra",
563                "novita",
564                "groq",
565                "amazon-bedrock",
566                "azure",
567                "nebius",
568            ]
569
570        if provider.anthropic_extended_thinking:
571            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
572
573        if provider.r1_openrouter_options:
574            # Require providers that support the reasoning parameter
575            provider_options["require_parameters"] = True
576            # Prefer R1 providers with reasonable perf/quants
577            provider_options["order"] = ["fireworks", "together"]
578            # R1 providers with unreasonable quants
579            provider_options["ignore"] = ["deepinfra"]
580
581        # Only set of this request is to get logprobs.
582        if (
583            provider.logprobs_openrouter_options
584            and self.base_adapter_config.top_logprobs is not None
585        ):
586            # Don't let OpenRouter choose a provider that doesn't support logprobs.
587            provider_options["require_parameters"] = True
588            # DeepInfra silently fails to return logprobs consistently.
589            provider_options["ignore"] = ["deepinfra"]
590
591        if provider.openrouter_skip_required_parameters:
592            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
593            provider_options["require_parameters"] = False
594
595        # Siliconflow uses a bool flag for thinking, for some models
596        if provider.siliconflow_enable_thinking is not None:
597            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
598
599        if len(provider_options) > 0:
600            extra_body["provider"] = provider_options
601
602        return extra_body
603
604    def litellm_model_id(self) -> str:
605        # The model ID is an interesting combination of format and url endpoint.
606        # It specifics the provider URL/host, but this is overridden if you manually set an api url
607        if self._litellm_model_id:
608            return self._litellm_model_id
609
610        litellm_provider_info = get_litellm_provider_info(self.model_provider())
611        if litellm_provider_info.is_custom and self._api_base is None:
612            raise ValueError(
613                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
614            )
615
616        self._litellm_model_id = litellm_provider_info.litellm_model_id
617        return self._litellm_model_id
618
619    def _allowed_openai_params_for_completion_kwargs(
620        self, completion_kwargs: dict[str, Any]
621    ) -> list[str]:
622        """
623        LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong
624        and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter.
625        """
626        # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that
627        explicit_allowed_params: Any | list = completion_kwargs.get(
628            "allowed_openai_params", []
629        )
630        if not isinstance(explicit_allowed_params, list):
631            raise ValueError(
632                f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}"
633            )
634        explicit_allowed_params_validated = [
635            param for param in explicit_allowed_params if isinstance(param, str)
636        ]
637        invalid_count = len(explicit_allowed_params) - len(
638            explicit_allowed_params_validated
639        )
640        if invalid_count > 0:
641            raise ValueError(
642                f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings"
643            )
644
645        # these are our own logic
646        automatic_allowed_params: list[str] = []
647        if "tools" in completion_kwargs:
648            automatic_allowed_params.append("tools")
649        if "tool_choice" in completion_kwargs:
650            automatic_allowed_params.append("tool_choice")
651
652        return list(set(explicit_allowed_params_validated + automatic_allowed_params))
653
654    async def build_completion_kwargs(
655        self,
656        provider: KilnModelProvider,
657        messages: list[ChatCompletionMessageIncludingLiteLLM],
658        top_logprobs: int | None,
659        skip_response_format: bool = False,
660    ) -> dict[str, Any]:
661        run_config = as_kiln_agent_run_config(self.run_config)
662        extra_body = self.build_extra_body(provider)
663
664        # Merge all parameters into a single kwargs dict for litellm
665        completion_kwargs = {
666            "model": self.litellm_model_id(),
667            "messages": messages,
668            "api_base": self._api_base,
669            "headers": self._headers,
670            "temperature": run_config.temperature,
671            "top_p": run_config.top_p,
672            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
673            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
674            # Better to ignore them than to fail the model call.
675            # https://docs.litellm.ai/docs/completion/input
676            "drop_params": True,
677            **extra_body,
678            **self._additional_body_options,
679        }
680
681        if self.base_adapter_config.automatic_prompt_caching:
682            # Mark the last message for cache control. Litellm's AnthropicCacheControlHook
683            # handles provider-specific injection. Providers auto-cache matching prefixes,
684            # so marking the last message is sufficient for multi-turn conversations.
685            completion_kwargs["cache_control_injection_points"] = [
686                {"location": "message", "index": -1}
687            ]
688
689        tool_calls = await self.litellm_tools()
690        has_tools = len(tool_calls) > 0
691        if has_tools:
692            completion_kwargs["tools"] = tool_calls
693            completion_kwargs["tool_choice"] = "auto"
694
695        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
696        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
697        if provider.temp_top_p_exclusive:
698            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
699                del completion_kwargs["top_p"]
700            if (
701                "temperature" in completion_kwargs
702                and completion_kwargs["temperature"] == 1.0
703            ):
704                del completion_kwargs["temperature"]
705            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
706                raise ValueError(
707                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
708                )
709
710        if not skip_response_format:
711            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
712            response_format_options = await self.response_format_options()
713
714            # Check for a conflict between tools and response format using tools
715            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
716            if has_tools and "tools" in response_format_options:
717                raise ValueError(
718                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
719                )
720
721            completion_kwargs.update(response_format_options)
722
723        if top_logprobs is not None:
724            completion_kwargs["logprobs"] = True
725            completion_kwargs["top_logprobs"] = top_logprobs
726
727        # any params listed in this list will be passed to the model regardless of LiteLLM's own validation
728        allowed_openai_params = self._allowed_openai_params_for_completion_kwargs(
729            completion_kwargs
730        )
731        if len(allowed_openai_params) > 0:
732            completion_kwargs["allowed_openai_params"] = allowed_openai_params
733
734        completion_kwargs["messages"] = sanitize_messages_for_provider(messages)
735
736        return completion_kwargs
737
738    def usage_from_response(self, response: ModelResponse) -> MessageUsage:
739        litellm_usage = response.get("usage", None)
740
741        # LiteLLM isn't consistent in how it returns the cost.
742        cost = response._hidden_params.get("response_cost", None)
743        if cost is None and litellm_usage:
744            cost = litellm_usage.get("cost", None)
745
746        usage = MessageUsage()
747
748        if not litellm_usage and not cost:
749            return usage
750
751        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
752            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
753            usage.output_tokens = litellm_usage.get("completion_tokens", None)
754            usage.total_tokens = litellm_usage.get("total_tokens", None)
755            prompt_details = litellm_usage.get("prompt_tokens_details", None)
756            if prompt_details and hasattr(prompt_details, "cached_tokens"):
757                usage.cached_tokens = prompt_details.cached_tokens
758            elif prompt_details:
759                logger.warning(
760                    f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted"
761                )
762        else:
763            logger.warning(
764                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
765            )
766
767        if isinstance(cost, float):
768            usage.cost = cost
769        elif cost is not None:
770            # None is allowed, but no other types are expected
771            logger.warning(
772                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
773            )
774
775        return usage
776
777    async def cached_available_tools(self) -> list[KilnToolInterface]:
778        if self._cached_available_tools is None:
779            self._cached_available_tools = await self.available_tools()
780        return self._cached_available_tools
781
782    async def _tools_for_execution(self) -> list[KilnToolInterface]:
783        """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``)."""
784        registry = await self.cached_available_tools()
785        unmanaged = self.base_adapter_config.unmanaged_tools or []
786        return registry + unmanaged
787
788    async def litellm_tools(self) -> list[ToolCallDefinition]:
789        available_tools = await self.cached_available_tools()
790
791        registry_defs = [await tool.toolcall_definition() for tool in available_tools]
792        unmanaged = self.base_adapter_config.unmanaged_tools
793        unmanaged_defs = (
794            [await t.toolcall_definition() for t in unmanaged] if unmanaged else []
795        )
796
797        merged = registry_defs + unmanaged_defs
798        seen_names: set[str] = set()
799        for d in merged:
800            name = d["function"]["name"]
801            if name in seen_names:
802                raise ValueError(
803                    f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names."
804                )
805            seen_names.add(name)
806
807        return merged
808
809    async def process_tool_calls(
810        self, tool_calls: list[ChatCompletionMessageToolCall] | None
811    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
812        if tool_calls is None:
813            return None, []
814
815        assistant_output_from_toolcall: str | None = None
816        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
817        tool_run_coroutines = []
818
819        for tool_call in tool_calls:
820            # Kiln "task_response" tool is used for returning structured output via tool calls.
821            # Load the output from the tool call. Also
822            if tool_call.function.name == "task_response":
823                assistant_output_from_toolcall = tool_call.function.arguments
824                continue
825
826            # Process normal tool calls (not the "task_response" tool)
827            tool_name = tool_call.function.name
828            tool = None
829            for tool_option in await self._tools_for_execution():
830                if await tool_option.name() == tool_name:
831                    tool = tool_option
832                    break
833            if not tool:
834                raise RuntimeError(
835                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
836                )
837
838            # Parse the arguments and validate them against the tool's schema
839            try:
840                parsed_args = json.loads(tool_call.function.arguments)
841            except json.JSONDecodeError:
842                raise RuntimeError(
843                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
844                )
845            try:
846                tool_call_definition = await tool.toolcall_definition()
847                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
848                validate_schema_with_value_error(parsed_args, json_schema)
849            except Exception as e:
850                raise RuntimeError(
851                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
852                ) from e
853
854            # Create context with the calling task's allow_saving setting
855            context = ToolCallContext(
856                allow_saving=self.base_adapter_config.allow_saving
857            )
858
859            async def run_tool_and_format(
860                t=tool, c=context, args=parsed_args, tc_id=tool_call.id
861            ):
862                result = await t.run(c, **args)
863                return ChatCompletionToolMessageParamWrapper(
864                    role="tool",
865                    tool_call_id=tc_id,
866                    content=result.output,
867                    kiln_task_tool_data=result.kiln_task_tool_data
868                    if isinstance(result, KilnTaskToolResult)
869                    else None,
870                    is_error=result.is_error if result.is_error else None,
871                    error_message=result.error_message
872                    if result.error_message
873                    else None,
874                )
875
876            tool_run_coroutines.append(run_tool_and_format())
877
878        if tool_run_coroutines:
879            tool_call_response_messages = await asyncio.gather(*tool_run_coroutines)
880
881        if (
882            assistant_output_from_toolcall is not None
883            and len(tool_call_response_messages) > 0
884        ):
885            raise RuntimeError(
886                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
887            )
888
889        return assistant_output_from_toolcall, tool_call_response_messages
890
891    def litellm_message_to_trace_message(
892        self,
893        raw_message: LiteLLMMessage,
894        latency_ms: int | None = None,
895        usage: MessageUsage | None = None,
896    ) -> ChatCompletionAssistantMessageParamWrapper:
897        """
898        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
899        """
900        message: ChatCompletionAssistantMessageParamWrapper = {
901            "role": "assistant",
902        }
903        if raw_message.role != "assistant":
904            raise ValueError(
905                "Model returned a message with a role other than assistant. This is not supported."
906            )
907
908        if hasattr(raw_message, "content"):
909            message["content"] = raw_message.content
910        if hasattr(raw_message, "reasoning_content"):
911            message["reasoning_content"] = raw_message.reasoning_content
912        if hasattr(raw_message, "tool_calls"):
913            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
914            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
915            for litellm_tool_call in raw_message.tool_calls or []:
916                # Optional in the SDK for streaming responses, but should never be None at this point.
917                if litellm_tool_call.function.name is None:
918                    raise ValueError(
919                        "The model requested a tool call, without providing a function name (required)."
920                    )
921                open_ai_tool_calls.append(
922                    ChatCompletionMessageToolCallParam(
923                        id=litellm_tool_call.id,
924                        type="function",
925                        function={
926                            "name": litellm_tool_call.function.name,
927                            "arguments": litellm_tool_call.function.arguments,
928                        },
929                    )
930                )
931            if len(open_ai_tool_calls) > 0:
932                message["tool_calls"] = open_ai_tool_calls
933
934        if latency_ms is not None:
935            message["latency_ms"] = latency_ms
936
937        if usage is not None:
938            message["usage"] = usage
939
940        if not message.get("content") and not message.get("tool_calls"):
941            raise ValueError(
942                "Model returned an assistant message, but no content or tool calls. This is not supported."
943            )
944
945        return message
946
947    def all_messages_to_trace(
948        self,
949        messages: list[ChatCompletionMessageIncludingLiteLLM],
950        message_latency: dict[int, int] | None = None,
951        message_usage: dict[int, MessageUsage] | None = None,
952    ) -> list[ChatCompletionMessageParam]:
953        """
954        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
955
956        Non-LiteLLM dict messages pass through unchanged. Any per-message
957        ``usage``/``latency_ms`` already attached to those dicts (e.g. from a
958        seeded prior trace) is preserved.
959        """
960        trace: list[ChatCompletionMessageParam] = []
961        for i, message in enumerate(messages):
962            if isinstance(message, LiteLLMMessage):
963                latency_ms = message_latency.get(i) if message_latency else None
964                usage = message_usage.get(i) if message_usage else None
965                trace.append(
966                    self.litellm_message_to_trace_message(message, latency_ms, usage)
967                )
968            else:
969                trace.append(message)
970        return trace

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Prompt building is handled internally by the adapter, which uses a prompt builder based on the run config. To override the prompt building behavior, pass a custom prompt builder to the adapter config.

LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
 91    def __init__(
 92        self,
 93        config: LiteLlmConfig,
 94        kiln_task: datamodel.Task,
 95        base_adapter_config: AdapterConfig | None = None,
 96    ):
 97        if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties):
 98            raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties")
 99        self.config = config
100        self._additional_body_options = config.additional_body_options
101        self._api_base = config.base_url
102        self._headers = config.default_headers
103        self._litellm_model_id: str | None = None
104        self._cached_available_tools: list[KilnToolInterface] | None = None
105
106        super().__init__(
107            task=kiln_task,
108            run_config=config.run_config_properties,
109            config=base_adapter_config,
110        )
111
112        unmanaged_tools = self.base_adapter_config.unmanaged_tools
113        if unmanaged_tools:
114            _validate_unmanaged_tools(unmanaged_tools)
config
async def acompletion_checking_response( self, **kwargs: Any) -> Tuple[litellm.types.utils.ModelResponse, litellm.types.utils.Choices]:
403    async def acompletion_checking_response(
404        self, **kwargs: Any
405    ) -> Tuple[ModelResponse, Choices]:
406        response = await litellm.acompletion(**kwargs)
407
408        if (
409            not isinstance(response, ModelResponse)
410            or not response.choices
411            or len(response.choices) == 0
412            or not isinstance(response.choices[0], Choices)
413        ):
414            raise RuntimeError(
415                f"Expected ModelResponse with Choices, got {type(response)}."
416            )
417        return response, response.choices[0]
def adapter_name(self) -> str:
419    def adapter_name(self) -> str:
420        return "kiln_openai_compatible_adapter"
async def response_format_options(self) -> dict[str, typing.Any]:
422    async def response_format_options(self) -> dict[str, Any]:
423        # Unstructured if task isn't structured
424        if not self.has_structured_output():
425            return {}
426
427        run_config = as_kiln_agent_run_config(self.run_config)
428        structured_output_mode: StructuredOutputMode = run_config.structured_output_mode
429
430        match structured_output_mode:
431            case StructuredOutputMode.json_mode:
432                return {"response_format": {"type": "json_object"}}
433            case StructuredOutputMode.json_schema:
434                return self.json_schema_response_format()
435            case StructuredOutputMode.function_calling_weak:
436                return self.tool_call_params(strict=False)
437            case StructuredOutputMode.function_calling:
438                return self.tool_call_params(strict=True)
439            case StructuredOutputMode.json_instructions:
440                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
441                return {}
442            case StructuredOutputMode.json_custom_instructions:
443                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
444                return {}
445            case StructuredOutputMode.json_instruction_and_object:
446                # We set response_format to json_object and also set json instructions in the prompt
447                return {"response_format": {"type": "json_object"}}
448            case StructuredOutputMode.default:
449                provider_name = run_config.model_provider_name
450                if provider_name == ModelProviderName.ollama:
451                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
452                    return self.json_schema_response_format()
453                elif provider_name == ModelProviderName.docker_model_runner:
454                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
455                    return self.json_schema_response_format()
456                else:
457                    # Default to function calling -- it's older than the other modes. Higher compatibility.
458                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
459                    strict = provider_name == ModelProviderName.openai
460                    return self.tool_call_params(strict=strict)
461            case StructuredOutputMode.unknown:
462                # See above, but this case should never happen.
463                raise ValueError("Structured output mode is unknown.")
464            case _:
465                raise_exhaustive_enum_error(structured_output_mode)  # type: ignore[arg-type]
def json_schema_response_format(self) -> dict[str, typing.Any]:
467    def json_schema_response_format(self) -> dict[str, Any]:
468        output_schema = self.task.output_schema()
469        if output_schema is None:
470            raise ValueError(
471                "Invalid output schema for this task. Cannot use JSON schema response format."
472            )
473        output_schema = close_object_schemas(output_schema, strict=True)
474        return {
475            "response_format": {
476                "type": "json_schema",
477                "json_schema": {
478                    "name": "task_response",
479                    "schema": output_schema,
480                },
481            }
482        }
def tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
484    def tool_call_params(self, strict: bool) -> dict[str, Any]:
485        # Add additional_properties: false to the schema (OpenAI requires this for some models)
486        output_schema = self.task.output_schema()
487        if not isinstance(output_schema, dict):
488            raise ValueError(
489                "Invalid output schema for this task. Can not use tool calls."
490            )
491        output_schema = close_object_schemas(output_schema, strict=strict)
492
493        function_params = {
494            "name": "task_response",
495            "parameters": output_schema,
496        }
497        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
498        if strict:
499            function_params["strict"] = True
500
501        return {
502            "tools": [
503                {
504                    "type": "function",
505                    "function": function_params,
506                }
507            ],
508            "tool_choice": {
509                "type": "function",
510                "function": {"name": "task_response"},
511            },
512        }
def build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
514    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
515        # Don't love having this logic here. But it's worth the usability improvement
516        # so better to keep it than exclude it. Should figure out how I want to isolate
517        # this sort of logic so it's config driven and can be overridden
518        extra_body: dict[str, Any] = {}
519        provider_options = {}
520
521        run_config = as_kiln_agent_run_config(self.run_config)
522        # For legacy config 'thinking_level' is not set, default to provider's default
523        if "thinking_level" in run_config.model_fields_set:
524            thinking_level = run_config.thinking_level
525        else:
526            thinking_level = provider.default_thinking_level
527
528        # Set the reasoning_effort
529        if thinking_level is not None:
530            # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens
531            if (
532                provider.name == ModelProviderName.openrouter
533                and provider.openrouter_reasoning_object
534            ):
535                extra_body["reasoning"] = {"effort": thinking_level}
536            else:
537                extra_body["reasoning_effort"] = thinking_level
538
539        if provider.require_openrouter_reasoning:
540            # https://openrouter.ai/docs/use-cases/reasoning-tokens
541            extra_body["reasoning"] = {
542                "exclude": False,
543            }
544
545        if provider.gemini_reasoning_enabled:
546            extra_body["reasoning"] = {
547                "enabled": True,
548            }
549
550        if provider.name == ModelProviderName.openrouter:
551            # Ask OpenRouter to include usage in the response (cost)
552            extra_body["usage"] = {"include": True}
553
554            # Set a default provider order for more deterministic routing.
555            # OpenRouter will ignore providers that don't support the model.
556            # Special cases below (like R1) can override this order.
557            # allow_fallbacks is true by default, but we can override it here.
558            provider_options["order"] = [
559                "fireworks",
560                "parasail",
561                "together",
562                "deepinfra",
563                "novita",
564                "groq",
565                "amazon-bedrock",
566                "azure",
567                "nebius",
568            ]
569
570        if provider.anthropic_extended_thinking:
571            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
572
573        if provider.r1_openrouter_options:
574            # Require providers that support the reasoning parameter
575            provider_options["require_parameters"] = True
576            # Prefer R1 providers with reasonable perf/quants
577            provider_options["order"] = ["fireworks", "together"]
578            # R1 providers with unreasonable quants
579            provider_options["ignore"] = ["deepinfra"]
580
581        # Only set of this request is to get logprobs.
582        if (
583            provider.logprobs_openrouter_options
584            and self.base_adapter_config.top_logprobs is not None
585        ):
586            # Don't let OpenRouter choose a provider that doesn't support logprobs.
587            provider_options["require_parameters"] = True
588            # DeepInfra silently fails to return logprobs consistently.
589            provider_options["ignore"] = ["deepinfra"]
590
591        if provider.openrouter_skip_required_parameters:
592            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
593            provider_options["require_parameters"] = False
594
595        # Siliconflow uses a bool flag for thinking, for some models
596        if provider.siliconflow_enable_thinking is not None:
597            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
598
599        if len(provider_options) > 0:
600            extra_body["provider"] = provider_options
601
602        return extra_body
def litellm_model_id(self) -> str:
604    def litellm_model_id(self) -> str:
605        # The model ID is an interesting combination of format and url endpoint.
606        # It specifics the provider URL/host, but this is overridden if you manually set an api url
607        if self._litellm_model_id:
608            return self._litellm_model_id
609
610        litellm_provider_info = get_litellm_provider_info(self.model_provider())
611        if litellm_provider_info.is_custom and self._api_base is None:
612            raise ValueError(
613                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
614            )
615
616        self._litellm_model_id = litellm_provider_info.litellm_model_id
617        return self._litellm_model_id
async def build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
654    async def build_completion_kwargs(
655        self,
656        provider: KilnModelProvider,
657        messages: list[ChatCompletionMessageIncludingLiteLLM],
658        top_logprobs: int | None,
659        skip_response_format: bool = False,
660    ) -> dict[str, Any]:
661        run_config = as_kiln_agent_run_config(self.run_config)
662        extra_body = self.build_extra_body(provider)
663
664        # Merge all parameters into a single kwargs dict for litellm
665        completion_kwargs = {
666            "model": self.litellm_model_id(),
667            "messages": messages,
668            "api_base": self._api_base,
669            "headers": self._headers,
670            "temperature": run_config.temperature,
671            "top_p": run_config.top_p,
672            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
673            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
674            # Better to ignore them than to fail the model call.
675            # https://docs.litellm.ai/docs/completion/input
676            "drop_params": True,
677            **extra_body,
678            **self._additional_body_options,
679        }
680
681        if self.base_adapter_config.automatic_prompt_caching:
682            # Mark the last message for cache control. Litellm's AnthropicCacheControlHook
683            # handles provider-specific injection. Providers auto-cache matching prefixes,
684            # so marking the last message is sufficient for multi-turn conversations.
685            completion_kwargs["cache_control_injection_points"] = [
686                {"location": "message", "index": -1}
687            ]
688
689        tool_calls = await self.litellm_tools()
690        has_tools = len(tool_calls) > 0
691        if has_tools:
692            completion_kwargs["tools"] = tool_calls
693            completion_kwargs["tool_choice"] = "auto"
694
695        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
696        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
697        if provider.temp_top_p_exclusive:
698            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
699                del completion_kwargs["top_p"]
700            if (
701                "temperature" in completion_kwargs
702                and completion_kwargs["temperature"] == 1.0
703            ):
704                del completion_kwargs["temperature"]
705            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
706                raise ValueError(
707                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
708                )
709
710        if not skip_response_format:
711            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
712            response_format_options = await self.response_format_options()
713
714            # Check for a conflict between tools and response format using tools
715            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
716            if has_tools and "tools" in response_format_options:
717                raise ValueError(
718                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
719                )
720
721            completion_kwargs.update(response_format_options)
722
723        if top_logprobs is not None:
724            completion_kwargs["logprobs"] = True
725            completion_kwargs["top_logprobs"] = top_logprobs
726
727        # any params listed in this list will be passed to the model regardless of LiteLLM's own validation
728        allowed_openai_params = self._allowed_openai_params_for_completion_kwargs(
729            completion_kwargs
730        )
731        if len(allowed_openai_params) > 0:
732            completion_kwargs["allowed_openai_params"] = allowed_openai_params
733
734        completion_kwargs["messages"] = sanitize_messages_for_provider(messages)
735
736        return completion_kwargs
def usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.MessageUsage:
738    def usage_from_response(self, response: ModelResponse) -> MessageUsage:
739        litellm_usage = response.get("usage", None)
740
741        # LiteLLM isn't consistent in how it returns the cost.
742        cost = response._hidden_params.get("response_cost", None)
743        if cost is None and litellm_usage:
744            cost = litellm_usage.get("cost", None)
745
746        usage = MessageUsage()
747
748        if not litellm_usage and not cost:
749            return usage
750
751        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
752            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
753            usage.output_tokens = litellm_usage.get("completion_tokens", None)
754            usage.total_tokens = litellm_usage.get("total_tokens", None)
755            prompt_details = litellm_usage.get("prompt_tokens_details", None)
756            if prompt_details and hasattr(prompt_details, "cached_tokens"):
757                usage.cached_tokens = prompt_details.cached_tokens
758            elif prompt_details:
759                logger.warning(
760                    f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted"
761                )
762        else:
763            logger.warning(
764                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
765            )
766
767        if isinstance(cost, float):
768            usage.cost = cost
769        elif cost is not None:
770            # None is allowed, but no other types are expected
771            logger.warning(
772                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
773            )
774
775        return usage
async def cached_available_tools(self) -> list[kiln_ai.tools.KilnToolInterface]:
777    async def cached_available_tools(self) -> list[KilnToolInterface]:
778        if self._cached_available_tools is None:
779            self._cached_available_tools = await self.available_tools()
780        return self._cached_available_tools
async def litellm_tools(self) -> list[kiln_ai.tools.base_tool.ToolCallDefinition]:
788    async def litellm_tools(self) -> list[ToolCallDefinition]:
789        available_tools = await self.cached_available_tools()
790
791        registry_defs = [await tool.toolcall_definition() for tool in available_tools]
792        unmanaged = self.base_adapter_config.unmanaged_tools
793        unmanaged_defs = (
794            [await t.toolcall_definition() for t in unmanaged] if unmanaged else []
795        )
796
797        merged = registry_defs + unmanaged_defs
798        seen_names: set[str] = set()
799        for d in merged:
800            name = d["function"]["name"]
801            if name in seen_names:
802                raise ValueError(
803                    f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names."
804                )
805            seen_names.add(name)
806
807        return merged
async def process_tool_calls( self, tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None) -> tuple[str | None, list[kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper]]:
809    async def process_tool_calls(
810        self, tool_calls: list[ChatCompletionMessageToolCall] | None
811    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
812        if tool_calls is None:
813            return None, []
814
815        assistant_output_from_toolcall: str | None = None
816        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
817        tool_run_coroutines = []
818
819        for tool_call in tool_calls:
820            # Kiln "task_response" tool is used for returning structured output via tool calls.
821            # Load the output from the tool call. Also
822            if tool_call.function.name == "task_response":
823                assistant_output_from_toolcall = tool_call.function.arguments
824                continue
825
826            # Process normal tool calls (not the "task_response" tool)
827            tool_name = tool_call.function.name
828            tool = None
829            for tool_option in await self._tools_for_execution():
830                if await tool_option.name() == tool_name:
831                    tool = tool_option
832                    break
833            if not tool:
834                raise RuntimeError(
835                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
836                )
837
838            # Parse the arguments and validate them against the tool's schema
839            try:
840                parsed_args = json.loads(tool_call.function.arguments)
841            except json.JSONDecodeError:
842                raise RuntimeError(
843                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
844                )
845            try:
846                tool_call_definition = await tool.toolcall_definition()
847                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
848                validate_schema_with_value_error(parsed_args, json_schema)
849            except Exception as e:
850                raise RuntimeError(
851                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
852                ) from e
853
854            # Create context with the calling task's allow_saving setting
855            context = ToolCallContext(
856                allow_saving=self.base_adapter_config.allow_saving
857            )
858
859            async def run_tool_and_format(
860                t=tool, c=context, args=parsed_args, tc_id=tool_call.id
861            ):
862                result = await t.run(c, **args)
863                return ChatCompletionToolMessageParamWrapper(
864                    role="tool",
865                    tool_call_id=tc_id,
866                    content=result.output,
867                    kiln_task_tool_data=result.kiln_task_tool_data
868                    if isinstance(result, KilnTaskToolResult)
869                    else None,
870                    is_error=result.is_error if result.is_error else None,
871                    error_message=result.error_message
872                    if result.error_message
873                    else None,
874                )
875
876            tool_run_coroutines.append(run_tool_and_format())
877
878        if tool_run_coroutines:
879            tool_call_response_messages = await asyncio.gather(*tool_run_coroutines)
880
881        if (
882            assistant_output_from_toolcall is not None
883            and len(tool_call_response_messages) > 0
884        ):
885            raise RuntimeError(
886                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
887            )
888
889        return assistant_output_from_toolcall, tool_call_response_messages
def litellm_message_to_trace_message( self, raw_message: litellm.types.utils.Message, latency_ms: int | None = None, usage: kiln_ai.datamodel.MessageUsage | None = None) -> kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper:
891    def litellm_message_to_trace_message(
892        self,
893        raw_message: LiteLLMMessage,
894        latency_ms: int | None = None,
895        usage: MessageUsage | None = None,
896    ) -> ChatCompletionAssistantMessageParamWrapper:
897        """
898        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
899        """
900        message: ChatCompletionAssistantMessageParamWrapper = {
901            "role": "assistant",
902        }
903        if raw_message.role != "assistant":
904            raise ValueError(
905                "Model returned a message with a role other than assistant. This is not supported."
906            )
907
908        if hasattr(raw_message, "content"):
909            message["content"] = raw_message.content
910        if hasattr(raw_message, "reasoning_content"):
911            message["reasoning_content"] = raw_message.reasoning_content
912        if hasattr(raw_message, "tool_calls"):
913            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
914            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
915            for litellm_tool_call in raw_message.tool_calls or []:
916                # Optional in the SDK for streaming responses, but should never be None at this point.
917                if litellm_tool_call.function.name is None:
918                    raise ValueError(
919                        "The model requested a tool call, without providing a function name (required)."
920                    )
921                open_ai_tool_calls.append(
922                    ChatCompletionMessageToolCallParam(
923                        id=litellm_tool_call.id,
924                        type="function",
925                        function={
926                            "name": litellm_tool_call.function.name,
927                            "arguments": litellm_tool_call.function.arguments,
928                        },
929                    )
930                )
931            if len(open_ai_tool_calls) > 0:
932                message["tool_calls"] = open_ai_tool_calls
933
934        if latency_ms is not None:
935            message["latency_ms"] = latency_ms
936
937        if usage is not None:
938            message["usage"] = usage
939
940        if not message.get("content") and not message.get("tool_calls"):
941            raise ValueError(
942                "Model returned an assistant message, but no content or tool calls. This is not supported."
943            )
944
945        return message

Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper

def all_messages_to_trace( self, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], message_latency: dict[int, int] | None = None, message_usage: dict[int, kiln_ai.datamodel.MessageUsage] | None = None) -> list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]]:
947    def all_messages_to_trace(
948        self,
949        messages: list[ChatCompletionMessageIncludingLiteLLM],
950        message_latency: dict[int, int] | None = None,
951        message_usage: dict[int, MessageUsage] | None = None,
952    ) -> list[ChatCompletionMessageParam]:
953        """
954        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
955
956        Non-LiteLLM dict messages pass through unchanged. Any per-message
957        ``usage``/``latency_ms`` already attached to those dicts (e.g. from a
958        seeded prior trace) is preserved.
959        """
960        trace: list[ChatCompletionMessageParam] = []
961        for i, message in enumerate(messages):
962            if isinstance(message, LiteLLMMessage):
963                latency_ms = message_latency.get(i) if message_latency else None
964                usage = message_usage.get(i) if message_usage else None
965                trace.append(
966                    self.litellm_message_to_trace_message(message, latency_ms, usage)
967                )
968            else:
969                trace.append(message)
970        return trace

Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.

Non-LiteLLM dict messages pass through unchanged. Any per-message usage/latency_ms already attached to those dicts (e.g. from a seeded prior trace) is preserved.