kiln_ai.adapters.model_adapters.litellm_adapter

  1import copy
  2import json
  3import logging
  4from dataclasses import dataclass
  5from typing import Any, Dict, List, Tuple, TypeAlias, Union
  6
  7import litellm
  8from litellm.types.utils import (
  9    ChatCompletionMessageToolCall,
 10    ChoiceLogprobs,
 11    Choices,
 12    ModelResponse,
 13)
 14from litellm.types.utils import (
 15    Message as LiteLLMMessage,
 16)
 17from litellm.types.utils import Usage as LiteLlmUsage
 18from openai.types.chat import (
 19    ChatCompletionToolMessageParam,
 20)
 21from openai.types.chat.chat_completion_message_tool_call_param import (
 22    ChatCompletionMessageToolCallParam,
 23)
 24
 25import kiln_ai.datamodel as datamodel
 26from kiln_ai.adapters.ml_model_list import (
 27    KilnModelProvider,
 28    ModelProviderName,
 29    StructuredOutputMode,
 30)
 31from kiln_ai.adapters.model_adapters.base_adapter import (
 32    AdapterConfig,
 33    BaseAdapter,
 34    RunOutput,
 35    Usage,
 36)
 37from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
 38from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
 39from kiln_ai.tools.base_tool import KilnToolInterface
 40from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
 41from kiln_ai.utils.open_ai_types import (
 42    ChatCompletionAssistantMessageParamWrapper,
 43    ChatCompletionMessageParam,
 44)
 45
 46MAX_CALLS_PER_TURN = 10
 47MAX_TOOL_CALLS_PER_TURN = 30
 48
 49logger = logging.getLogger(__name__)
 50
 51ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[
 52    ChatCompletionMessageParam, LiteLLMMessage
 53]
 54
 55
 56@dataclass
 57class ModelTurnResult:
 58    assistant_message: str
 59    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
 60    model_response: ModelResponse | None
 61    model_choice: Choices | None
 62    usage: Usage
 63
 64
 65class LiteLlmAdapter(BaseAdapter):
 66    def __init__(
 67        self,
 68        config: LiteLlmConfig,
 69        kiln_task: datamodel.Task,
 70        base_adapter_config: AdapterConfig | None = None,
 71    ):
 72        self.config = config
 73        self._additional_body_options = config.additional_body_options
 74        self._api_base = config.base_url
 75        self._headers = config.default_headers
 76        self._litellm_model_id: str | None = None
 77        self._cached_available_tools: list[KilnToolInterface] | None = None
 78
 79        super().__init__(
 80            task=kiln_task,
 81            run_config=config.run_config_properties,
 82            config=base_adapter_config,
 83        )
 84
 85    async def _run_model_turn(
 86        self,
 87        provider: KilnModelProvider,
 88        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
 89        top_logprobs: int | None,
 90        skip_response_format: bool,
 91    ) -> ModelTurnResult:
 92        """
 93        Call the model for a single top level turn: from user message to agent message.
 94
 95        It may make handle iterations of tool calls between the user/agent message if needed.
 96        """
 97
 98        usage = Usage()
 99        messages = list(prior_messages)
