kiln_ai.adapters.model_adapters.litellm_adapter

  1import copy
  2import json
  3import logging
  4from dataclasses import dataclass
  5from typing import Any, Dict, List, Tuple, TypeAlias, Union
  6
  7import litellm
  8from litellm.types.utils import (
  9    ChatCompletionMessageToolCall,
 10    ChoiceLogprobs,
 11    Choices,
 12    ModelResponse,
 13)
 14from litellm.types.utils import Message as LiteLLMMessage
 15from litellm.types.utils import Usage as LiteLlmUsage
 16from openai.types.chat.chat_completion_message_tool_call_param import (
 17    ChatCompletionMessageToolCallParam,
 18)
 19
 20import kiln_ai.datamodel as datamodel
 21from kiln_ai.adapters.ml_model_list import (
 22    KilnModelProvider,
 23    ModelProviderName,
 24    StructuredOutputMode,
 25)
 26from kiln_ai.adapters.model_adapters.base_adapter import (
 27    AdapterConfig,
 28    BaseAdapter,
 29    RunOutput,
 30    Usage,
 31)
 32from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
 33from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
 34from kiln_ai.tools.base_tool import (
 35    KilnToolInterface,
 36    ToolCallContext,
 37    ToolCallDefinition,
 38)
 39from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
 40from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
 41from kiln_ai.utils.litellm import get_litellm_provider_info
 42from kiln_ai.utils.open_ai_types import (
 43    ChatCompletionAssistantMessageParamWrapper,
 44    ChatCompletionMessageParam,
 45    ChatCompletionToolMessageParamWrapper,
 46)
 47
 48MAX_CALLS_PER_TURN = 10
 49MAX_TOOL_CALLS_PER_TURN = 30
 50
 51logger = logging.getLogger(__name__)
 52
 53ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[
 54    ChatCompletionMessageParam, LiteLLMMessage
 55]
 56
 57
 58@dataclass
 59class ModelTurnResult:
 60    assistant_message: str
 61    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
 62    model_response: ModelResponse | None
 63    model_choice: Choices | None
 64    usage: Usage
 65
 66
 67class LiteLlmAdapter(BaseAdapter):
 68    def __init__(
 69        self,
 70        config: LiteLlmConfig,
 71        kiln_task: datamodel.Task,
 72        base_adapter_config: AdapterConfig | None = None,
 73    ):
 74        self.config = config
 75        self._additional_body_options = config.additional_body_options
 76        self._api_base = config.base_url
 77        self._headers = config.default_headers
 78        self._litellm_model_id: str | None = None
 79        self._cached_available_tools: list[KilnToolInterface] | None = None
 80
 81        super().__init__(
 82            task=kiln_task,
 83            run_config=config.run_config_properties,
 84            config=base_adapter_config,
 85        )
 86
 87    async def _run_model_turn(
 88        self,
 89        provider: KilnModelProvider,
 90        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
 91        top_logprobs: int | None,
 92        skip_response_format: bool,
 93    ) -> ModelTurnResult:
 94        """
 95        Call the model for a single top level turn: from user message to agent message.
 96
 97        It may make handle iterations of tool calls between the user/agent message if needed.
 98        """
 99
