kiln_ai.adapters.model_adapters.litellm_adapter

  1import copy
  2import json
  3import logging
  4from dataclasses import dataclass
  5from typing import Any, Dict, List, Tuple, TypeAlias, Union
  6
  7import litellm
  8from litellm.types.utils import (
  9    ChatCompletionMessageToolCall,
 10    ChoiceLogprobs,
 11    Choices,
 12    ModelResponse,
 13)
 14from litellm.types.utils import Message as LiteLLMMessage
 15from litellm.types.utils import Usage as LiteLlmUsage
 16from openai.types.chat.chat_completion_message_tool_call_param import (
 17    ChatCompletionMessageToolCallParam,
 18)
 19
 20import kiln_ai.datamodel as datamodel
 21from kiln_ai.adapters.ml_model_list import (
 22    KilnModelProvider,
 23    ModelProviderName,
 24    StructuredOutputMode,
 25)
 26from kiln_ai.adapters.model_adapters.base_adapter import (
 27    AdapterConfig,
 28    BaseAdapter,
 29    RunOutput,
 30    Usage,
 31)
 32from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
 33from kiln_ai.datamodel.datamodel_enums import InputType
 34from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
 35from kiln_ai.tools.base_tool import (
 36    KilnToolInterface,
 37    ToolCallContext,
 38    ToolCallDefinition,
 39)
 40from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
 41from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
 42from kiln_ai.utils.litellm import get_litellm_provider_info
 43from kiln_ai.utils.open_ai_types import (
 44    ChatCompletionAssistantMessageParamWrapper,
 45    ChatCompletionMessageParam,
 46    ChatCompletionToolMessageParamWrapper,
 47)
 48
 49MAX_CALLS_PER_TURN = 10
 50MAX_TOOL_CALLS_PER_TURN = 30
 51
 52logger = logging.getLogger(__name__)
 53
 54ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[
 55    ChatCompletionMessageParam, LiteLLMMessage
 56]
 57
 58
 59@dataclass
 60class ModelTurnResult:
 61    assistant_message: str
 62    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
 63    model_response: ModelResponse | None
 64    model_choice: Choices | None
 65    usage: Usage
 66
 67
 68class LiteLlmAdapter(BaseAdapter):
 69    def __init__(
 70        self,
 71        config: LiteLlmConfig,
 72        kiln_task: datamodel.Task,
 73        base_adapter_config: AdapterConfig | None = None,
 74    ):
 75        self.config = config
 76        self._additional_body_options = config.additional_body_options
 77        self._api_base = config.base_url
 78        self._headers = config.default_headers
 79        self._litellm_model_id: str | None = None
 80        self._cached_available_tools: list[KilnToolInterface] | None = None
 81
 82        super().__init__(
 83            task=kiln_task,
 84            run_config=config.run_config_properties,
 85            config=base_adapter_config,
 86        )
 87
 88    async def _run_model_turn(
 89        self,
 90        provider: KilnModelProvider,
 91        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
 92        top_logprobs: int | None,
 93        skip_response_format: bool,
 94    ) -> ModelTurnResult:
 95        """
 96        Call the model for a single top level turn: from user message to agent message.
 97
 98        It may make handle iterations of tool calls between the user/agent message if needed.
 99        """