100        tool_calls_count = 0
101
102        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
103            # Build completion kwargs for tool calls
104            completion_kwargs = await self.build_completion_kwargs(
105                provider,
106                # Pass a copy, as acompletion mutates objects and breaks types.
107                copy.deepcopy(messages),
108                top_logprobs,
109                skip_response_format,
110            )
111
112            # Make the completion call
113            model_response, response_choice = await self.acompletion_checking_response(
114                **completion_kwargs
115            )
116
117            # count the usage
118            usage += self.usage_from_response(model_response)
119
120            # Extract content and tool calls
121            if not hasattr(response_choice, "message"):
122                raise ValueError("Response choice has no message")
123            content = response_choice.message.content
124            tool_calls = response_choice.message.tool_calls
125            if not content and not tool_calls:
126                raise ValueError(
127                    "Model returned an assistant message, but no content or tool calls. This is not supported."
128                )
129
130            # Add message to messages, so it can be used in the next turn
131            messages.append(response_choice.message)
132
133            # Process tool calls if any
134            if tool_calls and len(tool_calls) > 0:
135                (
136                    assistant_message_from_toolcall,
137                    tool_call_messages,
138                ) = await self.process_tool_calls(tool_calls)
139
140                # Add tool call results to messages
141                messages.extend(tool_call_messages)
142
143                # If task_response tool was called, we're done
144                if assistant_message_from_toolcall is not None:
145                    return ModelTurnResult(
146                        assistant_message=assistant_message_from_toolcall,
147                        all_messages=messages,
148                        model_response=model_response,
149                        model_choice=response_choice,
150                        usage=usage,
151                    )
152
153                # If there were tool calls, increment counter and continue
154                if tool_call_messages:
155                    tool_calls_count += 1
156                    continue
157
158            # If no tool calls, return the content as final output
159            if content:
160                return ModelTurnResult(
161                    assistant_message=content,
162                    all_messages=messages,
163                    model_response=model_response,
164                    model_choice=response_choice,
165                    usage=usage,
166                )
167
168            # If we get here with no content and no tool calls, break
169            raise RuntimeError(
170                "Model returned neither content nor tool calls. It must return at least one of these."
171            )
172
173        raise RuntimeError(
174            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
175        )
176
177    async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
178        usage = Usage()
179
180        provider = self.model_provider()
181        if not provider.model_id:
182            raise ValueError("Model ID is required for OpenAI compatible models")
183
184        chat_formatter = self.build_chat_formatter(input)
185        messages: list[ChatCompletionMessageIncludingLiteLLM] = []
186
187        prior_output: str | None = None
188        final_choice: Choices | None = None
189        turns = 0
190
191        while True:
192            turns += 1
193            if turns > MAX_CALLS_PER_TURN:
194                raise RuntimeError(
195                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
196                )
197
198            turn = chat_formatter.next_turn(prior_output)
199            if turn is None:
200                # No next turn, we're done
201                break
202
203            # Add messages from the turn to chat history
204            for message in turn.messages:
205                if message.content is None:
206                    raise ValueError("Empty message content isn't allowed")
207                # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
208                messages.append({"role": message.role, "content": message.content})  # type: ignore
209
210            skip_response_format = not turn.final_call
211            turn_result = await self._run_model_turn(
212                provider,
213                messages,
214                self.base_adapter_config.top_logprobs if turn.final_call else None,
215                skip_response_format,
216            )
217
218            usage += turn_result.usage
219
220            prior_output = turn_result.assistant_message
221            messages = turn_result.all_messages
222            final_choice = turn_result.model_choice
223
224            if not prior_output:
225                raise RuntimeError("No assistant message/output returned from model")
226
227        logprobs = self._extract_and_validate_logprobs(final_choice)
228
229        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
230        intermediate_outputs = chat_formatter.intermediate_outputs()
231        self._extract_reasoning_to_intermediate_outputs(
232            final_choice, intermediate_outputs
233        )
234
235        if not isinstance(prior_output, str):
236            raise RuntimeError(f"assistant message is not a string: {prior_output}")
237
238        trace = self.all_messages_to_trace(messages)
239        output = RunOutput(
240            output=prior_output,
241            intermediate_outputs=intermediate_outputs,
242            output_logprobs=logprobs,
243            trace=trace,
244        )
245
246        return output, usage
247
248    def _extract_and_validate_logprobs(
249        self, final_choice: Choices | None
250    ) -> ChoiceLogprobs | None:
251        """
252        Extract logprobs from the final choice and validate they exist if required.
253        """
254        logprobs = None
255        if (
256            final_choice is not None
257            and hasattr(final_choice, "logprobs")
258            and isinstance(final_choice.logprobs, ChoiceLogprobs)
259        ):
260            logprobs = final_choice.logprobs
261
262        # Check logprobs worked, if required
263        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
264            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
265
266        return logprobs
267
268    def _extract_reasoning_to_intermediate_outputs(
269        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
270    ) -> None:
271        """Extract reasoning content from model choice and add to intermediate outputs if present."""
272        if (
273            final_choice is not None
274            and hasattr(final_choice, "message")
275            and hasattr(final_choice.message, "reasoning_content")
276        ):
277            reasoning_content = final_choice.message.reasoning_content
278            if reasoning_content is not None:
279                stripped_reasoning_content = reasoning_content.strip()
280                if len(stripped_reasoning_content) > 0:
281                    intermediate_outputs["reasoning"] = stripped_reasoning_content
282
283    async def acompletion_checking_response(
284        self, **kwargs
285    ) -> Tuple[ModelResponse, Choices]:
286        response = await litellm.acompletion(**kwargs)
287        if (
288            not isinstance(response, ModelResponse)
289            or not response.choices
290            or len(response.choices) == 0
291            or not isinstance(response.choices[0], Choices)
292        ):
293            raise RuntimeError(
294                f"Expected ModelResponse with Choices, got {type(response)}."
295            )
296        return response, response.choices[0]
297
298    def adapter_name(self) -> str:
299        return "kiln_openai_compatible_adapter"
300
301    async def response_format_options(self) -> dict[str, Any]:
302        # Unstructured if task isn't structured
303        if not self.has_structured_output():
304            return {}
305
306        structured_output_mode = self.run_config.structured_output_mode
307
308        match structured_output_mode:
309            case StructuredOutputMode.json_mode:
310                return {"response_format": {"type": "json_object"}}
311            case StructuredOutputMode.json_schema:
312                return self.json_schema_response_format()
313            case StructuredOutputMode.function_calling_weak:
314                return self.tool_call_params(strict=False)
315            case StructuredOutputMode.function_calling:
316                return self.tool_call_params(strict=True)
317            case StructuredOutputMode.json_instructions:
318                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
319                return {}
320            case StructuredOutputMode.json_custom_instructions:
321                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
322                return {}
323            case StructuredOutputMode.json_instruction_and_object:
324                # We set response_format to json_object and also set json instructions in the prompt
325                return {"response_format": {"type": "json_object"}}
326            case StructuredOutputMode.default:
327                provider_name = self.run_config.model_provider_name
328                if provider_name == ModelProviderName.ollama:
329                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
330                    return self.json_schema_response_format()
331                elif provider_name == ModelProviderName.docker_model_runner:
332                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
333                    return self.json_schema_response_format()
334                else:
335                    # Default to function calling -- it's older than the other modes. Higher compatibility.
336                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
337                    strict = provider_name == ModelProviderName.openai
338                    return self.tool_call_params(strict=strict)
339            case StructuredOutputMode.unknown:
340                # See above, but this case should never happen.
341                raise ValueError("Structured output mode is unknown.")
342            case _:
343                raise_exhaustive_enum_error(structured_output_mode)
344
345    def json_schema_response_format(self) -> dict[str, Any]:
346        output_schema = self.task.output_schema()
347        return {
348            "response_format": {
349                "type": "json_schema",
350                "json_schema": {
351                    "name": "task_response",
352                    "schema": output_schema,
353                },
354            }
355        }
356
357    def tool_call_params(self, strict: bool) -> dict[str, Any]:
358        # Add additional_properties: false to the schema (OpenAI requires this for some models)
359        output_schema = self.task.output_schema()
360        if not isinstance(output_schema, dict):
361            raise ValueError(
362                "Invalid output schema for this task. Can not use tool calls."
363            )
364        output_schema["additionalProperties"] = False
365
366        function_params = {
367            "name": "task_response",
368            "parameters": output_schema,
369        }
370        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
371        if strict:
372            function_params["strict"] = True
373
374        return {
375            "tools": [
376                {
377                    "type": "function",
378                    "function": function_params,
379                }
380            ],
381            "tool_choice": {
382                "type": "function",
383                "function": {"name": "task_response"},
384            },
385        }
386
387    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
388        # Don't love having this logic here. But it's worth the usability improvement
389        # so better to keep it than exclude it. Should figure out how I want to isolate
390        # this sort of logic so it's config driven and can be overridden
391
392        extra_body = {}
393        provider_options = {}
394
395        if provider.thinking_level is not None:
396            extra_body["reasoning_effort"] = provider.thinking_level
397
398        if provider.require_openrouter_reasoning:
399            # https://openrouter.ai/docs/use-cases/reasoning-tokens
400            extra_body["reasoning"] = {
401                "exclude": False,
402            }
403
404        if provider.gemini_reasoning_enabled:
405            extra_body["reasoning"] = {
406                "enabled": True,
407            }
408
409        if provider.name == ModelProviderName.openrouter:
410            # Ask OpenRouter to include usage in the response (cost)
411            extra_body["usage"] = {"include": True}
412
413        if provider.anthropic_extended_thinking:
414            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
415
416        if provider.r1_openrouter_options:
417            # Require providers that support the reasoning parameter
418            provider_options["require_parameters"] = True
419            # Prefer R1 providers with reasonable perf/quants
420            provider_options["order"] = ["Fireworks", "Together"]
421            # R1 providers with unreasonable quants
422            provider_options["ignore"] = ["DeepInfra"]
423
424        # Only set of this request is to get logprobs.
425        if (
426            provider.logprobs_openrouter_options
427            and self.base_adapter_config.top_logprobs is not None
428        ):
429            # Don't let OpenRouter choose a provider that doesn't support logprobs.
430            provider_options["require_parameters"] = True
431            # DeepInfra silently fails to return logprobs consistently.
432            provider_options["ignore"] = ["DeepInfra"]
433
434        if provider.openrouter_skip_required_parameters:
435            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
436            provider_options["require_parameters"] = False
437
438        # Siliconflow uses a bool flag for thinking, for some models
439        if provider.siliconflow_enable_thinking is not None:
440            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
441
442        if len(provider_options) > 0:
443            extra_body["provider"] = provider_options
444
445        return extra_body
446
447    def litellm_model_id(self) -> str:
448        # The model ID is an interesting combination of format and url endpoint.
449        # It specifics the provider URL/host, but this is overridden if you manually set an api url
450
451        if self._litellm_model_id:
452            return self._litellm_model_id
453
454        provider = self.model_provider()
455        if not provider.model_id:
456            raise ValueError("Model ID is required for OpenAI compatible models")
457
458        litellm_provider_name: str | None = None
459        is_custom = False
460        match provider.name:
461            case ModelProviderName.openrouter:
462                litellm_provider_name = "openrouter"
463            case ModelProviderName.openai:
464                litellm_provider_name = "openai"
465            case ModelProviderName.groq:
466                litellm_provider_name = "groq"
467            case ModelProviderName.anthropic:
468                litellm_provider_name = "anthropic"
469            case ModelProviderName.ollama:
470                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
471                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
472                is_custom = True
473            case ModelProviderName.docker_model_runner:
474                # Docker Model Runner uses OpenAI-compatible API, similar to Ollama
475                # We want direct control over the requests for features like response_format=json_schema
476                is_custom = True
477            case ModelProviderName.gemini_api:
478                litellm_provider_name = "gemini"
479            case ModelProviderName.fireworks_ai:
480                litellm_provider_name = "fireworks_ai"
481            case ModelProviderName.amazon_bedrock:
482                litellm_provider_name = "bedrock"
483            case ModelProviderName.azure_openai:
484                litellm_provider_name = "azure"
485            case ModelProviderName.huggingface:
486                litellm_provider_name = "huggingface"
487            case ModelProviderName.vertex:
488                litellm_provider_name = "vertex_ai"
489            case ModelProviderName.together_ai:
490                litellm_provider_name = "together_ai"
491            case ModelProviderName.cerebras:
492                litellm_provider_name = "cerebras"
493            case ModelProviderName.siliconflow_cn:
494                is_custom = True
495            case ModelProviderName.openai_compatible:
496                is_custom = True
497            case ModelProviderName.kiln_custom_registry:
498                is_custom = True
499            case ModelProviderName.kiln_fine_tune:
500                is_custom = True
501            case _:
502                raise_exhaustive_enum_error(provider.name)
503
504        if is_custom:
505            if self._api_base is None:
506                raise ValueError(
507                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
508                )
509            # Use openai as it's only used for format, not url
510            litellm_provider_name = "openai"
511
512        # Sholdn't be possible but keep type checker happy
513        if litellm_provider_name is None:
514            raise ValueError(
515                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
516            )
517
518        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
519        return self._litellm_model_id
520
521    async def build_completion_kwargs(
522        self,
523        provider: KilnModelProvider,
524        messages: list[ChatCompletionMessageIncludingLiteLLM],
525        top_logprobs: int | None,
526        skip_response_format: bool = False,
527    ) -> dict[str, Any]:
528        extra_body = self.build_extra_body(provider)
529
530        # Merge all parameters into a single kwargs dict for litellm
531        completion_kwargs = {
532            "model": self.litellm_model_id(),
533            "messages": messages,
534            "api_base": self._api_base,
535            "headers": self._headers,
536            "temperature": self.run_config.temperature,
537            "top_p": self.run_config.top_p,
538            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
539            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
540            # Better to ignore them than to fail the model call.
541            # https://docs.litellm.ai/docs/completion/input
542            "drop_params": True,
543            **extra_body,
544            **self._additional_body_options,
545        }
546
547        tool_calls = await self.litellm_tools()
548        has_tools = len(tool_calls) > 0
549        if has_tools:
550            completion_kwargs["tools"] = tool_calls
551            completion_kwargs["tool_choice"] = "auto"
552
553        if not skip_response_format:
554            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
555            response_format_options = await self.response_format_options()
556
557            # Check for a conflict between tools and response format using tools
558            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
559            if has_tools and "tools" in response_format_options:
560                raise ValueError(
561                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
562                )
563
564            completion_kwargs.update(response_format_options)
565
566        if top_logprobs is not None:
567            completion_kwargs["logprobs"] = True
568            completion_kwargs["top_logprobs"] = top_logprobs
569
570        return completion_kwargs
571
572    def usage_from_response(self, response: ModelResponse) -> Usage:
573        litellm_usage = response.get("usage", None)
574
575        # LiteLLM isn't consistent in how it returns the cost.
576        cost = response._hidden_params.get("response_cost", None)
577        if cost is None and litellm_usage:
578            cost = litellm_usage.get("cost", None)
579
580        usage = Usage()
581
582        if not litellm_usage and not cost:
583            return usage
584
585        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
586            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
587            usage.output_tokens = litellm_usage.get("completion_tokens", None)
588            usage.total_tokens = litellm_usage.get("total_tokens", None)
589        else:
590            logger.warning(
591                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
592            )
593
594        if isinstance(cost, float):
595            usage.cost = cost
596        elif cost is not None:
597            # None is allowed, but no other types are expected
598            logger.warning(
599                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
600            )
601
602        return usage
603
604    async def cached_available_tools(self) -> list[KilnToolInterface]:
605        if self._cached_available_tools is None:
606            self._cached_available_tools = await self.available_tools()
607        return self._cached_available_tools
608
609    async def litellm_tools(self) -> list[Dict]:
610        available_tools = await self.cached_available_tools()
611
612        # LiteLLM takes the standard OpenAI-compatible tool call format
613        return [await tool.toolcall_definition() for tool in available_tools]
614
615    async def process_tool_calls(
616        self, tool_calls: list[ChatCompletionMessageToolCall] | None
617    ) -> tuple[str | None, list[ChatCompletionToolMessageParam]]:
618        if tool_calls is None:
619            return None, []
620
621        assistant_output_from_toolcall: str | None = None
622        tool_call_response_messages: list[ChatCompletionToolMessageParam] = []
623
624        for tool_call in tool_calls:
625            # Kiln "task_response" tool is used for returning structured output via tool calls.
626            # Load the output from the tool call. Also
627            if tool_call.function.name == "task_response":
628                assistant_output_from_toolcall = tool_call.function.arguments
629                continue
630
631            # Process normal tool calls (not the "task_response" tool)
632            tool_name = tool_call.function.name
633            tool = None
634            for tool_option in await self.cached_available_tools():
635                if await tool_option.name() == tool_name:
636                    tool = tool_option
637                    break
638            if not tool:
639                raise RuntimeError(
640                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
641                )
642
643            # Parse the arguments and validate them against the tool's schema
644            try:
645                parsed_args = json.loads(tool_call.function.arguments)
646            except json.JSONDecodeError:
647                raise RuntimeError(
648                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
649                )
650            try:
651                tool_call_definition = await tool.toolcall_definition()
652                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
653                validate_schema_with_value_error(parsed_args, json_schema)
654            except Exception as e:
655                raise RuntimeError(
656                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
657                ) from e
658
659            result = await tool.run(**parsed_args)
660
661            tool_call_response_messages.append(
662                ChatCompletionToolMessageParam(
663                    role="tool",
664                    tool_call_id=tool_call.id,
665                    content=result,
666                )
667            )
668
669        if (
670            assistant_output_from_toolcall is not None
671            and len(tool_call_response_messages) > 0
672        ):
673            raise RuntimeError(
674                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
675            )
676
677        return assistant_output_from_toolcall, tool_call_response_messages
678
679    def litellm_message_to_trace_message(
680        self, raw_message: LiteLLMMessage
681    ) -> ChatCompletionAssistantMessageParamWrapper:
682        """
683        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
684        """
685        message: ChatCompletionAssistantMessageParamWrapper = {
686            "role": "assistant",
687        }
688        if raw_message.role != "assistant":
689            raise ValueError(
690                "Model returned a message with a role other than assistant. This is not supported."
691            )
692
693        if hasattr(raw_message, "content"):
694            message["content"] = raw_message.content
695        if hasattr(raw_message, "reasoning_content"):
696            message["reasoning_content"] = raw_message.reasoning_content
697        if hasattr(raw_message, "tool_calls"):
698            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
699            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
700            for litellm_tool_call in raw_message.tool_calls or []:
701                # Optional in the SDK for streaming responses, but should never be None at this point.
702                if litellm_tool_call.function.name is None:
703                    raise ValueError(
704                        "The model requested a tool call, without providing a function name (required)."
705                    )
706                open_ai_tool_calls.append(
707                    ChatCompletionMessageToolCallParam(
708                        id=litellm_tool_call.id,
709                        type="function",
710                        function={
711                            "name": litellm_tool_call.function.name,
712                            "arguments": litellm_tool_call.function.arguments,
713                        },
714                    )
715                )
716            if len(open_ai_tool_calls) > 0:
717                message["tool_calls"] = open_ai_tool_calls
718
719        if not message.get("content") and not message.get("tool_calls"):
720            raise ValueError(
721                "Model returned an assistant message, but no content or tool calls. This is not supported."
722            )
723
724        return message
725
726    def all_messages_to_trace(
727        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
728    ) -> list[ChatCompletionMessageParam]:
729        """
730        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
731        """
732        trace: list[ChatCompletionMessageParam] = []
733        for message in messages:
734            if isinstance(message, LiteLLMMessage):
735                trace.append(self.litellm_message_to_trace_message(message))
736            else:
737                trace.append(message)
738        return trace
MAX_CALLS_PER_TURN = 10
MAX_TOOL_CALLS_PER_TURN = 30
ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, openai.types.chat.chat_completion_tool_message_param.ChatCompletionToolMessageParam, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]
@dataclass
class ModelTurnResult:
57@dataclass
58class ModelTurnResult:
59    assistant_message: str
60    all_messages: list[ChatCompletionMessageIncludingLiteLLM]
61    model_response: ModelResponse | None
62    model_choice: Choices | None
63    usage: Usage
ModelTurnResult( assistant_message: str, all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, openai.types.chat.chat_completion_tool_message_param.ChatCompletionToolMessageParam, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], model_response: litellm.types.utils.ModelResponse | None, model_choice: litellm.types.utils.Choices | None, usage: kiln_ai.datamodel.Usage)
assistant_message: str
all_messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, openai.types.chat.chat_completion_tool_message_param.ChatCompletionToolMessageParam, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]
model_response: litellm.types.utils.ModelResponse | None
model_choice: litellm.types.utils.Choices | None
 66class LiteLlmAdapter(BaseAdapter):
 67    def __init__(
 68        self,
 69        config: LiteLlmConfig,
 70        kiln_task: datamodel.Task,
 71        base_adapter_config: AdapterConfig | None = None,
 72    ):
 73        self.config = config
 74        self._additional_body_options = config.additional_body_options
 75        self._api_base = config.base_url
 76        self._headers = config.default_headers
 77        self._litellm_model_id: str | None = None
 78        self._cached_available_tools: list[KilnToolInterface] | None = None
 79
 80        super().__init__(
 81            task=kiln_task,
 82            run_config=config.run_config_properties,
 83            config=base_adapter_config,
 84        )
 85
 86    async def _run_model_turn(
 87        self,
 88        provider: KilnModelProvider,
 89        prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
 90        top_logprobs: int | None,
 91        skip_response_format: bool,
 92    ) -> ModelTurnResult:
 93        """
 94        Call the model for a single top level turn: from user message to agent message.
 95
 96        It may make handle iterations of tool calls between the user/agent message if needed.
 97        """
 98
 99        usage = Usage()