100        usage = Usage()
101        messages = list(prior_messages)
102        tool_calls_count = 0
103
104        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
105            # Build completion kwargs for tool calls
106            completion_kwargs = await self.build_completion_kwargs(
107                provider,
108                # Pass a copy, as acompletion mutates objects and breaks types.
109                copy.deepcopy(messages),
110                top_logprobs,
111                skip_response_format,
112            )
113
114            # Make the completion call
115            model_response, response_choice = await self.acompletion_checking_response(
116                **completion_kwargs
117            )
118
119            # count the usage
120            usage += self.usage_from_response(model_response)
121
122            # Extract content and tool calls
123            if not hasattr(response_choice, "message"):
124                raise ValueError("Response choice has no message")
125            content = response_choice.message.content
126            tool_calls = response_choice.message.tool_calls
127            if not content and not tool_calls:
128                raise ValueError(
129                    "Model returned an assistant message, but no content or tool calls. This is not supported."
130                )
131
132            # Add message to messages, so it can be used in the next turn
133            messages.append(response_choice.message)
134
135            # Process tool calls if any
136            if tool_calls and len(tool_calls) > 0:
137                (
138                    assistant_message_from_toolcall,
139                    tool_call_messages,
140                ) = await self.process_tool_calls(tool_calls)
141
142                # Add tool call results to messages
143                messages.extend(tool_call_messages)
144
145                # If task_response tool was called, we're done
146                if assistant_message_from_toolcall is not None:
147                    return ModelTurnResult(
148                        assistant_message=assistant_message_from_toolcall,
149                        all_messages=messages,
150                        model_response=model_response,
151                        model_choice=response_choice,
152                        usage=usage,
153                    )
154
155                # If there were tool calls, increment counter and continue
156                if tool_call_messages:
157                    tool_calls_count += 1
158                    continue
159
160            # If no tool calls, return the content as final output
161            if content:
162                return ModelTurnResult(
163                    assistant_message=content,
164                    all_messages=messages,
165                    model_response=model_response,
166                    model_choice=response_choice,
167                    usage=usage,
168                )
169
170            # If we get here with no content and no tool calls, break
171            raise RuntimeError(
172                "Model returned neither content nor tool calls. It must return at least one of these."
173            )
174
175        raise RuntimeError(
176            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
177        )
178
179    async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
180        usage = Usage()
181
182        provider = self.model_provider()
183        if not provider.model_id:
184            raise ValueError("Model ID is required for OpenAI compatible models")
185
186        chat_formatter = self.build_chat_formatter(input)
187        messages: list[ChatCompletionMessageIncludingLiteLLM] = []
188
189        prior_output: str | None = None
190        final_choice: Choices | None = None
191        turns = 0
192
193        while True:
194            turns += 1
195            if turns > MAX_CALLS_PER_TURN:
196                raise RuntimeError(
197                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
198                )
199
200            turn = chat_formatter.next_turn(prior_output)
201            if turn is None:
202                # No next turn, we're done
203                break
204
205            # Add messages from the turn to chat history
206            for message in turn.messages:
207                if message.content is None:
208                    raise ValueError("Empty message content isn't allowed")
209                # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
210                messages.append({"role": message.role, "content": message.content})  # type: ignore
211
212            skip_response_format = not turn.final_call
213            turn_result = await self._run_model_turn(
214                provider,
215                messages,
216                self.base_adapter_config.top_logprobs if turn.final_call else None,
217                skip_response_format,
218            )
219
220            usage += turn_result.usage
221
222            prior_output = turn_result.assistant_message
223            messages = turn_result.all_messages
224            final_choice = turn_result.model_choice
225
226            if not prior_output:
227                raise RuntimeError("No assistant message/output returned from model")
228
229        logprobs = self._extract_and_validate_logprobs(final_choice)
230
231        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
232        intermediate_outputs = chat_formatter.intermediate_outputs()
233        self._extract_reasoning_to_intermediate_outputs(
234            final_choice, intermediate_outputs
235        )
236
237        if not isinstance(prior_output, str):
238            raise RuntimeError(f"assistant message is not a string: {prior_output}")
239
240        trace = self.all_messages_to_trace(messages)
241        output = RunOutput(
242            output=prior_output,
243            intermediate_outputs=intermediate_outputs,
244            output_logprobs=logprobs,
245            trace=trace,
246        )
247
248        return output, usage
249
250    def _extract_and_validate_logprobs(
251        self, final_choice: Choices | None
252    ) -> ChoiceLogprobs | None:
253        """
254        Extract logprobs from the final choice and validate they exist if required.
255        """
256        logprobs = None
257        if (
258            final_choice is not None
259            and hasattr(final_choice, "logprobs")
260            and isinstance(final_choice.logprobs, ChoiceLogprobs)
261        ):
262            logprobs = final_choice.logprobs
263
264        # Check logprobs worked, if required
265        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
266            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
267
268        return logprobs
269
270    def _extract_reasoning_to_intermediate_outputs(
271        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
272    ) -> None:
273        """Extract reasoning content from model choice and add to intermediate outputs if present."""
274        if (
275            final_choice is not None
276            and hasattr(final_choice, "message")
277            and hasattr(final_choice.message, "reasoning_content")
278        ):
279            reasoning_content = final_choice.message.reasoning_content
280            if reasoning_content is not None:
281                stripped_reasoning_content = reasoning_content.strip()
282                if len(stripped_reasoning_content) > 0:
283                    intermediate_outputs["reasoning"] = stripped_reasoning_content
284
285    async def acompletion_checking_response(
286        self, **kwargs
287    ) -> Tuple[ModelResponse, Choices]:
288        response = await litellm.acompletion(**kwargs)
289        if (
290            not isinstance(response, ModelResponse)
291            or not response.choices
292            or len(response.choices) == 0
293            or not isinstance(response.choices[0], Choices)
294        ):
295            raise RuntimeError(
296                f"Expected ModelResponse with Choices, got {type(response)}."
297            )
298        return response, response.choices[0]
299
300    def adapter_name(self) -> str:
301        return "kiln_openai_compatible_adapter"
302
303    async def response_format_options(self) -> dict[str, Any]:
304        # Unstructured if task isn't structured
305        if not self.has_structured_output():
306            return {}
307
308        structured_output_mode = self.run_config.structured_output_mode
309
310        match structured_output_mode:
311            case StructuredOutputMode.json_mode:
312                return {"response_format": {"type": "json_object"}}
313            case StructuredOutputMode.json_schema:
314                return self.json_schema_response_format()
315            case StructuredOutputMode.function_calling_weak:
316                return self.tool_call_params(strict=False)
317            case StructuredOutputMode.function_calling:
318                return self.tool_call_params(strict=True)
319            case StructuredOutputMode.json_instructions:
320                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
321                return {}
322            case StructuredOutputMode.json_custom_instructions:
323                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
324                return {}
325            case StructuredOutputMode.json_instruction_and_object:
326                # We set response_format to json_object and also set json instructions in the prompt
327                return {"response_format": {"type": "json_object"}}
328            case StructuredOutputMode.default:
329                provider_name = self.run_config.model_provider_name
330                if provider_name == ModelProviderName.ollama:
331                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
332                    return self.json_schema_response_format()
333                elif provider_name == ModelProviderName.docker_model_runner:
334                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
335                    return self.json_schema_response_format()
336                else:
337                    # Default to function calling -- it's older than the other modes. Higher compatibility.
338                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
339                    strict = provider_name == ModelProviderName.openai
340                    return self.tool_call_params(strict=strict)
341            case StructuredOutputMode.unknown:
342                # See above, but this case should never happen.
343                raise ValueError("Structured output mode is unknown.")
344            case _:
345                raise_exhaustive_enum_error(structured_output_mode)
346
347    def json_schema_response_format(self) -> dict[str, Any]:
348        output_schema = self.task.output_schema()
349        return {
350            "response_format": {
351                "type": "json_schema",
352                "json_schema": {
353                    "name": "task_response",
354                    "schema": output_schema,
355                },
356            }
357        }
358
359    def tool_call_params(self, strict: bool) -> dict[str, Any]:
360        # Add additional_properties: false to the schema (OpenAI requires this for some models)
361        output_schema = self.task.output_schema()
362        if not isinstance(output_schema, dict):
363            raise ValueError(
364                "Invalid output schema for this task. Can not use tool calls."
365            )
366        output_schema["additionalProperties"] = False
367
368        function_params = {
369            "name": "task_response",
370            "parameters": output_schema,
371        }
372        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
373        if strict:
374            function_params["strict"] = True
375
376        return {
377            "tools": [
378                {
379                    "type": "function",
380                    "function": function_params,
381                }
382            ],
383            "tool_choice": {
384                "type": "function",
385                "function": {"name": "task_response"},
386            },
387        }
388
389    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
390        # Don't love having this logic here. But it's worth the usability improvement
391        # so better to keep it than exclude it. Should figure out how I want to isolate
392        # this sort of logic so it's config driven and can be overridden
393
394        extra_body = {}
395        provider_options = {}
396
397        if provider.thinking_level is not None:
398            extra_body["reasoning_effort"] = provider.thinking_level
399
400        if provider.require_openrouter_reasoning:
401            # https://openrouter.ai/docs/use-cases/reasoning-tokens
402            extra_body["reasoning"] = {
403                "exclude": False,
404            }
405
406        if provider.gemini_reasoning_enabled:
407            extra_body["reasoning"] = {
408                "enabled": True,
409            }
410
411        if provider.name == ModelProviderName.openrouter:
412            # Ask OpenRouter to include usage in the response (cost)
413            extra_body["usage"] = {"include": True}
414
415        if provider.anthropic_extended_thinking:
416            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
417
418        if provider.r1_openrouter_options:
419            # Require providers that support the reasoning parameter
420            provider_options["require_parameters"] = True
421            # Prefer R1 providers with reasonable perf/quants
422            provider_options["order"] = ["Fireworks", "Together"]
423            # R1 providers with unreasonable quants
424            provider_options["ignore"] = ["DeepInfra"]
425
426        # Only set of this request is to get logprobs.
427        if (
428            provider.logprobs_openrouter_options
429            and self.base_adapter_config.top_logprobs is not None
430        ):
431            # Don't let OpenRouter choose a provider that doesn't support logprobs.
432            provider_options["require_parameters"] = True
433            # DeepInfra silently fails to return logprobs consistently.
434            provider_options["ignore"] = ["DeepInfra"]
435
436        if provider.openrouter_skip_required_parameters:
437            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
438            provider_options["require_parameters"] = False
439
440        # Siliconflow uses a bool flag for thinking, for some models
441        if provider.siliconflow_enable_thinking is not None:
442            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
443
444        if len(provider_options) > 0:
445            extra_body["provider"] = provider_options
446
447        return extra_body
448
449    def litellm_model_id(self) -> str:
450        # The model ID is an interesting combination of format and url endpoint.
451        # It specifics the provider URL/host, but this is overridden if you manually set an api url
452        if self._litellm_model_id:
453            return self._litellm_model_id
454
455        litellm_provider_info = get_litellm_provider_info(self.model_provider())
456        if litellm_provider_info.is_custom and self._api_base is None:
457            raise ValueError(
458                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
459            )
460
461        self._litellm_model_id = litellm_provider_info.litellm_model_id
462        return self._litellm_model_id
463
464    async def build_completion_kwargs(
465        self,
466        provider: KilnModelProvider,
467        messages: list[ChatCompletionMessageIncludingLiteLLM],
468        top_logprobs: int | None,
469        skip_response_format: bool = False,
470    ) -> dict[str, Any]:
471        extra_body = self.build_extra_body(provider)
472
473        # Merge all parameters into a single kwargs dict for litellm
474        completion_kwargs = {
475            "model": self.litellm_model_id(),
476            "messages": messages,
477            "api_base": self._api_base,
478            "headers": self._headers,
479            "temperature": self.run_config.temperature,
480            "top_p": self.run_config.top_p,
481            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
482            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
483            # Better to ignore them than to fail the model call.
484            # https://docs.litellm.ai/docs/completion/input
485            "drop_params": True,
486            **extra_body,
487            **self._additional_body_options,
488        }
489
490        tool_calls = await self.litellm_tools()
491        has_tools = len(tool_calls) > 0
492        if has_tools:
493            completion_kwargs["tools"] = tool_calls
494            completion_kwargs["tool_choice"] = "auto"
495
496        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
497        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
498        if provider.temp_top_p_exclusive:
499            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
500                del completion_kwargs["top_p"]
501            if (
502                "temperature" in completion_kwargs
503                and completion_kwargs["temperature"] == 1.0
504            ):
505                del completion_kwargs["temperature"]
506            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
507                raise ValueError(
508                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
509                )
510
511        if not skip_response_format:
512            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
513            response_format_options = await self.response_format_options()
514
515            # Check for a conflict between tools and response format using tools
516            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
517            if has_tools and "tools" in response_format_options:
518                raise ValueError(
519                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
520                )
521
522            completion_kwargs.update(response_format_options)
523
524        if top_logprobs is not None:
525            completion_kwargs["logprobs"] = True
526            completion_kwargs["top_logprobs"] = top_logprobs
527
528        return completion_kwargs
529
530    def usage_from_response(self, response: ModelResponse) -> Usage:
531        litellm_usage = response.get("usage", None)
532
533        # LiteLLM isn't consistent in how it returns the cost.
534        cost = response._hidden_params.get("response_cost", None)
535        if cost is None and litellm_usage:
536            cost = litellm_usage.get("cost", None)
537
538        usage = Usage()
539
540        if not litellm_usage and not cost:
541            return usage
542
543        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
544            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
545            usage.output_tokens = litellm_usage.get("completion_tokens", None)
546            usage.total_tokens = litellm_usage.get("total_tokens", None)
547        else:
548            logger.warning(
549                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
550            )
551
552        if isinstance(cost, float):
553            usage.cost = cost
554        elif cost is not None:
555            # None is allowed, but no other types are expected
556            logger.warning(
557                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
558            )
559
560        return usage
561
562    async def cached_available_tools(self) -> list[KilnToolInterface]:
563        if self._cached_available_tools is None:
564            self._cached_available_tools = await self.available_tools()
565        return self._cached_available_tools
566
567    async def litellm_tools(self) -> list[ToolCallDefinition]:
568        available_tools = await self.cached_available_tools()
569
570        # LiteLLM takes the standard OpenAI-compatible tool call format
571        return [await tool.toolcall_definition() for tool in available_tools]
572
573    async def process_tool_calls(
574        self, tool_calls: list[ChatCompletionMessageToolCall] | None
575    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
576        if tool_calls is None:
577            return None, []
578
579        assistant_output_from_toolcall: str | None = None
580        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
581
582        for tool_call in tool_calls:
583            # Kiln "task_response" tool is used for returning structured output via tool calls.
584            # Load the output from the tool call. Also
585            if tool_call.function.name == "task_response":
586                assistant_output_from_toolcall = tool_call.function.arguments
587                continue
588
589            # Process normal tool calls (not the "task_response" tool)
590            tool_name = tool_call.function.name
591            tool = None
592            for tool_option in await self.cached_available_tools():
593                if await tool_option.name() == tool_name:
594                    tool = tool_option
595                    break
596            if not tool:
597                raise RuntimeError(
598                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
599                )
600
601            # Parse the arguments and validate them against the tool's schema
602            try:
603                parsed_args = json.loads(tool_call.function.arguments)
604            except json.JSONDecodeError:
605                raise RuntimeError(
606                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
607                )
608            try:
609                tool_call_definition = await tool.toolcall_definition()
610                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
611                validate_schema_with_value_error(parsed_args, json_schema)
612            except Exception as e:
613                raise RuntimeError(
614                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
615                ) from e
616
617            # Create context with the calling task's allow_saving setting
618            context = ToolCallContext(
619                allow_saving=self.base_adapter_config.allow_saving
620            )
621            result = await tool.run(context, **parsed_args)
622            if isinstance(result, KilnTaskToolResult):
623                content = result.output
624                kiln_task_tool_data = result.kiln_task_tool_data
625            else:
626                content = result
627                kiln_task_tool_data = None
628
629            tool_call_response_messages.append(
630                ChatCompletionToolMessageParamWrapper(
631                    role="tool",
632                    tool_call_id=tool_call.id,
633                    content=content,
634                    kiln_task_tool_data=kiln_task_tool_data,
635                )
636            )
637
638        if (
639            assistant_output_from_toolcall is not None
640            and len(tool_call_response_messages) > 0
641        ):
642            raise RuntimeError(
643                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
644            )
645
646        return assistant_output_from_toolcall, tool_call_response_messages
647
648    def litellm_message_to_trace_message(
649        self, raw_message: LiteLLMMessage
650    ) -> ChatCompletionAssistantMessageParamWrapper:
651        """
652        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
653        """
654        message: ChatCompletionAssistantMessageParamWrapper = {
655            "role": "assistant",
656        }
657        if raw_message.role != "assistant":
658            raise ValueError(
659                "Model returned a message with a role other than assistant. This is not supported."
660            )
661
662        if hasattr(raw_message, "content"):
663            message["content"] = raw_message.content
664        if hasattr(raw_message, "reasoning_content"):
665            message["reasoning_content"] = raw_message.reasoning_content
666        if hasattr(raw_message, "tool_calls"):
667            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
668            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
669            for litellm_tool_call in raw_message.tool_calls or []:
670                # Optional in the SDK for streaming responses, but should never be None at this point.
671                if litellm_tool_call.function.name is None:
672                    raise ValueError(
673                        "The model requested a tool call, without providing a function name (required)."
674                    )
675                open_ai_tool_calls.append(
676                    ChatCompletionMessageToolCallParam(
677                        id=litellm_tool_call.id,
678                        type="function",
679                        function={
680                            "name": litellm_tool_call.function.name,
681                            "arguments": litellm_tool_call.function.arguments,
682                        },
683                    )
684                )
685            if len(open_ai_tool_calls) > 0:
686                message["tool_calls"] = open_ai_tool_calls
687
688        if not message.get("content") and not message.get("tool_calls"):
689            raise ValueError(
690                "Model returned an assistant message, but no content or tool calls. This is not supported."
691            )
692
693        return message
694
695    def all_messages_to_trace(
696        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
697    ) -> list[ChatCompletionMessageParam]:
698        """
699        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
700        """
701        trace: list[ChatCompletionMessageParam] = []
702        for message in messages:
703            if isinstance(message, LiteLLMMessage):
704                trace.append(self.litellm_message_to_trace_message(message))
705            else:
706                trace.append(message)
707        return trace
MAX_CALLS_PER_TURN = 10
MAX_TOOL_CALLS_PER_TURN = 30
ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]
@dataclass
class ModelTurnResult:
59@dataclass
60class ModelTurnResult:
61    assistant_message: str
62    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
63    model_response: ModelResponse | None
64    model_choice: Choices | None
65    usage: Usage
ModelTurnResult( assistant_message: str, all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], model_response: litellm.types.utils.ModelResponse | None, model_choice: litellm.types.utils.Choices | None, usage: kiln_ai.datamodel.Usage)
assistant_message: str
all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]
model_response: litellm.types.utils.ModelResponse | None
model_choice: litellm.types.utils.Choices | None
 68class LiteLlmAdapter(BaseAdapter):
 69    def __init__(
 70        self,
 71        config: LiteLlmConfig,
 72        kiln_task: datamodel.Task,
 73        base_adapter_config: AdapterConfig | None = None,
 74    ):
 75        self.config = config
 76        self._additional_body_options = config.additional_body_options
 77        self._api_base = config.base_url
 78        self._headers = config.default_headers
 79        self._litellm_model_id: str | None = None
 80        self._cached_available_tools: list[KilnToolInterface] | None = None
 81
 82        super().__init__(
 83            task=kiln_task,
 84            run_config=config.run_config_properties,
 85            config=base_adapter_config,
 86        )
 87
 88    async def _run_model_turn(
 89        self,
 90        provider: KilnModelProvider,
 91        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
 92        top_logprobs: int | None,
 93        skip_response_format: bool,
 94    ) -> ModelTurnResult:
 95        """
 96        Call the model for a single top level turn: from user message to agent message.
 97
 98        It may make handle iterations of tool calls between the user/agent message if needed.
 99        """