100
101        usage = Usage()
102        messages = list(prior_messages)
103        tool_calls_count = 0
104
105        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
106            # Build completion kwargs for tool calls
107            completion_kwargs = await self.build_completion_kwargs(
108                provider,
109                # Pass a copy, as acompletion mutates objects and breaks types.
110                copy.deepcopy(messages),
111                top_logprobs,
112                skip_response_format,
113            )
114
115            # Make the completion call
116            model_response, response_choice = await self.acompletion_checking_response(
117                **completion_kwargs
118            )
119
120            # count the usage
121            usage += self.usage_from_response(model_response)
122
123            # Extract content and tool calls
124            if not hasattr(response_choice, "message"):
125                raise ValueError("Response choice has no message")
126            content = response_choice.message.content
127            tool_calls = response_choice.message.tool_calls
128            if not content and not tool_calls:
129                raise ValueError(
130                    "Model returned an assistant message, but no content or tool calls. This is not supported."
131                )
132
133            # Add message to messages, so it can be used in the next turn
134            messages.append(response_choice.message)
135
136            # Process tool calls if any
137            if tool_calls and len(tool_calls) > 0:
138                (
139                    assistant_message_from_toolcall,
140                    tool_call_messages,
141                ) = await self.process_tool_calls(tool_calls)
142
143                # Add tool call results to messages
144                messages.extend(tool_call_messages)
145
146                # If task_response tool was called, we're done
147                if assistant_message_from_toolcall is not None:
148                    return ModelTurnResult(
149                        assistant_message=assistant_message_from_toolcall,
150                        all_messages=messages,
151                        model_response=model_response,
152                        model_choice=response_choice,
153                        usage=usage,
154                    )
155
156                # If there were tool calls, increment counter and continue
157                if tool_call_messages:
158                    tool_calls_count += 1
159                    continue
160
161            # If no tool calls, return the content as final output
162            if content:
163                return ModelTurnResult(
164                    assistant_message=content,
165                    all_messages=messages,
166                    model_response=model_response,
167                    model_choice=response_choice,
168                    usage=usage,
169                )
170
171            # If we get here with no content and no tool calls, break
172            raise RuntimeError(
173                "Model returned neither content nor tool calls. It must return at least one of these."
174            )
175
176        raise RuntimeError(
177            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
178        )
179
180    async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]:
181        usage = Usage()
182
183        provider = self.model_provider()
184        if not provider.model_id:
185            raise ValueError("Model ID is required for OpenAI compatible models")
186
187        chat_formatter = self.build_chat_formatter(input)
188        messages: list[ChatCompletionMessageIncludingLiteLLM] = []
189
190        prior_output: str | None = None
191        final_choice: Choices | None = None
192        turns = 0
193
194        while True:
195            turns += 1
196            if turns > MAX_CALLS_PER_TURN:
197                raise RuntimeError(
198                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
199                )
200
201            turn = chat_formatter.next_turn(prior_output)
202            if turn is None:
203                # No next turn, we're done
204                break
205
206            # Add messages from the turn to chat history
207            for message in turn.messages:
208                if message.content is None:
209                    raise ValueError("Empty message content isn't allowed")
210                # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
211                messages.append({"role": message.role, "content": message.content})  # type: ignore
212
213            skip_response_format = not turn.final_call
214            turn_result = await self._run_model_turn(
215                provider,
216                messages,
217                self.base_adapter_config.top_logprobs if turn.final_call else None,
218                skip_response_format,
219            )
220
221            usage += turn_result.usage
222
223            prior_output = turn_result.assistant_message
224            messages = turn_result.all_messages
225            final_choice = turn_result.model_choice
226
227            if not prior_output:
228                raise RuntimeError("No assistant message/output returned from model")
229
230        logprobs = self._extract_and_validate_logprobs(final_choice)
231
232        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
233        intermediate_outputs = chat_formatter.intermediate_outputs()
234        self._extract_reasoning_to_intermediate_outputs(
235            final_choice, intermediate_outputs
236        )
237
238        if not isinstance(prior_output, str):
239            raise RuntimeError(f"assistant message is not a string: {prior_output}")
240
241        trace = self.all_messages_to_trace(messages)
242        output = RunOutput(
243            output=prior_output,
244            intermediate_outputs=intermediate_outputs,
245            output_logprobs=logprobs,
246            trace=trace,
247        )
248
249        return output, usage
250
251    def _extract_and_validate_logprobs(
252        self, final_choice: Choices | None
253    ) -> ChoiceLogprobs | None:
254        """
255        Extract logprobs from the final choice and validate they exist if required.
256        """
257        logprobs = None
258        if (
259            final_choice is not None
260            and hasattr(final_choice, "logprobs")
261            and isinstance(final_choice.logprobs, ChoiceLogprobs)
262        ):
263            logprobs = final_choice.logprobs
264
265        # Check logprobs worked, if required
266        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
267            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
268
269        return logprobs
270
271    def _extract_reasoning_to_intermediate_outputs(
272        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
273    ) -> None:
274        """Extract reasoning content from model choice and add to intermediate outputs if present."""
275        if (
276            final_choice is not None
277            and hasattr(final_choice, "message")
278            and hasattr(final_choice.message, "reasoning_content")
279        ):
280            reasoning_content = final_choice.message.reasoning_content
281            if reasoning_content is not None:
282                stripped_reasoning_content = reasoning_content.strip()
283                if len(stripped_reasoning_content) > 0:
284                    intermediate_outputs["reasoning"] = stripped_reasoning_content
285
286    async def acompletion_checking_response(
287        self, **kwargs
288    ) -> Tuple[ModelResponse, Choices]:
289        response = await litellm.acompletion(**kwargs)
290        if (
291            not isinstance(response, ModelResponse)
292            or not response.choices
293            or len(response.choices) == 0
294            or not isinstance(response.choices[0], Choices)
295        ):
296            raise RuntimeError(
297                f"Expected ModelResponse with Choices, got {type(response)}."
298            )
299        return response, response.choices[0]
300
301    def adapter_name(self) -> str:
302        return "kiln_openai_compatible_adapter"
303
304    async def response_format_options(self) -> dict[str, Any]:
305        # Unstructured if task isn't structured
306        if not self.has_structured_output():
307            return {}
308
309        structured_output_mode = self.run_config.structured_output_mode
310
311        match structured_output_mode:
312            case StructuredOutputMode.json_mode:
313                return {"response_format": {"type": "json_object"}}
314            case StructuredOutputMode.json_schema:
315                return self.json_schema_response_format()
316            case StructuredOutputMode.function_calling_weak:
317                return self.tool_call_params(strict=False)
318            case StructuredOutputMode.function_calling:
319                return self.tool_call_params(strict=True)
320            case StructuredOutputMode.json_instructions:
321                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
322                return {}
323            case StructuredOutputMode.json_custom_instructions:
324                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
325                return {}
326            case StructuredOutputMode.json_instruction_and_object:
327                # We set response_format to json_object and also set json instructions in the prompt
328                return {"response_format": {"type": "json_object"}}
329            case StructuredOutputMode.default:
330                provider_name = self.run_config.model_provider_name
331                if provider_name == ModelProviderName.ollama:
332                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
333                    return self.json_schema_response_format()
334                elif provider_name == ModelProviderName.docker_model_runner:
335                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
336                    return self.json_schema_response_format()
337                else:
338                    # Default to function calling -- it's older than the other modes. Higher compatibility.
339                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
340                    strict = provider_name == ModelProviderName.openai
341                    return self.tool_call_params(strict=strict)
342            case StructuredOutputMode.unknown:
343                # See above, but this case should never happen.
344                raise ValueError("Structured output mode is unknown.")
345            case _:
346                raise_exhaustive_enum_error(structured_output_mode)
347
348    def json_schema_response_format(self) -> dict[str, Any]:
349        output_schema = self.task.output_schema()
350        return {
351            "response_format": {
352                "type": "json_schema",
353                "json_schema": {
354                    "name": "task_response",
355                    "schema": output_schema,
356                },
357            }
358        }
359
360    def tool_call_params(self, strict: bool) -> dict[str, Any]:
361        # Add additional_properties: false to the schema (OpenAI requires this for some models)
362        output_schema = self.task.output_schema()
363        if not isinstance(output_schema, dict):
364            raise ValueError(
365                "Invalid output schema for this task. Can not use tool calls."
366            )
367        output_schema["additionalProperties"] = False
368
369        function_params = {
370            "name": "task_response",
371            "parameters": output_schema,
372        }
373        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
374        if strict:
375            function_params["strict"] = True
376
377        return {
378            "tools": [
379                {
380                    "type": "function",
381                    "function": function_params,
382                }
383            ],
384            "tool_choice": {
385                "type": "function",
386                "function": {"name": "task_response"},
387            },
388        }
389
390    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
391        # Don't love having this logic here. But it's worth the usability improvement
392        # so better to keep it than exclude it. Should figure out how I want to isolate
393        # this sort of logic so it's config driven and can be overridden
394
395        extra_body = {}
396        provider_options = {}
397
398        if provider.thinking_level is not None:
399            extra_body["reasoning_effort"] = provider.thinking_level
400
401        if provider.require_openrouter_reasoning:
402            # https://openrouter.ai/docs/use-cases/reasoning-tokens
403            extra_body["reasoning"] = {
404                "exclude": False,
405            }
406
407        if provider.gemini_reasoning_enabled:
408            extra_body["reasoning"] = {
409                "enabled": True,
410            }
411
412        if provider.name == ModelProviderName.openrouter:
413            # Ask OpenRouter to include usage in the response (cost)
414            extra_body["usage"] = {"include": True}
415
416            # Set a default provider order for more deterministic routing.
417            # OpenRouter will ignore providers that don't support the model.
418            # Special cases below (like R1) can override this order.
419            # allow_fallbacks is true by default, but we can override it here.
420            provider_options["order"] = [
421                "fireworks",
422                "parasail",
423                "together",
424                "deepinfra",
425                "novita",
426                "groq",
427                "amazon-bedrock",
428                "azure",
429                "nebius",
430            ]
431
432        if provider.anthropic_extended_thinking:
433            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
434
435        if provider.r1_openrouter_options:
436            # Require providers that support the reasoning parameter
437            provider_options["require_parameters"] = True
438            # Prefer R1 providers with reasonable perf/quants
439            provider_options["order"] = ["fireworks", "together"]
440            # R1 providers with unreasonable quants
441            provider_options["ignore"] = ["deepinfra"]
442
443        # Only set of this request is to get logprobs.
444        if (
445            provider.logprobs_openrouter_options
446            and self.base_adapter_config.top_logprobs is not None
447        ):
448            # Don't let OpenRouter choose a provider that doesn't support logprobs.
449            provider_options["require_parameters"] = True
450            # DeepInfra silently fails to return logprobs consistently.
451            provider_options["ignore"] = ["deepinfra"]
452
453        if provider.openrouter_skip_required_parameters:
454            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
455            provider_options["require_parameters"] = False
456
457        # Siliconflow uses a bool flag for thinking, for some models
458        if provider.siliconflow_enable_thinking is not None:
459            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
460
461        if len(provider_options) > 0:
462            extra_body["provider"] = provider_options
463
464        return extra_body
465
466    def litellm_model_id(self) -> str:
467        # The model ID is an interesting combination of format and url endpoint.
468        # It specifics the provider URL/host, but this is overridden if you manually set an api url
469        if self._litellm_model_id:
470            return self._litellm_model_id
471
472        litellm_provider_info = get_litellm_provider_info(self.model_provider())
473        if litellm_provider_info.is_custom and self._api_base is None:
474            raise ValueError(
475                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
476            )
477
478        self._litellm_model_id = litellm_provider_info.litellm_model_id
479        return self._litellm_model_id
480
481    async def build_completion_kwargs(
482        self,
483        provider: KilnModelProvider,
484        messages: list[ChatCompletionMessageIncludingLiteLLM],
485        top_logprobs: int | None,
486        skip_response_format: bool = False,
487    ) -> dict[str, Any]:
488        extra_body = self.build_extra_body(provider)
489
490        # Merge all parameters into a single kwargs dict for litellm
491        completion_kwargs = {
492            "model": self.litellm_model_id(),
493            "messages": messages,
494            "api_base": self._api_base,
495            "headers": self._headers,
496            "temperature": self.run_config.temperature,
497            "top_p": self.run_config.top_p,
498            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
499            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
500            # Better to ignore them than to fail the model call.
501            # https://docs.litellm.ai/docs/completion/input
502            "drop_params": True,
503            **extra_body,
504            **self._additional_body_options,
505        }
506
507        tool_calls = await self.litellm_tools()
508        has_tools = len(tool_calls) > 0
509        if has_tools:
510            completion_kwargs["tools"] = tool_calls
511            completion_kwargs["tool_choice"] = "auto"
512
513        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
514        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
515        if provider.temp_top_p_exclusive:
516            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
517                del completion_kwargs["top_p"]
518            if (
519                "temperature" in completion_kwargs
520                and completion_kwargs["temperature"] == 1.0
521            ):
522                del completion_kwargs["temperature"]
523            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
524                raise ValueError(
525                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
526                )
527
528        if not skip_response_format:
529            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
530            response_format_options = await self.response_format_options()
531
532            # Check for a conflict between tools and response format using tools
533            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
534            if has_tools and "tools" in response_format_options:
535                raise ValueError(
536                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
537                )
538
539            completion_kwargs.update(response_format_options)
540
541        if top_logprobs is not None:
542            completion_kwargs["logprobs"] = True
543            completion_kwargs["top_logprobs"] = top_logprobs
544
545        return completion_kwargs
546
547    def usage_from_response(self, response: ModelResponse) -> Usage:
548        litellm_usage = response.get("usage", None)
549
550        # LiteLLM isn't consistent in how it returns the cost.
551        cost = response._hidden_params.get("response_cost", None)
552        if cost is None and litellm_usage:
553            cost = litellm_usage.get("cost", None)
554
555        usage = Usage()
556
557        if not litellm_usage and not cost:
558            return usage
559
560        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
561            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
562            usage.output_tokens = litellm_usage.get("completion_tokens", None)
563            usage.total_tokens = litellm_usage.get("total_tokens", None)
564        else:
565            logger.warning(
566                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
567            )
568
569        if isinstance(cost, float):
570            usage.cost = cost
571        elif cost is not None:
572            # None is allowed, but no other types are expected
573            logger.warning(
574                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
575            )
576
577        return usage
578
579    async def cached_available_tools(self) -> list[KilnToolInterface]:
580        if self._cached_available_tools is None:
581            self._cached_available_tools = await self.available_tools()
582        return self._cached_available_tools
583
584    async def litellm_tools(self) -> list[ToolCallDefinition]:
585        available_tools = await self.cached_available_tools()
586
587        # LiteLLM takes the standard OpenAI-compatible tool call format
588        return [await tool.toolcall_definition() for tool in available_tools]
589
590    async def process_tool_calls(
591        self, tool_calls: list[ChatCompletionMessageToolCall] | None
592    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
593        if tool_calls is None:
594            return None, []
595
596        assistant_output_from_toolcall: str | None = None
597        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
598
599        for tool_call in tool_calls:
600            # Kiln "task_response" tool is used for returning structured output via tool calls.
601            # Load the output from the tool call. Also
602            if tool_call.function.name == "task_response":
603                assistant_output_from_toolcall = tool_call.function.arguments
604                continue
605
606            # Process normal tool calls (not the "task_response" tool)
607            tool_name = tool_call.function.name
608            tool = None
609            for tool_option in await self.cached_available_tools():
610                if await tool_option.name() == tool_name:
611                    tool = tool_option
612                    break
613            if not tool:
614                raise RuntimeError(
615                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
616                )
617
618            # Parse the arguments and validate them against the tool's schema
619            try:
620                parsed_args = json.loads(tool_call.function.arguments)
621            except json.JSONDecodeError:
622                raise RuntimeError(
623                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
624                )
625            try:
626                tool_call_definition = await tool.toolcall_definition()
627                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
628                validate_schema_with_value_error(parsed_args, json_schema)
629            except Exception as e:
630                raise RuntimeError(
631                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
632                ) from e
633
634            # Create context with the calling task's allow_saving setting
635            context = ToolCallContext(
636                allow_saving=self.base_adapter_config.allow_saving
637            )
638            result = await tool.run(context, **parsed_args)
639
640            tool_call_response_messages.append(
641                ChatCompletionToolMessageParamWrapper(
642                    role="tool",
643                    tool_call_id=tool_call.id,
644                    content=result.output,
645                    kiln_task_tool_data=result.kiln_task_tool_data
646                    if isinstance(result, KilnTaskToolResult)
647                    else None,
648                )
649            )
650
651        if (
652            assistant_output_from_toolcall is not None
653            and len(tool_call_response_messages) > 0
654        ):
655            raise RuntimeError(
656                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
657            )
658
659        return assistant_output_from_toolcall, tool_call_response_messages
660
661    def litellm_message_to_trace_message(
662        self, raw_message: LiteLLMMessage
663    ) -> ChatCompletionAssistantMessageParamWrapper:
664        """
665        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
666        """
667        message: ChatCompletionAssistantMessageParamWrapper = {
668            "role": "assistant",
669        }
670        if raw_message.role != "assistant":
671            raise ValueError(
672                "Model returned a message with a role other than assistant. This is not supported."
673            )
674
675        if hasattr(raw_message, "content"):
676            message["content"] = raw_message.content
677        if hasattr(raw_message, "reasoning_content"):
678            message["reasoning_content"] = raw_message.reasoning_content
679        if hasattr(raw_message, "tool_calls"):
680            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
681            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
682            for litellm_tool_call in raw_message.tool_calls or []:
683                # Optional in the SDK for streaming responses, but should never be None at this point.
684                if litellm_tool_call.function.name is None:
685                    raise ValueError(
686                        "The model requested a tool call, without providing a function name (required)."
687                    )
688                open_ai_tool_calls.append(
689                    ChatCompletionMessageToolCallParam(
690                        id=litellm_tool_call.id,
691                        type="function",
692                        function={
693                            "name": litellm_tool_call.function.name,
694                            "arguments": litellm_tool_call.function.arguments,
695                        },
696                    )
697                )
698            if len(open_ai_tool_calls) > 0:
699                message["tool_calls"] = open_ai_tool_calls
700
701        if not message.get("content") and not message.get("tool_calls"):
702            raise ValueError(
703                "Model returned an assistant message, but no content or tool calls. This is not supported."
704            )
705
706        return message
707
708    def all_messages_to_trace(
709        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
710    ) -> list[ChatCompletionMessageParam]:
711        """
712        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
713        """
714        trace: list[ChatCompletionMessageParam] = []
715        for message in messages:
716            if isinstance(message, LiteLLMMessage):
717                trace.append(self.litellm_message_to_trace_message(message))
718            else:
719                trace.append(message)
720        return trace
MAX_CALLS_PER_TURN = 10
MAX_TOOL_CALLS_PER_TURN = 30
ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]
@dataclass
class ModelTurnResult:
60@dataclass
61class ModelTurnResult:
62    assistant_message: str
63    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
64    model_response: ModelResponse | None
65    model_choice: Choices | None
66    usage: Usage
ModelTurnResult( assistant_message: str, all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], model_response: litellm.types.utils.ModelResponse | None, model_choice: litellm.types.utils.Choices | None, usage: kiln_ai.datamodel.Usage)
assistant_message: str
all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]
model_response: litellm.types.utils.ModelResponse | None
model_choice: litellm.types.utils.Choices | None
 69class LiteLlmAdapter(BaseAdapter):
 70    def __init__(
 71        self,
 72        config: LiteLlmConfig,
 73        kiln_task: datamodel.Task,
 74        base_adapter_config: AdapterConfig | None = None,
 75    ):
 76        self.config = config
 77        self._additional_body_options = config.additional_body_options
 78        self._api_base = config.base_url
 79        self._headers = config.default_headers
 80        self._litellm_model_id: str | None = None
 81        self._cached_available_tools: list[KilnToolInterface] | None = None
 82
 83        super().__init__(
 84            task=kiln_task,
 85            run_config=config.run_config_properties,
 86            config=base_adapter_config,
 87        )
 88
 89    async def _run_model_turn(
 90        self,
 91        provider: KilnModelProvider,
 92        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
 93        top_logprobs: int | None,
 94        skip_response_format: bool,
 95    ) -> ModelTurnResult:
 96        """
 97        Call the model for a single top level turn: from user message to agent message.
 98
 99        It may make handle iterations of tool calls between the user/agent message if needed.
100        """
101
102        usage = Usage()
103        messages = list(prior_messages)
104        tool_calls_count = 0
105
106        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
107            # Build completion kwargs for tool calls
108            completion_kwargs = await self.build_completion_kwargs(
109                provider,
110                # Pass a copy, as acompletion mutates objects and breaks types.
111                copy.deepcopy(messages),
112                top_logprobs,
113                skip_response_format,
114            )
115
116            # Make the completion call
117            model_response, response_choice = await self.acompletion_checking_response(
118                **completion_kwargs
119            )
120
121            # count the usage
122            usage += self.usage_from_response(model_response)
123
124            # Extract content and tool calls
125            if not hasattr(response_choice, "message"):
126                raise ValueError("Response choice has no message")
127            content = response_choice.message.content
128            tool_calls = response_choice.message.tool_calls
129            if not content and not tool_calls:
130                raise ValueError(
131                    "Model returned an assistant message, but no content or tool calls. This is not supported."
132                )
133
134            # Add message to messages, so it can be used in the next turn
135            messages.append(response_choice.message)
136
137            # Process tool calls if any
138            if tool_calls and len(tool_calls) > 0:
139                (
140                    assistant_message_from_toolcall,
141                    tool_call_messages,
142                ) = await self.process_tool_calls(tool_calls)
143
144                # Add tool call results to messages
145                messages.extend(tool_call_messages)
146
147                # If task_response tool was called, we're done
148                if assistant_message_from_toolcall is not None:
149                    return ModelTurnResult(
150                        assistant_message=assistant_message_from_toolcall,
151                        all_messages=messages,
152                        model_response=model_response,
153                        model_choice=response_choice,
154                        usage=usage,
155                    )
156
157                # If there were tool calls, increment counter and continue
158                if tool_call_messages:
159                    tool_calls_count += 1
160                    continue
161
162            # If no tool calls, return the content as final output
163            if content:
164                return ModelTurnResult(
165                    assistant_message=content,
166                    all_messages=messages,
167                    model_response=model_response,
168                    model_choice=response_choice,
169                    usage=usage,
170                )
171
172            # If we get here with no content and no tool calls, break
173            raise RuntimeError(
174                "Model returned neither content nor tool calls. It must return at least one of these."
175            )
176
177        raise RuntimeError(
178            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
179        )
180
181    async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]:
182        usage = Usage()
183
184        provider = self.model_provider()
185        if not provider.model_id:
186            raise ValueError("Model ID is required for OpenAI compatible models")
187
188        chat_formatter = self.build_chat_formatter(input)
189        messages: list[ChatCompletionMessageIncludingLiteLLM] = []
190
191        prior_output: str | None = None
192        final_choice: Choices | None = None
193        turns = 0
194
195        while True:
196            turns += 1
197            if turns > MAX_CALLS_PER_TURN:
198                raise RuntimeError(
199                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
200                )
201
202            turn = chat_formatter.next_turn(prior_output)
203            if turn is None:
204                # No next turn, we're done
205                break
206
207            # Add messages from the turn to chat history
208            for message in turn.messages:
209                if message.content is None:
210                    raise ValueError("Empty message content isn't allowed")
211                # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
212                messages.append({"role": message.role, "content": message.content})  # type: ignore
213
214            skip_response_format = not turn.final_call
215            turn_result = await self._run_model_turn(
216                provider,
217                messages,
218                self.base_adapter_config.top_logprobs if turn.final_call else None,
219                skip_response_format,
220            )
221
222            usage += turn_result.usage
223
224            prior_output = turn_result.assistant_message
225            messages = turn_result.all_messages
226            final_choice = turn_result.model_choice
227
228            if not prior_output:
229                raise RuntimeError("No assistant message/output returned from model")
230
231        logprobs = self._extract_and_validate_logprobs(final_choice)
232
233        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
234        intermediate_outputs = chat_formatter.intermediate_outputs()
235        self._extract_reasoning_to_intermediate_outputs(
236            final_choice, intermediate_outputs
237        )
238
239        if not isinstance(prior_output, str):
240            raise RuntimeError(f"assistant message is not a string: {prior_output}")
241
242        trace = self.all_messages_to_trace(messages)
243        output = RunOutput(
244            output=prior_output,
245            intermediate_outputs=intermediate_outputs,
246            output_logprobs=logprobs,
247            trace=trace,
248        )
249
250        return output, usage
251
252    def _extract_and_validate_logprobs(
253        self, final_choice: Choices | None
254    ) -> ChoiceLogprobs | None:
255        """
256        Extract logprobs from the final choice and validate they exist if required.
257        """
258        logprobs = None
259        if (
260            final_choice is not None
261            and hasattr(final_choice, "logprobs")
262            and isinstance(final_choice.logprobs, ChoiceLogprobs)
263        ):
264            logprobs = final_choice.logprobs
265
266        # Check logprobs worked, if required
267        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
268            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
269
270        return logprobs
271
272    def _extract_reasoning_to_intermediate_outputs(
273        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
274    ) -> None:
275        """Extract reasoning content from model choice and add to intermediate outputs if present."""
276        if (
277            final_choice is not None
278            and hasattr(final_choice, "message")
279            and hasattr(final_choice.message, "reasoning_content")
280        ):
281            reasoning_content = final_choice.message.reasoning_content
282            if reasoning_content is not None:
283                stripped_reasoning_content = reasoning_content.strip()
284                if len(stripped_reasoning_content) > 0:
285                    intermediate_outputs["reasoning"] = stripped_reasoning_content
286
287    async def acompletion_checking_response(
288        self, **kwargs
289    ) -> Tuple[ModelResponse, Choices]:
290        response = await litellm.acompletion(**kwargs)
291        if (
292            not isinstance(response, ModelResponse)
293            or not response.choices
294            or len(response.choices) == 0
295            or not isinstance(response.choices[0], Choices)
296        ):
297            raise RuntimeError(
298                f"Expected ModelResponse with Choices, got {type(response)}."
299            )
300        return response, response.choices[0]
301
302    def adapter_name(self) -> str:
303        return "kiln_openai_compatible_adapter"
304
305    async def response_format_options(self) -> dict[str, Any]:
306        # Unstructured if task isn't structured
307        if not self.has_structured_output():
308            return {}
309
310        structured_output_mode = self.run_config.structured_output_mode
311
312        match structured_output_mode:
313            case StructuredOutputMode.json_mode:
314                return {"response_format": {"type": "json_object"}}
315            case StructuredOutputMode.json_schema:
316                return self.json_schema_response_format()
317            case StructuredOutputMode.function_calling_weak:
318                return self.tool_call_params(strict=False)
319            case StructuredOutputMode.function_calling:
320                return self.tool_call_params(strict=True)
321            case StructuredOutputMode.json_instructions:
322                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
323                return {}
324            case StructuredOutputMode.json_custom_instructions:
325                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
326                return {}
327            case StructuredOutputMode.json_instruction_and_object:
328                # We set response_format to json_object and also set json instructions in the prompt
329                return {"response_format": {"type": "json_object"}}
330            case StructuredOutputMode.default:
331                provider_name = self.run_config.model_provider_name
332                if provider_name == ModelProviderName.ollama:
333                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
334                    return self.json_schema_response_format()
335                elif provider_name == ModelProviderName.docker_model_runner:
336                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
337                    return self.json_schema_response_format()
338                else:
339                    # Default to function calling -- it's older than the other modes. Higher compatibility.
340                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
341                    strict = provider_name == ModelProviderName.openai
342                    return self.tool_call_params(strict=strict)
343            case StructuredOutputMode.unknown:
344                # See above, but this case should never happen.
345                raise ValueError("Structured output mode is unknown.")
346            case _:
347                raise_exhaustive_enum_error(structured_output_mode)
348
349    def json_schema_response_format(self) -> dict[str, Any]:
350        output_schema = self.task.output_schema()
351        return {
352            "response_format": {
353                "type": "json_schema",
354                "json_schema": {
355                    "name": "task_response",
356                    "schema": output_schema,
357                },
358            }
359        }
360
361    def tool_call_params(self, strict: bool) -> dict[str, Any]:
362        # Add additional_properties: false to the schema (OpenAI requires this for some models)
363        output_schema = self.task.output_schema()
364        if not isinstance(output_schema, dict):
365            raise ValueError(
366                "Invalid output schema for this task. Can not use tool calls."
367            )
368        output_schema["additionalProperties"] = False
369
370        function_params = {
371            "name": "task_response",
372            "parameters": output_schema,
373        }
374        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
375        if strict:
376            function_params["strict"] = True
377
378        return {
379            "tools": [
380                {
381                    "type": "function",
382                    "function": function_params,
383                }
384            ],
385            "tool_choice": {
386                "type": "function",
387                "function": {"name": "task_response"},
388            },
389        }
390
391    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
392        # Don't love having this logic here. But it's worth the usability improvement
393        # so better to keep it than exclude it. Should figure out how I want to isolate
394        # this sort of logic so it's config driven and can be overridden
395
396        extra_body = {}
397        provider_options = {}
398
399        if provider.thinking_level is not None:
400            extra_body["reasoning_effort"] = provider.thinking_level
401
402        if provider.require_openrouter_reasoning:
403            # https://openrouter.ai/docs/use-cases/reasoning-tokens
404            extra_body["reasoning"] = {
405                "exclude": False,
406            }
407
408        if provider.gemini_reasoning_enabled:
409            extra_body["reasoning"] = {
410                "enabled": True,
411            }
412
413        if provider.name == ModelProviderName.openrouter:
414            # Ask OpenRouter to include usage in the response (cost)
415            extra_body["usage"] = {"include": True}
416
417            # Set a default provider order for more deterministic routing.
418            # OpenRouter will ignore providers that don't support the model.
419            # Special cases below (like R1) can override this order.
420            # allow_fallbacks is true by default, but we can override it here.
421            provider_options["order"] = [
422                "fireworks",
423                "parasail",
424                "together",
425                "deepinfra",
426                "novita",
427                "groq",
428                "amazon-bedrock",
429                "azure",
430                "nebius",
431            ]
432
433        if provider.anthropic_extended_thinking:
434            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
435
436        if provider.r1_openrouter_options:
437            # Require providers that support the reasoning parameter
438            provider_options["require_parameters"] = True
439            # Prefer R1 providers with reasonable perf/quants
440            provider_options["order"] = ["fireworks", "together"]
441            # R1 providers with unreasonable quants
442            provider_options["ignore"] = ["deepinfra"]
443
444        # Only set of this request is to get logprobs.
445        if (
446            provider.logprobs_openrouter_options
447            and self.base_adapter_config.top_logprobs is not None
448        ):
449            # Don't let OpenRouter choose a provider that doesn't support logprobs.
450            provider_options["require_parameters"] = True
451            # DeepInfra silently fails to return logprobs consistently.
452            provider_options["ignore"] = ["deepinfra"]
453
454        if provider.openrouter_skip_required_parameters:
455            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
456            provider_options["require_parameters"] = False
457
458        # Siliconflow uses a bool flag for thinking, for some models
459        if provider.siliconflow_enable_thinking is not None:
460            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
461
462        if len(provider_options) > 0:
463            extra_body["provider"] = provider_options
464
465        return extra_body
466
467    def litellm_model_id(self) -> str:
468        # The model ID is an interesting combination of format and url endpoint.
469        # It specifics the provider URL/host, but this is overridden if you manually set an api url
470        if self._litellm_model_id:
471            return self._litellm_model_id
472
473        litellm_provider_info = get_litellm_provider_info(self.model_provider())
474        if litellm_provider_info.is_custom and self._api_base is None:
475            raise ValueError(
476                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
477            )
478
479        self._litellm_model_id = litellm_provider_info.litellm_model_id
480        return self._litellm_model_id
481
482    async def build_completion_kwargs(
483        self,
484        provider: KilnModelProvider,
485        messages: list[ChatCompletionMessageIncludingLiteLLM],
486        top_logprobs: int | None,
487        skip_response_format: bool = False,
488    ) -> dict[str, Any]:
489        extra_body = self.build_extra_body(provider)
490
491        # Merge all parameters into a single kwargs dict for litellm
492        completion_kwargs = {
493            "model": self.litellm_model_id(),
494            "messages": messages,
495            "api_base": self._api_base,
496            "headers": self._headers,
497            "temperature": self.run_config.temperature,
498            "top_p": self.run_config.top_p,
499            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
500            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
501            # Better to ignore them than to fail the model call.
502            # https://docs.litellm.ai/docs/completion/input
503            "drop_params": True,
504            **extra_body,
505            **self._additional_body_options,
506        }
507
508        tool_calls = await self.litellm_tools()
509        has_tools = len(tool_calls) > 0
510        if has_tools:
511            completion_kwargs["tools"] = tool_calls
512            completion_kwargs["tool_choice"] = "auto"
513
514        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
515        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
516        if provider.temp_top_p_exclusive:
517            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
518                del completion_kwargs["top_p"]
519            if (
520                "temperature" in completion_kwargs
521                and completion_kwargs["temperature"] == 1.0
522            ):
523                del completion_kwargs["temperature"]
524            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
525                raise ValueError(
526                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
527                )
528
529        if not skip_response_format:
530            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
531            response_format_options = await self.response_format_options()
532
533            # Check for a conflict between tools and response format using tools
534            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
535            if has_tools and "tools" in response_format_options:
536                raise ValueError(
537                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
538                )
539
540            completion_kwargs.update(response_format_options)
541
542        if top_logprobs is not None:
543            completion_kwargs["logprobs"] = True
544            completion_kwargs["top_logprobs"] = top_logprobs
545
546        return completion_kwargs
547
548    def usage_from_response(self, response: ModelResponse) -> Usage:
549        litellm_usage = response.get("usage", None)
550
551        # LiteLLM isn't consistent in how it returns the cost.
552        cost = response._hidden_params.get("response_cost", None)
553        if cost is None and litellm_usage:
554            cost = litellm_usage.get("cost", None)
555
556        usage = Usage()
557
558        if not litellm_usage and not cost:
559            return usage
560
561        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
562            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
563            usage.output_tokens = litellm_usage.get("completion_tokens", None)
564            usage.total_tokens = litellm_usage.get("total_tokens", None)
565        else:
566            logger.warning(
567                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
568            )
569
570        if isinstance(cost, float):
571            usage.cost = cost
572        elif cost is not None:
573            # None is allowed, but no other types are expected
574            logger.warning(
575                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
576            )
577
578        return usage
579
580    async def cached_available_tools(self) -> list[KilnToolInterface]:
581        if self._cached_available_tools is None:
582            self._cached_available_tools = await self.available_tools()
583        return self._cached_available_tools
584
585    async def litellm_tools(self) -> list[ToolCallDefinition]:
586        available_tools = await self.cached_available_tools()
587
588        # LiteLLM takes the standard OpenAI-compatible tool call format
589        return [await tool.toolcall_definition() for tool in available_tools]
590
591    async def process_tool_calls(
592        self, tool_calls: list[ChatCompletionMessageToolCall] | None
593    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
594        if tool_calls is None:
595            return None, []
596
597        assistant_output_from_toolcall: str | None = None
598        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
599
600        for tool_call in tool_calls:
601            # Kiln "task_response" tool is used for returning structured output via tool calls.
602            # Load the output from the tool call. Also
603            if tool_call.function.name == "task_response":
604                assistant_output_from_toolcall = tool_call.function.arguments
605                continue
606
607            # Process normal tool calls (not the "task_response" tool)
608            tool_name = tool_call.function.name
609            tool = None
610            for tool_option in await self.cached_available_tools():
611                if await tool_option.name() == tool_name:
612                    tool = tool_option
613                    break
614            if not tool:
615                raise RuntimeError(
616                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
617                )
618
619            # Parse the arguments and validate them against the tool's schema
620            try:
621                parsed_args = json.loads(tool_call.function.arguments)
622            except json.JSONDecodeError:
623                raise RuntimeError(
624                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
625                )
626            try:
627                tool_call_definition = await tool.toolcall_definition()
628                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
629                validate_schema_with_value_error(parsed_args, json_schema)
630            except Exception as e:
631                raise RuntimeError(
632                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
633                ) from e
634
635            # Create context with the calling task's allow_saving setting
636            context = ToolCallContext(
637                allow_saving=self.base_adapter_config.allow_saving
638            )
639            result = await tool.run(context, **parsed_args)
640
641            tool_call_response_messages.append(
642                ChatCompletionToolMessageParamWrapper(
643                    role="tool",
644                    tool_call_id=tool_call.id,
645                    content=result.output,
646                    kiln_task_tool_data=result.kiln_task_tool_data
647                    if isinstance(result, KilnTaskToolResult)
648                    else None,
649                )
650            )
651
652        if (
653            assistant_output_from_toolcall is not None
654            and len(tool_call_response_messages) > 0
655        ):
656            raise RuntimeError(
657                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
658            )
659
660        return assistant_output_from_toolcall, tool_call_response_messages
661
662    def litellm_message_to_trace_message(
663        self, raw_message: LiteLLMMessage
664    ) -> ChatCompletionAssistantMessageParamWrapper:
665        """
666        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
667        """
668        message: ChatCompletionAssistantMessageParamWrapper = {
669            "role": "assistant",
670        }
671        if raw_message.role != "assistant":
672            raise ValueError(
673                "Model returned a message with a role other than assistant. This is not supported."
674            )
675
676        if hasattr(raw_message, "content"):
677            message["content"] = raw_message.content
678        if hasattr(raw_message, "reasoning_content"):
679            message["reasoning_content"] = raw_message.reasoning_content
680        if hasattr(raw_message, "tool_calls"):
681            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
682            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
683            for litellm_tool_call in raw_message.tool_calls or []:
684                # Optional in the SDK for streaming responses, but should never be None at this point.
685                if litellm_tool_call.function.name is None:
686                    raise ValueError(
687                        "The model requested a tool call, without providing a function name (required)."
688                    )
689                open_ai_tool_calls.append(
690                    ChatCompletionMessageToolCallParam(
691                        id=litellm_tool_call.id,
692                        type="function",
693                        function={
694                            "name": litellm_tool_call.function.name,
695                            "arguments": litellm_tool_call.function.arguments,
696                        },
697                    )
698                )
699            if len(open_ai_tool_calls) > 0:
700                message["tool_calls"] = open_ai_tool_calls
701
702        if not message.get("content") and not message.get("tool_calls"):
703            raise ValueError(
704                "Model returned an assistant message, but no content or tool calls. This is not supported."
705            )
706
707        return message
708
709    def all_messages_to_trace(
710        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
711    ) -> list[ChatCompletionMessageParam]:
712        """
713        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
714        """
715        trace: list[ChatCompletionMessageParam] = []
716        for message in messages:
717            if isinstance(message, LiteLLMMessage):
718                trace.append(self.litellm_message_to_trace_message(message))
719            else:
720                trace.append(message)
721        return trace

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
70    def __init__(
71        self,
72        config: LiteLlmConfig,
73        kiln_task: datamodel.Task,
74        base_adapter_config: AdapterConfig | None = None,
75    ):
76        self.config = config
77        self._additional_body_options = config.additional_body_options
78        self._api_base = config.base_url
79        self._headers = config.default_headers
80        self._litellm_model_id: str | None = None
81        self._cached_available_tools: list[KilnToolInterface] | None = None
82
83        super().__init__(
84            task=kiln_task,
85            run_config=config.run_config_properties,
86            config=base_adapter_config,
87        )
config
async def acompletion_checking_response( self, **kwargs) -> Tuple[litellm.types.utils.ModelResponse, litellm.types.utils.Choices]:
287    async def acompletion_checking_response(
288        self, **kwargs
289    ) -> Tuple[ModelResponse, Choices]:
290        response = await litellm.acompletion(**kwargs)
291        if (
292            not isinstance(response, ModelResponse)
293            or not response.choices
294            or len(response.choices) == 0
295            or not isinstance(response.choices[0], Choices)
296        ):
297            raise RuntimeError(
298                f"Expected ModelResponse with Choices, got {type(response)}."
299            )
300        return response, response.choices[0]
def adapter_name(self) -> str:
302    def adapter_name(self) -> str:
303        return "kiln_openai_compatible_adapter"
async def response_format_options(self) -> dict[str, typing.Any]:
305    async def response_format_options(self) -> dict[str, Any]:
306        # Unstructured if task isn't structured
307        if not self.has_structured_output():
308            return {}
309
310        structured_output_mode = self.run_config.structured_output_mode
311
312        match structured_output_mode:
313            case StructuredOutputMode.json_mode:
314                return {"response_format": {"type": "json_object"}}
315            case StructuredOutputMode.json_schema:
316                return self.json_schema_response_format()
317            case StructuredOutputMode.function_calling_weak:
318                return self.tool_call_params(strict=False)
319            case StructuredOutputMode.function_calling:
320                return self.tool_call_params(strict=True)
321            case StructuredOutputMode.json_instructions:
322                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
323                return {}
324            case StructuredOutputMode.json_custom_instructions:
325                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
326                return {}
327            case StructuredOutputMode.json_instruction_and_object:
328                # We set response_format to json_object and also set json instructions in the prompt
329                return {"response_format": {"type": "json_object"}}
330            case StructuredOutputMode.default:
331                provider_name = self.run_config.model_provider_name
332                if provider_name == ModelProviderName.ollama:
333                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
334                    return self.json_schema_response_format()
335                elif provider_name == ModelProviderName.docker_model_runner:
336                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
337                    return self.json_schema_response_format()
338                else:
339                    # Default to function calling -- it's older than the other modes. Higher compatibility.
340                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
341                    strict = provider_name == ModelProviderName.openai
342                    return self.tool_call_params(strict=strict)
343            case StructuredOutputMode.unknown:
344                # See above, but this case should never happen.
345                raise ValueError("Structured output mode is unknown.")
346            case _:
347                raise_exhaustive_enum_error(structured_output_mode)
def json_schema_response_format(self) -> dict[str, typing.Any]:
349    def json_schema_response_format(self) -> dict[str, Any]:
350        output_schema = self.task.output_schema()
351        return {
352            "response_format": {
353                "type": "json_schema",
354                "json_schema": {
355                    "name": "task_response",
356                    "schema": output_schema,
357                },
358            }
359        }
def tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
361    def tool_call_params(self, strict: bool) -> dict[str, Any]:
362        # Add additional_properties: false to the schema (OpenAI requires this for some models)
363        output_schema = self.task.output_schema()
364        if not isinstance(output_schema, dict):
365            raise ValueError(
366                "Invalid output schema for this task. Can not use tool calls."
367            )
368        output_schema["additionalProperties"] = False
369
370        function_params = {
371            "name": "task_response",
372            "parameters": output_schema,
373        }
374        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
375        if strict:
376            function_params["strict"] = True
377
378        return {
379            "tools": [
380                {
381                    "type": "function",
382                    "function": function_params,
383                }
384            ],
385            "tool_choice": {
386                "type": "function",
387                "function": {"name": "task_response"},
388            },
389        }
def build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
391    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
392        # Don't love having this logic here. But it's worth the usability improvement
393        # so better to keep it than exclude it. Should figure out how I want to isolate
394        # this sort of logic so it's config driven and can be overridden
395
396        extra_body = {}
397        provider_options = {}
398
399        if provider.thinking_level is not None:
400            extra_body["reasoning_effort"] = provider.thinking_level
401
402        if provider.require_openrouter_reasoning:
403            # https://openrouter.ai/docs/use-cases/reasoning-tokens
404            extra_body["reasoning"] = {
405                "exclude": False,
406            }
407
408        if provider.gemini_reasoning_enabled:
409            extra_body["reasoning"] = {
410                "enabled": True,
411            }
412
413        if provider.name == ModelProviderName.openrouter:
414            # Ask OpenRouter to include usage in the response (cost)
415            extra_body["usage"] = {"include": True}
416
417            # Set a default provider order for more deterministic routing.
418            # OpenRouter will ignore providers that don't support the model.
419            # Special cases below (like R1) can override this order.
420            # allow_fallbacks is true by default, but we can override it here.
421            provider_options["order"] = [
422                "fireworks",
423                "parasail",
424                "together",
425                "deepinfra",
426                "novita",
427                "groq",
428                "amazon-bedrock",
429                "azure",
430                "nebius",
431            ]
432
433        if provider.anthropic_extended_thinking:
434            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
435
436        if provider.r1_openrouter_options:
437            # Require providers that support the reasoning parameter
438            provider_options["require_parameters"] = True
439            # Prefer R1 providers with reasonable perf/quants
440            provider_options["order"] = ["fireworks", "together"]
441            # R1 providers with unreasonable quants
442            provider_options["ignore"] = ["deepinfra"]
443
444        # Only set of this request is to get logprobs.
445        if (
446            provider.logprobs_openrouter_options
447            and self.base_adapter_config.top_logprobs is not None
448        ):
449            # Don't let OpenRouter choose a provider that doesn't support logprobs.
450            provider_options["require_parameters"] = True
451            # DeepInfra silently fails to return logprobs consistently.
452            provider_options["ignore"] = ["deepinfra"]
453
454        if provider.openrouter_skip_required_parameters:
455            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
456            provider_options["require_parameters"] = False
457
458        # Siliconflow uses a bool flag for thinking, for some models
459        if provider.siliconflow_enable_thinking is not None:
460            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
461
462        if len(provider_options) > 0:
463            extra_body["provider"] = provider_options
464
465        return extra_body
def litellm_model_id(self) -> str:
467    def litellm_model_id(self) -> str:
468        # The model ID is an interesting combination of format and url endpoint.
469        # It specifics the provider URL/host, but this is overridden if you manually set an api url
470        if self._litellm_model_id:
471            return self._litellm_model_id
472
473        litellm_provider_info = get_litellm_provider_info(self.model_provider())
474        if litellm_provider_info.is_custom and self._api_base is None:
475            raise ValueError(
476                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
477            )
478
479        self._litellm_model_id = litellm_provider_info.litellm_model_id
480        return self._litellm_model_id
async def build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
482    async def build_completion_kwargs(
483        self,
484        provider: KilnModelProvider,
485        messages: list[ChatCompletionMessageIncludingLiteLLM],
486        top_logprobs: int | None,
487        skip_response_format: bool = False,
488    ) -> dict[str, Any]:
489        extra_body = self.build_extra_body(provider)
490
491        # Merge all parameters into a single kwargs dict for litellm
492        completion_kwargs = {
493            "model": self.litellm_model_id(),
494            "messages": messages,
495            "api_base": self._api_base,
496            "headers": self._headers,
497            "temperature": self.run_config.temperature,
498            "top_p": self.run_config.top_p,
499            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
500            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
501            # Better to ignore them than to fail the model call.
502            # https://docs.litellm.ai/docs/completion/input
503            "drop_params": True,
504            **extra_body,
505            **self._additional_body_options,
506        }
507
508        tool_calls = await self.litellm_tools()
509        has_tools = len(tool_calls) > 0
510        if has_tools:
511            completion_kwargs["tools"] = tool_calls
512            completion_kwargs["tool_choice"] = "auto"
513
514        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
515        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
516        if provider.temp_top_p_exclusive:
517            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
518                del completion_kwargs["top_p"]
519            if (
520                "temperature" in completion_kwargs
521                and completion_kwargs["temperature"] == 1.0
522            ):
523                del completion_kwargs["temperature"]
524            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
525                raise ValueError(
526                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
527                )
528
529        if not skip_response_format:
530            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
531            response_format_options = await self.response_format_options()
532
533            # Check for a conflict between tools and response format using tools
534            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
535            if has_tools and "tools" in response_format_options:
536                raise ValueError(
537                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
538                )
539
540            completion_kwargs.update(response_format_options)
541
542        if top_logprobs is not None:
543            completion_kwargs["logprobs"] = True
544            completion_kwargs["top_logprobs"] = top_logprobs
545
546        return completion_kwargs
def usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage:
548    def usage_from_response(self, response: ModelResponse) -> Usage:
549        litellm_usage = response.get("usage", None)
550
551        # LiteLLM isn't consistent in how it returns the cost.
552        cost = response._hidden_params.get("response_cost", None)
553        if cost is None and litellm_usage:
554            cost = litellm_usage.get("cost", None)
555
556        usage = Usage()
557
558        if not litellm_usage and not cost:
559            return usage
560
561        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
562            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
563            usage.output_tokens = litellm_usage.get("completion_tokens", None)
564            usage.total_tokens = litellm_usage.get("total_tokens", None)
565        else:
566            logger.warning(
567                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
568            )
569
570        if isinstance(cost, float):
571            usage.cost = cost
572        elif cost is not None:
573            # None is allowed, but no other types are expected
574            logger.warning(
575                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
576            )
577
578        return usage
async def cached_available_tools(self) -> list[kiln_ai.tools.KilnToolInterface]:
580    async def cached_available_tools(self) -> list[KilnToolInterface]:
581        if self._cached_available_tools is None:
582            self._cached_available_tools = await self.available_tools()
583        return self._cached_available_tools
async def litellm_tools(self) -> list[kiln_ai.tools.base_tool.ToolCallDefinition]:
585    async def litellm_tools(self) -> list[ToolCallDefinition]:
586        available_tools = await self.cached_available_tools()
587
588        # LiteLLM takes the standard OpenAI-compatible tool call format
589        return [await tool.toolcall_definition() for tool in available_tools]
async def process_tool_calls( self, tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None) -> tuple[str | None, list[kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper]]:
591    async def process_tool_calls(
592        self, tool_calls: list[ChatCompletionMessageToolCall] | None
593    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
594        if tool_calls is None:
595            return None, []
596
597        assistant_output_from_toolcall: str | None = None
598        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
599
600        for tool_call in tool_calls:
601            # Kiln "task_response" tool is used for returning structured output via tool calls.
602            # Load the output from the tool call. Also
603            if tool_call.function.name == "task_response":
604                assistant_output_from_toolcall = tool_call.function.arguments
605                continue
606
607            # Process normal tool calls (not the "task_response" tool)
608            tool_name = tool_call.function.name
609            tool = None
610            for tool_option in await self.cached_available_tools():
611                if await tool_option.name() == tool_name:
612                    tool = tool_option
613                    break
614            if not tool:
615                raise RuntimeError(
616                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
617                )
618
619            # Parse the arguments and validate them against the tool's schema
620            try:
621                parsed_args = json.loads(tool_call.function.arguments)
622            except json.JSONDecodeError:
623                raise RuntimeError(
624                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
625                )
626            try:
627                tool_call_definition = await tool.toolcall_definition()
628                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
629                validate_schema_with_value_error(parsed_args, json_schema)
630            except Exception as e:
631                raise RuntimeError(
632                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
633                ) from e
634
635            # Create context with the calling task's allow_saving setting
636            context = ToolCallContext(
637                allow_saving=self.base_adapter_config.allow_saving
638            )
639            result = await tool.run(context, **parsed_args)
640
641            tool_call_response_messages.append(
642                ChatCompletionToolMessageParamWrapper(
643                    role="tool",
644                    tool_call_id=tool_call.id,
645                    content=result.output,
646                    kiln_task_tool_data=result.kiln_task_tool_data
647                    if isinstance(result, KilnTaskToolResult)
648                    else None,
649                )
650            )
651
652        if (
653            assistant_output_from_toolcall is not None
654            and len(tool_call_response_messages) > 0
655        ):
656            raise RuntimeError(
657                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
658            )
659
660        return assistant_output_from_toolcall, tool_call_response_messages
def litellm_message_to_trace_message( self, raw_message: litellm.types.utils.Message) -> kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper:
662    def litellm_message_to_trace_message(
663        self, raw_message: LiteLLMMessage
664    ) -> ChatCompletionAssistantMessageParamWrapper:
665        """
666        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
667        """
668        message: ChatCompletionAssistantMessageParamWrapper = {
669            "role": "assistant",
670        }
671        if raw_message.role != "assistant":
672            raise ValueError(
673                "Model returned a message with a role other than assistant. This is not supported."
674            )
675
676        if hasattr(raw_message, "content"):
677            message["content"] = raw_message.content
678        if hasattr(raw_message, "reasoning_content"):
679            message["reasoning_content"] = raw_message.reasoning_content
680        if hasattr(raw_message, "tool_calls"):
681            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
682            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
683            for litellm_tool_call in raw_message.tool_calls or []:
684                # Optional in the SDK for streaming responses, but should never be None at this point.
685                if litellm_tool_call.function.name is None:
686                    raise ValueError(
687                        "The model requested a tool call, without providing a function name (required)."
688                    )
689                open_ai_tool_calls.append(
690                    ChatCompletionMessageToolCallParam(
691                        id=litellm_tool_call.id,
692                        type="function",
693                        function={
694                            "name": litellm_tool_call.function.name,
695                            "arguments": litellm_tool_call.function.arguments,
696                        },
697                    )
698                )
699            if len(open_ai_tool_calls) > 0:
700                message["tool_calls"] = open_ai_tool_calls
701
702        if not message.get("content") and not message.get("tool_calls"):
703            raise ValueError(
704                "Model returned an assistant message, but no content or tool calls. This is not supported."
705            )
706
707        return message

Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper

def all_messages_to_trace( self, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]) -> list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]]:
709    def all_messages_to_trace(
710        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
711    ) -> list[ChatCompletionMessageParam]:
712        """
713        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
714        """
715        trace: list[ChatCompletionMessageParam] = []
716        for message in messages:
717            if isinstance(message, LiteLLMMessage):
718                trace.append(self.litellm_message_to_trace_message(message))
719            else:
720                trace.append(message)
721        return trace

Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.