100        messages = list(prior_messages)
101        tool_calls_count = 0
102
103        while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
104            # Build completion kwargs for tool calls
105            completion_kwargs = await self.build_completion_kwargs(
106                provider,
107                # Pass a copy, as acompletion mutates objects and breaks types.
108                copy.deepcopy(messages),
109                top_logprobs,
110                skip_response_format,
111            )
112
113            # Make the completion call
114            model_response, response_choice = await self.acompletion_checking_response(
115                **completion_kwargs
116            )
117
118            # count the usage
119            usage += self.usage_from_response(model_response)
120
121            # Extract content and tool calls
122            if not hasattr(response_choice, "message"):
123                raise ValueError("Response choice has no message")
124            content = response_choice.message.content
125            tool_calls = response_choice.message.tool_calls
126            if not content and not tool_calls:
127                raise ValueError(
128                    "Model returned an assistant message, but no content or tool calls. This is not supported."
129                )
130
131            # Add message to messages, so it can be used in the next turn
132            messages.append(response_choice.message)
133
134            # Process tool calls if any
135            if tool_calls and len(tool_calls) > 0:
136                (
137                    assistant_message_from_toolcall,
138                    tool_call_messages,
139                ) = await self.process_tool_calls(tool_calls)
140
141                # Add tool call results to messages
142                messages.extend(tool_call_messages)
143
144                # If task_response tool was called, we're done
145                if assistant_message_from_toolcall is not None:
146                    return ModelTurnResult(
147                        assistant_message=assistant_message_from_toolcall,
148                        all_messages=messages,
149                        model_response=model_response,
150                        model_choice=response_choice,
151                        usage=usage,
152                    )
153
154                # If there were tool calls, increment counter and continue
155                if tool_call_messages:
156                    tool_calls_count += 1
157                    continue
158
159            # If no tool calls, return the content as final output
160            if content:
161                return ModelTurnResult(
162                    assistant_message=content,
163                    all_messages=messages,
164                    model_response=model_response,
165                    model_choice=response_choice,
166                    usage=usage,
167                )
168
169            # If we get here with no content and no tool calls, break
170            raise RuntimeError(
171                "Model returned neither content nor tool calls. It must return at least one of these."
172            )
173
174        raise RuntimeError(
175            f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
176        )
177
178    async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
179        usage = Usage()
180
181        provider = self.model_provider()
182        if not provider.model_id:
183            raise ValueError("Model ID is required for OpenAI compatible models")
184
185        chat_formatter = self.build_chat_formatter(input)
186        messages: list[ChatCompletionMessageIncludingLiteLLM] = []
187
188        prior_output: str | None = None
189        final_choice: Choices | None = None
190        turns = 0
191
192        while True:
193            turns += 1
194            if turns > MAX_CALLS_PER_TURN:
195                raise RuntimeError(
196                    f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
197                )
198
199            turn = chat_formatter.next_turn(prior_output)
200            if turn is None:
201                # No next turn, we're done
202                break
203
204            # Add messages from the turn to chat history
205            for message in turn.messages:
206                if message.content is None:
207                    raise ValueError("Empty message content isn't allowed")
208                # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
209                messages.append({"role": message.role, "content": message.content})  # type: ignore
210
211            skip_response_format = not turn.final_call
212            turn_result = await self._run_model_turn(
213                provider,
214                messages,
215                self.base_adapter_config.top_logprobs if turn.final_call else None,
216                skip_response_format,
217            )
218
219            usage += turn_result.usage
220
221            prior_output = turn_result.assistant_message
222            messages = turn_result.all_messages
223            final_choice = turn_result.model_choice
224
225            if not prior_output:
226                raise RuntimeError("No assistant message/output returned from model")
227
228        logprobs = self._extract_and_validate_logprobs(final_choice)
229
230        # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
231        intermediate_outputs = chat_formatter.intermediate_outputs()
232        self._extract_reasoning_to_intermediate_outputs(
233            final_choice, intermediate_outputs
234        )
235
236        if not isinstance(prior_output, str):
237            raise RuntimeError(f"assistant message is not a string: {prior_output}")
238
239        trace = self.all_messages_to_trace(messages)
240        output = RunOutput(
241            output=prior_output,
242            intermediate_outputs=intermediate_outputs,
243            output_logprobs=logprobs,
244            trace=trace,
245        )
246
247        return output, usage
248
249    def _extract_and_validate_logprobs(
250        self, final_choice: Choices | None
251    ) -> ChoiceLogprobs | None:
252        """
253        Extract logprobs from the final choice and validate they exist if required.
254        """
255        logprobs = None
256        if (
257            final_choice is not None
258            and hasattr(final_choice, "logprobs")
259            and isinstance(final_choice.logprobs, ChoiceLogprobs)
260        ):
261            logprobs = final_choice.logprobs
262
263        # Check logprobs worked, if required
264        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
265            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
266
267        return logprobs
268
269    def _extract_reasoning_to_intermediate_outputs(
270        self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
271    ) -> None:
272        """Extract reasoning content from model choice and add to intermediate outputs if present."""
273        if (
274            final_choice is not None
275            and hasattr(final_choice, "message")
276            and hasattr(final_choice.message, "reasoning_content")
277        ):
278            reasoning_content = final_choice.message.reasoning_content
279            if reasoning_content is not None:
280                stripped_reasoning_content = reasoning_content.strip()
281                if len(stripped_reasoning_content) > 0:
282                    intermediate_outputs["reasoning"] = stripped_reasoning_content
283
284    async def acompletion_checking_response(
285        self, **kwargs
286    ) -> Tuple[ModelResponse, Choices]:
287        response = await litellm.acompletion(**kwargs)
288        if (
289            not isinstance(response, ModelResponse)
290            or not response.choices
291            or len(response.choices) == 0
292            or not isinstance(response.choices[0], Choices)
293        ):
294            raise RuntimeError(
295                f"Expected ModelResponse with Choices, got {type(response)}."
296            )
297        return response, response.choices[0]
298
299    def adapter_name(self) -> str:
300        return "kiln_openai_compatible_adapter"
301
302    async def response_format_options(self) -> dict[str, Any]:
303        # Unstructured if task isn't structured
304        if not self.has_structured_output():
305            return {}
306
307        structured_output_mode = self.run_config.structured_output_mode
308
309        match structured_output_mode:
310            case StructuredOutputMode.json_mode:
311                return {"response_format": {"type": "json_object"}}
312            case StructuredOutputMode.json_schema:
313                return self.json_schema_response_format()
314            case StructuredOutputMode.function_calling_weak:
315                return self.tool_call_params(strict=False)
316            case StructuredOutputMode.function_calling:
317                return self.tool_call_params(strict=True)
318            case StructuredOutputMode.json_instructions:
319                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
320                return {}
321            case StructuredOutputMode.json_custom_instructions:
322                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
323                return {}
324            case StructuredOutputMode.json_instruction_and_object:
325                # We set response_format to json_object and also set json instructions in the prompt
326                return {"response_format": {"type": "json_object"}}
327            case StructuredOutputMode.default:
328                provider_name = self.run_config.model_provider_name
329                if provider_name == ModelProviderName.ollama:
330                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
331                    return self.json_schema_response_format()
332                elif provider_name == ModelProviderName.docker_model_runner:
333                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
334                    return self.json_schema_response_format()
335                else:
336                    # Default to function calling -- it's older than the other modes. Higher compatibility.
337                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
338                    strict = provider_name == ModelProviderName.openai
339                    return self.tool_call_params(strict=strict)
340            case StructuredOutputMode.unknown:
341                # See above, but this case should never happen.
342                raise ValueError("Structured output mode is unknown.")
343            case _:
344                raise_exhaustive_enum_error(structured_output_mode)
345
346    def json_schema_response_format(self) -> dict[str, Any]:
347        output_schema = self.task.output_schema()
348        return {
349            "response_format": {
350                "type": "json_schema",
351                "json_schema": {
352                    "name": "task_response",
353                    "schema": output_schema,
354                },
355            }
356        }
357
358    def tool_call_params(self, strict: bool) -> dict[str, Any]:
359        # Add additional_properties: false to the schema (OpenAI requires this for some models)
360        output_schema = self.task.output_schema()
361        if not isinstance(output_schema, dict):
362            raise ValueError(
363                "Invalid output schema for this task. Can not use tool calls."
364            )
365        output_schema["additionalProperties"] = False
366
367        function_params = {
368            "name": "task_response",
369            "parameters": output_schema,
370        }
371        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
372        if strict:
373            function_params["strict"] = True
374
375        return {
376            "tools": [
377                {
378                    "type": "function",
379                    "function": function_params,
380                }
381            ],
382            "tool_choice": {
383                "type": "function",
384                "function": {"name": "task_response"},
385            },
386        }
387
388    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
389        # Don't love having this logic here. But it's worth the usability improvement
390        # so better to keep it than exclude it. Should figure out how I want to isolate
391        # this sort of logic so it's config driven and can be overridden
392
393        extra_body = {}
394        provider_options = {}
395
396        if provider.thinking_level is not None:
397            extra_body["reasoning_effort"] = provider.thinking_level
398
399        if provider.require_openrouter_reasoning:
400            # https://openrouter.ai/docs/use-cases/reasoning-tokens
401            extra_body["reasoning"] = {
402                "exclude": False,
403            }
404
405        if provider.gemini_reasoning_enabled:
406            extra_body["reasoning"] = {
407                "enabled": True,
408            }
409
410        if provider.name == ModelProviderName.openrouter:
411            # Ask OpenRouter to include usage in the response (cost)
412            extra_body["usage"] = {"include": True}
413
414        if provider.anthropic_extended_thinking:
415            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
416
417        if provider.r1_openrouter_options:
418            # Require providers that support the reasoning parameter
419            provider_options["require_parameters"] = True
420            # Prefer R1 providers with reasonable perf/quants
421            provider_options["order"] = ["Fireworks", "Together"]
422            # R1 providers with unreasonable quants
423            provider_options["ignore"] = ["DeepInfra"]
424
425        # Only set of this request is to get logprobs.
426        if (
427            provider.logprobs_openrouter_options
428            and self.base_adapter_config.top_logprobs is not None
429        ):
430            # Don't let OpenRouter choose a provider that doesn't support logprobs.
431            provider_options["require_parameters"] = True
432            # DeepInfra silently fails to return logprobs consistently.
433            provider_options["ignore"] = ["DeepInfra"]
434
435        if provider.openrouter_skip_required_parameters:
436            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
437            provider_options["require_parameters"] = False
438
439        # Siliconflow uses a bool flag for thinking, for some models
440        if provider.siliconflow_enable_thinking is not None:
441            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
442
443        if len(provider_options) > 0:
444            extra_body["provider"] = provider_options
445
446        return extra_body
447
448    def litellm_model_id(self) -> str:
449        # The model ID is an interesting combination of format and url endpoint.
450        # It specifics the provider URL/host, but this is overridden if you manually set an api url
451
452        if self._litellm_model_id:
453            return self._litellm_model_id
454
455        provider = self.model_provider()
456        if not provider.model_id:
457            raise ValueError("Model ID is required for OpenAI compatible models")
458
459        litellm_provider_name: str | None = None
460        is_custom = False
461        match provider.name:
462            case ModelProviderName.openrouter:
463                litellm_provider_name = "openrouter"
464            case ModelProviderName.openai:
465                litellm_provider_name = "openai"
466            case ModelProviderName.groq:
467                litellm_provider_name = "groq"
468            case ModelProviderName.anthropic:
469                litellm_provider_name = "anthropic"
470            case ModelProviderName.ollama:
471                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
472                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
473                is_custom = True
474            case ModelProviderName.docker_model_runner:
475                # Docker Model Runner uses OpenAI-compatible API, similar to Ollama
476                # We want direct control over the requests for features like response_format=json_schema
477                is_custom = True
478            case ModelProviderName.gemini_api:
479                litellm_provider_name = "gemini"
480            case ModelProviderName.fireworks_ai:
481                litellm_provider_name = "fireworks_ai"
482            case ModelProviderName.amazon_bedrock:
483                litellm_provider_name = "bedrock"
484            case ModelProviderName.azure_openai:
485                litellm_provider_name = "azure"
486            case ModelProviderName.huggingface:
487                litellm_provider_name = "huggingface"
488            case ModelProviderName.vertex:
489                litellm_provider_name = "vertex_ai"
490            case ModelProviderName.together_ai:
491                litellm_provider_name = "together_ai"
492            case ModelProviderName.cerebras:
493                litellm_provider_name = "cerebras"
494            case ModelProviderName.siliconflow_cn:
495                is_custom = True
496            case ModelProviderName.openai_compatible:
497                is_custom = True
498            case ModelProviderName.kiln_custom_registry:
499                is_custom = True
500            case ModelProviderName.kiln_fine_tune:
501                is_custom = True
502            case _:
503                raise_exhaustive_enum_error(provider.name)
504
505        if is_custom:
506            if self._api_base is None:
507                raise ValueError(
508                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
509                )
510            # Use openai as it's only used for format, not url
511            litellm_provider_name = "openai"
512
513        # Sholdn't be possible but keep type checker happy
514        if litellm_provider_name is None:
515            raise ValueError(
516                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
517            )
518
519        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
520        return self._litellm_model_id
521
522    async def build_completion_kwargs(
523        self,
524        provider: KilnModelProvider,
525        messages: list[ChatCompletionMessageIncludingLiteLLM],
526        top_logprobs: int | None,
527        skip_response_format: bool = False,
528    ) -> dict[str, Any]:
529        extra_body = self.build_extra_body(provider)
530
531        # Merge all parameters into a single kwargs dict for litellm
532        completion_kwargs = {
533            "model": self.litellm_model_id(),
534            "messages": messages,
535            "api_base": self._api_base,
536            "headers": self._headers,
537            "temperature": self.run_config.temperature,
538            "top_p": self.run_config.top_p,
539            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
540            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
541            # Better to ignore them than to fail the model call.
542            # https://docs.litellm.ai/docs/completion/input
543            "drop_params": True,
544            **extra_body,
545            **self._additional_body_options,
546        }
547
548        tool_calls = await self.litellm_tools()
549        has_tools = len(tool_calls) > 0
550        if has_tools:
551            completion_kwargs["tools"] = tool_calls
552            completion_kwargs["tool_choice"] = "auto"
553
554        if not skip_response_format:
555            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
556            response_format_options = await self.response_format_options()
557
558            # Check for a conflict between tools and response format using tools
559            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
560            if has_tools and "tools" in response_format_options:
561                raise ValueError(
562                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
563                )
564
565            completion_kwargs.update(response_format_options)
566
567        if top_logprobs is not None:
568            completion_kwargs["logprobs"] = True
569            completion_kwargs["top_logprobs"] = top_logprobs
570
571        return completion_kwargs
572
573    def usage_from_response(self, response: ModelResponse) -> Usage:
574        litellm_usage = response.get("usage", None)
575
576        # LiteLLM isn't consistent in how it returns the cost.
577        cost = response._hidden_params.get("response_cost", None)
578        if cost is None and litellm_usage:
579            cost = litellm_usage.get("cost", None)
580
581        usage = Usage()
582
583        if not litellm_usage and not cost:
584            return usage
585
586        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
587            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
588            usage.output_tokens = litellm_usage.get("completion_tokens", None)
589            usage.total_tokens = litellm_usage.get("total_tokens", None)
590        else:
591            logger.warning(
592                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
593            )
594
595        if isinstance(cost, float):
596            usage.cost = cost
597        elif cost is not None:
598            # None is allowed, but no other types are expected
599            logger.warning(
600                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
601            )
602
603        return usage
604
605    async def cached_available_tools(self) -> list[KilnToolInterface]:
606        if self._cached_available_tools is None:
607            self._cached_available_tools = await self.available_tools()
608        return self._cached_available_tools
609
610    async def litellm_tools(self) -> list[Dict]:
611        available_tools = await self.cached_available_tools()
612
613        # LiteLLM takes the standard OpenAI-compatible tool call format
614        return [await tool.toolcall_definition() for tool in available_tools]
615
616    async def process_tool_calls(
617        self, tool_calls: list[ChatCompletionMessageToolCall] | None
618    ) -> tuple[str | None, list[ChatCompletionToolMessageParam]]:
619        if tool_calls is None:
620            return None, []
621
622        assistant_output_from_toolcall: str | None = None
623        tool_call_response_messages: list[ChatCompletionToolMessageParam] = []
624
625        for tool_call in tool_calls:
626            # Kiln "task_response" tool is used for returning structured output via tool calls.
627            # Load the output from the tool call. Also
628            if tool_call.function.name == "task_response":
629                assistant_output_from_toolcall = tool_call.function.arguments
630                continue
631
632            # Process normal tool calls (not the "task_response" tool)
633            tool_name = tool_call.function.name
634            tool = None
635            for tool_option in await self.cached_available_tools():
636                if await tool_option.name() == tool_name:
637                    tool = tool_option
638                    break
639            if not tool:
640                raise RuntimeError(
641                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
642                )
643
644            # Parse the arguments and validate them against the tool's schema
645            try:
646                parsed_args = json.loads(tool_call.function.arguments)
647            except json.JSONDecodeError:
648                raise RuntimeError(
649                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
650                )
651            try:
652                tool_call_definition = await tool.toolcall_definition()
653                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
654                validate_schema_with_value_error(parsed_args, json_schema)
655            except Exception as e:
656                raise RuntimeError(
657                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
658                ) from e
659
660            result = await tool.run(**parsed_args)
661
662            tool_call_response_messages.append(
663                ChatCompletionToolMessageParam(
664                    role="tool",
665                    tool_call_id=tool_call.id,
666                    content=result,
667                )
668            )
669
670        if (
671            assistant_output_from_toolcall is not None
672            and len(tool_call_response_messages) > 0
673        ):
674            raise RuntimeError(
675                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
676            )
677
678        return assistant_output_from_toolcall, tool_call_response_messages
679
680    def litellm_message_to_trace_message(
681        self, raw_message: LiteLLMMessage
682    ) -> ChatCompletionAssistantMessageParamWrapper:
683        """
684        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
685        """
686        message: ChatCompletionAssistantMessageParamWrapper = {
687            "role": "assistant",
688        }
689        if raw_message.role != "assistant":
690            raise ValueError(
691                "Model returned a message with a role other than assistant. This is not supported."
692            )
693
694        if hasattr(raw_message, "content"):
695            message["content"] = raw_message.content
696        if hasattr(raw_message, "reasoning_content"):
697            message["reasoning_content"] = raw_message.reasoning_content
698        if hasattr(raw_message, "tool_calls"):
699            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
700            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
701            for litellm_tool_call in raw_message.tool_calls or []:
702                # Optional in the SDK for streaming responses, but should never be None at this point.
703                if litellm_tool_call.function.name is None:
704                    raise ValueError(
705                        "The model requested a tool call, without providing a function name (required)."
706                    )
707                open_ai_tool_calls.append(
708                    ChatCompletionMessageToolCallParam(
709                        id=litellm_tool_call.id,
710                        type="function",
711                        function={
712                            "name": litellm_tool_call.function.name,
713                            "arguments": litellm_tool_call.function.arguments,
714                        },
715                    )
716                )
717            if len(open_ai_tool_calls) > 0:
718                message["tool_calls"] = open_ai_tool_calls
719
720        if not message.get("content") and not message.get("tool_calls"):
721            raise ValueError(
722                "Model returned an assistant message, but no content or tool calls. This is not supported."
723            )
724
725        return message
726
727    def all_messages_to_trace(
728        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
729    ) -> list[ChatCompletionMessageParam]:
730        """
731        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
732        """
733        trace: list[ChatCompletionMessageParam] = []
734        for message in messages:
735            if isinstance(message, LiteLLMMessage):
736                trace.append(self.litellm_message_to_trace_message(message))
737            else:
738                trace.append(message)
739        return trace

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
67    def __init__(
68        self,
69        config: LiteLlmConfig,
70        kiln_task: datamodel.Task,
71        base_adapter_config: AdapterConfig | None = None,
72    ):
73        self.config = config
74        self._additional_body_options = config.additional_body_options
75        self._api_base = config.base_url
76        self._headers = config.default_headers
77        self._litellm_model_id: str | None = None
78        self._cached_available_tools: list[KilnToolInterface] | None = None
79
80        super().__init__(
81            task=kiln_task,
82            run_config=config.run_config_properties,
83            config=base_adapter_config,
84        )
config
async def acompletion_checking_response( self, **kwargs) -> Tuple[litellm.types.utils.ModelResponse, litellm.types.utils.Choices]:
284    async def acompletion_checking_response(
285        self, **kwargs
286    ) -> Tuple[ModelResponse, Choices]:
287        response = await litellm.acompletion(**kwargs)
288        if (
289            not isinstance(response, ModelResponse)
290            or not response.choices
291            or len(response.choices) == 0
292            or not isinstance(response.choices[0], Choices)
293        ):
294            raise RuntimeError(
295                f"Expected ModelResponse with Choices, got {type(response)}."
296            )
297        return response, response.choices[0]
def adapter_name(self) -> str:
299    def adapter_name(self) -> str:
300        return "kiln_openai_compatible_adapter"
async def response_format_options(self) -> dict[str, typing.Any]:
302    async def response_format_options(self) -> dict[str, Any]:
303        # Unstructured if task isn't structured
304        if not self.has_structured_output():
305            return {}
306
307        structured_output_mode = self.run_config.structured_output_mode
308
309        match structured_output_mode:
310            case StructuredOutputMode.json_mode:
311                return {"response_format": {"type": "json_object"}}
312            case StructuredOutputMode.json_schema:
313                return self.json_schema_response_format()
314            case StructuredOutputMode.function_calling_weak:
315                return self.tool_call_params(strict=False)
316            case StructuredOutputMode.function_calling:
317                return self.tool_call_params(strict=True)
318            case StructuredOutputMode.json_instructions:
319                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
320                return {}
321            case StructuredOutputMode.json_custom_instructions:
322                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
323                return {}
324            case StructuredOutputMode.json_instruction_and_object:
325                # We set response_format to json_object and also set json instructions in the prompt
326                return {"response_format": {"type": "json_object"}}
327            case StructuredOutputMode.default:
328                provider_name = self.run_config.model_provider_name
329                if provider_name == ModelProviderName.ollama:
330                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
331                    return self.json_schema_response_format()
332                elif provider_name == ModelProviderName.docker_model_runner:
333                    # Docker Model Runner uses OpenAI-compatible API with JSON schema support
334                    return self.json_schema_response_format()
335                else:
336                    # Default to function calling -- it's older than the other modes. Higher compatibility.
337                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
338                    strict = provider_name == ModelProviderName.openai
339                    return self.tool_call_params(strict=strict)
340            case StructuredOutputMode.unknown:
341                # See above, but this case should never happen.
342                raise ValueError("Structured output mode is unknown.")
343            case _:
344                raise_exhaustive_enum_error(structured_output_mode)
def json_schema_response_format(self) -> dict[str, typing.Any]:
346    def json_schema_response_format(self) -> dict[str, Any]:
347        output_schema = self.task.output_schema()
348        return {
349            "response_format": {
350                "type": "json_schema",
351                "json_schema": {
352                    "name": "task_response",
353                    "schema": output_schema,
354                },
355            }
356        }
def tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
358    def tool_call_params(self, strict: bool) -> dict[str, Any]:
359        # Add additional_properties: false to the schema (OpenAI requires this for some models)
360        output_schema = self.task.output_schema()
361        if not isinstance(output_schema, dict):
362            raise ValueError(
363                "Invalid output schema for this task. Can not use tool calls."
364            )
365        output_schema["additionalProperties"] = False
366
367        function_params = {
368            "name": "task_response",
369            "parameters": output_schema,
370        }
371        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
372        if strict:
373            function_params["strict"] = True
374
375        return {
376            "tools": [
377                {
378                    "type": "function",
379                    "function": function_params,
380                }
381            ],
382            "tool_choice": {
383                "type": "function",
384                "function": {"name": "task_response"},
385            },
386        }
def build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
388    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
389        # Don't love having this logic here. But it's worth the usability improvement
390        # so better to keep it than exclude it. Should figure out how I want to isolate
391        # this sort of logic so it's config driven and can be overridden
392
393        extra_body = {}
394        provider_options = {}
395
396        if provider.thinking_level is not None:
397            extra_body["reasoning_effort"] = provider.thinking_level
398
399        if provider.require_openrouter_reasoning:
400            # https://openrouter.ai/docs/use-cases/reasoning-tokens
401            extra_body["reasoning"] = {
402                "exclude": False,
403            }
404
405        if provider.gemini_reasoning_enabled:
406            extra_body["reasoning"] = {
407                "enabled": True,
408            }
409
410        if provider.name == ModelProviderName.openrouter:
411            # Ask OpenRouter to include usage in the response (cost)
412            extra_body["usage"] = {"include": True}
413
414        if provider.anthropic_extended_thinking:
415            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
416
417        if provider.r1_openrouter_options:
418            # Require providers that support the reasoning parameter
419            provider_options["require_parameters"] = True
420            # Prefer R1 providers with reasonable perf/quants
421            provider_options["order"] = ["Fireworks", "Together"]
422            # R1 providers with unreasonable quants
423            provider_options["ignore"] = ["DeepInfra"]
424
425        # Only set of this request is to get logprobs.
426        if (
427            provider.logprobs_openrouter_options
428            and self.base_adapter_config.top_logprobs is not None
429        ):
430            # Don't let OpenRouter choose a provider that doesn't support logprobs.
431            provider_options["require_parameters"] = True
432            # DeepInfra silently fails to return logprobs consistently.
433            provider_options["ignore"] = ["DeepInfra"]
434
435        if provider.openrouter_skip_required_parameters:
436            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
437            provider_options["require_parameters"] = False
438
439        # Siliconflow uses a bool flag for thinking, for some models
440        if provider.siliconflow_enable_thinking is not None:
441            extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
442
443        if len(provider_options) > 0:
444            extra_body["provider"] = provider_options
445
446        return extra_body
def litellm_model_id(self) -> str:
448    def litellm_model_id(self) -> str:
449        # The model ID is an interesting combination of format and url endpoint.
450        # It specifics the provider URL/host, but this is overridden if you manually set an api url
451
452        if self._litellm_model_id:
453            return self._litellm_model_id
454
455        provider = self.model_provider()
456        if not provider.model_id:
457            raise ValueError("Model ID is required for OpenAI compatible models")
458
459        litellm_provider_name: str | None = None
460        is_custom = False
461        match provider.name:
462            case ModelProviderName.openrouter:
463                litellm_provider_name = "openrouter"
464            case ModelProviderName.openai:
465                litellm_provider_name = "openai"
466            case ModelProviderName.groq:
467                litellm_provider_name = "groq"
468            case ModelProviderName.anthropic:
469                litellm_provider_name = "anthropic"
470            case ModelProviderName.ollama:
471                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
472                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
473                is_custom = True
474            case ModelProviderName.docker_model_runner:
475                # Docker Model Runner uses OpenAI-compatible API, similar to Ollama
476                # We want direct control over the requests for features like response_format=json_schema
477                is_custom = True
478            case ModelProviderName.gemini_api:
479                litellm_provider_name = "gemini"
480            case ModelProviderName.fireworks_ai:
481                litellm_provider_name = "fireworks_ai"
482            case ModelProviderName.amazon_bedrock:
483                litellm_provider_name = "bedrock"
484            case ModelProviderName.azure_openai:
485                litellm_provider_name = "azure"
486            case ModelProviderName.huggingface:
487                litellm_provider_name = "huggingface"
488            case ModelProviderName.vertex:
489                litellm_provider_name = "vertex_ai"
490            case ModelProviderName.together_ai:
491                litellm_provider_name = "together_ai"
492            case ModelProviderName.cerebras:
493                litellm_provider_name = "cerebras"
494            case ModelProviderName.siliconflow_cn:
495                is_custom = True
496            case ModelProviderName.openai_compatible:
497                is_custom = True
498            case ModelProviderName.kiln_custom_registry:
499                is_custom = True
500            case ModelProviderName.kiln_fine_tune:
501                is_custom = True
502            case _:
503                raise_exhaustive_enum_error(provider.name)
504
505        if is_custom:
506            if self._api_base is None:
507                raise ValueError(
508                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
509                )
510            # Use openai as it's only used for format, not url
511            litellm_provider_name = "openai"
512
513        # Sholdn't be possible but keep type checker happy
514        if litellm_provider_name is None:
515            raise ValueError(
516                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
517            )
518
519        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
520        return self._litellm_model_id
async def build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, openai.types.chat.chat_completion_tool_message_param.ChatCompletionToolMessageParam, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
522    async def build_completion_kwargs(
523        self,
524        provider: KilnModelProvider,
525        messages: list[ChatCompletionMessageIncludingLiteLLM],
526        top_logprobs: int | None,
527        skip_response_format: bool = False,
528    ) -> dict[str, Any]:
529        extra_body = self.build_extra_body(provider)
530
531        # Merge all parameters into a single kwargs dict for litellm
532        completion_kwargs = {
533            "model": self.litellm_model_id(),
534            "messages": messages,
535            "api_base": self._api_base,
536            "headers": self._headers,
537            "temperature": self.run_config.temperature,
538            "top_p": self.run_config.top_p,
539            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
540            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
541            # Better to ignore them than to fail the model call.
542            # https://docs.litellm.ai/docs/completion/input
543            "drop_params": True,
544            **extra_body,
545            **self._additional_body_options,
546        }
547
548        tool_calls = await self.litellm_tools()
549        has_tools = len(tool_calls) > 0
550        if has_tools:
551            completion_kwargs["tools"] = tool_calls
552            completion_kwargs["tool_choice"] = "auto"
553
554        if not skip_response_format:
555            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
556            response_format_options = await self.response_format_options()
557
558            # Check for a conflict between tools and response format using tools
559            # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
560            if has_tools and "tools" in response_format_options:
561                raise ValueError(
562                    "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
563                )
564
565            completion_kwargs.update(response_format_options)
566
567        if top_logprobs is not None:
568            completion_kwargs["logprobs"] = True
569            completion_kwargs["top_logprobs"] = top_logprobs
570
571        return completion_kwargs
def usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage:
573    def usage_from_response(self, response: ModelResponse) -> Usage:
574        litellm_usage = response.get("usage", None)
575
576        # LiteLLM isn't consistent in how it returns the cost.
577        cost = response._hidden_params.get("response_cost", None)
578        if cost is None and litellm_usage:
579            cost = litellm_usage.get("cost", None)
580
581        usage = Usage()
582
583        if not litellm_usage and not cost:
584            return usage
585
586        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
587            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
588            usage.output_tokens = litellm_usage.get("completion_tokens", None)
589            usage.total_tokens = litellm_usage.get("total_tokens", None)
590        else:
591            logger.warning(
592                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
593            )
594
595        if isinstance(cost, float):
596            usage.cost = cost
597        elif cost is not None:
598            # None is allowed, but no other types are expected
599            logger.warning(
600                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
601            )
602
603        return usage
async def cached_available_tools(self) -> list[kiln_ai.tools.KilnToolInterface]:
605    async def cached_available_tools(self) -> list[KilnToolInterface]:
606        if self._cached_available_tools is None:
607            self._cached_available_tools = await self.available_tools()
608        return self._cached_available_tools
async def litellm_tools(self) -> list[typing.Dict]:
610    async def litellm_tools(self) -> list[Dict]:
611        available_tools = await self.cached_available_tools()
612
613        # LiteLLM takes the standard OpenAI-compatible tool call format
614        return [await tool.toolcall_definition() for tool in available_tools]
async def process_tool_calls( self, tool_calls: list[litellm.types.utils.ChatCompletionMessageToolCall] | None) -> tuple[str | None, list[openai.types.chat.chat_completion_tool_message_param.ChatCompletionToolMessageParam]]:
616    async def process_tool_calls(
617        self, tool_calls: list[ChatCompletionMessageToolCall] | None
618    ) -> tuple[str | None, list[ChatCompletionToolMessageParam]]:
619        if tool_calls is None:
620            return None, []
621
622        assistant_output_from_toolcall: str | None = None
623        tool_call_response_messages: list[ChatCompletionToolMessageParam] = []
624
625        for tool_call in tool_calls:
626            # Kiln "task_response" tool is used for returning structured output via tool calls.
627            # Load the output from the tool call. Also
628            if tool_call.function.name == "task_response":
629                assistant_output_from_toolcall = tool_call.function.arguments
630                continue
631
632            # Process normal tool calls (not the "task_response" tool)
633            tool_name = tool_call.function.name
634            tool = None
635            for tool_option in await self.cached_available_tools():
636                if await tool_option.name() == tool_name:
637                    tool = tool_option
638                    break
639            if not tool:
640                raise RuntimeError(
641                    f"A tool named '{tool_name}' was invoked by a model, but was not available."
642                )
643
644            # Parse the arguments and validate them against the tool's schema
645            try:
646                parsed_args = json.loads(tool_call.function.arguments)
647            except json.JSONDecodeError:
648                raise RuntimeError(
649                    f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
650                )
651            try:
652                tool_call_definition = await tool.toolcall_definition()
653                json_schema = json.dumps(tool_call_definition["function"]["parameters"])
654                validate_schema_with_value_error(parsed_args, json_schema)
655            except Exception as e:
656                raise RuntimeError(
657                    f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
658                ) from e
659
660            result = await tool.run(**parsed_args)
661
662            tool_call_response_messages.append(
663                ChatCompletionToolMessageParam(
664                    role="tool",
665                    tool_call_id=tool_call.id,
666                    content=result,
667                )
668            )
669
670        if (
671            assistant_output_from_toolcall is not None
672            and len(tool_call_response_messages) > 0
673        ):
674            raise RuntimeError(
675                "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
676            )
677
678        return assistant_output_from_toolcall, tool_call_response_messages
def litellm_message_to_trace_message( self, raw_message: litellm.types.utils.Message) -> kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper:
680    def litellm_message_to_trace_message(
681        self, raw_message: LiteLLMMessage
682    ) -> ChatCompletionAssistantMessageParamWrapper:
683        """
684        Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
685        """
686        message: ChatCompletionAssistantMessageParamWrapper = {
687            "role": "assistant",
688        }
689        if raw_message.role != "assistant":
690            raise ValueError(
691                "Model returned a message with a role other than assistant. This is not supported."
692            )
693
694        if hasattr(raw_message, "content"):
695            message["content"] = raw_message.content
696        if hasattr(raw_message, "reasoning_content"):
697            message["reasoning_content"] = raw_message.reasoning_content
698        if hasattr(raw_message, "tool_calls"):
699            # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
700            open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
701            for litellm_tool_call in raw_message.tool_calls or []:
702                # Optional in the SDK for streaming responses, but should never be None at this point.
703                if litellm_tool_call.function.name is None:
704                    raise ValueError(
705                        "The model requested a tool call, without providing a function name (required)."
706                    )
707                open_ai_tool_calls.append(
708                    ChatCompletionMessageToolCallParam(
709                        id=litellm_tool_call.id,
710                        type="function",
711                        function={
712                            "name": litellm_tool_call.function.name,
713                            "arguments": litellm_tool_call.function.arguments,
714                        },
715                    )
716                )
717            if len(open_ai_tool_calls) > 0:
718                message["tool_calls"] = open_ai_tool_calls
719
720        if not message.get("content") and not message.get("tool_calls"):
721            raise ValueError(
722                "Model returned an assistant message, but no content or tool calls. This is not supported."
723            )
724
725        return message

Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper

def all_messages_to_trace( self, messages: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, openai.types.chat.chat_completion_tool_message_param.ChatCompletionToolMessageParam, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]) -> list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, openai.types.chat.chat_completion_tool_message_param.ChatCompletionToolMessageParam, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]]:
727    def all_messages_to_trace(
728        self, messages: list[ChatCompletionMessageIncludingLiteLLM]
729    ) -> list[ChatCompletionMessageParam]:
730        """
731        Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
732        """
733        trace: list[ChatCompletionMessageParam] = []
734        for message in messages:
735            if isinstance(message, LiteLLMMessage):
736                trace.append(self.litellm_message_to_trace_message(message))
737            else:
738                trace.append(message)
739        return trace

Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.