100
101        usage = Usage()
102        messages = list(prior_messages)
103        tool_calls_count = 0
104
105        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
106            # Build completion kwargs for tool calls
107            completion_kwargs = await self.build_completion_kwargs(
108                provider,
109                # Pass a copy, as acompletion mutates objects and breaks types.
110                copy.deepcopy(messages),
111                top_logprobs,
112                skip_response_format,
113            )
114
115            # Make the completion call
116            model_response, response_choice = await self.acompletion_checking_response(
117                **completion_kwargs
118            )
119
120            # count the usage
121            usage += self.usage_from_response(model_response)
122
123            # Extract content and tool calls
124            if not hasattr(response_choice, "message"):
125                raise ValueError("Response choice has no message")
126            content = response_choice.message.content
127            tool_calls = response_choice.message.tool_calls
128            if not content and not tool_calls:
129                raise ValueError(
130                    "Model returned an assistant message, but no content or tool calls. This is not supported."
131                )
132
133            # Add message to messages, so it can be used in the next turn
134            messages.append(response_choice.message)
135
136            # Process tool calls if any
137            if tool_calls and len(tool_calls) > 0:
138                (
139                    assistant_message_from_toolcall,
140                    tool_call_messages,
141                ) = await self.process_tool_calls(tool_calls)
142
143                # Add tool call results to messages
144                messages.extend(tool_call_messages)
145
146                # If task_response tool was called, we're done
147                if assistant_message_from_toolcall is not None:
148                    return ModelTurnResult(
149                        assistant_message=assistant_message_from_toolcall,
150                        all_messages=messages,
151                        model_response=model_response,
152                        model_choice=response_choice,
153                        usage=usage,
154                    )
155
156                # If there were tool calls, increment counter and continue
157                if tool_call_messages:
158                    tool_calls_count += 1
159                    continue
160
161            # If no tool calls, return the content as final output
162            if content:
163                return ModelTurnResult(
164                    assistant_message=content,
165                    all_messages=messages,
166                    model_response=model_response,
167                    model_choice=response_choice,
168                    usage=usage,
169                )
170
171            # If we get here with no content and no tool calls, break
172            raise RuntimeError(
173                "Model returned neither content nor tool calls. It must return at least one of these."
174            )
175
176        raise RuntimeError(
177            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
178        )
179
180    async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
181        usage = Usage()
182
183        provider = self.model_provider()
184        if not provider.model_id:
185            raise ValueError("Model ID is required for OpenAI compatible models")
186
187        chat_formatter = self.build_chat_formatter(input)
188        messages: list[ChatCompletionMessageIncludingLiteLLM] = []
189
190        prior_output: str | None = None
191        final_choice: Choices | None = None
192        turns = 0
193
194        while True:
195            turns += 1
196            if turns > MAX_CALLS_PER_TURN:
197                raise RuntimeError(
198                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
199                )
200
201            turn = chat_formatter.next_turn(prior_output)
202            if turn is None:
203                # No next turn, we're done
204                break
205
206            # Add messages from the turn to chat history
207            for message in turn.messages:
208                if message.content is None:
209                    raise ValueError("Empty message content isn't allowed")
210                # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
211                messages.append({"role": message.role, "content": message.content})  # type: ignore
212
213            skip_response_format = not turn.final_call
214            turn_result = await self._run_model_turn(
215                provider,
216                messages,
217                self.base_adapter_config.top_logprobs if turn.final_call else None,
218                skip_response_format,
219            )
220
221            usage += turn_result.usage
222
223            prior_output = turn_result.assistant_message
224            messages = turn_result.all_messages
225            final_choice = turn_result.model_choice
226
227            if not prior_output:
228                raise RuntimeError("No assistant message/output returned from model")
229
230        logprobs = self._extract_and_validate_logprobs(final_choice)
231
232        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
233        intermediate_outputs = chat_formatter.intermediate_outputs()
234        self._extract_reasoning_to_intermediate_outputs(
235            final_choice, intermediate_outputs
236        )
237
238        if not isinstance(prior_output, str):
239            raise RuntimeError(f"assistant message is not a string: {prior_output}")
240
241        trace = self.all_messages_to_trace(messages)
242        output = RunOutput(
243            output=prior_output,
244            intermediate_outputs=intermediate_outputs,
245            output_logprobs=logprobs,
246            trace=trace,
247        )
248
249        return output, usage
250
251    def _extract_and_validate_logprobs(
252        self, final_choice: Choices | None
253    ) -> ChoiceLogprobs | None:
254        """
255        Extract logprobs from the final choice and validate they exist if required.
256        """
257        logprobs = None
258        if (
259            final_choice is not None
260            and hasattr(final_choice, "logprobs")
261            and isinstance(final_choice.logprobs, ChoiceLogprobs)
262        ):
263            logprobs = final_choice.logprobs
264
265        # Check logprobs worked, if required
266        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
267            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
268
269        return logprobs
270
271    def _extract_reasoning_to_intermediate_outputs(
272        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
273    ) -> None:
274        """Extract reasoning content from model choice and add to intermediate outputs if present."""
275        if (
276            final_choice is not None
277            and hasattr(final_choice, "message")
278            and hasattr(final_choice.message, "reasoning_content")
279        ):
280            reasoning_content = final_choice.message.reasoning_content
281            if reasoning_content is not None:
282                stripped_reasoning_content = reasoning_content.strip()
283                if len(stripped_reasoning_content) > 0:
284                    intermediate_outputs["reasoning"] = stripped_reasoning_content
285
286    async def acompletion_checking_response(
287        self, **kwargs
288    ) -> Tuple[ModelResponse, Choices]:
289        response = await litellm.acompletion(**kwargs)
290        if (
291            not isinstance(response, ModelResponse)
292            or not response.choices
293            or len(response.choices) == 0
294            or not isinstance(response.choices[0], Choices)
295        ):
296            raise RuntimeError(
297                f"Expected ModelResponse with Choices, got {type(response)}."
298            )
299        return response, response.choices[0]
300
301    def adapter_name(self) -> str:
302        return "kiln_openai_compatible_adapter"
303
304    async def response_format_options(self) -> dict[str, Any]:
305        # Unstructured if task isn't structured
306        if not self.has_structured_output():
307            return {}
308
309        structured_output_mode = self.run_config.structured_output_mode
310
311        match structured_output_mode:
312            case StructuredOutputMode.json_mode:
313                return {"response_format": {"type": "json_object"}}
314            case StructuredOutputMode.json_schema:
315                return self.json_schema_response_format()
316            case StructuredOutputMode.function_calling_weak:
317                return self.tool_call_params(strict=False)
318            case StructuredOutputMode.function_calling:
319                return self.tool_call_params(strict=True)
320            case StructuredOutputMode.json_instructions:
321                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
322                return {}
323            case StructuredOutputMode.json_custom_instructions:
324                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
325                return {}
326            case StructuredOutputMode.json_instruction_and_object:
327                # We set response_format to json_object and also set json instructions in the prompt
328                return {"response_format": {"type": "json_object"}}
329            case StructuredOutputMode.default:
330                provider_name = self.run_config.model_provider_name
331                if provider_name == ModelProviderName.ollama:
332                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
333                    return self.json_schema_response_format()
334                elif provider_name == ModelProviderName.docker_model_runner:
335                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
336                    return self.json_schema_response_format()
337                else:
338                    # Default to function calling -- it's older than the other modes. Higher compatibility.
339                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
340                    strict = provider_name == ModelProviderName.openai
341                    return self.tool_call_params(strict=strict)
342            case StructuredOutputMode.unknown:
343                # See above, but this case should never happen.
344                raise ValueError("Structured output mode is unknown.")
345            case _:
346                raise_exhaustive_enum_error(structured_output_mode)
347
348    def json_schema_response_format(self) -> dict[str, Any]:
349        output_schema = self.task.output_schema()
350        return {
351            "response_format": {
352                "type": "json_schema",
353                "json_schema": {
354                    "name": "task_response",
355                    "schema": output_schema,
356                },
357            }
358        }
359
360    def tool_call_params(self, strict: bool) -> dict[str, Any]:
361        # Add additional_properties: false to the schema (OpenAI requires this for some models)
362        output_schema = self.task.output_schema()
363        if not isinstance(output_schema, dict):
364            raise ValueError(
365                "Invalid output schema for this task. Can not use tool calls."
366            )
367        output_schema["additionalProperties"] = False
368
369        function_params = {
370            "name": "task_response",
371            "parameters": output_schema,
372        }
373        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
374        if strict:
375            function_params["strict"] = True
376
377        return {
378            "tools": [
379                {
380                    "type": "function",
381                    "function": function_params,
382                }
383            ],
384            "tool_choice": {
385                "type": "function",
386                "function": {"name": "task_response"},
387            },
388        }
389
390    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
391        # Don't love having this logic here. But it's worth the usability improvement
392        # so better to keep it than exclude it. Should figure out how I want to isolate
393        # this sort of logic so it's config driven and can be overridden
394
395        extra_body = {}
396        provider_options = {}
397
398        if provider.thinking_level is not None:
399            extra_body["reasoning_effort"] = provider.thinking_level
400
401        if provider.require_openrouter_reasoning:
402            # https://openrouter.ai/docs/use-cases/reasoning-tokens
403            extra_body["reasoning"] = {
404                "exclude": False,
405            }
406
407        if provider.gemini_reasoning_enabled:
408            extra_body["reasoning"] = {
409                "enabled": True,
410            }
411
412        if provider.name == ModelProviderName.openrouter:
413            # Ask OpenRouter to include usage in the response (cost)
414            extra_body["usage"] = {"include": True}
415
416        if provider.anthropic_extended_thinking:
417            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
418
419        if provider.r1_openrouter_options:
420            # Require providers that support the reasoning parameter
421            provider_options["require_parameters"] = True
422            # Prefer R1 providers with reasonable perf/quants
423            provider_options["order"] = ["Fireworks", "Together"]
424            # R1 providers with unreasonable quants
425            provider_options["ignore"] = ["DeepInfra"]
426
427        # Only set of this request is to get logprobs.
428        if (
429            provider.logprobs_openrouter_options
430            and self.base_adapter_config.top_logprobs is not None
431        ):
432            # Don't let OpenRouter choose a provider that doesn't support logprobs.
433            provider_options["require_parameters"] = True
434            # DeepInfra silently fails to return logprobs consistently.
435            provider_options["ignore"] = ["DeepInfra"]
436
437        if provider.openrouter_skip_required_parameters:
438            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
439            provider_options["require_parameters"] = False
440
441        # Siliconflow uses a bool flag for thinking, for some models
442        if provider.siliconflow_enable_thinking is not None:
443            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
444
445        if len(provider_options) > 0:
446            extra_body["provider"] = provider_options
447
448        return extra_body
449
450    def litellm_model_id(self) -> str:
451        # The model ID is an interesting combination of format and url endpoint.
452        # It specifics the provider URL/host, but this is overridden if you manually set an api url
453        if self._litellm_model_id:
454            return self._litellm_model_id
455
456        litellm_provider_info = get_litellm_provider_info(self.model_provider())
457        if litellm_provider_info.is_custom and self._api_base is None:
458            raise ValueError(
459                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
460            )
461
462        self._litellm_model_id = litellm_provider_info.litellm_model_id
463        return self._litellm_model_id
464
465    async def build_completion_kwargs(
466        self,
467        provider: KilnModelProvider,
468        messages: list[ChatCompletionMessageIncludingLiteLLM],
469        top_logprobs: int | None,
470        skip_response_format: bool = False,
471    ) -> dict[str, Any]:
472        extra_body = self.build_extra_body(provider)
473
474        # Merge all parameters into a single kwargs dict for litellm
475        completion_kwargs = {
476            "model": self.litellm_model_id(),
477            "messages": messages,
478            "api_base": self._api_base,
479            "headers": self._headers,
480            "temperature": self.run_config.temperature,
481            "top_p": self.run_config.top_p,
482            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
483            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
484            # Better to ignore them than to fail the model call.
485            # https://docs.litellm.ai/docs/completion/input
486            "drop_params": True,
487            **extra_body,
488            **self._additional_body_options,
489        }
490
491        tool_calls = await self.litellm_tools()
492        has_tools = len(tool_calls) > 0
493        if has_tools:
494            completion_kwargs["tools"] = tool_calls
495            completion_kwargs["tool_choice"] = "auto"
496
497        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
498        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
499        if provider.temp_top_p_exclusive:
500            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
501                del completion_kwargs["top_p"]
502            if (
503                "temperature" in completion_kwargs
504                and completion_kwargs["temperature"] == 1.0
505            ):
506                del completion_kwargs["temperature"]
507            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
508                raise ValueError(
509                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
510                )
511
512        if not skip_response_format:
513            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
514            response_format_options = await self.response_format_options()
515
516            # Check for a conflict between tools and response format using tools
517            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
518            if has_tools and "tools" in response_format_options:
519                raise ValueError(
520                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
521                )
522
523            completion_kwargs.update(response_format_options)
524
525        if top_logprobs is not None:
526            completion_kwargs["logprobs"] = True
527            completion_kwargs["top_logprobs"] = top_logprobs
528
529        return completion_kwargs
530
531    def usage_from_response(self, response: ModelResponse) -> Usage:
532        litellm_usage = response.get("usage", None)
533
534        # LiteLLM isn't consistent in how it returns the cost.
535        cost = response._hidden_params.get("response_cost", None)
536        if cost is None and litellm_usage:
537            cost = litellm_usage.get("cost", None)
538
539        usage = Usage()
540
541        if not litellm_usage and not cost:
542            return usage
543
544        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
545            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
546            usage.output_tokens = litellm_usage.get("completion_tokens", None)
547            usage.total_tokens = litellm_usage.get("total_tokens", None)
548        else:
549            logger.warning(
550                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
551            )
552
553        if isinstance(cost, float):
554            usage.cost = cost
555        elif cost is not None:
556            # None is allowed, but no other types are expected
557            logger.warning(
558                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
559            )
560
561        return usage
562
563    async def cached_available_tools(self) -> list[KilnToolInterface]:
564        if self._cached_available_tools is None:
565            self._cached_available_tools = await self.available_tools()
566        return self._cached_available_tools
567
568    async def litellm_tools(self) -> list[ToolCallDefinition]:
569        available_tools = await self.cached_available_tools()
570
571        # LiteLLM takes the standard OpenAI-compatible tool call format
572        return [await tool.toolcall_definition() for tool in available_tools]
573
574    async def process_tool_calls(
575        self, tool_calls: list[ChatCompletionMessageToolCall] | None
576    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
577        if tool_calls is None:
578            return None, []
579
580        assistant_output_from_toolcall: str | None = None
581        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
582
583        for tool_call in tool_calls:
584            # Kiln "task_response" tool is used for returning structured output via tool calls.
585            # Load the output from the tool call. Also
586            if tool_call.function.name == "task_response":
587                assistant_output_from_toolcall = tool_call.function.arguments
588                continue
589
590            # Process normal tool calls (not the "task_response" tool)
591            tool_name = tool_call.function.name
592            tool = None
593            for tool_option in await self.cached_available_tools():
594                if await tool_option.name() == tool_name:
595                    tool = tool_option
596                    break
597            if not tool:
598                raise RuntimeError(
599                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
600                )
601
602            # Parse the arguments and validate them against the tool's schema
603            try:
604                parsed_args = json.loads(tool_call.function.arguments)
605            except json.JSONDecodeError:
606                raise RuntimeError(
607                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
608                )
609            try:
610                tool_call_definition = await tool.toolcall_definition()
611                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
612                validate_schema_with_value_error(parsed_args, json_schema)
613            except Exception as e:
614                raise RuntimeError(
615                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
616                ) from e
617
618            # Create context with the calling task's allow_saving setting
619            context = ToolCallContext(
620                allow_saving=self.base_adapter_config.allow_saving
621            )
622            result = await tool.run(context, **parsed_args)
623            if isinstance(result, KilnTaskToolResult):
624                content = result.output
625                kiln_task_tool_data = result.kiln_task_tool_data
626            else:
627                content = result
628                kiln_task_tool_data = None
629
630            tool_call_response_messages.append(
631                ChatCompletionToolMessageParamWrapper(
632                    role="tool",
633                    tool_call_id=tool_call.id,
634                    content=content,
635                    kiln_task_tool_data=kiln_task_tool_data,
636                )
637            )
638
639        if (
640            assistant_output_from_toolcall is not None
641            and len(tool_call_response_messages) > 0
642        ):
643            raise RuntimeError(
644                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
645            )
646
647        return assistant_output_from_toolcall, tool_call_response_messages
648
649    def litellm_message_to_trace_message(
650        self, raw_message: LiteLLMMessage
651    ) -> ChatCompletionAssistantMessageParamWrapper:
652        """
653        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
654        """
655        message: ChatCompletionAssistantMessageParamWrapper = {
656            "role": "assistant",
657        }
658        if raw_message.role != "assistant":
659            raise ValueError(
660                "Model returned a message with a role other than assistant. This is not supported."
661            )
662
663        if hasattr(raw_message, "content"):
664            message["content"] = raw_message.content
665        if hasattr(raw_message, "reasoning_content"):
666            message["reasoning_content"] = raw_message.reasoning_content
667        if hasattr(raw_message, "tool_calls"):
668            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
669            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
670            for litellm_tool_call in raw_message.tool_calls or []:
671                # Optional in the SDK for streaming responses, but should never be None at this point.
672                if litellm_tool_call.function.name is None:
673                    raise ValueError(
674                        "The model requested a tool call, without providing a function name (required)."
675                    )
676                open_ai_tool_calls.append(
677                    ChatCompletionMessageToolCallParam(
678                        id=litellm_tool_call.id,
679                        type="function",
680                        function={
681                            "name": litellm_tool_call.function.name,
682                            "arguments": litellm_tool_call.function.arguments,
683                        },
684                    )
685                )
686            if len(open_ai_tool_calls) > 0:
687                message["tool_calls"] = open_ai_tool_calls
688
689        if not message.get("content") and not message.get("tool_calls"):
690            raise ValueError(
691                "Model returned an assistant message, but no content or tool calls. This is not supported."
692            )
693
694        return message
695
696    def all_messages_to_trace(
697        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
698    ) -> list[ChatCompletionMessageParam]:
699        """
700        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
701        """
702        trace: list[ChatCompletionMessageParam] = []
703        for message in messages:
704            if isinstance(message, LiteLLMMessage):
705                trace.append(self.litellm_message_to_trace_message(message))
706            else:
707                trace.append(message)
708        return trace

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
69    def __init__(
70        self,
71        config: LiteLlmConfig,
72        kiln_task: datamodel.Task,
73        base_adapter_config: AdapterConfig | None = None,
74    ):
75        self.config = config
76        self._additional_body_options = config.additional_body_options
77        self._api_base = config.base_url
78        self._headers = config.default_headers
79        self._litellm_model_id: str | None = None
80        self._cached_available_tools: list[KilnToolInterface] | None = None
81
82        super().__init__(
83            task=kiln_task,
84            run_config=config.run_config_properties,
85            config=base_adapter_config,
86        )
config
async def acompletion_checking_response( self, **kwargs) -> Tuple[litellm.types.utils.ModelResponse, litellm.types.utils.Choices]:
286    async def acompletion_checking_response(
287        self, **kwargs
288    ) -> Tuple[ModelResponse, Choices]:
289        response = await litellm.acompletion(**kwargs)
290        if (
291            not isinstance(response, ModelResponse)
292            or not response.choices
293            or len(response.choices) == 0
294            or not isinstance(response.choices[0], Choices)
295        ):
296            raise RuntimeError(
297                f"Expected ModelResponse with Choices, got {type(response)}."
298            )
299        return response, response.choices[0]
def adapter_name(self) -> str:
301    def adapter_name(self) -> str:
302        return "kiln_openai_compatible_adapter"
async def response_format_options(self) -> dict[str, typing.Any]:
304    async def response_format_options(self) -> dict[str, Any]:
305        # Unstructured if task isn't structured
306        if not self.has_structured_output():
307            return {}
308
309        structured_output_mode = self.run_config.structured_output_mode
310
311        match structured_output_mode:
312            case StructuredOutputMode.json_mode:
313                return {"response_format": {"type": "json_object"}}
314            case StructuredOutputMode.json_schema:
315                return self.json_schema_response_format()
316            case StructuredOutputMode.function_calling_weak:
317                return self.tool_call_params(strict=False)
318            case StructuredOutputMode.function_calling:
319                return self.tool_call_params(strict=True)
320            case StructuredOutputMode.json_instructions:
321                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
322                return {}
323            case StructuredOutputMode.json_custom_instructions:
324                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
325                return {}
326            case StructuredOutputMode.json_instruction_and_object:
327                # We set response_format to json_object and also set json instructions in the prompt
328                return {"response_format": {"type": "json_object"}}
329            case StructuredOutputMode.default:
330                provider_name = self.run_config.model_provider_name
331                if provider_name == ModelProviderName.ollama:
332                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
333                    return self.json_schema_response_format()
334                elif provider_name == ModelProviderName.docker_model_runner:
335                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
336                    return self.json_schema_response_format()
337                else:
338                    # Default to function calling -- it's older than the other modes. Higher compatibility.
339                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
340                    strict = provider_name == ModelProviderName.openai
341                    return self.tool_call_params(strict=strict)
342            case StructuredOutputMode.unknown:
343                # See above, but this case should never happen.
344                raise ValueError("Structured output mode is unknown.")
345            case _:
346                raise_exhaustive_enum_error(structured_output_mode)
def json_schema_response_format(self) -> dict[str, typing.Any]:
348    def json_schema_response_format(self) -> dict[str, Any]:
349        output_schema = self.task.output_schema()
350        return {
351            "response_format": {
352                "type": "json_schema",
353                "json_schema": {
354                    "name": "task_response",
355                    "schema": output_schema,
356                },
357            }
358        }
def tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
360    def tool_call_params(self, strict: bool) -> dict[str, Any]:
361        # Add additional_properties: false to the schema (OpenAI requires this for some models)
362        output_schema = self.task.output_schema()
363        if not isinstance(output_schema, dict):
364            raise ValueError(
365                "Invalid output schema for this task. Can not use tool calls."
366            )
367        output_schema["additionalProperties"] = False
368
369        function_params = {
370            "name": "task_response",
371            "parameters": output_schema,
372        }
373        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
374        if strict:
375            function_params["strict"] = True
376
377        return {
378            "tools": [
379                {
380                    "type": "function",
381                    "function": function_params,
382                }
383            ],
384            "tool_choice": {
385                "type": "function",
386                "function": {"name": "task_response"},
387            },
388        }
def build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
390    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
391        # Don't love having this logic here. But it's worth the usability improvement
392        # so better to keep it than exclude it. Should figure out how I want to isolate
393        # this sort of logic so it's config driven and can be overridden
394
395        extra_body = {}
396        provider_options = {}
397
398        if provider.thinking_level is not None:
399            extra_body["reasoning_effort"] = provider.thinking_level
400
401        if provider.require_openrouter_reasoning:
402            # https://openrouter.ai/docs/use-cases/reasoning-tokens
403            extra_body["reasoning"] = {
404                "exclude": False,
405            }
406
407        if provider.gemini_reasoning_enabled:
408            extra_body["reasoning"] = {
409                "enabled": True,
410            }
411
412        if provider.name == ModelProviderName.openrouter:
413            # Ask OpenRouter to include usage in the response (cost)
414            extra_body["usage"] = {"include": True}
415
416        if provider.anthropic_extended_thinking:
417            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
418
419        if provider.r1_openrouter_options:
420            # Require providers that support the reasoning parameter
421            provider_options["require_parameters"] = True
422            # Prefer R1 providers with reasonable perf/quants
423            provider_options["order"] = ["Fireworks", "Together"]
424            # R1 providers with unreasonable quants
425            provider_options["ignore"] = ["DeepInfra"]
426
427        # Only set of this request is to get logprobs.
428        if (
429            provider.logprobs_openrouter_options
430            and self.base_adapter_config.top_logprobs is not None
431        ):
432            # Don't let OpenRouter choose a provider that doesn't support logprobs.
433            provider_options["require_parameters"] = True
434            # DeepInfra silently fails to return logprobs consistently.
435            provider_options["ignore"] = ["DeepInfra"]
436
437        if provider.openrouter_skip_required_parameters:
438            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
439            provider_options["require_parameters"] = False
440
441        # Siliconflow uses a bool flag for thinking, for some models
442        if provider.siliconflow_enable_thinking is not None:
443            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
444
445        if len(provider_options) > 0:
446            extra_body["provider"] = provider_options
447
448        return extra_body
def litellm_model_id(self) -> str:
450    def litellm_model_id(self) -> str:
451        # The model ID is an interesting combination of format and url endpoint.
452        # It specifics the provider URL/host, but this is overridden if you manually set an api url
453        if self._litellm_model_id:
454            return self._litellm_model_id
455
456        litellm_provider_info = get_litellm_provider_info(self.model_provider())
457        if litellm_provider_info.is_custom and self._api_base is None:
458            raise ValueError(
459                "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
460            )
461
462        self._litellm_model_id = litellm_provider_info.litellm_model_id
463        return self._litellm_model_id
async def build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
465    async def build_completion_kwargs(
466        self,
467        provider: KilnModelProvider,
468        messages: list[ChatCompletionMessageIncludingLiteLLM],
469        top_logprobs: int | None,
470        skip_response_format: bool = False,
471    ) -> dict[str, Any]:
472        extra_body = self.build_extra_body(provider)
473
474        # Merge all parameters into a single kwargs dict for litellm
475        completion_kwargs = {
476            "model": self.litellm_model_id(),
477            "messages": messages,
478            "api_base": self._api_base,
479            "headers": self._headers,
480            "temperature": self.run_config.temperature,
481            "top_p": self.run_config.top_p,
482            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
483            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
484            # Better to ignore them than to fail the model call.
485            # https://docs.litellm.ai/docs/completion/input
486            "drop_params": True,
487            **extra_body,
488            **self._additional_body_options,
489        }
490
491        tool_calls = await self.litellm_tools()
492        has_tools = len(tool_calls) > 0
493        if has_tools:
494            completion_kwargs["tools"] = tool_calls
495            completion_kwargs["tool_choice"] = "auto"
496
497        # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
498        # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
499        if provider.temp_top_p_exclusive:
500            if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
501                del completion_kwargs["top_p"]
502            if (
503                "temperature" in completion_kwargs
504                and completion_kwargs["temperature"] == 1.0
505            ):
506                del completion_kwargs["temperature"]
507            if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
508                raise ValueError(
509                    "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
510                )
511
512        if not skip_response_format:
513            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
514            response_format_options = await self.response_format_options()
515
516            # Check for a conflict between tools and response format using tools
517            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
518            if has_tools and "tools" in response_format_options:
519                raise ValueError(
520                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
521                )
522
523            completion_kwargs.update(response_format_options)
524
525        if top_logprobs is not None:
526            completion_kwargs["logprobs"] = True
527            completion_kwargs["top_logprobs"] = top_logprobs
528
529        return completion_kwargs
def usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage:
531    def usage_from_response(self, response: ModelResponse) -> Usage:
532        litellm_usage = response.get("usage", None)
533
534        # LiteLLM isn't consistent in how it returns the cost.
535        cost = response._hidden_params.get("response_cost", None)
536        if cost is None and litellm_usage:
537            cost = litellm_usage.get("cost", None)
538
539        usage = Usage()
540
541        if not litellm_usage and not cost:
542            return usage
543
544        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
545            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
546            usage.output_tokens = litellm_usage.get("completion_tokens", None)
547            usage.total_tokens = litellm_usage.get("total_tokens", None)
548        else:
549            logger.warning(
550                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
551            )
552
553        if isinstance(cost, float):
554            usage.cost = cost
555        elif cost is not None:
556            # None is allowed, but no other types are expected
557            logger.warning(
558                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
559            )
560
561        return usage
async def cached_available_tools(self) -> list[kiln_ai.tools.KilnToolInterface]:
563    async def cached_available_tools(self) -> list[KilnToolInterface]:
564        if self._cached_available_tools is None:
565            self._cached_available_tools = await self.available_tools()
566        return self._cached_available_tools
async def litellm_tools(self) -> list[kiln_ai.tools.base_tool.ToolCallDefinition]:
568    async def litellm_tools(self) -> list[ToolCallDefinition]:
569        available_tools = await self.cached_available_tools()
570
571        # LiteLLM takes the standard OpenAI-compatible tool call format
572        return [await tool.toolcall_definition() for tool in available_tools]
async def process_tool_calls( self, tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None) -> tuple[str | None, list[kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper]]:
574    async def process_tool_calls(
575        self, tool_calls: list[ChatCompletionMessageToolCall] | None
576    ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
577        if tool_calls is None:
578            return None, []
579
580        assistant_output_from_toolcall: str | None = None
581        tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
582
583        for tool_call in tool_calls:
584            # Kiln "task_response" tool is used for returning structured output via tool calls.
585            # Load the output from the tool call. Also
586            if tool_call.function.name == "task_response":
587                assistant_output_from_toolcall = tool_call.function.arguments
588                continue
589
590            # Process normal tool calls (not the "task_response" tool)
591            tool_name = tool_call.function.name
592            tool = None
593            for tool_option in await self.cached_available_tools():
594                if await tool_option.name() == tool_name:
595                    tool = tool_option
596                    break
597            if not tool:
598                raise RuntimeError(
599                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
600                )
601
602            # Parse the arguments and validate them against the tool's schema
603            try:
604                parsed_args = json.loads(tool_call.function.arguments)
605            except json.JSONDecodeError:
606                raise RuntimeError(
607                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
608                )
609            try:
610                tool_call_definition = await tool.toolcall_definition()
611                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
612                validate_schema_with_value_error(parsed_args, json_schema)
613            except Exception as e:
614                raise RuntimeError(
615                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
616                ) from e
617
618            # Create context with the calling task's allow_saving setting
619            context = ToolCallContext(
620                allow_saving=self.base_adapter_config.allow_saving
621            )
622            result = await tool.run(context, **parsed_args)
623            if isinstance(result, KilnTaskToolResult):
624                content = result.output
625                kiln_task_tool_data = result.kiln_task_tool_data
626            else:
627                content = result
628                kiln_task_tool_data = None
629
630            tool_call_response_messages.append(
631                ChatCompletionToolMessageParamWrapper(
632                    role="tool",
633                    tool_call_id=tool_call.id,
634                    content=content,
635                    kiln_task_tool_data=kiln_task_tool_data,
636                )
637            )
638
639        if (
640            assistant_output_from_toolcall is not None
641            and len(tool_call_response_messages) > 0
642        ):
643            raise RuntimeError(
644                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
645            )
646
647        return assistant_output_from_toolcall, tool_call_response_messages
def litellm_message_to_trace_message( self, raw_message: litellm.types.utils.Message) -> kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper:
649    def litellm_message_to_trace_message(
650        self, raw_message: LiteLLMMessage
651    ) -> ChatCompletionAssistantMessageParamWrapper:
652        """
653        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
654        """
655        message: ChatCompletionAssistantMessageParamWrapper = {
656            "role": "assistant",
657        }
658        if raw_message.role != "assistant":
659            raise ValueError(
660                "Model returned a message with a role other than assistant. This is not supported."
661            )
662
663        if hasattr(raw_message, "content"):
664            message["content"] = raw_message.content
665        if hasattr(raw_message, "reasoning_content"):
666            message["reasoning_content"] = raw_message.reasoning_content
667        if hasattr(raw_message, "tool_calls"):
668            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
669            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
670            for litellm_tool_call in raw_message.tool_calls or []:
671                # Optional in the SDK for streaming responses, but should never be None at this point.
672                if litellm_tool_call.function.name is None:
673                    raise ValueError(
674                        "The model requested a tool call, without providing a function name (required)."
675                    )
676                open_ai_tool_calls.append(
677                    ChatCompletionMessageToolCallParam(
678                        id=litellm_tool_call.id,
679                        type="function",
680                        function={
681                            "name": litellm_tool_call.function.name,
682                            "arguments": litellm_tool_call.function.arguments,
683                        },
684                    )
685                )
686            if len(open_ai_tool_calls) > 0:
687                message["tool_calls"] = open_ai_tool_calls
688
689        if not message.get("content") and not message.get("tool_calls"):
690            raise ValueError(
691                "Model returned an assistant message, but no content or tool calls. This is not supported."
692            )
693
694        return message

Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper

def all_messages_to_trace( self, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]) -> list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]]:
696    def all_messages_to_trace(
697        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
698    ) -> list[ChatCompletionMessageParam]:
699        """
700        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
701        """
702        trace: list[ChatCompletionMessageParam] = []
703        for message in messages:
704            if isinstance(message, LiteLLMMessage):
705                trace.append(self.litellm_message_to_trace_message(message))
706            else:
707                trace.append(message)
708        return trace

Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.